95 lines
2.5 KiB
Python
95 lines
2.5 KiB
Python
import base64
|
|
import pickle
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from lang_main.loggers import logger_shared_helpers as logger
|
|
|
|
|
|
# ** Lib
|
|
def create_saving_folder(
|
|
saving_path_folder: str | Path,
|
|
overwrite_existing: bool = False,
|
|
) -> None:
|
|
# check for existence of given path
|
|
if isinstance(saving_path_folder, str):
|
|
saving_path_folder = Path(saving_path_folder)
|
|
if not saving_path_folder.exists():
|
|
saving_path_folder.mkdir(parents=True)
|
|
else:
|
|
if overwrite_existing:
|
|
# overwrite if desired (deletes whole path and re-creates it)
|
|
shutil.rmtree(saving_path_folder)
|
|
saving_path_folder.mkdir(parents=True)
|
|
else:
|
|
logger.info(
|
|
(
|
|
'Path >>%s<< already exists and remained unchanged. If you want to '
|
|
'overwrite this path, use parameter >>overwrite_existing<<.',
|
|
),
|
|
saving_path_folder,
|
|
)
|
|
|
|
|
|
# saving and loading using pickle
|
|
# careful: pickling from unknown sources can be dangerous
|
|
def save_pickle(
|
|
obj: Any,
|
|
path: str | Path,
|
|
) -> None:
|
|
with open(path, 'wb') as file:
|
|
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
|
|
logger.info('Saved file successfully under %s', path)
|
|
|
|
|
|
def load_pickle(
|
|
path: str | Path,
|
|
) -> Any:
|
|
with open(path, 'rb') as file:
|
|
obj = pickle.load(file)
|
|
logger.info('Loaded file successfully.')
|
|
return obj
|
|
|
|
|
|
def encode_to_base64_str(
|
|
obj: Any,
|
|
encoding: str = 'utf-8',
|
|
) -> str:
|
|
serialised = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
|
b64_bytes = base64.b64encode(serialised)
|
|
return b64_bytes.decode(encoding=encoding)
|
|
|
|
|
|
def encode_file_to_base64_str(
|
|
path: Path,
|
|
encoding: str = 'utf-8',
|
|
) -> str:
|
|
with open(path, 'rb') as file:
|
|
b64_bytes = base64.b64encode(file.read())
|
|
return b64_bytes.decode(encoding=encoding)
|
|
|
|
|
|
def decode_from_base64_str(
|
|
b64_str: str,
|
|
encoding: str = 'utf-8',
|
|
) -> Any:
|
|
b64_bytes = b64_str.encode(encoding=encoding)
|
|
decoded = base64.b64decode(b64_bytes)
|
|
return pickle.loads(decoded)
|
|
|
|
|
|
def get_entry_point(
|
|
saving_path: Path,
|
|
filename: str,
|
|
file_ext: str = '.pkl',
|
|
check_existence: bool = True,
|
|
) -> Path:
|
|
entry_point_path = (saving_path / filename).with_suffix(file_ext)
|
|
if check_existence and not entry_point_path.exists():
|
|
raise FileNotFoundError(
|
|
f'Could not find provided entry data under path: >>{entry_point_path}<<'
|
|
)
|
|
|
|
return entry_point_path
|