from dataclasses import asdict import pytest import sqlalchemy as sql from delta_barth import databases as db from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics def test_get_engine(tmp_path): db_path = tmp_path / "test_db.db" engine = db.get_engine(db_path) assert isinstance(engine, sql.Engine) assert "sqlite" in str(engine.url) assert db_path.parent.name in str(engine.url) def test_write_sales_forecast_statistics_small(session): eng = session.db_engine code = 0 descr = "Test case to write stats" length = 32 stats = SalesForecastStatistics(code, descr, length) _ = stats.xgb_params stats_to_write = asdict(stats) _ = stats_to_write.pop("xgb_params") with eng.begin() as conn: res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write)) _ = res.inserted_primary_key[0] with eng.begin() as conn: res = conn.execute(sql.select(db.sf_stats)) inserted = tuple(res.mappings())[0] data = dict(**inserted) del data["id"] result = SalesForecastStatistics(**data) assert result.status_code == code assert result.status_dscr == descr assert result.length_dataset == length assert result.score_mae is None assert result.score_r2 is None assert result.best_start_year is None assert result.xgb_params is None def test_write_sales_forecast_statistics_large(session): eng = session.db_engine code = 0 descr = "Test case to write stats" length = 32 score_mae = 3.54 score_r2 = 0.56 best_start_year = 2020 xgb_params = BestParametersXGBRegressor( n_estimators=2, learning_rate=0.3, max_depth=2, min_child_weight=5, gamma=0.5, subsample=0.8, colsample_bytree=5.25, early_stopping_rounds=5, ) stats = SalesForecastStatistics( code, descr, length, score_mae, score_r2, best_start_year, xgb_params, ) xgb_params = stats.xgb_params assert xgb_params is not None stats_to_write = asdict(stats) _ = stats_to_write.pop("xgb_params") with eng.begin() as conn: res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write)) sf_id = res.inserted_primary_key[0] xgb_params["forecast_id"] = sf_id res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params)) with eng.begin() as conn: res_stats = conn.execute(sql.select(db.sf_stats)) res_xgb = conn.execute(sql.select(db.sf_XGB)) # reconstruct best XGB parameters inserted_xgb = tuple(res_xgb.mappings())[0] data_xgb = dict(**inserted_xgb) del data_xgb["id"] xgb_stats = BestParametersXGBRegressor(**data_xgb) # reconstruct other statistics inserted = tuple(res_stats.mappings())[0] data_inserted = dict(**inserted) stats_id_fk = data_inserted["id"] # foreign key in XGB parameters del data_inserted["id"] stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats) assert stats.status_code == code assert stats.status_dscr == descr assert stats.length_dataset == length assert stats.score_mae == pytest.approx(score_mae) assert stats.score_r2 == pytest.approx(score_r2) assert stats.best_start_year == best_start_year assert stats.xgb_params is not None # compare xgb_stats assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore assert stats.xgb_params["n_estimators"] == 2 assert stats.xgb_params["learning_rate"] == pytest.approx(0.3) assert stats.xgb_params["max_depth"] == 2 assert stats.xgb_params["min_child_weight"] == 5 assert stats.xgb_params["gamma"] == pytest.approx(0.5) assert stats.xgb_params["subsample"] == pytest.approx(0.8) assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25) assert stats.xgb_params["early_stopping_rounds"] == 5