diff --git a/src/dopt_sensor_anomalies/detection.py b/src/dopt_sensor_anomalies/detection.py index cfb4ce3..5bd5612 100644 --- a/src/dopt_sensor_anomalies/detection.py +++ b/src/dopt_sensor_anomalies/detection.py @@ -176,7 +176,7 @@ def infer_image( output = model(input_tensor) anomaly_score = output.pred_score.item() - anomaly_label = 1 if anomaly_score >= .2 else 0 + anomaly_label = bool(1 if anomaly_score >= 0.2 else 0) anomaly_map = output.anomaly_map.squeeze().cpu().numpy() img_np = np.array(pil_image)