import pytest from sentence_transformers import SentenceTransformer from spacy.language import Language from lang_main import model_loader from lang_main.constants import ( SimilarityFunction, SpacyModelTypes, STFRBackends, STFRDeviceTypes, STFRModelTypes, stfr_model_args_onnx, ) from lang_main.errors import LanguageModelNotFoundError from lang_main.types import LanguageModels @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_DISTILROBERTA_V1, STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer( model_name, similarity_func, ) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.TORCH, device=STFRDeviceTypes.CPU, model_kwargs=None, local_files_only=False, ) assert isinstance(model, SentenceTransformer) @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_DISTILROBERTA_V1, STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.ONNX, device=STFRDeviceTypes.CPU, model_kwargs=stfr_model_args_onnx, # type: ignore local_files_only=False, ) assert isinstance(model, SentenceTransformer) @pytest.mark.parametrize( 'model_name', [ SpacyModelTypes.DE_CORE_NEWS_SM, SpacyModelTypes.DE_CORE_NEWS_MD, SpacyModelTypes.DE_CORE_NEWS_LG, SpacyModelTypes.DE_DEP_NEWS_TRF, ], ) @pytest.mark.mload def test_load_spacy_model(model_name): model = model_loader.load_spacy( model_name=model_name, ) assert isinstance(model, Language) def test_load_spacy_model_fail(): model_name = 'not_existing' with pytest.raises(LanguageModelNotFoundError): model = model_loader.load_spacy(model_name) @pytest.mark.mload def test_instantiate_spacy_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SPACY, ) assert isinstance(model, Language) @pytest.mark.mload def test_instantiate_stfr_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SENTENCE_TRANSFORMER, ) assert isinstance(model, SentenceTransformer)