integrate database writing procedures for logging purposes
This commit was merged in pull request #9.
This commit is contained in:
@@ -4,12 +4,19 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import sqlalchemy as sql
|
||||
from pydantic import ValidationError
|
||||
|
||||
from delta_barth import databases as db
|
||||
from delta_barth.analysis import forecast as fc
|
||||
from delta_barth.api.requests import SalesPrognosisResponse, SalesPrognosisResponseEntry
|
||||
from delta_barth.errors import STATUS_HANDLER
|
||||
from delta_barth.types import DualDict, PipeResult
|
||||
from delta_barth.types import (
|
||||
BestParametersXGBRegressor,
|
||||
DualDict,
|
||||
PipeResult,
|
||||
SalesForecastStatistics,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -123,6 +130,96 @@ def test_parse_df_to_results_InvalidData(invalid_results):
|
||||
_ = fc._parse_df_to_results(invalid_results)
|
||||
|
||||
|
||||
def test_write_sales_forecast_stats_small(session):
|
||||
eng = session.db_engine
|
||||
code = 0
|
||||
descr = "Test case to write stats"
|
||||
length = 32
|
||||
stats = SalesForecastStatistics(code, descr, length)
|
||||
# execute
|
||||
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||
fc._write_sales_forecast_stats(stats)
|
||||
# read
|
||||
with eng.begin() as conn:
|
||||
res = conn.execute(sql.select(db.sf_stats))
|
||||
|
||||
inserted = tuple(res.mappings())[0]
|
||||
data = dict(**inserted)
|
||||
del data["id"]
|
||||
result = SalesForecastStatistics(**data)
|
||||
assert result.status_code == code
|
||||
assert result.status_dscr == descr
|
||||
assert result.length_dataset == length
|
||||
assert result.score_mae is None
|
||||
assert result.score_r2 is None
|
||||
assert result.best_start_year is None
|
||||
assert result.xgb_params is None
|
||||
|
||||
|
||||
def test_write_sales_forecast_stats_large(session):
|
||||
eng = session.db_engine
|
||||
code = 0
|
||||
descr = "Test case to write stats"
|
||||
length = 32
|
||||
score_mae = 3.54
|
||||
score_r2 = 0.56
|
||||
best_start_year = 2020
|
||||
xgb_params = BestParametersXGBRegressor(
|
||||
n_estimators=2,
|
||||
learning_rate=0.3,
|
||||
max_depth=2,
|
||||
min_child_weight=5,
|
||||
gamma=0.5,
|
||||
subsample=0.8,
|
||||
colsample_bytree=5.25,
|
||||
early_stopping_rounds=5,
|
||||
)
|
||||
stats = SalesForecastStatistics(
|
||||
code,
|
||||
descr,
|
||||
length,
|
||||
score_mae,
|
||||
score_r2,
|
||||
best_start_year,
|
||||
xgb_params,
|
||||
)
|
||||
# execute
|
||||
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||
fc._write_sales_forecast_stats(stats)
|
||||
# read
|
||||
with eng.begin() as conn:
|
||||
res_stats = conn.execute(sql.select(db.sf_stats))
|
||||
res_xgb = conn.execute(sql.select(db.sf_XGB))
|
||||
# reconstruct best XGB parameters
|
||||
inserted_xgb = tuple(res_xgb.mappings())[0]
|
||||
data_xgb = dict(**inserted_xgb)
|
||||
del data_xgb["id"]
|
||||
xgb_stats = BestParametersXGBRegressor(**data_xgb)
|
||||
# reconstruct other statistics
|
||||
inserted = tuple(res_stats.mappings())[0]
|
||||
data_inserted = dict(**inserted)
|
||||
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
|
||||
del data_inserted["id"]
|
||||
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
|
||||
assert stats.status_code == code
|
||||
assert stats.status_dscr == descr
|
||||
assert stats.length_dataset == length
|
||||
assert stats.score_mae == pytest.approx(score_mae)
|
||||
assert stats.score_r2 == pytest.approx(score_r2)
|
||||
assert stats.best_start_year == best_start_year
|
||||
assert stats.xgb_params is not None
|
||||
# compare xgb_stats
|
||||
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
|
||||
assert stats.xgb_params["n_estimators"] == 2
|
||||
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
|
||||
assert stats.xgb_params["max_depth"] == 2
|
||||
assert stats.xgb_params["min_child_weight"] == 5
|
||||
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
|
||||
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
|
||||
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
|
||||
assert stats.xgb_params["early_stopping_rounds"] == 5
|
||||
|
||||
|
||||
def test_preprocess_sales_Success(
|
||||
exmpl_api_sales_prognosis_resp,
|
||||
feature_map,
|
||||
@@ -319,16 +416,25 @@ def test_export_on_fail():
|
||||
|
||||
|
||||
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
|
||||
def test_pipeline_sales_prognosis(exmpl_api_sales_prognosis_resp):
|
||||
def mock_request(*args, **kwargs): # pragma: no cover
|
||||
return exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
||||
|
||||
def test_pipeline_sales_forecast_SuccessDbWrite(exmpl_api_sales_prognosis_resp, session):
|
||||
with patch(
|
||||
"delta_barth.analysis.forecast.get_sales_prognosis_data",
|
||||
# new=mock_request,
|
||||
) as mock:
|
||||
mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
||||
result = fc.pipeline_sales(None) # type: ignore
|
||||
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||
result = fc.pipeline_sales_forecast(None) # type: ignore
|
||||
print(result)
|
||||
assert result.status == STATUS_HANDLER.SUCCESS
|
||||
assert len(result.response.daten) > 0
|
||||
|
||||
|
||||
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
|
||||
def test_pipeline_sales_forecast_FailDbWrite(exmpl_api_sales_prognosis_resp):
|
||||
with patch(
|
||||
"delta_barth.analysis.forecast.get_sales_prognosis_data",
|
||||
) as mock:
|
||||
mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
||||
result = fc.pipeline_sales_forecast(None) # type: ignore
|
||||
print(result)
|
||||
assert result.status == STATUS_HANDLER.SUCCESS
|
||||
assert len(result.response.daten) > 0
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
from dataclasses import asdict
|
||||
|
||||
import pytest
|
||||
import sqlalchemy as sql
|
||||
|
||||
from delta_barth import databases as db
|
||||
from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics
|
||||
|
||||
|
||||
def test_get_engine(tmp_path):
|
||||
@@ -13,107 +9,3 @@ def test_get_engine(tmp_path):
|
||||
assert isinstance(engine, sql.Engine)
|
||||
assert "sqlite" in str(engine.url)
|
||||
assert db_path.parent.name in str(engine.url)
|
||||
|
||||
|
||||
def test_write_sales_forecast_statistics_small(session):
|
||||
eng = session.db_engine
|
||||
code = 0
|
||||
descr = "Test case to write stats"
|
||||
length = 32
|
||||
stats = SalesForecastStatistics(code, descr, length)
|
||||
_ = stats.xgb_params
|
||||
|
||||
stats_to_write = asdict(stats)
|
||||
_ = stats_to_write.pop("xgb_params")
|
||||
|
||||
with eng.begin() as conn:
|
||||
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
||||
_ = res.inserted_primary_key[0]
|
||||
|
||||
with eng.begin() as conn:
|
||||
res = conn.execute(sql.select(db.sf_stats))
|
||||
|
||||
inserted = tuple(res.mappings())[0]
|
||||
data = dict(**inserted)
|
||||
del data["id"]
|
||||
result = SalesForecastStatistics(**data)
|
||||
assert result.status_code == code
|
||||
assert result.status_dscr == descr
|
||||
assert result.length_dataset == length
|
||||
assert result.score_mae is None
|
||||
assert result.score_r2 is None
|
||||
assert result.best_start_year is None
|
||||
assert result.xgb_params is None
|
||||
|
||||
|
||||
def test_write_sales_forecast_statistics_large(session):
|
||||
eng = session.db_engine
|
||||
code = 0
|
||||
descr = "Test case to write stats"
|
||||
length = 32
|
||||
score_mae = 3.54
|
||||
score_r2 = 0.56
|
||||
best_start_year = 2020
|
||||
xgb_params = BestParametersXGBRegressor(
|
||||
n_estimators=2,
|
||||
learning_rate=0.3,
|
||||
max_depth=2,
|
||||
min_child_weight=5,
|
||||
gamma=0.5,
|
||||
subsample=0.8,
|
||||
colsample_bytree=5.25,
|
||||
early_stopping_rounds=5,
|
||||
)
|
||||
|
||||
stats = SalesForecastStatistics(
|
||||
code,
|
||||
descr,
|
||||
length,
|
||||
score_mae,
|
||||
score_r2,
|
||||
best_start_year,
|
||||
xgb_params,
|
||||
)
|
||||
xgb_params = stats.xgb_params
|
||||
assert xgb_params is not None
|
||||
|
||||
stats_to_write = asdict(stats)
|
||||
_ = stats_to_write.pop("xgb_params")
|
||||
|
||||
with eng.begin() as conn:
|
||||
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
||||
sf_id = res.inserted_primary_key[0]
|
||||
xgb_params["forecast_id"] = sf_id
|
||||
res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params))
|
||||
|
||||
with eng.begin() as conn:
|
||||
res_stats = conn.execute(sql.select(db.sf_stats))
|
||||
res_xgb = conn.execute(sql.select(db.sf_XGB))
|
||||
# reconstruct best XGB parameters
|
||||
inserted_xgb = tuple(res_xgb.mappings())[0]
|
||||
data_xgb = dict(**inserted_xgb)
|
||||
del data_xgb["id"]
|
||||
xgb_stats = BestParametersXGBRegressor(**data_xgb)
|
||||
# reconstruct other statistics
|
||||
inserted = tuple(res_stats.mappings())[0]
|
||||
data_inserted = dict(**inserted)
|
||||
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
|
||||
del data_inserted["id"]
|
||||
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
|
||||
assert stats.status_code == code
|
||||
assert stats.status_dscr == descr
|
||||
assert stats.length_dataset == length
|
||||
assert stats.score_mae == pytest.approx(score_mae)
|
||||
assert stats.score_r2 == pytest.approx(score_r2)
|
||||
assert stats.best_start_year == best_start_year
|
||||
assert stats.xgb_params is not None
|
||||
# compare xgb_stats
|
||||
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
|
||||
assert stats.xgb_params["n_estimators"] == 2
|
||||
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
|
||||
assert stats.xgb_params["max_depth"] == 2
|
||||
assert stats.xgb_params["min_child_weight"] == 5
|
||||
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
|
||||
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
|
||||
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
|
||||
assert stats.xgb_params["early_stopping_rounds"] == 5
|
||||
|
||||
Reference in New Issue
Block a user