from __future__ import annotations
from pathlib import Path
import json
import re
from typing import Any, Dict, Optional, Literal, List, Mapping, Iterable, Tuple
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator
from batter.config.simulation import PROTOCOL_TO_FE_TYPE, SimulationConfig
from batter.config.remd import RemdArgs
from batter.config.utils import (
coerce_yes_no,
expand_env_vars,
normalize_optional_path,
sanitize_user_ligand_name,
)
# ----------------------------- SLURM ---------------------------------
[docs]
class SlurmConfig(BaseModel):
"""
SLURM-specific configuration.
Parameters
----------
partition : str, optional
SLURM partition/queue name.
time : str, optional
Walltime in the ``HH:MM:SS`` format.
nodes : int, optional
Number of nodes to request.
ntasks_per_node : int, optional
Number of tasks per node.
mem_per_cpu : str, optional
Memory per CPU (e.g., ``16G``).
gres : str, optional
Generic resource string (e.g., GPU spec).
account : str, optional
Account to charge for jobs.
qos : str, optional
QoS string if required by the cluster.
constraint : str, optional
Constraint string passed to ``sbatch``.
extra_sbatch : list[str]
Additional arguments appended to the ``sbatch`` submission command.
"""
model_config = ConfigDict(extra="ignore")
partition: Optional[str] = Field(None, description="SLURM partition / queue")
time: Optional[str] = Field(None, description="Walltime, e.g. '04:00:00'")
nodes: Optional[int] = None
ntasks_per_node: Optional[int] = None
mem_per_cpu: Optional[str] = None
gres: Optional[str] = None
account: Optional[str] = None
qos: Optional[str] = None
constraint: Optional[str] = None
extra_sbatch: List[str] = Field(default_factory=list)
[docs]
def to_sbatch_flags(self) -> List[str]:
"""
Produce a flat list of ``sbatch`` command-line flags.
Returns
-------
list of str
Sequence suitable for passing to :func:`subprocess.run`.
"""
flags: List[str] = []
if self.partition:
flags += ["-p", self.partition]
if self.time:
flags += ["-t", self.time]
if self.nodes:
flags += ["-N", str(self.nodes)]
if self.ntasks_per_node:
flags += ["--ntasks-per-node", str(self.ntasks_per_node)]
if self.mem_per_cpu:
flags += ["--mem-per-cpu", self.mem_per_cpu]
if self.gres:
flags += ["--gres", self.gres]
if self.account:
flags += ["--account", self.account]
if self.qos:
flags += ["--qos", self.qos]
if self.constraint:
flags += ["--constraint", self.constraint]
flags += list(self.extra_sbatch or [])
return flags
# ----------------------------- Sections ---------------------------------
[docs]
class CreateArgs(BaseModel):
"""
Inputs for system creation and staging.
Notes
-----
This section mirrors the ``create`` block in the run YAML file.
"""
model_config = ConfigDict(extra="forbid")
system_name: Optional[str] = Field(
"unnamed_system",
description="Logical system name; used to label outputs when not provided.",
)
protein_input: Optional[Path] = Field(
None,
description="Path to the receptor structure (PDB/mmCIF).",
)
system_input: Optional[Path] = Field(
None,
description="Optional pre-built system topology (e.g., PRMTOP).",
)
system_coordinate: Optional[Path] = Field(
None,
description="Optional starting coordinates (e.g., INPCRD/RST7).",
)
protein_align: Optional[str] = Field(
"name CA",
description="Selection string used to align the protein prior to staging.",
)
# Ligand staging
ligand_paths: dict[str, Path | str] = Field(
default_factory=dict,
description="Mapping of ligand identifiers to structure files (relative paths are resolved at runtime).",
)
ligand_input: Optional[Path] = Field(
None,
description="Alternative JSON file describing ligands (dict or list).",
)
# Param settings
ligand_ff: str = Field(
"gaff2",
description="Ligand force field identifier passed to parameterization tools.",
)
retain_lig_prot: bool = Field(
True,
description="Whether to retain ligand protomers generated during staging.",
)
param_method: Literal["amber", "openff"] = Field(
"amber",
description="Parameterization backend to use for ligands.",
)
param_charge: str = Field(
"am1bcc",
description="Charge derivation method for ligands.",
)
param_outdir: Optional[Path] = Field(
None,
description="Optional override for the ligand parameter output directory.",
)
# Environment / anchors
anchor_atoms: list[str] = Field(
default_factory=list,
description="List of anchor atom selections used for restraint placement.",
)
lipid_mol: list[str] = Field(
default_factory=list,
description="Names of lipid molecules present in the system.",
)
other_mol: list[str] = Field(
default_factory=list,
description="Names of non-lipid cofactors or co-binders.",
)
overwrite: bool = Field(
True,
description="If true, overwrite existing artifacts in the staging directory.",
)
# Extra restraints
# position restraints on selected string
extra_restraints: Optional[str] = Field(
None,
description="Optional positional restraint specification string.",
)
extra_restraint_fc: float = Field(
10.0,
description="Force constant (kcal/mol/Å^2) applied to ``extra_restraints``.",
)
# additional conformational restraints file (NFE)
extra_conformation_restraints: Optional[Path] = Field(
None,
description="Path to conformational restraint JSON (used for NFE workflows).",
)
# Box / chemistry basics that are used before FE
receptor_ff: str = Field(
"protein.ff14SB",
description="Protein force-field identifier.",
)
lipid_ff: str = Field(
"lipid21",
description="Lipid force-field identifier.",
)
solv_shell: float = Field(
15.0,
description="Initial solvent shell radius (Å).",
)
cation: str = Field(
"Na+",
description="Cation species for ion placement.",
)
anion: str = Field(
"Cl-",
description="Anion species for ion placement.",
)
ion_conc: float = Field(
0.15,
description="Target salt concentration (M).",
)
neutralize_only: Literal["yes", "no"] = Field(
"no",
description='If ``"yes"``, neutralize the system without adding bulk salt.',
)
water_model: str = Field(
"TIP3P",
description="Water model used for solvation.",
)
l1_range: float = Field(
6.0,
description="Radius (Å) for L1 search when identifying pocket positions.",
)
min_adis: float = Field(
3.0,
description="Minimum anchor-atom distance used during pose selection (Å).",
)
max_adis: float = Field(
7.0,
description="Maximum anchor-atom distance used during pose selection (Å).",
)
@field_validator(
"protein_input",
"system_input",
"system_coordinate",
"param_outdir",
mode="before",
)
@classmethod
def _coerce_opt_paths(cls, v):
return normalize_optional_path(v)
@field_validator("ligand_paths", mode="before")
@classmethod
def _normalize_ligand_paths(cls, v):
"""
Accept:
- dict[str, path-like]
- list/tuple of path-like
- CSV string of paths
Normalize → dict[str, Path] with sanitized names from stems if needed.
"""
if v is None:
return {}
# CSV string
if isinstance(v, str):
items = [p.strip() for p in v.split(",") if p.strip()]
d: dict[str, Path] = {}
for s in items:
p = normalize_optional_path(s)
if p is None:
continue
d[sanitize_user_ligand_name(p.stem)] = p
return d
# mapping
if isinstance(v, Mapping):
out: dict[str, Path] = {}
for k, p in v.items():
path_obj = normalize_optional_path(p)
if path_obj is None:
continue
out[sanitize_user_ligand_name(str(k))] = path_obj
return out
# iterable of paths
if isinstance(v, Iterable):
d: dict[str, Path] = {}
for p in v:
path_obj = normalize_optional_path(p)
if path_obj is None:
continue
d[sanitize_user_ligand_name(path_obj.stem)] = path_obj
return d
raise ValueError(f"Unsupported ligand_paths type: {type(v).__name__}")
@field_validator("neutralize_only", mode="before")
@classmethod
def _coerce_create_yes_no(cls, v):
return coerce_yes_no(v)
@field_validator("ligand_input", mode="before")
@classmethod
def _coerce_ligand_input(cls, v):
return normalize_optional_path(v)
[docs]
def resolve_paths(self, base: Path) -> "CreateArgs":
"""
Return a copy where path fields are absolute relative to ``base``.
"""
updates: dict[str, object] = {}
path_fields = [
"protein_input",
"system_input",
"system_coordinate",
"param_outdir",
"ligand_input",
"extra_conformation_restraints",
]
for name in path_fields:
path_val = getattr(self, name)
if isinstance(path_val, Path) and not path_val.is_absolute():
updates[name] = (base / path_val).resolve()
if self.ligand_paths:
resolved = {}
for key, path_val in self.ligand_paths.items():
resolved[key] = (
(base / path_val).resolve()
if not path_val.is_absolute()
else path_val
)
updates["ligand_paths"] = resolved
return self.model_copy(update=updates)
@model_validator(mode="after")
def _require_ligands(self):
if not self.ligand_paths and not self.ligand_input:
raise ValueError(
"You must provide either `ligand_paths` or `ligand_input`."
)
return self
@model_validator(mode="after")
def _check_extra_restraints(self):
if self.extra_conformation_restraints and self.extra_restraints:
raise ValueError(
"Cannot specify both `extra_conformation_restraints` and `extra_restraints`."
)
if self.extra_conformation_restraints:
p = Path(self.extra_conformation_restraints)
if not p.exists():
raise ValueError(
f"extra_conformation_restraints file does not exist: {p}"
)
# (optional) schema check if you expect JSON:
try:
data = json.loads(p.read_text())
except Exception as e:
raise ValueError(f"Could not parse {p}: {e}")
if not isinstance(data, (list, tuple)) or not all(
isinstance(r, (list, tuple)) for r in data
):
raise ValueError(
"JSON must be a list of rows [dir, res1, res2, cutoff, k]."
)
return self
[docs]
class FESimArgs(BaseModel):
"""
Free-energy simulation knobs loaded from the ``fe_sim`` section.
The fields feed directly into :class:`batter.config.simulation.SimulationConfig`
overrides. ``fe_type`` is resolved internally from ``protocol`` rather than
being set by users.
"""
model_config = ConfigDict(extra="forbid")
@model_validator(mode="before")
@classmethod
def _reject_legacy_knobs(cls, data: Any) -> Any:
if not isinstance(data, Mapping):
return data
if "num_fe_extends" in data:
raise ValueError(
"fe_sim.num_fe_extends is no longer supported; set fe_sim.n_steps (or <comp>_n_steps) to the total production steps."
)
if "analysis_range" in data:
raise ValueError(
"fe_sim.analysis_range is no longer supported; set fe_sim.analysis_start_step to the first step to include in analysis."
)
return data
dec_int: str = Field(
"mbar",
description="Free-energy integration scheme (``mbar`` or ``ti``).",
)
remd: RemdArgs = Field(
default_factory=RemdArgs,
description="Replica-exchange MD controls (nstlim).",
)
rocklin_correction: Literal["yes", "no"] = Field(
"no",
description="Apply Rocklin correction during analysis.",
)
lambdas: List[float] = Field(
default_factory=list,
description="Default lambda schedule when component-specific overrides are not provided.",
)
component_lambdas: Dict[str, List[float]] = Field(
default_factory=dict,
description="Per-component lambda overrides (key = letter).",
)
blocks: int = Field(
0,
description="Number of MBAR blocks to use during analysis.",
)
lig_buffer: float = Field(
15.0,
description="Ligand-specific box buffer (Å) for solvation boxes.",
)
# Restraint forces
lig_distance_force: float = Field(
5.0,
description="Ligand COM distance restraint spring constant (kcal/mol/Å^2).",
)
lig_angle_force: float = Field(
250.0,
description="Ligand angle restraint spring constant (kcal/mol/rad^2).",
)
lig_dihcf_force: float = Field(
0.0,
description="Ligand dihedral restraint spring constant (kcal/mol/rad^2).",
)
rec_com_force: float = Field(
10.0,
description="Protein COM restraint spring constant (kcal/mol/Å^2).",
)
lig_com_force: float = Field(
10.0,
description="Ligand COM restraint spring constant (kcal/mol/Å^2).",
)
# Box padding (used by some builders)
buffer_x: float = Field(20.0, description="Box padding along X (Å).")
buffer_y: float = Field(20.0, description="Box padding along Y (Å).")
buffer_z: float = Field(20.0, description="Box padding along Z (Å).")
eq_steps: int = Field(
1_000_000,
ge=0,
description="Total equilibration steps (entire equilibration run).",
)
n_steps: Dict[str, int] = Field(
default_factory=lambda: {"x": 300_000, "y": 300_000},
description="Total production steps per component (key = letter).",
)
ntpr: int = Field(100, description="Energy print frequency.")
ntwr: int = Field(2_500, description="Restart write frequency.")
ntwe: int = Field(0, description="Energy write frequency (0 disables).")
ntwx: int = Field(25_000, description="Trajectory write frequency.")
cut: float = Field(9.0, description="Nonbonded cutoff (Å).")
gamma_ln: float = Field(1.0, description="Langevin gamma value (ps^-1).")
dt: float = Field(0.004, description="MD timestep (ps).")
hmr: Literal["yes", "no"] = Field(
"no", description="Hydrogen mass repartitioning toggle."
)
enable_mcwat: Literal["yes", "no"] = Field(
"yes",
description="Enable MC water exchange moves during equilibration (1 = on).",
)
temperature: float = Field(298.15, description="Simulation temperature (K).")
barostat: int = Field(2, description="Barostat selection (1=Berendsen, 2=MC).")
unbound_threshold: float = Field(
8.0,
ge=0.0,
description="Distance threshold (Å) used to flag ligands as unbound during equilibration analysis.",
)
analysis_start_step: int = Field(
0,
ge=0,
description="Only analyze FE production steps after this step (per window).",
)
n_bootstraps: int = Field(
0,
ge=0,
description="Number of MBAR bootstrap resamples used during FE analysis.",
)
@field_validator("rocklin_correction", "hmr", "enable_mcwat", mode="before")
@classmethod
def _coerce_fe_yes_no(cls, v):
return coerce_yes_no(v)
@field_validator("remd", mode="before")
@classmethod
def _coerce_remd(cls, v):
if isinstance(v, RemdArgs):
return v
if isinstance(v, Mapping):
return RemdArgs(**v)
if v is None:
return RemdArgs()
raise ValueError(
"fe_sim.remd only accepts REMD timing settings (nstlim); "
"use run.remd to enable or disable REMD submissions."
)
@field_validator("lambdas")
@classmethod
def _validate_lambdas(cls, v: List[float]) -> List[float]:
if not v:
return v
if any(left > right for left, right in zip(v, v[1:])):
raise ValueError("Lambda values must be in ascending order.")
return v
@field_validator(
"lig_distance_force",
"lig_angle_force",
"rec_com_force",
"lig_com_force",
)
@classmethod
def _validate_force_const(cls, value: float) -> float:
if value is None:
return value
if value <= 0.0:
raise ValueError("Force constants must be non-zero and positive.")
return value
@model_validator(mode="before")
@classmethod
def _ingest_component_lambda_fields(cls, data: Any) -> Any:
if not isinstance(data, Mapping):
return data
payload = dict(data)
comp_map = dict(payload.get("component_lambdas") or {})
def _parse_lambda_value(val: Any) -> List[float]:
if val is None:
return []
if isinstance(val, str):
parts = [p for p in re.split(r"[,\s]+", val.strip()) if p]
return [float(p) for p in parts]
if isinstance(val, (list, tuple)):
return [float(v) for v in val]
return [float(val)]
for key in list(payload.keys()):
m = re.match(r"^([a-z])_lambdas$", key)
if not m:
continue
comp = m.group(1)
comp_map.setdefault(comp, _parse_lambda_value(payload.pop(key)))
payload["component_lambdas"] = comp_map
return payload
@model_validator(mode="before")
@classmethod
def _ingest_legacy_step_fields(cls, data: Any) -> Any:
if not isinstance(data, Mapping):
return data
payload = dict(data)
n_steps = dict(payload.get("n_steps") or {})
# Allow legacy 'steps2' while migrating; raise on steps1
legacy_steps2 = dict(payload.pop("steps2", {}) or {})
for k, v in legacy_steps2.items():
n_steps.setdefault(k, v)
for key in list(payload.keys()):
m_n = re.match(r"^([a-z])_n_steps$", key)
if m_n:
comp = m_n.group(1)
val = payload.pop(key)
try:
val = int(val)
except Exception:
pass
n_steps.setdefault(comp, val)
continue
m = re.match(r"^([a-z])_steps([12])$", key)
if not m:
continue
comp, stage = m.groups()
val = payload.pop(key)
try:
val = int(val)
except Exception:
pass
if stage == "1":
raise ValueError(
f"{comp}_steps1 is no longer supported; set {comp}_n_steps to the total production steps."
)
n_steps.setdefault(comp, val)
payload["n_steps"] = n_steps
return payload
[docs]
class MDSimArgs(BaseModel):
"""
Simulation overrides used when ``protocol == "md"``.
These runs reuse the equilibration steps from ABFE but never schedule FE windows,
so only generic MD knobs are required (no lambdas, SDR restraints, etc.).
"""
model_config = ConfigDict(extra="forbid")
@model_validator(mode="before")
@classmethod
def _reject_legacy_knobs(cls, data: Any) -> Any:
if not isinstance(data, Mapping):
return data
if "num_fe_extends" in data:
raise ValueError(
"fe_sim.num_fe_extends is no longer supported; set fe_sim.n_steps (or <comp>_n_steps) to the total production steps."
)
if "analysis_range" in data:
raise ValueError(
"fe_sim.analysis_range is no longer supported; set fe_sim.analysis_start_step to the first step to include in analysis."
)
return data
dt: float = Field(0.004, description="MD timestep (ps).")
temperature: float = Field(298.15, description="Simulation temperature (K).")
eq_steps: int = Field(
100_000,
ge=0,
description="Total equilibration steps (entire equilibration run).",
)
ntpr: int = Field(100, description="Energy print frequency.")
ntwr: int = Field(10_000, description="Restart write frequency.")
ntwe: int = Field(0, description="Energy write frequency (0 disables).")
ntwx: int = Field(25_000, description="Trajectory write frequency.")
cut: float = Field(9.0, description="Nonbonded cutoff (Å).")
gamma_ln: float = Field(1.0, description="Langevin gamma value (ps^-1).")
barostat: int = Field(2, description="Barostat selection (1=Berendsen, 2=MC).")
hmr: Literal["yes", "no"] = Field(
"yes", description="Hydrogen mass repartitioning toggle."
)
enable_mcwat: Literal["yes", "no"] = Field(
"yes",
description="Enable MC water exchange moves during equilibration (1 = on).",
)
@field_validator("hmr", "enable_mcwat", mode="before")
@classmethod
def _coerce_hmr(cls, v):
return coerce_yes_no(v) or "no"
[docs]
class KartografMapperArgs(BaseModel):
"""Kartograf atom mapper option overrides for RBFE."""
model_config = ConfigDict(extra="forbid")
@model_validator(mode="before")
@classmethod
def _reject_hydrogen_mapping_overrides(cls, data: Any) -> Any:
if not isinstance(data, Mapping):
return data
restricted = {
"atom_map_hydrogens",
"map_hydrogens_on_hydrogens_only",
}.intersection(data)
if restricted:
fields = ", ".join(sorted(restricted))
raise ValueError(
f"rbfe.kartograf {fields} are fixed for AMBER compatibility and cannot be overridden."
)
return data
atom_max_distance: float = Field(
0.95,
description="Override KartografAtomMapper atom_max_distance.",
)
map_exact_ring_matches_only: bool = Field(
True,
description="Override KartografAtomMapper map_exact_ring_matches_only.",
)
allow_partial_fused_rings: bool = Field(
True,
description="Override KartografAtomMapper allow_partial_fused_rings.",
)
allow_bond_breaks: bool = Field(
False,
description="Override KartografAtomMapper allow_bond_breaks.",
)
filter_element_changes: bool = Field(
True,
description="Include BATTER's element-change mapping filter.",
)
filter_mismatched_attached_h_count: bool = Field(
False,
description="Include BATTER's attached-hydrogen-count mapping filter.",
)
[docs]
class LomapMapperArgs(BaseModel):
"""LoMap atom mapper option overrides for RBFE."""
model_config = ConfigDict(extra="forbid")
time: Optional[int] = Field(None, description="Override LomapAtomMapper time.")
threed: Optional[bool] = Field(None, description="Override LomapAtomMapper threed.")
max3d: Optional[float] = Field(None, description="Override LomapAtomMapper max3d.")
element_change: Optional[bool] = Field(
None,
description="Override LomapAtomMapper element_change.",
)
shift: Optional[bool] = Field(None, description="Override LomapAtomMapper shift.")
[docs]
class RBFENetworkArgs(BaseModel):
"""
RBFE network mapping controls.
Users can specify a mapping strategy by name (``mapping``) or provide
an explicit mapping file (``mapping_file``).
"""
model_config = ConfigDict(extra="forbid")
mapping: Optional[str] = Field(
"default",
description="Mapping strategy name (e.g., 'default', 'konnektor').",
)
atom_mapper: Literal["kartograf", "lomap"] = Field(
"kartograf",
description="Atom mapper backend for RBFE pair mapping ('kartograf' or 'lomap').",
)
kartograf: KartografMapperArgs = Field(
default_factory=KartografMapperArgs,
description="KartografAtomMapper option overrides.",
)
lomap: LomapMapperArgs = Field(
default_factory=LomapMapperArgs,
description="LomapAtomMapper option overrides.",
)
konnektor_layout: Optional[str] = Field(
None,
description="Optional Konnektor layout name (e.g., 'star', 'radial', 'maximal') used when mapping='konnektor'.",
)
both_directions: bool = Field(
False,
description="When true, run each mapped RBFE edge in both directions (A~B and B~A).",
)
mapping_file: Optional[Path] = Field(
None,
description="Optional path to a mapping file (JSON/YAML/text).",
)
[docs]
def resolve_paths(self, base: Path) -> "RBFENetworkArgs":
mf = self.mapping_file
if mf is not None and not mf.is_absolute():
mf = (base / mf).resolve()
return self.model_copy(update={"mapping_file": mf})
@field_validator("mapping", mode="before")
@classmethod
def _lower_mapping(cls, v):
if v is None:
return None
text = str(v).strip()
return text.lower() if text else None
@field_validator("atom_mapper", mode="before")
@classmethod
def _lower_atom_mapper(cls, v):
if v is None:
return "kartograf"
return str(v).strip().lower()
@model_validator(mode="after")
def _validate_mapping(self) -> "RBFENetworkArgs":
if self.mapping_file is None and not self.mapping:
raise ValueError(
"rbfe.mapping or rbfe.mapping_file must be provided."
)
return self
[docs]
class RunSection(BaseModel):
"""Run-related settings, including where outputs land."""
model_config = ConfigDict(extra="forbid")
output_folder: Path = Field(
...,
description="Directory where system artifacts and executions are stored.",
)
system_type: Literal["MABFE", "MASFE"] | None = Field(
None,
description=(
"Optional override for the system builder. When omitted, the orchestrator "
"chooses the builder based on the protocol."
),
)
only_fe_preparation: bool = Field(
False,
description="When true, stop the workflow after FE preparation.",
)
on_failure: Literal["raise", "prune", "retry"] = Field(
"raise",
description="Behavior on ligand failure: 'raise', 'prune', or 'retry' (clear FAILED sentinels and rerun once).",
)
max_workers: int | None = Field(
None,
description="Parallel workers for local backend (None = auto, 0 = serial).",
)
max_active_jobs: int | None = Field(
1000,
ge=0,
description="Max concurrent SLURM jobs for FE submissions (0 disables throttling).",
)
batch_mode: bool = Field(
False,
description="When true, run SLURM jobs inline via srun inside the manager allocation instead of submitting with sbatch.",
)
batch_gpus: int | None = Field(
None,
ge=0,
description="GPUs available to the manager process for batch_mode; auto-detected from SLURM env when omitted.",
)
batch_gpus_per_task: int = Field(
1,
ge=1,
description="GPUs to assign per task when batch_mode is enabled.",
)
batch_srun_extra: List[str] = Field(
default_factory=list,
description="Extra srun flags appended when launching tasks in batch_mode.",
)
dry_run: bool = Field(
False, description="Force dry-run mode regardless of YAML setting."
)
clean_failures: bool = Field(
False,
description="Clear FAILED markers, job_attempt.txt retry counters, and progress caches before rerunning.",
)
remd: Literal["yes", "no"] = Field(
"no",
description="Enable REMD execution.",
)
run_id: str = Field(
"auto", description="Run identifier to use (``auto`` picks latest)."
)
allow_run_id_mismatch: bool = Field(
False,
description=(
"When ``True``, allow reusing an explicit ``run_id`` even if the "
"configuration hash differs from the existing execution."
),
)
slurm_header_dir: Path | None = Field(
None,
description="Optional directory containing user Slurm headers (defaults to ~/.batter).",
)
email_sender: str = Field(
"nobody@stanford.edu",
description=("Sender address used for BATTER run-status email notifications."),
)
email_on_completion: str | None = Field(
None,
description=(
"Email address that should receive a notification once the run "
"finishes or aborts with an uncaught failure."
),
)
slurm: SlurmConfig = Field(default_factory=SlurmConfig)
[docs]
def resolve_paths(self, base: Path) -> "RunSection":
"""
Return a copy where ``output_folder`` is absolute relative to ``base``.
"""
folder = self.output_folder
if not folder.is_absolute():
folder = (base / folder).resolve()
hdr = self.slurm_header_dir
if hdr is not None and not hdr.is_absolute():
hdr = (base / hdr).resolve()
return self.model_copy(
update={
"output_folder": folder,
"slurm_header_dir": hdr,
"remd": coerce_yes_no(self.remd),
}
)
@field_validator("output_folder", mode="before")
@classmethod
def _coerce_output_folder(cls, v):
if v is None or (isinstance(v, str) and not v.strip()):
raise ValueError("`run.output_folder` is required.")
return Path(v)
@field_validator("remd", mode="before")
@classmethod
def _coerce_remd(cls, v):
return coerce_yes_no(v)
@field_validator("system_type", mode="before")
@classmethod
def _normalize_system_type(cls, value):
if value is None:
return None
text = str(value).strip().upper()
if text not in {"MABFE", "MASFE"}:
raise ValueError("run.system_type must be 'MABFE', 'MASFE', or omitted.")
return text
[docs]
class RunConfig(BaseModel):
"""Top-level YAML config."""
model_config = ConfigDict(extra="forbid")
version: int = Field(1, description="Schema version of the run configuration.")
protocol: Literal["abfe", "rbfe", "asfe", "md"] = Field(
"abfe", description="High-level protocol to execute."
)
backend: Literal["local", "slurm"] = Field(
"local", description="Execution backend."
)
create: CreateArgs = Field(..., description="Settings for system creation/staging.")
fe_sim: Dict[str, Any] | FESimArgs | MDSimArgs = Field(
default_factory=dict, description="Simulation parameter overrides."
)
run: RunSection = Field(
..., description="Execution controls and artifact destination."
)
rbfe: RBFENetworkArgs | None = Field(
default=None, description="RBFE network mapping configuration."
)
@model_validator(mode="after")
def _coerce_fe_sim_model(self) -> "RunConfig":
proto = getattr(self, "protocol", "abfe")
current = self.fe_sim
if proto == "md":
target = MDSimArgs
else:
target = FESimArgs
if isinstance(current, target):
return self
if isinstance(current, BaseModel):
payload = current.model_dump()
else:
payload = dict(current or {})
self.fe_sim = target.model_validate(payload)
return self
@model_validator(mode="after")
def _validate_rbfe_section(self) -> "RunConfig":
if self.rbfe is not None and self.protocol != "rbfe":
raise ValueError("The 'rbfe' section is only valid when protocol='rbfe'.")
return self
@field_validator("protocol", mode="before")
@classmethod
def _lower_protocol(cls, v):
return str(v).lower() if v else v
@field_validator("backend", mode="before")
@classmethod
def _lower_backend(cls, v):
return str(v).lower() if v else v
# disable backend slurm for now
@field_validator("backend", mode="before")
@classmethod
def _disable_slurm(cls, v):
if v == "slurm":
raise ValueError("SLURM backend is currently disabled.")
return v
[docs]
@classmethod
def load(cls, path: Path | str) -> "RunConfig":
"""Load and validate a run configuration from disk.
Parameters
----------
path : str or pathlib.Path
Location of the YAML file to parse.
Returns
-------
RunConfig
Fully validated configuration object.
"""
import yaml
p = Path(path)
data = yaml.safe_load(p.read_text()) or {}
data = expand_env_vars(data, base_dir=p.parent)
create_data = data.get("create") if isinstance(data, Mapping) else None
if isinstance(create_data, dict):
conf_restraints = create_data.get("extra_conformation_restraints")
if conf_restraints not in (None, ""):
conf_path = Path(str(conf_restraints))
if not conf_path.is_absolute():
create_data["extra_conformation_restraints"] = str(
(p.parent / conf_path).resolve()
)
cfg = cls.model_validate(data)
return cfg.with_base_dir(p.parent)
[docs]
@classmethod
def model_validate_yaml(cls, yaml_text: str) -> "RunConfig":
"""Validate a run configuration from an in-memory YAML string.
Parameters
----------
yaml_text : str
Raw YAML content describing the run configuration.
Returns
-------
RunConfig
Validated configuration model.
"""
import yaml
raw = yaml.safe_load(yaml_text) or {}
cfg = cls.model_validate(expand_env_vars(raw))
return cfg.with_base_dir(Path.cwd())
[docs]
def resolved_sim_config(self) -> SimulationConfig:
"""Build the effective simulation configuration for this run.
Returns
-------
SimulationConfig
Simulation parameters derived from ``create`` and ``fe_sim`` sections.
"""
fe_args = self.fe_sim
if self.protocol == "md":
if isinstance(fe_args, dict):
fe_args = MDSimArgs(**fe_args)
elif isinstance(fe_args, MDSimArgs):
pass
else:
fe_args = MDSimArgs.model_validate(fe_args)
else:
if isinstance(fe_args, dict):
fe_args = FESimArgs(**fe_args)
elif isinstance(fe_args, FESimArgs):
pass
else:
fe_args = FESimArgs.model_validate(fe_args)
desired_fe_type = PROTOCOL_TO_FE_TYPE.get(self.protocol)
return SimulationConfig.from_sections(
self.create,
fe_args,
protocol=self.protocol,
fe_type=desired_fe_type,
slurm_header_dir=self.run.slurm_header_dir,
run_remd=self.run.remd,
)
[docs]
def with_base_dir(self, base_dir: Path) -> "RunConfig":
"""
Return a copy with relative paths resolved against ``base_dir``.
"""
resolved_create = self.create.resolve_paths(base_dir)
resolved_run = self.run.resolve_paths(base_dir)
resolved_rbfe = self.rbfe.resolve_paths(base_dir) if self.rbfe else None
return self.model_copy(
update={
"create": resolved_create,
"run": resolved_run,
"rbfe": resolved_rbfe,
}
)
RunConfig.model_rebuild()