import os from pathlib import Path os.environ['LANG_MAIN_STOP_SEARCH_FOLDERNAME'] = 'tom-plugin' os.environ['LANG_MAIN_BASE_FOLDERNAME'] = 'tom-plugin' from lang_main.constants import SimilarityFunction from lang_main.model_loader import load_sentence_transformer as load_stfr from lang_main.types import ( ONNXExecutionProvider, STFRBackends, STFRDeviceTypes, STFRModelArgs, TorchDTypes, ) MODEL_NAME = 'mixedbread-ai/deepset-mxbai-embed-de-large-v1' MODEL_ARGS: STFRModelArgs = { # 'torch_dtype': 'float32', 'export': False, # 'file_name': 'onnx/model_uint8.onnx', # type: ignore 'file_name': 'onnx/model_quantized.onnx', # type: ignore 'provider': ONNXExecutionProvider.CPU, } MODEL_PATH = Path(r'A:\Arbeitsaufgaben\lang-models') def load_models(model_name: str, trust_remote: bool = False, use_onnx: bool = False): assert MODEL_PATH.exists(), 'model path not existing' if use_onnx: model_kwargs = MODEL_ARGS backend = STFRBackends.ONNX else: model_kwargs = {'torch_dtype': 'float32'} backend = STFRBackends.TORCH stfr_model = load_stfr( model_name=model_name, # type: ignore similarity_func=SimilarityFunction.COSINE, backend=backend, local_files_only=False, trust_remote_code=trust_remote, model_save_folder=str(MODEL_PATH), model_kwargs=model_kwargs, ) return stfr_model def main(): load_models(MODEL_NAME) if __name__ == '__main__': main()