diff --git a/src/lang_main/model_loader.py b/src/lang_main/model_loader.py index 7e49250..a00c3d2 100644 --- a/src/lang_main/model_loader.py +++ b/src/lang_main/model_loader.py @@ -22,6 +22,7 @@ from lang_main.constants import ( STFR_SIMILARITY, ) from lang_main.errors import LanguageModelNotFoundError +from lang_main.loggers import logger_config as logger from lang_main.types import ( LanguageModels, Model, @@ -73,17 +74,19 @@ def load_spacy( ) ) pretrained_model = cast(SpacyModel, spacy_model_obj.load()) + logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name) return pretrained_model def _preprocess_STFR_model_name( - model_name: STFRModelTypes, + model_name: STFRModelTypes | str, backend: STFRBackends, + force_download: bool = False, ) -> str: """utility function to parse specific model names to their local file paths per backend - necessary for models not present on the Huggingface Hub (like + necessary for models not present on the Hugging Face Hub (like own pretrained or optimised models) only if chosen model and backend in combination are defined a local file path is generated @@ -94,6 +97,9 @@ def _preprocess_STFR_model_name( model name given by configuration backend: STFRBackends backend given by configuration + force_download: bool + try to download model even if it is configured as local, + by default: False Returns ------- @@ -102,7 +108,9 @@ def _preprocess_STFR_model_name( """ combination = (model_name, backend) model_name_or_path: str - if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]: + if force_download: + model_name_or_path = model_name + elif combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]: # !! defined that each model is placed in a folder with its model name # !! without any user names folder_name = model_name.split('/')[-1] @@ -119,7 +127,7 @@ def _preprocess_STFR_model_name( def load_sentence_transformer( - model_name: STFRModelTypes, + model_name: STFRModelTypes | str, similarity_func: SimilarityFunction = SimilarityFunction.COSINE, backend: STFRBackends = STFRBackends.TORCH, device: STFRDeviceTypes = STFRDeviceTypes.CPU, @@ -127,10 +135,12 @@ def load_sentence_transformer( trust_remote_code: bool = False, model_save_folder: str | None = None, model_kwargs: STFRModelArgs | dict[str, Any] | None = None, + force_download: bool = False, ) -> SentenceTransformer: - model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend) - - return SentenceTransformer( + model_name_or_path = _preprocess_STFR_model_name( + model_name=model_name, backend=backend, force_download=force_download + ) + model = SentenceTransformer( model_name_or_path=model_name_or_path, similarity_fn_name=similarity_func, backend=backend, # type: ignore Literal matches Enum @@ -140,6 +150,9 @@ def load_sentence_transformer( trust_remote_code=trust_remote_code, model_kwargs=model_kwargs, # type: ignore ) + logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name) + + return model # ** configured model builder functions @@ -147,7 +160,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = { LanguageModels.SENTENCE_TRANSFORMER: { 'func': load_sentence_transformer, 'kwargs': { - 'model_name_or_path': STFR_MODEL_NAME, + 'model_name': STFR_MODEL_NAME, 'similarity_func': STFR_SIMILARITY, 'backend': STFR_BACKEND, 'device': STFR_DEVICE, diff --git a/src/lang_main/types.py b/src/lang_main/types.py index b10a585..eeec036 100644 --- a/src/lang_main/types.py +++ b/src/lang_main/types.py @@ -65,8 +65,9 @@ class SpacyModelTypes(enum.StrEnum): DE_DEP_NEWS_TRF = 'de_dep_news_trf' -class STFRQuantFilenames(enum.StrEnum): +class STFRONNXFilenames(enum.StrEnum): ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx' + ONNX_OPT_O3 = 'onnx/model_O3.onnx' TorchDTypes: TypeAlias = Literal[ @@ -79,7 +80,7 @@ TorchDTypes: TypeAlias = Literal[ class STFRModelArgs(TypedDict): torch_dtype: NotRequired[TorchDTypes] provider: NotRequired[ONNXExecutionProvider] - file_name: NotRequired[STFRQuantFilenames] + file_name: NotRequired[STFRONNXFilenames] export: NotRequired[bool]