added new test cases

This commit is contained in:
Florian Förster
2024-11-14 16:40:00 +01:00
parent 6781b4a132
commit 5a789b7605
20 changed files with 2339 additions and 94 deletions

View File

@@ -198,8 +198,10 @@ def filter_graph_by_edge_weight(
graph: TokenGraph,
bound_lower: int | None,
bound_upper: int | None,
property: str = 'weight',
) -> TokenGraph:
"""filters all edges which are within the provided bounds
inclusive limits: bound_lower <= edge_weight <= bound_upper are retained
Parameters
----------
@@ -216,12 +218,12 @@ def filter_graph_by_edge_weight(
original_graph_edges = copy.deepcopy(graph.edges)
filtered_graph = graph.copy()
if not any([bound_lower, bound_upper]):
if not any((bound_lower, bound_upper)):
logger.warning('No bounds provided, returning original graph.')
return filtered_graph
for edge in original_graph_edges:
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]]['weight'])
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]][property])
if bound_lower is not None and weight < bound_lower:
filtered_graph.remove_edge(edge[0], edge[1])
if bound_upper is not None and weight > bound_upper:
@@ -329,14 +331,12 @@ def static_graph_analysis(
Parameters
----------
tk_graph_directed : TokenGraph
token graph (directed) and with rescaled edge weights
tk_graph_undirected : Graph
token graph (undirected) and with rescaled edge weights
token graph (directed)
Returns
-------
tuple[TokenGraph, Graph]
token graph (directed) and undirected version with added weighted degree
tuple[TokenGraph]
token graph (directed) with included undirected version and calculated KPIs
"""
graph = graph.copy()
graph.perform_static_analysis()
@@ -559,12 +559,12 @@ class TokenGraph(DiGraph):
return hash(self.__key())
"""
def copy(self) -> TokenGraph:
def copy(self) -> Self:
"""returns a (deep) copy of the graph
Returns
-------
TokenGraph
Self
deep copy of the graph
"""
return copy.deepcopy(self)
@@ -669,7 +669,7 @@ class TokenGraph(DiGraph):
return token_graph, undirected
def perform_static_analysis(self):
def perform_static_analysis(self) -> None:
"""calculate different metrics directly on the data of the underlying graphs
(directed and undirected)
@@ -717,16 +717,11 @@ class TokenGraph(DiGraph):
saving_path = self._save_prepare(path=path, filename=filename)
if directed:
target_graph = self._directed
elif not directed and self._undirected is not None:
target_graph = self._undirected
target_graph = self.directed
else:
raise ValueError('No undirected graph available.')
target_graph = self.undirected
save_to_GraphML(graph=target_graph, saving_path=saving_path)
# saving_path = saving_path.with_suffix('.graphml')
# nx.write_graphml(G=target_graph, path=saving_path)
# logger.info('Successfully saved graph as GraphML file under %s.', saving_path)
def to_pickle(
self,
@@ -743,13 +738,14 @@ class TokenGraph(DiGraph):
filename to be given, by default None
"""
saving_path = self._save_prepare(path=path, filename=filename)
saving_path = saving_path.with_suffix('.pickle')
saving_path = saving_path.with_suffix('.pkl')
save_pickle(obj=self, path=saving_path)
@classmethod
def from_file(
cls,
path: Path,
node_type_graphml: type = str,
) -> Self:
# !! no validity checks for pickle files
# !! GraphML files not correct because not all properties
@@ -757,7 +753,7 @@ class TokenGraph(DiGraph):
# TODO REWORK
match path.suffix:
case '.graphml':
graph = typing.cast(Self, nx.read_graphml(path, node_type=int))
graph = typing.cast(Self, nx.read_graphml(path, node_type=node_type_graphml))
logger.info('Successfully loaded graph from GraphML file %s.', path)
case '.pkl' | '.pickle':
graph = typing.cast(Self, load_pickle(path))
@@ -767,17 +763,18 @@ class TokenGraph(DiGraph):
return graph
@classmethod
def from_pickle(
cls,
path: str | Path,
) -> Self:
if isinstance(path, str):
path = Path(path)
# TODO check removal
# @classmethod
# def from_pickle(
# cls,
# path: str | Path,
# ) -> Self:
# if isinstance(path, str):
# path = Path(path)
if path.suffix not in ('.pkl', '.pickle'):
raise ValueError('File format not supported.')
# if path.suffix not in ('.pkl', '.pickle'):
# raise ValueError('File format not supported.')
graph = typing.cast(Self, load_pickle(path))
# graph = typing.cast(Self, load_pickle(path))
return graph
# return graph

View File

@@ -205,6 +205,30 @@ def numeric_pre_filter_feature(
bound_lower: int | None,
bound_upper: int | None,
) -> tuple[DataFrame]:
"""filter DataFrame for a given numerical feature regarding their bounds
bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained
Parameters
----------
data : DataFrame
DataFrame to filter
feature : str
feature name to filter
bound_lower : int | None
lower bound of values to retain
bound_upper : int | None
upper bound of values to retain
Returns
-------
tuple[DataFrame]
filtered DataFrame
Raises
------
ValueError
if no bounds are provided, at least one bound must be set
"""
if not any([bound_lower, bound_upper]):
raise ValueError('No bounds for filtering provided')
@@ -228,7 +252,7 @@ def numeric_pre_filter_feature(
# a more robust identification of duplicates negating negative side effects
# of several disturbances like typos, escape characters, etc.
# build mapping of embeddings for given model
def merge_similarity_dupl(
def merge_similarity_duplicates(
data: DataFrame,
model: SentenceTransformer,
cos_sim_threshold: float,

View File

@@ -11,6 +11,7 @@ from lang_main.analysis.graphs import (
TokenGraph,
update_graph,
)
from lang_main.analysis.shared import pattern_dates
from lang_main.constants import (
POS_INDIRECT,
POS_OF_INTEREST,
@@ -38,21 +39,40 @@ def is_str_date(
string: str,
fuzzy: bool = False,
) -> bool:
"""not stable function to test strings for dates, not 100 percent reliable
Parameters
----------
string : str
string to check for dates
fuzzy : bool, optional
whether to use dateutils.parser.pase fuzzy capability, by default False
Returns
-------
bool
indicates whether date was found or not
"""
try:
# check if string is a number
# if length is greater than 8, it is not a date
int(string)
if len(string) > 8:
if len(string) not in {2, 4}:
return False
except ValueError:
# not a number
pass
try:
parse(string, fuzzy=fuzzy)
parse(string, fuzzy=fuzzy, dayfirst=True, yearfirst=False)
return True
except ValueError:
return False
date_found: bool = False
match = pattern_dates.search(string)
if match is None:
return date_found
date_found = any(match.groups())
return date_found
def obtain_relevant_descendants(
@@ -106,7 +126,7 @@ def add_doc_info_to_graph(
if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
continue
# skip token which are dates or times
if is_str_date(string=token.text):
if token.pos_ == 'NUM' and is_str_date(string=token.text):
continue
relevant_descendants = obtain_relevant_descendants(token=token)
@@ -252,32 +272,33 @@ def build_token_graph_simple(
return graph, docs_mapping
def build_token_graph_old(
data: DataFrame,
model: SpacyModel,
) -> tuple[TokenGraph]:
# empty NetworkX directed graph
# graph = nx.DiGraph()
graph = TokenGraph()
# TODO check removal
# def build_token_graph_old(
# data: DataFrame,
# model: SpacyModel,
# ) -> tuple[TokenGraph]:
# # empty NetworkX directed graph
# # graph = nx.DiGraph()
# graph = TokenGraph()
for row in tqdm(data.itertuples(), total=len(data)):
# obtain properties from tuple
# attribute names must match with preprocessed data
entry_text = cast(str, row.entry)
weight = cast(int, row.num_occur)
# for row in tqdm(data.itertuples(), total=len(data)):
# # obtain properties from tuple
# # attribute names must match with preprocessed data
# entry_text = cast(str, row.entry)
# weight = cast(int, row.num_occur)
# get spacy model output
doc = model(entry_text)
# # get spacy model output
# doc = model(entry_text)
add_doc_info_to_graph(
graph=graph,
doc=doc,
weight=weight,
)
# add_doc_info_to_graph(
# graph=graph,
# doc=doc,
# weight=weight,
# )
# metadata
graph.update_metadata()
# convert to undirected
graph.to_undirected()
# # metadata
# graph.update_metadata()
# # convert to undirected
# graph.to_undirected()
return (graph,)
# return (graph,)

View File

@@ -43,6 +43,9 @@ LOGGING_TO_FILE: Final[bool] = CONFIG['logging']['file']
LOGGING_TO_STDERR: Final[bool] = CONFIG['logging']['stderr']
LOGGING_DEFAULT_GRAPHS: Final[bool] = False
# ** pickling
PICKLE_PROTOCOL_VERSION: Final[int] = 5
# ** paths
input_path_conf = Path.cwd() / Path(CONFIG['paths']['inputs'])
INPUT_PATH_FOLDER: Final[Path] = input_path_conf.resolve()
@@ -91,12 +94,7 @@ else:
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
# ** language dependency analysis
# ** POS
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'ADJ', 'VERB', 'AUX'])
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN'])
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX'])
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV'])
# POS_INDIRECT: frozenset[str] = frozenset(['AUX', 'VERB'])
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV', 'NUM'])
POS_INDIRECT: frozenset[str] = frozenset(['AUX'])
# ** TAG
# TAG_OF_INTEREST: frozenset[str] = frozenset(['ADJD'])

View File

@@ -4,6 +4,7 @@ import shutil
from pathlib import Path
from typing import Any
from lang_main.constants import PICKLE_PROTOCOL_VERSION
from lang_main.loggers import logger_shared_helpers as logger
@@ -39,7 +40,7 @@ def save_pickle(
path: str | Path,
) -> None:
with open(path, 'wb') as file:
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(obj, file, protocol=PICKLE_PROTOCOL_VERSION)
logger.info('Saved file successfully under %s', path)
@@ -56,7 +57,7 @@ def encode_to_base64_str(
obj: Any,
encoding: str = 'utf-8',
) -> str:
serialised = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
serialised = pickle.dumps(obj, protocol=PICKLE_PROTOCOL_VERSION)
b64_bytes = base64.b64encode(serialised)
return b64_bytes.decode(encoding=encoding)

View File

@@ -5,7 +5,7 @@ from lang_main.analysis import graphs
from lang_main.analysis.preprocessing import (
analyse_feature,
load_raw_data,
merge_similarity_dupl,
merge_similarity_duplicates,
numeric_pre_filter_feature,
remove_duplicates,
remove_NA,
@@ -100,7 +100,7 @@ def build_merge_duplicates_pipe() -> Pipeline:
},
)
pipe_merge.add(
merge_similarity_dupl,
merge_similarity_duplicates,
{
'model': STFR_MODEL,
'cos_sim_threshold': THRESHOLD_SIMILARITY,