195 lines
5.8 KiB
Python
195 lines
5.8 KiB
Python
import argparse
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, cast
|
|
|
|
from lang_main.constants import stfr_model_args_default
|
|
from lang_main.model_loader import (
|
|
MODEL_BASE_FOLDER,
|
|
STFR_BACKEND,
|
|
STFR_DEVICE,
|
|
STFR_MODEL_ARGS,
|
|
STFR_MODEL_NAME,
|
|
STFR_SIMILARITY,
|
|
load_sentence_transformer,
|
|
)
|
|
from lang_main.types import (
|
|
SentenceTransformer,
|
|
STFRBackends,
|
|
STFRModelArgs,
|
|
STFRONNXFilenames,
|
|
)
|
|
from sentence_transformers.backend import (
|
|
export_dynamic_quantized_onnx_model,
|
|
export_optimized_onnx_model,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TypedArgumentParser:
|
|
default: bool
|
|
model: str
|
|
path: Path | None
|
|
convert: bool
|
|
optim: bool
|
|
quant: bool
|
|
|
|
|
|
def get_model_name_from_repo(
|
|
full_model_name: str,
|
|
) -> str:
|
|
return full_model_name.split('/')[-1]
|
|
|
|
|
|
def _preload_STFR_model(
|
|
model_name_repo: str,
|
|
backend: STFRBackends,
|
|
model_kwargs: STFRModelArgs | dict[str, Any] | None,
|
|
target_folder: Path | str | None,
|
|
) -> SentenceTransformer:
|
|
save_folder: str | None = None
|
|
if target_folder is not None:
|
|
save_folder = str(target_folder)
|
|
|
|
return load_sentence_transformer(
|
|
model_name=model_name_repo,
|
|
similarity_func=STFR_SIMILARITY,
|
|
backend=backend,
|
|
device=STFR_DEVICE,
|
|
model_kwargs=model_kwargs,
|
|
model_save_folder=save_folder,
|
|
local_files_only=False,
|
|
force_download=True,
|
|
)
|
|
|
|
|
|
def _load_config_STFR_model() -> None:
|
|
_ = _preload_STFR_model(
|
|
model_name_repo=STFR_MODEL_NAME,
|
|
backend=STFR_BACKEND,
|
|
model_kwargs=STFR_MODEL_ARGS,
|
|
target_folder=None,
|
|
)
|
|
|
|
|
|
def _model_conversion(
|
|
model_name_repo: str,
|
|
quant: bool,
|
|
optimise: bool,
|
|
target_folder: Path | None,
|
|
) -> None:
|
|
model_name = get_model_name_from_repo(model_name_repo)
|
|
base_folder: Path = MODEL_BASE_FOLDER
|
|
if target_folder is not None:
|
|
base_folder = target_folder
|
|
|
|
if base_folder.stem == 'converted':
|
|
export_folder = (base_folder / model_name).resolve()
|
|
else:
|
|
export_folder = (base_folder / 'converted' / model_name).resolve()
|
|
|
|
# attempt to download base model if not present
|
|
_ = _preload_STFR_model(
|
|
model_name_repo=model_name_repo,
|
|
backend=STFRBackends.TORCH,
|
|
model_kwargs=stfr_model_args_default,
|
|
target_folder=base_folder,
|
|
)
|
|
|
|
model_onnx = _preload_STFR_model(
|
|
model_name_repo=model_name_repo,
|
|
backend=STFRBackends.ONNX,
|
|
model_kwargs=None,
|
|
target_folder=base_folder,
|
|
)
|
|
model_onnx.save_pretrained(path=str(export_folder), safe_serialization=True)
|
|
path_export_onnx_base = export_folder / 'onnx' / 'model.onnx'
|
|
assert path_export_onnx_base.exists(), 'ONNX base weights not existing'
|
|
print(f'Saved converted ONNX model under: {path_export_onnx_base}')
|
|
|
|
if quant:
|
|
export_dynamic_quantized_onnx_model(
|
|
model_onnx, quantization_config='avx2', model_name_or_path=str(export_folder)
|
|
)
|
|
path_export_onnx_quant = export_folder / STFRONNXFilenames.ONNX_Q_UINT8
|
|
assert path_export_onnx_quant.exists(), 'ONNX quant weights not existing'
|
|
print(f'Saved quantised ONNX model under: {path_export_onnx_quant}')
|
|
os.remove(path_export_onnx_base)
|
|
if optimise:
|
|
export_optimized_onnx_model(
|
|
model_onnx, optimization_config='O3', model_name_or_path=str(export_folder)
|
|
)
|
|
path_export_onnx_optim = export_folder / STFRONNXFilenames.ONNX_OPT_O3
|
|
assert path_export_onnx_optim.exists(), 'ONNX optimised weights not existing'
|
|
print(f'Saved optimised ONNX model under: {path_export_onnx_optim}')
|
|
os.remove(path_export_onnx_base)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
prog='STFR-Model-Loader',
|
|
description=(
|
|
'Helper program to pre-download SentenceTransformer models '
|
|
'and convert them to different formats if desired'
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
'-d', '--default', action='store_true', help='load model from default config'
|
|
)
|
|
parser.add_argument(
|
|
'-m',
|
|
'--model',
|
|
default=STFR_MODEL_NAME,
|
|
help='model to load (full repo name from Hugging Face Hub)',
|
|
)
|
|
parser.add_argument(
|
|
'-p', '--path', type=Path, default=None, help='path to save models to'
|
|
)
|
|
parser.add_argument(
|
|
'-c', '--convert', action='store_true', help='convert model to ONNX format'
|
|
)
|
|
parser.add_argument(
|
|
'-o',
|
|
'--optim',
|
|
action='store_true',
|
|
help=(
|
|
'optimise ONNX model with O3 profile, , model is '
|
|
'always converted to ONNX beforehand'
|
|
),
|
|
)
|
|
# parser.add_argument('--onnx', action='store_true', help='use ONNX backend')
|
|
parser.add_argument(
|
|
'--quant',
|
|
action='store_true',
|
|
help=(
|
|
'quantise model with "AVX2" configuration, model is always '
|
|
'converted to ONNX beforehand'
|
|
),
|
|
)
|
|
|
|
args = cast(TypedArgumentParser, parser.parse_args())
|
|
use_default_model = args.default
|
|
convert_model = args.convert
|
|
optimise_model = args.optim
|
|
quantise_model = args.quant
|
|
|
|
if use_default_model and convert_model:
|
|
raise ValueError('Loading default model does not allow model conversion')
|
|
|
|
path_models: Path | None = None
|
|
if args.path is not None:
|
|
path_models = args.path.resolve()
|
|
assert path_models.exists(), 'model saving path not existing'
|
|
assert path_models.is_dir(), 'model saving path not a directory'
|
|
|
|
if args.default:
|
|
_load_config_STFR_model()
|
|
else:
|
|
_model_conversion(
|
|
model_name_repo=args.model,
|
|
quant=quantise_model,
|
|
optimise=optimise_model,
|
|
target_folder=path_models,
|
|
)
|