Merge "metrics_database": first database integration for metrics logging #9
@ -297,7 +297,7 @@ def _process_sales(
|
|||||||
score_mae=best_score_mae,
|
score_mae=best_score_mae,
|
||||||
score_r2=best_score_r2,
|
score_r2=best_score_r2,
|
||||||
best_start_year=best_start_year,
|
best_start_year=best_start_year,
|
||||||
XGB_params=best_params,
|
xgb_params=best_params,
|
||||||
)
|
)
|
||||||
pipe.stats(stats)
|
pipe.stats(stats)
|
||||||
|
|
||||||
|
|||||||
@ -2,18 +2,17 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from delta_barth.constants import DB_ECHO
|
|
||||||
|
|
||||||
# ** meta
|
# ** meta
|
||||||
metadata = sql.MetaData()
|
metadata = sql.MetaData()
|
||||||
|
|
||||||
|
|
||||||
def get_engine(
|
def get_engine(
|
||||||
db_path: Path,
|
db_path: Path,
|
||||||
|
echo: bool = False,
|
||||||
) -> sql.Engine:
|
) -> sql.Engine:
|
||||||
path = db_path.resolve()
|
path = db_path.resolve()
|
||||||
connection_str: str = f"sqlite:///{str(path)}"
|
connection_str: str = f"sqlite:///{str(path)}"
|
||||||
engine = sql.create_engine(connection_str, echo=DB_ECHO)
|
engine = sql.create_engine(connection_str, echo=echo)
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@ -31,8 +30,8 @@ sf_stats = sql.Table(
|
|||||||
"sales_forecast_statistics",
|
"sales_forecast_statistics",
|
||||||
metadata,
|
metadata,
|
||||||
sql.Column("id", sql.Integer, primary_key=True),
|
sql.Column("id", sql.Integer, primary_key=True),
|
||||||
sql.Column("error_code", sql.Integer),
|
sql.Column("status_code", sql.Integer),
|
||||||
sql.Column("error_msg", sql.String(length=200)),
|
sql.Column("status_dscr", sql.String(length=200)),
|
||||||
sql.Column("length_dataset", sql.Integer),
|
sql.Column("length_dataset", sql.Integer),
|
||||||
sql.Column("score_mae", sql.Float, nullable=True),
|
sql.Column("score_mae", sql.Float, nullable=True),
|
||||||
sql.Column("score_r2", sql.Float, nullable=True),
|
sql.Column("score_r2", sql.Float, nullable=True),
|
||||||
|
|||||||
@ -4,14 +4,17 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Final
|
from typing import TYPE_CHECKING, Final
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import sqlalchemy as sql
|
||||||
from dopt_basics.io import combine_route
|
from dopt_basics.io import combine_route
|
||||||
|
|
||||||
import delta_barth.logging
|
import delta_barth.logging
|
||||||
|
from delta_barth import databases as db
|
||||||
from delta_barth.api.common import (
|
from delta_barth.api.common import (
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
validate_credentials,
|
validate_credentials,
|
||||||
)
|
)
|
||||||
|
from delta_barth.constants import DB_ECHO
|
||||||
from delta_barth.errors import STATUS_HANDLER
|
from delta_barth.errors import STATUS_HANDLER
|
||||||
from delta_barth.logging import logger_session as logger
|
from delta_barth.logging import logger_session as logger
|
||||||
from delta_barth.types import DelBarApiError, Status
|
from delta_barth.types import DelBarApiError, Status
|
||||||
@ -36,9 +39,13 @@ class Session:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_headers: HttpContentHeaders,
|
base_headers: HttpContentHeaders,
|
||||||
|
db_folder: str = "data",
|
||||||
logging_folder: str = "logs",
|
logging_folder: str = "logs",
|
||||||
) -> None:
|
) -> None:
|
||||||
self._data_path: Path | None = None
|
self._data_path: Path | None = None
|
||||||
|
self._db_path: Path | None = None
|
||||||
|
self._db_folder = db_folder
|
||||||
|
self._db_engine: sql.Engine | None = None
|
||||||
self._logging_dir: Path | None = None
|
self._logging_dir: Path | None = None
|
||||||
self._logging_folder = logging_folder
|
self._logging_folder = logging_folder
|
||||||
self._creds: ApiCredentials | None = None
|
self._creds: ApiCredentials | None = None
|
||||||
@ -48,13 +55,31 @@ class Session:
|
|||||||
self._logged_in: bool = False
|
self._logged_in: bool = False
|
||||||
|
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
self.setup_logging()
|
self._setup_db_management()
|
||||||
|
self._setup_logging()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> Path:
|
def data_path(self) -> Path:
|
||||||
assert self._data_path is not None, "accessed data path not set"
|
assert self._data_path is not None, "accessed data path not set"
|
||||||
return self._data_path
|
return self._data_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_engine(self) -> sql.Engine:
|
||||||
|
assert self._db_engine is not None, "accessed database engine not set"
|
||||||
|
return self._db_engine
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_path(self) -> Path:
|
||||||
|
if self._db_path is not None:
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
|
db_root = (self.data_path / self._db_folder).resolve()
|
||||||
|
db_path = db_root / "dopt-data.db"
|
||||||
|
if not db_root.exists():
|
||||||
|
db_root.mkdir(parents=False)
|
||||||
|
self._db_path = db_path
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logging_dir(self) -> Path:
|
def logging_dir(self) -> Path:
|
||||||
if self._logging_dir is not None:
|
if self._logging_dir is not None:
|
||||||
@ -66,7 +91,12 @@ class Session:
|
|||||||
self._logging_dir = logging_dir
|
self._logging_dir = logging_dir
|
||||||
return self._logging_dir
|
return self._logging_dir
|
||||||
|
|
||||||
def setup_logging(self) -> None:
|
def _setup_db_management(self) -> None:
|
||||||
|
self._db_engine = db.get_engine(self.db_path, echo=DB_ECHO)
|
||||||
|
db.metadata.create_all(self._db_engine)
|
||||||
|
logger.info("[SESSION] Successfully setup DB management")
|
||||||
|
|
||||||
|
def _setup_logging(self) -> None:
|
||||||
delta_barth.logging.setup_logging(self.logging_dir)
|
delta_barth.logging.setup_logging(self.logging_dir)
|
||||||
logger.info("[SESSION] Successfully setup logging")
|
logger.info("[SESSION] Successfully setup logging")
|
||||||
|
|
||||||
|
|||||||
@ -161,10 +161,11 @@ class SalesForecastStatistics(Statistics):
|
|||||||
score_mae: float | None = None
|
score_mae: float | None = None
|
||||||
score_r2: float | None = None
|
score_r2: float | None = None
|
||||||
best_start_year: int | None = None
|
best_start_year: int | None = None
|
||||||
XGB_params: BestParametersXGBRegressor | None = None
|
xgb_params: BestParametersXGBRegressor | None = None
|
||||||
|
|
||||||
|
|
||||||
class BestParametersXGBRegressor(t.TypedDict):
|
class BestParametersXGBRegressor(t.TypedDict):
|
||||||
|
forecast_id: t.NotRequired[int]
|
||||||
n_estimators: int
|
n_estimators: int
|
||||||
learning_rate: float
|
learning_rate: float
|
||||||
max_depth: int
|
max_depth: int
|
||||||
|
|||||||
@ -158,7 +158,6 @@ def test_preprocess_sales_FailOnTargetFeature(
|
|||||||
assert pipe.results is None
|
assert pipe.results is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.new
|
|
||||||
def test_process_sales_Success(sales_data_real_preproc):
|
def test_process_sales_Success(sales_data_real_preproc):
|
||||||
data = sales_data_real_preproc.copy()
|
data = sales_data_real_preproc.copy()
|
||||||
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
|
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
|
||||||
@ -178,10 +177,9 @@ def test_process_sales_Success(sales_data_real_preproc):
|
|||||||
assert pipe.statistics.score_mae is not None
|
assert pipe.statistics.score_mae is not None
|
||||||
assert pipe.statistics.score_r2 is not None
|
assert pipe.statistics.score_r2 is not None
|
||||||
assert pipe.statistics.best_start_year is not None
|
assert pipe.statistics.best_start_year is not None
|
||||||
assert pipe.statistics.XGB_params is not None
|
assert pipe.statistics.xgb_params is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.new
|
|
||||||
def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
|
def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
|
||||||
data = sales_data_real_preproc.copy()
|
data = sales_data_real_preproc.copy()
|
||||||
data = data.iloc[:20, :]
|
data = data.iloc[:20, :]
|
||||||
@ -205,10 +203,9 @@ def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
|
|||||||
assert pipe.statistics.score_mae is None
|
assert pipe.statistics.score_mae is None
|
||||||
assert pipe.statistics.score_r2 is None
|
assert pipe.statistics.score_r2 is None
|
||||||
assert pipe.statistics.best_start_year is None
|
assert pipe.statistics.best_start_year is None
|
||||||
assert pipe.statistics.XGB_params is None
|
assert pipe.statistics.xgb_params is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.new
|
|
||||||
def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
|
def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
|
||||||
data = sales_data_real_preproc.copy()
|
data = sales_data_real_preproc.copy()
|
||||||
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
|
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
|
||||||
@ -232,10 +229,9 @@ def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
|
|||||||
assert pipe.statistics.score_mae is None
|
assert pipe.statistics.score_mae is None
|
||||||
assert pipe.statistics.score_r2 is None
|
assert pipe.statistics.score_r2 is None
|
||||||
assert pipe.statistics.best_start_year is None
|
assert pipe.statistics.best_start_year is None
|
||||||
assert pipe.statistics.XGB_params is None
|
assert pipe.statistics.xgb_params is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.new
|
|
||||||
def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
|
def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
|
||||||
data = sales_data_real_preproc.copy()
|
data = sales_data_real_preproc.copy()
|
||||||
data["betrag"] = 10000
|
data["betrag"] = 10000
|
||||||
@ -278,7 +274,7 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
|
|||||||
assert pipe.statistics.score_mae is None
|
assert pipe.statistics.score_mae is None
|
||||||
assert pipe.statistics.score_r2 is None
|
assert pipe.statistics.score_r2 is None
|
||||||
assert pipe.statistics.best_start_year is None
|
assert pipe.statistics.best_start_year is None
|
||||||
assert pipe.statistics.XGB_params is None
|
assert pipe.statistics.xgb_params is None
|
||||||
|
|
||||||
|
|
||||||
def test_postprocess_sales_Success(
|
def test_postprocess_sales_Success(
|
||||||
|
|||||||
119
tests/test_databases.py
Normal file
119
tests/test_databases.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy as sql
|
||||||
|
|
||||||
|
from delta_barth import databases as db
|
||||||
|
from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_engine(tmp_path):
|
||||||
|
db_path = tmp_path / "test_db.db"
|
||||||
|
engine = db.get_engine(db_path)
|
||||||
|
assert isinstance(engine, sql.Engine)
|
||||||
|
assert "sqlite" in str(engine.url)
|
||||||
|
assert db_path.parent.name in str(engine.url)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_sales_forecast_statistics_small(session):
|
||||||
|
eng = session.db_engine
|
||||||
|
code = 0
|
||||||
|
descr = "Test case to write stats"
|
||||||
|
length = 32
|
||||||
|
stats = SalesForecastStatistics(code, descr, length)
|
||||||
|
_ = stats.xgb_params
|
||||||
|
|
||||||
|
stats_to_write = asdict(stats)
|
||||||
|
_ = stats_to_write.pop("xgb_params")
|
||||||
|
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
||||||
|
_ = res.inserted_primary_key[0]
|
||||||
|
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res = conn.execute(sql.select(db.sf_stats))
|
||||||
|
|
||||||
|
inserted = tuple(res.mappings())[0]
|
||||||
|
data = dict(**inserted)
|
||||||
|
del data["id"]
|
||||||
|
result = SalesForecastStatistics(**data)
|
||||||
|
assert result.status_code == code
|
||||||
|
assert result.status_dscr == descr
|
||||||
|
assert result.length_dataset == length
|
||||||
|
assert result.score_mae is None
|
||||||
|
assert result.score_r2 is None
|
||||||
|
assert result.best_start_year is None
|
||||||
|
assert result.xgb_params is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_sales_forecast_statistics_large(session):
|
||||||
|
eng = session.db_engine
|
||||||
|
code = 0
|
||||||
|
descr = "Test case to write stats"
|
||||||
|
length = 32
|
||||||
|
score_mae = 3.54
|
||||||
|
score_r2 = 0.56
|
||||||
|
best_start_year = 2020
|
||||||
|
xgb_params = BestParametersXGBRegressor(
|
||||||
|
n_estimators=2,
|
||||||
|
learning_rate=0.3,
|
||||||
|
max_depth=2,
|
||||||
|
min_child_weight=5,
|
||||||
|
gamma=0.5,
|
||||||
|
subsample=0.8,
|
||||||
|
colsample_bytree=5.25,
|
||||||
|
early_stopping_rounds=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = SalesForecastStatistics(
|
||||||
|
code,
|
||||||
|
descr,
|
||||||
|
length,
|
||||||
|
score_mae,
|
||||||
|
score_r2,
|
||||||
|
best_start_year,
|
||||||
|
xgb_params,
|
||||||
|
)
|
||||||
|
xgb_params = stats.xgb_params
|
||||||
|
assert xgb_params is not None
|
||||||
|
|
||||||
|
stats_to_write = asdict(stats)
|
||||||
|
_ = stats_to_write.pop("xgb_params")
|
||||||
|
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
||||||
|
sf_id = res.inserted_primary_key[0]
|
||||||
|
xgb_params["forecast_id"] = sf_id
|
||||||
|
res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params))
|
||||||
|
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res_stats = conn.execute(sql.select(db.sf_stats))
|
||||||
|
res_xgb = conn.execute(sql.select(db.sf_XGB))
|
||||||
|
# reconstruct best XGB parameters
|
||||||
|
inserted_xgb = tuple(res_xgb.mappings())[0]
|
||||||
|
data_xgb = dict(**inserted_xgb)
|
||||||
|
del data_xgb["id"]
|
||||||
|
xgb_stats = BestParametersXGBRegressor(**data_xgb)
|
||||||
|
# reconstruct other statistics
|
||||||
|
inserted = tuple(res_stats.mappings())[0]
|
||||||
|
data_inserted = dict(**inserted)
|
||||||
|
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
|
||||||
|
del data_inserted["id"]
|
||||||
|
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
|
||||||
|
assert stats.status_code == code
|
||||||
|
assert stats.status_dscr == descr
|
||||||
|
assert stats.length_dataset == length
|
||||||
|
assert stats.score_mae == pytest.approx(score_mae)
|
||||||
|
assert stats.score_r2 == pytest.approx(score_r2)
|
||||||
|
assert stats.best_start_year == best_start_year
|
||||||
|
assert stats.xgb_params is not None
|
||||||
|
# compare xgb_stats
|
||||||
|
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
|
||||||
|
assert stats.xgb_params["n_estimators"] == 2
|
||||||
|
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
|
||||||
|
assert stats.xgb_params["max_depth"] == 2
|
||||||
|
assert stats.xgb_params["min_child_weight"] == 5
|
||||||
|
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
|
||||||
|
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
|
||||||
|
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
|
||||||
|
assert stats.xgb_params["early_stopping_rounds"] == 5
|
||||||
@ -43,6 +43,22 @@ def test_session_set_DataPath(tmp_path):
|
|||||||
assert isinstance(session.data_path, Path)
|
assert isinstance(session.data_path, Path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_setup_db_management(tmp_path):
|
||||||
|
str_path = str(tmp_path)
|
||||||
|
foldername: str = "data_test"
|
||||||
|
target_db_dir = tmp_path / foldername
|
||||||
|
|
||||||
|
session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS, db_folder=foldername)
|
||||||
|
session.set_data_path(str_path)
|
||||||
|
db_path = session.db_path
|
||||||
|
assert db_path.parent.exists()
|
||||||
|
assert db_path.parent == target_db_dir
|
||||||
|
assert not db_path.exists()
|
||||||
|
session.setup()
|
||||||
|
assert session._db_engine is not None
|
||||||
|
assert db_path.exists()
|
||||||
|
|
||||||
|
|
||||||
@patch("delta_barth.logging.ENABLE_LOGGING", True)
|
@patch("delta_barth.logging.ENABLE_LOGGING", True)
|
||||||
@patch("delta_barth.logging.LOGGING_TO_FILE", True)
|
@patch("delta_barth.logging.LOGGING_TO_FILE", True)
|
||||||
def test_session_setup_logging(tmp_path):
|
def test_session_setup_logging(tmp_path):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user