Source code for batter.pipeline.pipeline

from __future__ import annotations

from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Any, Dict, List

from .step import Step, ExecResult


__all__ = ["Pipeline", "PipelineState"]


[docs] @dataclass(slots=True) class PipelineState: """ In-memory state of a pipeline execution. Attributes ---------- results : dict[str, ExecResult] Per-step execution results. """ results: Dict[str, ExecResult] = field(default_factory=dict)
[docs] class Pipeline: """ Directed acyclic pipeline of :class:`Step` objects. Parameters ---------- steps : list[Step] Steps that form a DAG. Dependencies are given by ``Step.requires``. Notes ----- - A simple **topological sort** is performed before execution. - Backends must implement a ``run(step, system) -> ExecResult`` method. """ def __init__(self, steps: List[Step]) -> None: self.steps = steps self._step_by_name: Dict[str, Step] = {step.name: step for step in steps} self._graph: Dict[str, List[str]] = {} self._validate_unique_names() self._order = self._toposort() # ------------------- public -------------------
[docs] def run(self, backend, system) -> Dict[str, ExecResult]: """ Execute steps in topological order. Parameters ---------- backend Object providing ``run(step, system) -> ExecResult``. system The :class:`~batter.systems.core.SimSystem` descriptor. Returns ------- dict[str, ExecResult] Mapping from step name to execution result. Raises ------ RuntimeError If a required dependency has not been produced. """ state = PipelineState() for step in self._order: for req in step.requires: if req not in state.results: raise RuntimeError(f"Dependency {req!r} missing before running {step.name!r}") res = backend.run(step, system, step.params) # type: ignore[attr-defined] state.results[step.name] = res return state.results
[docs] def ordered_steps(self) -> List[Step]: """Return steps in execution order.""" return list(self._order)
[docs] def describe(self) -> List[Dict[str, Any]]: """ Return a serialisable summary of the pipeline. Returns ------- list of dict Each entry contains ``name``, ``requires``, and ``payload_type`` keys. """ summary: List[Dict[str, Any]] = [] for step in self._order: summary.append( { "name": step.name, "requires": list(step.requires), "payload_type": type(step.payload).__name__ if step.payload is not None else None, } ) return summary
[docs] def adjacency(self) -> Dict[str, List[str]]: """ Return the adjacency list describing the DAG. Returns ------- dict[str, list[str]] Mapping of each step to the steps that depend on it. """ return {name: list(children) for name, children in self._graph.items()}
[docs] def dependencies(self, step_name: str) -> List[str]: """ Retrieve the declared dependencies for ``step_name``. Parameters ---------- step_name : str Step identifier. Returns ------- list[str] Names of prerequisite steps. Raises ------ KeyError If ``step_name`` does not exist in the pipeline. """ try: return list(self._step_by_name[step_name].requires) except KeyError as exc: # pragma: no cover - defensive branch raise KeyError(f"Unknown step: {step_name}") from exc
# ------------------- internals ------------------- def _validate_unique_names(self) -> None: names = [s.name for s in self.steps] if len(names) != len(set(names)): dupes = {n for n in names if names.count(n) > 1} raise ValueError(f"Duplicate step names: {sorted(dupes)}") def _toposort(self) -> List[Step]: graph = defaultdict(list) # node -> children indeg = defaultdict(int) # node -> indegree nodes = self._step_by_name for s in self.steps: indeg.setdefault(s.name, 0) for r in s.requires: if r not in nodes: raise ValueError(f"Unknown dependency {r!r} for step {s.name!r}") graph[r].append(s.name) indeg[s.name] += 1 q = deque([nodes[n] for n, d in indeg.items() if d == 0]) order: List[Step] = [] while q: u = q.popleft() order.append(u) for v in graph[u.name]: indeg[v] -= 1 if indeg[v] == 0: q.append(nodes[v]) if len(order) != len(nodes): raise ValueError("Cycle detected in pipeline dependencies.") # freeze graph for later introspection (ensure every node is present) packed: Dict[str, List[str]] = {name: list(children) for name, children in graph.items()} for name in nodes: packed.setdefault(name, []) self._graph = packed return order