diff --git a/src/lang_main/model_loader.py b/src/lang_main/model_loader.py index 7dc0cf6..99e9dc1 100644 --- a/src/lang_main/model_loader.py +++ b/src/lang_main/model_loader.py @@ -27,6 +27,7 @@ from lang_main.types import ( SpacyModel, STFRBackends, STFRDeviceTypes, + STFRModelArgs, ) @@ -78,9 +79,9 @@ def load_sentence_transformer( similarity_func: SimilarityFunction = SimilarityFunction.COSINE, backend: STFRBackends = STFRBackends.TORCH, device: STFRDeviceTypes = STFRDeviceTypes.CPU, - local_files_only: bool = False, + local_files_only: bool = True, model_save_folder: str | None = None, - model_kwargs: dict[str, Any] | None = None, + model_kwargs: STFRModelArgs | dict[str, Any] | None = None, ) -> SentenceTransformer: return SentenceTransformer( model_name_or_path=model_name, @@ -89,7 +90,7 @@ def load_sentence_transformer( device=device, cache_folder=model_save_folder, local_files_only=local_files_only, - model_kwargs=model_kwargs, + model_kwargs=model_kwargs, # type: ignore ) diff --git a/tests/test_model_loader.py b/tests/test_model_loader.py index 914b67c..1127383 100644 --- a/tests/test_model_loader.py +++ b/tests/test_model_loader.py @@ -42,6 +42,7 @@ def test_load_sentence_transformer( backend=STFRBackends.TORCH, device=STFRDeviceTypes.CPU, model_kwargs=None, + local_files_only=False, ) assert isinstance(model, SentenceTransformer) @@ -70,6 +71,7 @@ def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None: backend=STFRBackends.ONNX, device=STFRDeviceTypes.CPU, model_kwargs=stfr_model_args_onnx, # type: ignore + local_files_only=False, ) assert isinstance(model, SentenceTransformer)