274 lines
8.9 KiB
Python
274 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
from typing import Any, Never, cast
|
|
from typing_extensions import override
|
|
|
|
from lang_main.errors import (
|
|
NoPerformableActionError,
|
|
OutputInPipelineContainerError,
|
|
WrongActionTypeError,
|
|
)
|
|
from lang_main.io import load_pickle, save_pickle
|
|
from lang_main.loggers import logger_pipelines as logger
|
|
from lang_main.types import ResultHandling
|
|
|
|
|
|
class BasePipeline(ABC):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
working_dir: Path,
|
|
) -> None:
|
|
# init base class
|
|
super().__init__()
|
|
|
|
# name of pipeline
|
|
self.name = name
|
|
# working directory for pipeline == output path
|
|
self.working_dir = working_dir
|
|
|
|
# container for actions to perform during pass
|
|
self.actions: list[Callable] = []
|
|
self.action_names: list[str] = []
|
|
# progress tracking, start at 1
|
|
self.curr_proc_idx: int = 1
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f'{self.__class__.__name__}(name: {self.name}, '
|
|
f'working dir: {self.working_dir}, contents: {self.action_names})'
|
|
)
|
|
|
|
def panic_wrong_action_type(
|
|
self,
|
|
action: Any,
|
|
compatible_type: str,
|
|
) -> Never:
|
|
raise WrongActionTypeError(
|
|
(
|
|
f'Action must be of type {compatible_type}, '
|
|
f'but is of type >>{type(action)}<<.'
|
|
)
|
|
)
|
|
|
|
def prep_run(self) -> None:
|
|
logger.info('Starting pipeline >>%s<<...', self.name)
|
|
# progress tracking
|
|
self.curr_proc_idx = 1
|
|
# check if performable actions available
|
|
if len(self.actions) == 0:
|
|
raise NoPerformableActionError(
|
|
'The pipeline does not contain any performable actions.'
|
|
)
|
|
|
|
def post_run(self) -> None:
|
|
logger.info(
|
|
'Processing pipeline >>%s<< successfully ended after %d steps.',
|
|
self.name,
|
|
(self.curr_proc_idx - 1),
|
|
)
|
|
|
|
@abstractmethod
|
|
def add(self) -> None: ...
|
|
|
|
@abstractmethod
|
|
def logic(self) -> None: ...
|
|
|
|
def run(self, *args, **kwargs) -> Any:
|
|
self.prep_run()
|
|
ret = self.logic(*args, **kwargs)
|
|
self.post_run()
|
|
return ret
|
|
|
|
|
|
class PipelineContainer(BasePipeline):
|
|
"""Container class for basic actions.
|
|
Basic actions are usually functions, which do not take any parameters
|
|
and return nothing. Indeed, if an action returns any values after its
|
|
procedure is finished, an error is raised. Therefore, PipelineContainers
|
|
can be seen as a concatenation of many (independent) simple procedures
|
|
which are executed in the order in which they were added to the pipe.
|
|
With a simple call of the ``run`` method the actions are performed.
|
|
Additionally, there is an option to skip actions which can be set in
|
|
the ``add`` method. This allows for easily configurable pipelines,
|
|
e.g., via a user configuration.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
working_dir: Path,
|
|
) -> None:
|
|
super().__init__(name=name, working_dir=working_dir)
|
|
|
|
self.action_skip: list[bool] = []
|
|
|
|
@override
|
|
def add(
|
|
self,
|
|
action: Callable,
|
|
skip: bool = False,
|
|
) -> None:
|
|
if isinstance(action, Callable):
|
|
self.actions.append(action)
|
|
self.action_names.append(action.__name__)
|
|
self.action_skip.append(skip)
|
|
else:
|
|
self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
|
|
|
|
@override
|
|
def logic(self) -> None:
|
|
for idx, (action, action_name) in enumerate(zip(self.actions, self.action_names)):
|
|
# loading
|
|
if self.action_skip[idx]:
|
|
logger.info('[No Calculation] Skipping >>%s<<...', action_name)
|
|
self.curr_proc_idx += 1
|
|
continue
|
|
# calculation
|
|
ret = action()
|
|
if ret is not None:
|
|
raise OutputInPipelineContainerError(
|
|
(
|
|
f'Output in PipelineContainers not allowed. Action {action_name} '
|
|
f'returned values in Container {self.name}.'
|
|
)
|
|
)
|
|
# processing tracking
|
|
self.curr_proc_idx += 1
|
|
|
|
|
|
class Pipeline(BasePipeline):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
working_dir: Path,
|
|
) -> None:
|
|
# init base class
|
|
super().__init__(name=name, working_dir=working_dir)
|
|
# name of pipeline
|
|
self.name = name
|
|
# working directory for pipeline == output path
|
|
self.working_dir = working_dir
|
|
# container for actions to perform during pass
|
|
self.actions_kwargs: list[dict[str, Any]] = []
|
|
self.save_results: ResultHandling = []
|
|
self.load_results: ResultHandling = []
|
|
# intermediate result
|
|
self._intermediate_result: tuple[Any, ...] | None = None
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f'{self.__class__.__name__}(name: {self.name}, '
|
|
f'working dir: {self.working_dir}, contents: {self.action_names})'
|
|
)
|
|
|
|
@override
|
|
def add(
|
|
self,
|
|
action: Callable,
|
|
action_kwargs: dict[str, Any] | None = None,
|
|
save_result: bool = False,
|
|
load_result: bool = False,
|
|
filename: str | None = None,
|
|
) -> None:
|
|
# check explicitly for function type
|
|
# if isinstance(action, FunctionType):
|
|
if action_kwargs is None:
|
|
action_kwargs = {}
|
|
|
|
if isinstance(action, Callable):
|
|
self.actions.append(action)
|
|
self.action_names.append(action.__name__)
|
|
self.actions_kwargs.append(action_kwargs.copy())
|
|
self.save_results.append((save_result, filename))
|
|
self.load_results.append((load_result, filename))
|
|
else:
|
|
self.panic_wrong_action_type(action=action, compatible_type=Callable.__name__)
|
|
|
|
def get_result_path(
|
|
self,
|
|
action_idx: int,
|
|
filename: str | None,
|
|
) -> tuple[Path, str]:
|
|
action_name = self.action_names[action_idx]
|
|
if filename is None:
|
|
target_filename = f'Pipe-{self.name}_Step-{self.curr_proc_idx}_{action_name}'
|
|
else:
|
|
target_filename = filename
|
|
target_path = self.working_dir.joinpath(target_filename).with_suffix('.pkl')
|
|
return target_path, action_name
|
|
|
|
def load_step(
|
|
self,
|
|
action_idx: int,
|
|
filename: str | None,
|
|
) -> tuple[Any, ...]:
|
|
target_path, action_name = self.get_result_path(action_idx, filename)
|
|
|
|
if not target_path.exists():
|
|
raise FileNotFoundError(
|
|
(
|
|
f'No intermediate results for action >>{action_name}<< '
|
|
f'under >>{target_path}<< found'
|
|
)
|
|
)
|
|
# results should be tuple, but that is not guaranteed
|
|
result_loaded = cast(tuple[Any, ...], load_pickle(target_path))
|
|
if not isinstance(result_loaded, tuple):
|
|
raise TypeError(f'Loaded results must be tuple, not {type(result_loaded)}')
|
|
|
|
return result_loaded
|
|
|
|
def save_step(
|
|
self,
|
|
action_idx: int,
|
|
filename: str | None,
|
|
) -> None:
|
|
target_path, _ = self.get_result_path(action_idx, filename)
|
|
save_pickle(obj=self._intermediate_result, path=target_path)
|
|
|
|
@override
|
|
def logic(
|
|
self,
|
|
starting_values: tuple[Any, ...] | None = None,
|
|
) -> tuple[Any, ...]:
|
|
for idx, (action, action_kwargs) in enumerate(zip(self.actions, self.actions_kwargs)):
|
|
# loading
|
|
if self.load_results[idx][0]:
|
|
filename = self.load_results[idx][1]
|
|
ret = self.load_step(action_idx=idx, filename=filename)
|
|
self._intermediate_result = ret
|
|
logger.info(
|
|
'[No Calculation] Loaded result for action >>%s<< successfully',
|
|
self.action_names[idx],
|
|
)
|
|
self.curr_proc_idx += 1
|
|
continue
|
|
# calculation
|
|
if idx == 0:
|
|
args = starting_values
|
|
else:
|
|
args = ret
|
|
|
|
if args is not None:
|
|
ret = action(*args, **action_kwargs)
|
|
else:
|
|
ret = action(**action_kwargs)
|
|
|
|
if ret is not None and not isinstance(ret, tuple):
|
|
ret = (ret,)
|
|
ret = cast(tuple[Any, ...], ret)
|
|
# save intermediate result
|
|
self._intermediate_result = ret
|
|
# saving result locally, always save last action
|
|
if self.save_results[idx][0] or idx == (len(self.actions) - 1):
|
|
filename = self.save_results[idx][1]
|
|
self.save_step(action_idx=idx, filename=filename)
|
|
# processing tracking
|
|
self.curr_proc_idx += 1
|
|
|
|
return ret
|