diff --git a/src/delta_barth/analysis/forecast.py b/src/delta_barth/analysis/forecast.py index 03c81b7..3f00d0d 100644 --- a/src/delta_barth/analysis/forecast.py +++ b/src/delta_barth/analysis/forecast.py @@ -297,7 +297,7 @@ def _process_sales( score_mae=best_score_mae, score_r2=best_score_r2, best_start_year=best_start_year, - XGB_params=best_params, + xgb_params=best_params, ) pipe.stats(stats) diff --git a/src/delta_barth/databases.py b/src/delta_barth/databases.py index 56b5999..6756040 100644 --- a/src/delta_barth/databases.py +++ b/src/delta_barth/databases.py @@ -2,18 +2,17 @@ from pathlib import Path import sqlalchemy as sql -from delta_barth.constants import DB_ECHO - # ** meta metadata = sql.MetaData() def get_engine( db_path: Path, + echo: bool = False, ) -> sql.Engine: path = db_path.resolve() connection_str: str = f"sqlite:///{str(path)}" - engine = sql.create_engine(connection_str, echo=DB_ECHO) + engine = sql.create_engine(connection_str, echo=echo) return engine @@ -31,8 +30,8 @@ sf_stats = sql.Table( "sales_forecast_statistics", metadata, sql.Column("id", sql.Integer, primary_key=True), - sql.Column("error_code", sql.Integer), - sql.Column("error_msg", sql.String(length=200)), + sql.Column("status_code", sql.Integer), + sql.Column("status_dscr", sql.String(length=200)), sql.Column("length_dataset", sql.Integer), sql.Column("score_mae", sql.Float, nullable=True), sql.Column("score_r2", sql.Float, nullable=True), diff --git a/src/delta_barth/session.py b/src/delta_barth/session.py index 941b4c8..e539ce9 100644 --- a/src/delta_barth/session.py +++ b/src/delta_barth/session.py @@ -4,14 +4,17 @@ from pathlib import Path from typing import TYPE_CHECKING, Final import requests +import sqlalchemy as sql from dopt_basics.io import combine_route import delta_barth.logging +from delta_barth import databases as db from delta_barth.api.common import ( LoginRequest, LoginResponse, validate_credentials, ) +from delta_barth.constants import DB_ECHO from delta_barth.errors import STATUS_HANDLER from delta_barth.logging import logger_session as logger from delta_barth.types import DelBarApiError, Status @@ -36,9 +39,13 @@ class Session: def __init__( self, base_headers: HttpContentHeaders, + db_folder: str = "data", logging_folder: str = "logs", ) -> None: self._data_path: Path | None = None + self._db_path: Path | None = None + self._db_folder = db_folder + self._db_engine: sql.Engine | None = None self._logging_dir: Path | None = None self._logging_folder = logging_folder self._creds: ApiCredentials | None = None @@ -48,13 +55,31 @@ class Session: self._logged_in: bool = False def setup(self) -> None: - self.setup_logging() + self._setup_db_management() + self._setup_logging() @property def data_path(self) -> Path: assert self._data_path is not None, "accessed data path not set" return self._data_path + @property + def db_engine(self) -> sql.Engine: + assert self._db_engine is not None, "accessed database engine not set" + return self._db_engine + + @property + def db_path(self) -> Path: + if self._db_path is not None: + return self._db_path + + db_root = (self.data_path / self._db_folder).resolve() + db_path = db_root / "dopt-data.db" + if not db_root.exists(): + db_root.mkdir(parents=False) + self._db_path = db_path + return self._db_path + @property def logging_dir(self) -> Path: if self._logging_dir is not None: @@ -66,7 +91,12 @@ class Session: self._logging_dir = logging_dir return self._logging_dir - def setup_logging(self) -> None: + def _setup_db_management(self) -> None: + self._db_engine = db.get_engine(self.db_path, echo=DB_ECHO) + db.metadata.create_all(self._db_engine) + logger.info("[SESSION] Successfully setup DB management") + + def _setup_logging(self) -> None: delta_barth.logging.setup_logging(self.logging_dir) logger.info("[SESSION] Successfully setup logging") diff --git a/src/delta_barth/types.py b/src/delta_barth/types.py index 43a0c13..67e5887 100644 --- a/src/delta_barth/types.py +++ b/src/delta_barth/types.py @@ -161,10 +161,11 @@ class SalesForecastStatistics(Statistics): score_mae: float | None = None score_r2: float | None = None best_start_year: int | None = None - XGB_params: BestParametersXGBRegressor | None = None + xgb_params: BestParametersXGBRegressor | None = None class BestParametersXGBRegressor(t.TypedDict): + forecast_id: t.NotRequired[int] n_estimators: int learning_rate: float max_depth: int diff --git a/tests/analysis/test_forecast.py b/tests/analysis/test_forecast.py index 5548578..2bd5202 100644 --- a/tests/analysis/test_forecast.py +++ b/tests/analysis/test_forecast.py @@ -158,7 +158,6 @@ def test_preprocess_sales_FailOnTargetFeature( assert pipe.results is None -@pytest.mark.new def test_process_sales_Success(sales_data_real_preproc): data = sales_data_real_preproc.copy() pipe = PipeResult(data, STATUS_HANDLER.SUCCESS) @@ -178,10 +177,9 @@ def test_process_sales_Success(sales_data_real_preproc): assert pipe.statistics.score_mae is not None assert pipe.statistics.score_r2 is not None assert pipe.statistics.best_start_year is not None - assert pipe.statistics.XGB_params is not None + assert pipe.statistics.xgb_params is not None -@pytest.mark.new def test_process_sales_FailTooFewPoints(sales_data_real_preproc): data = sales_data_real_preproc.copy() data = data.iloc[:20, :] @@ -205,10 +203,9 @@ def test_process_sales_FailTooFewPoints(sales_data_real_preproc): assert pipe.statistics.score_mae is None assert pipe.statistics.score_r2 is None assert pipe.statistics.best_start_year is None - assert pipe.statistics.XGB_params is None + assert pipe.statistics.xgb_params is None -@pytest.mark.new def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc): data = sales_data_real_preproc.copy() pipe = PipeResult(data, STATUS_HANDLER.SUCCESS) @@ -232,10 +229,9 @@ def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc): assert pipe.statistics.score_mae is None assert pipe.statistics.score_r2 is None assert pipe.statistics.best_start_year is None - assert pipe.statistics.XGB_params is None + assert pipe.statistics.xgb_params is None -@pytest.mark.new def test_process_sales_FailNoReliableForecast(sales_data_real_preproc): data = sales_data_real_preproc.copy() data["betrag"] = 10000 @@ -278,7 +274,7 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc): assert pipe.statistics.score_mae is None assert pipe.statistics.score_r2 is None assert pipe.statistics.best_start_year is None - assert pipe.statistics.XGB_params is None + assert pipe.statistics.xgb_params is None def test_postprocess_sales_Success( diff --git a/tests/test_databases.py b/tests/test_databases.py new file mode 100644 index 0000000..3e5f8b2 --- /dev/null +++ b/tests/test_databases.py @@ -0,0 +1,119 @@ +from dataclasses import asdict + +import pytest +import sqlalchemy as sql + +from delta_barth import databases as db +from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics + + +def test_get_engine(tmp_path): + db_path = tmp_path / "test_db.db" + engine = db.get_engine(db_path) + assert isinstance(engine, sql.Engine) + assert "sqlite" in str(engine.url) + assert db_path.parent.name in str(engine.url) + + +def test_write_sales_forecast_statistics_small(session): + eng = session.db_engine + code = 0 + descr = "Test case to write stats" + length = 32 + stats = SalesForecastStatistics(code, descr, length) + _ = stats.xgb_params + + stats_to_write = asdict(stats) + _ = stats_to_write.pop("xgb_params") + + with eng.begin() as conn: + res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write)) + _ = res.inserted_primary_key[0] + + with eng.begin() as conn: + res = conn.execute(sql.select(db.sf_stats)) + + inserted = tuple(res.mappings())[0] + data = dict(**inserted) + del data["id"] + result = SalesForecastStatistics(**data) + assert result.status_code == code + assert result.status_dscr == descr + assert result.length_dataset == length + assert result.score_mae is None + assert result.score_r2 is None + assert result.best_start_year is None + assert result.xgb_params is None + + +def test_write_sales_forecast_statistics_large(session): + eng = session.db_engine + code = 0 + descr = "Test case to write stats" + length = 32 + score_mae = 3.54 + score_r2 = 0.56 + best_start_year = 2020 + xgb_params = BestParametersXGBRegressor( + n_estimators=2, + learning_rate=0.3, + max_depth=2, + min_child_weight=5, + gamma=0.5, + subsample=0.8, + colsample_bytree=5.25, + early_stopping_rounds=5, + ) + + stats = SalesForecastStatistics( + code, + descr, + length, + score_mae, + score_r2, + best_start_year, + xgb_params, + ) + xgb_params = stats.xgb_params + assert xgb_params is not None + + stats_to_write = asdict(stats) + _ = stats_to_write.pop("xgb_params") + + with eng.begin() as conn: + res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write)) + sf_id = res.inserted_primary_key[0] + xgb_params["forecast_id"] = sf_id + res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params)) + + with eng.begin() as conn: + res_stats = conn.execute(sql.select(db.sf_stats)) + res_xgb = conn.execute(sql.select(db.sf_XGB)) + # reconstruct best XGB parameters + inserted_xgb = tuple(res_xgb.mappings())[0] + data_xgb = dict(**inserted_xgb) + del data_xgb["id"] + xgb_stats = BestParametersXGBRegressor(**data_xgb) + # reconstruct other statistics + inserted = tuple(res_stats.mappings())[0] + data_inserted = dict(**inserted) + stats_id_fk = data_inserted["id"] # foreign key in XGB parameters + del data_inserted["id"] + stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats) + assert stats.status_code == code + assert stats.status_dscr == descr + assert stats.length_dataset == length + assert stats.score_mae == pytest.approx(score_mae) + assert stats.score_r2 == pytest.approx(score_r2) + assert stats.best_start_year == best_start_year + assert stats.xgb_params is not None + # compare xgb_stats + assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore + assert stats.xgb_params["n_estimators"] == 2 + assert stats.xgb_params["learning_rate"] == pytest.approx(0.3) + assert stats.xgb_params["max_depth"] == 2 + assert stats.xgb_params["min_child_weight"] == 5 + assert stats.xgb_params["gamma"] == pytest.approx(0.5) + assert stats.xgb_params["subsample"] == pytest.approx(0.8) + assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25) + assert stats.xgb_params["early_stopping_rounds"] == 5 diff --git a/tests/test_session.py b/tests/test_session.py index 59506f2..780f1f7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -43,6 +43,22 @@ def test_session_set_DataPath(tmp_path): assert isinstance(session.data_path, Path) +def test_session_setup_db_management(tmp_path): + str_path = str(tmp_path) + foldername: str = "data_test" + target_db_dir = tmp_path / foldername + + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS, db_folder=foldername) + session.set_data_path(str_path) + db_path = session.db_path + assert db_path.parent.exists() + assert db_path.parent == target_db_dir + assert not db_path.exists() + session.setup() + assert session._db_engine is not None + assert db_path.exists() + + @patch("delta_barth.logging.ENABLE_LOGGING", True) @patch("delta_barth.logging.LOGGING_TO_FILE", True) def test_session_setup_logging(tmp_path):