Source code for batter.config.simulation

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Literal, TYPE_CHECKING, Tuple
from pathlib import Path
from pydantic import (
    BaseModel,
    Field,
    ConfigDict,
    PrivateAttr,
    field_validator,
    model_validator,
)
import re
import os
from loguru import logger
from batter.utils import COMPONENTS_LAMBDA_DICT
from batter.config.utils import coerce_yes_no
from batter.config.remd import RemdArgs

if TYPE_CHECKING:
    from batter.config.run import CreateArgs, FESimArgs

FEP_COMPONENTS = list(COMPONENTS_LAMBDA_DICT.keys())
_ANCHOR_RE = re.compile(r"^:?\d+@[\w\d]+$")  # e.g., ":85@CA" or "85@CA"

MEMBRANE_EXEMPT_COMPONENTS = {"y", "m"}

PROTOCOL_TO_FE_TYPE = {
    "abfe": "uno_rest",
    "rbfe": "relative",
    "asfe": "asfe",
    "md": "md",
}


[docs] class SimulationConfig(BaseModel): """ Simulation configuration for ABFE/ASFE/RBFE workflows. Values are fed by RunConfig.resolved_sim_config(), which merges `create:` and `fe_sim:`. """
[docs] @classmethod def from_sections( cls, create: "CreateArgs", fe: "FESimArgs", *, protocol: str | None = None, fe_type: str | None = None, slurm_header_dir: Path | None = None, run_remd: str | bool | None = None, ) -> "SimulationConfig": """Construct a :class:`SimulationConfig` from run sections. Parameters ---------- create : CreateArgs System creation inputs taken from the ``create`` YAML section. fe : FESimArgs Free-energy simulation overrides from the ``fe_sim`` section. run_remd : {"yes","no"}, optional Whether REMD execution is enabled (controls submission only; REMD inputs are always written during preparation). Returns ------- SimulationConfig Fully merged simulation configuration ready for downstream use. """ l1_range = create.l1_range if create.l1_range is not None else 6.0 min_adis = create.min_adis if create.min_adis is not None else 3.0 max_adis = create.max_adis if create.max_adis is not None else 7.0 create_data: dict[str, Any] = { "system_name": create.system_name or "unnamed_system", "receptor_ff": create.receptor_ff, "ligand_ff": create.ligand_ff, "lipid_ff": create.lipid_ff, "lipid_mol": list(create.lipid_mol or []), "other_mol": list(create.other_mol or []), "water_model": create.water_model, "neutralize_only": coerce_yes_no(create.neutralize_only), "ion_conc": float(create.ion_conc), "cation": create.cation, "anion": create.anion, "solv_shell": float(create.solv_shell), "protein_align": create.protein_align, "l1_range": float(l1_range), "min_adis": float(min_adis), "max_adis": float(max_adis), } def _fe_attr(name: str, default): if hasattr(fe, name): return getattr(fe, name) if isinstance(fe, Mapping) and name in fe: return fe[name] return default() if callable(default) else default resolved_fe_type = fe_type if resolved_fe_type is None and protocol: resolved_fe_type = PROTOCOL_TO_FE_TYPE.get(protocol.lower()) if resolved_fe_type is None: resolved_fe_type = _fe_attr("fe_type", lambda: None) if resolved_fe_type is None: resolved_fe_type = "md" proto_key = (protocol or "").lower() def _coerce_step_dict(name: str, mapping: Mapping[str, Any]) -> dict[str, int]: out: dict[str, int] = {} for comp, value in (mapping or {}).items(): if not isinstance(comp, str): raise ValueError( f"{name} keys must be single-letter component codes; got {comp!r}" ) comp_key = comp.strip().lower() if len(comp_key) != 1: raise ValueError( f"{name} keys must be single letters (got {comp!r})." ) if comp_key not in FEP_COMPONENTS: raise ValueError( f"Unknown component '{comp_key}' in {name}; valid components: {', '.join(sorted(FEP_COMPONENTS))}." ) out[comp_key] = int(value) return out def _coerce_lambda_list(name: str, seq: Any) -> List[float]: if seq is None: return [] if isinstance(seq, str): parts = [p for p in re.split(r"[,\s]+", seq.strip()) if p] out = [float(p) for p in parts] elif isinstance(seq, (list, tuple)): out = [float(x) for x in seq] else: out = [float(seq)] if out and any(left > right for left, right in zip(out, out[1:])): raise ValueError(f"{name} values must be in ascending order.") return out n_steps = _coerce_step_dict( "n_steps", dict(_fe_attr("n_steps", lambda: {"x": 300_000, "y": 300_000}) or {}), ) if proto_key == "rbfe" and "x" not in n_steps: n_steps["x"] = 300_000 base_lambdas = _coerce_lambda_list("lambdas", _fe_attr("lambdas", list) or []) component_lambda_map: dict[str, List[float]] = {} raw_component_lambdas = dict(_fe_attr("component_lambdas", dict) or {}) for comp, seq in raw_component_lambdas.items(): comp_key = str(comp).strip().lower() if comp_key not in FEP_COMPONENTS: raise ValueError( f"Unknown component '{comp_key}' in component_lambdas; valid components: {', '.join(sorted(FEP_COMPONENTS))}." ) component_lambda_map[comp_key] = _coerce_lambda_list( f"component_lambdas['{comp_key}']", seq ) required_components = { "abfe": ["z"], "asfe": ["y", "m"], "rbfe": ["x"], }.get(proto_key, []) for comp in required_components: if comp not in n_steps: raise ValueError( f"{proto_key.upper()} protocol requires steps for component '{comp}'. Add {comp}_n_steps." ) if n_steps[comp] <= 0: raise ValueError( f"{proto_key.upper()} protocol requires positive steps for component '{comp}'." ) eq_steps_raw = int(_fe_attr("eq_steps", lambda: 1_000_000)) if eq_steps_raw <= 2500 and eq_steps_raw != 0: logger.warning( "Setting fe_sim.eq_steps to 2500 (minimum allowed non-zero value). " "If you don't want equilibration, set fe_sim.eq_steps=0." ) eq_steps_raw = 2500 eq_steps_value = eq_steps_raw fe_release_eq = [0.0] extra_conf_rest = create.extra_conformation_restraints extra_restraints = create.extra_restraints analysis_start_step_val = 0 if hasattr(fe, "analysis_start_step"): analysis_start_step_val = int(getattr(fe, "analysis_start_step") or 0) elif isinstance(fe, Mapping) and "analysis_start_step" in fe: analysis_start_step_val = int(fe.get("analysis_start_step") or 0) if analysis_start_step_val < 0: raise ValueError("analysis_start_step must be >= 0.") n_bootstraps_val = 0 if hasattr(fe, "n_bootstraps"): n_bootstraps_val = int(getattr(fe, "n_bootstraps") or 0) elif isinstance(fe, Mapping) and "n_bootstraps" in fe: n_bootstraps_val = int(fe.get("n_bootstraps") or 0) if n_bootstraps_val < 0: raise ValueError("n_bootstraps must be >= 0.") max_fe_steps = max((int(v) for v in n_steps.values() if v is not None), default=0) if max_fe_steps and analysis_start_step_val >= max_fe_steps: raise ValueError( f"analysis_start_step ({analysis_start_step_val}) must be smaller than fe_total_step ({max_fe_steps})." ) remd_settings = _fe_attr("remd", lambda: RemdArgs()) if isinstance(remd_settings, RemdArgs): remd_nstlim = int(remd_settings.nstlim) elif isinstance(remd_settings, Mapping): remd_settings = RemdArgs(**remd_settings) remd_nstlim = int(remd_settings.nstlim) else: raise ValueError( "fe_sim.remd must be a mapping of REMD settings (nstlim); " "toggle execution with run.remd." ) remd_flag = coerce_yes_no(run_remd or "no") or "no" fe_data: dict[str, Any] = { "fe_type": resolved_fe_type, "dec_int": _fe_attr("dec_int", lambda: "mbar"), "remd": remd_flag, "remd_nstlim": remd_nstlim, "rocklin_correction": coerce_yes_no( _fe_attr("rocklin_correction", lambda: "no") ), "enable_mcwat": coerce_yes_no(_fe_attr("enable_mcwat", lambda: "yes")), "lambdas": base_lambdas, "component_windows": component_lambda_map, "blocks": int(_fe_attr("blocks", lambda: 0)), "lig_buffer": float(_fe_attr("lig_buffer", lambda: 15.0)), "lig_distance_force": float(_fe_attr("lig_distance_force", lambda: 5.0)), "lig_angle_force": float(_fe_attr("lig_angle_force", lambda: 250.0)), "lig_dihcf_force": float(_fe_attr("lig_dihcf_force", lambda: 0.0)), "rec_com_force": float(_fe_attr("rec_com_force", lambda: 10.0)), "lig_com_force": float(_fe_attr("lig_com_force", lambda: 10.0)), "buffer_x": float(_fe_attr("buffer_x", lambda: 10.0)), "buffer_y": float(_fe_attr("buffer_y", lambda: 10.0)), "buffer_z": float(_fe_attr("buffer_z", lambda: 15.0)), "temperature": float(_fe_attr("temperature", lambda: 298.15)), "dt": float(_fe_attr("dt", lambda: 0.004)), "hmr": coerce_yes_no(_fe_attr("hmr", lambda: "yes")), "release_eq": fe_release_eq, "eq_steps": eq_steps_value, "ntpr": int(_fe_attr("ntpr", lambda: 100)), "ntwr": int(_fe_attr("ntwr", lambda: 10_000)), "ntwe": int(_fe_attr("ntwe", lambda: 0)), "ntwx": int(_fe_attr("ntwx", lambda: 50_000)), "cut": float(_fe_attr("cut", lambda: 9.0)), "gamma_ln": float(_fe_attr("gamma_ln", lambda: 1.0)), "barostat": int(_fe_attr("barostat", lambda: 2)), "unbound_threshold": float(_fe_attr("unbound_threshold", lambda: 8.0)), "analysis_start_step": analysis_start_step_val, "n_bootstraps": n_bootstraps_val, "slurm_header_dir": Path(slurm_header_dir or (Path.home() / ".batter")), } infe_flag = bool(extra_conf_rest) if extra_conf_rest: fe_data["barostat"] = 2 elif extra_restraints is not None: fe_data["barostat"] = 1 n_steps_dict: dict[str, int] = {} for comp in sorted(n_steps): n_steps_dict[f"{comp}_n_steps"] = int(n_steps.get(comp, 0)) merged: dict[str, Any] = { **create_data, **fe_data, "n_steps_dict": n_steps_dict, "infe": infe_flag, } return cls(**merged)
# model_config = ConfigDict(extra="ignore", populate_by_name=True, validate_default=True) # --- Required / core --- system_name: str = Field(..., description="System name (required)") fe_type: Literal[ "custom", "rest", "sdr", "dd", "sdr-rest", "express", "relative", "uno", "uno_com", "uno_rest", "self", "uno_dd", "dd-rest", "asfe", "md", ] = Field(..., description="Free-energy protocol type") # --- Global switches --- dec_int: Literal["mbar", "ti"] = Field( "mbar", description="Integration method (mbar/ti)" ) remd: Literal["yes", "no"] = Field( "no", description="Enable REMD execution (submission only; inputs are always prepared).", ) remd_nstlim: int = Field( 100, description="Steps per REMD segment (applied to ``mdin-*-remd`` copies)." ) slurm_header_dir: Path = Field( default_factory=lambda: Path.home() / ".batter", description="Directory containing user Slurm header templates.", ) infe: bool = Field( False, description="Enable NFE (infinite) equilibration when true." ) # --- Anchors / molecular definitions --- p1: str = Field("", description='Anchor P1 "RESID@ATOM" (e.g., "85@CA")') p2: str = Field("", description='Anchor P2 "RESID@ATOM"') p3: str = Field("", description='Anchor P3 "RESID@ATOM"') other_mol: List[str] = Field(default_factory=list, description="Other co-binders") lipid_mol: List[str] = Field(default_factory=list, description="Lipid molecules") solv_shell: Optional[float] = Field( 15.0, description="Initial solvent shell radius (Å)" ) rocklin_correction: Literal["yes", "no"] = Field( "no", description="Rocklin correction" ) # --- FE controls / analysis --- release_eq: List[float] = Field( default_factory=lambda: [0.0], description="Equilibration release weights (derived; fixed to [0.0]).", ) ti_points: Optional[int] = Field(0, description="(#) TI points (not implemented)") lambdas: List[float] = Field( default_factory=list, description="default lambda values" ) component_windows: Dict[str, List[float]] = Field( default_factory=dict, description="Per-component lambda values for overrides" ) sdr_dist: Optional[float] = Field(0.0, description="SDR placement distance (Å)") dec_method: Optional[str] = Field( None, description="Decoupling method (set for fe_type='custom')" ) blocks: int = Field(0, description="MBAR blocks") unbound_threshold: float = Field( 8.0, ge=0.0, description="Distance (Å) between ligand COMs that classifies equilibration as unbound.", ) analysis_start_step: int = Field( 0, ge=0, description="Analyze only steps after this (per FE window).", ) n_bootstraps: int = Field( 0, ge=0, description="Number of MBAR bootstrap resamples used during FE analysis.", ) # --- Force constants --- lig_distance_force: float = Field( 0.0, description="Ligand COM distance spring (kcal/mol/Å^2)" ) lig_angle_force: float = Field( 0.0, description="Ligand angle/dihedral spring (kcal/mol/rad^2)" ) lig_dihcf_force: float = Field( 0.0, description="Ligand dihedral spring (kcal/mol/rad^2)" ) rec_com_force: float = Field(0.0, description="Protein COM spring") lig_com_force: float = Field(0.0, description="Ligand COM spring") # --- Solvent / box --- water_model: Literal["SPCE", "TIP4PEW", "TIP3P", "TIP3PF", "OPC"] = Field( "TIP3P", description="Water model" ) buffer_x: float = Field(10.0, description="Box buffer X (Å)") buffer_y: float = Field(10.0, description="Box buffer Y (Å)") buffer_z: float = Field(15.0, description="Box buffer Z (Å)") lig_buffer: float = Field(10.0, description="Ligand box buffer (Å)") # --- Ions --- neutralize_only: Literal["yes", "no"] = Field("no", description="Neutralize only") cation: str = Field("Na+", description="Cation species") anion: str = Field("Cl-", description="Anion species") ion_conc: float = Field(0.15, description="Target salt concentration (M)") # --- Simulation params --- hmr: Literal["yes", "no"] = Field("no", description="Hydrogen mass repartitioning") enable_mcwat: Literal["yes", "no"] = Field( "yes", description="Enable MC water exchange moves during equilibration templates.", ) temperature: float = Field(298.15, description="Temperature (K)") eq_steps: int = Field( 1_000_000, description="Total equilibration steps (entire run)." ) n_steps_dict: Dict[str, int] = Field( default_factory=lambda: {f"{comp}_n_steps": 1_000_000 for comp in FEP_COMPONENTS}, description="Per-component steps (keys: '{comp}_n_steps')", ) # --- L1 search (optional) --- l1_x: Optional[float] = Field(None, description="L1 center offset X (Å)") l1_y: Optional[float] = Field(None, description="L1 center offset Y (Å)") l1_z: Optional[float] = Field(None, description="L1 center offset Z (Å)") l1_range: Optional[float] = Field(None, description="L1 search radius (Å)") min_adis: Optional[float] = Field(None, description="Min anchor distance (Å)") max_adis: Optional[float] = Field(None, description="Max anchor distance (Å)") # --- Amber i/o --- ntpr: int = Field(100, description="Print energy every ntpr steps") ntwr: int = Field(10_000, description="Write restart every ntwr steps") ntwe: int = Field(0, description="Write energy every ntwe steps") ntwx: int = Field(2500, description="Write trajectory every ntwx steps") cut: float = Field(9.0, description="Nonbonded cutoff (Å)") gamma_ln: float = Field(1.0, description="Langevin γ (ps^-1)") barostat: Literal[1, 2] = Field(2, description="1=Berendsen, 2=MC barostat") dt: float = Field(0.004, description="Time step (ps)") all_atoms: Literal["yes", "no"] = Field("no", description="save all atoms for FE") # --- Force fields --- receptor_ff: str = Field("protein.ff14SB", description="Receptor FF") ligand_ff: str = Field("gaff2", description="Ligand FF") lipid_ff: str = Field("lipid21", description="Lipid FF") # --- Derived/public state (not user-set) --- ligand_dict: Dict[str, Any] = Field( default_factory=dict, description="Ligand dictionary" ) rng: int = Field(0, description="Range of release_eq") ion_def: List[Any] = Field( default_factory=list, description="Ion tuple [cation, anion, conc]" ) dic_n_steps: Dict[str, int] = Field( default_factory=dict, description="Steps per component" ) rest: List[float] = Field( default_factory=list, description="Packed restraint constants" ) neut: str = Field("", description="Alias of neutralize_only") protein_align: str = Field("name CA", description="Alignment selection") receptor_segment: Optional[str] = Field( None, description="Segment to embed in membrane" ) # --- Private/internal runtime --- components: List[str] = Field( default_factory=list, description="List of components (v, o, z, etc.)" ) component_lambdas: Dict[str, List[float]] = Field( default_factory=dict, description="Lambda schedule for each component" ) membrane_simulation: bool = Field( default=True, description="Whether system includes a membrane" ) # ---------------- validators / coercers ---------------- @field_validator("fe_type", "dec_int", mode="before") @classmethod def _lower_enums(cls, v: Any) -> Any: return v if v is None else str(v).lower() @field_validator( "neutralize_only", "hmr", "rocklin_correction", "enable_mcwat", "remd", mode="before", ) @classmethod def _coerce_yes_no(cls, v: Any) -> str | None: if v is None: return None if isinstance(v, bool): return "yes" if v else "no" if isinstance(v, (int, float)): return "yes" if v else "no" if isinstance(v, str): s = v.strip().lower() if s in {"yes", "no"}: return s if s in {"true", "t", "1"}: return "yes" if s in {"false", "f", "0"}: return "no" raise ValueError(f"Invalid yes/no: {v!r}") @field_validator("lambdas", mode="before") @classmethod def _parse_lambdas(cls, v: Any) -> Any: """ Accept a YAML list, or a single space/comma-separated string. """ if v is None: return [] if isinstance(v, str): # split on commas or whitespace parts = [p for p in re.split(r"[,\s]+", v.strip()) if p] return [float(x) for x in parts] return v @field_validator("barostat", mode="before") @classmethod def _coerce_barostat(cls, v: Any) -> Any: if isinstance(v, str): text = v.strip() if not text: return v try: return int(text) except ValueError: return v if isinstance(v, float): return int(v) return v @field_validator("p1", "p2", "p3") @classmethod def _validate_anchor(cls, v: str) -> str: if not v: return v if not _ANCHOR_RE.match(v): raise ValueError(f"Anchor must look like ':85@CA' (got {v!r})") return v @model_validator(mode="after") def _finalize(self) -> "SimulationConfig": # TI not implemented if self.dec_int == "ti": raise NotImplementedError("TI integration not implemented; use 'mbar'.") # derived fields self.rng = len(self.release_eq) - 1 self.ion_def = [self.cation, self.anion, self.ion_conc] self.neut = self.neutralize_only # stage dicts (copy from n_steps_dict only for ACTIVE components) self.dic_n_steps.clear() for comp in FEP_COMPONENTS: k2 = f"{comp}_n_steps" self.dic_n_steps[comp] = int(self.n_steps_dict.get(k2, 0)) # pack restraints (order-sensitive, matches legacy) self.rest = [ 0, 0, self.lig_distance_force, self.lig_angle_force, self.lig_dihcf_force, self.rec_com_force, self.lig_com_force, ] # friendly notices if self.buffer_z == 0: logger.debug( "buffer_z=0; automatic Z buffer will be applied for membranes." ) # Set components/dec_method by fe_type match self.fe_type: case "custom": if self.dec_method is None: raise ValueError( "For fe_type='custom', set dec_method to one of: dd, sdr, exchange." ) self.components = [] case "rest": self.components, self.dec_method = ["c", "a", "l", "t", "r"], "dd" case "sdr": self.components, self.dec_method = ["e", "v"], "sdr" case "dd": self.components, self.dec_method = ["e", "v", "f", "w"], "dd" case "sdr-rest": self.components, self.dec_method = [ "c", "a", "l", "t", "r", "e", "v", ], "sdr" case "express": self.components, self.dec_method = ["m", "n", "e", "v"], "sdr" case "dd-rest": self.components, self.dec_method = [ "c", "a", "l", "t", "r", "e", "v", "f", "w", ], "dd" case "relative": self.components, self.dec_method = ["x"], "exchange" case "uno": self.components, self.dec_method = ["m", "n", "o"], "sdr" case "uno_rest": self.components, self.dec_method = ["z"], "sdr" case "uno_com": self.components, self.dec_method = ["o"], "sdr" case "self": self.components, self.dec_method = ["s"], "sdr" case "uno_dd": self.components, self.dec_method = ["z", "y"], "dd" case "asfe": self.components, self.dec_method = ["y", "m"], "sdr" case "md": self.components, self.dec_method = [], None # sanity checks for active components only for comp in self.components: s2 = self.dic_n_steps.get(comp, 0) if s2 <= 0: raise ValueError( f"{comp}: steps must be > 0 (key '{comp}_n_steps')." ) if self.analysis_start_step >= s2: raise ValueError( f"analysis_start_step ({self.analysis_start_step}) must be < {comp}_n_steps ({s2})." ) # update per-component lambdas self.component_lambdas.clear() for comp in self.components: lambdas = self.component_windows.get(comp) or [] if not lambdas: lambdas = self.lambdas if not lambdas: if self.fe_type == "relative" and comp == "x": raise ValueError( "RBFE requires a lambda schedule for component 'x'. " "Set fe_sim.lambdas (or fe_sim.component_lambdas.x / x_lambdas)." ) raise ValueError(f"No lambdas defined for component '{comp}'.") logger.debug( f"No per-component lambdas for '{comp}'; using default lambdas." ) self.component_lambdas[comp] = lambdas # membrane simulation if lipids defined self.membrane_simulation = len(self.lipid_mol) > 0 if self.membrane_simulation: self._check_membrane_compatibility() else: self._check_water_compatibility() return self def _check_membrane_compatibility(self) -> None: pass def _check_water_compatibility(self) -> None: # make sure buffer_x/y/z is > 10.0 Å for dim, buf in zip( ("X", "Y", "Z"), (self.buffer_x, self.buffer_y, self.buffer_z) ): if buf < 10.0: raise ValueError( f"For water simulations, buffer_{dim.lower()} must be >= 10.0 Å (got {buf})." ) # convenience
[docs] def to_dict(self) -> Dict[str, Any]: return self.model_dump()