from __future__ import annotations import multiprocessing as mp from collections.abc import Iterable, Sequence from typing import Any, TypeVar import psutil T = TypeVar("T") class MPPool: def __init__(self) -> None: self.num_workers = psutil.cpu_count(logical=False) or 4 self.pool = mp.Pool(processes=self.num_workers) def chunk_data( self, data: list[T], chunk_size: int | None = None, ) -> Sequence[Sequence[T]]: if chunk_size is None: chunk_size = max(1, len(data) // self.num_workers) chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] if len(chunks) > self.num_workers: open_chunk = chunks[-1] for idx, entry in enumerate(open_chunk): chunks[idx].append(entry) del chunks[-1] return chunks def stop(self) -> None: self.pool.close() self.pool.join()