import argparse import os from dataclasses import dataclass from pathlib import Path from typing import Any, cast from lang_main.constants import stfr_model_args_default from lang_main.model_loader import ( MODEL_BASE_FOLDER, STFR_BACKEND, STFR_DEVICE, STFR_MODEL_ARGS, STFR_MODEL_NAME, STFR_SIMILARITY, load_sentence_transformer, ) from lang_main.types import ( SentenceTransformer, STFRBackends, STFRModelArgs, STFRONNXFilenames, ) from sentence_transformers.backend import ( export_dynamic_quantized_onnx_model, export_optimized_onnx_model, ) @dataclass class TypedArgumentParser: default: bool model: str path: Path | None convert: bool optim: bool quant: bool def get_model_name_from_repo( full_model_name: str, ) -> str: return full_model_name.split('/')[-1] def _preload_STFR_model( model_name_repo: str, backend: STFRBackends, model_kwargs: STFRModelArgs | dict[str, Any] | None, target_folder: Path | str | None, ) -> SentenceTransformer: save_folder: str | None = None if target_folder is not None: save_folder = str(target_folder) return load_sentence_transformer( model_name=model_name_repo, similarity_func=STFR_SIMILARITY, backend=backend, device=STFR_DEVICE, model_kwargs=model_kwargs, model_save_folder=save_folder, local_files_only=False, force_download=True, ) def _load_config_STFR_model() -> None: _ = _preload_STFR_model( model_name_repo=STFR_MODEL_NAME, backend=STFR_BACKEND, model_kwargs=STFR_MODEL_ARGS, target_folder=None, ) def _model_conversion( model_name_repo: str, quant: bool, optimise: bool, target_folder: Path | None, ) -> None: model_name = get_model_name_from_repo(model_name_repo) base_folder: Path = MODEL_BASE_FOLDER if target_folder is not None: base_folder = target_folder if base_folder.stem == 'converted': export_folder = (base_folder / model_name).resolve() else: export_folder = (base_folder / 'converted' / model_name).resolve() # attempt to download base model if not present _ = _preload_STFR_model( model_name_repo=model_name_repo, backend=STFRBackends.TORCH, model_kwargs=stfr_model_args_default, target_folder=base_folder, ) model_onnx = _preload_STFR_model( model_name_repo=model_name_repo, backend=STFRBackends.ONNX, model_kwargs=None, target_folder=base_folder, ) model_onnx.save_pretrained(path=str(export_folder), safe_serialization=True) path_export_onnx_base = export_folder / 'onnx' / 'model.onnx' assert path_export_onnx_base.exists(), 'ONNX base weights not existing' print(f'Saved converted ONNX model under: {path_export_onnx_base}') if quant: export_dynamic_quantized_onnx_model( model_onnx, quantization_config='avx2', model_name_or_path=str(export_folder) ) path_export_onnx_quant = export_folder / STFRONNXFilenames.ONNX_Q_UINT8 assert path_export_onnx_quant.exists(), 'ONNX quant weights not existing' print(f'Saved quantised ONNX model under: {path_export_onnx_quant}') os.remove(path_export_onnx_base) if optimise: export_optimized_onnx_model( model_onnx, optimization_config='O3', model_name_or_path=str(export_folder) ) path_export_onnx_optim = export_folder / STFRONNXFilenames.ONNX_OPT_O3 assert path_export_onnx_optim.exists(), 'ONNX optimised weights not existing' print(f'Saved optimised ONNX model under: {path_export_onnx_optim}') os.remove(path_export_onnx_base) def main() -> None: parser = argparse.ArgumentParser( prog='STFR-Model-Loader', description=( 'Helper program to pre-download SentenceTransformer models ' 'and convert them to different formats if desired' ), ) parser.add_argument( '-d', '--default', action='store_true', help='load model from default config' ) parser.add_argument( '-m', '--model', default=STFR_MODEL_NAME, help='model to load (full repo name from Hugging Face Hub)', ) parser.add_argument( '-p', '--path', type=Path, default=None, help='path to save models to' ) parser.add_argument( '-c', '--convert', action='store_true', help='convert model to ONNX format' ) parser.add_argument( '-o', '--optim', action='store_true', help=( 'optimise ONNX model with O3 profile, , model is ' 'always converted to ONNX beforehand' ), ) # parser.add_argument('--onnx', action='store_true', help='use ONNX backend') parser.add_argument( '--quant', action='store_true', help=( 'quantise model with "AVX2" configuration, model is always ' 'converted to ONNX beforehand' ), ) args = cast(TypedArgumentParser, parser.parse_args()) use_default_model = args.default convert_model = args.convert optimise_model = args.optim quantise_model = args.quant if use_default_model and convert_model: raise ValueError('Loading default model does not allow model conversion') path_models: Path | None = None if args.path is not None: path_models = args.path.resolve() assert path_models.exists(), 'model saving path not existing' assert path_models.is_dir(), 'model saving path not a directory' if args.default: _load_config_STFR_model() else: _model_conversion( model_name_repo=args.model, quant=quantise_model, optimise=optimise_model, target_folder=path_models, )