2025-10-23 09:49:22 +02:00

75 lines
2.2 KiB
Python

from pathlib import Path
import dopt_basics.io
from dopt_sensor_anomalies import types as t
from dopt_sensor_anomalies.constants import LIB_ROOT_PATH, MODEL_FOLDER_NAME, STOP_FOLDER_NAME
def get_model_folder() -> Path:
"""retrieves the folder which contains the trained and needed models for anomaly detection
Returns
-------
Path
the path containing the models
Raises
------
FileNotFoundError
raised if the application's root folder is not found or the model folder is not present
"""
path_found = dopt_basics.io.search_folder_path(
starting_path=LIB_ROOT_PATH, stop_folder_name=STOP_FOLDER_NAME
)
if path_found is None:
raise FileNotFoundError("The application's root directory could not be determined.")
model_folder = path_found / MODEL_FOLDER_NAME
if not model_folder.exists():
raise FileNotFoundError(
"The model folder was not found in the application's root directory."
)
return model_folder
def get_detection_models(
model_folder: Path,
) -> t.DetectionModels:
"""retrieve the model paths both for the left and the right side as a TypedDict
Parameters
----------
model_folder : Path
the found path to the folder containing the models
Returns
-------
t.DetectionModels
TypedDict with key "left" and "right" with the corresponding paths to each model
Raises
------
ValueError
raised if there are no or too many model files are found for one side
"""
left_model_search = tuple(model_folder.glob("*left_hand_side*.pth"))
if len(left_model_search) == 0:
raise ValueError("No model for the left hand side found.")
if len(left_model_search) > 1:
raise ValueError("Too many models for the left hand side found.")
right_model_search = tuple(model_folder.glob("*right_hand_side*.pth"))
if len(right_model_search) == 0:
raise ValueError("No model for the right hand side found.")
elif len(right_model_search) > 1:
raise ValueError("Too many models for the right hand side found.")
left_model = left_model_search[0]
right_model = right_model_search[0]
return t.DetectionModels(left=left_model, right=right_model)