From 763d3c1aac5dd84d2865b3fec4c45e018ce6b174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20F=C3=B6rster?= Date: Wed, 5 Mar 2025 15:06:24 +0100 Subject: [PATCH] adapt forecast pipeline to new output format --- src/delta_barth/analysis/forecast.py | 6 +++--- tests/analysis/test_forecast.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/delta_barth/analysis/forecast.py b/src/delta_barth/analysis/forecast.py index d964c4e..aefa4c8 100644 --- a/src/delta_barth/analysis/forecast.py +++ b/src/delta_barth/analysis/forecast.py @@ -8,7 +8,7 @@ import pandas as pd from sklearn.metrics import mean_squared_error from xgboost import XGBRegressor -from delta_barth._management import ERROR_HANDLER +from delta_barth._management import STATE_HANDLER from delta_barth.analysis import parse from delta_barth.constants import COL_MAP_SALES_PROGNOSIS, FEATURES_SALES_PROGNOSIS from delta_barth.types import CustomerDataSalesForecast, DataPipeStates, PipeResult @@ -105,7 +105,7 @@ def sales_per_customer( # check data availability if len(df_cust) < min_num_data_points: - return PipeResult(status=ERROR_HANDLER.pipe_states.TOO_FEW_POINTS, data=None) + return PipeResult(status=STATE_HANDLER.pipe_states.TOO_FEW_POINTS, data=None) else: # Entwicklung der Umsätze: definierte Zeiträume Monat df_cust["year"] = df_cust["date"].dt.year @@ -144,4 +144,4 @@ def sales_per_customer( test = test.reset_index(drop=True) # umsetzung, prognose - return PipeResult(status=ERROR_HANDLER.pipe_states.SUCCESS, data=test) + return PipeResult(status=STATE_HANDLER.pipe_states.SUCCESS, data=test) diff --git a/tests/analysis/test_forecast.py b/tests/analysis/test_forecast.py index 7fda954..5ee2112 100644 --- a/tests/analysis/test_forecast.py +++ b/tests/analysis/test_forecast.py @@ -1,5 +1,3 @@ -import pytest - from delta_barth.analysis import forecast as fc @@ -7,7 +5,7 @@ def test_sales_per_customer_success(sales_data): customer_id = 1133 res = fc.sales_per_customer(sales_data, customer_id) - assert res.status.status_code == 0 + assert res.status.code == 0 assert res.data is not None @@ -15,7 +13,7 @@ def test_sales_per_customer_too_few_data_points(sales_data): customer_id = 1000 res = fc.sales_per_customer(sales_data, customer_id) - assert res.status.status_code == 1 + assert res.status.code == 1 assert res.data is None