from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Literal, TYPE_CHECKING, Tuple
from pathlib import Path
from pydantic import (
BaseModel,
Field,
ConfigDict,
PrivateAttr,
field_validator,
model_validator,
)
import re
import os
from loguru import logger
from batter.utils import COMPONENTS_LAMBDA_DICT
from batter.config.utils import coerce_yes_no
from batter.config.remd import RemdArgs
if TYPE_CHECKING:
from batter.config.run import CreateArgs, FESimArgs
FEP_COMPONENTS = list(COMPONENTS_LAMBDA_DICT.keys())
_ANCHOR_RE = re.compile(r"^:?\d+@[\w\d]+$") # e.g., ":85@CA" or "85@CA"
MEMBRANE_EXEMPT_COMPONENTS = {"y", "m"}
PROTOCOL_TO_FE_TYPE = {
"abfe": "uno_rest",
"rbfe": "relative",
"asfe": "asfe",
"md": "md",
}
[docs]
class SimulationConfig(BaseModel):
"""
Simulation configuration for ABFE/ASFE/RBFE workflows.
Values are fed by RunConfig.resolved_sim_config(), which merges `create:` and `fe_sim:`.
"""
[docs]
@classmethod
def from_sections(
cls,
create: "CreateArgs",
fe: "FESimArgs",
*,
protocol: str | None = None,
fe_type: str | None = None,
slurm_header_dir: Path | None = None,
run_remd: str | bool | None = None,
) -> "SimulationConfig":
"""Construct a :class:`SimulationConfig` from run sections.
Parameters
----------
create : CreateArgs
System creation inputs taken from the ``create`` YAML section.
fe : FESimArgs
Free-energy simulation overrides from the ``fe_sim`` section.
run_remd : {"yes","no"}, optional
Whether REMD execution is enabled (controls submission only; REMD inputs are
always written during preparation).
Returns
-------
SimulationConfig
Fully merged simulation configuration ready for downstream use.
"""
l1_range = create.l1_range if create.l1_range is not None else 6.0
min_adis = create.min_adis if create.min_adis is not None else 3.0
max_adis = create.max_adis if create.max_adis is not None else 7.0
create_data: dict[str, Any] = {
"system_name": create.system_name or "unnamed_system",
"receptor_ff": create.receptor_ff,
"ligand_ff": create.ligand_ff,
"lipid_ff": create.lipid_ff,
"lipid_mol": list(create.lipid_mol or []),
"other_mol": list(create.other_mol or []),
"water_model": create.water_model,
"neutralize_only": coerce_yes_no(create.neutralize_only),
"ion_conc": float(create.ion_conc),
"cation": create.cation,
"anion": create.anion,
"solv_shell": float(create.solv_shell),
"protein_align": create.protein_align,
"l1_range": float(l1_range),
"min_adis": float(min_adis),
"max_adis": float(max_adis),
}
def _fe_attr(name: str, default):
if hasattr(fe, name):
return getattr(fe, name)
if isinstance(fe, Mapping) and name in fe:
return fe[name]
return default() if callable(default) else default
resolved_fe_type = fe_type
if resolved_fe_type is None and protocol:
resolved_fe_type = PROTOCOL_TO_FE_TYPE.get(protocol.lower())
if resolved_fe_type is None:
resolved_fe_type = _fe_attr("fe_type", lambda: None)
if resolved_fe_type is None:
resolved_fe_type = "md"
proto_key = (protocol or "").lower()
def _coerce_step_dict(name: str, mapping: Mapping[str, Any]) -> dict[str, int]:
out: dict[str, int] = {}
for comp, value in (mapping or {}).items():
if not isinstance(comp, str):
raise ValueError(
f"{name} keys must be single-letter component codes; got {comp!r}"
)
comp_key = comp.strip().lower()
if len(comp_key) != 1:
raise ValueError(
f"{name} keys must be single letters (got {comp!r})."
)
if comp_key not in FEP_COMPONENTS:
raise ValueError(
f"Unknown component '{comp_key}' in {name}; valid components: {', '.join(sorted(FEP_COMPONENTS))}."
)
out[comp_key] = int(value)
return out
def _coerce_lambda_list(name: str, seq: Any) -> List[float]:
if seq is None:
return []
if isinstance(seq, str):
parts = [p for p in re.split(r"[,\s]+", seq.strip()) if p]
out = [float(p) for p in parts]
elif isinstance(seq, (list, tuple)):
out = [float(x) for x in seq]
else:
out = [float(seq)]
if out and any(left > right for left, right in zip(out, out[1:])):
raise ValueError(f"{name} values must be in ascending order.")
return out
n_steps = _coerce_step_dict(
"n_steps",
dict(_fe_attr("n_steps", lambda: {"x": 300_000, "y": 300_000}) or {}),
)
if proto_key == "rbfe" and "x" not in n_steps:
n_steps["x"] = 300_000
base_lambdas = _coerce_lambda_list("lambdas", _fe_attr("lambdas", list) or [])
component_lambda_map: dict[str, List[float]] = {}
raw_component_lambdas = dict(_fe_attr("component_lambdas", dict) or {})
for comp, seq in raw_component_lambdas.items():
comp_key = str(comp).strip().lower()
if comp_key not in FEP_COMPONENTS:
raise ValueError(
f"Unknown component '{comp_key}' in component_lambdas; valid components: {', '.join(sorted(FEP_COMPONENTS))}."
)
component_lambda_map[comp_key] = _coerce_lambda_list(
f"component_lambdas['{comp_key}']", seq
)
required_components = {
"abfe": ["z"],
"asfe": ["y", "m"],
"rbfe": ["x"],
}.get(proto_key, [])
for comp in required_components:
if comp not in n_steps:
raise ValueError(
f"{proto_key.upper()} protocol requires steps for component '{comp}'. Add {comp}_n_steps."
)
if n_steps[comp] <= 0:
raise ValueError(
f"{proto_key.upper()} protocol requires positive steps for component '{comp}'."
)
eq_steps_raw = int(_fe_attr("eq_steps", lambda: 1_000_000))
if eq_steps_raw <= 2500 and eq_steps_raw != 0:
logger.warning(
"Setting fe_sim.eq_steps to 2500 (minimum allowed non-zero value). "
"If you don't want equilibration, set fe_sim.eq_steps=0."
)
eq_steps_raw = 2500
eq_steps_value = eq_steps_raw
fe_release_eq = [0.0]
extra_conf_rest = create.extra_conformation_restraints
extra_restraints = create.extra_restraints
analysis_start_step_val = 0
if hasattr(fe, "analysis_start_step"):
analysis_start_step_val = int(getattr(fe, "analysis_start_step") or 0)
elif isinstance(fe, Mapping) and "analysis_start_step" in fe:
analysis_start_step_val = int(fe.get("analysis_start_step") or 0)
if analysis_start_step_val < 0:
raise ValueError("analysis_start_step must be >= 0.")
n_bootstraps_val = 0
if hasattr(fe, "n_bootstraps"):
n_bootstraps_val = int(getattr(fe, "n_bootstraps") or 0)
elif isinstance(fe, Mapping) and "n_bootstraps" in fe:
n_bootstraps_val = int(fe.get("n_bootstraps") or 0)
if n_bootstraps_val < 0:
raise ValueError("n_bootstraps must be >= 0.")
max_fe_steps = max((int(v) for v in n_steps.values() if v is not None), default=0)
if max_fe_steps and analysis_start_step_val >= max_fe_steps:
raise ValueError(
f"analysis_start_step ({analysis_start_step_val}) must be smaller than fe_total_step ({max_fe_steps})."
)
remd_settings = _fe_attr("remd", lambda: RemdArgs())
if isinstance(remd_settings, RemdArgs):
remd_nstlim = int(remd_settings.nstlim)
elif isinstance(remd_settings, Mapping):
remd_settings = RemdArgs(**remd_settings)
remd_nstlim = int(remd_settings.nstlim)
else:
raise ValueError(
"fe_sim.remd must be a mapping of REMD settings (nstlim); "
"toggle execution with run.remd."
)
remd_flag = coerce_yes_no(run_remd or "no") or "no"
fe_data: dict[str, Any] = {
"fe_type": resolved_fe_type,
"dec_int": _fe_attr("dec_int", lambda: "mbar"),
"remd": remd_flag,
"remd_nstlim": remd_nstlim,
"rocklin_correction": coerce_yes_no(
_fe_attr("rocklin_correction", lambda: "no")
),
"enable_mcwat": coerce_yes_no(_fe_attr("enable_mcwat", lambda: "yes")),
"lambdas": base_lambdas,
"component_windows": component_lambda_map,
"blocks": int(_fe_attr("blocks", lambda: 0)),
"lig_buffer": float(_fe_attr("lig_buffer", lambda: 15.0)),
"lig_distance_force": float(_fe_attr("lig_distance_force", lambda: 5.0)),
"lig_angle_force": float(_fe_attr("lig_angle_force", lambda: 250.0)),
"lig_dihcf_force": float(_fe_attr("lig_dihcf_force", lambda: 0.0)),
"rec_com_force": float(_fe_attr("rec_com_force", lambda: 10.0)),
"lig_com_force": float(_fe_attr("lig_com_force", lambda: 10.0)),
"buffer_x": float(_fe_attr("buffer_x", lambda: 10.0)),
"buffer_y": float(_fe_attr("buffer_y", lambda: 10.0)),
"buffer_z": float(_fe_attr("buffer_z", lambda: 15.0)),
"temperature": float(_fe_attr("temperature", lambda: 298.15)),
"dt": float(_fe_attr("dt", lambda: 0.004)),
"hmr": coerce_yes_no(_fe_attr("hmr", lambda: "yes")),
"release_eq": fe_release_eq,
"eq_steps": eq_steps_value,
"ntpr": int(_fe_attr("ntpr", lambda: 100)),
"ntwr": int(_fe_attr("ntwr", lambda: 10_000)),
"ntwe": int(_fe_attr("ntwe", lambda: 0)),
"ntwx": int(_fe_attr("ntwx", lambda: 50_000)),
"cut": float(_fe_attr("cut", lambda: 9.0)),
"gamma_ln": float(_fe_attr("gamma_ln", lambda: 1.0)),
"barostat": int(_fe_attr("barostat", lambda: 2)),
"unbound_threshold": float(_fe_attr("unbound_threshold", lambda: 8.0)),
"analysis_start_step": analysis_start_step_val,
"n_bootstraps": n_bootstraps_val,
"slurm_header_dir": Path(slurm_header_dir or (Path.home() / ".batter")),
}
infe_flag = bool(extra_conf_rest)
if extra_conf_rest:
fe_data["barostat"] = 2
elif extra_restraints is not None:
fe_data["barostat"] = 1
n_steps_dict: dict[str, int] = {}
for comp in sorted(n_steps):
n_steps_dict[f"{comp}_n_steps"] = int(n_steps.get(comp, 0))
merged: dict[str, Any] = {
**create_data,
**fe_data,
"n_steps_dict": n_steps_dict,
"infe": infe_flag,
}
return cls(**merged)
# model_config = ConfigDict(extra="ignore", populate_by_name=True, validate_default=True)
# --- Required / core ---
system_name: str = Field(..., description="System name (required)")
fe_type: Literal[
"custom",
"rest",
"sdr",
"dd",
"sdr-rest",
"express",
"relative",
"uno",
"uno_com",
"uno_rest",
"self",
"uno_dd",
"dd-rest",
"asfe",
"md",
] = Field(..., description="Free-energy protocol type")
# --- Global switches ---
dec_int: Literal["mbar", "ti"] = Field(
"mbar", description="Integration method (mbar/ti)"
)
remd: Literal["yes", "no"] = Field(
"no",
description="Enable REMD execution (submission only; inputs are always prepared).",
)
remd_nstlim: int = Field(
100, description="Steps per REMD segment (applied to ``mdin-*-remd`` copies)."
)
slurm_header_dir: Path = Field(
default_factory=lambda: Path.home() / ".batter",
description="Directory containing user Slurm header templates.",
)
infe: bool = Field(
False, description="Enable NFE (infinite) equilibration when true."
)
# --- Anchors / molecular definitions ---
p1: str = Field("", description='Anchor P1 "RESID@ATOM" (e.g., "85@CA")')
p2: str = Field("", description='Anchor P2 "RESID@ATOM"')
p3: str = Field("", description='Anchor P3 "RESID@ATOM"')
other_mol: List[str] = Field(default_factory=list, description="Other co-binders")
lipid_mol: List[str] = Field(default_factory=list, description="Lipid molecules")
solv_shell: Optional[float] = Field(
15.0, description="Initial solvent shell radius (Å)"
)
rocklin_correction: Literal["yes", "no"] = Field(
"no", description="Rocklin correction"
)
# --- FE controls / analysis ---
release_eq: List[float] = Field(
default_factory=lambda: [0.0],
description="Equilibration release weights (derived; fixed to [0.0]).",
)
ti_points: Optional[int] = Field(0, description="(#) TI points (not implemented)")
lambdas: List[float] = Field(
default_factory=list, description="default lambda values"
)
component_windows: Dict[str, List[float]] = Field(
default_factory=dict, description="Per-component lambda values for overrides"
)
sdr_dist: Optional[float] = Field(0.0, description="SDR placement distance (Å)")
dec_method: Optional[str] = Field(
None, description="Decoupling method (set for fe_type='custom')"
)
blocks: int = Field(0, description="MBAR blocks")
unbound_threshold: float = Field(
8.0,
ge=0.0,
description="Distance (Å) between ligand COMs that classifies equilibration as unbound.",
)
analysis_start_step: int = Field(
0,
ge=0,
description="Analyze only steps after this (per FE window).",
)
n_bootstraps: int = Field(
0,
ge=0,
description="Number of MBAR bootstrap resamples used during FE analysis.",
)
# --- Force constants ---
lig_distance_force: float = Field(
0.0, description="Ligand COM distance spring (kcal/mol/Å^2)"
)
lig_angle_force: float = Field(
0.0, description="Ligand angle/dihedral spring (kcal/mol/rad^2)"
)
lig_dihcf_force: float = Field(
0.0, description="Ligand dihedral spring (kcal/mol/rad^2)"
)
rec_com_force: float = Field(0.0, description="Protein COM spring")
lig_com_force: float = Field(0.0, description="Ligand COM spring")
# --- Solvent / box ---
water_model: Literal["SPCE", "TIP4PEW", "TIP3P", "TIP3PF", "OPC"] = Field(
"TIP3P", description="Water model"
)
buffer_x: float = Field(10.0, description="Box buffer X (Å)")
buffer_y: float = Field(10.0, description="Box buffer Y (Å)")
buffer_z: float = Field(15.0, description="Box buffer Z (Å)")
lig_buffer: float = Field(10.0, description="Ligand box buffer (Å)")
# --- Ions ---
neutralize_only: Literal["yes", "no"] = Field("no", description="Neutralize only")
cation: str = Field("Na+", description="Cation species")
anion: str = Field("Cl-", description="Anion species")
ion_conc: float = Field(0.15, description="Target salt concentration (M)")
# --- Simulation params ---
hmr: Literal["yes", "no"] = Field("no", description="Hydrogen mass repartitioning")
enable_mcwat: Literal["yes", "no"] = Field(
"yes",
description="Enable MC water exchange moves during equilibration templates.",
)
temperature: float = Field(298.15, description="Temperature (K)")
eq_steps: int = Field(
1_000_000, description="Total equilibration steps (entire run)."
)
n_steps_dict: Dict[str, int] = Field(
default_factory=lambda: {f"{comp}_n_steps": 1_000_000 for comp in FEP_COMPONENTS},
description="Per-component steps (keys: '{comp}_n_steps')",
)
# --- L1 search (optional) ---
l1_x: Optional[float] = Field(None, description="L1 center offset X (Å)")
l1_y: Optional[float] = Field(None, description="L1 center offset Y (Å)")
l1_z: Optional[float] = Field(None, description="L1 center offset Z (Å)")
l1_range: Optional[float] = Field(None, description="L1 search radius (Å)")
min_adis: Optional[float] = Field(None, description="Min anchor distance (Å)")
max_adis: Optional[float] = Field(None, description="Max anchor distance (Å)")
# --- Amber i/o ---
ntpr: int = Field(100, description="Print energy every ntpr steps")
ntwr: int = Field(10_000, description="Write restart every ntwr steps")
ntwe: int = Field(0, description="Write energy every ntwe steps")
ntwx: int = Field(2500, description="Write trajectory every ntwx steps")
cut: float = Field(9.0, description="Nonbonded cutoff (Å)")
gamma_ln: float = Field(1.0, description="Langevin γ (ps^-1)")
barostat: Literal[1, 2] = Field(2, description="1=Berendsen, 2=MC barostat")
dt: float = Field(0.004, description="Time step (ps)")
all_atoms: Literal["yes", "no"] = Field("no", description="save all atoms for FE")
# --- Force fields ---
receptor_ff: str = Field("protein.ff14SB", description="Receptor FF")
ligand_ff: str = Field("gaff2", description="Ligand FF")
lipid_ff: str = Field("lipid21", description="Lipid FF")
# --- Derived/public state (not user-set) ---
ligand_dict: Dict[str, Any] = Field(
default_factory=dict, description="Ligand dictionary"
)
rng: int = Field(0, description="Range of release_eq")
ion_def: List[Any] = Field(
default_factory=list, description="Ion tuple [cation, anion, conc]"
)
dic_n_steps: Dict[str, int] = Field(
default_factory=dict, description="Steps per component"
)
rest: List[float] = Field(
default_factory=list, description="Packed restraint constants"
)
neut: str = Field("", description="Alias of neutralize_only")
protein_align: str = Field("name CA", description="Alignment selection")
receptor_segment: Optional[str] = Field(
None, description="Segment to embed in membrane"
)
# --- Private/internal runtime ---
components: List[str] = Field(
default_factory=list, description="List of components (v, o, z, etc.)"
)
component_lambdas: Dict[str, List[float]] = Field(
default_factory=dict, description="Lambda schedule for each component"
)
membrane_simulation: bool = Field(
default=True, description="Whether system includes a membrane"
)
# ---------------- validators / coercers ----------------
@field_validator("fe_type", "dec_int", mode="before")
@classmethod
def _lower_enums(cls, v: Any) -> Any:
return v if v is None else str(v).lower()
@field_validator(
"neutralize_only",
"hmr",
"rocklin_correction",
"enable_mcwat",
"remd",
mode="before",
)
@classmethod
def _coerce_yes_no(cls, v: Any) -> str | None:
if v is None:
return None
if isinstance(v, bool):
return "yes" if v else "no"
if isinstance(v, (int, float)):
return "yes" if v else "no"
if isinstance(v, str):
s = v.strip().lower()
if s in {"yes", "no"}:
return s
if s in {"true", "t", "1"}:
return "yes"
if s in {"false", "f", "0"}:
return "no"
raise ValueError(f"Invalid yes/no: {v!r}")
@field_validator("lambdas", mode="before")
@classmethod
def _parse_lambdas(cls, v: Any) -> Any:
"""
Accept a YAML list, or a single space/comma-separated string.
"""
if v is None:
return []
if isinstance(v, str):
# split on commas or whitespace
parts = [p for p in re.split(r"[,\s]+", v.strip()) if p]
return [float(x) for x in parts]
return v
@field_validator("barostat", mode="before")
@classmethod
def _coerce_barostat(cls, v: Any) -> Any:
if isinstance(v, str):
text = v.strip()
if not text:
return v
try:
return int(text)
except ValueError:
return v
if isinstance(v, float):
return int(v)
return v
@field_validator("p1", "p2", "p3")
@classmethod
def _validate_anchor(cls, v: str) -> str:
if not v:
return v
if not _ANCHOR_RE.match(v):
raise ValueError(f"Anchor must look like ':85@CA' (got {v!r})")
return v
@model_validator(mode="after")
def _finalize(self) -> "SimulationConfig":
# TI not implemented
if self.dec_int == "ti":
raise NotImplementedError("TI integration not implemented; use 'mbar'.")
# derived fields
self.rng = len(self.release_eq) - 1
self.ion_def = [self.cation, self.anion, self.ion_conc]
self.neut = self.neutralize_only
# stage dicts (copy from n_steps_dict only for ACTIVE components)
self.dic_n_steps.clear()
for comp in FEP_COMPONENTS:
k2 = f"{comp}_n_steps"
self.dic_n_steps[comp] = int(self.n_steps_dict.get(k2, 0))
# pack restraints (order-sensitive, matches legacy)
self.rest = [
0,
0,
self.lig_distance_force,
self.lig_angle_force,
self.lig_dihcf_force,
self.rec_com_force,
self.lig_com_force,
]
# friendly notices
if self.buffer_z == 0:
logger.debug(
"buffer_z=0; automatic Z buffer will be applied for membranes."
)
# Set components/dec_method by fe_type
match self.fe_type:
case "custom":
if self.dec_method is None:
raise ValueError(
"For fe_type='custom', set dec_method to one of: dd, sdr, exchange."
)
self.components = []
case "rest":
self.components, self.dec_method = ["c", "a", "l", "t", "r"], "dd"
case "sdr":
self.components, self.dec_method = ["e", "v"], "sdr"
case "dd":
self.components, self.dec_method = ["e", "v", "f", "w"], "dd"
case "sdr-rest":
self.components, self.dec_method = [
"c",
"a",
"l",
"t",
"r",
"e",
"v",
], "sdr"
case "express":
self.components, self.dec_method = ["m", "n", "e", "v"], "sdr"
case "dd-rest":
self.components, self.dec_method = [
"c",
"a",
"l",
"t",
"r",
"e",
"v",
"f",
"w",
], "dd"
case "relative":
self.components, self.dec_method = ["x"], "exchange"
case "uno":
self.components, self.dec_method = ["m", "n", "o"], "sdr"
case "uno_rest":
self.components, self.dec_method = ["z"], "sdr"
case "uno_com":
self.components, self.dec_method = ["o"], "sdr"
case "self":
self.components, self.dec_method = ["s"], "sdr"
case "uno_dd":
self.components, self.dec_method = ["z", "y"], "dd"
case "asfe":
self.components, self.dec_method = ["y", "m"], "sdr"
case "md":
self.components, self.dec_method = [], None
# sanity checks for active components only
for comp in self.components:
s2 = self.dic_n_steps.get(comp, 0)
if s2 <= 0:
raise ValueError(
f"{comp}: steps must be > 0 (key '{comp}_n_steps')."
)
if self.analysis_start_step >= s2:
raise ValueError(
f"analysis_start_step ({self.analysis_start_step}) must be < {comp}_n_steps ({s2})."
)
# update per-component lambdas
self.component_lambdas.clear()
for comp in self.components:
lambdas = self.component_windows.get(comp) or []
if not lambdas:
lambdas = self.lambdas
if not lambdas:
if self.fe_type == "relative" and comp == "x":
raise ValueError(
"RBFE requires a lambda schedule for component 'x'. "
"Set fe_sim.lambdas (or fe_sim.component_lambdas.x / x_lambdas)."
)
raise ValueError(f"No lambdas defined for component '{comp}'.")
logger.debug(
f"No per-component lambdas for '{comp}'; using default lambdas."
)
self.component_lambdas[comp] = lambdas
# membrane simulation if lipids defined
self.membrane_simulation = len(self.lipid_mol) > 0
if self.membrane_simulation:
self._check_membrane_compatibility()
else:
self._check_water_compatibility()
return self
def _check_membrane_compatibility(self) -> None:
pass
def _check_water_compatibility(self) -> None:
# make sure buffer_x/y/z is > 10.0 Å
for dim, buf in zip(
("X", "Y", "Z"), (self.buffer_x, self.buffer_y, self.buffer_z)
):
if buf < 10.0:
raise ValueError(
f"For water simulations, buffer_{dim.lower()} must be >= 10.0 Å (got {buf})."
)
# convenience
[docs]
def to_dict(self) -> Dict[str, Any]:
return self.model_dump()