added new test cases
This commit is contained in:
parent
6781b4a132
commit
5a789b7605
1917
notebooks/misc.ipynb
1917
notebooks/misc.ipynb
File diff suppressed because it is too large
Load Diff
37
notebooks/test.graphml
Normal file
37
notebooks/test.graphml
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
<?xml version='1.0' encoding='utf-8'?>
|
||||||
|
<graphml xmlns="http://graphml.graphdrawing.org/xmlns" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd">
|
||||||
|
<key id="d1" for="edge" attr.name="weight" attr.type="long" />
|
||||||
|
<key id="d0" for="node" attr.name="degree_weighted" attr.type="long" />
|
||||||
|
<graph edgedefault="directed">
|
||||||
|
<node id="1">
|
||||||
|
<data key="d0">14</data>
|
||||||
|
</node>
|
||||||
|
<node id="2">
|
||||||
|
<data key="d0">10</data>
|
||||||
|
</node>
|
||||||
|
<node id="3">
|
||||||
|
<data key="d0">6</data>
|
||||||
|
</node>
|
||||||
|
<node id="4">
|
||||||
|
<data key="d0">12</data>
|
||||||
|
</node>
|
||||||
|
<edge source="1" target="2">
|
||||||
|
<data key="d1">1</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="1" target="3">
|
||||||
|
<data key="d1">2</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="1" target="4">
|
||||||
|
<data key="d1">5</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="2" target="4">
|
||||||
|
<data key="d1">3</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="2" target="1">
|
||||||
|
<data key="d1">6</data>
|
||||||
|
</edge>
|
||||||
|
<edge source="3" target="4">
|
||||||
|
<data key="d1">4</data>
|
||||||
|
</edge>
|
||||||
|
</graph>
|
||||||
|
</graphml>
|
||||||
@ -118,6 +118,8 @@ exclude_also = [
|
|||||||
"def __repr__",
|
"def __repr__",
|
||||||
"def __str__",
|
"def __str__",
|
||||||
"@overload",
|
"@overload",
|
||||||
|
"if logging",
|
||||||
|
"if TYPE_CHECKING",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.coverage.html]
|
[tool.coverage.html]
|
||||||
|
|||||||
@ -198,8 +198,10 @@ def filter_graph_by_edge_weight(
|
|||||||
graph: TokenGraph,
|
graph: TokenGraph,
|
||||||
bound_lower: int | None,
|
bound_lower: int | None,
|
||||||
bound_upper: int | None,
|
bound_upper: int | None,
|
||||||
|
property: str = 'weight',
|
||||||
) -> TokenGraph:
|
) -> TokenGraph:
|
||||||
"""filters all edges which are within the provided bounds
|
"""filters all edges which are within the provided bounds
|
||||||
|
inclusive limits: bound_lower <= edge_weight <= bound_upper are retained
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -216,12 +218,12 @@ def filter_graph_by_edge_weight(
|
|||||||
original_graph_edges = copy.deepcopy(graph.edges)
|
original_graph_edges = copy.deepcopy(graph.edges)
|
||||||
filtered_graph = graph.copy()
|
filtered_graph = graph.copy()
|
||||||
|
|
||||||
if not any([bound_lower, bound_upper]):
|
if not any((bound_lower, bound_upper)):
|
||||||
logger.warning('No bounds provided, returning original graph.')
|
logger.warning('No bounds provided, returning original graph.')
|
||||||
return filtered_graph
|
return filtered_graph
|
||||||
|
|
||||||
for edge in original_graph_edges:
|
for edge in original_graph_edges:
|
||||||
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]]['weight'])
|
weight = typing.cast(int, filtered_graph[edge[0]][edge[1]][property])
|
||||||
if bound_lower is not None and weight < bound_lower:
|
if bound_lower is not None and weight < bound_lower:
|
||||||
filtered_graph.remove_edge(edge[0], edge[1])
|
filtered_graph.remove_edge(edge[0], edge[1])
|
||||||
if bound_upper is not None and weight > bound_upper:
|
if bound_upper is not None and weight > bound_upper:
|
||||||
@ -329,14 +331,12 @@ def static_graph_analysis(
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
tk_graph_directed : TokenGraph
|
tk_graph_directed : TokenGraph
|
||||||
token graph (directed) and with rescaled edge weights
|
token graph (directed)
|
||||||
tk_graph_undirected : Graph
|
|
||||||
token graph (undirected) and with rescaled edge weights
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
tuple[TokenGraph, Graph]
|
tuple[TokenGraph]
|
||||||
token graph (directed) and undirected version with added weighted degree
|
token graph (directed) with included undirected version and calculated KPIs
|
||||||
"""
|
"""
|
||||||
graph = graph.copy()
|
graph = graph.copy()
|
||||||
graph.perform_static_analysis()
|
graph.perform_static_analysis()
|
||||||
@ -559,12 +559,12 @@ class TokenGraph(DiGraph):
|
|||||||
return hash(self.__key())
|
return hash(self.__key())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def copy(self) -> TokenGraph:
|
def copy(self) -> Self:
|
||||||
"""returns a (deep) copy of the graph
|
"""returns a (deep) copy of the graph
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
TokenGraph
|
Self
|
||||||
deep copy of the graph
|
deep copy of the graph
|
||||||
"""
|
"""
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
@ -669,7 +669,7 @@ class TokenGraph(DiGraph):
|
|||||||
|
|
||||||
return token_graph, undirected
|
return token_graph, undirected
|
||||||
|
|
||||||
def perform_static_analysis(self):
|
def perform_static_analysis(self) -> None:
|
||||||
"""calculate different metrics directly on the data of the underlying graphs
|
"""calculate different metrics directly on the data of the underlying graphs
|
||||||
(directed and undirected)
|
(directed and undirected)
|
||||||
|
|
||||||
@ -717,16 +717,11 @@ class TokenGraph(DiGraph):
|
|||||||
saving_path = self._save_prepare(path=path, filename=filename)
|
saving_path = self._save_prepare(path=path, filename=filename)
|
||||||
|
|
||||||
if directed:
|
if directed:
|
||||||
target_graph = self._directed
|
target_graph = self.directed
|
||||||
elif not directed and self._undirected is not None:
|
|
||||||
target_graph = self._undirected
|
|
||||||
else:
|
else:
|
||||||
raise ValueError('No undirected graph available.')
|
target_graph = self.undirected
|
||||||
|
|
||||||
save_to_GraphML(graph=target_graph, saving_path=saving_path)
|
save_to_GraphML(graph=target_graph, saving_path=saving_path)
|
||||||
# saving_path = saving_path.with_suffix('.graphml')
|
|
||||||
# nx.write_graphml(G=target_graph, path=saving_path)
|
|
||||||
# logger.info('Successfully saved graph as GraphML file under %s.', saving_path)
|
|
||||||
|
|
||||||
def to_pickle(
|
def to_pickle(
|
||||||
self,
|
self,
|
||||||
@ -743,13 +738,14 @@ class TokenGraph(DiGraph):
|
|||||||
filename to be given, by default None
|
filename to be given, by default None
|
||||||
"""
|
"""
|
||||||
saving_path = self._save_prepare(path=path, filename=filename)
|
saving_path = self._save_prepare(path=path, filename=filename)
|
||||||
saving_path = saving_path.with_suffix('.pickle')
|
saving_path = saving_path.with_suffix('.pkl')
|
||||||
save_pickle(obj=self, path=saving_path)
|
save_pickle(obj=self, path=saving_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(
|
def from_file(
|
||||||
cls,
|
cls,
|
||||||
path: Path,
|
path: Path,
|
||||||
|
node_type_graphml: type = str,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
# !! no validity checks for pickle files
|
# !! no validity checks for pickle files
|
||||||
# !! GraphML files not correct because not all properties
|
# !! GraphML files not correct because not all properties
|
||||||
@ -757,7 +753,7 @@ class TokenGraph(DiGraph):
|
|||||||
# TODO REWORK
|
# TODO REWORK
|
||||||
match path.suffix:
|
match path.suffix:
|
||||||
case '.graphml':
|
case '.graphml':
|
||||||
graph = typing.cast(Self, nx.read_graphml(path, node_type=int))
|
graph = typing.cast(Self, nx.read_graphml(path, node_type=node_type_graphml))
|
||||||
logger.info('Successfully loaded graph from GraphML file %s.', path)
|
logger.info('Successfully loaded graph from GraphML file %s.', path)
|
||||||
case '.pkl' | '.pickle':
|
case '.pkl' | '.pickle':
|
||||||
graph = typing.cast(Self, load_pickle(path))
|
graph = typing.cast(Self, load_pickle(path))
|
||||||
@ -767,17 +763,18 @@ class TokenGraph(DiGraph):
|
|||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
@classmethod
|
# TODO check removal
|
||||||
def from_pickle(
|
# @classmethod
|
||||||
cls,
|
# def from_pickle(
|
||||||
path: str | Path,
|
# cls,
|
||||||
) -> Self:
|
# path: str | Path,
|
||||||
if isinstance(path, str):
|
# ) -> Self:
|
||||||
path = Path(path)
|
# if isinstance(path, str):
|
||||||
|
# path = Path(path)
|
||||||
|
|
||||||
if path.suffix not in ('.pkl', '.pickle'):
|
# if path.suffix not in ('.pkl', '.pickle'):
|
||||||
raise ValueError('File format not supported.')
|
# raise ValueError('File format not supported.')
|
||||||
|
|
||||||
graph = typing.cast(Self, load_pickle(path))
|
# graph = typing.cast(Self, load_pickle(path))
|
||||||
|
|
||||||
return graph
|
# return graph
|
||||||
|
|||||||
@ -205,6 +205,30 @@ def numeric_pre_filter_feature(
|
|||||||
bound_lower: int | None,
|
bound_lower: int | None,
|
||||||
bound_upper: int | None,
|
bound_upper: int | None,
|
||||||
) -> tuple[DataFrame]:
|
) -> tuple[DataFrame]:
|
||||||
|
"""filter DataFrame for a given numerical feature regarding their bounds
|
||||||
|
bounds are inclusive: entries (bound_lower <= entry <= bound_upper) are retained
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : DataFrame
|
||||||
|
DataFrame to filter
|
||||||
|
feature : str
|
||||||
|
feature name to filter
|
||||||
|
bound_lower : int | None
|
||||||
|
lower bound of values to retain
|
||||||
|
bound_upper : int | None
|
||||||
|
upper bound of values to retain
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[DataFrame]
|
||||||
|
filtered DataFrame
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
if no bounds are provided, at least one bound must be set
|
||||||
|
"""
|
||||||
if not any([bound_lower, bound_upper]):
|
if not any([bound_lower, bound_upper]):
|
||||||
raise ValueError('No bounds for filtering provided')
|
raise ValueError('No bounds for filtering provided')
|
||||||
|
|
||||||
@ -228,7 +252,7 @@ def numeric_pre_filter_feature(
|
|||||||
# a more robust identification of duplicates negating negative side effects
|
# a more robust identification of duplicates negating negative side effects
|
||||||
# of several disturbances like typos, escape characters, etc.
|
# of several disturbances like typos, escape characters, etc.
|
||||||
# build mapping of embeddings for given model
|
# build mapping of embeddings for given model
|
||||||
def merge_similarity_dupl(
|
def merge_similarity_duplicates(
|
||||||
data: DataFrame,
|
data: DataFrame,
|
||||||
model: SentenceTransformer,
|
model: SentenceTransformer,
|
||||||
cos_sim_threshold: float,
|
cos_sim_threshold: float,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from lang_main.analysis.graphs import (
|
|||||||
TokenGraph,
|
TokenGraph,
|
||||||
update_graph,
|
update_graph,
|
||||||
)
|
)
|
||||||
|
from lang_main.analysis.shared import pattern_dates
|
||||||
from lang_main.constants import (
|
from lang_main.constants import (
|
||||||
POS_INDIRECT,
|
POS_INDIRECT,
|
||||||
POS_OF_INTEREST,
|
POS_OF_INTEREST,
|
||||||
@ -38,21 +39,40 @@ def is_str_date(
|
|||||||
string: str,
|
string: str,
|
||||||
fuzzy: bool = False,
|
fuzzy: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
"""not stable function to test strings for dates, not 100 percent reliable
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
string : str
|
||||||
|
string to check for dates
|
||||||
|
fuzzy : bool, optional
|
||||||
|
whether to use dateutils.parser.pase fuzzy capability, by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
indicates whether date was found or not
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# check if string is a number
|
# check if string is a number
|
||||||
# if length is greater than 8, it is not a date
|
# if length is greater than 8, it is not a date
|
||||||
int(string)
|
int(string)
|
||||||
if len(string) > 8:
|
if len(string) not in {2, 4}:
|
||||||
return False
|
return False
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# not a number
|
# not a number
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parse(string, fuzzy=fuzzy)
|
parse(string, fuzzy=fuzzy, dayfirst=True, yearfirst=False)
|
||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
date_found: bool = False
|
||||||
|
match = pattern_dates.search(string)
|
||||||
|
if match is None:
|
||||||
|
return date_found
|
||||||
|
date_found = any(match.groups())
|
||||||
|
return date_found
|
||||||
|
|
||||||
|
|
||||||
def obtain_relevant_descendants(
|
def obtain_relevant_descendants(
|
||||||
@ -106,7 +126,7 @@ def add_doc_info_to_graph(
|
|||||||
if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
|
if not (token.pos_ in POS_OF_INTEREST or token.tag_ in TAG_OF_INTEREST):
|
||||||
continue
|
continue
|
||||||
# skip token which are dates or times
|
# skip token which are dates or times
|
||||||
if is_str_date(string=token.text):
|
if token.pos_ == 'NUM' and is_str_date(string=token.text):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
relevant_descendants = obtain_relevant_descendants(token=token)
|
relevant_descendants = obtain_relevant_descendants(token=token)
|
||||||
@ -252,32 +272,33 @@ def build_token_graph_simple(
|
|||||||
return graph, docs_mapping
|
return graph, docs_mapping
|
||||||
|
|
||||||
|
|
||||||
def build_token_graph_old(
|
# TODO check removal
|
||||||
data: DataFrame,
|
# def build_token_graph_old(
|
||||||
model: SpacyModel,
|
# data: DataFrame,
|
||||||
) -> tuple[TokenGraph]:
|
# model: SpacyModel,
|
||||||
# empty NetworkX directed graph
|
# ) -> tuple[TokenGraph]:
|
||||||
# graph = nx.DiGraph()
|
# # empty NetworkX directed graph
|
||||||
graph = TokenGraph()
|
# # graph = nx.DiGraph()
|
||||||
|
# graph = TokenGraph()
|
||||||
|
|
||||||
for row in tqdm(data.itertuples(), total=len(data)):
|
# for row in tqdm(data.itertuples(), total=len(data)):
|
||||||
# obtain properties from tuple
|
# # obtain properties from tuple
|
||||||
# attribute names must match with preprocessed data
|
# # attribute names must match with preprocessed data
|
||||||
entry_text = cast(str, row.entry)
|
# entry_text = cast(str, row.entry)
|
||||||
weight = cast(int, row.num_occur)
|
# weight = cast(int, row.num_occur)
|
||||||
|
|
||||||
# get spacy model output
|
# # get spacy model output
|
||||||
doc = model(entry_text)
|
# doc = model(entry_text)
|
||||||
|
|
||||||
add_doc_info_to_graph(
|
# add_doc_info_to_graph(
|
||||||
graph=graph,
|
# graph=graph,
|
||||||
doc=doc,
|
# doc=doc,
|
||||||
weight=weight,
|
# weight=weight,
|
||||||
)
|
# )
|
||||||
|
|
||||||
# metadata
|
# # metadata
|
||||||
graph.update_metadata()
|
# graph.update_metadata()
|
||||||
# convert to undirected
|
# # convert to undirected
|
||||||
graph.to_undirected()
|
# graph.to_undirected()
|
||||||
|
|
||||||
return (graph,)
|
# return (graph,)
|
||||||
|
|||||||
@ -43,6 +43,9 @@ LOGGING_TO_FILE: Final[bool] = CONFIG['logging']['file']
|
|||||||
LOGGING_TO_STDERR: Final[bool] = CONFIG['logging']['stderr']
|
LOGGING_TO_STDERR: Final[bool] = CONFIG['logging']['stderr']
|
||||||
LOGGING_DEFAULT_GRAPHS: Final[bool] = False
|
LOGGING_DEFAULT_GRAPHS: Final[bool] = False
|
||||||
|
|
||||||
|
# ** pickling
|
||||||
|
PICKLE_PROTOCOL_VERSION: Final[int] = 5
|
||||||
|
|
||||||
# ** paths
|
# ** paths
|
||||||
input_path_conf = Path.cwd() / Path(CONFIG['paths']['inputs'])
|
input_path_conf = Path.cwd() / Path(CONFIG['paths']['inputs'])
|
||||||
INPUT_PATH_FOLDER: Final[Path] = input_path_conf.resolve()
|
INPUT_PATH_FOLDER: Final[Path] = input_path_conf.resolve()
|
||||||
@ -91,12 +94,7 @@ else:
|
|||||||
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
STFR_MODEL_ARGS: Final[STFRModelArgs] = stfr_model_args
|
||||||
# ** language dependency analysis
|
# ** language dependency analysis
|
||||||
# ** POS
|
# ** POS
|
||||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'ADJ', 'VERB', 'AUX'])
|
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV', 'NUM'])
|
||||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'ADJ', 'VERB', 'AUX'])
|
|
||||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN'])
|
|
||||||
# POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX'])
|
|
||||||
POS_OF_INTEREST: frozenset[str] = frozenset(['NOUN', 'PROPN', 'VERB', 'AUX', 'ADV'])
|
|
||||||
# POS_INDIRECT: frozenset[str] = frozenset(['AUX', 'VERB'])
|
|
||||||
POS_INDIRECT: frozenset[str] = frozenset(['AUX'])
|
POS_INDIRECT: frozenset[str] = frozenset(['AUX'])
|
||||||
# ** TAG
|
# ** TAG
|
||||||
# TAG_OF_INTEREST: frozenset[str] = frozenset(['ADJD'])
|
# TAG_OF_INTEREST: frozenset[str] = frozenset(['ADJD'])
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import shutil
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from lang_main.constants import PICKLE_PROTOCOL_VERSION
|
||||||
from lang_main.loggers import logger_shared_helpers as logger
|
from lang_main.loggers import logger_shared_helpers as logger
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ def save_pickle(
|
|||||||
path: str | Path,
|
path: str | Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
with open(path, 'wb') as file:
|
with open(path, 'wb') as file:
|
||||||
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
|
pickle.dump(obj, file, protocol=PICKLE_PROTOCOL_VERSION)
|
||||||
logger.info('Saved file successfully under %s', path)
|
logger.info('Saved file successfully under %s', path)
|
||||||
|
|
||||||
|
|
||||||
@ -56,7 +57,7 @@ def encode_to_base64_str(
|
|||||||
obj: Any,
|
obj: Any,
|
||||||
encoding: str = 'utf-8',
|
encoding: str = 'utf-8',
|
||||||
) -> str:
|
) -> str:
|
||||||
serialised = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
serialised = pickle.dumps(obj, protocol=PICKLE_PROTOCOL_VERSION)
|
||||||
b64_bytes = base64.b64encode(serialised)
|
b64_bytes = base64.b64encode(serialised)
|
||||||
return b64_bytes.decode(encoding=encoding)
|
return b64_bytes.decode(encoding=encoding)
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from lang_main.analysis import graphs
|
|||||||
from lang_main.analysis.preprocessing import (
|
from lang_main.analysis.preprocessing import (
|
||||||
analyse_feature,
|
analyse_feature,
|
||||||
load_raw_data,
|
load_raw_data,
|
||||||
merge_similarity_dupl,
|
merge_similarity_duplicates,
|
||||||
numeric_pre_filter_feature,
|
numeric_pre_filter_feature,
|
||||||
remove_duplicates,
|
remove_duplicates,
|
||||||
remove_NA,
|
remove_NA,
|
||||||
@ -100,7 +100,7 @@ def build_merge_duplicates_pipe() -> Pipeline:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
pipe_merge.add(
|
pipe_merge.add(
|
||||||
merge_similarity_dupl,
|
merge_similarity_duplicates,
|
||||||
{
|
{
|
||||||
'model': STFR_MODEL,
|
'model': STFR_MODEL,
|
||||||
'cos_sim_threshold': THRESHOLD_SIMILARITY,
|
'cos_sim_threshold': THRESHOLD_SIMILARITY,
|
||||||
|
|||||||
BIN
tests/_comparison_results/analyse_feature.pkl
Normal file
BIN
tests/_comparison_results/analyse_feature.pkl
Normal file
Binary file not shown.
BIN
tests/_comparison_results/merge_cands.xlsx
Normal file
BIN
tests/_comparison_results/merge_cands.xlsx
Normal file
Binary file not shown.
BIN
tests/_comparison_results/merge_similarity_candidates.pkl
Normal file
BIN
tests/_comparison_results/merge_similarity_candidates.pkl
Normal file
Binary file not shown.
BIN
tests/_comparison_results/numeric_pre_filter.pkl
Normal file
BIN
tests/_comparison_results/numeric_pre_filter.pkl
Normal file
Binary file not shown.
BIN
tests/_comparison_results/tk_graph_built.pkl
Normal file
BIN
tests/_comparison_results/tk_graph_built.pkl
Normal file
Binary file not shown.
@ -2,6 +2,7 @@ import networkx as nx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from lang_main.analysis import graphs
|
from lang_main.analysis import graphs
|
||||||
|
from lang_main.errors import EmptyEdgesError, EmptyGraphError, EdgePropertyNotContainedError
|
||||||
|
|
||||||
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
|
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
|
||||||
|
|
||||||
@ -40,13 +41,18 @@ def build_init_graph(token_graph: bool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def graph():
|
def graph() -> graphs.DiGraph:
|
||||||
return build_init_graph(token_graph=False)
|
return build_init_graph(token_graph=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='module')
|
@pytest.fixture(scope='module')
|
||||||
def tk_graph():
|
def tk_graph() -> graphs.TokenGraph:
|
||||||
return build_init_graph(token_graph=True)
|
return build_init_graph(token_graph=True) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def tk_graph_undirected(tk_graph) -> graphs.Graph:
|
||||||
|
return tk_graph.undirected
|
||||||
|
|
||||||
|
|
||||||
def test_graph_size(graph):
|
def test_graph_size(graph):
|
||||||
@ -61,7 +67,45 @@ def test_save_to_GraphML(graph, tmp_path):
|
|||||||
assert saved_file.exists()
|
assert saved_file.exists()
|
||||||
|
|
||||||
|
|
||||||
def test_metadata_retrieval(graph):
|
def test_save_load_pickle_tk_graph(tk_graph, tmp_path):
|
||||||
|
filename = 'test_save_tkg'
|
||||||
|
tk_graph.to_pickle(tmp_path, filename)
|
||||||
|
load_pth = (tmp_path / filename).with_suffix('.pkl')
|
||||||
|
assert load_pth.exists()
|
||||||
|
loaded_graph = graphs.TokenGraph.from_file(load_pth)
|
||||||
|
assert loaded_graph.nodes == tk_graph.nodes
|
||||||
|
assert loaded_graph.edges == tk_graph.edges
|
||||||
|
filename = None
|
||||||
|
tk_graph.to_pickle(tmp_path, filename)
|
||||||
|
load_pth = (tmp_path / tk_graph.name).with_suffix('.pkl')
|
||||||
|
assert load_pth.exists()
|
||||||
|
loaded_graph = graphs.TokenGraph.from_file(load_pth)
|
||||||
|
assert loaded_graph.nodes == tk_graph.nodes
|
||||||
|
assert loaded_graph.edges == tk_graph.edges
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'import_graph,directed', [('tk_graph', True), ('tk_graph_undirected', False)]
|
||||||
|
)
|
||||||
|
def test_save_load_GraphML_tk_graph(import_graph, tk_graph, directed, tmp_path, request):
|
||||||
|
test_graph = request.getfixturevalue(import_graph)
|
||||||
|
filename = 'test_save_tkg'
|
||||||
|
tk_graph.to_GraphML(tmp_path, filename, directed=directed)
|
||||||
|
load_pth = (tmp_path / filename).with_suffix('.graphml')
|
||||||
|
assert load_pth.exists()
|
||||||
|
loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
|
||||||
|
assert loaded_graph.nodes == test_graph.nodes
|
||||||
|
assert loaded_graph.edges == test_graph.edges
|
||||||
|
filename = None
|
||||||
|
tk_graph.to_GraphML(tmp_path, filename, directed=directed)
|
||||||
|
load_pth = (tmp_path / tk_graph.name).with_suffix('.graphml')
|
||||||
|
assert load_pth.exists()
|
||||||
|
loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
|
||||||
|
assert loaded_graph.nodes == test_graph.nodes
|
||||||
|
assert loaded_graph.edges == test_graph.edges
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_graph_metadata(graph):
|
||||||
metadata = graphs.get_graph_metadata(graph)
|
metadata = graphs.get_graph_metadata(graph)
|
||||||
assert metadata['num_nodes'] == 4
|
assert metadata['num_nodes'] == 4
|
||||||
assert metadata['num_edges'] == 6
|
assert metadata['num_edges'] == 6
|
||||||
@ -72,7 +116,7 @@ def test_metadata_retrieval(graph):
|
|||||||
assert metadata['total_memory'] == 448
|
assert metadata['total_memory'] == 448
|
||||||
|
|
||||||
|
|
||||||
def test_graph_update_batch():
|
def test_update_graph_batch():
|
||||||
graph_obj = build_init_graph(token_graph=False)
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
graphs.update_graph(graph_obj, batch=((4, 5), (5, 6)), weight_connection=8)
|
graphs.update_graph(graph_obj, batch=((4, 5), (5, 6)), weight_connection=8)
|
||||||
metadata = graphs.get_graph_metadata(graph_obj)
|
metadata = graphs.get_graph_metadata(graph_obj)
|
||||||
@ -82,7 +126,7 @@ def test_graph_update_batch():
|
|||||||
assert metadata['max_edge_weight'] == 8
|
assert metadata['max_edge_weight'] == 8
|
||||||
|
|
||||||
|
|
||||||
def test_graph_update_single_new():
|
def test_update_graph_single_new():
|
||||||
graph_obj = build_init_graph(token_graph=False)
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
graphs.update_graph(graph_obj, parent=4, child=5, weight_connection=7)
|
graphs.update_graph(graph_obj, parent=4, child=5, weight_connection=7)
|
||||||
metadata = graphs.get_graph_metadata(graph_obj)
|
metadata = graphs.get_graph_metadata(graph_obj)
|
||||||
@ -92,7 +136,7 @@ def test_graph_update_single_new():
|
|||||||
assert metadata['max_edge_weight'] == 7
|
assert metadata['max_edge_weight'] == 7
|
||||||
|
|
||||||
|
|
||||||
def test_graph_update_single_existing():
|
def test_update_graph_single_existing():
|
||||||
graph_obj = build_init_graph(token_graph=False)
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
graphs.update_graph(graph_obj, parent=1, child=4, weight_connection=5)
|
graphs.update_graph(graph_obj, parent=1, child=4, weight_connection=5)
|
||||||
metadata = graphs.get_graph_metadata(graph_obj)
|
metadata = graphs.get_graph_metadata(graph_obj)
|
||||||
@ -103,13 +147,13 @@ def test_graph_update_single_existing():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('cast_int', [True, False])
|
@pytest.mark.parametrize('cast_int', [True, False])
|
||||||
def test_graph_undirected_conversion(graph, cast_int):
|
def test_convert_graph_to_undirected(graph, cast_int):
|
||||||
graph_undir = graphs.convert_graph_to_undirected(graph, cast_int=cast_int)
|
graph_undir = graphs.convert_graph_to_undirected(graph, cast_int=cast_int)
|
||||||
# edges: (1, 2, w=1) und (2, 1, w=6) --> undirected: (1, 2, w=7)
|
# edges: (1, 2, w=1) und (2, 1, w=6) --> undirected: (1, 2, w=7)
|
||||||
assert graph_undir[1][2]['weight'] == pytest.approx(7.0)
|
assert graph_undir[1][2]['weight'] == pytest.approx(7.0)
|
||||||
|
|
||||||
|
|
||||||
def test_graph_cytoscape_conversion(graph):
|
def test_convert_graph_to_cytoscape(graph):
|
||||||
cyto_graph, weight_data = graphs.convert_graph_to_cytoscape(graph)
|
cyto_graph, weight_data = graphs.convert_graph_to_cytoscape(graph)
|
||||||
node = cyto_graph[0]
|
node = cyto_graph[0]
|
||||||
edge = cyto_graph[-1]
|
edge = cyto_graph[-1]
|
||||||
@ -144,7 +188,17 @@ def test_tk_graph_properties(tk_graph):
|
|||||||
assert metadata_undirected['total_memory'] == 392
|
assert metadata_undirected['total_memory'] == 392
|
||||||
|
|
||||||
|
|
||||||
def test_graph_degree_filter(tk_graph):
|
def test_filter_graph_by_edge_weight(tk_graph):
|
||||||
|
filtered_graph = graphs.filter_graph_by_edge_weight(
|
||||||
|
tk_graph,
|
||||||
|
bound_lower=2,
|
||||||
|
bound_upper=5,
|
||||||
|
)
|
||||||
|
assert not filtered_graph.has_edge(1, 2)
|
||||||
|
assert not filtered_graph.has_edge(2, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_graph_by_node_degree(tk_graph):
|
||||||
filtered_graph = graphs.filter_graph_by_node_degree(
|
filtered_graph = graphs.filter_graph_by_node_degree(
|
||||||
tk_graph,
|
tk_graph,
|
||||||
bound_lower=3,
|
bound_lower=3,
|
||||||
@ -153,7 +207,7 @@ def test_graph_degree_filter(tk_graph):
|
|||||||
assert len(filtered_graph.nodes) == 2
|
assert len(filtered_graph.nodes) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_graph_edge_number_filter(tk_graph):
|
def test_filter_graph_by_number_edges(tk_graph):
|
||||||
number_edges_limit = 1
|
number_edges_limit = 1
|
||||||
filtered_graph = graphs.filter_graph_by_number_edges(
|
filtered_graph = graphs.filter_graph_by_number_edges(
|
||||||
tk_graph,
|
tk_graph,
|
||||||
@ -166,3 +220,75 @@ def test_graph_edge_number_filter(tk_graph):
|
|||||||
bound_upper=None,
|
bound_upper=None,
|
||||||
)
|
)
|
||||||
assert len(filtered_graph.nodes) == 2, 'one edge should result in only two nodes'
|
assert len(filtered_graph.nodes) == 2, 'one edge should result in only two nodes'
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_weighted_degree():
|
||||||
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
|
property_name = 'degree_weighted'
|
||||||
|
graphs.add_weighted_degree(graph_obj, 'weight', property_name)
|
||||||
|
assert graph_obj.nodes[1][property_name] == 14
|
||||||
|
assert graph_obj.nodes[2][property_name] == 10
|
||||||
|
assert graph_obj.nodes[3][property_name] == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_graph_analysis():
|
||||||
|
graph_obj = build_init_graph(token_graph=True)
|
||||||
|
(graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
|
||||||
|
property_name = 'degree_weighted'
|
||||||
|
assert graph_obj.nodes[1][property_name] == 14
|
||||||
|
assert graph_obj.nodes[2][property_name] == 10
|
||||||
|
assert graph_obj.nodes[3][property_name] == 6
|
||||||
|
assert graph_obj.undirected.nodes[1][property_name] == 14
|
||||||
|
assert graph_obj.undirected.nodes[2][property_name] == 10
|
||||||
|
assert graph_obj.undirected.nodes[3][property_name] == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_add_graph_metrics():
|
||||||
|
graph_obj = build_init_graph(token_graph=False)
|
||||||
|
graph_obj_undir = graphs.convert_graph_to_undirected(graph_obj, cast_int=True)
|
||||||
|
graph_collection = graphs.pipe_add_graph_metrics(graph_obj, graph_obj_undir)
|
||||||
|
property_name = 'degree_weighted'
|
||||||
|
assert graph_collection[0].nodes[1][property_name] == 14
|
||||||
|
assert graph_collection[0].nodes[2][property_name] == 10
|
||||||
|
assert graph_collection[0].nodes[3][property_name] == 6
|
||||||
|
assert graph_collection[1].nodes[1][property_name] == 14
|
||||||
|
assert graph_collection[1].nodes[2][property_name] == 10
|
||||||
|
assert graph_collection[1].nodes[3][property_name] == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_rescale_graph_edge_weights(tk_graph):
|
||||||
|
rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph)
|
||||||
|
assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0)
|
||||||
|
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952)
|
||||||
|
assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0)
|
||||||
|
assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
|
||||||
|
def test_rescale_edge_weights(import_graph, request):
|
||||||
|
test_graph = request.getfixturevalue(import_graph)
|
||||||
|
rescaled_graph = graphs.rescale_edge_weights(test_graph)
|
||||||
|
assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0)
|
||||||
|
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
|
||||||
|
def test_verify_property(import_graph, request):
|
||||||
|
test_graph = request.getfixturevalue(import_graph)
|
||||||
|
test_property = 'centrality'
|
||||||
|
with pytest.raises(EdgePropertyNotContainedError):
|
||||||
|
graphs.verify_property(test_graph, property=test_property)
|
||||||
|
test_property = 'weight'
|
||||||
|
assert not graphs.verify_property(test_graph, property=test_property)
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_non_empty_graph():
|
||||||
|
graph = nx.Graph()
|
||||||
|
with pytest.raises(EmptyGraphError):
|
||||||
|
graphs.verify_non_empty_graph(graph)
|
||||||
|
graph.add_nodes_from([1, 2, 3, 4])
|
||||||
|
with pytest.raises(EmptyEdgesError):
|
||||||
|
graphs.verify_non_empty_graph(graph, including_edges=True)
|
||||||
|
assert not graphs.verify_non_empty_graph(graph, including_edges=False)
|
||||||
|
graph.add_edges_from([(1, 2), (1, 3), (2, 4)])
|
||||||
|
assert not graphs.verify_non_empty_graph(graph, including_edges=True)
|
||||||
|
|||||||
@ -2,8 +2,11 @@
|
|||||||
executed in in a pipeline
|
executed in in a pipeline
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from lang_main import model_loader
|
||||||
from lang_main.analysis import preprocessing as ppc
|
from lang_main.analysis import preprocessing as ppc
|
||||||
from lang_main.analysis import shared
|
from lang_main.analysis import shared
|
||||||
|
from lang_main.types import LanguageModels, STFRModelTypes
|
||||||
|
|
||||||
|
|
||||||
def test_load_data(raw_data_path, raw_data_date_cols):
|
def test_load_data(raw_data_path, raw_data_date_cols):
|
||||||
@ -71,3 +74,43 @@ def test_analyse_feature(raw_data_path, raw_data_date_cols):
|
|||||||
|
|
||||||
(data,) = ppc.analyse_feature(data, target_feature=target_features[0])
|
(data,) = ppc.analyse_feature(data, target_feature=target_features[0])
|
||||||
assert len(data) == 139
|
assert len(data) == 139
|
||||||
|
|
||||||
|
|
||||||
|
def test_numeric_pre_filter_feature(data_analyse_feature, data_numeric_pre_filter_feature):
|
||||||
|
# Dataset contains 139 entries. The feature "len" has a minimum value of 15,
|
||||||
|
# which occurs only once. If all values >= are retained only one entry should be
|
||||||
|
# filtered. This results in a total number of 138 entries.
|
||||||
|
(data,) = ppc.numeric_pre_filter_feature(
|
||||||
|
data=data_analyse_feature,
|
||||||
|
feature='len',
|
||||||
|
bound_lower=16,
|
||||||
|
bound_upper=None,
|
||||||
|
)
|
||||||
|
assert len(data) == 138
|
||||||
|
eval_merged = data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
|
||||||
|
eval_benchmark = data_numeric_pre_filter_feature[
|
||||||
|
['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
|
||||||
|
]
|
||||||
|
assert bool((eval_merged == eval_benchmark).all(axis=None))
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_similarity_duplicates(data_analyse_feature, data_merge_similarity_duplicates):
|
||||||
|
cos_sim_threshold = 0.8
|
||||||
|
# reduce dataset to 10 entries
|
||||||
|
data = data_analyse_feature.iloc[:10]
|
||||||
|
model = model_loader.load_sentence_transformer(
|
||||||
|
model_name=STFRModelTypes.ALL_MPNET_BASE_V2,
|
||||||
|
)
|
||||||
|
(merged_data,) = ppc.merge_similarity_duplicates(
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
cos_sim_threshold=cos_sim_threshold,
|
||||||
|
)
|
||||||
|
# constructed use case: with this threshold,
|
||||||
|
# 2 out of 10 entries are merged into one
|
||||||
|
assert len(merged_data) == 9
|
||||||
|
eval_merged = merged_data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
|
||||||
|
eval_benchmark = data_merge_similarity_duplicates[
|
||||||
|
['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
|
||||||
|
]
|
||||||
|
assert bool((eval_merged == eval_benchmark).all(axis=None))
|
||||||
|
|||||||
79
tests/analysis/test_tokens.py
Normal file
79
tests/analysis/test_tokens.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lang_main import model_loader
|
||||||
|
from lang_main.analysis import graphs, tokens
|
||||||
|
from lang_main.types import SpacyModelTypes
|
||||||
|
|
||||||
|
SENTENCE = (
|
||||||
|
'Ich ging am 22.05. mit ID 0912393 schnell über die Wiese zu einem Menschen, '
|
||||||
|
'um ihm zu helfen. Ich konnte nicht mit ansehen, wie er Probleme beim Tragen '
|
||||||
|
'seiner Tasche hatte.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def spacy_model():
|
||||||
|
model = model_loader.load_spacy(
|
||||||
|
model_name=SpacyModelTypes.DE_CORE_NEWS_SM,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_clean_word():
|
||||||
|
string = 'Öl3bad2024prüfung'
|
||||||
|
assert tokens.pre_clean_word(string) == 'Ölbadprüfung'
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_str_date():
|
||||||
|
string = '22.05.'
|
||||||
|
assert tokens.is_str_date(string, fuzzy=True)
|
||||||
|
string = '22.05.2024'
|
||||||
|
assert tokens.is_str_date(string)
|
||||||
|
string = '22-05-2024'
|
||||||
|
assert tokens.is_str_date(string)
|
||||||
|
string = '9009090909'
|
||||||
|
assert not tokens.is_str_date(string)
|
||||||
|
string = 'hello347'
|
||||||
|
assert not tokens.is_str_date(string)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: depends on fixed Constants
|
||||||
|
def test_obtain_relevant_descendants(spacy_model):
|
||||||
|
doc = spacy_model(SENTENCE)
|
||||||
|
sent1 = tuple(doc.sents)[0] # first sentence
|
||||||
|
word1 = sent1[1] # word "ging" (POS:VERB)
|
||||||
|
descendants1 = ('0912393', 'schnell', 'Wiese', 'Menschen')
|
||||||
|
rel_descs = tokens.obtain_relevant_descendants(word1)
|
||||||
|
rel_descs = tuple((token.text for token in rel_descs))
|
||||||
|
assert descendants1 == rel_descs
|
||||||
|
|
||||||
|
sent2 = tuple(doc.sents)[1] # first sentence
|
||||||
|
word2 = sent2[1] # word "konnte" (POS:AUX)
|
||||||
|
descendants2 = ('mit', 'Probleme', 'Tragen', 'Tasche')
|
||||||
|
rel_descs = tokens.obtain_relevant_descendants(word2)
|
||||||
|
rel_descs = tuple((token.text for token in rel_descs))
|
||||||
|
assert descendants2 == rel_descs
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_doc_info_to_graph(spacy_model):
|
||||||
|
doc = spacy_model(SENTENCE)
|
||||||
|
tk_graph = graphs.TokenGraph()
|
||||||
|
tokens.add_doc_info_to_graph(tk_graph, doc, weight=2)
|
||||||
|
assert len(tk_graph.nodes) == 11
|
||||||
|
assert len(tk_graph.edges) == 17
|
||||||
|
assert '0912393' in tk_graph.nodes
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_token_graph(
|
||||||
|
data_merge_similarity_duplicates,
|
||||||
|
spacy_model,
|
||||||
|
data_tk_graph_built,
|
||||||
|
):
|
||||||
|
tk_graph, _ = tokens.build_token_graph(
|
||||||
|
data=data_merge_similarity_duplicates,
|
||||||
|
model=spacy_model,
|
||||||
|
)
|
||||||
|
assert len(tk_graph.nodes) == len(data_tk_graph_built.nodes)
|
||||||
|
assert len(tk_graph.edges) == len(data_tk_graph_built.edges)
|
||||||
@ -1,5 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from lang_main.analysis import graphs
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
DATE_COLS: tuple[str, ...] = (
|
DATE_COLS: tuple[str, ...] = (
|
||||||
@ -12,7 +14,7 @@ DATE_COLS: tuple[str, ...] = (
|
|||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def raw_data_path():
|
def raw_data_path():
|
||||||
pth_data = Path('./tests/Dummy_Dataset_N_1000.csv')
|
pth_data = Path('./tests/_comparison_results/Dummy_Dataset_N_1000.csv')
|
||||||
assert pth_data.exists()
|
assert pth_data.exists()
|
||||||
|
|
||||||
return pth_data
|
return pth_data
|
||||||
@ -21,3 +23,27 @@ def raw_data_path():
|
|||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def raw_data_date_cols():
|
def raw_data_date_cols():
|
||||||
return DATE_COLS
|
return DATE_COLS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def data_analyse_feature() -> pd.DataFrame:
|
||||||
|
pth_data = Path('./tests/_comparison_results/analyse_feature.pkl')
|
||||||
|
return pd.read_pickle(pth_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def data_numeric_pre_filter_feature() -> pd.DataFrame:
|
||||||
|
pth_data = Path('./tests/_comparison_results/numeric_pre_filter.pkl')
|
||||||
|
return pd.read_pickle(pth_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def data_merge_similarity_duplicates() -> pd.DataFrame:
|
||||||
|
pth_data = Path('./tests/_comparison_results/merge_similarity_candidates.pkl')
|
||||||
|
return pd.read_pickle(pth_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session')
|
||||||
|
def data_tk_graph_built():
|
||||||
|
pth_data = Path('./tests/_comparison_results/tk_graph_built.pkl')
|
||||||
|
return graphs.TokenGraph.from_file(pth_data)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user