prepare database writing operations for pipeline stats

This commit is contained in:
2025-03-27 16:29:01 +01:00
parent 7bb312d34e
commit 447a70486b
7 changed files with 178 additions and 17 deletions

View File

@@ -158,7 +158,6 @@ def test_preprocess_sales_FailOnTargetFeature(
assert pipe.results is None
@pytest.mark.new
def test_process_sales_Success(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
@@ -178,10 +177,9 @@ def test_process_sales_Success(sales_data_real_preproc):
assert pipe.statistics.score_mae is not None
assert pipe.statistics.score_r2 is not None
assert pipe.statistics.best_start_year is not None
assert pipe.statistics.XGB_params is not None
assert pipe.statistics.xgb_params is not None
@pytest.mark.new
def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
data = data.iloc[:20, :]
@@ -205,10 +203,9 @@ def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
assert pipe.statistics.score_mae is None
assert pipe.statistics.score_r2 is None
assert pipe.statistics.best_start_year is None
assert pipe.statistics.XGB_params is None
assert pipe.statistics.xgb_params is None
@pytest.mark.new
def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
@@ -232,10 +229,9 @@ def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
assert pipe.statistics.score_mae is None
assert pipe.statistics.score_r2 is None
assert pipe.statistics.best_start_year is None
assert pipe.statistics.XGB_params is None
assert pipe.statistics.xgb_params is None
@pytest.mark.new
def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
data["betrag"] = 10000
@@ -278,7 +274,7 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
assert pipe.statistics.score_mae is None
assert pipe.statistics.score_r2 is None
assert pipe.statistics.best_start_year is None
assert pipe.statistics.XGB_params is None
assert pipe.statistics.xgb_params is None
def test_postprocess_sales_Success(

119
tests/test_databases.py Normal file
View File

@@ -0,0 +1,119 @@
from dataclasses import asdict
import pytest
import sqlalchemy as sql
from delta_barth import databases as db
from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics
def test_get_engine(tmp_path):
db_path = tmp_path / "test_db.db"
engine = db.get_engine(db_path)
assert isinstance(engine, sql.Engine)
assert "sqlite" in str(engine.url)
assert db_path.parent.name in str(engine.url)
def test_write_sales_forecast_statistics_small(session):
eng = session.db_engine
code = 0
descr = "Test case to write stats"
length = 32
stats = SalesForecastStatistics(code, descr, length)
_ = stats.xgb_params
stats_to_write = asdict(stats)
_ = stats_to_write.pop("xgb_params")
with eng.begin() as conn:
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
_ = res.inserted_primary_key[0]
with eng.begin() as conn:
res = conn.execute(sql.select(db.sf_stats))
inserted = tuple(res.mappings())[0]
data = dict(**inserted)
del data["id"]
result = SalesForecastStatistics(**data)
assert result.status_code == code
assert result.status_dscr == descr
assert result.length_dataset == length
assert result.score_mae is None
assert result.score_r2 is None
assert result.best_start_year is None
assert result.xgb_params is None
def test_write_sales_forecast_statistics_large(session):
eng = session.db_engine
code = 0
descr = "Test case to write stats"
length = 32
score_mae = 3.54
score_r2 = 0.56
best_start_year = 2020
xgb_params = BestParametersXGBRegressor(
n_estimators=2,
learning_rate=0.3,
max_depth=2,
min_child_weight=5,
gamma=0.5,
subsample=0.8,
colsample_bytree=5.25,
early_stopping_rounds=5,
)
stats = SalesForecastStatistics(
code,
descr,
length,
score_mae,
score_r2,
best_start_year,
xgb_params,
)
xgb_params = stats.xgb_params
assert xgb_params is not None
stats_to_write = asdict(stats)
_ = stats_to_write.pop("xgb_params")
with eng.begin() as conn:
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
sf_id = res.inserted_primary_key[0]
xgb_params["forecast_id"] = sf_id
res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params))
with eng.begin() as conn:
res_stats = conn.execute(sql.select(db.sf_stats))
res_xgb = conn.execute(sql.select(db.sf_XGB))
# reconstruct best XGB parameters
inserted_xgb = tuple(res_xgb.mappings())[0]
data_xgb = dict(**inserted_xgb)
del data_xgb["id"]
xgb_stats = BestParametersXGBRegressor(**data_xgb)
# reconstruct other statistics
inserted = tuple(res_stats.mappings())[0]
data_inserted = dict(**inserted)
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
del data_inserted["id"]
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
assert stats.status_code == code
assert stats.status_dscr == descr
assert stats.length_dataset == length
assert stats.score_mae == pytest.approx(score_mae)
assert stats.score_r2 == pytest.approx(score_r2)
assert stats.best_start_year == best_start_year
assert stats.xgb_params is not None
# compare xgb_stats
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
assert stats.xgb_params["n_estimators"] == 2
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
assert stats.xgb_params["max_depth"] == 2
assert stats.xgb_params["min_child_weight"] == 5
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
assert stats.xgb_params["early_stopping_rounds"] == 5

View File

@@ -43,6 +43,22 @@ def test_session_set_DataPath(tmp_path):
assert isinstance(session.data_path, Path)
def test_session_setup_db_management(tmp_path):
str_path = str(tmp_path)
foldername: str = "data_test"
target_db_dir = tmp_path / foldername
session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS, db_folder=foldername)
session.set_data_path(str_path)
db_path = session.db_path
assert db_path.parent.exists()
assert db_path.parent == target_db_dir
assert not db_path.exists()
session.setup()
assert session._db_engine is not None
assert db_path.exists()
@patch("delta_barth.logging.ENABLE_LOGGING", True)
@patch("delta_barth.logging.LOGGING_TO_FILE", True)
def test_session_setup_logging(tmp_path):