import pytest from sentence_transformers import SentenceTransformer from spacy.language import Language from lang_main import model_loader from lang_main.constants import ( STFR_MODEL_ARGS_ONNX, SimilarityFunction, SpacyModelTypes, STFRBackends, STFRDeviceTypes, STFRModelTypes, ) from lang_main.types import LanguageModels @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_DISTILROBERTA_V1, STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer( model_name, similarity_func, ) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.TORCH, device=STFRDeviceTypes.CPU, model_kwargs=None, ) assert isinstance(model, SentenceTransformer) @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_DISTILROBERTA_V1, STFRModelTypes.ALL_MINI_LM_L12_V2, STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer_onnx( model_name, similarity_func, ) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.ONNX, device=STFRDeviceTypes.CPU, model_kwargs=STFR_MODEL_ARGS_ONNX, # type: ignore ) assert isinstance(model, SentenceTransformer) @pytest.mark.parametrize( 'model_name', [ SpacyModelTypes.DE_CORE_NEWS_SM, SpacyModelTypes.DE_CORE_NEWS_MD, SpacyModelTypes.DE_CORE_NEWS_LG, SpacyModelTypes.DE_DEP_NEWS_TRF, ], ) @pytest.mark.mload def test_load_spacy_model( model_name, ): model = model_loader.load_spacy( model_name=model_name, ) assert isinstance(model, Language) @pytest.mark.mload def test_instantiate_spacy_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SPACY, ) assert isinstance(model, Language) @pytest.mark.mload def test_instantiate_stfr_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SENTENCE_TRANSFORMER, ) assert isinstance(model, SentenceTransformer)