added test cases

This commit is contained in:
Florian Förster
2025-01-22 16:54:15 +01:00
parent 30fe71e80a
commit fb28b8548b
28 changed files with 17721 additions and 17 deletions

View File

@@ -321,7 +321,7 @@ def test_pipe_add_graph_metrics():
def test_pipe_rescale_graph_edge_weights(tk_graph):
rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph)
assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0)
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952)
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.095238)
assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0)
assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0)
@@ -331,7 +331,7 @@ def test_rescale_edge_weights(import_graph, request):
test_graph = request.getfixturevalue(import_graph)
rescaled_graph = graphs.rescale_edge_weights(test_graph)
assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0)
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952)
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.095238)
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])

View File

@@ -72,7 +72,7 @@ def test_calc_delta_to_repair(data_pre_cleaned, convert_to_days):
def test_non_relevant_obj_ids(data_pre_cleaned):
feature_uniqueness = 'HObjektText'
feature_obj_id = 'ObjektID'
threshold = 1
threshold = 2
data = data_pre_cleaned.copy()
data.at[0, feature_obj_id] = 1
ids_to_ignore = tl._non_relevant_obj_ids(
@@ -88,7 +88,7 @@ def test_non_relevant_obj_ids(data_pre_cleaned):
def test_remove_non_relevant_obj_ids(data_pre_cleaned):
feature_uniqueness = 'HObjektText'
feature_obj_id = 'ObjektID'
threshold = 1
threshold = 2
data = data_pre_cleaned.copy()
data.at[0, feature_obj_id] = 1

View File

@@ -25,8 +25,6 @@ from lang_main.types import LanguageModels
@pytest.mark.parametrize(
'model_name',
[
STFRModelTypes.ALL_DISTILROBERTA_V1,
STFRModelTypes.ALL_MINI_LM_L12_V2,
STFRModelTypes.ALL_MINI_LM_L6_V2,
STFRModelTypes.ALL_MPNET_BASE_V2,
],
@@ -47,6 +45,25 @@ def test_load_sentence_transformer(
assert isinstance(model, SentenceTransformer)
def test_preprocess_STFR_model_name() -> None:
model_name_not_exist = 'TestModel'
ret_model_name = model_loader._preprocess_STFR_model_name(
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=True
)
assert ret_model_name == model_name_not_exist
ret_model_name = model_loader._preprocess_STFR_model_name(
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=False
)
assert ret_model_name == model_name_not_exist
model_name_exist = STFRModelTypes.E5_BASE_STS_EN_DE
backend_exist = STFRBackends.ONNX
with pytest.raises(FileNotFoundError):
_ = model_loader._preprocess_STFR_model_name(
model_name=model_name_exist, backend=backend_exist, force_download=False
)
@pytest.mark.parametrize(
'similarity_func',
[
@@ -57,8 +74,6 @@ def test_load_sentence_transformer(
@pytest.mark.parametrize(
'model_name',
[
STFRModelTypes.ALL_DISTILROBERTA_V1,
STFRModelTypes.ALL_MINI_LM_L12_V2,
STFRModelTypes.ALL_MINI_LM_L6_V2,
STFRModelTypes.ALL_MPNET_BASE_V2,
],
@@ -108,6 +123,14 @@ def test_instantiate_spacy_model():
assert isinstance(model, Language)
def test_fail_instantiate_spacy_model():
with pytest.raises(KeyError):
_ = model_loader.instantiate_model(
model_load_map=model_loader.MODEL_LOADER_MAP,
model='test', # type: ignore
) # type: ignore
@pytest.mark.mload
def test_instantiate_stfr_model():
model = model_loader.instantiate_model(