changing model loading behaviour
This commit is contained in:
parent
42e2185f62
commit
cf655eb00a
@ -22,6 +22,7 @@ from lang_main.constants import (
|
|||||||
STFR_SIMILARITY,
|
STFR_SIMILARITY,
|
||||||
)
|
)
|
||||||
from lang_main.errors import LanguageModelNotFoundError
|
from lang_main.errors import LanguageModelNotFoundError
|
||||||
|
from lang_main.loggers import logger_config as logger
|
||||||
from lang_main.types import (
|
from lang_main.types import (
|
||||||
LanguageModels,
|
LanguageModels,
|
||||||
Model,
|
Model,
|
||||||
@ -73,17 +74,19 @@ def load_spacy(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
pretrained_model = cast(SpacyModel, spacy_model_obj.load())
|
pretrained_model = cast(SpacyModel, spacy_model_obj.load())
|
||||||
|
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
|
||||||
|
|
||||||
return pretrained_model
|
return pretrained_model
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_STFR_model_name(
|
def _preprocess_STFR_model_name(
|
||||||
model_name: STFRModelTypes,
|
model_name: STFRModelTypes | str,
|
||||||
backend: STFRBackends,
|
backend: STFRBackends,
|
||||||
|
force_download: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""utility function to parse specific model names to their
|
"""utility function to parse specific model names to their
|
||||||
local file paths per backend
|
local file paths per backend
|
||||||
necessary for models not present on the Huggingface Hub (like
|
necessary for models not present on the Hugging Face Hub (like
|
||||||
own pretrained or optimised models)
|
own pretrained or optimised models)
|
||||||
only if chosen model and backend in combination are defined a local
|
only if chosen model and backend in combination are defined a local
|
||||||
file path is generated
|
file path is generated
|
||||||
@ -94,6 +97,9 @@ def _preprocess_STFR_model_name(
|
|||||||
model name given by configuration
|
model name given by configuration
|
||||||
backend: STFRBackends
|
backend: STFRBackends
|
||||||
backend given by configuration
|
backend given by configuration
|
||||||
|
force_download: bool
|
||||||
|
try to download model even if it is configured as local,
|
||||||
|
by default: False
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -102,7 +108,9 @@ def _preprocess_STFR_model_name(
|
|||||||
"""
|
"""
|
||||||
combination = (model_name, backend)
|
combination = (model_name, backend)
|
||||||
model_name_or_path: str
|
model_name_or_path: str
|
||||||
if combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
|
if force_download:
|
||||||
|
model_name_or_path = model_name
|
||||||
|
elif combination in STFR_CUSTOM_MODELS and STFR_CUSTOM_MODELS[combination]:
|
||||||
# !! defined that each model is placed in a folder with its model name
|
# !! defined that each model is placed in a folder with its model name
|
||||||
# !! without any user names
|
# !! without any user names
|
||||||
folder_name = model_name.split('/')[-1]
|
folder_name = model_name.split('/')[-1]
|
||||||
@ -119,7 +127,7 @@ def _preprocess_STFR_model_name(
|
|||||||
|
|
||||||
|
|
||||||
def load_sentence_transformer(
|
def load_sentence_transformer(
|
||||||
model_name: STFRModelTypes,
|
model_name: STFRModelTypes | str,
|
||||||
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
|
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
|
||||||
backend: STFRBackends = STFRBackends.TORCH,
|
backend: STFRBackends = STFRBackends.TORCH,
|
||||||
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
|
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
|
||||||
@ -127,10 +135,12 @@ def load_sentence_transformer(
|
|||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
model_save_folder: str | None = None,
|
model_save_folder: str | None = None,
|
||||||
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
|
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
) -> SentenceTransformer:
|
) -> SentenceTransformer:
|
||||||
model_name_or_path = _preprocess_STFR_model_name(model_name=model_name, backend=backend)
|
model_name_or_path = _preprocess_STFR_model_name(
|
||||||
|
model_name=model_name, backend=backend, force_download=force_download
|
||||||
return SentenceTransformer(
|
)
|
||||||
|
model = SentenceTransformer(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
similarity_fn_name=similarity_func,
|
similarity_fn_name=similarity_func,
|
||||||
backend=backend, # type: ignore Literal matches Enum
|
backend=backend, # type: ignore Literal matches Enum
|
||||||
@ -140,6 +150,9 @@ def load_sentence_transformer(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
model_kwargs=model_kwargs, # type: ignore
|
model_kwargs=model_kwargs, # type: ignore
|
||||||
)
|
)
|
||||||
|
logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
# ** configured model builder functions
|
# ** configured model builder functions
|
||||||
@ -147,7 +160,7 @@ MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
|||||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||||
'func': load_sentence_transformer,
|
'func': load_sentence_transformer,
|
||||||
'kwargs': {
|
'kwargs': {
|
||||||
'model_name_or_path': STFR_MODEL_NAME,
|
'model_name': STFR_MODEL_NAME,
|
||||||
'similarity_func': STFR_SIMILARITY,
|
'similarity_func': STFR_SIMILARITY,
|
||||||
'backend': STFR_BACKEND,
|
'backend': STFR_BACKEND,
|
||||||
'device': STFR_DEVICE,
|
'device': STFR_DEVICE,
|
||||||
|
|||||||
@ -65,8 +65,9 @@ class SpacyModelTypes(enum.StrEnum):
|
|||||||
DE_DEP_NEWS_TRF = 'de_dep_news_trf'
|
DE_DEP_NEWS_TRF = 'de_dep_news_trf'
|
||||||
|
|
||||||
|
|
||||||
class STFRQuantFilenames(enum.StrEnum):
|
class STFRONNXFilenames(enum.StrEnum):
|
||||||
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
|
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
|
||||||
|
ONNX_OPT_O3 = 'onnx/model_O3.onnx'
|
||||||
|
|
||||||
|
|
||||||
TorchDTypes: TypeAlias = Literal[
|
TorchDTypes: TypeAlias = Literal[
|
||||||
@ -79,7 +80,7 @@ TorchDTypes: TypeAlias = Literal[
|
|||||||
class STFRModelArgs(TypedDict):
|
class STFRModelArgs(TypedDict):
|
||||||
torch_dtype: NotRequired[TorchDTypes]
|
torch_dtype: NotRequired[TorchDTypes]
|
||||||
provider: NotRequired[ONNXExecutionProvider]
|
provider: NotRequired[ONNXExecutionProvider]
|
||||||
file_name: NotRequired[STFRQuantFilenames]
|
file_name: NotRequired[STFRONNXFilenames]
|
||||||
export: NotRequired[bool]
|
export: NotRequired[bool]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user