diff --git a/src/dopt_sensor_anomalies/constants.py b/src/dopt_sensor_anomalies/constants.py index a8493da..d90be3a 100644 --- a/src/dopt_sensor_anomalies/constants.py +++ b/src/dopt_sensor_anomalies/constants.py @@ -9,6 +9,7 @@ THRESHOLD_BW: Final[int] = 63 BACKBONE: Final[str] = "'wide_resnet50_2" LAYERS: Final[tuple[str, str]] = ("layer1", "layer2", "layer3") RATIO: Final[float] = 0.01 +ANOMALY_THRESHOLD: Final[float] = 0.2 NUM_VALID_ELECTRODES: Final[int] = 6 HEATMAP_FILENAME_SUFFIX: Final[str] = "_Heatmap" diff --git a/src/dopt_sensor_anomalies/detection.py b/src/dopt_sensor_anomalies/detection.py index 5bd5612..6f7a8b2 100644 --- a/src/dopt_sensor_anomalies/detection.py +++ b/src/dopt_sensor_anomalies/detection.py @@ -176,7 +176,7 @@ def infer_image( output = model(input_tensor) anomaly_score = output.pred_score.item() - anomaly_label = bool(1 if anomaly_score >= 0.2 else 0) + anomaly_label = bool(1 if anomaly_score >= const.ANOMALY_THRESHOLD else 0) anomaly_map = output.anomaly_map.squeeze().cpu().numpy() img_np = np.array(pil_image)