61 lines
1.7 KiB
Python

from pathlib import Path
import sqlalchemy as sql
# ** meta
metadata = sql.MetaData()
def get_engine(
db_path: Path,
echo: bool = False,
) -> sql.Engine:
path = db_path.resolve()
connection_str: str = f"sqlite:///{str(path)}"
engine = sql.create_engine(connection_str, echo=echo)
return engine
# ** table declarations
# ** ---- common
perf_meas = sql.Table(
"performance_measurement",
metadata,
sql.Column("id", sql.Integer, primary_key=True),
sql.Column("pipeline_name", sql.String(length=30)),
sql.Column("execution_duration", sql.Float),
)
# ** ---- forecasts
sf_stats = sql.Table(
"sales_forecast_statistics",
metadata,
sql.Column("id", sql.Integer, primary_key=True),
sql.Column("status_code", sql.Integer),
sql.Column("status_dscr", sql.String(length=200)),
sql.Column("length_dataset", sql.Integer),
sql.Column("score_mae", sql.Float, nullable=True),
sql.Column("score_r2", sql.Float, nullable=True),
sql.Column("best_start_year", sql.Integer, nullable=True),
)
sf_XGB = sql.Table(
"sales_forecast_XGB_parameters",
metadata,
sql.Column("id", sql.Integer, primary_key=True),
sql.Column(
"forecast_id",
sql.Integer,
sql.ForeignKey(
"sales_forecast_statistics.id", onupdate="CASCADE", ondelete="CASCADE"
),
unique=True,
),
sql.Column("n_estimators", sql.Integer),
sql.Column("learning_rate", sql.Float),
sql.Column("max_depth", sql.Integer),
sql.Column("min_child_weight", sql.Integer),
sql.Column("gamma", sql.Float),
sql.Column("subsample", sql.Float),
sql.Column("colsample_bytree", sql.Float),
sql.Column("early_stopping_rounds", sql.Integer),
)