"""Execution backend for running pipelines locally."""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from loguru import logger
from joblib import Parallel, delayed
from batter.exec.base import ExecBackend
from batter.pipeline.pipeline import Pipeline
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
Handler = Callable[[Step, SimSystem, Mapping], ExecResult]
def _run_pipeline_task(
pipeline: Pipeline,
backend: "LocalBackend",
sys: SimSystem,
) -> Tuple[str, Mapping[str, ExecResult] | None, BaseException | None]:
"""Execute ``pipeline`` for a single system.
Parameters
----------
pipeline :
Pipeline instance to execute.
backend :
Backend used to dispatch individual steps.
sys :
Simulation system descriptor.
Returns
-------
tuple of (str, Mapping[str, ExecResult] or None, BaseException or None)
Three-tuple containing the system name, the step results if successful,
and the raised exception otherwise. The structure is joblib-friendly.
"""
try:
results = pipeline.run(backend, sys)
return sys.name, results, None
except BaseException as exc: # pragma: no cover - propagated to parent
return sys.name, None, exc
[docs]
@dataclass
class LocalBackend(ExecBackend):
"""In-process execution backend with optional parallel orchestration.
Parameters
----------
max_workers : int, optional
Maximum number of worker processes to use when :meth:`run_parallel`
is invoked. ``None`` lets the backend auto-detect resources; ``0`` or
``1`` forces serial execution.
"""
name: str = "local"
_handlers: Dict[str, Handler] = field(default_factory=dict)
_max_workers: Optional[int] = None
def __init__(self, max_workers: Optional[int] = None):
object.__setattr__(self, "name", "local")
object.__setattr__(self, "_handlers", {})
object.__setattr__(self, "_max_workers", max_workers)
# ---------- registry ----------
[docs]
def register(self, step_name: str, handler: Handler) -> None:
"""Register a callable to execute ``step_name``.
Parameters
----------
step_name : str
Identifier of the step (matches :class:`batter.pipeline.step.Step.name`).
handler : Callable[[Step, SimSystem, Mapping], ExecResult]
Function responsible for executing the step.
"""
self._handlers[step_name] = handler
# ---------- ExecBackend ----------
[docs]
def run(self, step: Step, system: SimSystem, params: Mapping) -> ExecResult:
"""Execute ``step`` for ``system`` on the local machine.
Parameters
----------
step :
Pipeline step metadata.
system :
Simulation system descriptor.
params :
Step parameters, typically generated by the orchestration layer.
Returns
-------
ExecResult
Artifacts and job identifiers (empty for local execution).
"""
handler = self._handlers.get(step.name)
if handler is None:
logger.debug("LOCAL: no handler for step {!r}; treating as no-op.", step.name)
return ExecResult(job_ids=[], artifacts={})
logger.debug("LOCAL: executing step {!r}", step.name)
return handler(step, system, params)
# ---------- parallel pipeline runner (process-based via joblib) ----------
[docs]
def run_parallel(
self,
pipeline: Pipeline,
systems: Iterable[SimSystem],
*,
max_workers: Optional[int] = None,
description: str = "",
batch_size: str | int = "auto",
verbose: int = 10,
prefer: str = "processes",
backend: Optional[str] = None,
) -> Dict[str, Mapping[str, ExecResult]]:
"""Execute ``pipeline`` for multiple systems in parallel.
Parameters
----------
pipeline :
Pipeline object providing the sequence of steps to execute.
systems : Iterable[SimSystem]
Collection of systems to process.
max_workers : int, optional
Override the configured worker cap; ``None`` falls back to the
value provided at construction time.
description : str, optional
Human-readable label used in debug logging.
batch_size, verbose, prefer, backend :
Joblib configuration knobs forwarded to :class:`joblib.Parallel`.
Returns
-------
dict
Mapping of ``system.name`` to per-step results.
Raises
------
RuntimeError
When one or more systems fail.
"""
systems = list(systems)
if not systems:
return {}
worker_cap = max_workers if max_workers is not None else self._max_workers
if worker_cap in (0, 1):
logger.debug(
"LOCAL(parallel): running serially for {} system(s) (max_workers={}) — {}",
len(systems),
worker_cap,
description,
)
out: Dict[str, Mapping[str, ExecResult]] = {}
errors: Dict[str, BaseException] = {}
for sys in systems:
try:
out[sys.name] = pipeline.run(self, sys)
except BaseException as exc: # pragma: no cover - passthrough
errors[sys.name] = exc
if errors:
logger.error(
"LOCAL(parallel-serial): {} system(s) failed: {}",
len(errors),
", ".join(errors),
)
raise RuntimeError(
"LOCAL(parallel-serial): failures encountered for "
f"{', '.join(errors)}"
) from next(iter(errors.values()))
return out
if worker_cap is None:
cpu_count = os.cpu_count() or 1
worker_cap = min(len(systems), cpu_count)
else:
worker_cap = min(worker_cap, len(systems))
logger.debug(
"LOCAL(parallel): joblib(loky) with n_jobs={} for {} system(s) — {}",
worker_cap,
len(systems),
description,
)
results: List[Tuple[str, Mapping[str, ExecResult] | None, BaseException | None]] = Parallel(
n_jobs=worker_cap,
backend=backend,
prefer=prefer,
batch_size=batch_size,
verbose=verbose,
max_nbytes=None,
)(
delayed(_run_pipeline_task)(pipeline, self, sys)
for sys in systems
)
out: Dict[str, Mapping[str, ExecResult]] = {}
errors: Dict[str, BaseException] = {}
for name, res, err in results:
if err is None and res is not None:
out[name] = res
logger.debug("LOCAL(parallel): finished {}", name)
else:
errors[name] = err or RuntimeError("Unknown error")
if errors:
logger.error(
"LOCAL(parallel): {} system(s) failed: {}",
len(errors),
", ".join(errors.keys()),
)
raise RuntimeError(
"LOCAL(parallel): failures encountered for "
f"{', '.join(errors.keys())}"
) from next(iter(errors.values()))
return out