"""Execution backend for running pipelines locally."""
from __future__ import annotations
import os
import traceback
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple
from loguru import logger
from joblib import Parallel, delayed
from batter.exec.base import ExecBackend
from batter.pipeline.pipeline import Pipeline
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
Handler = Callable[[Step, SimSystem, Mapping], ExecResult]
def _run_pipeline_task(
pipeline: Pipeline,
backend: "LocalBackend",
sys: SimSystem,
) -> Tuple[str, Mapping[str, ExecResult] | None, BaseException | None, str | None]:
"""Execute ``pipeline`` for a single system.
Parameters
----------
pipeline :
Pipeline instance to execute.
backend :
Backend used to dispatch individual steps.
sys :
Simulation system descriptor.
Returns
-------
tuple of (str, Mapping[str, ExecResult] or None, BaseException or None, str or None)
Tuple containing the system name, the step results if successful, the
raised exception otherwise, and formatted traceback text captured inside
the worker process. The structure is joblib-friendly.
"""
try:
results = pipeline.run(backend, sys)
return sys.name, results, None, None
except BaseException as exc: # pragma: no cover - propagated to parent
return sys.name, None, exc, traceback.format_exc()
def _format_failure_detail(exc: BaseException, tb_text: str | None = None) -> str:
"""Return useful failure detail for parent-process logging."""
text = (tb_text or "").strip()
if text:
return text
return "".join(traceback.format_exception_only(type(exc), exc)).strip()
def _failure_summary_line(exc: BaseException, tb_text: str | None = None) -> str:
"""Return a compact one-line failure summary."""
detail = _format_failure_detail(exc, tb_text)
for line in reversed(detail.splitlines()):
stripped = line.strip()
if stripped:
return stripped
return repr(exc)
def _log_failures(
prefix: str,
errors: Mapping[str, BaseException],
tracebacks: Mapping[str, str | None],
) -> None:
"""Log per-system failure details in the parent process."""
logger.error(
"{}: {} system(s) failed: {}",
prefix,
len(errors),
", ".join(errors.keys()),
)
for name, exc in errors.items():
logger.error(
"{}: failure details for {}\n{}",
prefix,
name,
_format_failure_detail(exc, tracebacks.get(name)),
)
def _parallel_failure_message(
prefix: str,
errors: Mapping[str, BaseException],
tracebacks: Mapping[str, str | None],
) -> str:
"""Build an exception message that keeps per-system failure causes visible."""
lines = [
f"{prefix}: failures encountered for {', '.join(errors.keys())}",
"Failure summaries:",
]
for name, exc in errors.items():
lines.append(f"- {name}: {_failure_summary_line(exc, tracebacks.get(name))}")
return "\n".join(lines)
[docs]
@dataclass
class LocalBackend(ExecBackend):
"""In-process execution backend with optional parallel orchestration.
Parameters
----------
max_workers : int, optional
Maximum number of worker processes to use when :meth:`run_parallel`
is invoked. ``None`` lets the backend auto-detect resources; ``0`` or
``1`` forces serial execution.
"""
name: str = "local"
_handlers: Dict[str, Handler] = field(default_factory=dict)
_max_workers: Optional[int] = None
def __init__(self, max_workers: Optional[int] = None):
object.__setattr__(self, "name", "local")
object.__setattr__(self, "_handlers", {})
object.__setattr__(self, "_max_workers", max_workers)
# ---------- registry ----------
[docs]
def register(self, step_name: str, handler: Handler) -> None:
"""Register a callable to execute ``step_name``.
Parameters
----------
step_name : str
Identifier of the step (matches :class:`batter.pipeline.step.Step.name`).
handler : Callable[[Step, SimSystem, Mapping], ExecResult]
Function responsible for executing the step.
"""
self._handlers[step_name] = handler
# ---------- ExecBackend ----------
[docs]
def run(self, step: Step, system: SimSystem, params: Mapping) -> ExecResult:
"""Execute ``step`` for ``system`` on the local machine.
Parameters
----------
step :
Pipeline step metadata.
system :
Simulation system descriptor.
params :
Step parameters, typically generated by the orchestration layer.
Returns
-------
ExecResult
Artifacts and job identifiers (empty for local execution).
"""
handler = self._handlers.get(step.name)
if handler is None:
logger.debug("LOCAL: no handler for step {!r}; treating as no-op.", step.name)
return ExecResult(job_ids=[], artifacts={})
logger.debug("LOCAL: executing step {!r}", step.name)
return handler(step, system, params)
# ---------- parallel pipeline runner (process-based via joblib) ----------
[docs]
def run_parallel(
self,
pipeline: Pipeline,
systems: Iterable[SimSystem],
*,
max_workers: Optional[int] = None,
description: str = "",
batch_size: str | int = "auto",
verbose: int = 10,
prefer: str = "processes",
backend: Optional[str] = None,
) -> Dict[str, Mapping[str, ExecResult]]:
"""Execute ``pipeline`` for multiple systems in parallel.
Parameters
----------
pipeline :
Pipeline object providing the sequence of steps to execute.
systems : Iterable[SimSystem]
Collection of systems to process.
max_workers : int, optional
Override the configured worker cap; ``None`` falls back to the
value provided at construction time.
description : str, optional
Human-readable label used in debug logging.
batch_size, verbose, prefer, backend :
Joblib configuration knobs forwarded to :class:`joblib.Parallel`.
Returns
-------
dict
Mapping of ``system.name`` to per-step results.
Raises
------
RuntimeError
When one or more systems fail.
"""
systems = list(systems)
if not systems:
return {}
worker_cap = max_workers if max_workers is not None else self._max_workers
if worker_cap in (0, 1):
logger.debug(
"LOCAL(parallel): running serially for {} system(s) (max_workers={}) — {}",
len(systems),
worker_cap,
description,
)
out: Dict[str, Mapping[str, ExecResult]] = {}
errors: Dict[str, BaseException] = {}
traces: Dict[str, str | None] = {}
for sys in systems:
try:
out[sys.name] = pipeline.run(self, sys)
except BaseException as exc: # pragma: no cover - passthrough
errors[sys.name] = exc
traces[sys.name] = traceback.format_exc()
if errors:
prefix = "LOCAL(parallel-serial)"
_log_failures(prefix, errors, traces)
raise RuntimeError(
_parallel_failure_message(prefix, errors, traces)
) from next(iter(errors.values()))
return out
if worker_cap is None:
cpu_count = os.cpu_count() or 1
worker_cap = min(len(systems), cpu_count)
else:
worker_cap = min(worker_cap, len(systems))
logger.debug(
"LOCAL(parallel): joblib(loky) with n_jobs={} for {} system(s) — {}",
worker_cap,
len(systems),
description,
)
results: List[
Tuple[str, Mapping[str, ExecResult] | None, BaseException | None, str | None]
] = Parallel(
n_jobs=worker_cap,
backend=backend,
prefer=prefer,
batch_size=batch_size,
verbose=verbose,
max_nbytes=None,
)(
delayed(_run_pipeline_task)(pipeline, self, sys)
for sys in systems
)
out: Dict[str, Mapping[str, ExecResult]] = {}
errors: Dict[str, BaseException] = {}
traces: Dict[str, str | None] = {}
for name, res, err, tb_text in results:
if err is None and res is not None:
out[name] = res
logger.debug("LOCAL(parallel): finished {}", name)
else:
errors[name] = err or RuntimeError("Unknown error")
traces[name] = tb_text
if errors:
prefix = "LOCAL(parallel)"
_log_failures(prefix, errors, traces)
raise RuntimeError(
_parallel_failure_message(prefix, errors, traces)
) from next(iter(errors.values()))
return out