added new test cases
This commit is contained in:
@@ -198,8 +198,10 @@ def filter_graph_by_edge_weight(
|
||||
graph: TokenGraph,
|
||||
bound_lower: int | None,
|
||||
bound_upper: int | None,
|
||||
property: str = 'weight',
|
||||
) -> TokenGraph:
|
||||
"""filters all edges which are within the provided bounds
|
||||
inclusive limits: bound_lower <= edge_weight <= bound_upper are retained
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -216,12 +218,12 @@ def filter_graph_by_edge_weight(
|
||||
original_graph_edges = copy.deepcopy(graph.edges)
|
||||
filtered_graph = graph.copy()
|
||||
|
||||
if not any([bound_lower, bound_upper]):
|
||||
if not any((bound_lower, bound_upper)):
|
||||
logger.warning('No bounds provided, returning original graph.')
|
||||
return filtered_graph
|
||||
|
||||
for edge in original_graph_edges:
|
||||
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]]['weight'])
|
||||
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]][property])
|
||||
if bound_lower is not None and weight < bound_lower:
|
||||
filtered_graph.remove_edge(edge[0], edge[1])
|
||||
if bound_upper is not None and weight > bound_upper:
|
||||
@@ -329,14 +331,12 @@ def static_graph_analysis(
|
||||
Parameters
|
||||
----------
|
||||
tk_graph_directed : TokenGraph
|
||||
token graph (directed) and with rescaled edge weights
|
||||
tk_graph_undirected : Graph
|
||||
token graph (undirected) and with rescaled edge weights
|
||||
token graph (directed)
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[TokenGraph, Graph]
|
||||
token graph (directed) and undirected version with added weighted degree
|
||||
tuple[TokenGraph]
|
||||
token graph (directed) with included undirected version and calculated KPIs
|
||||
"""
|
||||
graph = graph.copy()
|
||||
graph.perform_static_analysis()
|
||||
@@ -559,12 +559,12 @@ class TokenGraph(DiGraph):
|
||||
return hash(self.__key())
|
||||
"""
|
||||
|
||||
def copy(self) -> TokenGraph:
|
||||
def copy(self) -> Self:
|
||||
"""returns a (deep) copy of the graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
TokenGraph
|
||||
Self
|
||||
deep copy of the graph
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
@@ -669,7 +669,7 @@ class TokenGraph(DiGraph):
|
||||
|
||||
return token_graph, undirected
|
||||
|
||||
def perform_static_analysis(self):
|
||||
def perform_static_analysis(self) -> None:
|
||||
"""calculate different metrics directly on the data of the underlying graphs
|
||||
(directed and undirected)
|
||||
|
||||
@@ -717,16 +717,11 @@ class TokenGraph(DiGraph):
|
||||
saving_path = self._save_prepare(path=path, filename=filename)
|
||||
|
||||
if directed:
|
||||
target_graph = self._directed
|
||||
elif not directed and self._undirected is not None:
|
||||
target_graph = self._undirected
|
||||
target_graph = self.directed
|
||||
else:
|
||||
raise ValueError('No undirected graph available.')
|
||||
target_graph = self.undirected
|
||||
|
||||
save_to_GraphML(graph=target_graph, saving_path=saving_path)
|
||||
# saving_path = saving_path.with_suffix('.graphml')
|
||||
# nx.write_graphml(G=target_graph, path=saving_path)
|
||||
# logger.info('Successfully saved graph as GraphML file under %s.', saving_path)
|
||||
|
||||
def to_pickle(
|
||||
self,
|
||||
@@ -743,13 +738,14 @@ class TokenGraph(DiGraph):
|
||||
filename to be given, by default None
|
||||
"""
|
||||
saving_path = self._save_prepare(path=path, filename=filename)
|
||||
saving_path = saving_path.with_suffix('.pickle')
|
||||
saving_path = saving_path.with_suffix('.pkl')
|
||||
save_pickle(obj=self, path=saving_path)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
path: Path,
|
||||
node_type_graphml: type = str,
|
||||
) -> Self:
|
||||
# !! no validity checks for pickle files
|
||||
# !! GraphML files not correct because not all properties
|
||||
@@ -757,7 +753,7 @@ class TokenGraph(DiGraph):
|
||||
# TODO REWORK
|
||||
match path.suffix:
|
||||
case '.graphml':
|
||||
graph = typing.cast(Self, nx.read_graphml(path, node_type=int))
|
||||
graph = typing.cast(Self, nx.read_graphml(path, node_type=node_type_graphml))
|
||||
logger.info('Successfully loaded graph from GraphML file %s.', path)
|
||||
case '.pkl' | '.pickle':
|
||||
graph = typing.cast(Self, load_pickle(path))
|
||||
@@ -767,17 +763,18 @@ class TokenGraph(DiGraph):
|
||||
|
||||
return graph
|
||||
|
||||
@classmethod
|
||||
def from_pickle(
|
||||
cls,
|
||||
path: str | Path,
|
||||
) -> Self:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
# TODO check removal
|
||||
# @classmethod
|
||||
# def from_pickle(
|
||||
# cls,
|
||||
# path: str | Path,
|
||||
# ) -> Self:
|
||||
# if isinstance(path, str):
|
||||
# path = Path(path)
|
||||
|
||||
if path.suffix not in ('.pkl', '.pickle'):
|
||||
raise ValueError('File format not supported.')
|
||||
# if path.suffix not in ('.pkl', '.pickle'):
|
||||
# raise ValueError('File format not supported.')
|
||||
|
||||
graph = typing.cast(Self, load_pickle(path))
|
||||
# graph = typing.cast(Self, load_pickle(path))
|
||||
|
||||
return graph
|
||||
# return graph
|
||||
|
||||
@@ -205,6 +205,30 @@ def numeric_pre_filter_feature(
|
||||
bound_lower: int | None,
|
||||
bound_upper: int | None,
|
||||
) -> tuple[DataFrame]:
|
||||
"""filter DataFrame for a given numerical feature regarding their bounds
|
||||
bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DataFrame
|
||||
DataFrame to filter
|
||||
feature : str
|
||||
feature name to filter
|
||||
bound_lower : int | None
|
||||
lower bound of values to retain
|
||||
bound_upper : int | None
|
||||
upper bound of values to retain
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[DataFrame]
|
||||
filtered DataFrame
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
if no bounds are provided, at least one bound must be set
|
||||
"""
|
||||
if not any([bound_lower, bound_upper]):
|
||||
raise ValueError('No bounds for filtering provided')
|
||||
|
||||
@@ -228,7 +252,7 @@ def numeric_pre_filter_feature(
|
||||
# a more robust identification of duplicates negating negative side effects
|
||||
# of several disturbances like typos, escape characters, etc.
|
||||
# build mapping of embeddings for given model
|
||||
def merge_similarity_dupl(
|
||||
def merge_similarity_duplicates(
|
||||
data: DataFrame,
|
||||
model: SentenceTransformer,
|
||||
cos_sim_threshold: float,
|
||||
|
||||
@@ -11,6 +11,7 @@ from lang_main.analysis.graphs import (
|
||||
TokenGraph,
|
||||
update_graph,
|
||||
)
|
||||
from lang_main.analysis.shared import pattern_dates
|
||||
from lang_main.constants import (
|
||||
POS_INDIRECT,
|
||||
POS_OF_INTEREST,
|
||||
@@ -38,21 +39,40 @@ def is_str_date(
|
||||
string: str,
|
||||
fuzzy: bool = False,
|
||||
) -> bool:
|
||||
"""not stable function to test strings for dates, not 100 percent reliable
|
||||
|
||||
Parameters
|
||||
----------
|
||||
string : str
|
||||
string to check for dates
|
||||
fuzzy : bool, optional
|
||||
whether to use dateutils.parser.pase fuzzy capability, by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
indicates whether date was found or not
|
||||
"""
|
||||
try:
|
||||
# check if string is a number
|
||||
# if length is greater than 8, it is not a date
|
||||
int(string)
|
||||
if len(string) > 8:
|
||||
if len(string) not in {2, 4}:
|
||||
return False
|
||||
except ValueError:
|
||||
# not a number
|
||||
pass
|
||||
|
||||
try:
|
||||
parse(string, fuzzy=fuzzy)
|
||||
parse(string, fuzzy=fuzzy, dayfirst=True, yearfirst=False)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
date_found: bool = False
|
||||
match = pattern_dates.search(string)
|
||||
if match is None:
|
||||
return date_found
|
||||
date_found = any(match.groups())
|
||||
return date_found
|
||||
|
||||
|
||||
def obtain_relevant_descendants(
|
||||
@@ -106,7 +126,7 @@ def add_doc_info_to_graph(
|
||||
if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
|
||||
continue
|
||||
# skip token which are dates or times
|
||||
if is_str_date(string=token.text):
|
||||
if token.pos_ == 'NUM' and is_str_date(string=token.text):
|
||||
continue
|
||||
|
||||
relevant_descendants = obtain_relevant_descendants(token=token)
|
||||
@@ -252,32 +272,33 @@ def build_token_graph_simple(
|
||||
return graph, docs_mapping
|
||||
|
||||
|
||||
def build_token_graph_old(
|
||||
data: DataFrame,
|
||||
model: SpacyModel,
|
||||
) -> tuple[TokenGraph]:
|
||||
# empty NetworkX directed graph
|
||||
# graph = nx.DiGraph()
|
||||
graph = TokenGraph()
|
||||
# TODO check removal
|
||||
# def build_token_graph_old(
|
||||
# data: DataFrame,
|
||||
# model: SpacyModel,
|
||||
# ) -> tuple[TokenGraph]:
|
||||
# # empty NetworkX directed graph
|
||||
# # graph = nx.DiGraph()
|
||||
# graph = TokenGraph()
|
||||
|
||||
for row in tqdm(data.itertuples(), total=len(data)):
|
||||
# obtain properties from tuple
|
||||
# attribute names must match with preprocessed data
|
||||
entry_text = cast(str, row.entry)
|
||||
weight = cast(int, row.num_occur)
|
||||
# for row in tqdm(data.itertuples(), total=len(data)):
|
||||
# # obtain properties from tuple
|
||||
# # attribute names must match with preprocessed data
|
||||
# entry_text = cast(str, row.entry)
|
||||
# weight = cast(int, row.num_occur)
|
||||
|
||||
# get spacy model output
|
||||
doc = model(entry_text)
|
||||
# # get spacy model output
|
||||
# doc = model(entry_text)
|
||||
|
||||
add_doc_info_to_graph(
|
||||
graph=graph,
|
||||
doc=doc,
|
||||
weight=weight,
|
||||
)
|
||||
# add_doc_info_to_graph(
|
||||
# graph=graph,
|
||||
# doc=doc,
|
||||
# weight=weight,
|
||||
# )
|
||||
|
||||
# metadata
|
||||
graph.update_metadata()
|
||||
# convert to undirected
|
||||
graph.to_undirected()
|
||||
# # metadata
|
||||
# graph.update_metadata()
|
||||
# # convert to undirected
|
||||
# graph.to_undirected()
|
||||
|
||||
return (graph,)
|
||||
# return (graph,)
|
||||
|
||||
@@ -43,6 +43,9 @@ LOGGING_TO_FILE: Final[bool] = CONFIG['logging']['file']
|
||||
LOGGING_TO_STDERR: Final[bool] = CONFIG['logging']['stderr']
|
||||
LOGGING_DEFAULT_GRAPHS: Final[bool] = False
|
||||
|
||||
# ** pickling
|
||||
PICKLE_PROTOCOL_VERSION: Final[int] = 5
|
||||
|
||||
# ** paths
|
||||
input_path_conf = Path.cwd() / Path(CONFIG['paths']['inputs'])
|
||||
INPUT_PATH_FOLDER: Final[Path] = input_path_conf.resolve()
|
||||
@@ -91,12 +94,7 @@ else:
|
||||
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
||||
# ** language dependency analysis
|
||||
# ** POS
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'ADJ', 'VERB', 'AUX'])
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN'])
|
||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX'])
|
||||
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV'])
|
||||
# POS_INDIRECT: frozenset[str] = frozenset(['AUX', 'VERB'])
|
||||
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV', 'NUM'])
|
||||
POS_INDIRECT: frozenset[str] = frozenset(['AUX'])
|
||||
# ** TAG
|
||||
# TAG_OF_INTEREST: frozenset[str] = frozenset(['ADJD'])
|
||||
|
||||
@@ -4,6 +4,7 @@ import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from lang_main.constants import PICKLE_PROTOCOL_VERSION
|
||||
from lang_main.loggers import logger_shared_helpers as logger
|
||||
|
||||
|
||||
@@ -39,7 +40,7 @@ def save_pickle(
|
||||
path: str | Path,
|
||||
) -> None:
|
||||
with open(path, 'wb') as file:
|
||||
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickle.dump(obj, file, protocol=PICKLE_PROTOCOL_VERSION)
|
||||
logger.info('Saved file successfully under %s', path)
|
||||
|
||||
|
||||
@@ -56,7 +57,7 @@ def encode_to_base64_str(
|
||||
obj: Any,
|
||||
encoding: str = 'utf-8',
|
||||
) -> str:
|
||||
serialised = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
serialised = pickle.dumps(obj, protocol=PICKLE_PROTOCOL_VERSION)
|
||||
b64_bytes = base64.b64encode(serialised)
|
||||
return b64_bytes.decode(encoding=encoding)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from lang_main.analysis import graphs
|
||||
from lang_main.analysis.preprocessing import (
|
||||
analyse_feature,
|
||||
load_raw_data,
|
||||
merge_similarity_dupl,
|
||||
merge_similarity_duplicates,
|
||||
numeric_pre_filter_feature,
|
||||
remove_duplicates,
|
||||
remove_NA,
|
||||
@@ -100,7 +100,7 @@ def build_merge_duplicates_pipe() -> Pipeline:
|
||||
},
|
||||
)
|
||||
pipe_merge.add(
|
||||
merge_similarity_dupl,
|
||||
merge_similarity_duplicates,
|
||||
{
|
||||
'model': STFR_MODEL,
|
||||
'cos_sim_threshold': THRESHOLD_SIMILARITY,
|
||||
|
||||
Reference in New Issue
Block a user