major overhaul of forecast pipeline (#21)

includes several aspects:

- harden forecast logic with additional error checks
- fix wrong behaviour
- ensure minimum data viability
- extrapolate for multiple data points into the future

fix #19

Co-authored-by: frasu
Reviewed-on: #21
Co-authored-by: foefl <f.foerster@d-opt.com>
Co-committed-by: foefl <f.foerster@d-opt.com>
This commit was merged in pull request #21.
This commit is contained in:
2025-04-16 09:24:33 +00:00
committed by Florian Förster
parent 6caa087efd
commit 063531a08e
6 changed files with 110 additions and 31 deletions

View File

@@ -1,4 +1,6 @@
import datetime
from datetime import datetime as Datetime
from pathlib import Path
from unittest.mock import patch
import numpy as np
@@ -255,6 +257,7 @@ def test_preprocess_sales_FailOnTargetFeature(
assert pipe.results is None
@pytest.mark.forecast
def test_process_sales_Success(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
@@ -277,6 +280,7 @@ def test_process_sales_Success(sales_data_real_preproc):
assert pipe.statistics.xgb_params is not None
@pytest.mark.forecast
def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
data = data.iloc[:20, :]
@@ -303,6 +307,7 @@ def test_process_sales_FailTooFewPoints(sales_data_real_preproc):
assert pipe.statistics.xgb_params is None
@pytest.mark.forecast
def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
pipe = PipeResult(data, STATUS_HANDLER.SUCCESS)
@@ -329,8 +334,19 @@ def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc):
assert pipe.statistics.xgb_params is None
@pytest.mark.forecast
def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
data = sales_data_real_preproc.copy()
# prepare fake data
df = sales_data_real_preproc.copy()
f_dates = "buchungs_datum"
end = datetime.datetime.now()
start = df[f_dates].max()
fake_dates = pd.date_range(start, end, freq="MS")
fake_data = [(1234, 1014, 1024, 1000, 10, date) for date in fake_dates]
fake_df = pd.DataFrame(fake_data, columns=df.columns)
enhanced_df = pd.concat((df, fake_df), ignore_index=True)
data = enhanced_df.copy()
data["betrag"] = 10000
print(data["betrag"])
data = data.iloc[:20000, :]
@@ -340,7 +356,7 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
def __init__(self, *args, **kwargs) -> None:
class Predictor:
def predict(self, *args, **kwargs):
return np.array([1, 1, 1, 1])
return np.array([1, 1, 1, 1], dtype=np.float64)
self.best_estimator_ = Predictor()
@@ -354,7 +370,7 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc):
pipe = fc._process_sales(
pipe,
min_num_data_points=1,
base_num_data_points_months=-100,
base_num_data_points_months=1,
)
assert pipe.status != STATUS_HANDLER.SUCCESS

View File

@@ -1,17 +1,15 @@
import importlib
import json
from unittest.mock import patch
import pytest
import sqlalchemy as sql
import delta_barth.pipelines
from delta_barth import databases as db
from delta_barth import pipelines as pl
from delta_barth.errors import STATUS_HANDLER
def test_write_performance_metrics(session):
def test_write_performance_metrics_Success(session):
pipe_name = "test_pipe"
t_start = 20_000_000_000
t_end = 30_000_000_000
@@ -33,6 +31,20 @@ def test_write_performance_metrics(session):
assert metrics.execution_duration == 10
def test_write_performance_metrics_FailStartingTime(session):
pipe_name = "test_pipe"
t_start = 30_000_000_000
t_end = 20_000_000_000
with patch("delta_barth.pipelines.SESSION", session):
with pytest.raises(ValueError):
_ = pl._write_performance_metrics(
pipeline_name=pipe_name,
time_start=t_start,
time_end=t_end,
)
@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1)
def test_sales_prognosis_pipeline(exmpl_api_sales_prognosis_resp, session):
with patch(

View File

@@ -64,6 +64,7 @@ def test_session_setup_db_management(tmp_path):
@patch("delta_barth.logging.ENABLE_LOGGING", True)
@patch("delta_barth.logging.LOGGING_TO_FILE", True)
@patch("delta_barth.logging.LOGGING_TO_STDERR", True)
def test_session_setup_logging(tmp_path):
str_path = str(tmp_path)
foldername: str = "logging_test"