no limits on edge number threshold

This commit is contained in:
Florian Förster 2025-01-16 16:02:57 +01:00
parent cf655eb00a
commit 67fd541671
3 changed files with 17 additions and 10 deletions

View File

@ -282,7 +282,7 @@ def filter_graph_by_node_degree(
def filter_graph_by_number_edges(
graph: TokenGraph,
limit: int,
limit: int | None,
property: str = 'weight',
descending: bool = True,
) -> TokenGraph:
@ -290,7 +290,10 @@ def filter_graph_by_number_edges(
# edges
original = set(graph.edges(data=property)) # type: ignore
original_sorted = sorted(original, key=lambda tup: tup[2], reverse=descending)
chosen = set(original_sorted[:limit])
if limit is not None:
chosen = set(original_sorted[:limit])
else:
chosen = set(original_sorted)
edges_to_drop = original.difference(chosen)
graph.remove_edges_from(edges_to_drop)

View File

@ -1,7 +1,7 @@
from enum import Enum # noqa: I001
from importlib.util import find_spec
from pathlib import Path
from typing import Final
from typing import Final, cast
import os
from sentence_transformers import SimilarityFunction
@ -19,7 +19,7 @@ from lang_main.types import (
STFRDeviceTypes,
STFRModelArgs,
STFRModelTypes,
STFRQuantFilenames, # noqa: F401
STFRONNXFilenames, # noqa: F401
SpacyModelTypes,
)
@ -85,7 +85,7 @@ os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER)
# LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx"
SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get(
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_DEP_NEWS_TRF
)
STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get(
'LANG_MAIN_STFR_MODEL', STFRModelTypes.E5_BASE_STS_EN_DE
@ -95,12 +95,12 @@ STFR_CUSTOM_MODELS: Final[dict[tuple[STFRModelTypes, STFRBackends], bool]] = {
}
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
STFR_BACKEND: Final[str | STFRBackends] = os.environ.get(
'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH
STFR_BACKEND: Final[STFRBackends] = cast(
STFRBackends, os.environ.get('LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH)
)
stfr_model_args_default: STFRModelArgs = {'torch_dtype': 'float32'}
stfr_model_args_onnx: STFRModelArgs = {
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
'file_name': STFRONNXFilenames.ONNX_Q_UINT8,
'provider': ONNXExecutionProvider.CPU,
'export': False,
}
@ -131,7 +131,11 @@ THRESHOLD_SIMILARITY: Final[float] = CONFIG['preprocess']['threshold_similarity'
# ** graph postprocessing
EDGE_WEIGHT_DECIMALS: Final[int] = 4
THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number']
threshold_edge_number: int | None = None
cfg_threshold_edge_number: int = CONFIG['graph_postprocessing']['threshold_edge_number']
if cfg_threshold_edge_number >= 0:
threshold_edge_number = cfg_threshold_edge_number
THRESHOLD_EDGE_NUMBER: Final[int | None] = threshold_edge_number
PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted'
PROPERTY_NAME_BETWEENNESS_CENTRALITY: Final[str] = 'betweenness_centrality'
PROPERTY_NAME_IMPORTANCE: Final[str] = 'importance'

View File

@ -29,7 +29,7 @@ date_cols = [
"ErstellungsDatum",
]
threshold_amount_characters = 5
threshold_similarity = 0.8
threshold_similarity = 0.9
[graph_postprocessing]
threshold_edge_number = 330