Source code for batter.exec.local

"""Execution backend for running pipelines locally."""

from __future__ import annotations

import os
import traceback
from dataclasses import dataclass, field
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple

from loguru import logger
from joblib import Parallel, delayed

from batter.exec.base import ExecBackend
from batter.pipeline.pipeline import Pipeline
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem

Handler = Callable[[Step, SimSystem, Mapping], ExecResult]


def _run_pipeline_task(
    pipeline: Pipeline,
    backend: "LocalBackend",
    sys: SimSystem,
) -> Tuple[str, Mapping[str, ExecResult] | None, BaseException | None, str | None]:
    """Execute ``pipeline`` for a single system.

    Parameters
    ----------
    pipeline :
        Pipeline instance to execute.
    backend :
        Backend used to dispatch individual steps.
    sys :
        Simulation system descriptor.

    Returns
    -------
    tuple of (str, Mapping[str, ExecResult] or None, BaseException or None, str or None)
        Tuple containing the system name, the step results if successful, the
        raised exception otherwise, and formatted traceback text captured inside
        the worker process. The structure is joblib-friendly.
    """
    try:
        results = pipeline.run(backend, sys)
        return sys.name, results, None, None
    except BaseException as exc:  # pragma: no cover - propagated to parent
        return sys.name, None, exc, traceback.format_exc()


def _format_failure_detail(exc: BaseException, tb_text: str | None = None) -> str:
    """Return useful failure detail for parent-process logging."""
    text = (tb_text or "").strip()
    if text:
        return text
    return "".join(traceback.format_exception_only(type(exc), exc)).strip()


def _failure_summary_line(exc: BaseException, tb_text: str | None = None) -> str:
    """Return a compact one-line failure summary."""
    detail = _format_failure_detail(exc, tb_text)
    for line in reversed(detail.splitlines()):
        stripped = line.strip()
        if stripped:
            return stripped
    return repr(exc)


def _log_failures(
    prefix: str,
    errors: Mapping[str, BaseException],
    tracebacks: Mapping[str, str | None],
) -> None:
    """Log per-system failure details in the parent process."""
    logger.error(
        "{}: {} system(s) failed: {}",
        prefix,
        len(errors),
        ", ".join(errors.keys()),
    )
    for name, exc in errors.items():
        logger.error(
            "{}: failure details for {}\n{}",
            prefix,
            name,
            _format_failure_detail(exc, tracebacks.get(name)),
        )


def _parallel_failure_message(
    prefix: str,
    errors: Mapping[str, BaseException],
    tracebacks: Mapping[str, str | None],
) -> str:
    """Build an exception message that keeps per-system failure causes visible."""
    lines = [
        f"{prefix}: failures encountered for {', '.join(errors.keys())}",
        "Failure summaries:",
    ]
    for name, exc in errors.items():
        lines.append(f"- {name}: {_failure_summary_line(exc, tracebacks.get(name))}")
    return "\n".join(lines)


[docs] @dataclass class LocalBackend(ExecBackend): """In-process execution backend with optional parallel orchestration. Parameters ---------- max_workers : int, optional Maximum number of worker processes to use when :meth:`run_parallel` is invoked. ``None`` lets the backend auto-detect resources; ``0`` or ``1`` forces serial execution. """ name: str = "local" _handlers: Dict[str, Handler] = field(default_factory=dict) _max_workers: Optional[int] = None def __init__(self, max_workers: Optional[int] = None): object.__setattr__(self, "name", "local") object.__setattr__(self, "_handlers", {}) object.__setattr__(self, "_max_workers", max_workers) # ---------- registry ----------
[docs] def register(self, step_name: str, handler: Handler) -> None: """Register a callable to execute ``step_name``. Parameters ---------- step_name : str Identifier of the step (matches :class:`batter.pipeline.step.Step.name`). handler : Callable[[Step, SimSystem, Mapping], ExecResult] Function responsible for executing the step. """ self._handlers[step_name] = handler
# ---------- ExecBackend ----------
[docs] def run(self, step: Step, system: SimSystem, params: Mapping) -> ExecResult: """Execute ``step`` for ``system`` on the local machine. Parameters ---------- step : Pipeline step metadata. system : Simulation system descriptor. params : Step parameters, typically generated by the orchestration layer. Returns ------- ExecResult Artifacts and job identifiers (empty for local execution). """ handler = self._handlers.get(step.name) if handler is None: logger.debug("LOCAL: no handler for step {!r}; treating as no-op.", step.name) return ExecResult(job_ids=[], artifacts={}) logger.debug("LOCAL: executing step {!r}", step.name) return handler(step, system, params)
# ---------- parallel pipeline runner (process-based via joblib) ----------
[docs] def run_parallel( self, pipeline: Pipeline, systems: Iterable[SimSystem], *, max_workers: Optional[int] = None, description: str = "", batch_size: str | int = "auto", verbose: int = 10, prefer: str = "processes", backend: Optional[str] = None, ) -> Dict[str, Mapping[str, ExecResult]]: """Execute ``pipeline`` for multiple systems in parallel. Parameters ---------- pipeline : Pipeline object providing the sequence of steps to execute. systems : Iterable[SimSystem] Collection of systems to process. max_workers : int, optional Override the configured worker cap; ``None`` falls back to the value provided at construction time. description : str, optional Human-readable label used in debug logging. batch_size, verbose, prefer, backend : Joblib configuration knobs forwarded to :class:`joblib.Parallel`. Returns ------- dict Mapping of ``system.name`` to per-step results. Raises ------ RuntimeError When one or more systems fail. """ systems = list(systems) if not systems: return {} worker_cap = max_workers if max_workers is not None else self._max_workers if worker_cap in (0, 1): logger.debug( "LOCAL(parallel): running serially for {} system(s) (max_workers={}) — {}", len(systems), worker_cap, description, ) out: Dict[str, Mapping[str, ExecResult]] = {} errors: Dict[str, BaseException] = {} traces: Dict[str, str | None] = {} for sys in systems: try: out[sys.name] = pipeline.run(self, sys) except BaseException as exc: # pragma: no cover - passthrough errors[sys.name] = exc traces[sys.name] = traceback.format_exc() if errors: prefix = "LOCAL(parallel-serial)" _log_failures(prefix, errors, traces) raise RuntimeError( _parallel_failure_message(prefix, errors, traces) ) from next(iter(errors.values())) return out if worker_cap is None: cpu_count = os.cpu_count() or 1 worker_cap = min(len(systems), cpu_count) else: worker_cap = min(worker_cap, len(systems)) logger.debug( "LOCAL(parallel): joblib(loky) with n_jobs={} for {} system(s) — {}", worker_cap, len(systems), description, ) results: List[ Tuple[str, Mapping[str, ExecResult] | None, BaseException | None, str | None] ] = Parallel( n_jobs=worker_cap, backend=backend, prefer=prefer, batch_size=batch_size, verbose=verbose, max_nbytes=None, )( delayed(_run_pipeline_task)(pipeline, self, sys) for sys in systems ) out: Dict[str, Mapping[str, ExecResult]] = {} errors: Dict[str, BaseException] = {} traces: Dict[str, str | None] = {} for name, res, err, tb_text in results: if err is None and res is not None: out[name] = res logger.debug("LOCAL(parallel): finished {}", name) else: errors[name] = err or RuntimeError("Unknown error") traces[name] = tb_text if errors: prefix = "LOCAL(parallel)" _log_failures(prefix, errors, traces) raise RuntimeError( _parallel_failure_message(prefix, errors, traces) ) from next(iter(errors.values())) return out