generated from dopt-python/py311
add anomaly threshold as parameter, related to #20
This commit is contained in:
parent
c169d6d1cf
commit
7a6188c647
@ -30,11 +30,13 @@ def sensor_anomalies_detection(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> int:
|
||||
res = detection.pipeline(
|
||||
user_img_path=user_img_path,
|
||||
pixels_per_metric_X=pixels_per_metric_X,
|
||||
pixels_per_metric_Y=pixels_per_metric_Y,
|
||||
anomaly_threshold=anomaly_threshold,
|
||||
)
|
||||
if res.status.code != 0:
|
||||
_print_error_state(res.status, out_stream=sys.stderr)
|
||||
|
||||
@ -9,7 +9,7 @@ THRESHOLD_BW: Final[int] = 63
|
||||
BACKBONE: Final[str] = "wide_resnet50_2"
|
||||
LAYERS: Final[tuple[str, ...]] = ("layer1", "layer2", "layer3")
|
||||
RATIO: Final[float] = 0.01
|
||||
ANOMALY_THRESHOLD: Final[float] = 0.14
|
||||
ANOMALY_THRESHOLD_DEFAULT: Final[float] = 0.14
|
||||
|
||||
NUM_VALID_ELECTRODES: Final[int] = 6
|
||||
HEATMAP_FILENAME_SUFFIX: Final[str] = "_Heatmap"
|
||||
|
||||
@ -182,6 +182,7 @@ def measure_length(
|
||||
def infer_image(
|
||||
image: npt.NDArray[np.uint8],
|
||||
model: Patchcore,
|
||||
anomaly_threshold: float,
|
||||
) -> t.InferenceResult:
|
||||
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(torch_device)
|
||||
@ -200,7 +201,7 @@ def infer_image(
|
||||
output = model(input_tensor)
|
||||
|
||||
anomaly_score = output.pred_score.item()
|
||||
anomaly_label = bool(1 if anomaly_score >= const.ANOMALY_THRESHOLD else 0)
|
||||
anomaly_label = bool(1 if anomaly_score >= anomaly_threshold else 0)
|
||||
anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
|
||||
|
||||
img_np = np.array(pil_image)
|
||||
@ -219,6 +220,7 @@ def anomaly_detection(
|
||||
detection_models: t.DetectionModels,
|
||||
data_csv: t.CsvData,
|
||||
sensor_images: t.SensorImages,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
file_stem = img_path.stem
|
||||
folder_path = img_path.parent
|
||||
@ -233,7 +235,7 @@ def anomaly_detection(
|
||||
checkpoint = torch.load(detection_models[side])
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
result = infer_image(image, model)
|
||||
result = infer_image(image, model, anomaly_threshold)
|
||||
data_csv.extend([int(result.anomaly_label)])
|
||||
|
||||
ax = axes[i]
|
||||
@ -265,6 +267,7 @@ def pipeline(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
file_path = Path(user_img_path)
|
||||
if not file_path.exists():
|
||||
@ -283,4 +286,5 @@ def pipeline(
|
||||
detection_models=DETECTION_MODELS,
|
||||
data_csv=data_csv,
|
||||
sensor_images=sensor_images,
|
||||
anomaly_threshold=anomaly_threshold,
|
||||
)
|
||||
|
||||
@ -86,6 +86,7 @@ def measure_length(
|
||||
def infer_image(
|
||||
image: npt.NDArray[np.uint8],
|
||||
model: Patchcore,
|
||||
anomaly_threshold: float,
|
||||
) -> t.InferenceResult:
|
||||
"""evaluate one image
|
||||
|
||||
@ -113,6 +114,7 @@ def anomaly_detection(
|
||||
detection_models: t.DetectionModels,
|
||||
data_csv: t.CsvData,
|
||||
sensor_images: t.SensorImages,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
"""load the model, call function for anomaly detection and store the results
|
||||
|
||||
@ -134,6 +136,7 @@ def pipeline(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
"""full pipeline defined by the agreed requirements
|
||||
wrapped as result pattern, handle errors on higher abstraction level
|
||||
|
||||
@ -104,6 +104,7 @@ def test_isolated_pipeline(results_folder, path_img_with_failure_TrainedModel):
|
||||
detection_models=DETECTION_MODELS,
|
||||
data_csv=data_csv,
|
||||
sensor_images=sensor_images,
|
||||
anomaly_threshold=constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
# check files for existence
|
||||
root_img = path_img_with_failure_TrainedModel.parent
|
||||
@ -124,7 +125,12 @@ def test_full_pipeline_wrapped_FailImagePath(setup_temp_dir):
|
||||
pixels_per_metric_X: float = 0.251
|
||||
pixels_per_metric_Y: float = 0.251
|
||||
|
||||
ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y)
|
||||
ret = detect.pipeline(
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
assert ret.status != result_pattern.STATUS_HANDLER.SUCCESS
|
||||
assert ret.status.ExceptionType is FileNotFoundError
|
||||
assert ret.status.message == MESSAGE
|
||||
@ -140,7 +146,12 @@ def test_full_pipeline_wrapped_FailElectrodeCount(path_img_with_failure_Electrod
|
||||
pixels_per_metric_X: float = 0.251
|
||||
pixels_per_metric_Y: float = 0.251
|
||||
|
||||
ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y)
|
||||
ret = detect.pipeline(
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
assert ret.status != result_pattern.STATUS_HANDLER.SUCCESS
|
||||
assert ret.status.ExceptionType is errors.InvalidElectrodeCount
|
||||
assert MESSAGE in ret.status.message
|
||||
@ -164,7 +175,12 @@ def test_full_pipeline_wrapped_Success(results_folder, path_img_with_failure_Tra
|
||||
pixels_per_metric_X: float = 0.251
|
||||
pixels_per_metric_Y: float = 0.251
|
||||
|
||||
ret = detect.pipeline(img_path, pixels_per_metric_X, pixels_per_metric_Y)
|
||||
ret = detect.pipeline(
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
assert ret.status == result_pattern.STATUS_HANDLER.SUCCESS
|
||||
assert ret.status.code == 0
|
||||
assert ret.status.ExceptionType is None
|
||||
|
||||
@ -54,7 +54,10 @@ def test_sensor_anomalies_detection_FailImagePath(setup_temp_dir):
|
||||
|
||||
with patch("sys.stderr", new_callable=StringIO) as mock_err:
|
||||
ret = _interface.sensor_anomalies_detection(
|
||||
img_path, pixels_per_metric_X, pixels_per_metric_Y
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
captured = mock_err.getvalue()
|
||||
assert ret != 0
|
||||
@ -72,7 +75,10 @@ def test_sensor_anomalies_detection_FailElectrodeCount(path_img_with_failure_Ele
|
||||
|
||||
with patch("sys.stderr", new_callable=StringIO) as mock_err:
|
||||
ret = _interface.sensor_anomalies_detection(
|
||||
img_path, pixels_per_metric_X, pixels_per_metric_Y
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
captured = mock_err.getvalue()
|
||||
assert ret != 0
|
||||
@ -99,7 +105,10 @@ def test_sensor_anomalies_detection_Success(
|
||||
pixels_per_metric_Y: float = 0.251
|
||||
|
||||
ret = _interface.sensor_anomalies_detection(
|
||||
img_path, pixels_per_metric_X, pixels_per_metric_Y
|
||||
img_path,
|
||||
pixels_per_metric_X,
|
||||
pixels_per_metric_Y,
|
||||
constants.ANOMALY_THRESHOLD_DEFAULT,
|
||||
)
|
||||
|
||||
assert ret == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user