started adding comprehensive unit tests
This commit is contained in:
@@ -1,51 +0,0 @@
|
||||
import inspect
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from time import gmtime
|
||||
from typing import Any, Final
|
||||
import warnings
|
||||
|
||||
from lang_main.io import load_toml_config
|
||||
|
||||
__all__ = [
|
||||
'CALLER_PATH',
|
||||
]
|
||||
|
||||
logging.Formatter.converter = gmtime
|
||||
LOG_FMT: Final[str] = '%(asctime)s | %(module)s:%(levelname)s | %(message)s'
|
||||
LOG_DATE_FMT: Final[str] = '%Y-%m-%d %H:%M:%S +0000'
|
||||
logging.basicConfig(
|
||||
stream=sys.stdout,
|
||||
format=LOG_FMT,
|
||||
datefmt=LOG_DATE_FMT,
|
||||
)
|
||||
|
||||
CONFIG_FILENAME: Final[str] = 'lang_main_config.toml'
|
||||
USE_INTERNAL_CONFIG: Final[bool] = True
|
||||
pkg_dir = Path(__file__).parent
|
||||
cfg_path_internal = pkg_dir / CONFIG_FILENAME
|
||||
caller_file = Path(inspect.stack()[-1].filename)
|
||||
CALLER_PATH: Final[Path] = caller_file.parent.resolve()
|
||||
|
||||
# load config data: internal/external
|
||||
if USE_INTERNAL_CONFIG:
|
||||
loaded_cfg = load_toml_config(path_to_toml=cfg_path_internal)
|
||||
else:
|
||||
cfg_path_external = CALLER_PATH / CONFIG_FILENAME
|
||||
if not caller_file.exists():
|
||||
warnings.warn('Caller file could not be correctly retrieved.')
|
||||
if not cfg_path_external.exists():
|
||||
shutil.copy(cfg_path_internal, cfg_path_external)
|
||||
sys.exit(
|
||||
(
|
||||
'No config file was found. A new one with default values was created '
|
||||
'in the execution path. Please fill in the necessary values and '
|
||||
'restart the programm.'
|
||||
)
|
||||
)
|
||||
# raise NotImplementedError("External config data not implemented yet.")
|
||||
loaded_cfg = load_toml_config(path_to_toml=cfg_path_external)
|
||||
|
||||
CONFIG: Final[dict[str, Any]] = loaded_cfg.copy()
|
||||
@@ -1,14 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Final
|
||||
|
||||
from lang_main.config import load_toml_config
|
||||
|
||||
_has_py4cyto: bool = True
|
||||
try:
|
||||
import py4cytoscape as p4c
|
||||
except ImportError:
|
||||
_has_py4cyto = False
|
||||
|
||||
from lang_main.io import load_toml_config
|
||||
# ** external packages config
|
||||
# ** Huggingface Hub caching
|
||||
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = 'set'
|
||||
|
||||
# ** py4cytoscape config
|
||||
if _has_py4cyto:
|
||||
@@ -20,6 +25,7 @@ if _has_py4cyto:
|
||||
p4c.py4cytoscape_logger.detail_logger.addHandler(logging.NullHandler())
|
||||
|
||||
# ** lang-main config
|
||||
BASE_FOLDERNAME: Final[str] = 'lang-main'
|
||||
CONFIG_FILENAME: Final[str] = 'lang_main_config.toml'
|
||||
CYTO_STYLESHEET_FILENAME: Final[str] = r'cytoscape_config/lang_main.xml'
|
||||
PREFER_INTERNAL_CONFIG: Final[bool] = False
|
||||
@@ -75,27 +81,71 @@ def search_iterative(
|
||||
pattern to look for, first match will be returned,
|
||||
by default CONFIG_FILENAME
|
||||
stop_folder_name : str, optional
|
||||
name of the last folder in the directory tree to search, by default 'python'
|
||||
name of the last folder in the directory tree to search, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path | None
|
||||
Path if corresponding object was found, None otherwise
|
||||
"""
|
||||
cfg_path: Path | None = None
|
||||
file_path: Path | None = None
|
||||
stop_folder_reached: bool = False
|
||||
for it in range(len(starting_path.parents)):
|
||||
search_path = starting_path.parents[it] # do not look in library folder
|
||||
res = tuple(search_path.glob(glob_pattern))
|
||||
if res:
|
||||
cfg_path = res[0]
|
||||
file_path = res[0]
|
||||
break
|
||||
elif stop_folder_reached:
|
||||
break
|
||||
|
||||
if stop_folder_name is not None and search_path.name == stop_folder_name:
|
||||
# library is placed inside a whole python installation for deployment
|
||||
# if this folder is reached, only look up one parent above
|
||||
stop_folder_reached = True
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def search_base_path(
|
||||
starting_path: Path,
|
||||
stop_folder_name: str | None = None,
|
||||
) -> Path | None:
|
||||
"""Iteratively searches the parent directories of the starting path
|
||||
and look for folders matching the given name. If a match is encountered,
|
||||
the parent path will be returned.
|
||||
|
||||
Example:
|
||||
starting_path = path/to/start/folder
|
||||
stop_folder_name = 'to'
|
||||
returned path = 'path/'
|
||||
|
||||
Parameters
|
||||
----------
|
||||
starting_path : Path
|
||||
non-inclusive starting path
|
||||
stop_folder_name : str, optional
|
||||
name of the last folder in the directory tree to search, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path | None
|
||||
Path if corresponding base path was found, None otherwise
|
||||
"""
|
||||
stop_folder_path: Path | None = None
|
||||
base_path: Path | None = None
|
||||
for it in range(len(starting_path.parents)):
|
||||
search_path = starting_path.parents[it] # do not look in library folder
|
||||
if stop_folder_name is not None and search_path.name == stop_folder_name:
|
||||
# library is placed inside a whole python installation for deployment
|
||||
# only look up to this folder
|
||||
stop_folder_path = search_path
|
||||
break
|
||||
|
||||
return cfg_path
|
||||
if stop_folder_path is not None:
|
||||
base_path = stop_folder_path.parent
|
||||
|
||||
return base_path
|
||||
|
||||
|
||||
def load_cfg() -> dict[str, Any]:
|
||||
@@ -121,6 +171,10 @@ def load_cfg() -> dict[str, Any]:
|
||||
|
||||
|
||||
CONFIG: Final[dict[str, Any]] = load_cfg()
|
||||
base_parent_path = search_base_path(pkg_dir, stop_folder_name=BASE_FOLDERNAME)
|
||||
if base_parent_path is None:
|
||||
raise FileNotFoundError('Could not resolve base path of library')
|
||||
BASE_PATH: Final[Path] = base_parent_path
|
||||
|
||||
|
||||
# ** Cytoscape configuration
|
||||
|
||||
@@ -48,9 +48,9 @@ def save_to_GraphML(
|
||||
def get_graph_metadata(
|
||||
graph: Graph | DiGraph,
|
||||
logging: bool = LOGGING_DEFAULT_GRAPHS,
|
||||
) -> dict[str, int]:
|
||||
) -> dict[str, float]:
|
||||
# info about graph
|
||||
graph_info: dict[str, int] = {}
|
||||
graph_info: dict[str, float] = {}
|
||||
# nodes and edges
|
||||
num_nodes = len(graph.nodes)
|
||||
num_edges = len(graph.edges)
|
||||
@@ -96,15 +96,6 @@ def update_graph(
|
||||
child: Hashable | None = None,
|
||||
weight_connection: int | None = None,
|
||||
) -> None:
|
||||
# !! not necessary to check for existence of nodes
|
||||
# !! feature already implemented in NetworkX ``add_edge``
|
||||
"""
|
||||
# check if nodes already in Graph
|
||||
if parent not in graph:
|
||||
graph.add_node(parent)
|
||||
if child not in graph:
|
||||
graph.add_node(child)
|
||||
"""
|
||||
if weight_connection is None:
|
||||
weight_connection = 1
|
||||
# check if edge not in Graph
|
||||
@@ -115,9 +106,7 @@ def update_graph(
|
||||
graph.add_edge(parent, child, weight=weight_connection)
|
||||
else:
|
||||
# update edge
|
||||
weight = graph[parent][child]['weight']
|
||||
weight += weight_connection
|
||||
graph[parent][child]['weight'] = weight
|
||||
graph[parent][child]['weight'] += weight_connection
|
||||
|
||||
|
||||
# build undirected adjacency matrix
|
||||
@@ -249,7 +238,8 @@ def filter_graph_by_node_degree(
|
||||
bound_lower: int | None,
|
||||
bound_upper: int | None,
|
||||
) -> TokenGraph:
|
||||
"""filters all nodes which are within the provided bounds by their degree
|
||||
"""filters all nodes which are within the provided bounds by their degree,
|
||||
inclusive limits: bound_lower <= node_degree <= bound_upper are retained
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -266,13 +256,14 @@ def filter_graph_by_node_degree(
|
||||
# filter nodes by degree
|
||||
original_graph_nodes = copy.deepcopy(graph.nodes)
|
||||
filtered_graph = graph.copy()
|
||||
filtered_graph_degree = copy.deepcopy(filtered_graph.degree)
|
||||
|
||||
if not any([bound_lower, bound_upper]):
|
||||
logger.warning('No bounds provided, returning original graph.')
|
||||
return filtered_graph
|
||||
|
||||
for node in original_graph_nodes:
|
||||
degree = filtered_graph.degree[node] # type: ignore
|
||||
degree = cast(int, filtered_graph_degree[node]) # type: ignore
|
||||
if bound_lower is not None and degree < bound_lower:
|
||||
filtered_graph.remove_node(node)
|
||||
if bound_upper is not None and degree > bound_upper:
|
||||
@@ -540,9 +531,9 @@ class TokenGraph(DiGraph):
|
||||
self._name = name
|
||||
# directed and undirected graph data
|
||||
self._directed = self
|
||||
self._metadata_directed: dict[str, int] = {}
|
||||
self._metadata_directed: dict[str, float] = {}
|
||||
self._undirected: Graph | None = None
|
||||
self._metadata_undirected: dict[str, int] = {}
|
||||
self._metadata_undirected: dict[str, float] = {}
|
||||
# indicate rescaled weights
|
||||
self.rescaled_weights: bool = False
|
||||
|
||||
@@ -568,12 +559,12 @@ class TokenGraph(DiGraph):
|
||||
return hash(self.__key())
|
||||
"""
|
||||
|
||||
def copy(self) -> Self:
|
||||
def copy(self) -> TokenGraph:
|
||||
"""returns a (deep) copy of the graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
Self
|
||||
TokenGraph
|
||||
deep copy of the graph
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
@@ -594,11 +585,11 @@ class TokenGraph(DiGraph):
|
||||
return self._undirected
|
||||
|
||||
@property
|
||||
def metadata_directed(self) -> dict[str, int]:
|
||||
def metadata_directed(self) -> dict[str, float]:
|
||||
return self._metadata_directed
|
||||
|
||||
@property
|
||||
def metadata_undirected(self) -> dict[str, int]:
|
||||
def metadata_undirected(self) -> dict[str, float]:
|
||||
return self._metadata_undirected
|
||||
|
||||
@overload
|
||||
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
# ** (1) dataset preparation: loading and simple preprocessing
|
||||
# following functions used to load a given dataset and perform simple
|
||||
# following functions are used to load a given dataset and perform simple
|
||||
# duplicate cleansing based on all properties
|
||||
def load_raw_data(
|
||||
path: Path,
|
||||
@@ -277,41 +277,41 @@ def merge_similarity_dupl(
|
||||
|
||||
# ** #################################################################################
|
||||
# TODO check removal
|
||||
def build_embedding_map(
|
||||
data: Series,
|
||||
model: GermanSpacyModel | SentenceTransformer,
|
||||
) -> tuple[dict[int, tuple[Embedding, str]], tuple[bool, bool]]:
|
||||
# dictionary with embeddings
|
||||
embeddings: dict[int, tuple[Embedding, str]] = {}
|
||||
is_spacy = False
|
||||
is_STRF = False
|
||||
# def build_embedding_map(
|
||||
# data: Series,
|
||||
# model: GermanSpacyModel | SentenceTransformer,
|
||||
# ) -> tuple[dict[int, tuple[Embedding, str]], tuple[bool, bool]]:
|
||||
# # dictionary with embeddings
|
||||
# embeddings: dict[int, tuple[Embedding, str]] = {}
|
||||
# is_spacy = False
|
||||
# is_STRF = False
|
||||
|
||||
if isinstance(model, GermanSpacyModel):
|
||||
is_spacy = True
|
||||
elif isinstance(model, SentenceTransformer):
|
||||
is_STRF = True
|
||||
# if isinstance(model, GermanSpacyModel):
|
||||
# is_spacy = True
|
||||
# elif isinstance(model, SentenceTransformer):
|
||||
# is_STRF = True
|
||||
|
||||
if not any((is_spacy, is_STRF)):
|
||||
raise NotImplementedError('Model type unknown')
|
||||
# if not any((is_spacy, is_STRF)):
|
||||
# raise NotImplementedError('Model type unknown')
|
||||
|
||||
for idx, text in tqdm(data.items(), total=len(data), mininterval=1.0):
|
||||
# verbose code: Pyright not inferring types correctly
|
||||
idx = cast(int, idx)
|
||||
text = cast(str, text)
|
||||
if is_spacy:
|
||||
model = cast(GermanSpacyModel, model)
|
||||
embd = cast(SpacyDoc, model(text))
|
||||
embeddings[idx] = (embd, text)
|
||||
# check for empty vectors
|
||||
if not embd.vector_norm:
|
||||
logger.debug('--- Unknown Words ---')
|
||||
logger.debug('embd.text: %s has no vector', embd.text)
|
||||
elif is_STRF:
|
||||
model = cast(SentenceTransformer, model)
|
||||
embd = cast(Tensor, model.encode(text, show_progress_bar=False))
|
||||
embeddings[idx] = (embd, text)
|
||||
# for idx, text in tqdm(data.items(), total=len(data), mininterval=1.0):
|
||||
# # verbose code: Pyright not inferring types correctly
|
||||
# idx = cast(int, idx)
|
||||
# text = cast(str, text)
|
||||
# if is_spacy:
|
||||
# model = cast(GermanSpacyModel, model)
|
||||
# embd = cast(SpacyDoc, model(text))
|
||||
# embeddings[idx] = (embd, text)
|
||||
# # check for empty vectors
|
||||
# if not embd.vector_norm:
|
||||
# logger.debug('--- Unknown Words ---')
|
||||
# logger.debug('embd.text: %s has no vector', embd.text)
|
||||
# elif is_STRF:
|
||||
# model = cast(SentenceTransformer, model)
|
||||
# embd = cast(Tensor, model.encode(text, show_progress_bar=False))
|
||||
# embeddings[idx] = (embd, text)
|
||||
|
||||
return embeddings, (is_spacy, is_STRF)
|
||||
# return embeddings, (is_spacy, is_STRF)
|
||||
|
||||
|
||||
# adapt interface
|
||||
@@ -320,276 +320,275 @@ def build_embedding_map(
|
||||
|
||||
|
||||
# build similarity matrix out of embeddings
|
||||
def build_cosSim_matrix(
|
||||
data: Series,
|
||||
model: GermanSpacyModel | SentenceTransformer,
|
||||
) -> tuple[DataFrame, dict[int, tuple[Embedding, str]]]:
|
||||
# build empty matrix
|
||||
df_index = data.index
|
||||
cosineSim_idx_matrix = pd.DataFrame(
|
||||
data=0.0, columns=df_index, index=df_index, dtype=np.float32
|
||||
)
|
||||
# def build_cosSim_matrix(
|
||||
# data: Series,
|
||||
# model: GermanSpacyModel | SentenceTransformer,
|
||||
# ) -> tuple[DataFrame, dict[int, tuple[Embedding, str]]]:
|
||||
# # build empty matrix
|
||||
# df_index = data.index
|
||||
# cosineSim_idx_matrix = pd.DataFrame(
|
||||
# data=0.0, columns=df_index, index=df_index, dtype=np.float32
|
||||
# )
|
||||
|
||||
logger.info('Start building embedding map...')
|
||||
# logger.info('Start building embedding map...')
|
||||
|
||||
# obtain embeddings based on used model
|
||||
embds, (is_spacy, is_STRF) = build_embedding_map(
|
||||
data=data,
|
||||
model=model,
|
||||
)
|
||||
# # obtain embeddings based on used model
|
||||
# embds, (is_spacy, is_STRF) = build_embedding_map(
|
||||
# data=data,
|
||||
# model=model,
|
||||
# )
|
||||
|
||||
logger.info('Embedding map built successfully.')
|
||||
# logger.info('Embedding map built successfully.')
|
||||
|
||||
# apply index based mapping for efficient handling of large texts
|
||||
combs = combinations(df_index, 2)
|
||||
total_combs = factorial(len(df_index)) // factorial(2) // factorial(len(df_index) - 2)
|
||||
# # apply index based mapping for efficient handling of large texts
|
||||
# combs = combinations(df_index, 2)
|
||||
# total_combs = factorial(len(df_index)) // factorial(2) // factorial(len(df_index) - 2)
|
||||
|
||||
logger.info('Start calculation of similarity scores...')
|
||||
# logger.info('Start calculation of similarity scores...')
|
||||
|
||||
for idx1, idx2 in tqdm(combs, total=total_combs, mininterval=1.0):
|
||||
# print(f"{idx1=}, {idx2=}")
|
||||
embd1 = embds[idx1][0]
|
||||
embd2 = embds[idx2][0]
|
||||
# for idx1, idx2 in tqdm(combs, total=total_combs, mininterval=1.0):
|
||||
# # print(f"{idx1=}, {idx2=}")
|
||||
# embd1 = embds[idx1][0]
|
||||
# embd2 = embds[idx2][0]
|
||||
|
||||
# calculate similarity based on model type
|
||||
if is_spacy:
|
||||
embd1 = cast(SpacyDoc, embds[idx1][0])
|
||||
embd2 = cast(SpacyDoc, embds[idx2][0])
|
||||
cosSim = embd1.similarity(embd2)
|
||||
elif is_STRF:
|
||||
embd1 = cast(Tensor, embds[idx1][0])
|
||||
embd2 = cast(Tensor, embds[idx2][0])
|
||||
cosSim = sentence_transformers.util.cos_sim(embd1, embd2)
|
||||
cosSim = cast(float, cosSim.item())
|
||||
# # calculate similarity based on model type
|
||||
# if is_spacy:
|
||||
# embd1 = cast(SpacyDoc, embds[idx1][0])
|
||||
# embd2 = cast(SpacyDoc, embds[idx2][0])
|
||||
# cosSim = embd1.similarity(embd2)
|
||||
# elif is_STRF:
|
||||
# embd1 = cast(Tensor, embds[idx1][0])
|
||||
# embd2 = cast(Tensor, embds[idx2][0])
|
||||
# cosSim = sentence_transformers.util.cos_sim(embd1, embd2)
|
||||
# cosSim = cast(float, cosSim.item())
|
||||
|
||||
cosineSim_idx_matrix.at[idx1, idx2] = cosSim
|
||||
# cosineSim_idx_matrix.at[idx1, idx2] = cosSim
|
||||
|
||||
logger.info('Similarity scores calculated successfully.')
|
||||
# logger.info('Similarity scores calculated successfully.')
|
||||
|
||||
return cosineSim_idx_matrix, embds
|
||||
# return cosineSim_idx_matrix, embds
|
||||
|
||||
|
||||
# obtain index pairs with cosine similarity
|
||||
# greater than or equal to given threshold value
|
||||
def filt_thresh_cosSim_matrix(
|
||||
cosineSim_idx_matrix: DataFrame,
|
||||
embds: dict[int, tuple[Embedding, str]],
|
||||
threshold: float,
|
||||
) -> tuple[Series, dict[int, tuple[Embedding, str]]]:
|
||||
"""filter similarity matrix by threshold value and return index pairs with
|
||||
a similarity score greater than the provided threshold
|
||||
# def filt_thresh_cosSim_matrix(
|
||||
# cosineSim_idx_matrix: DataFrame,
|
||||
# embds: dict[int, tuple[Embedding, str]],
|
||||
# threshold: float,
|
||||
# ) -> tuple[Series, dict[int, tuple[Embedding, str]]]:
|
||||
# """filter similarity matrix by threshold value and return index pairs with
|
||||
# a similarity score greater than the provided threshold
|
||||
|
||||
Parameters
|
||||
----------
|
||||
threshold : float
|
||||
similarity threshold
|
||||
cosineSim_idx_matrix : DataFrame
|
||||
similarity matrix
|
||||
# Parameters
|
||||
# ----------
|
||||
# threshold : float
|
||||
# similarity threshold
|
||||
# cosineSim_idx_matrix : DataFrame
|
||||
# similarity matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
Series
|
||||
series with multi index (index pairs) and corresponding similarity score
|
||||
"""
|
||||
cosineSim_filt = cast(
|
||||
Series, cosineSim_idx_matrix.where(cosineSim_idx_matrix >= threshold).stack()
|
||||
)
|
||||
# Returns
|
||||
# -------
|
||||
# Series
|
||||
# series with multi index (index pairs) and corresponding similarity score
|
||||
# """
|
||||
# cosineSim_filt = cast(
|
||||
# Series, cosineSim_idx_matrix.where(cosineSim_idx_matrix >= threshold).stack()
|
||||
# )
|
||||
|
||||
return cosineSim_filt, embds
|
||||
# return cosineSim_filt, embds
|
||||
|
||||
|
||||
def list_cosSim_dupl_candidates(
|
||||
cosineSim_filt: Series,
|
||||
embds: dict[int, tuple[Embedding, str]],
|
||||
save_candidates: bool = False,
|
||||
saving_path: Path | None = None,
|
||||
filename: str = 'CosSim-FilterCandidates',
|
||||
pipeline: Pipeline | None = None,
|
||||
) -> tuple[list[tuple[PandasIndex, PandasIndex]], dict[int, tuple[Embedding, str]]]:
|
||||
"""providing an overview of candidates with a similarity score greater than
|
||||
given threshold; more suitable for debugging purposes
|
||||
# def list_cosSim_dupl_candidates(
|
||||
# cosineSim_filt: Series,
|
||||
# embds: dict[int, tuple[Embedding, str]],
|
||||
# save_candidates: bool = False,
|
||||
# saving_path: Path | None = None,
|
||||
# filename: str = 'CosSim-FilterCandidates',
|
||||
# pipeline: Pipeline | None = None,
|
||||
# ) -> tuple[list[tuple[PandasIndex, PandasIndex]], dict[int, tuple[Embedding, str]]]:
|
||||
# """providing an overview of candidates with a similarity score greater than
|
||||
# given threshold; more suitable for debugging purposes
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataFrame
|
||||
contains indices, corresponding texts and similarity score to evaluate results
|
||||
list[tuple[Index, Index]]
|
||||
list containing relevant index pairs for entries with similarity score greater than
|
||||
given threshold
|
||||
"""
|
||||
logger.info('Start gathering of similarity candidates...')
|
||||
# compare found duplicates
|
||||
columns: list[str] = ['idx1', 'text1', 'idx2', 'text2', 'score']
|
||||
df_candidates = pd.DataFrame(columns=columns)
|
||||
# Returns
|
||||
# -------
|
||||
# DataFrame
|
||||
# contains indices, corresponding texts and similarity score to evaluate results
|
||||
# list[tuple[Index, Index]]
|
||||
# list containing relevant index pairs for entries with similarity score greater than
|
||||
# given threshold
|
||||
# """
|
||||
# logger.info('Start gathering of similarity candidates...')
|
||||
# # compare found duplicates
|
||||
# columns: list[str] = ['idx1', 'text1', 'idx2', 'text2', 'score']
|
||||
# df_candidates = pd.DataFrame(columns=columns)
|
||||
|
||||
index_pairs: list[tuple[PandasIndex, PandasIndex]] = []
|
||||
# index_pairs: list[tuple[PandasIndex, PandasIndex]] = []
|
||||
|
||||
for (idx1, idx2), score in tqdm(cosineSim_filt.items(), total=len(cosineSim_filt)): # type: ignore
|
||||
# get text content from embedding as second tuple entry
|
||||
content = [
|
||||
[
|
||||
idx1,
|
||||
embds[idx1][1],
|
||||
idx2,
|
||||
embds[idx2][1],
|
||||
score,
|
||||
]
|
||||
]
|
||||
# add candidates to collection DataFrame
|
||||
df_conc = pd.DataFrame(columns=columns, data=content)
|
||||
if df_candidates.empty:
|
||||
df_candidates = df_conc.copy()
|
||||
else:
|
||||
df_candidates = pd.concat([df_candidates, df_conc])
|
||||
# save index pairs
|
||||
index_pairs.append((idx1, idx2))
|
||||
# for (idx1, idx2), score in tqdm(cosineSim_filt.items(), total=len(cosineSim_filt)): # type: ignore
|
||||
# # get text content from embedding as second tuple entry
|
||||
# content = [
|
||||
# [
|
||||
# idx1,
|
||||
# embds[idx1][1],
|
||||
# idx2,
|
||||
# embds[idx2][1],
|
||||
# score,
|
||||
# ]
|
||||
# ]
|
||||
# # add candidates to collection DataFrame
|
||||
# df_conc = pd.DataFrame(columns=columns, data=content)
|
||||
# if df_candidates.empty:
|
||||
# df_candidates = df_conc.copy()
|
||||
# else:
|
||||
# df_candidates = pd.concat([df_candidates, df_conc])
|
||||
# # save index pairs
|
||||
# index_pairs.append((idx1, idx2))
|
||||
|
||||
logger.info('Similarity candidates gathered successfully.')
|
||||
# logger.info('Similarity candidates gathered successfully.')
|
||||
|
||||
if save_candidates:
|
||||
if saving_path is None:
|
||||
raise ValueError(
|
||||
('Saving path must be provided if duplicate ' 'candidates should be saved.')
|
||||
)
|
||||
elif pipeline is not None:
|
||||
target_filename = (
|
||||
f'Pipe-{pipeline.name}_Step_{pipeline.curr_proc_idx}_' + filename + '.xlsx'
|
||||
)
|
||||
elif pipeline is None:
|
||||
target_filename = f'{filename}.xlsx'
|
||||
logger.info('Saving similarity candidates...')
|
||||
target_path = saving_path.joinpath(target_filename)
|
||||
df_candidates.to_excel(target_path)
|
||||
logger.info('Similarity candidates saved successfully to >>%s<<.', target_path)
|
||||
# if save_candidates:
|
||||
# if saving_path is None:
|
||||
# raise ValueError(
|
||||
# ('Saving path must be provided if duplicate ' 'candidates should be saved.')
|
||||
# )
|
||||
# elif pipeline is not None:
|
||||
# target_filename = (
|
||||
# f'Pipe-{pipeline.name}_Step_{pipeline.curr_proc_idx}_' + filename + '.xlsx'
|
||||
# )
|
||||
# elif pipeline is None:
|
||||
# target_filename = f'{filename}.xlsx'
|
||||
# logger.info('Saving similarity candidates...')
|
||||
# target_path = saving_path.joinpath(target_filename)
|
||||
# df_candidates.to_excel(target_path)
|
||||
# logger.info('Similarity candidates saved successfully to >>%s<<.', target_path)
|
||||
|
||||
return index_pairs, embds
|
||||
# return index_pairs, embds
|
||||
|
||||
|
||||
# TODO: change implementation fully to SentenceTransformer
|
||||
# usage of batch processing for embeddings, use candidate idx function
|
||||
# from time analysis --> moved to ``helpers.py``
|
||||
"""
|
||||
def similar_ids_connection_graph(
|
||||
similar_idx_pairs: list[tuple[PandasIndex, PandasIndex]],
|
||||
) -> tuple[Graph, dict[str, int]]:
|
||||
# build index graph to obtain graph of connected (similar) indices
|
||||
# use this graph to get connected components (indices which belong together)
|
||||
# retain semantic connection on whole dataset
|
||||
similar_id_graph = nx.Graph()
|
||||
for (idx1, idx2) in similar_idx_pairs:
|
||||
# inplace operation, parent/child do not really exist in undirected graph
|
||||
update_graph(graph=similar_id_graph, parent=idx1, child=idx2)
|
||||
|
||||
graph_info = get_graph_metadata(graph=similar_id_graph, logging=True)
|
||||
|
||||
return similar_id_graph, graph_info
|
||||
|
||||
def similar_ids_groups(
|
||||
dupl_id_graph: Graph,
|
||||
) -> Iterator[list[PandasIndex]]:
|
||||
# groups of connected indices
|
||||
ids_groups = cast(Iterator[set[PandasIndex]],
|
||||
nx.connected_components(G=dupl_id_graph))
|
||||
|
||||
for id_group in ids_groups:
|
||||
yield list(id_group)
|
||||
"""
|
||||
# def similar_ids_connection_graph(
|
||||
# similar_idx_pairs: list[tuple[PandasIndex, PandasIndex]],
|
||||
# ) -> tuple[Graph, dict[str, int]]:
|
||||
# # build index graph to obtain graph of connected (similar) indices
|
||||
# # use this graph to get connected components (indices which belong together)
|
||||
# # retain semantic connection on whole dataset
|
||||
# similar_id_graph = nx.Graph()
|
||||
# for (idx1, idx2) in similar_idx_pairs:
|
||||
# # inplace operation, parent/child do not really exist in undirected graph
|
||||
# update_graph(graph=similar_id_graph, parent=idx1, child=idx2)
|
||||
|
||||
# graph_info = get_graph_metadata(graph=similar_id_graph, logging=True)
|
||||
|
||||
# return similar_id_graph, graph_info
|
||||
|
||||
# def similar_ids_groups(
|
||||
# dupl_id_graph: Graph,
|
||||
# ) -> Iterator[list[PandasIndex]]:
|
||||
# # groups of connected indices
|
||||
# ids_groups = cast(Iterator[set[PandasIndex]],
|
||||
# nx.connected_components(G=dupl_id_graph))
|
||||
|
||||
# for id_group in ids_groups:
|
||||
# yield list(id_group)
|
||||
|
||||
|
||||
# merge duplicates
|
||||
def merge_similarity_dupl_old(
|
||||
data: DataFrame,
|
||||
dupl_idx_pairs: list[tuple[PandasIndex, PandasIndex]],
|
||||
) -> tuple[DataFrame]:
|
||||
# copy pre-cleaned data
|
||||
temp = data.copy()
|
||||
index = temp.index
|
||||
# logger.info("Start merging of similarity candidates...")
|
||||
# # merge duplicates
|
||||
# def merge_similarity_dupl_old(
|
||||
# data: DataFrame,
|
||||
# dupl_idx_pairs: list[tuple[PandasIndex, PandasIndex]],
|
||||
# ) -> tuple[DataFrame]:
|
||||
# # copy pre-cleaned data
|
||||
# temp = data.copy()
|
||||
# index = temp.index
|
||||
# # logger.info("Start merging of similarity candidates...")
|
||||
|
||||
# iterate over index pairs
|
||||
for i1, i2 in tqdm(dupl_idx_pairs):
|
||||
# if an entry does not exist any more, skip this pair
|
||||
if i1 not in index or i2 not in index:
|
||||
continue
|
||||
# # iterate over index pairs
|
||||
# for i1, i2 in tqdm(dupl_idx_pairs):
|
||||
# # if an entry does not exist any more, skip this pair
|
||||
# if i1 not in index or i2 not in index:
|
||||
# continue
|
||||
|
||||
# merge num occur
|
||||
num_occur1 = temp.at[i1, 'num_occur']
|
||||
num_occur2 = temp.at[i2, 'num_occur']
|
||||
new_num_occur = num_occur1 + num_occur2
|
||||
# # merge num occur
|
||||
# num_occur1 = temp.at[i1, 'num_occur']
|
||||
# num_occur2 = temp.at[i2, 'num_occur']
|
||||
# new_num_occur = num_occur1 + num_occur2
|
||||
|
||||
# merge associated object ids
|
||||
assoc_ids1 = temp.at[i1, 'assoc_obj_ids']
|
||||
assoc_ids2 = temp.at[i2, 'assoc_obj_ids']
|
||||
new_assoc_ids = np.append(assoc_ids1, assoc_ids2)
|
||||
new_assoc_ids = np.unique(new_assoc_ids.flatten())
|
||||
# # merge associated object ids
|
||||
# assoc_ids1 = temp.at[i1, 'assoc_obj_ids']
|
||||
# assoc_ids2 = temp.at[i2, 'assoc_obj_ids']
|
||||
# new_assoc_ids = np.append(assoc_ids1, assoc_ids2)
|
||||
# new_assoc_ids = np.unique(new_assoc_ids.flatten())
|
||||
|
||||
# recalculate num associated obj ids
|
||||
new_num_assoc_obj_ids = len(new_assoc_ids)
|
||||
# # recalculate num associated obj ids
|
||||
# new_num_assoc_obj_ids = len(new_assoc_ids)
|
||||
|
||||
# write properties to first entry
|
||||
temp.at[i1, 'num_occur'] = new_num_occur
|
||||
temp.at[i1, 'assoc_obj_ids'] = new_assoc_ids
|
||||
temp.at[i1, 'num_assoc_obj_ids'] = new_num_assoc_obj_ids
|
||||
# # write properties to first entry
|
||||
# temp.at[i1, 'num_occur'] = new_num_occur
|
||||
# temp.at[i1, 'assoc_obj_ids'] = new_assoc_ids
|
||||
# temp.at[i1, 'num_assoc_obj_ids'] = new_num_assoc_obj_ids
|
||||
|
||||
# drop second entry
|
||||
temp = temp.drop(index=i2)
|
||||
index = temp.index
|
||||
# # drop second entry
|
||||
# temp = temp.drop(index=i2)
|
||||
# index = temp.index
|
||||
|
||||
# logger.info("Similarity candidates merged successfully.")
|
||||
# # logger.info("Similarity candidates merged successfully.")
|
||||
|
||||
return (temp,)
|
||||
# return (temp,)
|
||||
|
||||
|
||||
# ** debugging and evaluation
|
||||
def choose_cosSim_dupl_candidates(
|
||||
cosineSim_filt: Series,
|
||||
embds: dict[int, tuple[Embedding, str]],
|
||||
) -> tuple[DataFrame, list[tuple[PandasIndex, PandasIndex]]]:
|
||||
"""providing an overview of candidates with a similarity score greater than
|
||||
given threshold, but decision is made manually by iterating through the candidates
|
||||
with user interaction; more suitable for debugging purposes
|
||||
# def choose_cosSim_dupl_candidates(
|
||||
# cosineSim_filt: Series,
|
||||
# embds: dict[int, tuple[Embedding, str]],
|
||||
# ) -> tuple[DataFrame, list[tuple[PandasIndex, PandasIndex]]]:
|
||||
# """providing an overview of candidates with a similarity score greater than
|
||||
# given threshold, but decision is made manually by iterating through the candidates
|
||||
# with user interaction; more suitable for debugging purposes
|
||||
|
||||
Returns
|
||||
-------
|
||||
DataFrame
|
||||
contains indices, corresponding texts and similarity score to evaluate results
|
||||
list[tuple[Index, Index]]
|
||||
list containing relevant index pairs for entries with similarity score greater than
|
||||
given threshold
|
||||
"""
|
||||
# Returns
|
||||
# -------
|
||||
# DataFrame
|
||||
# contains indices, corresponding texts and similarity score to evaluate results
|
||||
# list[tuple[Index, Index]]
|
||||
# list containing relevant index pairs for entries with similarity score greater than
|
||||
# given threshold
|
||||
# """
|
||||
|
||||
# compare found duplicates
|
||||
columns = ['idx1', 'text1', 'idx2', 'text2', 'score']
|
||||
df_candidates = pd.DataFrame(columns=columns)
|
||||
# # compare found duplicates
|
||||
# columns = ['idx1', 'text1', 'idx2', 'text2', 'score']
|
||||
# df_candidates = pd.DataFrame(columns=columns)
|
||||
|
||||
index_pairs: list[tuple[PandasIndex, PandasIndex]] = []
|
||||
# index_pairs: list[tuple[PandasIndex, PandasIndex]] = []
|
||||
|
||||
for (idx1, idx2), score in cosineSim_filt.items(): # type: ignore
|
||||
# get texts for comparison
|
||||
text1 = embds[idx1][1]
|
||||
text2 = embds[idx2][1]
|
||||
# get decision
|
||||
print('---------- New Decision ----------')
|
||||
print('text1:\n', text1, '\n', flush=True)
|
||||
print('text2:\n', text2, '\n', flush=True)
|
||||
decision = input('Please enter >>y<< if this is a duplicate, else hit enter:')
|
||||
# for (idx1, idx2), score in cosineSim_filt.items(): # type: ignore
|
||||
# # get texts for comparison
|
||||
# text1 = embds[idx1][1]
|
||||
# text2 = embds[idx2][1]
|
||||
# # get decision
|
||||
# print('---------- New Decision ----------')
|
||||
# print('text1:\n', text1, '\n', flush=True)
|
||||
# print('text2:\n', text2, '\n', flush=True)
|
||||
# decision = input('Please enter >>y<< if this is a duplicate, else hit enter:')
|
||||
|
||||
if not decision == 'y':
|
||||
continue
|
||||
# if not decision == 'y':
|
||||
# continue
|
||||
|
||||
# get text content from embedding as second tuple entry
|
||||
content = [
|
||||
[
|
||||
idx1,
|
||||
text1,
|
||||
idx2,
|
||||
text2,
|
||||
score,
|
||||
]
|
||||
]
|
||||
df_conc = pd.DataFrame(columns=columns, data=content)
|
||||
# # get text content from embedding as second tuple entry
|
||||
# content = [
|
||||
# [
|
||||
# idx1,
|
||||
# text1,
|
||||
# idx2,
|
||||
# text2,
|
||||
# score,
|
||||
# ]
|
||||
# ]
|
||||
# df_conc = pd.DataFrame(columns=columns, data=content)
|
||||
|
||||
df_candidates = pd.concat([df_candidates, df_conc])
|
||||
index_pairs.append((idx1, idx2))
|
||||
# df_candidates = pd.concat([df_candidates, df_conc])
|
||||
# index_pairs.append((idx1, idx2))
|
||||
|
||||
return df_candidates, index_pairs
|
||||
# return df_candidates, index_pairs
|
||||
|
||||
@@ -22,7 +22,7 @@ pattern_escape_newline = re.compile(r'[\n]+')
|
||||
pattern_escape_seq = re.compile(r'[\t\n\r\f\v]+')
|
||||
pattern_escape_seq_sentences = re.compile(r' *[\t\n\r\f\v]+')
|
||||
pattern_repeated_chars = re.compile(r'[,;.:!?\-_+]+(?=[,;.:!?\-_+])')
|
||||
pattern_dates = re.compile(r'(\d{1,2}\.)?(\d{1,2}\.)([\d]{2,4})?')
|
||||
pattern_dates = re.compile(r'(\d{1,2}\.)?(\d{1,2}\.)?([\d]{2,4})?')
|
||||
pattern_whitespace = re.compile(r'[ ]{2,}')
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def clean_string_slim(string: str) -> str:
|
||||
cleaned entry
|
||||
"""
|
||||
# remove special chars
|
||||
string = pattern_escape_newline.sub('. ', string)
|
||||
# string = pattern_escape_newline.sub(' ', string)
|
||||
string = pattern_escape_seq.sub(' ', string)
|
||||
string = pattern_repeated_chars.sub('', string)
|
||||
# string = pattern_dates.sub('', string)
|
||||
@@ -127,7 +127,7 @@ def candidates_by_index(
|
||||
|
||||
def similar_index_connection_graph(
|
||||
similar_idx_pairs: Iterable[tuple[PandasIndex, PandasIndex]],
|
||||
) -> tuple[Graph, dict[str, int]]:
|
||||
) -> tuple[Graph, dict[str, float]]:
|
||||
# build index graph to obtain graph of connected (similar) indices
|
||||
# use this graph to get connected components (indices which belong together)
|
||||
# retain semantic connection on whole dataset
|
||||
|
||||
17
src/lang_main/config.py
Normal file
17
src/lang_main/config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import tomllib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_toml_config(
|
||||
path_to_toml: str | Path,
|
||||
) -> dict[str, Any]:
|
||||
with open(path_to_toml, 'rb') as f:
|
||||
data = tomllib.load(f)
|
||||
print('Loaded TOML config file successfully.', file=sys.stderr, flush=True)
|
||||
return data
|
||||
@@ -2,22 +2,21 @@ from enum import Enum # noqa: I001
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
import os
|
||||
|
||||
from sentence_transformers import SimilarityFunction
|
||||
|
||||
from lang_main import CONFIG, CYTO_PATH_STYLESHEET
|
||||
from lang_main import model_loader as m_load
|
||||
from lang_main import CONFIG, CYTO_PATH_STYLESHEET, BASE_PATH
|
||||
from lang_main.types import (
|
||||
CytoLayoutProperties,
|
||||
CytoLayouts,
|
||||
LanguageModels,
|
||||
ModelLoaderMap,
|
||||
ONNXExecutionProvider, # noqa: F401
|
||||
STFRBackends,
|
||||
STFRDeviceTypes,
|
||||
STFRModelArgs,
|
||||
STFRModels,
|
||||
STFRModelTypes,
|
||||
STFRQuantFilenames, # noqa: F401
|
||||
SpacyModelTypes,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -67,35 +66,29 @@ SKIP_TIME_ANALYSIS: Final[bool] = CONFIG['control']['time_analysis_skip']
|
||||
|
||||
# ** models
|
||||
# ** loading
|
||||
SPACY_MODEL_NAME: Final[str] = 'de_dep_news_trf'
|
||||
STFR_MODEL_NAME: Final[STFRModels] = STFRModels.ALL_MPNET_BASE_V2
|
||||
MODEL_BASE_FOLDER_NAME: Final[str] = 'lang-models'
|
||||
MODEL_BASE_FOLDER: Final[Path] = BASE_PATH / MODEL_BASE_FOLDER_NAME
|
||||
if not MODEL_BASE_FOLDER.exists():
|
||||
raise FileNotFoundError('Language model folder not found.')
|
||||
os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(MODEL_BASE_FOLDER)
|
||||
SPACY_MODEL_NAME: Final[SpacyModelTypes] = SpacyModelTypes.DE_DEP_NEWS_TRF
|
||||
STFR_MODEL_NAME: Final[STFRModelTypes] = STFRModelTypes.ALL_MPNET_BASE_V2
|
||||
STFR_DEVICE: Final[STFRDeviceTypes] = STFRDeviceTypes.CPU
|
||||
STFR_SIMILARITY: Final[SimilarityFunction] = SimilarityFunction.COSINE
|
||||
STFR_BACKEND: Final[STFRBackends] = STFRBackends.TORCH
|
||||
STFR_MODEL_ARGS: Final[STFRModelArgs] = {}
|
||||
# STFR_MODEL_ARGS: Final[STFRModelArgs] = {
|
||||
# 'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
|
||||
# 'provider': ONNXExecutionProvider.CPU,
|
||||
# 'export': False,
|
||||
# }
|
||||
MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||
'func': m_load.load_sentence_transformer,
|
||||
'kwargs': {
|
||||
'model_name': STFR_MODEL_NAME,
|
||||
'similarity_func': STFR_SIMILARITY,
|
||||
'backend': STFR_BACKEND,
|
||||
'device': STFR_DEVICE,
|
||||
'model_kwargs': STFR_MODEL_ARGS,
|
||||
},
|
||||
},
|
||||
LanguageModels.SPACY: {
|
||||
'func': m_load.load_spacy,
|
||||
'kwargs': {
|
||||
'model_name': SPACY_MODEL_NAME,
|
||||
},
|
||||
},
|
||||
STFR_MODEL_ARGS_DEFAULT: STFRModelArgs = {}
|
||||
STFR_MODEL_ARGS_ONNX: STFRModelArgs = {
|
||||
'file_name': STFRQuantFilenames.ONNX_Q_UINT8,
|
||||
'provider': ONNXExecutionProvider.CPU,
|
||||
'export': False,
|
||||
}
|
||||
stfr_model_args: STFRModelArgs
|
||||
if STFR_BACKEND == STFRBackends.ONNX:
|
||||
stfr_model_args = STFR_MODEL_ARGS_ONNX
|
||||
else:
|
||||
stfr_model_args = STFR_MODEL_ARGS_DEFAULT
|
||||
|
||||
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
||||
# ** language dependency analysis
|
||||
# ** POS
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
# ** meta exceptions
|
||||
class LanguageModelNotFoundError(Exception):
|
||||
"""Error raised if a given language model could not be loaded successfully"""
|
||||
|
||||
|
||||
# ** token graph exceptions
|
||||
class EdgePropertyNotContainedError(Exception):
|
||||
"""Error raised if a needed edge property is not contained in graph edges"""
|
||||
|
||||
@@ -21,8 +27,6 @@ class DependencyMissingError(Exception):
|
||||
|
||||
|
||||
# ** pipelines to perform given actions on dataset in a customisable manner
|
||||
|
||||
|
||||
class NoPerformableActionError(Exception):
|
||||
"""Error describing that no action is available in the current pipeline"""
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import pickle
|
||||
import shutil
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -33,15 +32,6 @@ def create_saving_folder(
|
||||
)
|
||||
|
||||
|
||||
def load_toml_config(
|
||||
path_to_toml: str | Path,
|
||||
) -> dict[str, Any]:
|
||||
with open(path_to_toml, 'rb') as f:
|
||||
data = tomllib.load(f)
|
||||
logger.info('Loaded TOML config file successfully.')
|
||||
return data
|
||||
|
||||
|
||||
# saving and loading using pickle
|
||||
# careful: pickling from unknown sources can be dangerous
|
||||
def save_pickle(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# lang_main: Config file
|
||||
[info]
|
||||
pkg = 'lang_main'
|
||||
|
||||
[paths]
|
||||
inputs = './inputs/'
|
||||
|
||||
@@ -5,6 +5,7 @@ from time import gmtime
|
||||
from typing import Final
|
||||
|
||||
from lang_main.constants import (
|
||||
BASE_PATH,
|
||||
ENABLE_LOGGING,
|
||||
LOGGING_TO_FILE,
|
||||
LOGGING_TO_STDERR,
|
||||
@@ -15,11 +16,11 @@ from lang_main.types import LoggingLevels
|
||||
logging.Formatter.converter = gmtime
|
||||
LOG_FMT: Final[str] = '%(asctime)s | lang_main:%(module)s:%(levelname)s | %(message)s'
|
||||
LOG_DATE_FMT: Final[str] = '%Y-%m-%d %H:%M:%S +0000'
|
||||
LOG_FILE_PATH: Final[Path] = Path.cwd() / 'lang-main.log'
|
||||
# logging.basicConfig(
|
||||
# format=LOG_FMT,
|
||||
# datefmt=LOG_DATE_FMT,
|
||||
# )
|
||||
LOG_FILE_FOLDER: Final[Path] = BASE_PATH / 'logs'
|
||||
if not LOG_FILE_FOLDER.exists():
|
||||
LOG_FILE_FOLDER.mkdir(parents=True)
|
||||
|
||||
LOG_FILE_PATH: Final[Path] = LOG_FILE_FOLDER / 'lang-main.log'
|
||||
|
||||
# ** formatters
|
||||
logger_all_formater = logging.Formatter(fmt=LOG_FMT, datefmt=LOG_DATE_FMT)
|
||||
|
||||
@@ -1,16 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Literal,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import spacy
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sentence_transformers import SentenceTransformer, SimilarityFunction
|
||||
|
||||
from lang_main.constants import STFR_SIMILARITY
|
||||
from lang_main.constants import (
|
||||
SPACY_MODEL_NAME,
|
||||
STFR_BACKEND,
|
||||
STFR_DEVICE,
|
||||
STFR_MODEL_ARGS,
|
||||
STFR_MODEL_NAME,
|
||||
STFR_SIMILARITY,
|
||||
)
|
||||
from lang_main.errors import LanguageModelNotFoundError
|
||||
from lang_main.types import (
|
||||
LanguageModels,
|
||||
Model,
|
||||
@@ -20,9 +29,6 @@ from lang_main.types import (
|
||||
STFRDeviceTypes,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SimilarityFunction
|
||||
|
||||
|
||||
@overload
|
||||
def instantiate_model(
|
||||
@@ -53,14 +59,27 @@ def instantiate_model(
|
||||
def load_spacy(
|
||||
model_name: str,
|
||||
) -> SpacyModel:
|
||||
return spacy.load(model_name)
|
||||
try:
|
||||
spacy_model_obj = importlib.import_module(SPACY_MODEL_NAME)
|
||||
except ModuleNotFoundError:
|
||||
raise LanguageModelNotFoundError(
|
||||
(
|
||||
f'Could not find spaCy model >>{model_name}<<. '
|
||||
f'Check if it is installed correctly.'
|
||||
)
|
||||
)
|
||||
pretrained_model = cast(SpacyModel, spacy_model_obj.load())
|
||||
|
||||
return pretrained_model
|
||||
|
||||
|
||||
def load_sentence_transformer(
|
||||
model_name: str,
|
||||
similarity_func: SimilarityFunction = STFR_SIMILARITY,
|
||||
similarity_func: SimilarityFunction = SimilarityFunction.COSINE,
|
||||
backend: STFRBackends = STFRBackends.TORCH,
|
||||
device: STFRDeviceTypes = STFRDeviceTypes.CPU,
|
||||
local_files_only: bool = False,
|
||||
model_save_folder: str | None = None,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
) -> SentenceTransformer:
|
||||
return SentenceTransformer(
|
||||
@@ -68,5 +87,28 @@ def load_sentence_transformer(
|
||||
similarity_fn_name=similarity_func,
|
||||
backend=backend, # type: ignore Literal matches Enum
|
||||
device=device,
|
||||
cache_folder=model_save_folder,
|
||||
local_files_only=local_files_only,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ** configured model builder functions
|
||||
MODEL_LOADER_MAP: Final[ModelLoaderMap] = {
|
||||
LanguageModels.SENTENCE_TRANSFORMER: {
|
||||
'func': load_sentence_transformer,
|
||||
'kwargs': {
|
||||
'model_name': STFR_MODEL_NAME,
|
||||
'similarity_func': STFR_SIMILARITY,
|
||||
'backend': STFR_BACKEND,
|
||||
'device': STFR_DEVICE,
|
||||
'model_kwargs': STFR_MODEL_ARGS,
|
||||
},
|
||||
},
|
||||
LanguageModels.SPACY: {
|
||||
'func': load_spacy,
|
||||
'kwargs': {
|
||||
'model_name': SPACY_MODEL_NAME,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ from lang_main.constants import (
|
||||
DATE_COLS,
|
||||
FEATURE_NAME_OBJ_ID,
|
||||
MODEL_INPUT_FEATURES,
|
||||
MODEL_LOADER_MAP,
|
||||
NAME_DELTA_FEAT_TO_REPAIR,
|
||||
SAVE_PATH_FOLDER,
|
||||
THRESHOLD_AMOUNT_CHARACTERS,
|
||||
@@ -41,6 +40,7 @@ from lang_main.constants import (
|
||||
THRESHOLD_UNIQUE_TEXTS,
|
||||
UNIQUE_CRITERION_FEATURE,
|
||||
)
|
||||
from lang_main.model_loader import MODEL_LOADER_MAP
|
||||
from lang_main.pipelines.base import Pipeline
|
||||
from lang_main.types import EntryPoints, LanguageModels
|
||||
|
||||
|
||||
@@ -45,13 +45,20 @@ class ONNXExecutionProvider(enum.StrEnum):
|
||||
CPU = 'CPUExecutionProvider'
|
||||
|
||||
|
||||
class STFRModels(enum.StrEnum):
|
||||
class STFRModelTypes(enum.StrEnum):
|
||||
ALL_MPNET_BASE_V2 = 'all-mpnet-base-v2'
|
||||
ALL_DISTILROBERTA_V1 = 'all-distilroberta-v1'
|
||||
ALL_MINI_LM_L12_V2 = 'all-MiniLM-L12-v2'
|
||||
ALL_MINI_LM_L6_V2 = 'all-MiniLM-L6-v2'
|
||||
|
||||
|
||||
class SpacyModelTypes(enum.StrEnum):
|
||||
DE_CORE_NEWS_SM = 'de_core_news_sm'
|
||||
DE_CORE_NEWS_MD = 'de_core_news_md'
|
||||
DE_CORE_NEWS_LG = 'de_core_news_lg'
|
||||
DE_DEP_NEWS_TRF = 'de_dep_news_trf'
|
||||
|
||||
|
||||
class STFRQuantFilenames(enum.StrEnum):
|
||||
ONNX_Q_UINT8 = 'onnx/model_quint8_avx2.onnx'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user