diff --git a/src/lang_main/analysis/graphs.py b/src/lang_main/analysis/graphs.py index 13425ba..64fb6d3 100644 --- a/src/lang_main/analysis/graphs.py +++ b/src/lang_main/analysis/graphs.py @@ -282,7 +282,7 @@ def filter_graph_by_node_degree( def filter_graph_by_number_edges( graph: TokenGraph, - limit: int, + limit: int | None, property: str = 'weight', descending: bool = True, ) -> TokenGraph: @@ -290,7 +290,10 @@ def filter_graph_by_number_edges( # edges original = set(graph.edges(data=property)) # type: ignore original_sorted = sorted(original, key=lambda tup: tup[2], reverse=descending) - chosen = set(original_sorted[:limit]) + if limit is not None: + chosen = set(original_sorted[:limit]) + else: + chosen = set(original_sorted) edges_to_drop = original.difference(chosen) graph.remove_edges_from(edges_to_drop) diff --git a/src/lang_main/constants.py b/src/lang_main/constants.py index 88263b1..9dab9ed 100644 --- a/src/lang_main/constants.py +++ b/src/lang_main/constants.py @@ -1,7 +1,7 @@ from enum import Enum # noqa: I001 from importlib.util import find_spec from pathlib import Path -from typing import Final +from typing import Final, cast import os from sentence_transformers import SimilarityFunction @@ -19,7 +19,7 @@ from lang_main.types import ( STFRDeviceTypes, STFRModelArgs, STFRModelTypes, - STFRQuantFilenames, # noqa: F401 + STFRONNXFilenames, # noqa: F401 SpacyModelTypes, ) @@ -85,7 +85,7 @@ os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER) # LANG_MAIN_STFR_BACKEND : STFR backend, choice between "torch" and "onnx" SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get( - 'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM + 'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_DEP_NEWS_TRF ) STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get( 'LANG_MAIN_STFR_MODEL', STFRModelTypes.E5_BASE_STS_EN_DE @@ -95,12 +95,12 @@ STFR_CUSTOM_MODELS: Final[dict[tuple[STFRModelTypes, STFRBackends], bool]] = { } STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE -STFR_BACKEND: Final[str | STFRBackends] = os.environ.get( - 'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH +STFR_BACKEND: Final[STFRBackends] = cast( + STFRBackends, os.environ.get('LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH) ) stfr_model_args_default: STFRModelArgs = {'torch_dtype': 'float32'} stfr_model_args_onnx: STFRModelArgs = { - 'file_name': STFRQuantFilenames.ONNX_Q_UINT8, + 'file_name': STFRONNXFilenames.ONNX_Q_UINT8, 'provider': ONNXExecutionProvider.CPU, 'export': False, } @@ -131,7 +131,11 @@ THRESHOLD_SIMILARITY: Final[float] = CONFIG['preprocess']['threshold_similarity' # ** graph postprocessing EDGE_WEIGHT_DECIMALS: Final[int] = 4 -THRESHOLD_EDGE_NUMBER: Final[int] = CONFIG['graph_postprocessing']['threshold_edge_number'] +threshold_edge_number: int | None = None +cfg_threshold_edge_number: int = CONFIG['graph_postprocessing']['threshold_edge_number'] +if cfg_threshold_edge_number >= 0: + threshold_edge_number = cfg_threshold_edge_number +THRESHOLD_EDGE_NUMBER: Final[int | None] = threshold_edge_number PROPERTY_NAME_DEGREE_WEIGHTED: Final[str] = 'degree_weighted' PROPERTY_NAME_BETWEENNESS_CENTRALITY: Final[str] = 'betweenness_centrality' PROPERTY_NAME_IMPORTANCE: Final[str] = 'importance' diff --git a/src/lang_main/lang_main_config.toml b/src/lang_main/lang_main_config.toml index 7b187f0..ca63d52 100644 --- a/src/lang_main/lang_main_config.toml +++ b/src/lang_main/lang_main_config.toml @@ -29,7 +29,7 @@ date_cols = [ "ErstellungsDatum", ] threshold_amount_characters = 5 -threshold_similarity = 0.8 +threshold_similarity = 0.9 [graph_postprocessing] threshold_edge_number = 330