diff --git a/notebooks/lang_main.xml b/notebooks/lang_main.xml
deleted file mode 100644
index 93adff8..0000000
--- a/notebooks/lang_main.xml
+++ /dev/null
@@ -1,128 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/notebooks/styles_template.xml b/notebooks/styles_template.xml
deleted file mode 100644
index a2090a9..0000000
--- a/notebooks/styles_template.xml
+++ /dev/null
@@ -1,123 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/notebooks/test.graphml b/notebooks/test.graphml
deleted file mode 100644
index 58011f8..0000000
--- a/notebooks/test.graphml
+++ /dev/null
@@ -1,37 +0,0 @@
-
-
-
-
-
-
- 14
-
-
- 10
-
-
- 6
-
-
- 12
-
-
- 1
-
-
- 2
-
-
- 5
-
-
- 3
-
-
- 6
-
-
- 4
-
-
-
diff --git a/notebooks/tk_graph_built.graphml b/notebooks/tk_graph_built.graphml
deleted file mode 100644
index 73538e4..0000000
--- a/notebooks/tk_graph_built.graphml
+++ /dev/null
@@ -1,73 +0,0 @@
-
-
-
-
-
-
- 2
-
-
- 1
-
-
- 4
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 2
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
- 1
-
-
-
diff --git a/notebooks/tk_graph_built.pkl b/notebooks/tk_graph_built.pkl
deleted file mode 100644
index 15b3a5d..0000000
Binary files a/notebooks/tk_graph_built.pkl and /dev/null differ
diff --git a/pyproject.toml b/pyproject.toml
index 6a7298e..12ad577 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "lang-main"
-version = "0.1.0a9"
+version = "0.1.0a14"
description = "Several tools to analyse TOM's data with strong focus on language processing"
authors = [
{name = "d-opt GmbH, resp. Florian Förster", email = "f.foerster@d-opt.com"},
@@ -154,7 +154,7 @@ directory = "reports/coverage"
[tool.bumpversion]
-current_version = "0.1.0a9"
+current_version = "0.1.0a14"
parse = """(?x)
(?P0|[1-9]\\d*)\\.
(?P0|[1-9]\\d*)\\.
diff --git a/src/lang_main/analysis/preprocessing.py b/src/lang_main/analysis/preprocessing.py
index dcebabd..b5b986b 100644
--- a/src/lang_main/analysis/preprocessing.py
+++ b/src/lang_main/analysis/preprocessing.py
@@ -193,7 +193,9 @@ def analyse_feature(
result_df = pd.concat([result_df, conc_df], ignore_index=True)
- result_df = result_df.sort_values(by='num_occur', ascending=False).copy()
+ result_df = result_df.sort_values(
+ by=['num_occur', 'len'], ascending=[False, False]
+ ).copy()
return (result_df,)
diff --git a/src/lang_main/constants.py b/src/lang_main/constants.py
index 7106572..88263b1 100644
--- a/src/lang_main/constants.py
+++ b/src/lang_main/constants.py
@@ -88,14 +88,17 @@ SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get(
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM
)
STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get(
- 'LANG_MAIN_STFR_MODEL', STFRModelTypes.ALL_MPNET_BASE_V2
+ 'LANG_MAIN_STFR_MODEL', STFRModelTypes.E5_BASE_STS_EN_DE
)
+STFR_CUSTOM_MODELS: Final[dict[tuple[STFRModelTypes, STFRBackends], bool]] = {
+ (STFRModelTypes.E5_BASE_STS_EN_DE, STFRBackends.ONNX): True,
+}
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
STFR_BACKEND: Final[str | STFRBackends] = os.environ.get(
'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH
)
-stfr_model_args_default: STFRModelArgs = {}
+stfr_model_args_default: STFRModelArgs = {'torch_dtype': 'float32'}
stfr_model_args_onnx: STFRModelArgs = {
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
'provider': ONNXExecutionProvider.CPU,
diff --git a/src/lang_main/model_loader.py b/src/lang_main/model_loader.py
index 99e9dc1..7e49250 100644
--- a/src/lang_main/model_loader.py
+++ b/src/lang_main/model_loader.py
@@ -12,8 +12,10 @@ from typing import (
from sentence_transformers import SentenceTransformer, SimilarityFunction
from lang_main.constants import (
+ MODEL_BASE_FOLDER,
SPACY_MODEL_NAME,
STFR_BACKEND,
+ STFR_CUSTOM_MODELS,
STFR_DEVICE,
STFR_MODEL_ARGS,
STFR_MODEL_NAME,
@@ -28,6 +30,7 @@ from lang_main.types import (
STFRBackends,
STFRDeviceTypes,
STFRModelArgs,
+ STFRModelTypes,
)
@@ -74,22 +77,67 @@ def load_spacy(
return pretrained_model
+def _preprocess_STFR_model_name(
+ model_name: STFRModelTypes,
+ backend: STFRBackends,
+) -> str:
+ """utility function to parse specific model names to their
+ local file paths per backend
+ necessary for models not present on the Huggingface Hub (like
+ own pretrained or optimised models)
+ only if chosen model and backend in combination are defined a local
+ file path is generated
+
+ Parameters
+ ----------
+ model_name : STFRModelTypes
+ model name given by configuration
+ backend: STFRBackends
+ backend given by configuration
+
+ Returns
+ -------
+ str
+ model name or specific file path if applicable
+ """
+ combination = (model_name, backend)
+ model_name_or_path: str
+ if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
+ # !! defined that each model is placed in a folder with its model name
+ # !! without any user names
+ folder_name = model_name.split('/')[-1]
+ model_path = MODEL_BASE_FOLDER / folder_name
+ if not model_path.exists():
+ raise FileNotFoundError(
+ f'Target model >{model_name}< not found under {model_path}'
+ )
+ model_name_or_path = str(model_path)
+ else:
+ model_name_or_path = model_name
+
+ return model_name_or_path
+
+
def load_sentence_transformer(
- model_name: str,
+ model_name: STFRModelTypes,
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = STFRBackends.TORCH,
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
local_files_only: bool = True,
+ trust_remote_code: bool = False,
model_save_folder: str | None = None,
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
) -> SentenceTransformer:
+ model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend)
+
return SentenceTransformer(
- model_name_or_path=model_name,
+ model_name_or_path=model_name_or_path,
similarity_fn_name=similarity_func,
backend=backend, # type: ignore Literal matches Enum
device=device,
cache_folder=model_save_folder,
local_files_only=local_files_only,
+ trust_remote_code=trust_remote_code,
model_kwargs=model_kwargs, # type: ignore
)
@@ -99,7 +147,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
LanguageModels.SENTENCE_TRANSFORMER: {
'func': load_sentence_transformer,
'kwargs': {
- 'model_name': STFR_MODEL_NAME,
+ 'model_name_or_path': STFR_MODEL_NAME,
'similarity_func': STFR_SIMILARITY,
'backend': STFR_BACKEND,
'device': STFR_DEVICE,
diff --git a/src/lang_main/types.py b/src/lang_main/types.py
index a521c47..b10a585 100644
--- a/src/lang_main/types.py
+++ b/src/lang_main/types.py
@@ -50,6 +50,12 @@ class STFRModelTypes(enum.StrEnum):
ALL_DISTILROBERTA_V1 = 'all-distilroberta-v1'
ALL_MINI_LM_L12_V2 = 'all-MiniLM-L12-v2'
ALL_MINI_LM_L6_V2 = 'all-MiniLM-L6-v2'
+ GERMAN_SEMANTIC_STS_V2 = 'aari1995/German_Semantic_STS_V2'
+ PARAPHRASE_MULTI_MPNET_BASE_V2 = 'paraphrase-multilingual-mpnet-base-v2'
+ JINAAI_BASE_DE_V2 = (
+ 'jinaai/jina-embeddings-v2-base-de' # only for testing, non-commercial
+ )
+ E5_BASE_STS_EN_DE = 'danielheinz/e5-base-sts-en-de'
class SpacyModelTypes(enum.StrEnum):
@@ -63,7 +69,15 @@ class STFRQuantFilenames(enum.StrEnum):
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
+TorchDTypes: TypeAlias = Literal[
+ 'float16',
+ 'bfloat16',
+ 'float32',
+]
+
+
class STFRModelArgs(TypedDict):
+ torch_dtype: NotRequired[TorchDTypes]
provider: NotRequired[ONNXExecutionProvider]
file_name: NotRequired[STFRQuantFilenames]
export: NotRequired[bool]
diff --git a/test.ps1 b/test.ps1
deleted file mode 100644
index 4592715..0000000
--- a/test.ps1
+++ /dev/null
@@ -1 +0,0 @@
-pdm run coverage run -p -m pytest -n 6
\ No newline at end of file