lang-main/tests/test_model_loader.py
2024-12-20 14:32:06 +01:00

118 lines
3.0 KiB
Python

import pytest
from sentence_transformers import SentenceTransformer
from spacy.language import Language
from lang_main import model_loader
from lang_main.constants import (
SimilarityFunction,
SpacyModelTypes,
STFRBackends,
STFRDeviceTypes,
STFRModelTypes,
stfr_model_args_onnx,
)
from lang_main.errors import LanguageModelNotFoundError
from lang_main.types import LanguageModels
@pytest.mark.parametrize(
'similarity_func',
[
SimilarityFunction.COSINE,
SimilarityFunction.DOT,
],
)
@pytest.mark.parametrize(
'model_name',
[
STFRModelTypes.ALL_DISTILROBERTA_V1,
STFRModelTypes.ALL_MINI_LM_L12_V2,
STFRModelTypes.ALL_MINI_LM_L6_V2,
STFRModelTypes.ALL_MPNET_BASE_V2,
],
)
@pytest.mark.mload
def test_load_sentence_transformer(
model_name,
similarity_func,
) -> None:
model = model_loader.load_sentence_transformer(
model_name=model_name,
similarity_func=similarity_func,
backend=STFRBackends.TORCH,
device=STFRDeviceTypes.CPU,
model_kwargs=None,
local_files_only=False,
)
assert isinstance(model, SentenceTransformer)
@pytest.mark.parametrize(
'similarity_func',
[
SimilarityFunction.COSINE,
SimilarityFunction.DOT,
],
)
@pytest.mark.parametrize(
'model_name',
[
STFRModelTypes.ALL_DISTILROBERTA_V1,
STFRModelTypes.ALL_MINI_LM_L12_V2,
STFRModelTypes.ALL_MINI_LM_L6_V2,
STFRModelTypes.ALL_MPNET_BASE_V2,
],
)
@pytest.mark.mload
def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None:
model = model_loader.load_sentence_transformer(
model_name=model_name,
similarity_func=similarity_func,
backend=STFRBackends.ONNX,
device=STFRDeviceTypes.CPU,
model_kwargs=stfr_model_args_onnx, # type: ignore
local_files_only=False,
)
assert isinstance(model, SentenceTransformer)
@pytest.mark.parametrize(
'model_name',
[
SpacyModelTypes.DE_CORE_NEWS_SM,
SpacyModelTypes.DE_CORE_NEWS_MD,
SpacyModelTypes.DE_CORE_NEWS_LG,
SpacyModelTypes.DE_DEP_NEWS_TRF,
],
)
@pytest.mark.mload
def test_load_spacy_model(model_name):
model = model_loader.load_spacy(
model_name=model_name,
)
assert isinstance(model, Language)
def test_load_spacy_model_fail():
model_name = 'not_existing'
with pytest.raises(LanguageModelNotFoundError):
model = model_loader.load_spacy(model_name)
@pytest.mark.mload
def test_instantiate_spacy_model():
model = model_loader.instantiate_model(
model_load_map=model_loader.MODEL_LOADER_MAP,
model=LanguageModels.SPACY,
)
assert isinstance(model, Language)
@pytest.mark.mload
def test_instantiate_stfr_model():
model = model_loader.instantiate_model(
model_load_map=model_loader.MODEL_LOADER_MAP,
model=LanguageModels.SENTENCE_TRANSFORMER,
)
assert isinstance(model, SentenceTransformer)