diff --git a/pdm.lock b/pdm.lock index ef7827f..84ab23c 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "lint", "nb", "tests"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:d51351adbafc599b97f8b3c9047ad0c7b8607d47cff5874121f546af04793ee2" +content_hash = "sha256:4931e32f8c146a72ad5b0a13c02485ea5ddc727de32fbe7c5e9314bbab05966c" [[metadata.targets]] requires_python = ">=3.11" @@ -648,6 +648,51 @@ files = [ {file = "fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f"}, ] +[[package]] +name = "greenlet" +version = "3.1.1" +requires_python = ">=3.7" +summary = "Lightweight in-process concurrent programming" +groups = ["default"] +marker = "(platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\") and python_version < \"3.14\"" +files = [ + {file = "greenlet-3.1.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e4d333e558953648ca09d64f13e6d8f0523fa705f51cae3f03b5983489958c70"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09fc016b73c94e98e29af67ab7b9a879c307c6731a2c9da0db5a7d9b7edd1159"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d5e975ca70269d66d17dd995dafc06f1b06e8cb1ec1e9ed54c1d1e4a7c4cf26e"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2813dc3de8c1ee3f924e4d4227999285fd335d1bcc0d2be6dc3f1f6a318ec1"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e347b3bfcf985a05e8c0b7d462ba6f15b1ee1c909e2dcad795e49e91b152c383"}, + {file = "greenlet-3.1.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9e8f8c9cb53cdac7ba9793c276acd90168f416b9ce36799b9b885790f8ad6c0a"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:62ee94988d6b4722ce0028644418d93a52429e977d742ca2ccbe1c4f4a792511"}, + {file = "greenlet-3.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1776fd7f989fc6b8d8c8cb8da1f6b82c5814957264d1f6cf818d475ec2bf6395"}, + {file = "greenlet-3.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:48ca08c771c268a768087b408658e216133aecd835c0ded47ce955381105ba39"}, + {file = "greenlet-3.1.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:4afe7ea89de619adc868e087b4d2359282058479d7cfb94970adf4b55284574d"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f406b22b7c9a9b4f8aa9d2ab13d6ae0ac3e85c9a809bd590ad53fed2bf70dc79"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c3a701fe5a9695b238503ce5bbe8218e03c3bcccf7e204e455e7462d770268aa"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2846930c65b47d70b9d178e89c7e1a69c95c1f68ea5aa0a58646b7a96df12441"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99cfaa2110534e2cf3ba31a7abcac9d328d1d9f1b95beede58294a60348fba36"}, + {file = "greenlet-3.1.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1443279c19fca463fc33e65ef2a935a5b09bb90f978beab37729e1c3c6c25fe9"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b7cede291382a78f7bb5f04a529cb18e068dd29e0fb27376074b6d0317bf4dd0"}, + {file = "greenlet-3.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:23f20bb60ae298d7d8656c6ec6db134bca379ecefadb0b19ce6f19d1f232a942"}, + {file = "greenlet-3.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:7124e16b4c55d417577c2077be379514321916d5790fa287c9ed6f23bd2ffd01"}, + {file = "greenlet-3.1.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:05175c27cb459dcfc05d026c4232f9de8913ed006d42713cb8a5137bd49375f1"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:935e943ec47c4afab8965954bf49bfa639c05d4ccf9ef6e924188f762145c0ff"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667a9706c970cb552ede35aee17339a18e8f2a87a51fba2ed39ceeeb1004798a"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8a678974d1f3aa55f6cc34dc480169d58f2e6d8958895d68845fa4ab566509e"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efc0f674aa41b92da8c49e0346318c6075d734994c3c4e4430b1c3f853e498e4"}, + {file = "greenlet-3.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0153404a4bb921f0ff1abeb5ce8a5131da56b953eda6e14b88dc6bbc04d2049e"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:275f72decf9932639c1c6dd1013a1bc266438eb32710016a1c742df5da6e60a1"}, + {file = "greenlet-3.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c4aab7f6381f38a4b42f269057aee279ab0fc7bf2e929e3d4abfae97b682a12c"}, + {file = "greenlet-3.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:b42703b1cf69f2aa1df7d1030b9d77d3e584a70755674d60e710f0af570f3761"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1695e76146579f8c06c1509c7ce4dfe0706f49c6831a817ac04eebb2fd02011"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7876452af029456b3f3549b696bb36a06db7c90747740c5302f74a9e9fa14b13"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ead44c85f8ab905852d3de8d86f6f8baf77109f9da589cb4fa142bd3b57b475"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8320f64b777d00dd7ccdade271eaf0cad6636343293a25074cc5566160e4de7b"}, + {file = "greenlet-3.1.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6510bf84a6b643dabba74d3049ead221257603a253d0a9873f55f6a59a65f822"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:04b013dc07c96f83134b1e99888e7a79979f1a247e2a9f59697fa14b5862ed01"}, + {file = "greenlet-3.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6"}, + {file = "greenlet-3.1.1.tar.gz", hash = "sha256:4ce3ac6cdb6adf7946475d7ef31777c26d94bccc377e070a7986bd2d5c515467"}, +] + [[package]] name = "h11" version = "0.14.0" @@ -2273,6 +2318,46 @@ files = [ {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, ] +[[package]] +name = "sqlalchemy" +version = "2.0.39" +requires_python = ">=3.7" +summary = "Database Abstraction Library" +groups = ["default"] +dependencies = [ + "greenlet!=0.4.17; (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\") and python_version < \"3.14\"", + "importlib-metadata; python_version < \"3.8\"", + "typing-extensions>=4.6.0", +] +files = [ + {file = "sqlalchemy-2.0.39-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a28f9c238f1e143ff42ab3ba27990dfb964e5d413c0eb001b88794c5c4a528a9"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:08cf721bbd4391a0e765fe0fe8816e81d9f43cece54fdb5ac465c56efafecb3d"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a8517b6d4005facdbd7eb4e8cf54797dbca100a7df459fdaff4c5123265c1cd"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b2de1523d46e7016afc7e42db239bd41f2163316935de7c84d0e19af7e69538"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:412c6c126369ddae171c13987b38df5122cb92015cba6f9ee1193b867f3f1530"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6b35e07f1d57b79b86a7de8ecdcefb78485dab9851b9638c2c793c50203b2ae8"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-win32.whl", hash = "sha256:3eb14ba1a9d07c88669b7faf8f589be67871d6409305e73e036321d89f1d904e"}, + {file = "sqlalchemy-2.0.39-cp311-cp311-win_amd64.whl", hash = "sha256:78f1b79132a69fe8bd6b5d91ef433c8eb40688ba782b26f8c9f3d2d9ca23626f"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c457a38351fb6234781d054260c60e531047e4d07beca1889b558ff73dc2014b"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:018ee97c558b499b58935c5a152aeabf6d36b3d55d91656abeb6d93d663c0c4c"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5493a8120d6fc185f60e7254fc056a6742f1db68c0f849cfc9ab46163c21df47"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2cf5b5ddb69142511d5559c427ff00ec8c0919a1e6c09486e9c32636ea2b9dd"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9f03143f8f851dd8de6b0c10784363712058f38209e926723c80654c1b40327a"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:06205eb98cb3dd52133ca6818bf5542397f1dd1b69f7ea28aa84413897380b06"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-win32.whl", hash = "sha256:7f5243357e6da9a90c56282f64b50d29cba2ee1f745381174caacc50d501b109"}, + {file = "sqlalchemy-2.0.39-cp312-cp312-win_amd64.whl", hash = "sha256:2ed107331d188a286611cea9022de0afc437dd2d3c168e368169f27aa0f61338"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fe193d3ae297c423e0e567e240b4324d6b6c280a048e64c77a3ea6886cc2aa87"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:79f4f502125a41b1b3b34449e747a6abfd52a709d539ea7769101696bdca6716"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8a10ca7f8a1ea0fd5630f02feb055b0f5cdfcd07bb3715fc1b6f8cb72bf114e4"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6b0a1c7ed54a5361aaebb910c1fa864bae34273662bb4ff788a527eafd6e14d"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52607d0ebea43cf214e2ee84a6a76bc774176f97c5a774ce33277514875a718e"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c08a972cbac2a14810463aec3a47ff218bb00c1a607e6689b531a7c589c50723"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-win32.whl", hash = "sha256:23c5aa33c01bd898f879db158537d7e7568b503b15aad60ea0c8da8109adf3e7"}, + {file = "sqlalchemy-2.0.39-cp313-cp313-win_amd64.whl", hash = "sha256:4dabd775fd66cf17f31f8625fc0e4cfc5765f7982f94dc09b9e5868182cb71c0"}, + {file = "sqlalchemy-2.0.39-py3-none-any.whl", hash = "sha256:a1c6b0a5e3e326a466d809b651c63f278b1256146a377a528b6938a279da334f"}, + {file = "sqlalchemy-2.0.39.tar.gz", hash = "sha256:5d2d1fe548def3267b4c70a8568f108d1fed7cbbeccb9cc166e05af2abc25c22"}, +] + [[package]] name = "stack-data" version = "0.6.3" diff --git a/pyproject.toml b/pyproject.toml index 633113b..fbc1e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [project] name = "delta-barth" -version = "0.4.1" +version = "0.5.0" description = "workflows and pipelines for the Python-based Plugin of Delta Barth's ERP system" authors = [ {name = "Florian Förster", email = "f.foerster@d-opt.com"}, ] -dependencies = ["scikit-learn>=1.6.1", "pandas>=2.2.3", "xgboost>=2.1.4", "joblib>=1.4.2", "typing-extensions>=4.12.2", "requests>=2.32.3", "pydantic>=2.10.6", "dopt-basics>=0.1.2"] +dependencies = ["scikit-learn>=1.6.1", "pandas>=2.2.3", "xgboost>=2.1.4", "joblib>=1.4.2", "typing-extensions>=4.12.2", "requests>=2.32.3", "pydantic>=2.10.6", "dopt-basics>=0.1.2", "SQLAlchemy>=2.0.39"] requires-python = ">=3.11" readme = "README.md" license = {text = "LicenseRef-Proprietary"} @@ -73,7 +73,7 @@ directory = "reports/coverage" [tool.bumpversion] -current_version = "0.4.1" +current_version = "0.5.0" parse = """(?x) (?P0|[1-9]\\d*)\\. (?P0|[1-9]\\d*)\\. diff --git a/src/delta_barth/analysis/forecast.py b/src/delta_barth/analysis/forecast.py index d494c15..cac4591 100644 --- a/src/delta_barth/analysis/forecast.py +++ b/src/delta_barth/analysis/forecast.py @@ -1,17 +1,21 @@ from __future__ import annotations import datetime +import math from collections.abc import Mapping, Set +from dataclasses import asdict from datetime import datetime as Datetime -from typing import TYPE_CHECKING, Final, cast +from typing import TYPE_CHECKING, Final, TypeAlias, cast import numpy as np import pandas as pd import scipy.stats +import sqlalchemy as sql from sklearn.metrics import mean_absolute_error, r2_score from sklearn.model_selection import KFold, RandomizedSearchCV from xgboost import XGBRegressor +from delta_barth import databases from delta_barth.analysis import parse from delta_barth.api.requests import ( SalesPrognosisResponse, @@ -28,18 +32,22 @@ from delta_barth.constants import ( SALES_MIN_NUM_DATAPOINTS, ) from delta_barth.errors import STATUS_HANDLER, wrap_result -from delta_barth.logging import logger_pipelines as logger +from delta_barth.logging import logger_db, logger_pipelines +from delta_barth.management import SESSION from delta_barth.types import ( BestParametersXGBRegressor, DualDict, ParamSearchXGBRegressor, PipeResult, + SalesForecastStatistics, ) if TYPE_CHECKING: - from delta_barth.api.common import Session + from delta_barth.session import Session from delta_barth.types import Status +ForecastPipe: TypeAlias = PipeResult[SalesPrognosisResultsExport, SalesForecastStatistics] + def _parse_api_resp_to_df( resp: SalesPrognosisResponse, @@ -73,6 +81,21 @@ def _parse_df_to_results( return SalesPrognosisResults(daten=tuple(df_formatted)) # type: ignore +def _write_sales_forecast_stats( + stats: SalesForecastStatistics, +) -> None: + stats_db = asdict(stats) + _ = stats_db.pop("xgb_params") + xgb_params = stats.xgb_params + + with SESSION.db_engine.begin() as conn: + res = conn.execute(sql.insert(databases.sf_stats).values(stats_db)) + sf_id = cast(int, res.inserted_primary_key[0]) # type: ignore + if xgb_params is not None: + xgb_params["forecast_id"] = sf_id + conn.execute(sql.insert(databases.sf_XGB).values(xgb_params)) + + @wrap_result() def _parse_api_resp_to_df_wrapped( resp: SalesPrognosisResponse, @@ -87,30 +110,18 @@ def _parse_df_to_results_wrapped( return _parse_df_to_results(data) -# ------------------------------------------------------------------------------ -# Input: -# DataFrame df mit Columns f_umsatz_fakt, firmen, art, v_warengrp -# kunde (muss enthalten sein in df['firmen']['firma_refid']) - -# Output: -# Integer umsetzung (Prognose möglich): 0 ja, 1 nein (zu wenig Daten verfügbar), -# 2 nein (Daten nicht für Prognose geeignet) -# DataFrame test: Jahr, Monat, Vorhersage -# ------------------------------------------------------------------------------- - - -# Prognose Umsatz je Firma - - -# TODO: check usage of separate exception and handle it in API function -# TODO set min number of data points as constant, not parameter +@wrap_result() +def _write_sales_forecast_stats_wrapped( + stats: SalesForecastStatistics, +) -> None: + return _write_sales_forecast_stats(stats) def _preprocess_sales( resp: SalesPrognosisResponse, feature_map: Mapping[str, str], target_features: Set[str], -) -> PipeResult[SalesPrognosisResultsExport]: +) -> ForecastPipe: """n = 1 Parameters @@ -127,7 +138,7 @@ def _preprocess_sales( PipeResult _description_ """ - pipe: PipeResult[SalesPrognosisResultsExport] = PipeResult(None, STATUS_HANDLER.SUCCESS) + pipe: ForecastPipe = PipeResult(None, STATUS_HANDLER.SUCCESS) res = _parse_api_resp_to_df_wrapped(resp) if res.status != STATUS_HANDLER.SUCCESS: @@ -149,10 +160,10 @@ def _preprocess_sales( def _process_sales( - pipe: PipeResult[SalesPrognosisResultsExport], + pipe: ForecastPipe, min_num_data_points: int, base_num_data_points_months: int, -) -> PipeResult[SalesPrognosisResultsExport]: +) -> ForecastPipe: """n = 1 Input-Data: fields: ["artikel_refid", "firma_refid", "betrag", "menge", "buchungs_datum"] @@ -182,9 +193,13 @@ def _process_sales( df_firma = data[(data["betrag"] > 0)] df_cust = df_firma.copy() df_cust = df_cust.sort_values(by=DATE_FEAT).reset_index() + len_ds = len(df_cust) - if len(df_cust) < min_num_data_points: - pipe.fail(STATUS_HANDLER.pipe_states.TOO_FEW_POINTS) + if len_ds < min_num_data_points: + status = STATUS_HANDLER.pipe_states.TOO_FEW_POINTS + pipe.fail(status) + stats = SalesForecastStatistics(status.code, status.description, len_ds) + pipe.stats(stats) return pipe df_cust["jahr"] = df_cust[DATE_FEAT].dt.year @@ -216,8 +231,8 @@ def _process_sales( } best_params: BestParametersXGBRegressor | None = None - best_score_mae: float = float("inf") - best_score_r2: float = float("inf") + best_score_mae: float | None = float("inf") + best_score_r2: float | None = None best_start_year: int | None = None too_few_month_points: bool = True forecast: pd.DataFrame | None = None @@ -252,7 +267,6 @@ def _process_sales( y_pred = rand.best_estimator_.predict(X_test) # type: ignore if len(np.unique(y_pred)) != 1: - # pp(y_pred) error = cast(float, mean_absolute_error(y_test, y_pred)) if error < best_score_mae: best_params = cast(BestParametersXGBRegressor, rand.best_params_) @@ -263,31 +277,44 @@ def _process_sales( forecast = test.copy() forecast.loc[:, "vorhersage"] = y_pred - # pp(best_params) - # pp(best_score_mae) - # pp(best_score_r2) - # pp(best_start_year) if forecast is not None: forecast = forecast.drop(SALES_FEAT, axis=1).reset_index(drop=True) - - # TODO log metrics + best_score_mae = best_score_mae if not math.isinf(best_score_mae) else None if too_few_month_points: - pipe.fail(STATUS_HANDLER.pipe_states.TOO_FEW_MONTH_POINTS) + status = STATUS_HANDLER.pipe_states.TOO_FEW_MONTH_POINTS + pipe.fail(status) + stats = SalesForecastStatistics(status.code, status.description, len_ds) + pipe.stats(stats) return pipe elif best_params is None: - pipe.fail(STATUS_HANDLER.pipe_states.NO_RELIABLE_FORECAST) + status = STATUS_HANDLER.pipe_states.NO_RELIABLE_FORECAST + pipe.fail(status) + stats = SalesForecastStatistics(status.code, status.description, len_ds) + pipe.stats(stats) return pipe assert forecast is not None, "forecast is None, but was attempted to be returned" - pipe.success(forecast, STATUS_HANDLER.SUCCESS) + status = STATUS_HANDLER.SUCCESS + pipe.success(forecast, status) + stats = SalesForecastStatistics( + status.code, + status.description, + len_ds, + score_mae=best_score_mae, + score_r2=best_score_r2, + best_start_year=best_start_year, + xgb_params=best_params, + ) + pipe.stats(stats) + return pipe def _postprocess_sales( - pipe: PipeResult[SalesPrognosisResultsExport], + pipe: ForecastPipe, feature_map: Mapping[str, str], -) -> PipeResult[SalesPrognosisResultsExport]: +) -> ForecastPipe: data = pipe.data assert data is not None, "processing not existing pipe result" # convert features back to original naming @@ -321,7 +348,7 @@ def _export_on_fail( return SalesPrognosisResultsExport(response=response, status=status) -def pipeline_sales( +def pipeline_sales_forecast( session: Session, company_id: int | None = None, start_date: Datetime | None = None, @@ -332,8 +359,8 @@ def pipeline_sales( start_date=start_date, ) if status != STATUS_HANDLER.SUCCESS: - logger.error( - "Error during sales prognosis data retrieval, Status: %s", + logger_pipelines.error( + "Error during sales forecast data retrieval, Status: %s", status, stack_info=True, ) @@ -345,8 +372,8 @@ def pipeline_sales( target_features=FEATURES_SALES_PROGNOSIS, ) if pipe.status != STATUS_HANDLER.SUCCESS: - logger.error( - "Error during sales prognosis preprocessing, Status: %s", + logger_pipelines.error( + "Error during sales forecast preprocessing, Status: %s", pipe.status, stack_info=True, ) @@ -357,9 +384,16 @@ def pipeline_sales( min_num_data_points=SALES_MIN_NUM_DATAPOINTS, base_num_data_points_months=SALES_BASE_NUM_DATAPOINTS_MONTHS, ) + if pipe.statistics is not None: + res = _write_sales_forecast_stats_wrapped(pipe.statistics) + if res.status != STATUS_HANDLER.SUCCESS: + logger_db.error( + "[DB] Error during write process of sales forecast statistics: %s", + res.status, + ) if pipe.status != STATUS_HANDLER.SUCCESS: - logger.error( - "Error during sales prognosis main processing, Status: %s", + logger_pipelines.error( + "Error during sales forecast main processing, Status: %s", pipe.status, stack_info=True, ) @@ -370,8 +404,8 @@ def pipeline_sales( feature_map=DualDict(), ) if pipe.status != STATUS_HANDLER.SUCCESS: - logger.error( - "Error during sales prognosis postprocessing, Status: %s", + logger_pipelines.error( + "Error during sales forecast postprocessing, Status: %s", pipe.status, stack_info=True, ) @@ -393,7 +427,7 @@ def pipeline_sales_dummy( data_pth = DUMMY_DATA_PATH / "exmp_sales_prognosis_output.pkl" assert data_pth.exists(), "sales forecast dummy data not existent" data = pd.read_pickle(data_pth) - pipe: PipeResult[SalesPrognosisResultsExport] = PipeResult(None, STATUS_HANDLER.SUCCESS) + pipe: ForecastPipe = PipeResult(None, STATUS_HANDLER.SUCCESS) res = _parse_df_to_results_wrapped(data) if res.status != STATUS_HANDLER.SUCCESS: diff --git a/src/delta_barth/api/common.py b/src/delta_barth/api/common.py index 688a35a..d542f12 100644 --- a/src/delta_barth/api/common.py +++ b/src/delta_barth/api/common.py @@ -1,236 +1,31 @@ from __future__ import annotations -from pathlib import Path -from typing import TYPE_CHECKING, Final +from typing import Final import requests from dopt_basics.io import combine_route from pydantic import BaseModel from requests import Response -import delta_barth.logging from delta_barth.errors import ( - STATUS_HANDLER, UnspecifiedRequestType, ) -from delta_barth.logging import logger_session as logger from delta_barth.types import ( ApiCredentials, - DelBarApiError, HttpRequestTypes, ) -if TYPE_CHECKING: - from delta_barth.types import HttpContentHeaders, Status + +# ** login +class LoginRequest(BaseModel): + userName: str + password: str + databaseName: str + mandantName: str -class Session: - def __init__( - self, - base_headers: HttpContentHeaders, - logging_folder: str = "logs", - ) -> None: - self._data_path: Path | None = None - self._logging_dir: Path | None = None - self._logging_folder = logging_folder - self._creds: ApiCredentials | None = None - self._base_url: str | None = None - self._headers = base_headers - self._session_token: str | None = None - self._logged_in: bool = False - - def setup(self) -> None: - self.setup_logging() - - @property - def data_path(self) -> Path: - assert self._data_path is not None, "accessed data path not set" - return self._data_path - - @property - def logging_dir(self) -> Path: - if self._logging_dir is not None: - return self._logging_dir - - logging_dir = self.data_path / self._logging_folder - if not logging_dir.exists(): - logging_dir.mkdir(parents=False) - self._logging_dir = logging_dir - return self._logging_dir - - def setup_logging(self) -> None: - delta_barth.logging.setup_logging(self.logging_dir) - logger.info("[SESSION] Successfully setup logging") - - @property - def creds(self) -> ApiCredentials: - assert self._creds is not None, "accessed credentials not set" - return self._creds - - def set_data_path( - self, - path: str, - ): - self._data_path = validate_path(path) - - def set_credentials( - self, - username: str, - password: str, - database: str, - mandant: str, - ) -> None: - if self.logged_in: - self.logout() - self._creds = validate_credentials( - username=username, - password=password, - database=database, - mandant=mandant, - ) - - @property - def base_url(self) -> str: - assert self._base_url is not None, "accessed base URL not set" - return self._base_url - - def set_base_url( - self, - base_url: str, - ) -> None: - if self.logged_in: - self.logout() - self._base_url = base_url - - @property - def headers(self) -> HttpContentHeaders: - return self._headers - - @property - def session_token(self) -> str | None: - return self._session_token - - @property - def logged_in(self) -> bool: - return self._logged_in - - def _add_session_token( - self, - token: str, - ) -> None: - assert self.session_token is None, "tried overwriting existing API session token" - self._session_token = token - self._headers.update(DelecoToken=token) - self._logged_in = True - - def _remove_session_token(self) -> None: - assert self.session_token is not None, ( - "tried to delete non-existing API session token" - ) - if "DelecoToken" in self.headers: - del self._headers["DelecoToken"] - self._session_token = None - self._logged_in = False - - def login( - self, - ) -> tuple[LoginResponse, Status]: - ROUTE: Final[str] = "user/login" - URL: Final = combine_route(self.base_url, ROUTE) - - login_req = LoginRequest( - userName=self.creds.username, - password=self.creds.password, - databaseName=self.creds.database, - mandantName=self.creds.mandant, - ) - resp = requests.put( - URL, - login_req.model_dump_json(), - headers=self.headers, # type: ignore - ) - - response: LoginResponse - status: Status - if resp.status_code == 200: - response = LoginResponse(**resp.json()) - status = STATUS_HANDLER.pipe_states.SUCCESS - self._add_session_token(response.token) - else: - response = LoginResponse(token="") - err = DelBarApiError(status_code=resp.status_code, **resp.json()) - status = STATUS_HANDLER.api_error(err) - - return response, status - - def logout( - self, - ) -> tuple[None, Status]: - ROUTE: Final[str] = "user/logout" - URL: Final = combine_route(self.base_url, ROUTE) - - resp = requests.put( - URL, - headers=self.headers, # type: ignore - ) - - response = None - status: Status - if resp.status_code == 200: - status = STATUS_HANDLER.SUCCESS - self._remove_session_token() - else: - err = DelBarApiError(status_code=resp.status_code, **resp.json()) - status = STATUS_HANDLER.api_error(err) - - return response, status - - def assert_login( - self, - ) -> tuple[LoginResponse, Status]: - # check if login token is still valid - # re-login if necessary - if self.session_token is None: - return self.login() - - # use known endpoint which requires a valid token in its header - # evaluate the response to decide if: - # current token is still valid, token is not valid, other errors occurred - ROUTE: Final[str] = "verkauf/umsatzprognosedaten" - URL: Final = combine_route(self.base_url, ROUTE) - params: dict[str, int] = {"FirmaId": 999999} - resp = requests.get( - URL, - params=params, - headers=self.headers, # type: ignore - ) - - response: LoginResponse - status: Status - if resp.status_code == 200: - response = LoginResponse(token=self.session_token) - status = STATUS_HANDLER.SUCCESS - elif resp.status_code == 401: - self._remove_session_token() - response, status = self.login() - else: - response = LoginResponse(token="") - err = DelBarApiError(status_code=resp.status_code, **resp.json()) - status = STATUS_HANDLER.api_error(err) - - return response, status - - -def validate_path( - str_path: str, -) -> Path: - path = Path(str_path).resolve() - if not path.exists(): - raise FileNotFoundError(f"Provided path >{path}< seems not to exist.") - elif not path.is_dir(): - raise FileNotFoundError(f"Provided path >{path}< seems not to be a directory.") - - return path +class LoginResponse(BaseModel): + token: str def validate_credentials( @@ -265,15 +60,3 @@ def ping( raise UnspecifiedRequestType(f"Request type {method} not defined for endpoint") return resp - - -# ** login -class LoginRequest(BaseModel): - userName: str - password: str - databaseName: str - mandantName: str - - -class LoginResponse(BaseModel): - token: str diff --git a/src/delta_barth/api/requests.py b/src/delta_barth/api/requests.py index 246bc4f..18fdd4f 100644 --- a/src/delta_barth/api/requests.py +++ b/src/delta_barth/api/requests.py @@ -11,7 +11,7 @@ from delta_barth.errors import STATUS_HANDLER from delta_barth.types import DelBarApiError, ExportResponse, ResponseType, Status if TYPE_CHECKING: - from delta_barth.api.common import Session + from delta_barth.session import Session # ** sales data diff --git a/src/delta_barth/constants.py b/src/delta_barth/constants.py index 4a85a6f..b2a49cd 100644 --- a/src/delta_barth/constants.py +++ b/src/delta_barth/constants.py @@ -20,6 +20,8 @@ LOGGING_TO_FILE: Final[bool] = True LOGGING_TO_STDERR: Final[bool] = True LOG_FILENAME: Final[str] = "dopt-delbar.log" +# ** databases +DB_ECHO: Final[bool] = True # ** error handling DEFAULT_INTERNAL_ERR_CODE: Final[int] = 100 diff --git a/src/delta_barth/databases.py b/src/delta_barth/databases.py new file mode 100644 index 0000000..6756040 --- /dev/null +++ b/src/delta_barth/databases.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import sqlalchemy as sql + +# ** meta +metadata = sql.MetaData() + + +def get_engine( + db_path: Path, + echo: bool = False, +) -> sql.Engine: + path = db_path.resolve() + connection_str: str = f"sqlite:///{str(path)}" + engine = sql.create_engine(connection_str, echo=echo) + return engine + + +# ** table declarations +# ** ---- common +perf_meas = sql.Table( + "performance_measurement", + metadata, + sql.Column("id", sql.Integer, primary_key=True), + sql.Column("execution_duration", sql.Float), + sql.Column("pipeline_name", sql.String(length=30)), +) +# ** ---- forecasts +sf_stats = sql.Table( + "sales_forecast_statistics", + metadata, + sql.Column("id", sql.Integer, primary_key=True), + sql.Column("status_code", sql.Integer), + sql.Column("status_dscr", sql.String(length=200)), + sql.Column("length_dataset", sql.Integer), + sql.Column("score_mae", sql.Float, nullable=True), + sql.Column("score_r2", sql.Float, nullable=True), + sql.Column("best_start_year", sql.Integer, nullable=True), +) +sf_XGB = sql.Table( + "sales_forecast_XGB_parameters", + metadata, + sql.Column("id", sql.Integer, primary_key=True), + sql.Column( + "forecast_id", + sql.Integer, + sql.ForeignKey( + "sales_forecast_statistics.id", onupdate="CASCADE", ondelete="CASCADE" + ), + unique=True, + ), + sql.Column("n_estimators", sql.Integer), + sql.Column("learning_rate", sql.Float), + sql.Column("max_depth", sql.Integer), + sql.Column("min_child_weight", sql.Integer), + sql.Column("gamma", sql.Float), + sql.Column("subsample", sql.Float), + sql.Column("colsample_bytree", sql.Float), + sql.Column("early_stopping_rounds", sql.Integer), +) diff --git a/src/delta_barth/logging.py b/src/delta_barth/logging.py index adeb308..908a364 100644 --- a/src/delta_barth/logging.py +++ b/src/delta_barth/logging.py @@ -34,8 +34,10 @@ logger_session = logging.getLogger("delta_barth.session") logger_session.setLevel(logging.DEBUG) logger_wrapped_results = logging.getLogger("delta_barth.wrapped_results") logger_wrapped_results.setLevel(logging.DEBUG) -logger_pipelines = logging.getLogger("delta_barth.logger_pipelines") +logger_pipelines = logging.getLogger("delta_barth.pipelines") logger_pipelines.setLevel(logging.DEBUG) +logger_db = logging.getLogger("delta_barth.databases") +logger_db.setLevel(logging.DEBUG) def setup_logging( diff --git a/src/delta_barth/management.py b/src/delta_barth/management.py index 77a1ffc..15badc4 100644 --- a/src/delta_barth/management.py +++ b/src/delta_barth/management.py @@ -5,16 +5,17 @@ from __future__ import annotations from typing import Final -from delta_barth.api.common import Session from delta_barth.constants import HTTP_BASE_CONTENT_HEADERS +from delta_barth.session import Session SESSION: Final[Session] = Session(HTTP_BASE_CONTENT_HEADERS) -def set_data_path( - path: str, +def setup( + data_path: str, ) -> None: # pragma: no cover - SESSION.set_data_path(path) + SESSION.set_data_path(data_path) + SESSION.setup() def set_credentials( diff --git a/src/delta_barth/pipelines.py b/src/delta_barth/pipelines.py index a4aadd5..d2c2802 100644 --- a/src/delta_barth/pipelines.py +++ b/src/delta_barth/pipelines.py @@ -11,7 +11,9 @@ def pipeline_sales_forecast( company_id: int | None, start_date: Datetime | None, ) -> JsonExportResponse: - result = forecast.pipeline_sales(SESSION, company_id=company_id, start_date=start_date) + result = forecast.pipeline_sales_forecast( + SESSION, company_id=company_id, start_date=start_date + ) export = JsonExportResponse(result.model_dump_json()) return export diff --git a/src/delta_barth/session.py b/src/delta_barth/session.py new file mode 100644 index 0000000..e539ce9 --- /dev/null +++ b/src/delta_barth/session.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Final + +import requests +import sqlalchemy as sql +from dopt_basics.io import combine_route + +import delta_barth.logging +from delta_barth import databases as db +from delta_barth.api.common import ( + LoginRequest, + LoginResponse, + validate_credentials, +) +from delta_barth.constants import DB_ECHO +from delta_barth.errors import STATUS_HANDLER +from delta_barth.logging import logger_session as logger +from delta_barth.types import DelBarApiError, Status + +if TYPE_CHECKING: + from delta_barth.types import ApiCredentials, HttpContentHeaders + + +def validate_path( + str_path: str, +) -> Path: + path = Path(str_path).resolve() + if not path.exists(): + raise FileNotFoundError(f"Provided path >{path}< seems not to exist.") + elif not path.is_dir(): + raise FileNotFoundError(f"Provided path >{path}< seems not to be a directory.") + + return path + + +class Session: + def __init__( + self, + base_headers: HttpContentHeaders, + db_folder: str = "data", + logging_folder: str = "logs", + ) -> None: + self._data_path: Path | None = None + self._db_path: Path | None = None + self._db_folder = db_folder + self._db_engine: sql.Engine | None = None + self._logging_dir: Path | None = None + self._logging_folder = logging_folder + self._creds: ApiCredentials | None = None + self._base_url: str | None = None + self._headers = base_headers + self._session_token: str | None = None + self._logged_in: bool = False + + def setup(self) -> None: + self._setup_db_management() + self._setup_logging() + + @property + def data_path(self) -> Path: + assert self._data_path is not None, "accessed data path not set" + return self._data_path + + @property + def db_engine(self) -> sql.Engine: + assert self._db_engine is not None, "accessed database engine not set" + return self._db_engine + + @property + def db_path(self) -> Path: + if self._db_path is not None: + return self._db_path + + db_root = (self.data_path / self._db_folder).resolve() + db_path = db_root / "dopt-data.db" + if not db_root.exists(): + db_root.mkdir(parents=False) + self._db_path = db_path + return self._db_path + + @property + def logging_dir(self) -> Path: + if self._logging_dir is not None: + return self._logging_dir + + logging_dir = self.data_path / self._logging_folder + if not logging_dir.exists(): + logging_dir.mkdir(parents=False) + self._logging_dir = logging_dir + return self._logging_dir + + def _setup_db_management(self) -> None: + self._db_engine = db.get_engine(self.db_path, echo=DB_ECHO) + db.metadata.create_all(self._db_engine) + logger.info("[SESSION] Successfully setup DB management") + + def _setup_logging(self) -> None: + delta_barth.logging.setup_logging(self.logging_dir) + logger.info("[SESSION] Successfully setup logging") + + @property + def creds(self) -> ApiCredentials: + assert self._creds is not None, "accessed credentials not set" + return self._creds + + def set_data_path( + self, + path: str, + ): + self._data_path = validate_path(path) + + def set_credentials( + self, + username: str, + password: str, + database: str, + mandant: str, + ) -> None: + if self.logged_in: + self.logout() + self._creds = validate_credentials( + username=username, + password=password, + database=database, + mandant=mandant, + ) + + @property + def base_url(self) -> str: + assert self._base_url is not None, "accessed base URL not set" + return self._base_url + + def set_base_url( + self, + base_url: str, + ) -> None: + if self.logged_in: + self.logout() + self._base_url = base_url + + @property + def headers(self) -> HttpContentHeaders: + return self._headers + + @property + def session_token(self) -> str | None: + return self._session_token + + @property + def logged_in(self) -> bool: + return self._logged_in + + def _add_session_token( + self, + token: str, + ) -> None: + assert self.session_token is None, "tried overwriting existing API session token" + self._session_token = token + self._headers.update(DelecoToken=token) + self._logged_in = True + + def _remove_session_token(self) -> None: + assert self.session_token is not None, ( + "tried to delete non-existing API session token" + ) + if "DelecoToken" in self.headers: + del self._headers["DelecoToken"] + self._session_token = None + self._logged_in = False + + def login( + self, + ) -> tuple[LoginResponse, Status]: + ROUTE: Final[str] = "user/login" + URL: Final = combine_route(self.base_url, ROUTE) + + login_req = LoginRequest( + userName=self.creds.username, + password=self.creds.password, + databaseName=self.creds.database, + mandantName=self.creds.mandant, + ) + resp = requests.put( + URL, + login_req.model_dump_json(), + headers=self.headers, # type: ignore + ) + + response: LoginResponse + status: Status + if resp.status_code == 200: + response = LoginResponse(**resp.json()) + status = STATUS_HANDLER.pipe_states.SUCCESS + self._add_session_token(response.token) + else: + response = LoginResponse(token="") + err = DelBarApiError(status_code=resp.status_code, **resp.json()) + status = STATUS_HANDLER.api_error(err) + + return response, status + + def logout( + self, + ) -> tuple[None, Status]: + ROUTE: Final[str] = "user/logout" + URL: Final = combine_route(self.base_url, ROUTE) + + resp = requests.put( + URL, + headers=self.headers, # type: ignore + ) + + response = None + status: Status + if resp.status_code == 200: + status = STATUS_HANDLER.SUCCESS + self._remove_session_token() + else: + err = DelBarApiError(status_code=resp.status_code, **resp.json()) + status = STATUS_HANDLER.api_error(err) + + return response, status + + def assert_login( + self, + ) -> tuple[LoginResponse, Status]: + # check if login token is still valid + # re-login if necessary + if self.session_token is None: + return self.login() + + # use known endpoint which requires a valid token in its header + # evaluate the response to decide if: + # current token is still valid, token is not valid, other errors occurred + ROUTE: Final[str] = "verkauf/umsatzprognosedaten" + URL: Final = combine_route(self.base_url, ROUTE) + params: dict[str, int] = {"FirmaId": 999999} + resp = requests.get( + URL, + params=params, + headers=self.headers, # type: ignore + ) + + response: LoginResponse + status: Status + if resp.status_code == 200: + response = LoginResponse(token=self.session_token) + status = STATUS_HANDLER.SUCCESS + elif resp.status_code == 401: + self._remove_session_token() + response, status = self.login() + else: + response = LoginResponse(token="") + err = DelBarApiError(status_code=resp.status_code, **resp.json()) + status = STATUS_HANDLER.api_error(err) + + return response, status diff --git a/src/delta_barth/types.py b/src/delta_barth/types.py index c871d2e..c65506c 100644 --- a/src/delta_barth/types.py +++ b/src/delta_barth/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import enum +import pprint import typing as t from collections.abc import Sequence from dataclasses import dataclass, field @@ -14,6 +15,7 @@ __all__ = ["DualDict"] # ** Pipeline state management StatusDescription: t.TypeAlias = tuple[str, int, str] R = t.TypeVar("R", bound="ExportResponse") +S = t.TypeVar("S", bound="Statistics") class IError(t.Protocol): @@ -28,6 +30,10 @@ class Status(BaseModel): message: SkipValidation[str] = "" api_server_error: SkipValidation[DelBarApiError | None] = None + def __str__(self) -> str: + py_repr = self.model_dump() + return pprint.pformat(py_repr, indent=4, sort_dicts=False) + class ResponseType(BaseModel): pass @@ -47,10 +53,11 @@ class DataPipeStates: @dataclass(slots=True) -class PipeResult(t.Generic[R]): +class PipeResult(t.Generic[R, S]): data: pd.DataFrame | None status: Status results: R | None = None + statistics: S | None = None def success( self, @@ -77,6 +84,12 @@ class PipeResult(t.Generic[R]): self.status = response.status self.results = response + def stats( + self, + statistics: S, + ) -> None: + self.statistics = statistics + JsonExportResponse = t.NewType("JsonExportResponse", str) JsonResponse = t.NewType("JsonResponse", str) @@ -121,6 +134,11 @@ HttpContentHeaders = t.TypedDict( ) +# ** statistics +class Statistics: + pass + + # ** forecasts @dataclass(slots=True) class CustomerDataSalesForecast: @@ -140,7 +158,19 @@ class ParamSearchXGBRegressor(t.TypedDict): early_stopping_rounds: Sequence[int] +@dataclass(slots=True, eq=False) +class SalesForecastStatistics(Statistics): + status_code: int + status_dscr: str + length_dataset: int + score_mae: float | None = None + score_r2: float | None = None + best_start_year: int | None = None + xgb_params: BestParametersXGBRegressor | None = None + + class BestParametersXGBRegressor(t.TypedDict): + forecast_id: t.NotRequired[int] n_estimators: int learning_rate: float max_depth: int diff --git a/tests/analysis/test_forecast.py b/tests/analysis/test_forecast.py index da95e52..d467ca0 100644 --- a/tests/analysis/test_forecast.py +++ b/tests/analysis/test_forecast.py @@ -1,17 +1,22 @@ -import importlib from datetime import datetime as Datetime from unittest.mock import patch import numpy as np import pandas as pd import pytest +import sqlalchemy as sql from pydantic import ValidationError -import delta_barth.analysis.forecast +from delta_barth import databases as db from delta_barth.analysis import forecast as fc from delta_barth.api.requests import SalesPrognosisResponse, SalesPrognosisResponseEntry from delta_barth.errors import STATUS_HANDLER -from delta_barth.types import DualDict, PipeResult +from delta_barth.types import ( + BestParametersXGBRegressor, + DualDict, + PipeResult, + SalesForecastStatistics, +) @pytest.fixture(scope="function") @@ -125,6 +130,96 @@ def test_parse_df_to_results_InvalidData(invalid_results): _ = fc._parse_df_to_results(invalid_results) +def test_write_sales_forecast_stats_small(session): + eng = session.db_engine + code = 0 + descr = "Test case to write stats" + length = 32 + stats = SalesForecastStatistics(code, descr, length) + # execute + with patch("delta_barth.analysis.forecast.SESSION", session): + fc._write_sales_forecast_stats(stats) + # read + with eng.begin() as conn: + res = conn.execute(sql.select(db.sf_stats)) + + inserted = tuple(res.mappings())[0] + data = dict(**inserted) + del data["id"] + result = SalesForecastStatistics(**data) + assert result.status_code == code + assert result.status_dscr == descr + assert result.length_dataset == length + assert result.score_mae is None + assert result.score_r2 is None + assert result.best_start_year is None + assert result.xgb_params is None + + +def test_write_sales_forecast_stats_large(session): + eng = session.db_engine + code = 0 + descr = "Test case to write stats" + length = 32 + score_mae = 3.54 + score_r2 = 0.56 + best_start_year = 2020 + xgb_params = BestParametersXGBRegressor( + n_estimators=2, + learning_rate=0.3, + max_depth=2, + min_child_weight=5, + gamma=0.5, + subsample=0.8, + colsample_bytree=5.25, + early_stopping_rounds=5, + ) + stats = SalesForecastStatistics( + code, + descr, + length, + score_mae, + score_r2, + best_start_year, + xgb_params, + ) + # execute + with patch("delta_barth.analysis.forecast.SESSION", session): + fc._write_sales_forecast_stats(stats) + # read + with eng.begin() as conn: + res_stats = conn.execute(sql.select(db.sf_stats)) + res_xgb = conn.execute(sql.select(db.sf_XGB)) + # reconstruct best XGB parameters + inserted_xgb = tuple(res_xgb.mappings())[0] + data_xgb = dict(**inserted_xgb) + del data_xgb["id"] + xgb_stats = BestParametersXGBRegressor(**data_xgb) + # reconstruct other statistics + inserted = tuple(res_stats.mappings())[0] + data_inserted = dict(**inserted) + stats_id_fk = data_inserted["id"] # foreign key in XGB parameters + del data_inserted["id"] + stats = SalesForecastStatistics(**data_inserted, xgb_params=xgb_stats) + assert stats.status_code == code + assert stats.status_dscr == descr + assert stats.length_dataset == length + assert stats.score_mae == pytest.approx(score_mae) + assert stats.score_r2 == pytest.approx(score_r2) + assert stats.best_start_year == best_start_year + assert stats.xgb_params is not None + # compare xgb_stats + assert stats.xgb_params["forecast_id"] == stats_id_fk # type: ignore + assert stats.xgb_params["n_estimators"] == 2 + assert stats.xgb_params["learning_rate"] == pytest.approx(0.3) + assert stats.xgb_params["max_depth"] == 2 + assert stats.xgb_params["min_child_weight"] == 5 + assert stats.xgb_params["gamma"] == pytest.approx(0.5) + assert stats.xgb_params["subsample"] == pytest.approx(0.8) + assert stats.xgb_params["colsample_bytree"] == pytest.approx(5.25) + assert stats.xgb_params["early_stopping_rounds"] == 5 + + def test_preprocess_sales_Success( exmpl_api_sales_prognosis_resp, feature_map, @@ -172,6 +267,14 @@ def test_process_sales_Success(sales_data_real_preproc): assert pipe.status == STATUS_HANDLER.SUCCESS assert pipe.data is not None assert pipe.results is None + assert pipe.statistics is not None + assert pipe.statistics.status_code == STATUS_HANDLER.SUCCESS.code + assert pipe.statistics.status_dscr == STATUS_HANDLER.SUCCESS.description + assert pipe.statistics.length_dataset is not None + assert pipe.statistics.score_mae is not None + assert pipe.statistics.score_r2 is not None + assert pipe.statistics.best_start_year is not None + assert pipe.statistics.xgb_params is not None def test_process_sales_FailTooFewPoints(sales_data_real_preproc): @@ -188,6 +291,16 @@ def test_process_sales_FailTooFewPoints(sales_data_real_preproc): assert pipe.status == STATUS_HANDLER.pipe_states.TOO_FEW_POINTS assert pipe.data is None assert pipe.results is None + assert pipe.statistics is not None + assert pipe.statistics.status_code == STATUS_HANDLER.pipe_states.TOO_FEW_POINTS.code + assert ( + pipe.statistics.status_dscr == STATUS_HANDLER.pipe_states.TOO_FEW_POINTS.description + ) + assert pipe.statistics.length_dataset is not None + assert pipe.statistics.score_mae is None + assert pipe.statistics.score_r2 is None + assert pipe.statistics.best_start_year is None + assert pipe.statistics.xgb_params is None def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc): @@ -203,6 +316,17 @@ def test_process_sales_FailTooFewMonthPoints(sales_data_real_preproc): assert pipe.status == STATUS_HANDLER.pipe_states.TOO_FEW_MONTH_POINTS assert pipe.data is None assert pipe.results is None + assert pipe.statistics is not None + assert pipe.statistics.status_code == STATUS_HANDLER.pipe_states.TOO_FEW_MONTH_POINTS.code + assert ( + pipe.statistics.status_dscr + == STATUS_HANDLER.pipe_states.TOO_FEW_MONTH_POINTS.description + ) + assert pipe.statistics.length_dataset is not None + assert pipe.statistics.score_mae is None + assert pipe.statistics.score_r2 is None + assert pipe.statistics.best_start_year is None + assert pipe.statistics.xgb_params is None def test_process_sales_FailNoReliableForecast(sales_data_real_preproc): @@ -237,6 +361,17 @@ def test_process_sales_FailNoReliableForecast(sales_data_real_preproc): assert pipe.status == STATUS_HANDLER.pipe_states.NO_RELIABLE_FORECAST assert pipe.data is None assert pipe.results is None + assert pipe.statistics is not None + assert pipe.statistics.status_code == STATUS_HANDLER.pipe_states.NO_RELIABLE_FORECAST.code + assert ( + pipe.statistics.status_dscr + == STATUS_HANDLER.pipe_states.NO_RELIABLE_FORECAST.description + ) + assert pipe.statistics.length_dataset is not None + assert pipe.statistics.score_mae is None + assert pipe.statistics.score_r2 is None + assert pipe.statistics.best_start_year is None + assert pipe.statistics.xgb_params is None def test_postprocess_sales_Success( @@ -281,16 +416,25 @@ def test_export_on_fail(): @patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1) -def test_pipeline_sales_prognosis(exmpl_api_sales_prognosis_resp): - def mock_request(*args, **kwargs): # pragma: no cover - return exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS - +def test_pipeline_sales_forecast_SuccessDbWrite(exmpl_api_sales_prognosis_resp, session): with patch( "delta_barth.analysis.forecast.get_sales_prognosis_data", - # new=mock_request, ) as mock: mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS - result = fc.pipeline_sales(None) # type: ignore + with patch("delta_barth.analysis.forecast.SESSION", session): + result = fc.pipeline_sales_forecast(None) # type: ignore + print(result) + assert result.status == STATUS_HANDLER.SUCCESS + assert len(result.response.daten) > 0 + + +@patch("delta_barth.analysis.forecast.SALES_BASE_NUM_DATAPOINTS_MONTHS", 1) +def test_pipeline_sales_forecast_FailDbWrite(exmpl_api_sales_prognosis_resp): + with patch( + "delta_barth.analysis.forecast.get_sales_prognosis_data", + ) as mock: + mock.return_value = exmpl_api_sales_prognosis_resp, STATUS_HANDLER.SUCCESS + result = fc.pipeline_sales_forecast(None) # type: ignore print(result) assert result.status == STATUS_HANDLER.SUCCESS assert len(result.response.daten) > 0 diff --git a/tests/api/conftest.py b/tests/api/conftest.py deleted file mode 100644 index 494a6b4..0000000 --- a/tests/api/conftest.py +++ /dev/null @@ -1,32 +0,0 @@ -from unittest.mock import patch - -import pytest - -from delta_barth.api import common -from delta_barth.constants import HTTP_BASE_CONTENT_HEADERS - - -@pytest.fixture(scope="function") -def session(credentials, api_base_url) -> common.Session: - session = common.Session(HTTP_BASE_CONTENT_HEADERS) - session.set_base_url(api_base_url) - session.set_credentials( - username=credentials["user"], - password=credentials["pwd"], - database=credentials["db"], - mandant=credentials["mandant"], - ) - - return session - - -@pytest.fixture -def mock_put(): - with patch("requests.put") as mock: - yield mock - - -@pytest.fixture -def mock_get(): - with patch("requests.get") as mock: - yield mock diff --git a/tests/api/test_common.py b/tests/api/test_common.py index f48f3cb..7e166b1 100644 --- a/tests/api/test_common.py +++ b/tests/api/test_common.py @@ -1,72 +1,13 @@ -from pathlib import Path -from unittest.mock import patch - import pytest from pydantic import ValidationError from delta_barth.api import common -from delta_barth.constants import ( - DEFAULT_API_ERR_CODE, - HTTP_BASE_CONTENT_HEADERS, - LOG_FILENAME, -) from delta_barth.errors import ( UnspecifiedRequestType, ) from delta_barth.types import HttpRequestTypes -def test_validate_path_Success(): - str_pth = str(Path.cwd()) - path = common.validate_path(str_pth) - assert path.name == Path.cwd().name - - -def test_validate_path_FailNotExisting(): - str_pth = str(Path.cwd() / "test") - with pytest.raises(FileNotFoundError, match=r"seems not to exist"): - _ = common.validate_path(str_pth) - - -def test_validate_path_FailNoDirectory(tmp_path): - file = tmp_path / "test.txt" - file.write_text("test", encoding="utf-8") - - str_pth = str(file) - with pytest.raises(FileNotFoundError, match=r"seems not to be a directory"): - _ = common.validate_path(str_pth) - - -def test_session_set_DataPath(tmp_path): - str_path = str(tmp_path) - session = common.Session(HTTP_BASE_CONTENT_HEADERS) - - assert session._data_path is None - - session.set_data_path(str_path) - assert session._data_path is not None - assert isinstance(session.data_path, Path) - - -@patch("delta_barth.logging.ENABLE_LOGGING", True) -@patch("delta_barth.logging.LOGGING_TO_FILE", True) -def test_session_setup_logging(tmp_path): - str_path = str(tmp_path) - foldername: str = "logging_test" - target_log_dir = tmp_path / foldername - - session = common.Session(HTTP_BASE_CONTENT_HEADERS, logging_folder=foldername) - session.set_data_path(str_path) - log_dir = session.logging_dir - assert log_dir.exists() - assert log_dir == target_log_dir - # write file - target_file = target_log_dir / LOG_FILENAME - assert not target_file.exists() - session.setup() # calls setup code for logging - assert target_file.exists() - - def test_validate_creds(credentials): creds = common.validate_credentials( username=credentials["user"], @@ -110,204 +51,3 @@ def test_ping(api_base_url): with pytest.raises(UnspecifiedRequestType): resp = common.ping(api_base_url, HttpRequestTypes.POST) - - -def test_session_set_ApiInfo_LoggedOut(credentials, api_base_url): - session = common.Session(HTTP_BASE_CONTENT_HEADERS) - - assert session.session_token is None - assert session._creds is None - assert session._base_url is None - - session.set_base_url(api_base_url) - assert session._base_url is not None - session.set_credentials( - username=credentials["user"], - password=credentials["pwd"], - database=credentials["db"], - mandant=credentials["mandant"], - ) - assert session._creds is not None - - assert session.session_token is None - assert not session.logged_in - - -@pytest.mark.api_con_required -def test_session_set_ApiInfo_LoggedIn(credentials, api_base_url): - session = common.Session(HTTP_BASE_CONTENT_HEADERS) - # prepare login - assert session.session_token is None - assert session._creds is None - assert session._base_url is None - session.set_base_url(api_base_url) - session.set_credentials( - username=credentials["user"], - password=credentials["pwd"], - database=credentials["db"], - mandant=credentials["mandant"], - ) - session.login() - assert session._base_url is not None - assert session.logged_in - # reset base URL - session.set_base_url(api_base_url) - assert session._base_url is not None - assert not session.logged_in - assert session.session_token is None - # reset credentials - session.login() - assert session.logged_in - session.set_credentials( - username=credentials["user"], - password=credentials["pwd"], - database=credentials["db"], - mandant=credentials["mandant"], - ) - assert session._creds is not None - assert not session.logged_in - assert session.session_token is None - - -@pytest.mark.api_con_required -def test_login_logout_Success(session, credentials): - assert not session.logged_in - - resp, status = session.login() - assert resp is not None - assert status.code == 0 - assert session.session_token is not None - resp, status = session.logout() - assert resp is None - assert status.code == 0 - assert session.session_token is None - assert "DelecoToken" not in session.headers - - session.set_credentials( - username=credentials["user"], - password="WRONG_PASSWORD", - database=credentials["db"], - mandant=credentials["mandant"], - ) - resp, status = session.login() - assert resp is not None - assert status.code == DEFAULT_API_ERR_CODE - assert status.api_server_error is not None - assert status.api_server_error.status_code == 409 - assert status.api_server_error.message == "Nutzer oder Passwort falsch." - - -def test_login_logout_FailApiServer(session, mock_put): - code = 401 - json = { - "message": "GenericError", - "code": "TestLogin", - "hints": "TestCase", - } - - mock_put.return_value.status_code = code - mock_put.return_value.json.return_value = json - resp, status = session.login() - assert resp is not None - assert not resp.token - assert status.code == 400 - assert status.api_server_error is not None - assert status.api_server_error.status_code == code - assert status.api_server_error.message == json["message"] - assert status.api_server_error.code == json["code"] - assert status.api_server_error.hints == json["hints"] - resp, status = session.logout() - assert resp is None - assert status.code == 400 - assert status.api_server_error is not None - assert status.api_server_error.status_code == code - assert status.api_server_error.message == json["message"] - assert status.api_server_error.code == json["code"] - assert status.api_server_error.hints == json["hints"] - - -@pytest.mark.api_con_required -def test_assert_login_SuccessLoggedOut(session): - assert session.session_token is None - assert session._creds is not None - # test logged out state - resp, status = session.assert_login() - assert resp is not None - assert status.code == 0 - assert session.session_token is not None - resp, status = session.logout() - assert status.code == 0 - - -@pytest.mark.api_con_required -def test_assert_login_SuccessStillLoggedIn(session): - assert session.session_token is None - assert session._creds is not None - resp, status = session.login() - resp, status = session.assert_login() - assert resp is not None - assert status.code == 0 - assert session.session_token is not None - resp, status = session.logout() - assert status.code == 0 - - -@pytest.mark.api_con_required -def test_assert_login_ReloginNoValidAuth(session, mock_get): - code = 401 - json = { - "message": "AuthentificationError", - "code": "TestAssertLoginAfter", - "hints": "TestCase", - } - mock_get.return_value.status_code = code - mock_get.return_value.json.return_value = json - - resp, status = session.login() - - resp, status = session.assert_login() - assert resp is not None - assert status.code == 0 - assert session.session_token is not None - resp, status = session.logout() - assert status.code == 0 - - -@pytest.mark.api_con_required -def test_assert_login_ReloginWrongToken(session): - # triggers code 401 - assert session.session_token is None - assert session._creds is not None - _, status = session.login() - assert status.code == 0 - session._session_token = "WRONGTOKEN" - resp, status = session.assert_login() - assert resp is not None - assert status.code == 0 - assert session.session_token is not None - resp, status = session.logout() - assert status.code == 0 - - -@pytest.mark.api_con_required -def test_assert_login_FailApiServer(session, mock_get): - code = 500 - json = { - "message": "ServerError", - "code": "TestExternalServerError", - "hints": "TestCase", - } - mock_get.return_value.status_code = code - mock_get.return_value.json.return_value = json - - resp, status = session.login() - - resp, status = session.assert_login() - assert resp is not None - assert not resp.token - assert status.code == 400 - assert status.api_server_error is not None - assert status.api_server_error.status_code == code - assert status.api_server_error.message == json["message"] - assert status.api_server_error.code == json["code"] - assert status.api_server_error.hints == json["hints"] diff --git a/tests/conftest.py b/tests/conftest.py index 7b3358e..a1b4735 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,14 @@ import json import tomllib from pathlib import Path from typing import Any, cast +from unittest.mock import patch import pandas as pd import pytest +import delta_barth.session from delta_barth.api.requests import SalesPrognosisResponse +from delta_barth.constants import HTTP_BASE_CONTENT_HEADERS @pytest.fixture(scope="session") @@ -30,40 +33,6 @@ def api_base_url(credentials) -> str: return credentials["base_url"] -# TODO: maybe include in main package depending if needed in future -# TODO check deletion -# def _cvt_str_float(value: str) -> float: -# import locale - -# locale.setlocale(locale.LC_NUMERIC, "de_DE.UTF-8") -# return locale.atof(value) - - -# def _cvt_str_ts(value: str) -> Any: -# date = value.split("_")[0] - -# return pd.to_datetime(date, format="%Y%m%d", errors="coerce") - - -# @pytest.fixture(scope="session") -# def sales_data_db_export() -> pd.DataFrame: -# pwd = Path.cwd() -# assert "barth" in pwd.parent.name.lower(), "not in project root directory" -# data_pth = pwd / "./tests/_test_data/swm_f_umsatz_fakt.csv" -# assert data_pth.exists(), "file to sales data not found" -# data = pd.read_csv(data_pth, sep="\t") -# data["betrag"] = data["betrag"].apply(_cvt_str_float) -# data["buchungs_datum"] = data["buchungs_datum"].apply(_cvt_str_ts) -# data = data.dropna( -# how="any", -# subset=["firma_refid", "beleg_typ", "buchungs_datum", "betrag"], -# ignore_index=True, -# ) -# data["buchungs_datum"] = pd.to_datetime(data["buchungs_datum"]) - -# return data - - @pytest.fixture(scope="session") def sales_data_real() -> pd.DataFrame: pwd = Path.cwd() @@ -101,3 +70,32 @@ def exmpl_api_sales_prognosis_output() -> pd.DataFrame: assert data_pth.exists(), "file to API sales data not found" return pd.read_pickle(data_pth) + + +# ** sessions +@pytest.fixture(scope="function") +def session(credentials, api_base_url, tmp_path) -> delta_barth.session.Session: + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS) + session.set_data_path(str(tmp_path)) + session.set_base_url(api_base_url) + session.set_credentials( + username=credentials["user"], + password=credentials["pwd"], + database=credentials["db"], + mandant=credentials["mandant"], + ) + session.setup() + + return session + + +@pytest.fixture +def mock_put(): + with patch("requests.put") as mock: + yield mock + + +@pytest.fixture +def mock_get(): + with patch("requests.get") as mock: + yield mock diff --git a/tests/test_databases.py b/tests/test_databases.py new file mode 100644 index 0000000..2b10318 --- /dev/null +++ b/tests/test_databases.py @@ -0,0 +1,11 @@ +import sqlalchemy as sql + +from delta_barth import databases as db + + +def test_get_engine(tmp_path): + db_path = tmp_path / "test_db.db" + engine = db.get_engine(db_path) + assert isinstance(engine, sql.Engine) + assert "sqlite" in str(engine.url) + assert db_path.parent.name in str(engine.url) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..780f1f7 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,281 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +import delta_barth.session +from delta_barth.constants import ( + DEFAULT_API_ERR_CODE, + HTTP_BASE_CONTENT_HEADERS, + LOG_FILENAME, +) + + +def test_validate_path_Success(): + str_pth = str(Path.cwd()) + path = delta_barth.session.validate_path(str_pth) + assert path.name == Path.cwd().name + + +def test_validate_path_FailNotExisting(): + str_pth = str(Path.cwd() / "test") + with pytest.raises(FileNotFoundError, match=r"seems not to exist"): + _ = delta_barth.session.validate_path(str_pth) + + +def test_validate_path_FailNoDirectory(tmp_path): + file = tmp_path / "test.txt" + file.write_text("test", encoding="utf-8") + + str_pth = str(file) + with pytest.raises(FileNotFoundError, match=r"seems not to be a directory"): + _ = delta_barth.session.validate_path(str_pth) + + +def test_session_set_DataPath(tmp_path): + str_path = str(tmp_path) + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS) + + assert session._data_path is None + + session.set_data_path(str_path) + assert session._data_path is not None + assert isinstance(session.data_path, Path) + + +def test_session_setup_db_management(tmp_path): + str_path = str(tmp_path) + foldername: str = "data_test" + target_db_dir = tmp_path / foldername + + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS, db_folder=foldername) + session.set_data_path(str_path) + db_path = session.db_path + assert db_path.parent.exists() + assert db_path.parent == target_db_dir + assert not db_path.exists() + session.setup() + assert session._db_engine is not None + assert db_path.exists() + + +@patch("delta_barth.logging.ENABLE_LOGGING", True) +@patch("delta_barth.logging.LOGGING_TO_FILE", True) +def test_session_setup_logging(tmp_path): + str_path = str(tmp_path) + foldername: str = "logging_test" + target_log_dir = tmp_path / foldername + + session = delta_barth.session.Session( + HTTP_BASE_CONTENT_HEADERS, logging_folder=foldername + ) + session.set_data_path(str_path) + log_dir = session.logging_dir + assert log_dir.exists() + assert log_dir == target_log_dir + # write file + target_file = target_log_dir / LOG_FILENAME + assert not target_file.exists() + session.setup() # calls setup code for logging + assert target_file.exists() + + +def test_session_set_ApiInfo_LoggedOut(credentials, api_base_url): + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS) + + assert session.session_token is None + assert session._creds is None + assert session._base_url is None + + session.set_base_url(api_base_url) + assert session._base_url is not None + session.set_credentials( + username=credentials["user"], + password=credentials["pwd"], + database=credentials["db"], + mandant=credentials["mandant"], + ) + assert session._creds is not None + + assert session.session_token is None + assert not session.logged_in + + +@pytest.mark.api_con_required +def test_session_set_ApiInfo_LoggedIn(credentials, api_base_url): + session = delta_barth.session.Session(HTTP_BASE_CONTENT_HEADERS) + # prepare login + assert session.session_token is None + assert session._creds is None + assert session._base_url is None + session.set_base_url(api_base_url) + session.set_credentials( + username=credentials["user"], + password=credentials["pwd"], + database=credentials["db"], + mandant=credentials["mandant"], + ) + session.login() + assert session._base_url is not None + assert session.logged_in + # reset base URL + session.set_base_url(api_base_url) + assert session._base_url is not None + assert not session.logged_in + assert session.session_token is None + # reset credentials + session.login() + assert session.logged_in + session.set_credentials( + username=credentials["user"], + password=credentials["pwd"], + database=credentials["db"], + mandant=credentials["mandant"], + ) + assert session._creds is not None + assert not session.logged_in + assert session.session_token is None + + +@pytest.mark.api_con_required +def test_login_logout_Success(session, credentials): + assert not session.logged_in + + resp, status = session.login() + assert resp is not None + assert status.code == 0 + assert session.session_token is not None + resp, status = session.logout() + assert resp is None + assert status.code == 0 + assert session.session_token is None + assert "DelecoToken" not in session.headers + + session.set_credentials( + username=credentials["user"], + password="WRONG_PASSWORD", + database=credentials["db"], + mandant=credentials["mandant"], + ) + resp, status = session.login() + assert resp is not None + assert status.code == DEFAULT_API_ERR_CODE + assert status.api_server_error is not None + assert status.api_server_error.status_code == 409 + assert status.api_server_error.message == "Nutzer oder Passwort falsch." + + +def test_login_logout_FailApiServer(session, mock_put): + code = 401 + json = { + "message": "GenericError", + "code": "TestLogin", + "hints": "TestCase", + } + + mock_put.return_value.status_code = code + mock_put.return_value.json.return_value = json + resp, status = session.login() + assert resp is not None + assert not resp.token + assert status.code == 400 + assert status.api_server_error is not None + assert status.api_server_error.status_code == code + assert status.api_server_error.message == json["message"] + assert status.api_server_error.code == json["code"] + assert status.api_server_error.hints == json["hints"] + resp, status = session.logout() + assert resp is None + assert status.code == 400 + assert status.api_server_error is not None + assert status.api_server_error.status_code == code + assert status.api_server_error.message == json["message"] + assert status.api_server_error.code == json["code"] + assert status.api_server_error.hints == json["hints"] + + +@pytest.mark.api_con_required +def test_assert_login_SuccessLoggedOut(session): + assert session.session_token is None + assert session._creds is not None + # test logged out state + resp, status = session.assert_login() + assert resp is not None + assert status.code == 0 + assert session.session_token is not None + resp, status = session.logout() + assert status.code == 0 + + +@pytest.mark.api_con_required +def test_assert_login_SuccessStillLoggedIn(session): + assert session.session_token is None + assert session._creds is not None + resp, status = session.login() + resp, status = session.assert_login() + assert resp is not None + assert status.code == 0 + assert session.session_token is not None + resp, status = session.logout() + assert status.code == 0 + + +@pytest.mark.api_con_required +def test_assert_login_ReloginNoValidAuth(session, mock_get): + code = 401 + json = { + "message": "AuthentificationError", + "code": "TestAssertLoginAfter", + "hints": "TestCase", + } + mock_get.return_value.status_code = code + mock_get.return_value.json.return_value = json + + resp, status = session.login() + + resp, status = session.assert_login() + assert resp is not None + assert status.code == 0 + assert session.session_token is not None + resp, status = session.logout() + assert status.code == 0 + + +@pytest.mark.api_con_required +def test_assert_login_ReloginWrongToken(session): + # triggers code 401 + assert session.session_token is None + assert session._creds is not None + _, status = session.login() + assert status.code == 0 + session._session_token = "WRONGTOKEN" + resp, status = session.assert_login() + assert resp is not None + assert status.code == 0 + assert session.session_token is not None + resp, status = session.logout() + assert status.code == 0 + + +@pytest.mark.api_con_required +def test_assert_login_FailApiServer(session, mock_get): + code = 500 + json = { + "message": "ServerError", + "code": "TestExternalServerError", + "hints": "TestCase", + } + mock_get.return_value.status_code = code + mock_get.return_value.json.return_value = json + + resp, status = session.login() + + resp, status = session.assert_login() + assert resp is not None + assert not resp.token + assert status.code == 400 + assert status.api_server_error is not None + assert status.api_server_error.status_code == code + assert status.api_server_error.message == json["message"] + assert status.api_server_error.code == json["code"] + assert status.api_server_error.hints == json["hints"]