generated from dopt-python/py311
add anomaly threshold as parameter, related to #20
This commit is contained in:
@@ -30,11 +30,13 @@ def sensor_anomalies_detection(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> int:
|
||||
res = detection.pipeline(
|
||||
user_img_path=user_img_path,
|
||||
pixels_per_metric_X=pixels_per_metric_X,
|
||||
pixels_per_metric_Y=pixels_per_metric_Y,
|
||||
anomaly_threshold=anomaly_threshold,
|
||||
)
|
||||
if res.status.code != 0:
|
||||
_print_error_state(res.status, out_stream=sys.stderr)
|
||||
|
||||
@@ -9,7 +9,7 @@ THRESHOLD_BW: Final[int] = 63
|
||||
BACKBONE: Final[str] = "wide_resnet50_2"
|
||||
LAYERS: Final[tuple[str, ...]] = ("layer1", "layer2", "layer3")
|
||||
RATIO: Final[float] = 0.01
|
||||
ANOMALY_THRESHOLD: Final[float] = 0.14
|
||||
ANOMALY_THRESHOLD_DEFAULT: Final[float] = 0.14
|
||||
|
||||
NUM_VALID_ELECTRODES: Final[int] = 6
|
||||
HEATMAP_FILENAME_SUFFIX: Final[str] = "_Heatmap"
|
||||
|
||||
@@ -182,6 +182,7 @@ def measure_length(
|
||||
def infer_image(
|
||||
image: npt.NDArray[np.uint8],
|
||||
model: Patchcore,
|
||||
anomaly_threshold: float,
|
||||
) -> t.InferenceResult:
|
||||
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(torch_device)
|
||||
@@ -200,7 +201,7 @@ def infer_image(
|
||||
output = model(input_tensor)
|
||||
|
||||
anomaly_score = output.pred_score.item()
|
||||
anomaly_label = bool(1 if anomaly_score >= const.ANOMALY_THRESHOLD else 0)
|
||||
anomaly_label = bool(1 if anomaly_score >= anomaly_threshold else 0)
|
||||
anomaly_map = output.anomaly_map.squeeze().cpu().numpy()
|
||||
|
||||
img_np = np.array(pil_image)
|
||||
@@ -219,6 +220,7 @@ def anomaly_detection(
|
||||
detection_models: t.DetectionModels,
|
||||
data_csv: t.CsvData,
|
||||
sensor_images: t.SensorImages,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
file_stem = img_path.stem
|
||||
folder_path = img_path.parent
|
||||
@@ -233,7 +235,7 @@ def anomaly_detection(
|
||||
checkpoint = torch.load(detection_models[side])
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
result = infer_image(image, model)
|
||||
result = infer_image(image, model, anomaly_threshold)
|
||||
data_csv.extend([int(result.anomaly_label)])
|
||||
|
||||
ax = axes[i]
|
||||
@@ -265,6 +267,7 @@ def pipeline(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
file_path = Path(user_img_path)
|
||||
if not file_path.exists():
|
||||
@@ -283,4 +286,5 @@ def pipeline(
|
||||
detection_models=DETECTION_MODELS,
|
||||
data_csv=data_csv,
|
||||
sensor_images=sensor_images,
|
||||
anomaly_threshold=anomaly_threshold,
|
||||
)
|
||||
|
||||
@@ -86,6 +86,7 @@ def measure_length(
|
||||
def infer_image(
|
||||
image: npt.NDArray[np.uint8],
|
||||
model: Patchcore,
|
||||
anomaly_threshold: float,
|
||||
) -> t.InferenceResult:
|
||||
"""evaluate one image
|
||||
|
||||
@@ -113,6 +114,7 @@ def anomaly_detection(
|
||||
detection_models: t.DetectionModels,
|
||||
data_csv: t.CsvData,
|
||||
sensor_images: t.SensorImages,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
"""load the model, call function for anomaly detection and store the results
|
||||
|
||||
@@ -134,6 +136,7 @@ def pipeline(
|
||||
user_img_path: str,
|
||||
pixels_per_metric_X: float,
|
||||
pixels_per_metric_Y: float,
|
||||
anomaly_threshold: float,
|
||||
) -> None:
|
||||
"""full pipeline defined by the agreed requirements
|
||||
wrapped as result pattern, handle errors on higher abstraction level
|
||||
|
||||
Reference in New Issue
Block a user