changing model loading behaviour

This commit is contained in:
Florian Förster 2025-01-16 16:02:22 +01:00
parent 42e2185f62
commit cf655eb00a
2 changed files with 24 additions and 10 deletions

View File

@ -22,6 +22,7 @@ from lang_main.constants import (
STFR_SIMILARITY, STFR_SIMILARITY,
) )
from lang_main.errors import LanguageModelNotFoundError from lang_main.errors import LanguageModelNotFoundError
from lang_main.loggers import logger_config as logger
from lang_main.types import ( from lang_main.types import (
LanguageModels, LanguageModels,
Model, Model,
@ -73,17 +74,19 @@ def load_spacy(
) )
) )
pretrained_model = cast(SpacyModel, spacy_model_obj.load()) pretrained_model = cast(SpacyModel, spacy_model_obj.load())
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
return pretrained_model return pretrained_model
def _preprocess_STFR_model_name( def _preprocess_STFR_model_name(
model_name: STFRModelTypes, model_name: STFRModelTypes | str,
backend: STFRBackends, backend: STFRBackends,
force_download: bool = False,
) -> str: ) -> str:
"""utility function to parse specific model names to their """utility function to parse specific model names to their
local file paths per backend local file paths per backend
necessary for models not present on the Huggingface Hub (like necessary for models not present on the Hugging Face Hub (like
own pretrained or optimised models) own pretrained or optimised models)
only if chosen model and backend in combination are defined a local only if chosen model and backend in combination are defined a local
file path is generated file path is generated
@ -94,6 +97,9 @@ def _preprocess_STFR_model_name(
model name given by configuration model name given by configuration
backend: STFRBackends backend: STFRBackends
backend given by configuration backend given by configuration
force_download: bool
try to download model even if it is configured as local,
by default: False
Returns Returns
------- -------
@ -102,7 +108,9 @@ def _preprocess_STFR_model_name(
""" """
combination = (model_name, backend) combination = (model_name, backend)
model_name_or_path: str model_name_or_path: str
if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]: if force_download:
model_name_or_path = model_name
elif combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
# !! defined that each model is placed in a folder with its model name # !! defined that each model is placed in a folder with its model name
# !! without any user names # !! without any user names
folder_name = model_name.split('/')[-1] folder_name = model_name.split('/')[-1]
@ -119,7 +127,7 @@ def _preprocess_STFR_model_name(
def load_sentence_transformer( def load_sentence_transformer(
model_name: STFRModelTypes, model_name: STFRModelTypes | str,
similarity_func: SimilarityFunction = SimilarityFunction.COSINE, similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = STFRBackends.TORCH, backend: STFRBackends = STFRBackends.TORCH,
device: STFRDeviceTypes = STFRDeviceTypes.CPU, device: STFRDeviceTypes = STFRDeviceTypes.CPU,
@ -127,10 +135,12 @@ def load_sentence_transformer(
trust_remote_code: bool = False, trust_remote_code: bool = False,
model_save_folder: str | None = None, model_save_folder: str | None = None,
model_kwargs: STFRModelArgs | dict[str, Any] | None = None, model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
force_download: bool = False,
) -> SentenceTransformer: ) -> SentenceTransformer:
model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend) model_name_or_path = _preprocess_STFR_model_name(
model_name=model_name, backend=backend, force_download=force_download
return SentenceTransformer( )
model = SentenceTransformer(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
similarity_fn_name=similarity_func, similarity_fn_name=similarity_func,
backend=backend, # type: ignore Literal matches Enum backend=backend, # type: ignore Literal matches Enum
@ -140,6 +150,9 @@ def load_sentence_transformer(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
model_kwargs=model_kwargs, # type: ignore model_kwargs=model_kwargs, # type: ignore
) )
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
return model
# ** configured model builder functions # ** configured model builder functions
@ -147,7 +160,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
LanguageModels.SENTENCE_TRANSFORMER: { LanguageModels.SENTENCE_TRANSFORMER: {
'func': load_sentence_transformer, 'func': load_sentence_transformer,
'kwargs': { 'kwargs': {
'model_name_or_path': STFR_MODEL_NAME, 'model_name': STFR_MODEL_NAME,
'similarity_func': STFR_SIMILARITY, 'similarity_func': STFR_SIMILARITY,
'backend': STFR_BACKEND, 'backend': STFR_BACKEND,
'device': STFR_DEVICE, 'device': STFR_DEVICE,

View File

@ -65,8 +65,9 @@ class SpacyModelTypes(enum.StrEnum):
DE_DEP_NEWS_TRF = 'de_dep_news_trf' DE_DEP_NEWS_TRF = 'de_dep_news_trf'
class STFRQuantFilenames(enum.StrEnum): class STFRONNXFilenames(enum.StrEnum):
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx' ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
ONNX_OPT_O3 = 'onnx/model_O3.onnx'
TorchDTypes: TypeAlias = Literal[ TorchDTypes: TypeAlias = Literal[
@ -79,7 +80,7 @@ TorchDTypes: TypeAlias = Literal[
class STFRModelArgs(TypedDict): class STFRModelArgs(TypedDict):
torch_dtype: NotRequired[TorchDTypes] torch_dtype: NotRequired[TorchDTypes]
provider: NotRequired[ONNXExecutionProvider] provider: NotRequired[ONNXExecutionProvider]
file_name: NotRequired[STFRQuantFilenames] file_name: NotRequired[STFRONNXFilenames]
export: NotRequired[bool] export: NotRequired[bool]