added test cases
This commit is contained in:
@@ -321,7 +321,7 @@ def test_pipe_add_graph_metrics():
|
||||
def test_pipe_rescale_graph_edge_weights(tk_graph):
|
||||
rescaled_tkg, rescaled_undir = graphs.pipe_rescale_graph_edge_weights(tk_graph)
|
||||
assert rescaled_tkg[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.0952)
|
||||
assert rescaled_tkg[1][2]['weight'] == pytest.approx(0.095238)
|
||||
assert rescaled_undir[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_undir[1][2]['weight'] == pytest.approx(1.0)
|
||||
|
||||
@@ -331,7 +331,7 @@ def test_rescale_edge_weights(import_graph, request):
|
||||
test_graph = request.getfixturevalue(import_graph)
|
||||
rescaled_graph = graphs.rescale_edge_weights(test_graph)
|
||||
assert rescaled_graph[2][1]['weight'] == pytest.approx(1.0)
|
||||
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.0952)
|
||||
assert rescaled_graph[1][2]['weight'] == pytest.approx(0.095238)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('import_graph', ['graph', 'tk_graph'])
|
||||
|
||||
@@ -72,7 +72,7 @@ def test_calc_delta_to_repair(data_pre_cleaned, convert_to_days):
|
||||
def test_non_relevant_obj_ids(data_pre_cleaned):
|
||||
feature_uniqueness = 'HObjektText'
|
||||
feature_obj_id = 'ObjektID'
|
||||
threshold = 1
|
||||
threshold = 2
|
||||
data = data_pre_cleaned.copy()
|
||||
data.at[0, feature_obj_id] = 1
|
||||
ids_to_ignore = tl._non_relevant_obj_ids(
|
||||
@@ -88,7 +88,7 @@ def test_non_relevant_obj_ids(data_pre_cleaned):
|
||||
def test_remove_non_relevant_obj_ids(data_pre_cleaned):
|
||||
feature_uniqueness = 'HObjektText'
|
||||
feature_obj_id = 'ObjektID'
|
||||
threshold = 1
|
||||
threshold = 2
|
||||
data = data_pre_cleaned.copy()
|
||||
data.at[0, feature_obj_id] = 1
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ from lang_main.types import LanguageModels
|
||||
@pytest.mark.parametrize(
|
||||
'model_name',
|
||||
[
|
||||
STFRModelTypes.ALL_DISTILROBERTA_V1,
|
||||
STFRModelTypes.ALL_MINI_LM_L12_V2,
|
||||
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
||||
STFRModelTypes.ALL_MPNET_BASE_V2,
|
||||
],
|
||||
@@ -47,6 +45,25 @@ def test_load_sentence_transformer(
|
||||
assert isinstance(model, SentenceTransformer)
|
||||
|
||||
|
||||
def test_preprocess_STFR_model_name() -> None:
|
||||
model_name_not_exist = 'TestModel'
|
||||
ret_model_name = model_loader._preprocess_STFR_model_name(
|
||||
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=True
|
||||
)
|
||||
assert ret_model_name == model_name_not_exist
|
||||
ret_model_name = model_loader._preprocess_STFR_model_name(
|
||||
model_name=model_name_not_exist, backend=STFRBackends.TORCH, force_download=False
|
||||
)
|
||||
assert ret_model_name == model_name_not_exist
|
||||
|
||||
model_name_exist = STFRModelTypes.E5_BASE_STS_EN_DE
|
||||
backend_exist = STFRBackends.ONNX
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_ = model_loader._preprocess_STFR_model_name(
|
||||
model_name=model_name_exist, backend=backend_exist, force_download=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'similarity_func',
|
||||
[
|
||||
@@ -57,8 +74,6 @@ def test_load_sentence_transformer(
|
||||
@pytest.mark.parametrize(
|
||||
'model_name',
|
||||
[
|
||||
STFRModelTypes.ALL_DISTILROBERTA_V1,
|
||||
STFRModelTypes.ALL_MINI_LM_L12_V2,
|
||||
STFRModelTypes.ALL_MINI_LM_L6_V2,
|
||||
STFRModelTypes.ALL_MPNET_BASE_V2,
|
||||
],
|
||||
@@ -108,6 +123,14 @@ def test_instantiate_spacy_model():
|
||||
assert isinstance(model, Language)
|
||||
|
||||
|
||||
def test_fail_instantiate_spacy_model():
|
||||
with pytest.raises(KeyError):
|
||||
_ = model_loader.instantiate_model(
|
||||
model_load_map=model_loader.MODEL_LOADER_MAP,
|
||||
model='test', # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.mload
|
||||
def test_instantiate_stfr_model():
|
||||
model = model_loader.instantiate_model(
|
||||
|
||||
Reference in New Issue
Block a user