changing model loading behaviour

This commit is contained in:
Florian Förster 2025-01-16 16:02:22 +01:00
parent 42e2185f62
commit cf655eb00a
2 changed files with 24 additions and 10 deletions

View File

@ -22,6 +22,7 @@ from lang_main.constants import (
STFR_SIMILARITY,
)
from lang_main.errors import LanguageModelNotFoundError
from lang_main.loggers import logger_config as logger
from lang_main.types import (
LanguageModels,
Model,
@ -73,17 +74,19 @@ def load_spacy(
)
)
pretrained_model = cast(SpacyModel, spacy_model_obj.load())
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
return pretrained_model
def _preprocess_STFR_model_name(
model_name: STFRModelTypes,
model_name: STFRModelTypes | str,
backend: STFRBackends,
force_download: bool = False,
) -> str:
"""utility function to parse specific model names to their
local file paths per backend
necessary for models not present on the Huggingface Hub (like
necessary for models not present on the Hugging Face Hub (like
own pretrained or optimised models)
only if chosen model and backend in combination are defined a local
file path is generated
@ -94,6 +97,9 @@ def _preprocess_STFR_model_name(
model name given by configuration
backend: STFRBackends
backend given by configuration
force_download: bool
try to download model even if it is configured as local,
by default: False
Returns
-------
@ -102,7 +108,9 @@ def _preprocess_STFR_model_name(
"""
combination = (model_name, backend)
model_name_or_path: str
if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
if force_download:
model_name_or_path = model_name
elif combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
# !! defined that each model is placed in a folder with its model name
# !! without any user names
folder_name = model_name.split('/')[-1]
@ -119,7 +127,7 @@ def _preprocess_STFR_model_name(
def load_sentence_transformer(
model_name: STFRModelTypes,
model_name: STFRModelTypes | str,
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = STFRBackends.TORCH,
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
@ -127,10 +135,12 @@ def load_sentence_transformer(
trust_remote_code: bool = False,
model_save_folder: str | None = None,
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
force_download: bool = False,
) -> SentenceTransformer:
model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend)
return SentenceTransformer(
model_name_or_path = _preprocess_STFR_model_name(
model_name=model_name, backend=backend, force_download=force_download
)
model = SentenceTransformer(
model_name_or_path=model_name_or_path,
similarity_fn_name=similarity_func,
backend=backend, # type: ignore Literal matches Enum
@ -140,6 +150,9 @@ def load_sentence_transformer(
trust_remote_code=trust_remote_code,
model_kwargs=model_kwargs, # type: ignore
)
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
return model
# ** configured model builder functions
@ -147,7 +160,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
LanguageModels.SENTENCE_TRANSFORMER: {
'func': load_sentence_transformer,
'kwargs': {
'model_name_or_path': STFR_MODEL_NAME,
'model_name': STFR_MODEL_NAME,
'similarity_func': STFR_SIMILARITY,
'backend': STFR_BACKEND,
'device': STFR_DEVICE,

View File

@ -65,8 +65,9 @@ class SpacyModelTypes(enum.StrEnum):
DE_DEP_NEWS_TRF = 'de_dep_news_trf'
class STFRQuantFilenames(enum.StrEnum):
class STFRONNXFilenames(enum.StrEnum):
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
ONNX_OPT_O3 = 'onnx/model_O3.onnx'
TorchDTypes: TypeAlias = Literal[
@ -79,7 +80,7 @@ TorchDTypes: TypeAlias = Literal[
class STFRModelArgs(TypedDict):
torch_dtype: NotRequired[TorchDTypes]
provider: NotRequired[ONNXExecutionProvider]
file_name: NotRequired[STFRQuantFilenames]
file_name: NotRequired[STFRONNXFilenames]
export: NotRequired[bool]