141 lines
3.8 KiB
Python
141 lines
3.8 KiB
Python
import pytest
|
|
from sentence_transformers import SentenceTransformer
|
|
from spacy.language import Language
|
|
|
|
from lang_main import model_loader
|
|
from lang_main.constants import (
|
|
SimilarityFunction,
|
|
SpacyModelTypes,
|
|
STFRBackends,
|
|
STFRDeviceTypes,
|
|
STFRModelTypes,
|
|
stfr_model_args_onnx,
|
|
)
|
|
from lang_main.errors import LanguageModelNotFoundError
|
|
from lang_main.types import LanguageModels
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'similarity_func',
|
|
[
|
|
SimilarityFunction.COSINE,
|
|
SimilarityFunction.DOT,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
|
STFRModelTypes.ALL_MPNET_BASE_V2,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_sentence_transformer(
|
|
model_name,
|
|
similarity_func,
|
|
) -> None:
|
|
model = model_loader.load_sentence_transformer(
|
|
model_name=model_name,
|
|
similarity_func=similarity_func,
|
|
backend=STFRBackends.TORCH,
|
|
device=STFRDeviceTypes.CPU,
|
|
model_kwargs=None,
|
|
local_files_only=False,
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|
|
|
|
|
|
def test_preprocess_STFR_model_name() -> None:
|
|
model_name_not_exist = 'TestModel'
|
|
ret_model_name = model_loader._preprocess_STFR_model_name(
|
|
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=True
|
|
)
|
|
assert ret_model_name == model_name_not_exist
|
|
ret_model_name = model_loader._preprocess_STFR_model_name(
|
|
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=False
|
|
)
|
|
assert ret_model_name == model_name_not_exist
|
|
|
|
model_name_exist = STFRModelTypes.E5_BASE_STS_EN_DE
|
|
backend_exist = STFRBackends.ONNX
|
|
with pytest.raises(FileNotFoundError):
|
|
_ = model_loader._preprocess_STFR_model_name(
|
|
model_name=model_name_exist, backend=backend_exist, force_download=False
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'similarity_func',
|
|
[
|
|
SimilarityFunction.COSINE,
|
|
SimilarityFunction.DOT,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
|
STFRModelTypes.ALL_MPNET_BASE_V2,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None:
|
|
model = model_loader.load_sentence_transformer(
|
|
model_name=model_name,
|
|
similarity_func=similarity_func,
|
|
backend=STFRBackends.ONNX,
|
|
device=STFRDeviceTypes.CPU,
|
|
model_kwargs=stfr_model_args_onnx, # type: ignore
|
|
local_files_only=False,
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
SpacyModelTypes.DE_CORE_NEWS_SM,
|
|
SpacyModelTypes.DE_CORE_NEWS_MD,
|
|
SpacyModelTypes.DE_CORE_NEWS_LG,
|
|
SpacyModelTypes.DE_DEP_NEWS_TRF,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_spacy_model(model_name):
|
|
model = model_loader.load_spacy(
|
|
model_name=model_name,
|
|
)
|
|
assert isinstance(model, Language)
|
|
|
|
|
|
def test_load_spacy_model_fail():
|
|
model_name = 'not_existing'
|
|
with pytest.raises(LanguageModelNotFoundError):
|
|
model = model_loader.load_spacy(model_name)
|
|
|
|
|
|
@pytest.mark.mload
|
|
def test_instantiate_spacy_model():
|
|
model = model_loader.instantiate_model(
|
|
model_load_map=model_loader.MODEL_LOADER_MAP,
|
|
model=LanguageModels.SPACY,
|
|
)
|
|
assert isinstance(model, Language)
|
|
|
|
|
|
def test_fail_instantiate_spacy_model():
|
|
with pytest.raises(KeyError):
|
|
_ = model_loader.instantiate_model(
|
|
model_load_map=model_loader.MODEL_LOADER_MAP,
|
|
model='test', # type: ignore
|
|
) # type: ignore
|
|
|
|
|
|
@pytest.mark.mload
|
|
def test_instantiate_stfr_model():
|
|
model = model_loader.instantiate_model(
|
|
model_load_map=model_loader.MODEL_LOADER_MAP,
|
|
model=LanguageModels.SENTENCE_TRANSFORMER,
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|