Module lang_main.model_loader

Functions

def instantiate_model(model_load_map: ModelLoaderMap, model: LanguageModels) ‑> sentence_transformers.SentenceTransformer.SentenceTransformer | spacy.language.Language
Expand source code
def instantiate_model(
    model_load_map: ModelLoaderMap,
    model: LanguageModels,
) -> Model:
    if model not in model_load_map:
        raise KeyError(f'Model >>{model}<< not known. Choose from: {model_load_map.keys()}')
    builder_func = model_load_map[model]['func']
    func_kwargs = model_load_map[model]['kwargs']

    return builder_func(**func_kwargs)
def load_sentence_transformer(model_name: STFRModelTypes | str,
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = torch,
device: STFRDeviceTypes = cpu,
local_files_only: bool = True,
trust_remote_code: bool = False,
model_save_folder: str | None = None,
model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
force_download: bool = False) ‑> sentence_transformers.SentenceTransformer.SentenceTransformer
Expand source code
def load_sentence_transformer(
    model_name: STFRModelTypes | str,
    similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
    backend: STFRBackends = STFRBackends.TORCH,
    device: STFRDeviceTypes = STFRDeviceTypes.CPU,
    local_files_only: bool = True,
    trust_remote_code: bool = False,
    model_save_folder: str | None = None,
    model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
    force_download: bool = False,
) -> SentenceTransformer:
    model_name_or_path = _preprocess_STFR_model_name(
        model_name=model_name, backend=backend, force_download=force_download
    )
    model = SentenceTransformer(
        model_name_or_path=model_name_or_path,
        similarity_fn_name=similarity_func,
        backend=backend,  # type: ignore Literal matches Enum
        device=device,
        cache_folder=model_save_folder,
        local_files_only=local_files_only,
        trust_remote_code=trust_remote_code,
        model_kwargs=model_kwargs,  # type: ignore
    )
    logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)

    return model
def load_spacy(model_name: str) ‑> spacy.language.Language
Expand source code
def load_spacy(
    model_name: str,
) -> SpacyModel:
    try:
        spacy_model_obj = importlib.import_module(model_name)
    except ModuleNotFoundError:
        raise LanguageModelNotFoundError(
            (
                f'Could not find spaCy model >>{model_name}<<. '
                f'Check if it is installed correctly.'
            )
        )
    pretrained_model = cast(SpacyModel, spacy_model_obj.load())
    logger.info('[MODEL LOADING] Loaded model >>%s<< successfully', model_name)

    return pretrained_model