60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import pandas as pd
|
|
import pytest
|
|
|
|
from delta_barth.analysis import forecast, parse
|
|
from delta_barth.errors import FeaturesMissingError
|
|
|
|
|
|
def test_check_needed_features():
|
|
target_features = set(("feat1", "feat2", "feat3"))
|
|
data = pd.DataFrame(
|
|
data=[[1, 2, 3, 4, 5]], columns=["feat1", "feat2", "feat3", "feat4", "feat5"]
|
|
)
|
|
parse._check_needed_features(data, target_features)
|
|
data = pd.DataFrame(
|
|
data=[[1, 2, 3, 4, 5]], columns=["featX", "feat2", "feat3", "feat4", "feat5"]
|
|
)
|
|
with pytest.raises(FeaturesMissingError):
|
|
parse._check_needed_features(data, target_features)
|
|
|
|
|
|
def test_map_features_to_targets():
|
|
feature_map = dict(feat1="feat10", feat2="feat20", feat5="feat50")
|
|
data = pd.DataFrame(
|
|
data=[[1, 2, 3, 4, 5]], columns=["feat1", "feat2", "feat3", "feat4", "feat5"]
|
|
)
|
|
data = parse._map_features_to_targets(data, feature_map)
|
|
assert "feat10" in data.columns
|
|
assert "feat20" in data.columns
|
|
assert "feat50" in data.columns
|
|
assert "feat3" in data.columns
|
|
assert "feat4" in data.columns
|
|
assert "feat1" not in data.columns
|
|
assert "feat2" not in data.columns
|
|
assert "feat5" not in data.columns
|
|
|
|
|
|
def test_preprocess_features(exmpl_api_sales_prognosis_resp):
|
|
resp = exmpl_api_sales_prognosis_resp
|
|
df = forecast._parse_api_resp_to_df(resp)
|
|
feat_mapping: dict[str, str] = {
|
|
"artikelId": "artikel_refid",
|
|
"firmaId": "firma_refid",
|
|
"betrag": "betrag",
|
|
"menge": "menge",
|
|
"buchungsDatum": "buchungs_datum",
|
|
}
|
|
target_features: frozenset[str] = frozenset(
|
|
(
|
|
"firma_refid",
|
|
"betrag",
|
|
"buchungs_datum",
|
|
)
|
|
)
|
|
|
|
assert all(feat in df.columns for feat in feat_mapping.keys())
|
|
data = parse.process_features(df, feat_mapping, target_features)
|
|
assert len(data.columns) == len(df.columns)
|
|
assert (data.columns != df.columns).any()
|
|
assert any(feat not in data.columns for feat in feat_mapping.keys())
|