58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
import os
|
|
from pathlib import Path
|
|
|
|
os.environ['LANG_MAIN_STOP_SEARCH_FOLDERNAME'] = 'tom-plugin'
|
|
os.environ['LANG_MAIN_BASE_FOLDERNAME'] = 'tom-plugin'
|
|
|
|
from lang_main.constants import SimilarityFunction
|
|
from lang_main.model_loader import load_sentence_transformer as load_stfr
|
|
from lang_main.types import (
|
|
ONNXExecutionProvider,
|
|
STFRBackends,
|
|
STFRDeviceTypes,
|
|
STFRModelArgs,
|
|
TorchDTypes,
|
|
)
|
|
|
|
MODEL_NAME = 'mixedbread-ai/deepset-mxbai-embed-de-large-v1'
|
|
|
|
MODEL_ARGS: STFRModelArgs = {
|
|
# 'torch_dtype': 'float32',
|
|
'export': False,
|
|
# 'file_name': 'onnx/model_uint8.onnx', # type: ignore
|
|
'file_name': 'onnx/model_quantized.onnx', # type: ignore
|
|
'provider': ONNXExecutionProvider.CPU,
|
|
}
|
|
|
|
MODEL_PATH = Path(r'A:\Arbeitsaufgaben\lang-models')
|
|
|
|
|
|
def load_models(model_name: str, trust_remote: bool = False, use_onnx: bool = False):
|
|
assert MODEL_PATH.exists(), 'model path not existing'
|
|
if use_onnx:
|
|
model_kwargs = MODEL_ARGS
|
|
backend = STFRBackends.ONNX
|
|
else:
|
|
model_kwargs = {'torch_dtype': 'float32'}
|
|
backend = STFRBackends.TORCH
|
|
|
|
stfr_model = load_stfr(
|
|
model_name=model_name, # type: ignore
|
|
similarity_func=SimilarityFunction.COSINE,
|
|
backend=backend,
|
|
local_files_only=False,
|
|
trust_remote_code=trust_remote,
|
|
model_save_folder=str(MODEL_PATH),
|
|
model_kwargs=model_kwargs,
|
|
)
|
|
|
|
return stfr_model
|
|
|
|
|
|
def main():
|
|
load_models(MODEL_NAME)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|