making offline loading of STFR default

This commit is contained in:
Florian Förster 2024-12-20 14:32:06 +01:00
parent 80a35c4658
commit ef5743bc85
2 changed files with 6 additions and 3 deletions

View File

@ -27,6 +27,7 @@ from lang_main.types import (
SpacyModel, SpacyModel,
STFRBackends, STFRBackends,
STFRDeviceTypes, STFRDeviceTypes,
STFRModelArgs,
) )
@ -78,9 +79,9 @@ def load_sentence_transformer(
similarity_func: SimilarityFunction = SimilarityFunction.COSINE, similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
backend: STFRBackends = STFRBackends.TORCH, backend: STFRBackends = STFRBackends.TORCH,
device: STFRDeviceTypes = STFRDeviceTypes.CPU, device: STFRDeviceTypes = STFRDeviceTypes.CPU,
local_files_only: bool = False, local_files_only: bool = True,
model_save_folder: str | None = None, model_save_folder: str | None = None,
model_kwargs: dict[str, Any] | None = None, model_kwargs: STFRModelArgs | dict[str, Any] | None = None,
) -> SentenceTransformer: ) -> SentenceTransformer:
return SentenceTransformer( return SentenceTransformer(
model_name_or_path=model_name, model_name_or_path=model_name,
@ -89,7 +90,7 @@ def load_sentence_transformer(
device=device, device=device,
cache_folder=model_save_folder, cache_folder=model_save_folder,
local_files_only=local_files_only, local_files_only=local_files_only,
model_kwargs=model_kwargs, model_kwargs=model_kwargs, # type: ignore
) )

View File

@ -42,6 +42,7 @@ def test_load_sentence_transformer(
backend=STFRBackends.TORCH, backend=STFRBackends.TORCH,
device=STFRDeviceTypes.CPU, device=STFRDeviceTypes.CPU,
model_kwargs=None, model_kwargs=None,
local_files_only=False,
) )
assert isinstance(model, SentenceTransformer) assert isinstance(model, SentenceTransformer)
@ -70,6 +71,7 @@ def test_load_sentence_transformer_onnx(model_name, similarity_func) -> None:
backend=STFRBackends.ONNX, backend=STFRBackends.ONNX,
device=STFRDeviceTypes.CPU, device=STFRDeviceTypes.CPU,
model_kwargs=stfr_model_args_onnx, # type: ignore model_kwargs=stfr_model_args_onnx, # type: ignore
local_files_only=False,
) )
assert isinstance(model, SentenceTransformer) assert isinstance(model, SentenceTransformer)