prepare database writing operations for pipeline stats
This commit is contained in:
@@ -297,7 +297,7 @@ def _process_sales(
|
||||
score_mae=best_score_mae,
|
||||
score_r2=best_score_r2,
|
||||
best_start_year=best_start_year,
|
||||
XGB_params=best_params,
|
||||
xgb_params=best_params,
|
||||
)
|
||||
pipe.stats(stats)
|
||||
|
||||
|
||||
@@ -2,18 +2,17 @@ from pathlib import Path
|
||||
|
||||
import sqlalchemy as sql
|
||||
|
||||
from delta_barth.constants import DB_ECHO
|
||||
|
||||
# ** meta
|
||||
metadata = sql.MetaData()
|
||||
|
||||
|
||||
def get_engine(
|
||||
db_path: Path,
|
||||
echo: bool = False,
|
||||
) -> sql.Engine:
|
||||
path = db_path.resolve()
|
||||
connection_str: str = f"sqlite:///{str(path)}"
|
||||
engine = sql.create_engine(connection_str, echo=DB_ECHO)
|
||||
engine = sql.create_engine(connection_str, echo=echo)
|
||||
return engine
|
||||
|
||||
|
||||
@@ -31,8 +30,8 @@ sf_stats = sql.Table(
|
||||
"sales_forecast_statistics",
|
||||
metadata,
|
||||
sql.Column("id", sql.Integer, primary_key=True),
|
||||
sql.Column("error_code", sql.Integer),
|
||||
sql.Column("error_msg", sql.String(length=200)),
|
||||
sql.Column("status_code", sql.Integer),
|
||||
sql.Column("status_dscr", sql.String(length=200)),
|
||||
sql.Column("length_dataset", sql.Integer),
|
||||
sql.Column("score_mae", sql.Float, nullable=True),
|
||||
sql.Column("score_r2", sql.Float, nullable=True),
|
||||
|
||||
@@ -4,14 +4,17 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import requests
|
||||
import sqlalchemy as sql
|
||||
from dopt_basics.io import combine_route
|
||||
|
||||
import delta_barth.logging
|
||||
from delta_barth import databases as db
|
||||
from delta_barth.api.common import (
|
||||
LoginRequest,
|
||||
LoginResponse,
|
||||
validate_credentials,
|
||||
)
|
||||
from delta_barth.constants import DB_ECHO
|
||||
from delta_barth.errors import STATUS_HANDLER
|
||||
from delta_barth.logging import logger_session as logger
|
||||
from delta_barth.types import DelBarApiError, Status
|
||||
@@ -36,9 +39,13 @@ class Session:
|
||||
def __init__(
|
||||
self,
|
||||
base_headers: HttpContentHeaders,
|
||||
db_folder: str = "data",
|
||||
logging_folder: str = "logs",
|
||||
) -> None:
|
||||
self._data_path: Path | None = None
|
||||
self._db_path: Path | None = None
|
||||
self._db_folder = db_folder
|
||||
self._db_engine: sql.Engine | None = None
|
||||
self._logging_dir: Path | None = None
|
||||
self._logging_folder = logging_folder
|
||||
self._creds: ApiCredentials | None = None
|
||||
@@ -48,13 +55,31 @@ class Session:
|
||||
self._logged_in: bool = False
|
||||
|
||||
def setup(self) -> None:
|
||||
self.setup_logging()
|
||||
self._setup_db_management()
|
||||
self._setup_logging()
|
||||
|
||||
@property
|
||||
def data_path(self) -> Path:
|
||||
assert self._data_path is not None, "accessed data path not set"
|
||||
return self._data_path
|
||||
|
||||
@property
|
||||
def db_engine(self) -> sql.Engine:
|
||||
assert self._db_engine is not None, "accessed database engine not set"
|
||||
return self._db_engine
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
if self._db_path is not None:
|
||||
return self._db_path
|
||||
|
||||
db_root = (self.data_path / self._db_folder).resolve()
|
||||
db_path = db_root / "dopt-data.db"
|
||||
if not db_root.exists():
|
||||
db_root.mkdir(parents=False)
|
||||
self._db_path = db_path
|
||||
return self._db_path
|
||||
|
||||
@property
|
||||
def logging_dir(self) -> Path:
|
||||
if self._logging_dir is not None:
|
||||
@@ -66,7 +91,12 @@ class Session:
|
||||
self._logging_dir = logging_dir
|
||||
return self._logging_dir
|
||||
|
||||
def setup_logging(self) -> None:
|
||||
def _setup_db_management(self) -> None:
|
||||
self._db_engine = db.get_engine(self.db_path, echo=DB_ECHO)
|
||||
db.metadata.create_all(self._db_engine)
|
||||
logger.info("[SESSION] Successfully setup DB management")
|
||||
|
||||
def _setup_logging(self) -> None:
|
||||
delta_barth.logging.setup_logging(self.logging_dir)
|
||||
logger.info("[SESSION] Successfully setup logging")
|
||||
|
||||
|
||||
@@ -161,10 +161,11 @@ class SalesForecastStatistics(Statistics):
|
||||
score_mae: float | None = None
|
||||
score_r2: float | None = None
|
||||
best_start_year: int | None = None
|
||||
XGB_params: BestParametersXGBRegressor | None = None
|
||||
xgb_params: BestParametersXGBRegressor | None = None
|
||||
|
||||
|
||||
class BestParametersXGBRegressor(t.TypedDict):
|
||||
forecast_id: t.NotRequired[int]
|
||||
n_estimators: int
|
||||
learning_rate: float
|
||||
max_depth: int
|
||||
|
||||
Reference in New Issue
Block a user