several changes for local models
This commit is contained in:
@@ -193,7 +193,9 @@ def analyse_feature(
|
||||
|
||||
result_df = pd.concat([result_df, conc_df], ignore_index=True)
|
||||
|
||||
result_df = result_df.sort_values(by='num_occur', ascending=False).copy()
|
||||
result_df = result_df.sort_values(
|
||||
by=['num_occur', 'len'], ascending=[False, False]
|
||||
).copy()
|
||||
|
||||
return (result_df,)
|
||||
|
||||
|
||||
@@ -88,14 +88,17 @@ SPACY_MODEL_NAME: Final[str | SpacyModelTypes] = os.environ.get(
|
||||
'LANG_MAIN_SPACY_MODEL', SpacyModelTypes.DE_CORE_NEWS_SM
|
||||
)
|
||||
STFR_MODEL_NAME: Final[str | STFRModelTypes] = os.environ.get(
|
||||
'LANG_MAIN_STFR_MODEL', STFRModelTypes.ALL_MPNET_BASE_V2
|
||||
'LANG_MAIN_STFR_MODEL', STFRModelTypes.E5_BASE_STS_EN_DE
|
||||
)
|
||||
STFR_CUSTOM_MODELS: Final[dict[tuple[STFRModelTypes, STFRBackends], bool]] = {
|
||||
(STFRModelTypes.E5_BASE_STS_EN_DE, STFRBackends.ONNX): True,
|
||||
}
|
||||
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
|
||||
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
|
||||
STFR_BACKEND: Final[str | STFRBackends] = os.environ.get(
|
||||
'LANG_MAIN_STFR_BACKEND', STFRBackends.TORCH
|
||||
)
|
||||
stfr_model_args_default: STFRModelArgs = {}
|
||||
stfr_model_args_default: STFRModelArgs = {'torch_dtype': 'float32'}
|
||||
stfr_model_args_onnx: STFRModelArgs = {
|
||||
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
|
||||
'provider': ONNXExecutionProvider.CPU,
|
||||
|
||||
@@ -12,8 +12,10 @@ from typing import (
|
||||
from sentence_transformers import SentenceTransformer, SimilarityFunction
|
||||
|
||||
from lang_main.constants import (
|
||||
MODEL_BASE_FOLDER,
|
||||
SPACY_MODEL_NAME,
|
||||
STFR_BACKEND,
|
||||
STFR_CUSTOM_MODELS,
|
||||
STFR_DEVICE,
|
||||
STFR_MODEL_ARGS,
|
||||
STFR_MODEL_NAME,
|
||||
@@ -28,6 +30,7 @@ from lang_main.types import (
|
||||
STFRBackends,
|
||||
STFRDeviceTypes,
|
||||
STFRModelArgs,
|
||||
STFRModelTypes,
|
||||
)
|
||||
|
||||
|
||||
@@ -74,22 +77,67 @@ def load_spacy(
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _preprocess_STFR_model_name(
|
||||
model_name: STFRModelTypes,
|
||||
backend: STFRBackends,
|
||||
) -> str:
|
||||
"""utility function to parse specific model names to their
|
||||
local file paths per backend
|
||||
necessary for models not present on the Huggingface Hub (like
|
||||
own pretrained or optimised models)
|
||||
only if chosen model and backend in combination are defined a local
|
||||
file path is generated
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : STFRModelTypes
|
||||
model name given by configuration
|
||||
backend: STFRBackends
|
||||
backend given by configuration
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
model name or specific file path if applicable
|
||||
"""
|
||||
combination = (model_name, backend)
|
||||
model_name_or_path: str
|
||||
if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
|
||||
# !! defined that each model is placed in a folder with its model name
|
||||
# !! without any user names
|
||||
folder_name = model_name.split('/')[-1]
|
||||
model_path = MODEL_BASE_FOLDER / folder_name
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f'Target model >{model_name}< not found under {model_path}'
|
||||
)
|
||||
model_name_or_path = str(model_path)
|
||||
else:
|
||||
model_name_or_path = model_name
|
||||
|
||||
return model_name_or_path
|
||||
|
||||
|
||||
def load_sentence_transformer(
|
||||
model_name: str,
|
||||
model_name: STFRModelTypes,
|
||||
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
|
||||
backend: STFRBackends = STFRBackends.TORCH,
|
||||
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
|
||||
local_files_only: bool = True,
|
||||
trust_remote_code: bool = False,
|
||||
model_save_folder: str | None = None,
|
||||
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
|
||||
) -> SentenceTransformer:
|
||||
model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend)
|
||||
|
||||
return SentenceTransformer(
|
||||
model_name_or_path=model_name,
|
||||
model_name_or_path=model_name_or_path,
|
||||
similarity_fn_name=similarity_func,
|
||||
backend=backend, # type: ignore Literal matches Enum
|
||||
device=device,
|
||||
cache_folder=model_save_folder,
|
||||
local_files_only=local_files_only,
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_kwargs=model_kwargs, # type: ignore
|
||||
)
|
||||
|
||||
@@ -99,7 +147,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||
'func': load_sentence_transformer,
|
||||
'kwargs': {
|
||||
'model_name': STFR_MODEL_NAME,
|
||||
'model_name_or_path': STFR_MODEL_NAME,
|
||||
'similarity_func': STFR_SIMILARITY,
|
||||
'backend': STFR_BACKEND,
|
||||
'device': STFR_DEVICE,
|
||||
|
||||
@@ -50,6 +50,12 @@ class STFRModelTypes(enum.StrEnum):
|
||||
ALL_DISTILROBERTA_V1 = 'all-distilroberta-v1'
|
||||
ALL_MINI_LM_L12_V2 = 'all-MiniLM-L12-v2'
|
||||
ALL_MINI_LM_L6_V2 = 'all-MiniLM-L6-v2'
|
||||
GERMAN_SEMANTIC_STS_V2 = 'aari1995/German_Semantic_STS_V2'
|
||||
PARAPHRASE_MULTI_MPNET_BASE_V2 = 'paraphrase-multilingual-mpnet-base-v2'
|
||||
JINAAI_BASE_DE_V2 = (
|
||||
'jinaai/jina-embeddings-v2-base-de' # only for testing, non-commercial
|
||||
)
|
||||
E5_BASE_STS_EN_DE = 'danielheinz/e5-base-sts-en-de'
|
||||
|
||||
|
||||
class SpacyModelTypes(enum.StrEnum):
|
||||
@@ -63,7 +69,15 @@ class STFRQuantFilenames(enum.StrEnum):
|
||||
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
|
||||
|
||||
|
||||
TorchDTypes: TypeAlias = Literal[
|
||||
'float16',
|
||||
'bfloat16',
|
||||
'float32',
|
||||
]
|
||||
|
||||
|
||||
class STFRModelArgs(TypedDict):
|
||||
torch_dtype: NotRequired[TorchDTypes]
|
||||
provider: NotRequired[ONNXExecutionProvider]
|
||||
file_name: NotRequired[STFRQuantFilenames]
|
||||
export: NotRequired[bool]
|
||||
|
||||
Reference in New Issue
Block a user