Source code for batter.pipeline.payloads

from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Mapping, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator

from batter.config.simulation import SimulationConfig


[docs] class SystemParams(BaseModel): """ System-level inputs shared by multiple pipeline steps. This wrapper normalises common fields (paths, anchor atoms, etc.) while still allowing arbitrary extra keys. Paths are converted to :class:`pathlib.Path` instances, making downstream usage safer. Parameters ---------- param_outdir : Path, optional Directory where ligand parameter outputs should be written. system_name : str, optional Logical system name propagated to child steps. protein_input, system_input, system_coordinate : Path, optional Paths to the protein topology/coordinate inputs if supplied. ligand_paths : dict[str, Path] Mapping of ligand identifiers to staged files. yaml_dir : Path, optional Directory containing the originating YAML (useful for resolving relatives). anchor_atoms : tuple[str, ...] Anchor atom labels used for restraint placement. extra_restraints : str, optional Optional positional restraint selection string. extra_restraint_fc : float, optional Force constant (kcal/mol/Å^2) applied to ``extra_restraints``. extra_conformation_restraints : Path, optional Path to a conformational restraint JSON file. """ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) param_outdir: Optional[Path] = None system_name: Optional[str] = None protein_input: Optional[Path] = None system_input: Optional[Path] = None system_coordinate: Optional[Path] = None ligand_paths: Dict[str, Path] = Field(default_factory=dict) yaml_dir: Optional[Path] = None anchor_atoms: tuple[str, ...] = () extra_restraints: Optional[str] = None extra_restraint_fc: Optional[float] = None extra_conformation_restraints: Optional[Path] = None @model_validator(mode="before") @classmethod def _coerce(cls, value: Any) -> Any: if isinstance(value, SystemParams): return value.to_mapping() if isinstance(value, Mapping): return dict(value) raise TypeError(f"Cannot construct SystemParams from {type(value)!r}") def __getitem__(self, item: str) -> Any: """ Return a field or extra value by key. Parameters ---------- item : str Key to fetch. Returns ------- Any Stored value for the key. Raises ------ KeyError If the key is not present. """ if item in type(self).model_fields: return getattr(self, item) if self.model_extra is not None and item in self.model_extra: return self.model_extra[item] raise KeyError(item)
[docs] def get(self, item: str, default: Any = None) -> Any: """ Safe lookup for a field or extra value with a default. Parameters ---------- item : str Key to fetch. default : Any, optional Value returned when the key is missing or None. Returns ------- Any Requested value or the default. """ if item in type(self).model_fields: value = getattr(self, item) return default if value is None else value if self.model_extra is not None: return self.model_extra.get(item, default) return default
[docs] def to_mapping(self) -> Dict[str, Any]: """ Convert the model (including extras) to a plain dictionary. Returns ------- dict[str, Any] Merged view of standard fields and extras. """ data = self.model_dump() if self.model_extra: data.update(self.model_extra) return data
[docs] def copy_with(self, **updates: Any) -> "SystemParams": """ Create a new :class:`SystemParams` with additional updates. Parameters ---------- **updates Keyword overrides applied atop the existing data. Returns ------- SystemParams A new instance incorporating the updates. """ data = self.to_mapping() data.update(updates) return SystemParams(**data)
[docs] class StepPayload(BaseModel): """ Typed payload passed to pipeline step handlers. The payload binds the :class:`~batter.config.simulation.SimulationConfig` and :class:`SystemParams` objects used by most handlers while permitting arbitrary extra values for backwards compatibility or specialised needs. Parameters ---------- sim : SimulationConfig, optional Resolved simulation configuration for the step. sys_params : SystemParams, optional Shared system-level parameters. """ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) sim: Optional[SimulationConfig] = None sys_params: Optional[SystemParams] = None @model_validator(mode="before") @classmethod def _coerce(cls, value: Any) -> Any: if isinstance(value, StepPayload): return value.to_mapping() if isinstance(value, Mapping): return dict(value) raise TypeError(f"Cannot construct StepPayload from {type(value)!r}") @model_validator(mode="after") def _coerce_nested(self) -> "StepPayload": if self.sys_params is not None and not isinstance(self.sys_params, SystemParams): object.__setattr__(self, "sys_params", SystemParams(self.sys_params)) if self.sim is not None and not isinstance(self.sim, SimulationConfig): object.__setattr__(self, "sim", SimulationConfig.model_validate(self.sim)) return self def __getitem__(self, item: str) -> Any: """ Return a stored value by key, searching typed fields first. Parameters ---------- item : str Key to fetch. Returns ------- Any Stored value. Raises ------ KeyError If the key is not present. """ if item in type(self).model_fields: return getattr(self, item) if self.model_extra is not None and item in self.model_extra: return self.model_extra[item] raise KeyError(item)
[docs] def get(self, item: str, default: Any = None) -> Any: """ Safe lookup for a payload value with a default. Parameters ---------- item : str Key to fetch. default : Any, optional Value returned when the key is missing or None. Returns ------- Any Requested value or the default. """ if item in type(self).model_fields: value = getattr(self, item) return default if value is None else value if self.model_extra is not None: return self.model_extra.get(item, default) return default
[docs] def to_mapping(self) -> Dict[str, Any]: """ Convert the payload (including extras) to a plain dictionary. Returns ------- dict[str, Any] Merged representation of fields and extras. """ data = self.model_dump() if self.model_extra: data.update(self.model_extra) return data
[docs] def copy_with(self, **updates: Any) -> "StepPayload": """ Create a new :class:`StepPayload` with additional updates. Parameters ---------- **updates Keyword overrides applied to the current payload. Returns ------- StepPayload New payload containing the merged data. """ data = self.to_mapping() data.update(updates) return StepPayload(**data)