tom-plugin/src/tom_plugin/_tools/_load_model.py
2025-01-23 12:05:13 +01:00

195 lines
5.8 KiB
Python

import argparse
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
from lang_main.constants import stfr_model_args_default
from lang_main.model_loader import (
MODEL_BASE_FOLDER,
STFR_BACKEND,
STFR_DEVICE,
STFR_MODEL_ARGS,
STFR_MODEL_NAME,
STFR_SIMILARITY,
load_sentence_transformer,
)
from lang_main.types import (
SentenceTransformer,
STFRBackends,
STFRModelArgs,
STFRONNXFilenames,
)
from sentence_transformers.backend import (
export_dynamic_quantized_onnx_model,
export_optimized_onnx_model,
)
@dataclass
class TypedArgumentParser:
default: bool
model: str
path: Path | None
convert: bool
optim: bool
quant: bool
def get_model_name_from_repo(
full_model_name: str,
) -> str:
return full_model_name.split('/')[-1]
def _preload_STFR_model(
model_name_repo: str,
backend: STFRBackends,
model_kwargs: STFRModelArgs | dict[str, Any] | None,
target_folder: Path | str | None,
) -> SentenceTransformer:
save_folder: str | None = None
if target_folder is not None:
save_folder = str(target_folder)
return load_sentence_transformer(
model_name=model_name_repo,
similarity_func=STFR_SIMILARITY,
backend=backend,
device=STFR_DEVICE,
model_kwargs=model_kwargs,
model_save_folder=save_folder,
local_files_only=False,
force_download=True,
)
def _load_config_STFR_model() -> None:
_ = _preload_STFR_model(
model_name_repo=STFR_MODEL_NAME,
backend=STFR_BACKEND,
model_kwargs=STFR_MODEL_ARGS,
target_folder=None,
)
def _model_conversion(
model_name_repo: str,
quant: bool,
optimise: bool,
target_folder: Path | None,
) -> None:
model_name = get_model_name_from_repo(model_name_repo)
base_folder: Path = MODEL_BASE_FOLDER
if target_folder is not None:
base_folder = target_folder
if base_folder.stem == 'converted':
export_folder = (base_folder / model_name).resolve()
else:
export_folder = (base_folder / 'converted' / model_name).resolve()
# attempt to download base model if not present
_ = _preload_STFR_model(
model_name_repo=model_name_repo,
backend=STFRBackends.TORCH,
model_kwargs=stfr_model_args_default,
target_folder=base_folder,
)
model_onnx = _preload_STFR_model(
model_name_repo=model_name_repo,
backend=STFRBackends.ONNX,
model_kwargs=None,
target_folder=base_folder,
)
model_onnx.save_pretrained(path=str(export_folder), safe_serialization=True)
path_export_onnx_base = export_folder / 'onnx' / 'model.onnx'
assert path_export_onnx_base.exists(), 'ONNX base weights not existing'
print(f'Saved converted ONNX model under: {path_export_onnx_base}')
if quant:
export_dynamic_quantized_onnx_model(
model_onnx, quantization_config='avx2', model_name_or_path=str(export_folder)
)
path_export_onnx_quant = export_folder / STFRONNXFilenames.ONNX_Q_UINT8
assert path_export_onnx_quant.exists(), 'ONNX quant weights not existing'
print(f'Saved quantised ONNX model under: {path_export_onnx_quant}')
os.remove(path_export_onnx_base)
if optimise:
export_optimized_onnx_model(
model_onnx, optimization_config='O3', model_name_or_path=str(export_folder)
)
path_export_onnx_optim = export_folder / STFRONNXFilenames.ONNX_OPT_O3
assert path_export_onnx_optim.exists(), 'ONNX optimised weights not existing'
print(f'Saved optimised ONNX model under: {path_export_onnx_optim}')
os.remove(path_export_onnx_base)
def main() -> None:
parser = argparse.ArgumentParser(
prog='STFR-Model-Loader',
description=(
'Helper program to pre-download SentenceTransformer models '
'and convert them to different formats if desired'
),
)
parser.add_argument(
'-d', '--default', action='store_true', help='load model from default config'
)
parser.add_argument(
'-m',
'--model',
default=STFR_MODEL_NAME,
help='model to load (full repo name from Hugging Face Hub)',
)
parser.add_argument(
'-p', '--path', type=Path, default=None, help='path to save models to'
)
parser.add_argument(
'-c', '--convert', action='store_true', help='convert model to ONNX format'
)
parser.add_argument(
'-o',
'--optim',
action='store_true',
help=(
'optimise ONNX model with O3 profile, , model is '
'always converted to ONNX beforehand'
),
)
# parser.add_argument('--onnx', action='store_true', help='use ONNX backend')
parser.add_argument(
'--quant',
action='store_true',
help=(
'quantise model with "AVX2" configuration, model is always '
'converted to ONNX beforehand'
),
)
args = cast(TypedArgumentParser, parser.parse_args())
use_default_model = args.default
convert_model = args.convert
optimise_model = args.optim
quantise_model = args.quant
if use_default_model and convert_model:
raise ValueError('Loading default model does not allow model conversion')
path_models: Path | None = None
if args.path is not None:
path_models = args.path.resolve()
assert path_models.exists(), 'model saving path not existing'
assert path_models.is_dir(), 'model saving path not a directory'
if args.default:
_load_config_STFR_model()
else:
_model_conversion(
model_name_repo=args.model,
quant=quantise_model,
optimise=optimise_model,
target_folder=path_models,
)