import pytest from sentence_transformers import SentenceTransformer from spacy.language import Language from lang_main import model_loader from lang_main.constants import ( SimilarityFunction, SpacyModelTypes, STFRBackends, STFRDeviceTypes, STFRModelTypes, stfr_model_args_onnx, ) from lang_main.errors import LanguageModelNotFoundError from lang_main.types import LanguageModels @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer( model_name, similarity_func, ) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.TORCH, device=STFRDeviceTypes.CPU, model_kwargs=None, local_files_only=False, ) assert isinstance(model, SentenceTransformer) def test_preprocess_STFR_model_name() -> None: model_name_not_exist = 'TestModel' ret_model_name = model_loader._preprocess_STFR_model_name( model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=True ) assert ret_model_name == model_name_not_exist ret_model_name = model_loader._preprocess_STFR_model_name( model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=False ) assert ret_model_name == model_name_not_exist model_name_exist = STFRModelTypes.E5_BASE_STS_EN_DE backend_exist = STFRBackends.ONNX with pytest.raises(FileNotFoundError): _ = model_loader._preprocess_STFR_model_name( model_name=model_name_exist, backend=backend_exist, force_download=False ) @pytest.mark.parametrize( 'similarity_func', [ SimilarityFunction.COSINE, SimilarityFunction.DOT, ], ) @pytest.mark.parametrize( 'model_name', [ STFRModelTypes.ALL_MINI_LM_L6_V2, STFRModelTypes.ALL_MPNET_BASE_V2, ], ) @pytest.mark.mload def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None: model = model_loader.load_sentence_transformer( model_name=model_name, similarity_func=similarity_func, backend=STFRBackends.ONNX, device=STFRDeviceTypes.CPU, model_kwargs=stfr_model_args_onnx, # type: ignore local_files_only=False, ) assert isinstance(model, SentenceTransformer) @pytest.mark.parametrize( 'model_name', [ SpacyModelTypes.DE_CORE_NEWS_SM, SpacyModelTypes.DE_CORE_NEWS_MD, SpacyModelTypes.DE_CORE_NEWS_LG, SpacyModelTypes.DE_DEP_NEWS_TRF, ], ) @pytest.mark.mload def test_load_spacy_model(model_name): model = model_loader.load_spacy( model_name=model_name, ) assert isinstance(model, Language) def test_load_spacy_model_fail(): model_name = 'not_existing' with pytest.raises(LanguageModelNotFoundError): model = model_loader.load_spacy(model_name) @pytest.mark.mload def test_instantiate_spacy_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SPACY, ) assert isinstance(model, Language) def test_fail_instantiate_spacy_model(): with pytest.raises(KeyError): _ = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model='test', # type: ignore ) # type: ignore @pytest.mark.mload def test_instantiate_stfr_model(): model = model_loader.instantiate_model( model_load_map=model_loader.MODEL_LOADER_MAP, model=LanguageModels.SENTENCE_TRANSFORMER, ) assert isinstance(model, SentenceTransformer)