114 lines
2.7 KiB
Python
114 lines
2.7 KiB
Python
import pytest
|
|
from sentence_transformers import SentenceTransformer
|
|
from spacy.language import Language
|
|
|
|
from lang_main import model_loader
|
|
from lang_main.constants import (
|
|
STFR_MODEL_ARGS_ONNX,
|
|
SimilarityFunction,
|
|
SpacyModelTypes,
|
|
STFRBackends,
|
|
STFRDeviceTypes,
|
|
STFRModelTypes,
|
|
)
|
|
from lang_main.types import LanguageModels
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'similarity_func',
|
|
[
|
|
SimilarityFunction.COSINE,
|
|
SimilarityFunction.DOT,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
STFRModelTypes.ALL_DISTILROBERTA_V1,
|
|
STFRModelTypes.ALL_MINI_LM_L12_V2,
|
|
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
|
STFRModelTypes.ALL_MPNET_BASE_V2,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_sentence_transformer(
|
|
model_name,
|
|
similarity_func,
|
|
) -> None:
|
|
model = model_loader.load_sentence_transformer(
|
|
model_name=model_name,
|
|
similarity_func=similarity_func,
|
|
backend=STFRBackends.TORCH,
|
|
device=STFRDeviceTypes.CPU,
|
|
model_kwargs=None,
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'similarity_func',
|
|
[
|
|
SimilarityFunction.COSINE,
|
|
SimilarityFunction.DOT,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
STFRModelTypes.ALL_DISTILROBERTA_V1,
|
|
STFRModelTypes.ALL_MINI_LM_L12_V2,
|
|
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
|
STFRModelTypes.ALL_MPNET_BASE_V2,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_sentence_transformer_onnx(
|
|
model_name,
|
|
similarity_func,
|
|
) -> None:
|
|
model = model_loader.load_sentence_transformer(
|
|
model_name=model_name,
|
|
similarity_func=similarity_func,
|
|
backend=STFRBackends.ONNX,
|
|
device=STFRDeviceTypes.CPU,
|
|
model_kwargs=STFR_MODEL_ARGS_ONNX, # type: ignore
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'model_name',
|
|
[
|
|
SpacyModelTypes.DE_CORE_NEWS_SM,
|
|
SpacyModelTypes.DE_CORE_NEWS_MD,
|
|
SpacyModelTypes.DE_CORE_NEWS_LG,
|
|
SpacyModelTypes.DE_DEP_NEWS_TRF,
|
|
],
|
|
)
|
|
@pytest.mark.mload
|
|
def test_load_spacy_model(
|
|
model_name,
|
|
):
|
|
model = model_loader.load_spacy(
|
|
model_name=model_name,
|
|
)
|
|
assert isinstance(model, Language)
|
|
|
|
|
|
@pytest.mark.mload
|
|
def test_instantiate_spacy_model():
|
|
model = model_loader.instantiate_model(
|
|
model_load_map=model_loader.MODEL_LOADER_MAP,
|
|
model=LanguageModels.SPACY,
|
|
)
|
|
assert isinstance(model, Language)
|
|
|
|
|
|
@pytest.mark.mload
|
|
def test_instantiate_stfr_model():
|
|
model = model_loader.instantiate_model(
|
|
model_load_map=model_loader.MODEL_LOADER_MAP,
|
|
model=LanguageModels.SENTENCE_TRANSFORMER,
|
|
)
|
|
assert isinstance(model, SentenceTransformer)
|