added new test cases
This commit is contained in:
@@ -2,6 +2,7 @@ import networkx as nx
|
||||
import pytest
|
||||
|
||||
from lang_main.analysis import graphs
|
||||
from lang_main.errors import EmptyEdgesError, EmptyGraphError, EdgePropertyNotContainedError
|
||||
|
||||
TK_GRAPH_NAME = 'TEST_TOKEN_GRAPH'
|
||||
|
||||
@@ -40,13 +41,18 @@ def build_init_graph(token_graph: bool):
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def graph():
|
||||
def graph() -> graphs.DiGraph:
|
||||
return build_init_graph(token_graph=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def tk_graph():
|
||||
return build_init_graph(token_graph=True)
|
||||
def tk_graph() -> graphs.TokenGraph:
|
||||
return build_init_graph(token_graph=True) # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def tk_graph_undirected(tk_graph) -> graphs.Graph:
|
||||
return tk_graph.undirected
|
||||
|
||||
|
||||
def test_graph_size(graph):
|
||||
@@ -61,7 +67,45 @@ def test_save_to_GraphML(graph, tmp_path):
|
||||
assert saved_file.exists()
|
||||
|
||||
|
||||
def test_metadata_retrieval(graph):
|
||||
def test_save_load_pickle_tk_graph(tk_graph, tmp_path):
|
||||
filename = 'test_save_tkg'
|
||||
tk_graph.to_pickle(tmp_path, filename)
|
||||
load_pth = (tmp_path / filename).with_suffix('.pkl')
|
||||
assert load_pth.exists()
|
||||
loaded_graph = graphs.TokenGraph.from_file(load_pth)
|
||||
assert loaded_graph.nodes == tk_graph.nodes
|
||||
assert loaded_graph.edges == tk_graph.edges
|
||||
filename = None
|
||||
tk_graph.to_pickle(tmp_path, filename)
|
||||
load_pth = (tmp_path / tk_graph.name).with_suffix('.pkl')
|
||||
assert load_pth.exists()
|
||||
loaded_graph = graphs.TokenGraph.from_file(load_pth)
|
||||
assert loaded_graph.nodes == tk_graph.nodes
|
||||
assert loaded_graph.edges == tk_graph.edges
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'import_graph,directed', [('tk_graph', True), ('tk_graph_undirected', False)]
|
||||
)
|
||||
def test_save_load_GraphML_tk_graph(import_graph, tk_graph, directed, tmp_path, request):
|
||||
test_graph = request.getfixturevalue(import_graph)
|
||||
filename = 'test_save_tkg'
|
||||
tk_graph.to_GraphML(tmp_path, filename, directed=directed)
|
||||
load_pth = (tmp_path / filename).with_suffix('.graphml')
|
||||
assert load_pth.exists()
|
||||
loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
|
||||
assert loaded_graph.nodes == test_graph.nodes
|
||||
assert loaded_graph.edges == test_graph.edges
|
||||
filename = None
|
||||
tk_graph.to_GraphML(tmp_path, filename, directed=directed)
|
||||
load_pth = (tmp_path / tk_graph.name).with_suffix('.graphml')
|
||||
assert load_pth.exists()
|
||||
loaded_graph = graphs.TokenGraph.from_file(load_pth, node_type_graphml=int)
|
||||
assert loaded_graph.nodes == test_graph.nodes
|
||||
assert loaded_graph.edges == test_graph.edges
|
||||
|
||||
|
||||
def test_get_graph_metadata(graph):
|
||||
metadata = graphs.get_graph_metadata(graph)
|
||||
assert metadata['num_nodes'] == 4
|
||||
assert metadata['num_edges'] == 6
|
||||
@@ -72,7 +116,7 @@ def test_metadata_retrieval(graph):
|
||||
assert metadata['total_memory'] == 448
|
||||
|
||||
|
||||
def test_graph_update_batch():
|
||||
def test_update_graph_batch():
|
||||
graph_obj = build_init_graph(token_graph=False)
|
||||
graphs.update_graph(graph_obj, batch=((4, 5), (5, 6)), weight_connection=8)
|
||||
metadata = graphs.get_graph_metadata(graph_obj)
|
||||
@@ -82,7 +126,7 @@ def test_graph_update_batch():
|
||||
assert metadata['max_edge_weight'] == 8
|
||||
|
||||
|
||||
def test_graph_update_single_new():
|
||||
def test_update_graph_single_new():
|
||||
graph_obj = build_init_graph(token_graph=False)
|
||||
graphs.update_graph(graph_obj, parent=4, child=5, weight_connection=7)
|
||||
metadata = graphs.get_graph_metadata(graph_obj)
|
||||
@@ -92,7 +136,7 @@ def test_graph_update_single_new():
|
||||
assert metadata['max_edge_weight'] == 7
|
||||
|
||||
|
||||
def test_graph_update_single_existing():
|
||||
def test_update_graph_single_existing():
|
||||
graph_obj = build_init_graph(token_graph=False)
|
||||
graphs.update_graph(graph_obj, parent=1, child=4, weight_connection=5)
|
||||
metadata = graphs.get_graph_metadata(graph_obj)
|
||||
@@ -103,13 +147,13 @@ def test_graph_update_single_existing():
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cast_int', [True, False])
|
||||
def test_graph_undirected_conversion(graph, cast_int):
|
||||
def test_convert_graph_to_undirected(graph, cast_int):
|
||||
graph_undir = graphs.convert_graph_to_undirected(graph, cast_int=cast_int)
|
||||
# edges: (1, 2, w=1) und (2, 1, w=6) --> undirected: (1, 2, w=7)
|
||||
assert graph_undir[1][2]['weight'] == pytest.approx(7.0)
|
||||
|
||||
|
||||
def test_graph_cytoscape_conversion(graph):
|
||||
def test_convert_graph_to_cytoscape(graph):
|
||||
cyto_graph, weight_data = graphs.convert_graph_to_cytoscape(graph)
|
||||
node = cyto_graph[0]
|
||||
edge = cyto_graph[-1]
|
||||
@@ -144,7 +188,17 @@ def test_tk_graph_properties(tk_graph):
|
||||
assert metadata_undirected['total_memory'] == 392
|
||||
|
||||
|
||||
def test_graph_degree_filter(tk_graph):
|
||||
def test_filter_graph_by_edge_weight(tk_graph):
|
||||
filtered_graph = graphs.filter_graph_by_edge_weight(
|
||||
tk_graph,
|
||||
bound_lower=2,
|
||||
bound_upper=5,
|
||||
)
|
||||
assert not filtered_graph.has_edge(1, 2)
|
||||
assert not filtered_graph.has_edge(2, 1)
|
||||
|
||||
|
||||
def test_filter_graph_by_node_degree(tk_graph):
|
||||
filtered_graph = graphs.filter_graph_by_node_degree(
|
||||
tk_graph,
|
||||
bound_lower=3,
|
||||
@@ -153,7 +207,7 @@ def test_graph_degree_filter(tk_graph):
|
||||
assert len(filtered_graph.nodes) == 2
|
||||
|
||||
|
||||
def test_graph_edge_number_filter(tk_graph):
|
||||
def test_filter_graph_by_number_edges(tk_graph):
|
||||
number_edges_limit = 1
|
||||
filtered_graph = graphs.filter_graph_by_number_edges(
|
||||
tk_graph,
|
||||
@@ -166,3 +220,75 @@ def test_graph_edge_number_filter(tk_graph):
|
||||
bound_upper=None,
|
||||
)
|
||||
assert len(filtered_graph.nodes) == 2, 'one edge should result in only two nodes'
|
||||
|
||||
|
||||
def test_add_weighted_degree():
|
||||
graph_obj = build_init_graph(token_graph=False)
|
||||
property_name = 'degree_weighted'
|
||||
graphs.add_weighted_degree(graph_obj, 'weight', property_name)
|
||||
assert graph_obj.nodes[1][property_name] == 14
|
||||
assert graph_obj.nodes[2][property_name] == 10
|
||||
assert graph_obj.nodes[3][property_name] == 6
|
||||
|
||||
|
||||
def test_static_graph_analysis():
|
||||
graph_obj = build_init_graph(token_graph=True)
|
||||
(graph_obj,) = graphs.static_graph_analysis(graph_obj) # type: ignore
|
||||
property_name = 'degree_weighted'
|
||||
assert graph_obj.nodes[1][property_name] == 14
|
||||
assert graph_obj.nodes[2][property_name] == 10
|
||||
assert graph_obj.nodes[3][property_name] == 6
|
||||
assert graph_obj.undirected.nodes[1][property_name] == 14
|
||||
assert graph_obj.undirected.nodes[2][property_name] == 10
|
||||
assert graph_obj.undirected.nodes[3][property_name] == 6
|
||||
|
||||
|
||||
def test_pipe_add_graph_metrics():
|
||||
graph_obj = build_init_graph(token_graph=False)
|
||||
graph_obj_undir = graphs.convert_graph_to_undirected(graph_obj, cast_int=True)
|
||||
graph_collection = graphs.pipe_add_graph_metrics(graph_obj, graph_obj_undir)
|
||||
property_name = 'degree_weighted'
|
||||
assert graph_collection[0].nodes[1][property_name] == 14
|
||||
assert graph_collection[0].nodes[2][property_name] == 10
|
||||
assert graph_collection[0].nodes[3][property_name] == 6
|
||||
assert graph_collection[1].nodes[1][property_name] == 14
|
||||
assert graph_collection[1].nodes[2][property_name] == 10
|
||||
assert graph_collection[1].nodes[3][property_name] == 6
|
||||
|
||||
|
||||
def test_pipe_rescale_graph_edge_weights(tk_graph):
|
||||
rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph)
|
||||
assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952)
|
||||
assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
|
||||
def test_rescale_edge_weights(import_graph, request):
|
||||
test_graph = request.getfixturevalue(import_graph)
|
||||
rescaled_graph = graphs.rescale_edge_weights(test_graph)
|
||||
assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
|
||||
def test_verify_property(import_graph, request):
|
||||
test_graph = request.getfixturevalue(import_graph)
|
||||
test_property = 'centrality'
|
||||
with pytest.raises(EdgePropertyNotContainedError):
|
||||
graphs.verify_property(test_graph, property=test_property)
|
||||
test_property = 'weight'
|
||||
assert not graphs.verify_property(test_graph, property=test_property)
|
||||
|
||||
|
||||
def test_verify_non_empty_graph():
|
||||
graph = nx.Graph()
|
||||
with pytest.raises(EmptyGraphError):
|
||||
graphs.verify_non_empty_graph(graph)
|
||||
graph.add_nodes_from([1, 2, 3, 4])
|
||||
with pytest.raises(EmptyEdgesError):
|
||||
graphs.verify_non_empty_graph(graph, including_edges=True)
|
||||
assert not graphs.verify_non_empty_graph(graph, including_edges=False)
|
||||
graph.add_edges_from([(1, 2), (1, 3), (2, 4)])
|
||||
assert not graphs.verify_non_empty_graph(graph, including_edges=True)
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
executed in in a pipeline
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from lang_main import model_loader
|
||||
from lang_main.analysis import preprocessing as ppc
|
||||
from lang_main.analysis import shared
|
||||
from lang_main.types import LanguageModels, STFRModelTypes
|
||||
|
||||
|
||||
def test_load_data(raw_data_path, raw_data_date_cols):
|
||||
@@ -71,3 +74,43 @@ def test_analyse_feature(raw_data_path, raw_data_date_cols):
|
||||
|
||||
(data,) = ppc.analyse_feature(data, target_feature=target_features[0])
|
||||
assert len(data) == 139
|
||||
|
||||
|
||||
def test_numeric_pre_filter_feature(data_analyse_feature, data_numeric_pre_filter_feature):
|
||||
# Dataset contains 139 entries. The feature "len" has a minimum value of 15,
|
||||
# which occurs only once. If all values >= are retained only one entry should be
|
||||
# filtered. This results in a total number of 138 entries.
|
||||
(data,) = ppc.numeric_pre_filter_feature(
|
||||
data=data_analyse_feature,
|
||||
feature='len',
|
||||
bound_lower=16,
|
||||
bound_upper=None,
|
||||
)
|
||||
assert len(data) == 138
|
||||
eval_merged = data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
|
||||
eval_benchmark = data_numeric_pre_filter_feature[
|
||||
['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
|
||||
]
|
||||
assert bool((eval_merged == eval_benchmark).all(axis=None))
|
||||
|
||||
|
||||
def test_merge_similarity_duplicates(data_analyse_feature, data_merge_similarity_duplicates):
|
||||
cos_sim_threshold = 0.8
|
||||
# reduce dataset to 10 entries
|
||||
data = data_analyse_feature.iloc[:10]
|
||||
model = model_loader.load_sentence_transformer(
|
||||
model_name=STFRModelTypes.ALL_MPNET_BASE_V2,
|
||||
)
|
||||
(merged_data,) = ppc.merge_similarity_duplicates(
|
||||
data=data,
|
||||
model=model,
|
||||
cos_sim_threshold=cos_sim_threshold,
|
||||
)
|
||||
# constructed use case: with this threshold,
|
||||
# 2 out of 10 entries are merged into one
|
||||
assert len(merged_data) == 9
|
||||
eval_merged = merged_data[['entry', 'len', 'num_occur', 'num_assoc_obj_ids']]
|
||||
eval_benchmark = data_merge_similarity_duplicates[
|
||||
['entry', 'len', 'num_occur', 'num_assoc_obj_ids']
|
||||
]
|
||||
assert bool((eval_merged == eval_benchmark).all(axis=None))
|
||||
|
||||
79
tests/analysis/test_tokens.py
Normal file
79
tests/analysis/test_tokens.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lang_main import model_loader
|
||||
from lang_main.analysis import graphs, tokens
|
||||
from lang_main.types import SpacyModelTypes
|
||||
|
||||
SENTENCE = (
|
||||
'Ich ging am 22.05. mit ID 0912393 schnell über die Wiese zu einem Menschen, '
|
||||
'um ihm zu helfen. Ich konnte nicht mit ansehen, wie er Probleme beim Tragen '
|
||||
'seiner Tasche hatte.'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def spacy_model():
|
||||
model = model_loader.load_spacy(
|
||||
model_name=SpacyModelTypes.DE_CORE_NEWS_SM,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def test_pre_clean_word():
|
||||
string = 'Öl3bad2024prüfung'
|
||||
assert tokens.pre_clean_word(string) == 'Ölbadprüfung'
|
||||
|
||||
|
||||
def test_is_str_date():
|
||||
string = '22.05.'
|
||||
assert tokens.is_str_date(string, fuzzy=True)
|
||||
string = '22.05.2024'
|
||||
assert tokens.is_str_date(string)
|
||||
string = '22-05-2024'
|
||||
assert tokens.is_str_date(string)
|
||||
string = '9009090909'
|
||||
assert not tokens.is_str_date(string)
|
||||
string = 'hello347'
|
||||
assert not tokens.is_str_date(string)
|
||||
|
||||
|
||||
# TODO: depends on fixed Constants
|
||||
def test_obtain_relevant_descendants(spacy_model):
|
||||
doc = spacy_model(SENTENCE)
|
||||
sent1 = tuple(doc.sents)[0] # first sentence
|
||||
word1 = sent1[1] # word "ging" (POS:VERB)
|
||||
descendants1 = ('0912393', 'schnell', 'Wiese', 'Menschen')
|
||||
rel_descs = tokens.obtain_relevant_descendants(word1)
|
||||
rel_descs = tuple((token.text for token in rel_descs))
|
||||
assert descendants1 == rel_descs
|
||||
|
||||
sent2 = tuple(doc.sents)[1] # first sentence
|
||||
word2 = sent2[1] # word "konnte" (POS:AUX)
|
||||
descendants2 = ('mit', 'Probleme', 'Tragen', 'Tasche')
|
||||
rel_descs = tokens.obtain_relevant_descendants(word2)
|
||||
rel_descs = tuple((token.text for token in rel_descs))
|
||||
assert descendants2 == rel_descs
|
||||
|
||||
|
||||
def test_add_doc_info_to_graph(spacy_model):
|
||||
doc = spacy_model(SENTENCE)
|
||||
tk_graph = graphs.TokenGraph()
|
||||
tokens.add_doc_info_to_graph(tk_graph, doc, weight=2)
|
||||
assert len(tk_graph.nodes) == 11
|
||||
assert len(tk_graph.edges) == 17
|
||||
assert '0912393' in tk_graph.nodes
|
||||
|
||||
|
||||
def test_build_token_graph(
|
||||
data_merge_similarity_duplicates,
|
||||
spacy_model,
|
||||
data_tk_graph_built,
|
||||
):
|
||||
tk_graph, _ = tokens.build_token_graph(
|
||||
data=data_merge_similarity_duplicates,
|
||||
model=spacy_model,
|
||||
)
|
||||
assert len(tk_graph.nodes) == len(data_tk_graph_built.nodes)
|
||||
assert len(tk_graph.edges) == len(data_tk_graph_built.edges)
|
||||
Reference in New Issue
Block a user