initial commit V0.1.0
This commit is contained in:
3
src/tom_plugin/__init__.py
Normal file
3
src/tom_plugin/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from tom_plugin import _env_vars as env
|
||||
|
||||
env.set()
|
||||
46
src/tom_plugin/_env_vars.py
Normal file
46
src/tom_plugin/_env_vars.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
|
||||
# ** ENV VARS
|
||||
def set() -> None:
|
||||
library_mode = os.environ.get('DOPT_TOM_PLUGIN_LIBRARY_USAGE', None)
|
||||
LIBRARY_MODE: Final[bool] = bool(library_mode)
|
||||
|
||||
if LIBRARY_MODE:
|
||||
_set_lib_mode()
|
||||
else:
|
||||
_set_app_mode(
|
||||
spacy_model=None,
|
||||
STFR_model=None,
|
||||
)
|
||||
|
||||
|
||||
def _set_lib_mode() -> None:
|
||||
os.environ['LANG_MAIN_STFR_BACKEND'] = 'onnx'
|
||||
|
||||
|
||||
def _set_app_mode(
|
||||
spacy_model: str | None = None,
|
||||
STFR_model: str | None = None,
|
||||
) -> None:
|
||||
os.environ['LANG_MAIN_STFR_BACKEND'] = 'onnx'
|
||||
os.environ['LANG_MAIN_STOP_SEARCH_FOLDERNAME'] = 'tom-plugin'
|
||||
os.environ['LANG_MAIN_BASE_FOLDERNAME'] = 'tom-plugin'
|
||||
|
||||
if spacy_model is not None:
|
||||
_set_spacy_model(spacy_model)
|
||||
if STFR_model is not None:
|
||||
_set_STFR_model(STFR_model)
|
||||
|
||||
|
||||
def _set_spacy_model(
|
||||
model_name: str = 'de_core_news_md',
|
||||
) -> None:
|
||||
os.environ['LANG_MAIN_SPACY_MODEL'] = model_name
|
||||
|
||||
|
||||
def _set_STFR_model(
|
||||
model_name: str = 'all-mpnet-base-v2',
|
||||
) -> None:
|
||||
os.environ['LANG_MAIN_SPACY_MODEL'] = model_name
|
||||
0
src/tom_plugin/_tools/__init__.py
Normal file
0
src/tom_plugin/_tools/__init__.py
Normal file
194
src/tom_plugin/_tools/_load_model.py
Normal file
194
src/tom_plugin/_tools/_load_model.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
from lang_main.constants import stfr_model_args_default
|
||||
from lang_main.model_loader import (
|
||||
MODEL_BASE_FOLDER,
|
||||
STFR_BACKEND,
|
||||
STFR_DEVICE,
|
||||
STFR_MODEL_ARGS,
|
||||
STFR_MODEL_NAME,
|
||||
STFR_SIMILARITY,
|
||||
load_sentence_transformer,
|
||||
)
|
||||
from lang_main.types import (
|
||||
SentenceTransformer,
|
||||
STFRBackends,
|
||||
STFRModelArgs,
|
||||
STFRONNXFilenames,
|
||||
)
|
||||
from sentence_transformers.backend import (
|
||||
export_dynamic_quantized_onnx_model,
|
||||
export_optimized_onnx_model,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypedArgumentParser:
|
||||
default: bool
|
||||
model: str
|
||||
path: Path | None
|
||||
convert: bool
|
||||
optim: bool
|
||||
quant: bool
|
||||
|
||||
|
||||
def get_model_name_from_repo(
|
||||
full_model_name: str,
|
||||
) -> str:
|
||||
return full_model_name.split('/')[-1]
|
||||
|
||||
|
||||
def _preload_STFR_model(
|
||||
model_name_repo: str,
|
||||
backend: STFRBackends,
|
||||
model_kwargs: STFRModelArgs | dict[str, Any] | None,
|
||||
target_folder: Path | str | None,
|
||||
) -> SentenceTransformer:
|
||||
save_folder: str | None = None
|
||||
if target_folder is not None:
|
||||
save_folder = str(target_folder)
|
||||
|
||||
return load_sentence_transformer(
|
||||
model_name=model_name_repo,
|
||||
similarity_func=STFR_SIMILARITY,
|
||||
backend=backend,
|
||||
device=STFR_DEVICE,
|
||||
model_kwargs=model_kwargs,
|
||||
model_save_folder=save_folder,
|
||||
local_files_only=False,
|
||||
force_download=True,
|
||||
)
|
||||
|
||||
|
||||
def _load_config_STFR_model() -> None:
|
||||
_ = _preload_STFR_model(
|
||||
model_name_repo=STFR_MODEL_NAME,
|
||||
backend=STFR_BACKEND,
|
||||
model_kwargs=STFR_MODEL_ARGS,
|
||||
target_folder=None,
|
||||
)
|
||||
|
||||
|
||||
def _model_conversion(
|
||||
model_name_repo: str,
|
||||
quant: bool,
|
||||
optimise: bool,
|
||||
target_folder: Path | None,
|
||||
) -> None:
|
||||
model_name = get_model_name_from_repo(model_name_repo)
|
||||
base_folder: Path = MODEL_BASE_FOLDER
|
||||
if target_folder is not None:
|
||||
base_folder = target_folder
|
||||
|
||||
if base_folder.stem == 'converted':
|
||||
export_folder = (base_folder / model_name).resolve()
|
||||
else:
|
||||
export_folder = (base_folder / 'converted' / model_name).resolve()
|
||||
|
||||
# attempt to download base model if not present
|
||||
_ = _preload_STFR_model(
|
||||
model_name_repo=model_name_repo,
|
||||
backend=STFRBackends.TORCH,
|
||||
model_kwargs=stfr_model_args_default,
|
||||
target_folder=base_folder,
|
||||
)
|
||||
|
||||
model_onnx = _preload_STFR_model(
|
||||
model_name_repo=model_name_repo,
|
||||
backend=STFRBackends.ONNX,
|
||||
model_kwargs=None,
|
||||
target_folder=base_folder,
|
||||
)
|
||||
model_onnx.save_pretrained(path=str(export_folder), safe_serialization=True)
|
||||
path_export_onnx_base = export_folder / 'onnx' / 'model.onnx'
|
||||
assert path_export_onnx_base.exists(), 'ONNX base weights not existing'
|
||||
print(f'Saved converted ONNX model under: {path_export_onnx_base}')
|
||||
|
||||
if quant:
|
||||
export_dynamic_quantized_onnx_model(
|
||||
model_onnx, quantization_config='avx2', model_name_or_path=str(export_folder)
|
||||
)
|
||||
path_export_onnx_quant = export_folder / STFRONNXFilenames.ONNX_Q_UINT8
|
||||
assert path_export_onnx_quant.exists(), 'ONNX quant weights not existing'
|
||||
print(f'Saved quantised ONNX model under: {path_export_onnx_quant}')
|
||||
os.remove(path_export_onnx_base)
|
||||
if optimise:
|
||||
export_optimized_onnx_model(
|
||||
model_onnx, optimization_config='O3', model_name_or_path=str(export_folder)
|
||||
)
|
||||
path_export_onnx_optim = export_folder / STFRONNXFilenames.ONNX_OPT_O3
|
||||
assert path_export_onnx_optim.exists(), 'ONNX optimised weights not existing'
|
||||
print(f'Saved optimised ONNX model under: {path_export_onnx_optim}')
|
||||
os.remove(path_export_onnx_base)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='STFR-Model-Loader',
|
||||
description=(
|
||||
'Helper program to pre-download SentenceTransformer models '
|
||||
'and convert them to different formats if desired'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
'-d', '--default', action='store_true', help='load model from default config'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-m',
|
||||
'--model',
|
||||
default=STFR_MODEL_NAME,
|
||||
help='model to load (full repo name from Hugging Face Hub)',
|
||||
)
|
||||
parser.add_argument(
|
||||
'-p', '--path', type=Path, default=None, help='path to save models to'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-c', '--convert', action='store_true', help='convert model to ONNX format'
|
||||
)
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--optim',
|
||||
action='store_true',
|
||||
help=(
|
||||
'optimise ONNX model with O3 profile, , model is '
|
||||
'always converted to ONNX beforehand'
|
||||
),
|
||||
)
|
||||
# parser.add_argument('--onnx', action='store_true', help='use ONNX backend')
|
||||
parser.add_argument(
|
||||
'--quant',
|
||||
action='store_true',
|
||||
help=(
|
||||
'quantise model with "AVX2" configuration, model is always '
|
||||
'converted to ONNX beforehand'
|
||||
),
|
||||
)
|
||||
|
||||
args = cast(TypedArgumentParser, parser.parse_args())
|
||||
use_default_model = args.default
|
||||
convert_model = args.convert
|
||||
optimise_model = args.optim
|
||||
quantise_model = args.quant
|
||||
|
||||
if use_default_model and convert_model:
|
||||
raise ValueError('Loading default model does not allow model conversion')
|
||||
|
||||
path_models: Path | None = None
|
||||
if args.path is not None:
|
||||
path_models = args.path.resolve()
|
||||
assert path_models.exists(), 'model saving path not existing'
|
||||
assert path_models.is_dir(), 'model saving path not a directory'
|
||||
|
||||
if args.default:
|
||||
_load_config_STFR_model()
|
||||
else:
|
||||
_model_conversion(
|
||||
model_name_repo=args.model,
|
||||
quant=quantise_model,
|
||||
optimise=optimise_model,
|
||||
target_folder=path_models,
|
||||
)
|
||||
29
src/tom_plugin/_tools/_run.py
Normal file
29
src/tom_plugin/_tools/_run.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import argparse
|
||||
import time
|
||||
from datetime import timedelta
|
||||
|
||||
from tom_plugin.pipeline import run_on_csv_data
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='TOM-Plugin-Demo-Runner',
|
||||
description='integration testing of provided pipelines in TOM-Plugin',
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest='subparser')
|
||||
parser_csv = subparsers.add_parser('csv', help='run on CSV data')
|
||||
parser_csv.add_argument('id', help='ID for data set')
|
||||
parser_csv.add_argument('filename', help='filename from configured input directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.subparser == 'csv':
|
||||
t1 = time.perf_counter()
|
||||
run_on_csv_data(args.id, args.filename)
|
||||
t2 = time.perf_counter()
|
||||
run_time = t2 - t1
|
||||
td = timedelta(seconds=run_time)
|
||||
print(f'Application runtime was. {td}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
2
src/tom_plugin/env_vars.txt
Normal file
2
src/tom_plugin/env_vars.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# list of all library's environment variables
|
||||
DOPT_TOM_PLUGIN_LIBRARY_USAGE : indicate that this wrapper application is in library mode (used to set different environment variables)
|
||||
276
src/tom_plugin/pipeline.py
Normal file
276
src/tom_plugin/pipeline.py
Normal file
@@ -0,0 +1,276 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from lang_main.analysis.graphs import (
|
||||
Graph,
|
||||
TokenGraph,
|
||||
save_to_GraphML,
|
||||
)
|
||||
from lang_main.constants import (
|
||||
CYTO_BASE_NETWORK_NAME,
|
||||
INPUT_PATH_FOLDER,
|
||||
SAVE_PATH_FOLDER,
|
||||
SKIP_GRAPH_POSTPROCESSING,
|
||||
SKIP_GRAPH_RESCALING,
|
||||
SKIP_GRAPH_STATIC_RENDERING,
|
||||
SKIP_PREPROCESSING,
|
||||
SKIP_TIME_ANALYSIS,
|
||||
SKIP_TOKEN_ANALYSIS,
|
||||
)
|
||||
from lang_main.errors import DependencyMissingError
|
||||
from lang_main.io import create_saving_folder, get_entry_point, load_pickle
|
||||
from lang_main.pipelines.base import Pipeline
|
||||
from lang_main.pipelines.predefined import (
|
||||
build_base_target_feature_pipe,
|
||||
build_merge_duplicates_pipe,
|
||||
build_timeline_pipe,
|
||||
build_tk_graph_pipe,
|
||||
build_tk_graph_post_pipe,
|
||||
build_tk_graph_render_pipe,
|
||||
build_tk_graph_rescaling_pipe,
|
||||
)
|
||||
from lang_main.types import (
|
||||
EntryPoints,
|
||||
ObjectID,
|
||||
PandasIndex,
|
||||
SpacyDoc,
|
||||
TimelineCandidates,
|
||||
)
|
||||
from pandas import DataFrame
|
||||
|
||||
# ** build pipelines
|
||||
pipe_target_feat_on_csv = build_base_target_feature_pipe()
|
||||
pipe_merge = build_merge_duplicates_pipe()
|
||||
pipe_token_analysis = build_tk_graph_pipe()
|
||||
pipe_graph_postprocessing = build_tk_graph_post_pipe()
|
||||
pipe_graph_rescaling = build_tk_graph_rescaling_pipe(
|
||||
save_result=True,
|
||||
exit_point=EntryPoints.TK_GRAPH_ANALYSIS_RESCALED,
|
||||
)
|
||||
pipe_timeline = build_timeline_pipe()
|
||||
|
||||
pipe_static_graph_rendering: Pipeline | None = None
|
||||
# rendering depending on optional dependencies
|
||||
try:
|
||||
pipe_static_graph_rendering = build_tk_graph_render_pipe(
|
||||
with_subgraphs=True,
|
||||
base_network_name=CYTO_BASE_NETWORK_NAME,
|
||||
)
|
||||
except (ImportError, DependencyMissingError):
|
||||
pass
|
||||
|
||||
|
||||
all_pipes: tuple[Pipeline | None, ...] = (
|
||||
pipe_target_feat_on_csv,
|
||||
pipe_merge,
|
||||
pipe_token_analysis,
|
||||
pipe_graph_postprocessing,
|
||||
pipe_graph_rescaling,
|
||||
pipe_static_graph_rendering,
|
||||
pipe_timeline,
|
||||
)
|
||||
|
||||
|
||||
# ENV variable: LANG_MAIN_SAVE_FOLDER : path for saving folder of current run
|
||||
# ENV variable: LANG_MAIN_INPUT_DATA : path for input data of current run
|
||||
def get_save_folder() -> Path:
|
||||
save_folder_env = os.environ.get('LANG_MAIN_SAVE_FOLDER', None)
|
||||
assert save_folder_env is not None, 'saving folder not defined as ENV variable'
|
||||
save_folder = Path(save_folder_env)
|
||||
assert save_folder.exists(), 'save folder does not exist'
|
||||
|
||||
return save_folder
|
||||
|
||||
|
||||
def get_path_to_dataset() -> Path:
|
||||
data_pth_env = os.environ.get('LANG_MAIN_INPUT_DATA', None)
|
||||
assert data_pth_env is not None, 'path to dataset not defined as ENV variable'
|
||||
data_pth = Path(data_pth_env)
|
||||
assert data_pth.exists(), 'path to dataset does not exist'
|
||||
|
||||
return data_pth
|
||||
|
||||
|
||||
def _set_save_folder(
|
||||
target_folder: Path,
|
||||
) -> None:
|
||||
# save_folder = get_save_folder()
|
||||
|
||||
for pipe in all_pipes:
|
||||
if pipe is not None:
|
||||
pipe.working_dir = target_folder
|
||||
|
||||
|
||||
# ** preparation
|
||||
def _prepare_run_on_csv(
|
||||
id: str,
|
||||
filename: str,
|
||||
) -> tuple[Path, Path]:
|
||||
# output directory for intermediate results
|
||||
print(f'Saving path: {SAVE_PATH_FOLDER}', flush=True)
|
||||
target_folder = SAVE_PATH_FOLDER / id
|
||||
create_saving_folder(
|
||||
saving_path_folder=target_folder,
|
||||
overwrite_existing=True,
|
||||
)
|
||||
assert target_folder.exists(), 'target folder not existing after creation'
|
||||
# data set
|
||||
data_pth = (INPUT_PATH_FOLDER / filename).with_suffix('.csv')
|
||||
|
||||
assert data_pth.exists(), 'path to data not existing'
|
||||
assert data_pth.is_file(), 'data is not a file'
|
||||
print(f'Data path: {data_pth}', flush=True)
|
||||
|
||||
return target_folder, data_pth
|
||||
|
||||
|
||||
# ** preprocessing pipeline
|
||||
def _run_preprocessing_on_csv(
|
||||
target_folder: Path,
|
||||
data_pth: Path,
|
||||
) -> Path:
|
||||
# data_pth = get_path_to_dataset()
|
||||
# run pipelines
|
||||
ret = typing.cast(
|
||||
tuple[DataFrame], pipe_target_feat_on_csv.run(starting_values=(data_pth,))
|
||||
)
|
||||
target_feat_data = ret[0]
|
||||
_ = typing.cast(tuple[DataFrame], pipe_merge.run(starting_values=(target_feat_data,)))
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
# ** token analysis
|
||||
def _run_token_analysis(
|
||||
target_folder: Path,
|
||||
) -> Path:
|
||||
# load entry point
|
||||
# save_folder = get_save_folder()
|
||||
entry_point_path = get_entry_point(target_folder, EntryPoints.TOKEN_ANALYSIS)
|
||||
loaded_results = cast(tuple[DataFrame], load_pickle(entry_point_path))
|
||||
preprocessed_data = loaded_results[0]
|
||||
# build token graph
|
||||
(tk_graph, _) = typing.cast(
|
||||
tuple[TokenGraph, dict[PandasIndex, SpacyDoc] | None],
|
||||
pipe_token_analysis.run(starting_values=(preprocessed_data,)),
|
||||
)
|
||||
tk_graph.to_GraphML(target_folder, filename='TokenGraph', directed=False)
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
def _run_graph_postprocessing(
|
||||
target_folder: Path,
|
||||
) -> Path:
|
||||
# load entry point
|
||||
# save_folder = get_save_folder()
|
||||
entry_point_path = get_entry_point(target_folder, EntryPoints.TK_GRAPH_POST)
|
||||
loaded_results = cast(
|
||||
tuple[TokenGraph, dict[PandasIndex, SpacyDoc] | None],
|
||||
load_pickle(entry_point_path),
|
||||
)
|
||||
tk_graph = loaded_results[0]
|
||||
# filter graph by edge weight and remove single nodes (no connection)
|
||||
ret = cast(tuple[TokenGraph], pipe_graph_postprocessing.run(starting_values=(tk_graph,)))
|
||||
tk_graph_filtered = ret[0]
|
||||
tk_graph_filtered.to_GraphML(
|
||||
target_folder, filename='TokenGraph-filtered', directed=False
|
||||
)
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
def _run_graph_edge_rescaling(
|
||||
target_folder: Path,
|
||||
) -> Path:
|
||||
# load entry point
|
||||
# save_folder = get_save_folder()
|
||||
entry_point_path = get_entry_point(target_folder, EntryPoints.TK_GRAPH_ANALYSIS)
|
||||
loaded_results = cast(
|
||||
tuple[TokenGraph],
|
||||
load_pickle(entry_point_path),
|
||||
)
|
||||
tk_graph = loaded_results[0]
|
||||
tk_graph_rescaled, tk_graph_rescaled_undirected = cast(
|
||||
tuple[TokenGraph, Graph], pipe_graph_rescaling.run(starting_values=(tk_graph,))
|
||||
)
|
||||
tk_graph_rescaled.to_GraphML(
|
||||
target_folder, filename='TokenGraph-directed-rescaled', directed=False
|
||||
)
|
||||
save_to_GraphML(
|
||||
tk_graph_rescaled_undirected,
|
||||
saving_path=target_folder,
|
||||
filename='TokenGraph-undirected-rescaled',
|
||||
)
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
def _run_static_graph_rendering(
|
||||
target_folder: Path,
|
||||
) -> Path:
|
||||
# load entry point
|
||||
# save_folder = get_save_folder()
|
||||
entry_point_path = get_entry_point(
|
||||
target_folder,
|
||||
EntryPoints.TK_GRAPH_ANALYSIS_RESCALED,
|
||||
)
|
||||
loaded_results = cast(
|
||||
tuple[TokenGraph, Graph],
|
||||
load_pickle(entry_point_path),
|
||||
)
|
||||
_ = loaded_results[0]
|
||||
tk_graph_rescaled_undirected = loaded_results[1]
|
||||
|
||||
if pipe_static_graph_rendering is not None:
|
||||
_ = pipe_static_graph_rendering.run(starting_values=(tk_graph_rescaled_undirected,))
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
# ** time analysis
|
||||
def _run_time_analysis(
|
||||
target_folder: Path,
|
||||
) -> Path:
|
||||
# load entry point
|
||||
# save_folder = get_save_folder()
|
||||
entry_point_path = get_entry_point(target_folder, EntryPoints.TIMELINE)
|
||||
loaded_results = cast(tuple[DataFrame], load_pickle(entry_point_path))
|
||||
preprocessed_data = loaded_results[0]
|
||||
|
||||
_ = cast(
|
||||
tuple[TimelineCandidates, dict[ObjectID, str]],
|
||||
pipe_timeline.run(starting_values=(preprocessed_data,)),
|
||||
)
|
||||
|
||||
return target_folder
|
||||
|
||||
|
||||
def _build_pipeline_container(
|
||||
target_folder: Path,
|
||||
) -> Pipeline:
|
||||
# save_folder = get_save_folder()
|
||||
# container = PipelineContainer(name='Pipeline-Container-Base', working_dir=target_folder)
|
||||
container = Pipeline(name='Pipeline-Base', working_dir=target_folder)
|
||||
container.add(_run_preprocessing_on_csv, skip=SKIP_PREPROCESSING)
|
||||
container.add(_run_token_analysis, skip=SKIP_TOKEN_ANALYSIS)
|
||||
container.add(_run_graph_postprocessing, skip=SKIP_GRAPH_POSTPROCESSING)
|
||||
container.add(_run_graph_edge_rescaling, skip=SKIP_GRAPH_RESCALING)
|
||||
container.add(_run_static_graph_rendering, skip=SKIP_GRAPH_STATIC_RENDERING)
|
||||
container.add(_run_time_analysis, skip=SKIP_TIME_ANALYSIS)
|
||||
|
||||
return container
|
||||
|
||||
|
||||
def run_on_csv_data(
|
||||
id: str,
|
||||
filename: str,
|
||||
) -> None:
|
||||
target_folder, data_pth = _prepare_run_on_csv(id=id, filename=filename)
|
||||
_set_save_folder(target_folder)
|
||||
procedure = _build_pipeline_container(target_folder)
|
||||
procedure.run(starting_values=(target_folder, data_pth))
|
||||
Reference in New Issue
Block a user