"""Minimal system-preparation handler for MASFE workflows."""
from __future__ import annotations
import json
import shutil
from pathlib import Path
from typing import Any, Dict, List
from loguru import logger
from batter.orchestrate.state_registry import register_phase_state
from batter.pipeline.payloads import StepPayload, SystemParams
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
def _ensure_pdb(lig_path: Path, out_dir: Path) -> Path:
"""
Ensure a PDB exists for a ligand file; if not PDB, convert via RDKit.
Returns the path to the PDB file we wrote or found.
"""
lig_path = Path(lig_path)
if lig_path.suffix.lower() == ".pdb":
return lig_path
try:
from rdkit import Chem
except Exception as e:
raise RuntimeError(
f"Ligand {lig_path} is not PDB; RDKit is required to convert SDF/MOL2 → PDB."
) from e
out_dir.mkdir(parents=True, exist_ok=True)
out_pdb = out_dir / f"{lig_path.stem}.pdb"
if lig_path.suffix.lower() == ".sdf":
suppl = Chem.SDMolSupplier(str(lig_path), removeHs=False)
mols = [m for m in suppl if m is not None]
if not mols:
raise ValueError(f"RDKit could not read any molecule from {lig_path}")
Chem.MolToPDBFile(mols[0], str(out_pdb))
elif lig_path.suffix.lower() == ".mol2":
mol = Chem.MolFromMol2File(str(lig_path), removeHs=False, sanitize=False)
if mol is None:
raise ValueError(f"RDKit could not read {lig_path}")
Chem.MolToPDBFile(mol, str(out_pdb))
else:
raise ValueError(f"Unsupported ligand format: {lig_path.suffix} for {lig_path}")
return out_pdb
def _copy(src: Path, dst: Path) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
class _MASFESystemPrepRunner:
"""
Minimal system_prep for MASFE (solvation FE):
- No protein/topology/coordinates
- Stage ligands into all-ligands/ as <NAME>.pdb (convert if needed)
- Write a manifest for downstream handlers
"""
def __init__(self, system: SimSystem) -> None:
self.system = system
self.output_dir = system.root
self.ligand_stage_dir = self.output_dir / "all-ligands"
self.ligand_stage_dir.mkdir(parents=True, exist_ok=True)
def run(self, *, system_name: str, ligand_paths: Dict[str, str]) -> Dict[str, Any]:
logger.info(f"[MASFE system_prep] system={system_name}, ligands={len(ligand_paths)}")
staged_map: Dict[str, str] = {}
for name, src in sorted(ligand_paths.items()):
src_p = Path(src)
if not src_p.exists():
raise FileNotFoundError(f"Ligand file not found: {src_p}")
pdb = _ensure_pdb(src_p, self.ligand_stage_dir)
dst = self.ligand_stage_dir / f"{name.upper()}.pdb"
if pdb.resolve() != dst.resolve():
_copy(pdb, dst)
staged_map[name.upper()] = str(dst)
manifest = {
"system_name": system_name,
"mode": "MASFE",
"ligands": staged_map,
}
(self.ligand_stage_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
return manifest
[docs]
def system_prep_masfe(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult:
"""Prepare a MASFE solvation system by staging ligands and overrides.
Parameters
----------
step : Step
Pipeline metadata (unused).
system : SimSystem
Simulation system descriptor.
params : dict
Handler payload validated into :class:`StepPayload`.
Returns
-------
ExecResult
Manifest of staged ligands and paths to generated files.
"""
logger.info(f"[system_prep_masfe] Preparing solvation FE system in {system.root}")
payload = StepPayload.model_validate(params)
sys_params = payload.sys_params or SystemParams()
lig_map = sys_params["ligand_paths"]
runner = _MASFESystemPrepRunner(system)
manifest = runner.run(system_name=sys_params["system_name"], ligand_paths=lig_map)
overrides = {
"is_solvation": True,
"water_model": sys_params.get("water_model", "TIP3P"),
"ion_conc": sys_params.get("ion_conc", 0.0),
"cation": sys_params.get("cation", "Na+"),
"anion": sys_params.get("anion", "Cl-"),
}
(system.root / "artifacts" / "config").mkdir(parents=True, exist_ok=True)
overrides_path = system.root / "artifacts" / "config" / "sim_overrides.json"
overrides_path.write_text(json.dumps(overrides, indent=2))
marker_rel = overrides_path.relative_to(system.root).as_posix()
register_phase_state(
system.root,
"system_prep_asfe",
required=[[marker_rel]],
success=[[marker_rel]],
)
outputs = [system.root / "all-ligands" / "manifest.json"]
info = {"system_prep_ok": True, **manifest, "sim_updates": overrides}
logger.info(f"[system_prep_masfe] Done (ligands: {len(manifest['ligands'])}).")
return ExecResult(outputs, info)