changing model loading behaviour
This commit is contained in:
parent
42e2185f62
commit
cf655eb00a
@ -22,6 +22,7 @@ from lang_main.constants import (
|
||||
STFR_SIMILARITY,
|
||||
)
|
||||
from lang_main.errors import LanguageModelNotFoundError
|
||||
from lang_main.loggers import logger_config as logger
|
||||
from lang_main.types import (
|
||||
LanguageModels,
|
||||
Model,
|
||||
@ -73,17 +74,19 @@ def load_spacy(
|
||||
)
|
||||
)
|
||||
pretrained_model = cast(SpacyModel, spacy_model_obj.load())
|
||||
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
|
||||
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def _preprocess_STFR_model_name(
|
||||
model_name: STFRModelTypes,
|
||||
model_name: STFRModelTypes | str,
|
||||
backend: STFRBackends,
|
||||
force_download: bool = False,
|
||||
) -> str:
|
||||
"""utility function to parse specific model names to their
|
||||
local file paths per backend
|
||||
necessary for models not present on the Huggingface Hub (like
|
||||
necessary for models not present on the Hugging Face Hub (like
|
||||
own pretrained or optimised models)
|
||||
only if chosen model and backend in combination are defined a local
|
||||
file path is generated
|
||||
@ -94,6 +97,9 @@ def _preprocess_STFR_model_name(
|
||||
model name given by configuration
|
||||
backend: STFRBackends
|
||||
backend given by configuration
|
||||
force_download: bool
|
||||
try to download model even if it is configured as local,
|
||||
by default: False
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -102,7 +108,9 @@ def _preprocess_STFR_model_name(
|
||||
"""
|
||||
combination = (model_name, backend)
|
||||
model_name_or_path: str
|
||||
if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
|
||||
if force_download:
|
||||
model_name_or_path = model_name
|
||||
elif combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
|
||||
# !! defined that each model is placed in a folder with its model name
|
||||
# !! without any user names
|
||||
folder_name = model_name.split('/')[-1]
|
||||
@ -119,7 +127,7 @@ def _preprocess_STFR_model_name(
|
||||
|
||||
|
||||
def load_sentence_transformer(
|
||||
model_name: STFRModelTypes,
|
||||
model_name: STFRModelTypes | str,
|
||||
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
|
||||
backend: STFRBackends = STFRBackends.TORCH,
|
||||
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
|
||||
@ -127,10 +135,12 @@ def load_sentence_transformer(
|
||||
trust_remote_code: bool = False,
|
||||
model_save_folder: str | None = None,
|
||||
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
|
||||
force_download: bool = False,
|
||||
) -> SentenceTransformer:
|
||||
model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend)
|
||||
|
||||
return SentenceTransformer(
|
||||
model_name_or_path = _preprocess_STFR_model_name(
|
||||
model_name=model_name, backend=backend, force_download=force_download
|
||||
)
|
||||
model = SentenceTransformer(
|
||||
model_name_or_path=model_name_or_path,
|
||||
similarity_fn_name=similarity_func,
|
||||
backend=backend, # type: ignore Literal matches Enum
|
||||
@ -140,6 +150,9 @@ def load_sentence_transformer(
|
||||
trust_remote_code=trust_remote_code,
|
||||
model_kwargs=model_kwargs, # type: ignore
|
||||
)
|
||||
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# ** configured model builder functions
|
||||
@ -147,7 +160,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||
'func': load_sentence_transformer,
|
||||
'kwargs': {
|
||||
'model_name_or_path': STFR_MODEL_NAME,
|
||||
'model_name': STFR_MODEL_NAME,
|
||||
'similarity_func': STFR_SIMILARITY,
|
||||
'backend': STFR_BACKEND,
|
||||
'device': STFR_DEVICE,
|
||||
|
||||
@ -65,8 +65,9 @@ class SpacyModelTypes(enum.StrEnum):
|
||||
DE_DEP_NEWS_TRF = 'de_dep_news_trf'
|
||||
|
||||
|
||||
class STFRQuantFilenames(enum.StrEnum):
|
||||
class STFRONNXFilenames(enum.StrEnum):
|
||||
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
|
||||
ONNX_OPT_O3 = 'onnx/model_O3.onnx'
|
||||
|
||||
|
||||
TorchDTypes: TypeAlias = Literal[
|
||||
@ -79,7 +80,7 @@ TorchDTypes: TypeAlias = Literal[
|
||||
class STFRModelArgs(TypedDict):
|
||||
torch_dtype: NotRequired[TorchDTypes]
|
||||
provider: NotRequired[ONNXExecutionProvider]
|
||||
file_name: NotRequired[STFRQuantFilenames]
|
||||
file_name: NotRequired[STFRONNXFilenames]
|
||||
export: NotRequired[bool]
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user