improved imports, dummy dataset generation
This commit is contained in:
@@ -18,7 +18,7 @@ p4c.py4cytoscape_logger.detail_logger.addHandler(logging.NullHandler())
|
||||
|
||||
# ** lang-main config
|
||||
logging.Formatter.converter = gmtime
|
||||
LOG_FMT: Final[str] = '%(asctime)s | %(module)s:%(levelname)s | %(message)s'
|
||||
LOG_FMT: Final[str] = '%(asctime)s | lang_main:%(module)s:%(levelname)s | %(message)s'
|
||||
LOG_DATE_FMT: Final[str] = '%Y-%m-%d %H:%M:%S +0000'
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
|
||||
@@ -70,7 +70,7 @@ def load_raw_data(
|
||||
filepath_or_buffer=path,
|
||||
sep=';',
|
||||
encoding='cp1252',
|
||||
parse_dates=date_cols,
|
||||
parse_dates=list(date_cols),
|
||||
dayfirst=True,
|
||||
)
|
||||
logger.info('Loaded dataset successfully.')
|
||||
@@ -278,7 +278,8 @@ def merge_similarity_dupl(
|
||||
return (merged_data,)
|
||||
|
||||
|
||||
#####################################################################
|
||||
# ** #################################################################################
|
||||
# TODO check removal
|
||||
def build_embedding_map(
|
||||
data: Series,
|
||||
model: GermanSpacyModel | SentenceTransformer,
|
||||
|
||||
@@ -8,10 +8,13 @@ from tqdm.auto import tqdm # TODO: check deletion
|
||||
from lang_main.analysis.shared import (
|
||||
candidates_by_index,
|
||||
entry_wise_cleansing,
|
||||
pattern_escape_seq_sentences,
|
||||
similar_index_connection_graph,
|
||||
similar_index_groups,
|
||||
)
|
||||
from lang_main.constants import (
|
||||
NAME_DELTA_FEAT_TO_NEXT_FAILURE,
|
||||
NAME_DELTA_FEAT_TO_REPAIR,
|
||||
)
|
||||
from lang_main.loggers import logger_timeline as logger
|
||||
from lang_main.types import (
|
||||
DataFrameTLFiltered,
|
||||
@@ -94,7 +97,7 @@ def calc_delta_to_repair(
|
||||
data: DataFrame,
|
||||
date_feature_start: str = 'ErstellungsDatum',
|
||||
date_feature_end: str = 'ErledigungsDatum',
|
||||
name_delta_feature: str = 'delta_to_repair',
|
||||
name_delta_feature: str = NAME_DELTA_FEAT_TO_REPAIR,
|
||||
convert_to_days: bool = True,
|
||||
) -> tuple[DataFrame]:
|
||||
logger.info('Calculating time differences between start and end of operations...')
|
||||
@@ -316,7 +319,7 @@ def filter_timeline_cands(
|
||||
def calc_delta_to_next_failure(
|
||||
data: DataFrameTLFiltered,
|
||||
date_feature: str = 'ErstellungsDatum',
|
||||
name_delta_feature: str = 'delta_to_next_failure',
|
||||
name_delta_feature: str = NAME_DELTA_FEAT_TO_NEXT_FAILURE,
|
||||
convert_to_days: bool = True,
|
||||
) -> DataFrameTLFiltered:
|
||||
data = data.copy()
|
||||
|
||||
@@ -5,9 +5,6 @@ from typing import Literal, cast, overload
|
||||
|
||||
from dateutil.parser import parse
|
||||
from pandas import DataFrame
|
||||
from spacy.language import Language as GermanSpacyModel
|
||||
from spacy.tokens.doc import Doc as SpacyDoc
|
||||
from spacy.tokens.token import Token as SpacyToken
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lang_main.analysis.graphs import (
|
||||
@@ -15,7 +12,12 @@ from lang_main.analysis.graphs import (
|
||||
update_graph,
|
||||
)
|
||||
from lang_main.loggers import logger_token_analysis as logger
|
||||
from lang_main.types import PandasIndex
|
||||
from lang_main.types import (
|
||||
PandasIndex,
|
||||
SpacyDoc,
|
||||
SpacyModel,
|
||||
SpacyToken,
|
||||
)
|
||||
|
||||
# ** POS
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
|
||||
@@ -147,7 +149,7 @@ def add_doc_info_to_graph(
|
||||
@overload
|
||||
def build_token_graph(
|
||||
data: DataFrame,
|
||||
model: GermanSpacyModel,
|
||||
model: SpacyModel,
|
||||
*,
|
||||
target_feature: str = ...,
|
||||
weights_feature: str | None = ...,
|
||||
@@ -161,7 +163,7 @@ def build_token_graph(
|
||||
@overload
|
||||
def build_token_graph(
|
||||
data: DataFrame,
|
||||
model: GermanSpacyModel,
|
||||
model: SpacyModel,
|
||||
*,
|
||||
target_feature: str = ...,
|
||||
weights_feature: str | None = ...,
|
||||
@@ -174,7 +176,7 @@ def build_token_graph(
|
||||
|
||||
def build_token_graph(
|
||||
data: DataFrame,
|
||||
model: GermanSpacyModel,
|
||||
model: SpacyModel,
|
||||
*,
|
||||
target_feature: str = 'entry',
|
||||
weights_feature: str | None = None,
|
||||
@@ -233,7 +235,7 @@ def build_token_graph(
|
||||
|
||||
def build_token_graph_simple(
|
||||
data: DataFrame,
|
||||
model: GermanSpacyModel,
|
||||
model: SpacyModel,
|
||||
) -> tuple[TokenGraph, dict[PandasIndex, SpacyDoc]]:
|
||||
graph = TokenGraph()
|
||||
model_input = cast(tuple[str], tuple(data['entry'].to_list()))
|
||||
@@ -264,7 +266,7 @@ def build_token_graph_simple(
|
||||
|
||||
def build_token_graph_old(
|
||||
data: DataFrame,
|
||||
model: GermanSpacyModel,
|
||||
model: SpacyModel,
|
||||
) -> tuple[TokenGraph]:
|
||||
# empty NetworkX directed graph
|
||||
# graph = nx.DiGraph()
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import spacy
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from spacy.language import Language as GermanSpacyModel
|
||||
|
||||
# TODO check removal
|
||||
# import spacy
|
||||
# from sentence_transformers import SentenceTransformer
|
||||
# from spacy.language import Language as GermanSpacyModel
|
||||
from lang_main import CONFIG, CYTO_PATH_STYLESHEET
|
||||
from lang_main.types import CytoLayoutProperties, CytoLayouts, STFRDeviceTypes
|
||||
from lang_main import model_loader as m_load
|
||||
from lang_main.types import (
|
||||
CytoLayoutProperties,
|
||||
CytoLayouts,
|
||||
LanguageModels,
|
||||
ModelLoaderMap,
|
||||
STFRDeviceTypes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'CONFIG',
|
||||
@@ -38,14 +45,33 @@ SKIP_TIME_ANALYSIS: Final[bool] = CONFIG['control']['time_analysis_skip']
|
||||
|
||||
|
||||
# ** models
|
||||
# ** sentence_transformers
|
||||
# ** loading
|
||||
SPACY_MODEL_NAME: Final[str] = 'de_dep_news_trf'
|
||||
STFR_MODEL_NAME: Final[str] = 'sentence-transformers/all-mpnet-base-v2'
|
||||
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
|
||||
STFR_MODEL: Final[SentenceTransformer] = SentenceTransformer(
|
||||
'sentence-transformers/all-mpnet-base-v2', device=STFR_DEVICE
|
||||
)
|
||||
MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||
'func': m_load.load_sentence_transformer,
|
||||
'kwargs': {
|
||||
'model_name': STFR_MODEL_NAME,
|
||||
'device': STFR_DEVICE,
|
||||
},
|
||||
},
|
||||
LanguageModels.SPACY: {
|
||||
'func': m_load.load_spacy,
|
||||
'kwargs': {
|
||||
'model_name': SPACY_MODEL_NAME,
|
||||
},
|
||||
},
|
||||
}
|
||||
# ** sentence_transformers
|
||||
|
||||
# STFR_MODEL: Final[SentenceTransformer] = SentenceTransformer(
|
||||
# 'sentence-transformers/all-mpnet-base-v2', device=STFR_DEVICE
|
||||
# )
|
||||
|
||||
# ** spacy
|
||||
SPCY_MODEL: Final[GermanSpacyModel] = spacy.load('de_dep_news_trf')
|
||||
# SPCY_MODEL: Final[GermanSpacyModel] = spacy.load('de_dep_news_trf')
|
||||
|
||||
# ** export
|
||||
# ** preprocessing
|
||||
@@ -82,6 +108,7 @@ CYTO_STYLESHEET_NAME: Final[str] = 'lang_main'
|
||||
CYTO_SELECTION_PROPERTY: Final[str] = 'node_selection'
|
||||
CYTO_NUMBER_SUBGRAPHS: Final[int] = 5
|
||||
CYTO_ITER_NEIGHBOUR_DEPTH: Final[int] = 2
|
||||
CYTO_NETWORK_ZOOM_FACTOR: Final[float] = 0.96
|
||||
|
||||
# ** time_analysis.uniqueness
|
||||
THRESHOLD_UNIQUE_TEXTS: Final[int] = CONFIG['time_analysis']['uniqueness'][
|
||||
|
||||
Binary file not shown.
@@ -93,9 +93,10 @@ def get_entry_point(
|
||||
saving_path: Path,
|
||||
filename: str,
|
||||
file_ext: str = '.pkl',
|
||||
check_existence: bool = True,
|
||||
) -> Path:
|
||||
entry_point_path = (saving_path / filename).with_suffix(file_ext)
|
||||
if not entry_point_path.exists():
|
||||
if check_existence and not entry_point_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f'Could not find provided entry data under path: >>{entry_point_path}<<'
|
||||
)
|
||||
|
||||
53
src/lang_main/model_loader.py
Normal file
53
src/lang_main/model_loader.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, overload
|
||||
|
||||
import spacy
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from lang_main.types import (
|
||||
LanguageModels,
|
||||
Model,
|
||||
ModelLoaderMap,
|
||||
SpacyModel,
|
||||
STFRDeviceTypes,
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def instantiate_model(
|
||||
model_load_map: ModelLoaderMap,
|
||||
model: Literal[LanguageModels.SENTENCE_TRANSFORMER],
|
||||
) -> SentenceTransformer: ...
|
||||
|
||||
|
||||
@overload
|
||||
def instantiate_model(
|
||||
model_load_map: ModelLoaderMap,
|
||||
model: Literal[LanguageModels.SPACY],
|
||||
) -> SpacyModel: ...
|
||||
|
||||
|
||||
def instantiate_model(
|
||||
model_load_map: ModelLoaderMap,
|
||||
model: LanguageModels,
|
||||
) -> Model:
|
||||
if model not in model_load_map:
|
||||
raise KeyError(f'Model >>{model}<< not known. Choose from: {model_load_map.keys()}')
|
||||
builder_func = model_load_map[model]['func']
|
||||
func_kwargs = model_load_map[model]['kwargs']
|
||||
|
||||
return builder_func(**func_kwargs)
|
||||
|
||||
|
||||
def load_spacy(
|
||||
model_name: str,
|
||||
) -> SpacyModel:
|
||||
return spacy.load(model_name)
|
||||
|
||||
|
||||
def load_sentence_transformer(
|
||||
model_name: str,
|
||||
device: STFRDeviceTypes,
|
||||
) -> SentenceTransformer:
|
||||
return SentenceTransformer(model_name_or_path=model_name, device=device)
|
||||
@@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from lang_main import model_loader as m_load
|
||||
from lang_main.analysis import graphs
|
||||
from lang_main.analysis.preprocessing import (
|
||||
analyse_feature,
|
||||
@@ -29,10 +30,9 @@ from lang_main.constants import (
|
||||
DATE_COLS,
|
||||
FEATURE_NAME_OBJ_ID,
|
||||
MODEL_INPUT_FEATURES,
|
||||
MODEL_LOADER_MAP,
|
||||
NAME_DELTA_FEAT_TO_REPAIR,
|
||||
SAVE_PATH_FOLDER,
|
||||
SPCY_MODEL,
|
||||
STFR_MODEL,
|
||||
THRESHOLD_AMOUNT_CHARACTERS,
|
||||
THRESHOLD_EDGE_WEIGHT,
|
||||
THRESHOLD_NUM_ACTIVITIES,
|
||||
@@ -43,7 +43,18 @@ from lang_main.constants import (
|
||||
)
|
||||
from lang_main.pipelines.base import Pipeline
|
||||
from lang_main.render import cytoscape as cyto
|
||||
from lang_main.types import EntryPoints
|
||||
from lang_main.types import EntryPoints, LanguageModels
|
||||
|
||||
# ** Models
|
||||
STFR_MODEL = m_load.instantiate_model(
|
||||
model_load_map=MODEL_LOADER_MAP,
|
||||
model=LanguageModels.SENTENCE_TRANSFORMER,
|
||||
)
|
||||
|
||||
SPACY_MODEL = m_load.instantiate_model(
|
||||
model_load_map=MODEL_LOADER_MAP,
|
||||
model=LanguageModels.SPACY,
|
||||
)
|
||||
|
||||
|
||||
# ** pipeline configuration
|
||||
@@ -61,7 +72,7 @@ def build_base_target_feature_pipe() -> Pipeline:
|
||||
pipe_target_feat.add(
|
||||
entry_wise_cleansing,
|
||||
{
|
||||
'target_feature': ('VorgangsBeschreibung',),
|
||||
'target_features': ('VorgangsBeschreibung',),
|
||||
'cleansing_func': clean_string_slim,
|
||||
},
|
||||
save_result=True,
|
||||
@@ -106,7 +117,6 @@ def build_base_target_feature_pipe() -> Pipeline:
|
||||
# ** Merge duplicates
|
||||
def build_merge_duplicates_pipe() -> Pipeline:
|
||||
pipe_merge = Pipeline(name='Merge_Duplicates', working_dir=SAVE_PATH_FOLDER)
|
||||
# pipe_merge.add(merge_similarity_dupl, save_result=True)
|
||||
pipe_merge.add(
|
||||
numeric_pre_filter_feature,
|
||||
{
|
||||
@@ -134,7 +144,7 @@ def build_tk_graph_pipe() -> Pipeline:
|
||||
pipe_token_analysis.add(
|
||||
build_token_graph,
|
||||
{
|
||||
'model': SPCY_MODEL,
|
||||
'model': SPACY_MODEL,
|
||||
'target_feature': 'entry',
|
||||
'weights_feature': 'num_occur',
|
||||
'batch_idx_feature': 'batched_idxs',
|
||||
|
||||
@@ -14,6 +14,7 @@ from lang_main.constants import (
|
||||
CYTO_ITER_NEIGHBOUR_DEPTH,
|
||||
CYTO_LAYOUT_NAME,
|
||||
CYTO_LAYOUT_PROPERTIES,
|
||||
CYTO_NETWORK_ZOOM_FACTOR,
|
||||
CYTO_NUMBER_SUBGRAPHS,
|
||||
CYTO_PATH_STYLESHEET,
|
||||
CYTO_SANDBOX_NAME,
|
||||
@@ -125,6 +126,17 @@ def reset_current_network_to_base() -> None:
|
||||
p4c.set_current_network(CYTO_BASE_NETWORK_NAME)
|
||||
|
||||
|
||||
def fit_content(
|
||||
zoom_factor: float = CYTO_NETWORK_ZOOM_FACTOR,
|
||||
network_name: str = CYTO_BASE_NETWORK_NAME,
|
||||
) -> None:
|
||||
p4c.hide_all_panels()
|
||||
p4c.fit_content(selected_only=False, network=network_name)
|
||||
zoom_current = p4c.get_network_zoom(network=network_name)
|
||||
zoom_new = zoom_current * zoom_factor
|
||||
p4c.set_network_zoom_bypass(zoom_new, bypass=False, network=network_name)
|
||||
|
||||
|
||||
def export_network_to_image(
|
||||
filename: str,
|
||||
target_folder: Path = SAVE_PATH_FOLDER,
|
||||
@@ -156,9 +168,10 @@ def export_network_to_image(
|
||||
if filetype == 'SVG':
|
||||
text_as_font = False
|
||||
|
||||
# close non-necessary windows and fit graph in frame before image display
|
||||
fit_content(network_name=network_name)
|
||||
# image is generated in sandbox directory and transferred to target destination
|
||||
# (preparation for remote instances of Cytoscape)
|
||||
# TODO close non-necessary windows before image display
|
||||
p4c.export_image(
|
||||
filename=filename,
|
||||
type=filetype,
|
||||
@@ -168,7 +181,6 @@ def export_network_to_image(
|
||||
export_text_as_font=text_as_font,
|
||||
page_size=pdf_export_page_size,
|
||||
)
|
||||
# TODO change back to Cytoscape 3.10 and above
|
||||
# TODO remove if Cytoscape >= 3.10.* is running in container
|
||||
# p4c.export_image(
|
||||
# filename=filename,
|
||||
@@ -211,7 +223,7 @@ def layout_network(
|
||||
logger.debug('Applying layout to network...')
|
||||
p4c.set_layout_properties(layout_name, layout_properties)
|
||||
p4c.layout_network(layout_name=layout_name, network=network_name)
|
||||
p4c.fit_content(selected_only=False, network=network_name)
|
||||
fit_content(network_name=network_name)
|
||||
logger.debug('Layout application to network successful.')
|
||||
|
||||
|
||||
@@ -245,7 +257,7 @@ def apply_style_to_network(
|
||||
"""
|
||||
logger.debug('Applying style to network...')
|
||||
styles_avail = cast(list[str], p4c.get_visual_style_names())
|
||||
if CYTO_STYLESHEET_NAME not in styles_avail:
|
||||
if style_name not in styles_avail:
|
||||
if not pth_to_stylesheet.exists():
|
||||
# existence for standard path verified at import, but not for other
|
||||
# provided paths
|
||||
@@ -278,7 +290,7 @@ def apply_style_to_network(
|
||||
node_size_property,
|
||||
number_scheme=scheme,
|
||||
mapping_type='c',
|
||||
style_name='lang_main',
|
||||
style_name=style_name,
|
||||
default_number=min_node_size,
|
||||
)
|
||||
p4c.set_node_size_mapping(**node_size_map)
|
||||
@@ -289,7 +301,7 @@ def apply_style_to_network(
|
||||
# p4c.set_node_size_bypass(nodes_SUID, new_sizes=min_node_size, network=network_name)
|
||||
# p4c.set_visual_style(style_name, network=network_name)
|
||||
# time.sleep(1) # if not waited image export could be without applied style
|
||||
p4c.fit_content(selected_only=False, network=network_name)
|
||||
fit_content(network_name=network_name)
|
||||
logger.debug('Style application to network successful.')
|
||||
|
||||
|
||||
@@ -384,7 +396,7 @@ def make_subnetwork(
|
||||
network=network_name,
|
||||
)
|
||||
p4c.set_current_network(subnetwork_name)
|
||||
p4c.fit_content(selected_only=False, network=subnetwork_name)
|
||||
|
||||
if export_image:
|
||||
time.sleep(1)
|
||||
export_network_to_image(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from collections.abc import Hashable
|
||||
from collections.abc import Callable, Hashable
|
||||
from typing import (
|
||||
Any,
|
||||
Literal,
|
||||
@@ -10,9 +10,20 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
from pandas import DataFrame
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from spacy.language import Language as SpacyModel
|
||||
from spacy.tokens.doc import Doc as SpacyDoc
|
||||
from spacy.tokens.token import Token as SpacyToken
|
||||
from torch import Tensor
|
||||
|
||||
__all__ = [
|
||||
'SentenceTransformer',
|
||||
'SpacyModel',
|
||||
'SpacyDoc',
|
||||
'SpacyToken',
|
||||
'Tensor',
|
||||
]
|
||||
|
||||
|
||||
# ** logging
|
||||
class LoggingLevels(enum.IntEnum):
|
||||
@@ -23,6 +34,24 @@ class LoggingLevels(enum.IntEnum):
|
||||
CRITICAL = 50
|
||||
|
||||
|
||||
# ** models
|
||||
class LanguageModels(enum.StrEnum):
|
||||
SENTENCE_TRANSFORMER = enum.auto()
|
||||
SPACY = enum.auto()
|
||||
|
||||
|
||||
Model: TypeAlias = SentenceTransformer | SpacyModel
|
||||
ModelLoaderFunc: TypeAlias = Callable[..., Model]
|
||||
|
||||
|
||||
class ModelLoaderInfo(TypedDict):
|
||||
func: ModelLoaderFunc
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
|
||||
ModelLoaderMap: TypeAlias = dict[LanguageModels, ModelLoaderInfo]
|
||||
|
||||
|
||||
# ** devices
|
||||
class STFRDeviceTypes(enum.StrEnum):
|
||||
CPU = enum.auto()
|
||||
|
||||
Reference in New Issue
Block a user