diff --git a/src/dopt_sensor_anomalies/_find_paths.py b/src/dopt_sensor_anomalies/_find_paths.py new file mode 100644 index 0000000..7033ac3 --- /dev/null +++ b/src/dopt_sensor_anomalies/_find_paths.py @@ -0,0 +1,37 @@ +from pathlib import Path + +import dopt_basics.io + +from dopt_sensor_anomalies import types as t +from dopt_sensor_anomalies.constants import LIB_ROOT_PATH, MODEL_FOLDER_NAME, STOP_FOLDER_NAME + + +def get_model_folder() -> Path: + path_found = dopt_basics.io.search_folder_path( + starting_path=LIB_ROOT_PATH, stop_folder_name=STOP_FOLDER_NAME + ) + if path_found is None: + raise FileNotFoundError( + "The model folder was not found in the application's root directory." + ) + + return path_found / MODEL_FOLDER_NAME + + +def get_detection_models(model_folder: Path) -> t.DetectionModels: + left_model_search = tuple(model_folder.glob("*left_hand_side*.pth")) + if not left_model_search: + raise ValueError("No model for the left hand side found.") + if len(left_model_search) > 1: + raise ValueError("Too many models for the left hand side found.") + + right_model_search = tuple(model_folder.glob("*right_hand_side*.pth")) + if not right_model_search: + raise ValueError("No model for the right hand side found.") + elif len(right_model_search) > 1: + raise ValueError("Too many models for the right hand side found.") + + left_model = left_model_search[0] + right_model = right_model_search[0] + + return t.DetectionModels(left=left_model, right=right_model) diff --git a/src/dopt_sensor_anomalies/constants.py b/src/dopt_sensor_anomalies/constants.py index fe344eb..a525708 100644 --- a/src/dopt_sensor_anomalies/constants.py +++ b/src/dopt_sensor_anomalies/constants.py @@ -1,5 +1,11 @@ +from pathlib import Path from typing import Final +LIB_ROOT_PATH: Final[Path] = Path(__file__).parent +STOP_FOLDER_NAME: Final[str] = "python" +MODEL_FOLDER_NAME: Final[str] = "models" + +# TODO: remove comment THRESHOLD_BW: Final[int] = 63 # threshold to distringuish black (electrodes) and white areas # model_path = [ # r"C:\Users\demon\Documents\EKF\Modelle\patchcore_model_links.pth", @@ -8,3 +14,8 @@ THRESHOLD_BW: Final[int] = 63 # threshold to distringuish black (electrodes) an BACKBONE: Final[str] = "resnet18" # parameters for AI model LAYERS: Final[tuple[str, str]] = ("layer1", "layer2") RATIO: Final[float] = 0.05 + +NUM_VALID_ELECTRODES: Final[int] = 6 +# TODO: Remove? +# CONTOUR_EXPORT_FILENAME_SUFFIX: Final[str] = "_all_contours" +HEATMAP_FILENAME_SUFFIX: Final[str] = "_Heatmap" diff --git a/src/dopt_sensor_anomalies/detection.py b/src/dopt_sensor_anomalies/detection.py index 3f4ded9..90aea1f 100644 --- a/src/dopt_sensor_anomalies/detection.py +++ b/src/dopt_sensor_anomalies/detection.py @@ -1,80 +1,94 @@ import csv from os import path +from pathlib import Path +from typing import Any, Final, cast # Image.MAX_IMAGE_PIXELS = None import cv2 +import imutils import matplotlib.pyplot as plt -from anomalib.data import Folder +import numpy as np +import numpy.typing as npt +import torch from anomalib.engine import Engine -from anomalib.metrics import AUROC, F1Score from anomalib.models import Patchcore -from imutils import contours, grab_contours, is_cv2, perspective -from numpy import all, array, linalg -from numpy import max as npmax -from numpy import min as npmin -from numpy import sum as npsum +from imutils import contours, perspective from pandas import DataFrame from PIL import Image from scipy.spatial import distance as dist -from torch import as_tensor, cuda, device, float32, load, no_grad from torchvision.transforms.v2.functional import to_dtype, to_image +import dopt_sensor_anomalies._find_paths from dopt_sensor_anomalies import constants as const +from dopt_sensor_anomalies import errors +from dopt_sensor_anomalies import types as t # input parameters: user-defined -file_path = r"C:\Users\demon\Documents\EKF\Analyse_fuer_Florian\bild2.bmp" -pixelsPerMetricX = 0.251 -pixelsPerMetricY = 0.251 - - -# internal parameters - configuration -schwellwert = 63 # threshold to distinguish black (electrodes) and white areas -model_path = [ - r"C:\Users\demon\Documents\EKF\Modelle\patchcore_model_links.pth", - r"C:\Users\demon\Documents\EKF\Modelle\patchcore_model_rechts.pth", -] # path to anomaly detection models +file_path: Path = Path(r"C:\Users\demon\Documents\EKF\Analyse_fuer_Florian\bild2.bmp") +pixels_per_metric_X: float = 0.251 +pixels_per_metric_Y: float = 0.251 # measuring -def midpoint(ptA, ptB): - # ---------------------------- - # To identify the midpoint of a 2D area - # Input: - # ptA (numpy.ndarray of shape (2, )): tuple of coordinates x, y - # ptB (numpy.ndarray of shape (2, )): tuple of coordinates x, y - # Output (tuple (float, float)): - # tuple of midpoint coordinates - # ---------------------------- +def midpoint(ptA: npt.NDArray, ptB: npt.NDArray) -> tuple[float, float]: + """to identify the midpoint of a 2D area + Parameters + ---------- + ptA : npt.NDArray + tuple of coordinates x, y; shape (2, ) + ptB : npt.NDArray + tuple of coordinates x, y; shape (2, ) + + Returns + ------- + tuple[float, float] + tuple of midpoint coordinates + """ return ((ptA[0] + ptB[0]) * 0.5, (ptA[1] + ptB[1]) * 0.5) -def check_box_redundancy(box1, box2, tolerance=5): - # ---------------------------- - # To check if bounding box has already been identified and is just a redundant one - # Input: - # box1 (tuple(float, float), (float, float), float)): tuple of box values: ((center_x, center_y), (width, height), angle) - # box2 (tuple(float, float), (float, float), float)): tuple of box values: ((center_x, center_y), (width, height), angle) - # tolerance (float): distance threshold for width and height - # Output (Boole): - # redundancy evaluation - # ---------------------------- +def check_box_redundancy( + box1: t.Box, + box2: t.Box, + tolerance: float = 5.0, +) -> bool: + """to check if bounding box has already been identified and is just a redundant one + Parameters + ---------- + box1 : t.Box + tuple of box values: ((center_x, center_y), (width, height), angle) + box2 : t.Box + tuple of box values: ((center_x, center_y), (width, height), angle) + tolerance : float, optional + distance threshold for width and height, by default 5.0 + + Returns + ------- + bool + redundancy evaluation + """ # unpack the boxes - (c1, s1, a1) = box1 - (c2, s2, a2) = box2 - - # sort width and height such that (w, h) == (h, w) is treated the same (might have been recognized in different orders) + c1, s1, _ = box1 + c2, s2, _ = box2 + # sort width and height such that (w, h) == (h, w) is treated the same + # (might have been recognized in different orders) s1 = sorted(s1) s2 = sorted(s2) - center_dist = linalg.norm(array(c1) - array(c2)) - size_diff = linalg.norm(array(s1) - array(s2)) + center_dist = cast(float, np.linalg.norm(np.array(c1) - np.array(c2))) + size_diff = cast(float, np.linalg.norm(np.array(s1) - np.array(s2))) return center_dist < tolerance and size_diff < tolerance -def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): +# ** main function +def measure_length( + file_path: Path, + pixels_per_metric_X: float, + pixels_per_metric_Y: float, +) -> tuple[list[str | int], t.SensorImages]: # ---------------------------- # To identify the midpoint of a 2D area # Input: @@ -86,53 +100,39 @@ def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): # image of left sensor # image of right sensor # ---------------------------- - - file = path.basename(file_path) - # extract file name and ending separately - name, endung = path.splitext(file) - # extract folder path - folder_path = path.dirname(file_path) - - # for data output - data_csv = [] - - # read - image = cv2.imread(file_path) - - # check if image was read + file_stem = file_path.stem + folder_path = file_path.parent + data_csv: list[str | int] = [] + image = cv2.imread(str(file_path)) if image is None: - error = "error: no image read" - return + raise errors.ImageNotReadError(f"Image could not be read from: >{file_path}<") - # crop image cropped = image[500:1500, 100 : image.shape[1] - 100] - - # store original image for later output orig = cropped.copy() - height, width = cropped.shape[0], cropped.shape[1] + # TODO: check removal + # height, width = cropped.shape[0], cropped.shape[1] # change colours in the image to black and white gray = cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY) _, binary = cv2.threshold(gray, const.THRESHOLD_BW, 255, cv2.THRESH_BINARY) - # perform edge detection, identify rectangular shapes kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel) edged = cv2.Canny(closed, 50, 100) - # find contours in the edge map cnts = cv2.findContours(edged.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - cnts = grab_contours(cnts) + cnts = imutils.grab_contours(cnts) if cnts is None: - print(f"{file}: offenbar nichts gefunden") - return None, None + raise errors.ContourCalculationError( + "No contours were found in the provided image. Can not continue analysis." + ) # sort the contours from left to right (i.e., use x coordinates) cnts, _ = contours.sort_contours(cnts) + # TODO: remove??? # bounding_boxes = list(set([cv2.boundingRect(c) for c in cnts])) # cnts = [c for _, c in sorted(zip(bounding_boxes, cnts), key=lambda b: b[0][0])] - # min_area = 1000 # adjust as needed # filtered_cnts = [c for c in cnts if cv2.contourArea(c) > min_area] @@ -140,27 +140,32 @@ def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): # get x coordinates of bounding boxes x_coords = [cv2.boundingRect(c)[0] for c in cnts] # check if x coordinates are sorted in increasing order - is_sorted = all(x1 <= x2 for x1, x2 in zip(x_coords, x_coords[1:])) + is_sorted = np.all(x1 <= x2 for x1, x2 in zip(x_coords, x_coords[1:])) # type: ignore if not is_sorted: - error = ( - "contour detection not valid: contours are not properly sorted from left to right" + raise errors.ContourCalculationError( + "Contour detection not valid: contours are not " + "properly sorted from left to right." ) - return None, None + ################################################################## + # TODO: Remove?? # ---------------------------------------- just for internal evaluation --------------------------------------- output_image = gray.copy() # ---------------------------------------- just for internal evaluation --------------------------------------- + ################################################################## # to store only electrodes contours and nothing redundant - accepted_boxes = [] - filtered_cnts = [] + accepted_boxes: list[t.Box] = [] + filtered_cnts: list[Any] = [] # loop over the contours individually for c in cnts: # compute the rotated bounding box of the contour - rbox = cv2.minAreaRect(c) - box = cv2.cv.BoxPoints(rbox) if is_cv2() else cv2.boxPoints(rbox) - box = array(box, dtype="int") + rbox = cast(t.Box, cv2.minAreaRect(c)) + # !! should only be newer OpenCV versions + # box = cv2.cv.BoxPoints(rbox) if is_cv2() else cv2.boxPoints(rbox) + box = cv2.boxPoints(rbox) + box = np.array(box, dtype="int") # order the points in the contour in top-left, top-right, bottom-right, and bottom-left box = perspective.order_points(box) @@ -193,8 +198,8 @@ def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): filtered_cnts.append(c) # compute the size of the electrode object - dimA = dA / pixelsPerMetricY # y - dimB = dB / pixelsPerMetricX # x + dimA = dA / pixels_per_metric_Y # y + dimB = dB / pixels_per_metric_X # x data_csv.extend( [ @@ -204,6 +209,8 @@ def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): ] ) + ################################################################## + # TODO: Remove?? # ---------------------------------------- just for internal evaluation --------------------------------------- count = 1 # loop over the original points and draw everything @@ -252,48 +259,58 @@ def measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY): # cv2.imwrite(path.join(folder_path, f'{name}_contour_{count}.png'), output_image) count += 1 - cv2.imwrite(path.join(folder_path, f"{name}_all_contours.png"), output_image) + cv2.imwrite(str(folder_path / f"{file_stem}_all_contours.png"), output_image) # ---------------------------------------- just for internal evaluation --------------------------------------- + ################################################################## if not filtered_cnts: - error = "contour detection not valid: no contours recognized" - return None, None + raise errors.ContourCalculationError( + "Contour detection not valid: no contours recognized" + ) # if incorrect number of electrodes has been identified - if len(filtered_cnts) != 6: - print("falsche Anzahl an Elektroden identifiziert", len(filtered_cnts)) - data_csv = [-1] * 6 - return data_csv, None + num_contours = len(filtered_cnts) + if num_contours != const.NUM_VALID_ELECTRODES: + raise errors.InvalidElectrodeCount( + f"Number of counted electroedes does not match the " + f"expected value: count = {num_contours}, expected = {const.NUM_VALID_ELECTRODES}" + ) - else: - # identify left and right sensor areas - x_min = min(npmin(c[:, 0, 0]) for c in filtered_cnts) - 20 - x_max = max(npmax(c[:, 0, 0]) for c in filtered_cnts) + 20 - y_min = min(npmin(c[:, 0, 1]) for c in filtered_cnts) - 20 - y_max = max(npmax(c[:, 0, 1]) for c in filtered_cnts) + 20 + # identify left and right sensor areas + x_min = min(np.min(c[:, 0, 0]) for c in filtered_cnts) - 20 + x_max = max(np.max(c[:, 0, 0]) for c in filtered_cnts) + 20 + y_min = min(np.min(c[:, 0, 1]) for c in filtered_cnts) - 20 + y_max = max(np.max(c[:, 0, 1]) for c in filtered_cnts) + 20 - rightmost_x_third = max(filtered_cnts[2][:, 0, 0]) - leftmost_x_fourth = min(filtered_cnts[3][:, 0, 0]) - x_middle = rightmost_x_third + int((leftmost_x_fourth - rightmost_x_third) / 2.0) + rightmost_x_third = max(filtered_cnts[2][:, 0, 0]) + leftmost_x_fourth = min(filtered_cnts[3][:, 0, 0]) + x_middle = rightmost_x_third + int((leftmost_x_fourth - rightmost_x_third) / 2.0) - # perform further cropping and separation of left and right sensor - cropped_sensor_left = orig[y_min:y_max, x_min:x_middle] - cropped_sensor_right = orig[y_min:y_max, x_middle:x_max] + # perform further cropping and separation of left and right sensor + cropped_sensor_left = orig[y_min:y_max, x_min:x_middle] + cropped_sensor_right = orig[y_min:y_max, x_middle:x_max] - # ---------------------------------------- just for internal evaluation --------------------------------------- - # save the cropped images for left and right sensor - try: - cv2.imwrite(path.join(folder_path, f"{name}_left.png"), cropped_sensor_left) - cv2.imwrite(path.join(folder_path, f"{name}_right.png"), cropped_sensor_right) - except: - print("not possible") - # ---------------------------------------- just for internal evaluation --------------------------------------- + ################################################################## + # TODO: Remove?? + # ---------------------------------------- just for internal evaluation --------------------------------------- + # save the cropped images for left and right sensor + try: + cv2.imwrite(path.join(folder_path, f"{file_stem}_left.png"), cropped_sensor_left) + cv2.imwrite(path.join(folder_path, f"{file_stem}_right.png"), cropped_sensor_right) + except Exception as err: + print(f"not possible: Error: {err}") + # ---------------------------------------- just for internal evaluation --------------------------------------- + ################################################################## - return data_csv, (cropped_sensor_left, cropped_sensor_right) + return data_csv, t.SensorImages(left=cropped_sensor_left, right=cropped_sensor_right) +# helper function # anomaly detection -def infer_image(image, model): +def infer_image( + image: npt.NDArray, + model: Patchcore, +) -> t.InferenceResult: # ---------------------------- # To evaluate the image # Input: @@ -306,20 +323,24 @@ def infer_image(image, model): # anomaly_label (bool): anomaly detected (1) or not (0) # ---------------------------- - torch_device = device("cuda" if cuda.is_available() else "cpu") + torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(torch_device) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # this is optional pil_image = Image.fromarray(image_rgb) - image = pil_image.convert("RGB") + pil_image = pil_image.convert("RGB") input_tensor = ( - to_dtype(to_image(image), float32, scale=True) if as_tensor else array(image) / 255.0 + to_dtype(to_image(pil_image), torch.float32, scale=True) + if torch.as_tensor # ?? Question: Wie passt diese Funktion hier rein? + # ?? Konvertiert, aber wird zur Evaluation der Aussage genutzt (sollte immer wahr sein?) + else np.array(pil_image) / 255.0 ) + # ?? Ist das immer ein Torch-Tensor? Falls nicht, müsste die Methode geändert werden input_tensor = input_tensor.unsqueeze(0) input_tensor = input_tensor.to(torch_device) model.eval() - with no_grad(): + with torch.no_grad(): output = model(input_tensor) anomaly_score = output.pred_score.item() @@ -327,13 +348,19 @@ def infer_image(image, model): anomaly_map = output.anomaly_map.squeeze().cpu().numpy() # resize heatmap to original image size - img_np = array(image) + img_np = np.array(pil_image) anomaly_map_resized = cv2.resize(anomaly_map, (img_np.shape[1], img_np.shape[0])) return img_np, anomaly_map_resized, anomaly_score, anomaly_label -def anomaly_detection(file_path, data_csv, sensor_images): +# ** main function +def anomaly_detection( + file_path: Path, + detection_models: t.DetectionModels, + data_csv: list[str | int], + sensor_images: t.SensorImages, +) -> None: # ---------------------------- # To load the model, call function for anomaly detection and store the results # Input: @@ -342,38 +369,37 @@ def anomaly_detection(file_path, data_csv, sensor_images): # Output: # none # ---------------------------- - file = path.basename(file_path) - # extract file name and ending separately - name, endung = path.splitext(file) - # extract folder path - folder_path = path.dirname(file_path) + file_stem = file_path.stem + folder_path = file_path.parent # reconstruct the model and initialize the engine model = Patchcore( backbone=const.BACKBONE, layers=const.LAYERS, coreset_sampling_ratio=const.RATIO ) + # ?? benötigt? Wird nicht genutzt engine = Engine() # preparation for plot - fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + _, axes = plt.subplots(1, 2, figsize=(12, 6)) # loop over left and right sensor - for i, image in enumerate(sensor_images): - # load the model - checkpoint = load(model_path[i]) + for i, (side, image) in enumerate(sensor_images.items()): + # Ich habe die Modellpfade als Funktionsparameter hinzugefügt + image = cast(npt.NDArray, image) + checkpoint = torch.load(detection_models[side]) model.load_state_dict(checkpoint["model_state_dict"]) - # evaluate image - img_np, anomaly_map_resized, score, label = infer_image(image, model) + _, anomaly_map_resized, score, label = infer_image(image, model) + ################################################################## + # TODO: Remove?? # ---------------------------------------- just for internal evaluation --------------------------------------- print(score) # ---------------------------------------- just for internal evaluation --------------------------------------- + ################################################################## - # add result to data_csv data_csv.extend([int(label)]) - # store heatmap ax = axes[i] ax.axis("off") ax.imshow(image, alpha=0.8) @@ -381,14 +407,15 @@ def anomaly_detection(file_path, data_csv, sensor_images): plt.subplots_adjust(wspace=0, hspace=0) plt.savefig( - path.join(folder_path, f"{name}_Heatmap.png"), bbox_inches="tight", pad_inches=0 + (folder_path / f"{file_stem}{const.HEATMAP_FILENAME_SUFFIX}.png"), + bbox_inches="tight", + pad_inches=0, ) plt.close() - # save csv file df = DataFrame([data_csv]) df.to_csv( - path.join(folder_path, f"{name}.csv"), + (folder_path / f"{file_stem}.csv"), mode="w", index=False, header=False, @@ -396,9 +423,27 @@ def anomaly_detection(file_path, data_csv, sensor_images): sep=";", ) - return +def pipeline( + user_file_path: str, + pixels_per_metric_X: float, + pixels_per_metric_Y: float, +) -> None: + file_path = Path(user_file_path) + if not file_path.exists(): + raise FileNotFoundError("The provided path seems not to exist") -data_csv, sensors = measure_length(file_path, pixelsPerMetricX, pixelsPerMetricY) + MODEL_FOLDER: Final[Path] = dopt_sensor_anomalies._find_paths.get_model_folder() + DETECTION_MODELS: Final[t.DetectionModels] = ( + dopt_sensor_anomalies._find_paths.get_detection_models(MODEL_FOLDER) + ) -anomaly_detection(file_path, data_csv, sensors) + data_csv, sensor_images = measure_length( + file_path, pixels_per_metric_X, pixels_per_metric_Y + ) + anomaly_detection( + file_path=file_path, + detection_models=DETECTION_MODELS, + data_csv=data_csv, + sensor_images=sensor_images, + ) diff --git a/src/dopt_sensor_anomalies/errors.py b/src/dopt_sensor_anomalies/errors.py new file mode 100644 index 0000000..89b6b9b --- /dev/null +++ b/src/dopt_sensor_anomalies/errors.py @@ -0,0 +1,10 @@ +class ImageNotReadError(Exception): + """thrown if image was not read successfully""" + + +class ContourCalculationError(Exception): + """thrown if contour detection was not successful""" + + +class InvalidElectrodeCount(Exception): + """thrown if the number of electrodes does not match the expected value""" diff --git a/src/dopt_sensor_anomalies/types.py b/src/dopt_sensor_anomalies/types.py new file mode 100644 index 0000000..259f251 --- /dev/null +++ b/src/dopt_sensor_anomalies/types.py @@ -0,0 +1,18 @@ +import dataclasses as dc +from pathlib import Path +from typing import TypeAlias, TypedDict + +import numpy.typing as npt + +Box: TypeAlias = tuple[tuple[float, float], tuple[float, float], float] +InferenceResult: TypeAlias = tuple[npt.NDArray, npt.NDArray, float, bool] + + +class SensorImages(TypedDict): + left: npt.NDArray + right: npt.NDArray + + +class DetectionModels(TypedDict): + left: Path + right: Path diff --git a/tests/test_find_paths.py b/tests/test_find_paths.py new file mode 100644 index 0000000..696cd73 --- /dev/null +++ b/tests/test_find_paths.py @@ -0,0 +1,102 @@ +from pathlib import Path +from unittest.mock import patch + +import pytest + +from dopt_sensor_anomalies import _find_paths + + +@pytest.fixture(scope="module", autouse=True) +def setup_temp_dir(tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("root") + folder_structure = "lib/folder" + pth = tmp_dir / folder_structure + pth.mkdir(parents=True, exist_ok=True) + + with patch("dopt_sensor_anomalies._find_paths.LIB_ROOT_PATH", pth): + yield + + +@pytest.fixture() +def temp_model_folder_empty(tmp_path_factory) -> Path: + return tmp_path_factory.mktemp("empty") + + +@pytest.fixture() +def temp_model_folder_full(tmp_path_factory) -> Path: + folder = tmp_path_factory.mktemp("full") + left_hand_model = folder / "this_file_contains_the_left_hand_side_model.pth" + right_hand_model = folder / "this_file_contains_the_right_hand_side_model.pth" + left_hand_model.touch() + right_hand_model.touch() + return folder + + +@pytest.fixture() +def temp_model_folder_only_left(tmp_path_factory) -> Path: + folder = tmp_path_factory.mktemp("only_left") + left_hand_model = folder / "this_file_contains_the_left_hand_side_model.pth" + left_hand_model.touch() + return folder + + +@pytest.fixture() +def temp_model_folder_only_right(tmp_path_factory) -> Path: + folder = tmp_path_factory.mktemp("only_right") + right_hand_model = folder / "this_file_contains_the_right_hand_side_model.pth" + right_hand_model.touch() + return folder + + +@patch("dopt_sensor_anomalies._find_paths.STOP_FOLDER_NAME", "not-found") +def test_get_model_folder_Fail_NotFound(): + with pytest.raises(FileNotFoundError): + _ = _find_paths.get_model_folder() + + +@patch("dopt_sensor_anomalies._find_paths.STOP_FOLDER_NAME", "lib") +def test_get_model_folder_Success(): + ret = _find_paths.get_model_folder() + assert ret is not None + assert ret.name == _find_paths.MODEL_FOLDER_NAME + + +def test_get_detection_models_FailEmptyDir(temp_model_folder_empty): + with pytest.raises(ValueError): + _ = _find_paths.get_detection_models(temp_model_folder_empty) + + +def test_get_detection_models_FailOnlyLeft(temp_model_folder_only_left): + with pytest.raises(ValueError): + _ = _find_paths.get_detection_models(temp_model_folder_only_left) + + +def test_get_detection_models_FailOnlyRight(temp_model_folder_only_right): + with pytest.raises(ValueError): + _ = _find_paths.get_detection_models(temp_model_folder_only_right) + + +def test_get_detection_models_FailTooManyLeft(temp_model_folder_full): + right_hand_model = ( + temp_model_folder_full / "this_file_contains_the_left_hand_side_model2.pth" + ) + right_hand_model.touch() + + with pytest.raises(ValueError): + _ = _find_paths.get_detection_models(temp_model_folder_full) + + +def test_get_detection_models_FailTooManyRight(temp_model_folder_full): + right_hand_model = ( + temp_model_folder_full / "this_file_contains_the_right_hand_side_model2.pth" + ) + right_hand_model.touch() + + with pytest.raises(ValueError): + _ = _find_paths.get_detection_models(temp_model_folder_full) + + +def test_get_detection_models_Success(temp_model_folder_full): + models = _find_paths.get_detection_models(temp_model_folder_full) + assert "left_hand" in models["left"].name + assert "right_hand" in models["right"].name