add basic functionality and corresponding tests

This commit is contained in:
Florian Förster 2025-03-13 16:36:18 +01:00
parent 48881e882c
commit 2db39b536e
15 changed files with 1026 additions and 2 deletions

1
.gitignore vendored
View File

@ -1,5 +1,6 @@
# own # own
*.code-workspace *.code-workspace
prototypes/
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

13
pdm.lock generated
View File

@ -5,7 +5,7 @@
groups = ["default", "dev", "lint", "nb", "tests"] groups = ["default", "dev", "lint", "nb", "tests"]
strategy = ["inherit_metadata"] strategy = ["inherit_metadata"]
lock_version = "4.5.0" lock_version = "4.5.0"
content_hash = "sha256:7fbf0fb5e93b92622653d3030d79837827c74020ce1d12cbe64fe40be0c04c46" content_hash = "sha256:c0f37f44c762f301bbedf08036c9a24479a5b134b3616e6b5446cb3fb6ef8f14"
[[metadata.targets]] [[metadata.targets]]
requires_python = ">=3.11" requires_python = ">=3.11"
@ -2184,6 +2184,17 @@ files = [
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
] ]
[[package]]
name = "tzdata"
version = "2025.1"
requires_python = ">=2"
summary = "Provider of IANA time zone data"
groups = ["default"]
files = [
{file = "tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639"},
{file = "tzdata-2025.1.tar.gz", hash = "sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694"},
]
[[package]] [[package]]
name = "uri-template" name = "uri-template"
version = "1.3.0" version = "1.3.0"

View File

@ -5,7 +5,7 @@ description = "basic cross-project tools for Python-based d-opt projects"
authors = [ authors = [
{name = "Florian Förster", email = "f.foerster@d-opt.com"}, {name = "Florian Förster", email = "f.foerster@d-opt.com"},
] ]
dependencies = [] dependencies = ["tzdata>=2025.1"]
requires-python = ">=3.11" requires-python = ">=3.11"
readme = "README.md" readme = "README.md"
license = {text = "LicenseRef-Proprietary"} license = {text = "LicenseRef-Proprietary"}

View File

@ -0,0 +1,25 @@
from __future__ import annotations
import tomllib
from pathlib import Path
from typing import Any
def load_toml(
path_to_toml: str | Path,
print_success: bool = False,
) -> dict[str, Any]:
if isinstance(path_to_toml, str):
path_to_toml = Path(path_to_toml)
if not path_to_toml.exists():
raise FileNotFoundError(f"Config file seems not to exist under: >{path_to_toml}<")
path_to_toml = path_to_toml.with_suffix(".toml")
with open(path_to_toml, "rb") as f:
data = tomllib.load(f)
if print_success: # pragma: no cover
print("Loaded TOML config file successfully.", flush=True)
return data

View File

@ -0,0 +1,79 @@
from __future__ import annotations
from collections.abc import Iterator, MutableMapping
from typing import Any, TypeAlias, TypeVar
FlattableObject: TypeAlias = (
list["FlattableObject | Any"]
| tuple["FlattableObject | Any", ...]
| set["FlattableObject | Any"]
)
K = TypeVar("K")
V = TypeVar("V")
def flatten(
obj: FlattableObject,
) -> Iterator[Any]:
"""flattens an arbitrarily nested list or tuple
Parameters
----------
obj : FlattableObject
arbitrarily nested list, tuple, set
Yields
------
Iterator[Any]
elements of the non-nested list, tuple, set
"""
for x in obj:
# only flatten lists and tuples
if isinstance(x, (list, tuple, set)):
yield from flatten(x)
else:
yield x
class DualDict(MutableMapping[K, V]):
def __init__(self, **kwargs: V):
self._store: dict[K, V] = dict(**kwargs)
self._inverted = self._calc_inverted()
def __str__(self) -> str:
return str(self._store) + str(self._inverted)
def __repr__(self) -> str:
return self.__str__()
@property
def inverted(self) -> dict[V, K]:
return self._inverted
def _calc_inverted(self) -> dict[V, K]:
invert = {val: key for key, val in self._store.items()}
if len(invert) != len(self._store):
raise ValueError("DualDict does not support identical values")
return invert
def __setitem__(self, key: K, value: V) -> None:
self._store[key] = value
self._inverted = self._calc_inverted()
def __getitem__(self, key: K) -> V:
return self._store[key]
def __delitem__(self, key: K) -> None:
del self._store[key]
self._inverted = self._calc_inverted()
def __iter__(self) -> Iterator[K]:
return iter(self._store)
def __len__(self) -> int:
return len(self._store)
def update(self, **kwargs: V) -> None:
self._store.update(**kwargs)
self._inverted = self._calc_inverted()

234
src/dopt_basics/datetime.py Normal file
View File

@ -0,0 +1,234 @@
from __future__ import annotations
import enum
import zoneinfo as tz_info
from datetime import datetime as Datetime
from datetime import timedelta as Timedelta
from datetime import timezone as Timezone
from datetime import tzinfo as TZInfo
from typing import Final
from dopt_basics.enums import enum_str_values_as_frzset
class TimeUnitsDatetime(enum.StrEnum):
YEAR = enum.auto()
MONTH = enum.auto()
DAY = enum.auto()
HOUR = enum.auto()
MINUTE = enum.auto()
SECOND = enum.auto()
MICROSECOND = enum.auto()
class TimeUnitsTimedelta(enum.StrEnum):
WEEKS = enum.auto()
DAYS = enum.auto()
HOURS = enum.auto()
MINUTES = enum.auto()
SECONDS = enum.auto()
MILLISECONDS = enum.auto()
MICROSECONDS = enum.auto()
TIMEZONE_CEST: Final[tz_info.ZoneInfo] = tz_info.ZoneInfo("Europe/Berlin")
TIMEZONE_UTC: Final[Timezone] = Timezone.utc
def get_timestamp(
tz: TZInfo = TIMEZONE_UTC,
with_time: bool = False,
) -> str:
dt = current_time_tz(tz)
if with_time:
return dt.strftime(r"%Y-%m-%d--%H-%M-%S")
return dt.strftime(r"%Y-%m-%d")
def timedelta_from_val(
val: float,
time_unit: TimeUnitsTimedelta,
) -> Timedelta:
"""create Python timedelta object by choosing time value and time unit
Parameters
----------
val : float
duration
time_unit : str
target time unit
Returns
-------
Timedelta
timedelta object corresponding to the given values
Raises
------
ValueError
if chosen time unit not implemented
"""
try:
TimeUnitsTimedelta(time_unit)
except ValueError:
allowed_time_units = enum_str_values_as_frzset(TimeUnitsTimedelta)
raise ValueError(
f"Time unit >>{time_unit}<< not supported. Choose from {allowed_time_units}"
)
else:
kwargs = {time_unit: val}
return Timedelta(**kwargs)
def dt_with_tz_UTC(
*args,
**kwargs,
) -> Datetime:
return Datetime(*args, **kwargs, tzinfo=TIMEZONE_UTC)
def round_td_by_seconds(
td: Timedelta,
round_to_next_seconds: int = 1,
) -> Timedelta:
"""round timedelta object to the next full defined seconds
Parameters
----------
td : Timedelta
timedelta object to be rounded
round_to_next_seconds : int, optional
number of seconds to round to, by default 1
Returns
-------
Timedelta
rounded timedelta object
"""
total_seconds = td.total_seconds()
rounded_seconds = round(total_seconds / round_to_next_seconds) * round_to_next_seconds
return Timedelta(seconds=rounded_seconds)
def current_time_tz(
tz: TZInfo = TIMEZONE_UTC,
cut_microseconds: bool = False,
) -> Datetime:
"""current time as datetime object with
associated time zone information (UTC by default)
Parameters
----------
tz : TZInfo, optional
time zone information, by default TIMEZONE_UTC
Returns
-------
Datetime
datetime object with corresponding time zone
"""
if cut_microseconds:
return Datetime.now(tz=tz).replace(microsecond=0)
else:
return Datetime.now(tz=tz)
def add_timedelta_with_tz(
starting_dt: Datetime,
td: Timedelta,
) -> Datetime:
"""time-zone-aware calculation of an end point in time
with a given timedelta
Parameters
----------
starting_dt : Datetime
starting point in time
td : Timedelta
duration as timedelta object
Returns
-------
Datetime
time-zone-aware end point
"""
if starting_dt.tzinfo is None:
# no time zone information
raise ValueError("The provided starting date does not contain time zone information.")
else:
# obtain time zone information from starting datetime object
tz_info = starting_dt.tzinfo
# transform starting point in time to utc
dt_utc = starting_dt.astimezone(TIMEZONE_UTC)
# all calculations are done in UTC
# add duration
ending_dt_utc = dt_utc + td
# transform back to previous time zone
ending_dt = ending_dt_utc.astimezone(tz=tz_info)
return ending_dt
def validate_dt_UTC(
dt: Datetime,
) -> None:
"""validates if datetime object is timezone-aware and references
UTC time
Parameters
----------
dt : Datetime
datetime object to be checked for available UTC time zone
information
Raises
------
ValueError
if no UTC time zone information is found
"""
if dt.tzinfo != TIMEZONE_UTC:
raise ValueError(
f"Datetime object {dt} does not contain necessary UTC time zone information"
)
def dt_to_timezone(
dt: Datetime,
target_tz: TZInfo = TIMEZONE_CEST,
) -> Datetime:
"""convert a datetime object from one timezone to another
Parameters
----------
dt : Datetime
datetime with time zone information
target_tz : TZInfo, optional
target time zone information, by default TIMEZONE_CEST
Returns
-------
Datetime
datetime object adjusted to given local time zone
Raises
------
RuntimeError
if datetime object does not contain time zone information
"""
if dt.tzinfo is None:
# no time zone information
raise ValueError("The provided starting date does not contain time zone information.")
# transform to given target time zone
dt_local_tz = dt.astimezone(tz=target_tz)
return dt_local_tz
def cut_dt_microseconds(
dt: Datetime,
) -> Datetime:
return dt.replace(microsecond=0)

22
src/dopt_basics/enums.py Normal file
View File

@ -0,0 +1,22 @@
from __future__ import annotations
from enum import StrEnum
from typing import Type
def enum_str_values_as_frzset(
enum_class: Type[StrEnum],
) -> frozenset[str]:
"""returns the values of an StrEnum class as a frozenset
Parameters
----------
enum_cls : Any
Enum class
Returns
-------
frozenset
values of the Enum class
"""
return frozenset(val.value for val in enum_class)

166
src/dopt_basics/paths.py Normal file
View File

@ -0,0 +1,166 @@
from __future__ import annotations
import shutil
from collections.abc import Sequence
from pathlib import Path
from dopt_basics.datetime import TIMEZONE_CEST, get_timestamp
def create_folder(
path: Path,
delete_existing: bool = False,
) -> None:
if delete_existing and path.exists():
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
def prepare_save_path(
root_folder: Path,
dirs: Sequence[str] | None,
filename: str | None,
suffix: str | None,
include_timestamp: bool = False,
create_folder: bool = False,
) -> Path:
if not any((dirs, filename, suffix)):
raise ValueError("Dirs or filename must be provided")
if not (
all(x is None for x in (filename, suffix))
or all(x is not None for x in (filename, suffix))
):
raise ValueError("Filename and suffix must be provided together")
if include_timestamp and filename is None:
raise ValueError("Timestamp only with filename")
folders: str = ""
if dirs is not None:
folders = "/".join(dirs)
filename = "" if filename is None else filename
if include_timestamp:
timestamp = get_timestamp(tz=TIMEZONE_CEST, with_time=True)
filename = f"{timestamp}_{filename}"
if suffix is None:
suffix = ""
elif suffix is not None and suffix == ".":
raise ValueError("Suffix can not be just dot.")
elif suffix is not None and not suffix.startswith("."):
suffix = f".{suffix}"
pth_parent = (root_folder / folders).resolve()
if create_folder and not pth_parent.exists():
pth_parent.mkdir(parents=True)
return (pth_parent / filename).with_suffix(suffix)
def search_cwd(
glob_pattern: str,
) -> Path | None:
"""Searches the current working directory and looks for files
matching the glob pattern.
Returns the first match encountered.
Parameters
----------
glob_pattern : str, optional
pattern to look for, first match will be returned
Returns
-------
Path | None
Path if corresponding object was found, None otherwise
"""
path_found: Path | None = None
res = tuple(Path.cwd().glob(glob_pattern))
if res:
path_found = res[0]
return path_found
def search_file_iterative(
starting_path: Path,
glob_pattern: str,
stop_folder_name: str | None = None,
) -> Path | None:
"""Iteratively searches the parent directories of the starting path
and look for files matching the glob pattern. The starting path is not
searched, only its parents. Therefore the starting path can also point
to a file. The folder in which it is placed in will be searched.
Returns the first match encountered.
The parent of the stop folder will be searched if it exists.
Parameters
----------
starting_path : Path
non-inclusive starting path
glob_pattern : str, optional
pattern to look for, first match will be returned
stop_folder_name : str, optional
name of the last folder in the directory tree where search should stop
(parent is searched), by default None
Returns
-------
Path | None
Path if corresponding object was found, None otherwise
"""
file_path: Path | None = None
stop_folder_reached: bool = False
for search_path in starting_path.parents:
res = tuple(search_path.glob(glob_pattern))
if res:
file_path = res[0]
break
elif stop_folder_reached:
break
if stop_folder_name is not None and search_path.name == stop_folder_name:
# library is placed inside a whole python installation for deployment
# if this folder is reached, only look up one parent above
stop_folder_reached = True
return file_path
def search_folder_path(
starting_path: Path,
stop_folder_name: str | None = None,
) -> Path | None:
"""Iteratively searches the parent directories of the starting path
and look for folders matching the given name. If a match is encountered,
the parent path will be returned.
Example:
starting_path = path/to/start/folder
stop_folder_name = 'to'
returned path = 'path/'
Parameters
----------
starting_path : Path
non-inclusive starting path
stop_folder_name : str, optional
name of the last folder in the directory tree to search, by default None
Returns
-------
Path | None
Path if corresponding base path was found, None otherwise
"""
stop_folder_path: Path | None = None
base_path: Path | None = None
for search_path in starting_path.parents:
if stop_folder_name is not None and search_path.name == stop_folder_name:
# library is placed inside a whole python installation for deployment
# only look up to this folder
stop_folder_path = search_path
break
if stop_folder_path is not None:
base_path = stop_folder_path.parent
return base_path

View File

@ -0,0 +1,2 @@
[test]
entry = 'test123'

12
tests/conftest.py Normal file
View File

@ -0,0 +1,12 @@
from pathlib import Path
import pytest
@pytest.fixture(scope="session")
def root_data_folder() -> Path:
pth = Path.cwd() / "tests/_test_data/"
assert pth.exists()
assert pth.is_dir()
return pth

36
tests/test_configs.py Normal file
View File

@ -0,0 +1,36 @@
from pathlib import Path
import pytest
from dopt_basics import configs
@pytest.fixture(scope="module")
def config_file(root_data_folder) -> Path:
pth = root_data_folder / "config.toml"
assert pth.exists()
assert pth.is_file()
return pth
def test_load_toml_SuccessPath(config_file):
cfg = configs.load_toml(config_file)
assert isinstance(cfg, dict)
assert "test" in cfg
assert cfg["test"]["entry"] == "test123"
def test_load_toml_SuccessStringPath(config_file):
str_pth = str(config_file)
cfg = configs.load_toml(str_pth)
assert isinstance(cfg, dict)
assert "test" in cfg
assert cfg["test"]["entry"] == "test123"
def test_load_toml_FailWrongPath(tmp_path):
wrong_pth = tmp_path / "config.toml"
with pytest.raises(FileNotFoundError):
_ = configs.load_toml(wrong_pth)

View File

@ -0,0 +1,64 @@
import pytest
from dopt_basics import datastructures as dst
def test_flatten():
nested_iterable = ([1, 2], [[3], [4, 5]], [6, [7, 8, 9]])
target = tuple(i for i in range(1, 10))
ret_iter = dst.flatten(nested_iterable)
ret = tuple(ret_iter)
assert ret == target
def test_DualDict():
base_dict: dict[str, int] = {"test1": 1, "test2": 2, "test3": 3}
inverted_dict: dict[int, str] = {1: "test1", 2: "test2", 3: "test3"}
assert all((key == inverted_dict[value] for key, value in base_dict.items()))
dual_dict: dst.DualDict[str, int] = dst.DualDict(test1=1, test2=2, test3=3)
assert all((key in dual_dict for key in base_dict.keys()))
assert all((base_dict[key] == dual_dict[key] for key in base_dict.keys()))
assert all((key == dual_dict.inverted[value] for key, value in base_dict.items()))
assert all(
(inverted_dict[key] == dual_dict.inverted[key] for key in inverted_dict.keys())
)
base_dict["test_add"] = 5
dual_dict["test_add"] = 5
assert len(base_dict) == len(dual_dict)
assert len(dual_dict) == len(dual_dict.inverted)
del base_dict["test_add"]
del dual_dict["test_add"]
assert len(base_dict) == len(dual_dict)
assert len(dual_dict) == len(dual_dict.inverted)
for key_base, key_dd in zip(base_dict, dual_dict):
assert key_base == key_dd
def test_DualDict_update_Success():
base_dict: dict[str, int] = {"test1": 1, "test2": 2, "test3": 3}
dual_dict: dst.DualDict[str, int] = dst.DualDict(test1=1, test2=2, test3=3)
update = dict(test3=4, test4=5)
base_dict.update(**update)
dual_dict.update(**update)
assert all((key in dual_dict for key in base_dict.keys()))
assert all((base_dict[key] == dual_dict[key] for key in base_dict.keys()))
assert all((key == dual_dict.inverted[value] for key, value in base_dict.items()))
def test_DualDict_update_FailIdenticalValues():
base_dict: dict[str, int] = {"test1": 1, "test2": 2, "test3": 3}
with pytest.raises(ValueError):
_: dst.DualDict[str, int] = dst.DualDict(test1=1, test2=3, test3=3)
dual_dict: dst.DualDict[str, int] = dst.DualDict(test1=1, test2=2, test3=3)
update = dict(test3=4, test4=4)
base_dict.update(**update)
with pytest.raises(ValueError):
dual_dict.update(**update)

148
tests/test_datetime.py Normal file
View File

@ -0,0 +1,148 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import pytest
from dopt_basics import datetime as datetime_
from dopt_basics.datetime import TIMEZONE_CEST, TimeUnitsTimedelta
def test_dt_with_UTC():
year = 2024
month = 3
day = 28
hour = 3
minute = 0
dt_target = datetime(year, month, day, hour, minute, tzinfo=UTC)
dt_ret = datetime_.dt_with_tz_UTC(year, month, day, hour, minute)
assert dt_target == dt_ret
@pytest.mark.parametrize(
"time_unit, expected",
[
("hours", timedelta(hours=2.0)),
("minutes", timedelta(minutes=2.0)),
("seconds", timedelta(seconds=2.0)),
("milliseconds", timedelta(milliseconds=2.0)),
("microseconds", timedelta(microseconds=2.0)),
(TimeUnitsTimedelta.HOURS, timedelta(hours=2.0)),
(TimeUnitsTimedelta.MINUTES, timedelta(minutes=2.0)),
],
)
def test_timedelta_from_val_Success(time_unit, expected):
val = 2.0
td = datetime_.timedelta_from_val(val, time_unit)
assert td == expected
def test_timedelta_from_val_FailWrongTimeUnit():
val = 2.0
time_unit = "years"
with pytest.raises(ValueError):
datetime_.timedelta_from_val(val, time_unit) # type: ignore
def test_round_td_by_seconds():
hours = 2.0
minutes = 30.0
seconds = 30.0
microseconds = 600
td = timedelta(hours=hours, minutes=minutes, seconds=seconds, microseconds=microseconds)
rounded_td = datetime_.round_td_by_seconds(td, round_to_next_seconds=1)
assert rounded_td == timedelta(hours=2.0, minutes=30.0, seconds=30.0)
def test_current_time_tz():
tz = datetime_.TIMEZONE_UTC
mock_dt = datetime(2024, 6, 1, 12, 15, 30, 1000, tzinfo=tz)
with patch("dopt_basics.datetime.Datetime") as mock_obj:
mock_obj.now.return_value = mock_dt
ret = datetime_.current_time_tz(cut_microseconds=False)
assert ret.tzinfo is not None
assert ret == mock_dt
with patch("dopt_basics.datetime.Datetime") as mock_obj:
mock_obj.now.return_value = mock_dt
ret = datetime_.current_time_tz(cut_microseconds=True)
target = datetime(2024, 6, 1, 12, 15, 30, tzinfo=tz)
assert ret.tzinfo is not None
assert ret == target
def test_get_timestamp_WithTime():
tz = datetime_.TIMEZONE_UTC
mock_dt = datetime(2024, 6, 1, 12, 15, 30, 1000, tzinfo=tz)
with patch("dopt_basics.datetime.Datetime") as mock_obj:
mock_obj.now.return_value = mock_dt
ret = datetime_.get_timestamp(tz=tz, with_time=True)
target = "2024-06-01--12-15-30"
assert ret == target
def test_get_timestamp_WithoutTime():
tz = datetime_.TIMEZONE_UTC
mock_dt = datetime(2024, 6, 1, 12, 15, 30, 1000, tzinfo=tz)
with patch("dopt_basics.datetime.Datetime") as mock_obj:
mock_obj.now.return_value = mock_dt
ret = datetime_.get_timestamp(tz=tz, with_time=False)
target = "2024-06-01"
assert ret == target
def test_add_timedelta_FailWithoutTZInfo():
year = 2024
month = 3
day = 30
hour = 3
minute = 0
dt = datetime(year, month, day, hour, minute)
td = timedelta(hours=2.0)
with pytest.raises(ValueError):
datetime_.add_timedelta_with_tz(dt, td)
def test_add_timedelta_with_tz():
year = 2024
month = 3
day = 30
hour = 23
minute = 0
dt = datetime(year, month, day, hour, minute, tzinfo=TIMEZONE_CEST)
td = timedelta(hours=6.0)
new_dt = datetime_.add_timedelta_with_tz(dt, td)
assert new_dt == datetime(2024, 3, 31, 6, 0, tzinfo=TIMEZONE_CEST)
def test_validate_dt_UTC_Success():
dt = datetime(2024, 3, 30, 0, 0, tzinfo=UTC)
datetime_.validate_dt_UTC(dt)
def test_validate_dt_FailWrongTZInfo():
dt = datetime(2024, 3, 30, 0, 0, tzinfo=TIMEZONE_CEST)
with pytest.raises(ValueError):
datetime_.validate_dt_UTC(dt)
def test_dt_to_timezone_Success():
dt = datetime(2024, 3, 30, 2, 0, tzinfo=UTC)
new_dt = datetime_.dt_to_timezone(dt, TIMEZONE_CEST)
assert new_dt == datetime(2024, 3, 30, 3, tzinfo=TIMEZONE_CEST)
def test_dt_to_timezone_FailWithoutTZInfo():
dt = datetime(2024, 3, 30, 2, 0)
with pytest.raises(ValueError):
datetime_.dt_to_timezone(dt, TIMEZONE_CEST)
def test_cut_microseconds():
dt = datetime(2024, 3, 30, 2, 0, 0, 600)
new_dt = datetime_.cut_dt_microseconds(dt)
assert new_dt == datetime(2024, 3, 30, 2, 0, 0, 0)

17
tests/test_enums.py Normal file
View File

@ -0,0 +1,17 @@
import enum
from dopt_basics import enums
def test_enum_str_values_as_frzset():
class TestEnum(enum.StrEnum):
T1 = enum.auto()
T2 = enum.auto()
T3 = enum.auto()
target_vals = frozenset(("t1", "t2", "t3"))
extracted_vals = enums.enum_str_values_as_frzset(TestEnum)
diff = target_vals.difference(extracted_vals)
assert len(diff) == 0

207
tests/test_paths.py Normal file
View File

@ -0,0 +1,207 @@
from pathlib import Path
import pytest
from dopt_basics import paths
FILE_SEARCH = "test.txt"
@pytest.fixture(scope="module")
def base_folder(tmp_path_factory) -> Path:
folder_structure = "path/to/base/folder/"
pth = tmp_path_factory.mktemp("search")
pth = pth / folder_structure
pth.mkdir(parents=True, exist_ok=True)
return pth
@pytest.fixture(scope="module")
def target_file_pth(base_folder) -> Path:
# place in folder 'path' of TMP path
target_folder = base_folder.parents[2]
target_file = target_folder / FILE_SEARCH
with open(target_file, "w") as file:
file.write("TEST")
return target_file
@pytest.mark.parametrize(
"delete_existing",
[True, False],
)
def test_create_folder(tmp_path, delete_existing):
target_dir = tmp_path / "test"
assert not target_dir.exists()
paths.create_folder(target_dir, delete_existing=delete_existing)
assert target_dir.exists()
assert target_dir.is_dir()
paths.create_folder(target_dir, delete_existing=delete_existing)
assert target_dir.exists()
assert target_dir.is_dir()
def test_prepare_save_path_SuccessWithCreate(tmp_path):
base_folder = tmp_path
dirs = ("target", "dir")
filename = None
suffix = None
target_pth = tmp_path / "/".join(dirs)
res_pth = paths.prepare_save_path(base_folder, dirs, filename, suffix, create_folder=True)
assert res_pth.exists()
assert res_pth == target_pth
def test_prepare_save_path_SuccessWithCreateTimestamp(tmp_path):
base_folder = tmp_path
dirs = ("target", "dir")
filename = "test"
suffix = ".pkl"
res_pth = paths.prepare_save_path(
base_folder, dirs, filename, suffix, create_folder=True, include_timestamp=True
)
assert res_pth.parent.exists()
def test_prepare_save_path_SuccessWithoutCreate(tmp_path):
base_folder = tmp_path
dirs = ("target", "dir")
filename = None
suffix = None
target_pth = tmp_path / "/".join(dirs)
res_pth = paths.prepare_save_path(
base_folder, dirs, filename, suffix, create_folder=False
)
assert not res_pth.exists()
assert res_pth == target_pth
def test_prepare_save_path_FailNoTargets(tmp_path):
base_folder = tmp_path
dirs = None
filename = None
suffix = None
with pytest.raises(ValueError):
_ = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
)
def test_prepare_save_path_FailNoFilenameSuffix(tmp_path):
base_folder = tmp_path
dirs = None
filename = None
suffix = "pkl"
with pytest.raises(ValueError):
_ = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
)
filename = "test"
suffix = None
with pytest.raises(ValueError):
_ = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
)
def test_prepare_save_path_FailTimestampWithoutFilename(tmp_path):
base_folder = tmp_path
dirs = ["test"]
filename = None
suffix = None
with pytest.raises(ValueError):
_ = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
include_timestamp=True,
)
def test_prepare_save_path_FailBadSuffix(tmp_path):
base_folder = tmp_path
dirs = None
filename = "test"
suffix = "."
with pytest.raises(ValueError):
_ = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
include_timestamp=False,
)
def test_prepare_save_path_SuccessSuffixAddDot(tmp_path):
base_folder = tmp_path
dirs = None
filename = "test"
suffix = "pkl"
target_path = tmp_path / f"{filename}.{suffix}"
ret_path = paths.prepare_save_path(
base_folder,
dirs,
filename,
suffix,
create_folder=False,
include_timestamp=False,
)
assert ret_path == target_path
def test_search_cwd(monkeypatch, base_folder, target_file_pth):
monkeypatch.setattr(Path, "cwd", lambda: base_folder)
assert Path.cwd() == base_folder
ret = paths.search_cwd(FILE_SEARCH)
assert ret is None
target_folder = target_file_pth.parent
monkeypatch.setattr(Path, "cwd", lambda: target_folder)
assert Path.cwd() == target_folder
ret = paths.search_cwd(FILE_SEARCH)
assert ret is not None
assert ret == target_file_pth
@pytest.mark.parametrize("stop_folder_name", ["to", "base", None])
def test_search_file_iterative(base_folder, target_file_pth, stop_folder_name):
# target in parent of 'to': 'path'
ret = paths.search_file_iterative(base_folder, FILE_SEARCH, stop_folder_name)
if stop_folder_name == "to" or stop_folder_name is None:
assert ret is not None
assert ret.name == FILE_SEARCH
assert ret == target_file_pth
elif stop_folder_name == "base":
assert ret is None
def test_search_folder_path(base_folder):
stop_folder = "123" # should not exist
found = paths.search_folder_path(base_folder, stop_folder_name=stop_folder)
assert found is None
stop_folder = "to"
found = paths.search_folder_path(base_folder, stop_folder_name=stop_folder)
assert found is not None
assert found.name == "path"
stop_folder = None
found = paths.search_folder_path(base_folder, stop_folder_name=stop_folder)
assert found is None