Source code for batter.exec.local

"""Execution backend for running pipelines locally."""

from __future__ import annotations

import os
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple

from loguru import logger
from joblib import Parallel, delayed

from batter.exec.base import ExecBackend
from batter.pipeline.pipeline import Pipeline
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem

Handler = Callable[[Step, SimSystem, Mapping], ExecResult]


def _run_pipeline_task(
    pipeline: Pipeline,
    backend: "LocalBackend",
    sys: SimSystem,
) -> Tuple[str, Mapping[str, ExecResult] | None, BaseException | None]:
    """Execute ``pipeline`` for a single system.

    Parameters
    ----------
    pipeline :
        Pipeline instance to execute.
    backend :
        Backend used to dispatch individual steps.
    sys :
        Simulation system descriptor.

    Returns
    -------
    tuple of (str, Mapping[str, ExecResult] or None, BaseException or None)
        Three-tuple containing the system name, the step results if successful,
        and the raised exception otherwise. The structure is joblib-friendly.
    """
    try:
        results = pipeline.run(backend, sys)
        return sys.name, results, None
    except BaseException as exc:  # pragma: no cover - propagated to parent
        return sys.name, None, exc


[docs] @dataclass class LocalBackend(ExecBackend): """In-process execution backend with optional parallel orchestration. Parameters ---------- max_workers : int, optional Maximum number of worker processes to use when :meth:`run_parallel` is invoked. ``None`` lets the backend auto-detect resources; ``0`` or ``1`` forces serial execution. """ name: str = "local" _handlers: Dict[str, Handler] = field(default_factory=dict) _max_workers: Optional[int] = None def __init__(self, max_workers: Optional[int] = None): object.__setattr__(self, "name", "local") object.__setattr__(self, "_handlers", {}) object.__setattr__(self, "_max_workers", max_workers) # ---------- registry ----------
[docs] def register(self, step_name: str, handler: Handler) -> None: """Register a callable to execute ``step_name``. Parameters ---------- step_name : str Identifier of the step (matches :class:`batter.pipeline.step.Step.name`). handler : Callable[[Step, SimSystem, Mapping], ExecResult] Function responsible for executing the step. """ self._handlers[step_name] = handler
# ---------- ExecBackend ----------
[docs] def run(self, step: Step, system: SimSystem, params: Mapping) -> ExecResult: """Execute ``step`` for ``system`` on the local machine. Parameters ---------- step : Pipeline step metadata. system : Simulation system descriptor. params : Step parameters, typically generated by the orchestration layer. Returns ------- ExecResult Artifacts and job identifiers (empty for local execution). """ handler = self._handlers.get(step.name) if handler is None: logger.debug("LOCAL: no handler for step {!r}; treating as no-op.", step.name) return ExecResult(job_ids=[], artifacts={}) logger.debug("LOCAL: executing step {!r}", step.name) return handler(step, system, params)
# ---------- parallel pipeline runner (process-based via joblib) ----------
[docs] def run_parallel( self, pipeline: Pipeline, systems: Iterable[SimSystem], *, max_workers: Optional[int] = None, description: str = "", batch_size: str | int = "auto", verbose: int = 0, prefer: str = "processes", backend: Optional[str] = None, ) -> Dict[str, Mapping[str, ExecResult]]: """Execute ``pipeline`` for multiple systems in parallel. Parameters ---------- pipeline : Pipeline object providing the sequence of steps to execute. systems : Iterable[SimSystem] Collection of systems to process. max_workers : int, optional Override the configured worker cap; ``None`` falls back to the value provided at construction time. description : str, optional Human-readable label used in debug logging. batch_size, verbose, prefer, backend : Joblib configuration knobs forwarded to :class:`joblib.Parallel`. Returns ------- dict Mapping of ``system.name`` to per-step results. Raises ------ RuntimeError When one or more systems fail. """ systems = list(systems) if not systems: return {} worker_cap = max_workers if max_workers is not None else self._max_workers if worker_cap in (0, 1): logger.debug( "LOCAL(parallel): running serially for {} system(s) (max_workers={}) — {}", len(systems), worker_cap, description, ) out: Dict[str, Mapping[str, ExecResult]] = {} errors: Dict[str, BaseException] = {} for sys in systems: try: out[sys.name] = pipeline.run(self, sys) except BaseException as exc: # pragma: no cover - passthrough errors[sys.name] = exc if errors: logger.error( "LOCAL(parallel-serial): {} system(s) failed: {}", len(errors), ", ".join(errors), ) raise RuntimeError( "LOCAL(parallel-serial): failures encountered for " f"{', '.join(errors)}" ) from next(iter(errors.values())) return out if worker_cap is None: cpu_count = os.cpu_count() or 1 worker_cap = min(len(systems), cpu_count) else: worker_cap = min(worker_cap, len(systems)) logger.debug( "LOCAL(parallel): joblib(loky) with n_jobs={} for {} system(s) — {}", worker_cap, len(systems), description, ) results: List[Tuple[str, Mapping[str, ExecResult] | None, BaseException | None]] = Parallel( n_jobs=worker_cap, backend=backend, prefer=prefer, batch_size=batch_size, verbose=verbose, )( delayed(_run_pipeline_task)(pipeline, self, sys) for sys in systems ) out: Dict[str, Mapping[str, ExecResult]] = {} errors: Dict[str, BaseException] = {} for name, res, err in results: if err is None and res is not None: out[name] = res logger.debug("LOCAL(parallel): finished {}", name) else: errors[name] = err or RuntimeError("Unknown error") if errors: logger.error( "LOCAL(parallel): {} system(s) failed: {}", len(errors), ", ".join(errors.keys()), ) raise RuntimeError( "LOCAL(parallel): failures encountered for " f"{', '.join(errors.keys())}" ) from next(iter(errors.values())) return out