From 7a6188c64700096a9dc3e46fcb572cc716afc9a8 Mon Sep 17 00:00:00 2001 From: foefl Date: Tue, 6 Jan 2026 08:28:16 +0100 Subject: [PATCH] add anomaly threshold as parameter, related to #20 --- src/dopt_sensor_anomalies/_interface.py | 2 ++ src/dopt_sensor_anomalies/constants.py | 2 +- src/dopt_sensor_anomalies/detection.py | 8 ++++++-- src/dopt_sensor_anomalies/detection.pyi | 3 +++ tests/test_detection.py | 22 +++++++++++++++++++--- tests/test_interface.py | 15 ++++++++++++--- 6 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/dopt_sensor_anomalies/_interface.py b/src/dopt_sensor_anomalies/_interface.py index 1195f17..8b6a0a9 100644 --- a/src/dopt_sensor_anomalies/_interface.py +++ b/src/dopt_sensor_anomalies/_interface.py @@ -30,11 +30,13 @@ def sensor_anomalies_detection( user_img_path: str, pixels_per_metric_X: float, pixels_per_metric_Y: float, + anomaly_threshold: float, ) -> int: res = detection.pipeline( user_img_path=user_img_path, pixels_per_metric_X=pixels_per_metric_X, pixels_per_metric_Y=pixels_per_metric_Y, + anomaly_threshold=anomaly_threshold, ) if res.status.code != 0: _print_error_state(res.status, out_stream=sys.stderr) diff --git a/src/dopt_sensor_anomalies/constants.py b/src/dopt_sensor_anomalies/constants.py index 1e7b6d7..88f8d0b 100644 --- a/src/dopt_sensor_anomalies/constants.py +++ b/src/dopt_sensor_anomalies/constants.py @@ -9,7 +9,7 @@ THRESHOLD_BW: Final[int] = 63 BACKBONE: Final[str] = "wide_resnet50_2" LAYERS: Final[tuple[str, ...]] = ("layer1", "layer2", "layer3") RATIO: Final[float] = 0.01 -ANOMALY_THRESHOLD: Final[float] = 0.14 +ANOMALY_THRESHOLD_DEFAULT: Final[float] = 0.14 NUM_VALID_ELECTRODES: Final[int] = 6 HEATMAP_FILENAME_SUFFIX: Final[str] = "_Heatmap" diff --git a/src/dopt_sensor_anomalies/detection.py b/src/dopt_sensor_anomalies/detection.py index 2415541..af97907 100644 --- a/src/dopt_sensor_anomalies/detection.py +++ b/src/dopt_sensor_anomalies/detection.py @@ -182,6 +182,7 @@ def measure_length( def infer_image( image: npt.NDArray[np.uint8], model: Patchcore, + anomaly_threshold: float, ) -> t.InferenceResult: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(torch_device) @@ -200,7 +201,7 @@ def infer_image( output = model(input_tensor) anomaly_score = output.pred_score.item() - anomaly_label = bool(1 if anomaly_score >= const.ANOMALY_THRESHOLD else 0) + anomaly_label = bool(1 if anomaly_score >= anomaly_threshold else 0) anomaly_map = output.anomaly_map.squeeze().cpu().numpy() img_np = np.array(pil_image) @@ -219,6 +220,7 @@ def anomaly_detection( detection_models: t.DetectionModels, data_csv: t.CsvData, sensor_images: t.SensorImages, + anomaly_threshold: float, ) -> None: file_stem = img_path.stem folder_path = img_path.parent @@ -233,7 +235,7 @@ def anomaly_detection( checkpoint = torch.load(detection_models[side]) model.load_state_dict(checkpoint["model_state_dict"]) - result = infer_image(image, model) + result = infer_image(image, model, anomaly_threshold) data_csv.extend([int(result.anomaly_label)]) ax = axes[i] @@ -265,6 +267,7 @@ def pipeline( user_img_path: str, pixels_per_metric_X: float, pixels_per_metric_Y: float, + anomaly_threshold: float, ) -> None: file_path = Path(user_img_path) if not file_path.exists(): @@ -283,4 +286,5 @@ def pipeline( detection_models=DETECTION_MODELS, data_csv=data_csv, sensor_images=sensor_images, + anomaly_threshold=anomaly_threshold, ) diff --git a/src/dopt_sensor_anomalies/detection.pyi b/src/dopt_sensor_anomalies/detection.pyi index eaf06b3..65b8cb4 100644 --- a/src/dopt_sensor_anomalies/detection.pyi +++ b/src/dopt_sensor_anomalies/detection.pyi @@ -86,6 +86,7 @@ def measure_length( def infer_image( image: npt.NDArray[np.uint8], model: Patchcore, + anomaly_threshold: float, ) -> t.InferenceResult: """evaluate one image @@ -113,6 +114,7 @@ def anomaly_detection( detection_models: t.DetectionModels, data_csv: t.CsvData, sensor_images: t.SensorImages, + anomaly_threshold: float, ) -> None: """load the model, call function for anomaly detection and store the results @@ -134,6 +136,7 @@ def pipeline( user_img_path: str, pixels_per_metric_X: float, pixels_per_metric_Y: float, + anomaly_threshold: float, ) -> None: """full pipeline defined by the agreed requirements wrapped as result pattern, handle errors on higher abstraction level diff --git a/tests/test_detection.py b/tests/test_detection.py index 0bb00bc..273cf97 100644 --- a/tests/test_detection.py +++ b/tests/test_detection.py @@ -104,6 +104,7 @@ def test_isolated_pipeline(results_folder, path_img_with_failure_TrainedModel): detection_models=DETECTION_MODELS, data_csv=data_csv, sensor_images=sensor_images, + anomaly_threshold=constants.ANOMALY_THRESHOLD_DEFAULT, ) # check files for existence root_img = path_img_with_failure_TrainedModel.parent @@ -124,7 +125,12 @@ def test_full_pipeline_wrapped_FailImagePath(setup_temp_dir): pixels_per_metric_X: float = 0.251 pixels_per_metric_Y: float = 0.251 - ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y) + ret = detect.pipeline( + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, + ) assert ret.status != result_pattern.STATUS_HANDLER.SUCCESS assert ret.status.ExceptionType is FileNotFoundError assert ret.status.message == MESSAGE @@ -140,7 +146,12 @@ def test_full_pipeline_wrapped_FailElectrodeCount(path_img_with_failure_Electrod pixels_per_metric_X: float = 0.251 pixels_per_metric_Y: float = 0.251 - ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y) + ret = detect.pipeline( + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, + ) assert ret.status != result_pattern.STATUS_HANDLER.SUCCESS assert ret.status.ExceptionType is errors.InvalidElectrodeCount assert MESSAGE in ret.status.message @@ -164,7 +175,12 @@ def test_full_pipeline_wrapped_Success(results_folder, path_img_with_failure_Tra pixels_per_metric_X: float = 0.251 pixels_per_metric_Y: float = 0.251 - ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y) + ret = detect.pipeline( + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, + ) assert ret.status == result_pattern.STATUS_HANDLER.SUCCESS assert ret.status.code == 0 assert ret.status.ExceptionType is None diff --git a/tests/test_interface.py b/tests/test_interface.py index 4a907a8..f55aa6e 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -54,7 +54,10 @@ def test_sensor_anomalies_detection_FailImagePath(setup_temp_dir): with patch("sys.stderr", new_callable=StringIO) as mock_err: ret = _interface.sensor_anomalies_detection( - img_path, pixels_per_metric_X, pixels_per_metric_Y + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, ) captured = mock_err.getvalue() assert ret != 0 @@ -72,7 +75,10 @@ def test_sensor_anomalies_detection_FailElectrodeCount(path_img_with_failure_Ele with patch("sys.stderr", new_callable=StringIO) as mock_err: ret = _interface.sensor_anomalies_detection( - img_path, pixels_per_metric_X, pixels_per_metric_Y + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, ) captured = mock_err.getvalue() assert ret != 0 @@ -99,7 +105,10 @@ def test_sensor_anomalies_detection_Success( pixels_per_metric_Y: float = 0.251 ret = _interface.sensor_anomalies_detection( - img_path, pixels_per_metric_X, pixels_per_metric_Y + img_path, + pixels_per_metric_X, + pixels_per_metric_Y, + constants.ANOMALY_THRESHOLD_DEFAULT, ) assert ret == 0