Merge "metrics_database": first database integration for metrics logging #9
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "delta-barth"
|
name = "delta-barth"
|
||||||
version = "0.5.0dev0"
|
version = "0.5.0"
|
||||||
description = "workflows and pipelines for the Python-based Plugin of Delta Barth's ERP system"
|
description = "workflows and pipelines for the Python-based Plugin of Delta Barth's ERP system"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Florian Förster", email = "f.foerster@d-opt.com"},
|
{name = "Florian Förster", email = "f.foerster@d-opt.com"},
|
||||||
@ -73,7 +73,7 @@ directory = "reports/coverage"
|
|||||||
|
|
||||||
|
|
||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.5.0dev0"
|
current_version = "0.5.0"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@ -3,16 +3,19 @@ from __future__ import annotations
|
|||||||
import datetime
|
import datetime
|
||||||
import math
|
import math
|
||||||
from collections.abc import Mapping, Set
|
from collections.abc import Mapping, Set
|
||||||
|
from dataclasses import asdict
|
||||||
from datetime import datetime as Datetime
|
from datetime import datetime as Datetime
|
||||||
from typing import TYPE_CHECKING, Final, TypeAlias, cast
|
from typing import TYPE_CHECKING, Final, TypeAlias, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
|
import sqlalchemy as sql
|
||||||
from sklearn.metrics import mean_absolute_error, r2_score
|
from sklearn.metrics import mean_absolute_error, r2_score
|
||||||
from sklearn.model_selection import KFold, RandomizedSearchCV
|
from sklearn.model_selection import KFold, RandomizedSearchCV
|
||||||
from xgboost import XGBRegressor
|
from xgboost import XGBRegressor
|
||||||
|
|
||||||
|
from delta_barth import databases
|
||||||
from delta_barth.analysis import parse
|
from delta_barth.analysis import parse
|
||||||
from delta_barth.api.requests import (
|
from delta_barth.api.requests import (
|
||||||
SalesPrognosisResponse,
|
SalesPrognosisResponse,
|
||||||
@ -29,7 +32,8 @@ from delta_barth.constants import (
|
|||||||
SALES_MIN_NUM_DATAPOINTS,
|
SALES_MIN_NUM_DATAPOINTS,
|
||||||
)
|
)
|
||||||
from delta_barth.errors import STATUS_HANDLER, wrap_result
|
from delta_barth.errors import STATUS_HANDLER, wrap_result
|
||||||
from delta_barth.logging import logger_pipelines as logger
|
from delta_barth.logging import logger_db, logger_pipelines
|
||||||
|
from delta_barth.management import SESSION
|
||||||
from delta_barth.types import (
|
from delta_barth.types import (
|
||||||
BestParametersXGBRegressor,
|
BestParametersXGBRegressor,
|
||||||
DualDict,
|
DualDict,
|
||||||
@ -77,6 +81,21 @@ def _parse_df_to_results(
|
|||||||
return SalesPrognosisResults(daten=tuple(df_formatted)) # type: ignore
|
return SalesPrognosisResults(daten=tuple(df_formatted)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _write_sales_forecast_stats(
|
||||||
|
stats: SalesForecastStatistics,
|
||||||
|
) -> None:
|
||||||
|
stats_db = asdict(stats)
|
||||||
|
_ = stats_db.pop("xgb_params")
|
||||||
|
xgb_params = stats.xgb_params
|
||||||
|
|
||||||
|
with SESSION.db_engine.begin() as conn:
|
||||||
|
res = conn.execute(sql.insert(databases.sf_stats).values(stats_db))
|
||||||
|
sf_id = cast(int, res.inserted_primary_key[0]) # type: ignore
|
||||||
|
if xgb_params is not None:
|
||||||
|
xgb_params["forecast_id"] = sf_id
|
||||||
|
conn.execute(sql.insert(databases.sf_XGB).values(xgb_params))
|
||||||
|
|
||||||
|
|
||||||
@wrap_result()
|
@wrap_result()
|
||||||
def _parse_api_resp_to_df_wrapped(
|
def _parse_api_resp_to_df_wrapped(
|
||||||
resp: SalesPrognosisResponse,
|
resp: SalesPrognosisResponse,
|
||||||
@ -91,23 +110,11 @@ def _parse_df_to_results_wrapped(
|
|||||||
return _parse_df_to_results(data)
|
return _parse_df_to_results(data)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
@wrap_result()
|
||||||
# Input:
|
def _write_sales_forecast_stats_wrapped(
|
||||||
# DataFrame df mit Columns f_umsatz_fakt, firmen, art, v_warengrp
|
stats: SalesForecastStatistics,
|
||||||
# kunde (muss enthalten sein in df['firmen']['firma_refid'])
|
) -> None:
|
||||||
|
return _write_sales_forecast_stats(stats)
|
||||||
# Output:
|
|
||||||
# Integer umsetzung (Prognose möglich): 0 ja, 1 nein (zu wenig Daten verfügbar),
|
|
||||||
# 2 nein (Daten nicht für Prognose geeignet)
|
|
||||||
# DataFrame test: Jahr, Monat, Vorhersage
|
|
||||||
# -------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
# Prognose Umsatz je Firma
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: check usage of separate exception and handle it in API function
|
|
||||||
# TODO set min number of data points as constant, not parameter
|
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_sales(
|
def _preprocess_sales(
|
||||||
@ -341,7 +348,7 @@ def _export_on_fail(
|
|||||||
return SalesPrognosisResultsExport(response=response, status=status)
|
return SalesPrognosisResultsExport(response=response, status=status)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_sales(
|
def pipeline_sales_forecast(
|
||||||
session: Session,
|
session: Session,
|
||||||
company_id: int | None = None,
|
company_id: int | None = None,
|
||||||
start_date: Datetime | None = None,
|
start_date: Datetime | None = None,
|
||||||
@ -352,8 +359,8 @@ def pipeline_sales(
|
|||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
)
|
)
|
||||||
if status != STATUS_HANDLER.SUCCESS:
|
if status != STATUS_HANDLER.SUCCESS:
|
||||||
logger.error(
|
logger_pipelines.error(
|
||||||
"Error during sales prognosis data retrieval, Status: %s",
|
"Error during sales forecast data retrieval, Status: %s",
|
||||||
status,
|
status,
|
||||||
stack_info=True,
|
stack_info=True,
|
||||||
)
|
)
|
||||||
@ -365,8 +372,8 @@ def pipeline_sales(
|
|||||||
target_features=FEATURES_SALES_PROGNOSIS,
|
target_features=FEATURES_SALES_PROGNOSIS,
|
||||||
)
|
)
|
||||||
if pipe.status != STATUS_HANDLER.SUCCESS:
|
if pipe.status != STATUS_HANDLER.SUCCESS:
|
||||||
logger.error(
|
logger_pipelines.error(
|
||||||
"Error during sales prognosis preprocessing, Status: %s",
|
"Error during sales forecast preprocessing, Status: %s",
|
||||||
pipe.status,
|
pipe.status,
|
||||||
stack_info=True,
|
stack_info=True,
|
||||||
)
|
)
|
||||||
@ -377,9 +384,16 @@ def pipeline_sales(
|
|||||||
min_num_data_points=SALES_MIN_NUM_DATAPOINTS,
|
min_num_data_points=SALES_MIN_NUM_DATAPOINTS,
|
||||||
base_num_data_points_months=SALES_BASE_NUM_DATAPOINTS_MONTHS,
|
base_num_data_points_months=SALES_BASE_NUM_DATAPOINTS_MONTHS,
|
||||||
)
|
)
|
||||||
|
if pipe.statistics is not None:
|
||||||
|
res = _write_sales_forecast_stats_wrapped(pipe.statistics)
|
||||||
|
if res.status != STATUS_HANDLER.SUCCESS:
|
||||||
|
logger_db.error(
|
||||||
|
"[DB] Error during write process of sales forecast statistics: %s",
|
||||||
|
res.status,
|
||||||
|
)
|
||||||
if pipe.status != STATUS_HANDLER.SUCCESS:
|
if pipe.status != STATUS_HANDLER.SUCCESS:
|
||||||
logger.error(
|
logger_pipelines.error(
|
||||||
"Error during sales prognosis main processing, Status: %s",
|
"Error during sales forecast main processing, Status: %s",
|
||||||
pipe.status,
|
pipe.status,
|
||||||
stack_info=True,
|
stack_info=True,
|
||||||
)
|
)
|
||||||
@ -390,8 +404,8 @@ def pipeline_sales(
|
|||||||
feature_map=DualDict(),
|
feature_map=DualDict(),
|
||||||
)
|
)
|
||||||
if pipe.status != STATUS_HANDLER.SUCCESS:
|
if pipe.status != STATUS_HANDLER.SUCCESS:
|
||||||
logger.error(
|
logger_pipelines.error(
|
||||||
"Error during sales prognosis postprocessing, Status: %s",
|
"Error during sales forecast postprocessing, Status: %s",
|
||||||
pipe.status,
|
pipe.status,
|
||||||
stack_info=True,
|
stack_info=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -34,8 +34,10 @@ logger_session = logging.getLogger("delta_barth.session")
|
|||||||
logger_session.setLevel(logging.DEBUG)
|
logger_session.setLevel(logging.DEBUG)
|
||||||
logger_wrapped_results = logging.getLogger("delta_barth.wrapped_results")
|
logger_wrapped_results = logging.getLogger("delta_barth.wrapped_results")
|
||||||
logger_wrapped_results.setLevel(logging.DEBUG)
|
logger_wrapped_results.setLevel(logging.DEBUG)
|
||||||
logger_pipelines = logging.getLogger("delta_barth.logger_pipelines")
|
logger_pipelines = logging.getLogger("delta_barth.pipelines")
|
||||||
logger_pipelines.setLevel(logging.DEBUG)
|
logger_pipelines.setLevel(logging.DEBUG)
|
||||||
|
logger_db = logging.getLogger("delta_barth.databases")
|
||||||
|
logger_db.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(
|
def setup_logging(
|
||||||
|
|||||||
@ -11,7 +11,9 @@ def pipeline_sales_forecast(
|
|||||||
company_id: int | None,
|
company_id: int | None,
|
||||||
start_date: Datetime | None,
|
start_date: Datetime | None,
|
||||||
) -> JsonExportResponse:
|
) -> JsonExportResponse:
|
||||||
result = forecast.pipeline_sales(SESSION, company_id=company_id, start_date=start_date)
|
result = forecast.pipeline_sales_forecast(
|
||||||
|
SESSION, company_id=company_id, start_date=start_date
|
||||||
|
)
|
||||||
export = JsonExportResponse(result.model_dump_json())
|
export = JsonExportResponse(result.model_dump_json())
|
||||||
|
|
||||||
return export
|
return export
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import pprint
|
||||||
import typing as t
|
import typing as t
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -29,6 +30,10 @@ class Status(BaseModel):
|
|||||||
message: SkipValidation[str] = ""
|
message: SkipValidation[str] = ""
|
||||||
api_server_error: SkipValidation[DelBarApiError | None] = None
|
api_server_error: SkipValidation[DelBarApiError | None] = None
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
py_repr = self.model_dump()
|
||||||
|
return pprint.pformat(py_repr, indent=4, sort_dicts=False)
|
||||||
|
|
||||||
|
|
||||||
class ResponseType(BaseModel):
|
class ResponseType(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -4,12 +4,19 @@ from unittest.mock import patch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
import sqlalchemy as sql
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from delta_barth import databases as db
|
||||||
from delta_barth.analysis import forecast as fc
|
from delta_barth.analysis import forecast as fc
|
||||||
from delta_barth.api.requests import SalesPrognosisResponse, SalesPrognosisResponseEntry
|
from delta_barth.api.requests import SalesPrognosisResponse, SalesPrognosisResponseEntry
|
||||||
from delta_barth.errors import STATUS_HANDLER
|
from delta_barth.errors import STATUS_HANDLER
|
||||||
from delta_barth.types import DualDict, PipeResult
|
from delta_barth.types import (
|
||||||
|
BestParametersXGBRegressor,
|
||||||
|
DualDict,
|
||||||
|
PipeResult,
|
||||||
|
SalesForecastStatistics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
@ -123,6 +130,96 @@ def test_parse_df_to_results_InvalidData(invalid_results):
|
|||||||
_ = fc._parse_df_to_results(invalid_results)
|
_ = fc._parse_df_to_results(invalid_results)
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_sales_forecast_stats_small(session):
|
||||||
|
eng = session.db_engine
|
||||||
|
code = 0
|
||||||
|
descr = "Test case to write stats"
|
||||||
|
length = 32
|
||||||
|
stats = SalesForecastStatistics(code, descr, length)
|
||||||
|
# execute
|
||||||
|
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||||
|
fc._write_sales_forecast_stats(stats)
|
||||||
|
# read
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res = conn.execute(sql.select(db.sf_stats))
|
||||||
|
|
||||||
|
inserted = tuple(res.mappings())[0]
|
||||||
|
data = dict(**inserted)
|
||||||
|
del data["id"]
|
||||||
|
result = SalesForecastStatistics(**data)
|
||||||
|
assert result.status_code == code
|
||||||
|
assert result.status_dscr == descr
|
||||||
|
assert result.length_dataset == length
|
||||||
|
assert result.score_mae is None
|
||||||
|
assert result.score_r2 is None
|
||||||
|
assert result.best_start_year is None
|
||||||
|
assert result.xgb_params is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_sales_forecast_stats_large(session):
|
||||||
|
eng = session.db_engine
|
||||||
|
code = 0
|
||||||
|
descr = "Test case to write stats"
|
||||||
|
length = 32
|
||||||
|
score_mae = 3.54
|
||||||
|
score_r2 = 0.56
|
||||||
|
best_start_year = 2020
|
||||||
|
xgb_params = BestParametersXGBRegressor(
|
||||||
|
n_estimators=2,
|
||||||
|
learning_rate=0.3,
|
||||||
|
max_depth=2,
|
||||||
|
min_child_weight=5,
|
||||||
|
gamma=0.5,
|
||||||
|
subsample=0.8,
|
||||||
|
colsample_bytree=5.25,
|
||||||
|
early_stopping_rounds=5,
|
||||||
|
)
|
||||||
|
stats = SalesForecastStatistics(
|
||||||
|
code,
|
||||||
|
descr,
|
||||||
|
length,
|
||||||
|
score_mae,
|
||||||
|
score_r2,
|
||||||
|
best_start_year,
|
||||||
|
xgb_params,
|
||||||
|
)
|
||||||
|
# execute
|
||||||
|
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||||
|
fc._write_sales_forecast_stats(stats)
|
||||||
|
# read
|
||||||
|
with eng.begin() as conn:
|
||||||
|
res_stats = conn.execute(sql.select(db.sf_stats))
|
||||||
|
res_xgb = conn.execute(sql.select(db.sf_XGB))
|
||||||
|
# reconstruct best XGB parameters
|
||||||
|
inserted_xgb = tuple(res_xgb.mappings())[0]
|
||||||
|
data_xgb = dict(**inserted_xgb)
|
||||||
|
del data_xgb["id"]
|
||||||
|
xgb_stats = BestParametersXGBRegressor(**data_xgb)
|
||||||
|
# reconstruct other statistics
|
||||||
|
inserted = tuple(res_stats.mappings())[0]
|
||||||
|
data_inserted = dict(**inserted)
|
||||||
|
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
|
||||||
|
del data_inserted["id"]
|
||||||
|
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
|
||||||
|
assert stats.status_code == code
|
||||||
|
assert stats.status_dscr == descr
|
||||||
|
assert stats.length_dataset == length
|
||||||
|
assert stats.score_mae == pytest.approx(score_mae)
|
||||||
|
assert stats.score_r2 == pytest.approx(score_r2)
|
||||||
|
assert stats.best_start_year == best_start_year
|
||||||
|
assert stats.xgb_params is not None
|
||||||
|
# compare xgb_stats
|
||||||
|
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
|
||||||
|
assert stats.xgb_params["n_estimators"] == 2
|
||||||
|
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
|
||||||
|
assert stats.xgb_params["max_depth"] == 2
|
||||||
|
assert stats.xgb_params["min_child_weight"] == 5
|
||||||
|
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
|
||||||
|
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
|
||||||
|
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
|
||||||
|
assert stats.xgb_params["early_stopping_rounds"] == 5
|
||||||
|
|
||||||
|
|
||||||
def test_preprocess_sales_Success(
|
def test_preprocess_sales_Success(
|
||||||
exmpl_api_sales_prognosis_resp,
|
exmpl_api_sales_prognosis_resp,
|
||||||
feature_map,
|
feature_map,
|
||||||
@ -319,16 +416,25 @@ def test_export_on_fail():
|
|||||||
|
|
||||||
|
|
||||||
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
|
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
|
||||||
def test_pipeline_sales_prognosis(exmpl_api_sales_prognosis_resp):
|
def test_pipeline_sales_forecast_SuccessDbWrite(exmpl_api_sales_prognosis_resp, session):
|
||||||
def mock_request(*args, **kwargs): # pragma: no cover
|
|
||||||
return exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"delta_barth.analysis.forecast.get_sales_prognosis_data",
|
"delta_barth.analysis.forecast.get_sales_prognosis_data",
|
||||||
# new=mock_request,
|
|
||||||
) as mock:
|
) as mock:
|
||||||
mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
||||||
result = fc.pipeline_sales(None) # type: ignore
|
with patch("delta_barth.analysis.forecast.SESSION", session):
|
||||||
|
result = fc.pipeline_sales_forecast(None) # type: ignore
|
||||||
|
print(result)
|
||||||
|
assert result.status == STATUS_HANDLER.SUCCESS
|
||||||
|
assert len(result.response.daten) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
|
||||||
|
def test_pipeline_sales_forecast_FailDbWrite(exmpl_api_sales_prognosis_resp):
|
||||||
|
with patch(
|
||||||
|
"delta_barth.analysis.forecast.get_sales_prognosis_data",
|
||||||
|
) as mock:
|
||||||
|
mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS
|
||||||
|
result = fc.pipeline_sales_forecast(None) # type: ignore
|
||||||
print(result)
|
print(result)
|
||||||
assert result.status == STATUS_HANDLER.SUCCESS
|
assert result.status == STATUS_HANDLER.SUCCESS
|
||||||
assert len(result.response.daten) > 0
|
assert len(result.response.daten) > 0
|
||||||
|
|||||||
@ -1,10 +1,6 @@
|
|||||||
from dataclasses import asdict
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from delta_barth import databases as db
|
from delta_barth import databases as db
|
||||||
from delta_barth.types import BestParametersXGBRegressor, SalesForecastStatistics
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_engine(tmp_path):
|
def test_get_engine(tmp_path):
|
||||||
@ -13,107 +9,3 @@ def test_get_engine(tmp_path):
|
|||||||
assert isinstance(engine, sql.Engine)
|
assert isinstance(engine, sql.Engine)
|
||||||
assert "sqlite" in str(engine.url)
|
assert "sqlite" in str(engine.url)
|
||||||
assert db_path.parent.name in str(engine.url)
|
assert db_path.parent.name in str(engine.url)
|
||||||
|
|
||||||
|
|
||||||
def test_write_sales_forecast_statistics_small(session):
|
|
||||||
eng = session.db_engine
|
|
||||||
code = 0
|
|
||||||
descr = "Test case to write stats"
|
|
||||||
length = 32
|
|
||||||
stats = SalesForecastStatistics(code, descr, length)
|
|
||||||
_ = stats.xgb_params
|
|
||||||
|
|
||||||
stats_to_write = asdict(stats)
|
|
||||||
_ = stats_to_write.pop("xgb_params")
|
|
||||||
|
|
||||||
with eng.begin() as conn:
|
|
||||||
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
|
||||||
_ = res.inserted_primary_key[0]
|
|
||||||
|
|
||||||
with eng.begin() as conn:
|
|
||||||
res = conn.execute(sql.select(db.sf_stats))
|
|
||||||
|
|
||||||
inserted = tuple(res.mappings())[0]
|
|
||||||
data = dict(**inserted)
|
|
||||||
del data["id"]
|
|
||||||
result = SalesForecastStatistics(**data)
|
|
||||||
assert result.status_code == code
|
|
||||||
assert result.status_dscr == descr
|
|
||||||
assert result.length_dataset == length
|
|
||||||
assert result.score_mae is None
|
|
||||||
assert result.score_r2 is None
|
|
||||||
assert result.best_start_year is None
|
|
||||||
assert result.xgb_params is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_sales_forecast_statistics_large(session):
|
|
||||||
eng = session.db_engine
|
|
||||||
code = 0
|
|
||||||
descr = "Test case to write stats"
|
|
||||||
length = 32
|
|
||||||
score_mae = 3.54
|
|
||||||
score_r2 = 0.56
|
|
||||||
best_start_year = 2020
|
|
||||||
xgb_params = BestParametersXGBRegressor(
|
|
||||||
n_estimators=2,
|
|
||||||
learning_rate=0.3,
|
|
||||||
max_depth=2,
|
|
||||||
min_child_weight=5,
|
|
||||||
gamma=0.5,
|
|
||||||
subsample=0.8,
|
|
||||||
colsample_bytree=5.25,
|
|
||||||
early_stopping_rounds=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
stats = SalesForecastStatistics(
|
|
||||||
code,
|
|
||||||
descr,
|
|
||||||
length,
|
|
||||||
score_mae,
|
|
||||||
score_r2,
|
|
||||||
best_start_year,
|
|
||||||
xgb_params,
|
|
||||||
)
|
|
||||||
xgb_params = stats.xgb_params
|
|
||||||
assert xgb_params is not None
|
|
||||||
|
|
||||||
stats_to_write = asdict(stats)
|
|
||||||
_ = stats_to_write.pop("xgb_params")
|
|
||||||
|
|
||||||
with eng.begin() as conn:
|
|
||||||
res = conn.execute(sql.insert(db.sf_stats).values(stats_to_write))
|
|
||||||
sf_id = res.inserted_primary_key[0]
|
|
||||||
xgb_params["forecast_id"] = sf_id
|
|
||||||
res = conn.execute(sql.insert(db.sf_XGB).values(xgb_params))
|
|
||||||
|
|
||||||
with eng.begin() as conn:
|
|
||||||
res_stats = conn.execute(sql.select(db.sf_stats))
|
|
||||||
res_xgb = conn.execute(sql.select(db.sf_XGB))
|
|
||||||
# reconstruct best XGB parameters
|
|
||||||
inserted_xgb = tuple(res_xgb.mappings())[0]
|
|
||||||
data_xgb = dict(**inserted_xgb)
|
|
||||||
del data_xgb["id"]
|
|
||||||
xgb_stats = BestParametersXGBRegressor(**data_xgb)
|
|
||||||
# reconstruct other statistics
|
|
||||||
inserted = tuple(res_stats.mappings())[0]
|
|
||||||
data_inserted = dict(**inserted)
|
|
||||||
stats_id_fk = data_inserted["id"] # foreign key in XGB parameters
|
|
||||||
del data_inserted["id"]
|
|
||||||
stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats)
|
|
||||||
assert stats.status_code == code
|
|
||||||
assert stats.status_dscr == descr
|
|
||||||
assert stats.length_dataset == length
|
|
||||||
assert stats.score_mae == pytest.approx(score_mae)
|
|
||||||
assert stats.score_r2 == pytest.approx(score_r2)
|
|
||||||
assert stats.best_start_year == best_start_year
|
|
||||||
assert stats.xgb_params is not None
|
|
||||||
# compare xgb_stats
|
|
||||||
assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore
|
|
||||||
assert stats.xgb_params["n_estimators"] == 2
|
|
||||||
assert stats.xgb_params["learning_rate"] == pytest.approx(0.3)
|
|
||||||
assert stats.xgb_params["max_depth"] == 2
|
|
||||||
assert stats.xgb_params["min_child_weight"] == 5
|
|
||||||
assert stats.xgb_params["gamma"] == pytest.approx(0.5)
|
|
||||||
assert stats.xgb_params["subsample"] == pytest.approx(0.8)
|
|
||||||
assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25)
|
|
||||||
assert stats.xgb_params["early_stopping_rounds"] == 5
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user