From ef5743bc853bcbc4cdfd906767566c68b46b5c5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20F=C3=B6rster?= Date: Fri, 20 Dec 2024 14:32:06 +0100 Subject: [PATCH] making offline loading of STFR default --- src/lang_main/model_loader.py | 7 ++++--- tests/test_model_loader.py | 2 ++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lang_main/model_loader.py b/src/lang_main/model_loader.py index 7dc0cf6..99e9dc1 100644 --- a/src/lang_main/model_loader.py +++ b/src/lang_main/model_loader.py @@ -27,6 +27,7 @@ from lang_main.types import ( SpacyModel, STFRBackends, STFRDeviceTypes, + STFRModelArgs, ) @@ -78,9 +79,9 @@ def load_sentence_transformer( similarity_func: SimilarityFunction = SimilarityFunction.COSINE, backend: STFRBackends = STFRBackends.TORCH, device: STFRDeviceTypes = STFRDeviceTypes.CPU, - local_files_only: bool = False, + local_files_only: bool = True, model_save_folder: str | None = None, - model_kwargs: dict[str, Any] | None = None, + model_kwargs: STFRModelArgs | dict[str, Any] | None = None, ) -> SentenceTransformer: return SentenceTransformer( model_name_or_path=model_name, @@ -89,7 +90,7 @@ def load_sentence_transformer( device=device, cache_folder=model_save_folder, local_files_only=local_files_only, - model_kwargs=model_kwargs, + model_kwargs=model_kwargs, # type: ignore ) diff --git a/tests/test_model_loader.py b/tests/test_model_loader.py index 914b67c..1127383 100644 --- a/tests/test_model_loader.py +++ b/tests/test_model_loader.py @@ -42,6 +42,7 @@ def test_load_sentence_transformer( backend=STFRBackends.TORCH, device=STFRDeviceTypes.CPU, model_kwargs=None, + local_files_only=False, ) assert isinstance(model, SentenceTransformer) @@ -70,6 +71,7 @@ def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None: backend=STFRBackends.ONNX, device=STFRDeviceTypes.CPU, model_kwargs=stfr_model_args_onnx, # type: ignore + local_files_only=False, ) assert isinstance(model, SentenceTransformer)