Source code for batter.exec.handlers.prepare_fe

"""Prepare alchemical FE inputs for a ligand."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, Optional

from loguru import logger

from batter._internal.builders.fe_alchemical import AlchemicalFEBuilder
from batter.orchestrate.state_registry import register_phase_state
from batter._internal.ops import remd as remd_ops
from batter._internal.ops import batch as batch_ops
from batter.pipeline.payloads import StepPayload, SystemParams
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem


def _system_root_for(child_root: Path) -> Path:
    """work/<sys>/simulations/<lig> → work/<sys>"""
    # Prefer the parent of the "simulations" directory so this works for
    # both per-ligand (simulations/<lig>) and RBFE transformations
    # (simulations/transformations/<pair>).
    try:
        parts = child_root.parts
        if "simulations" in parts:
            idx = parts.index("simulations")
            return Path(*parts[:idx])
        return child_root.parents[1]
    except Exception:
        return child_root


def _load_param_dir_dict(system_root: Path) -> Dict[str, str]:
    """
    Read artifacts/ligand_params/index.json → {residue_name: store_dir}
    """
    index_json = system_root / "artifacts" / "ligand_params" / "index.json"
    data = json.loads(index_json.read_text())
    out: Dict[str, str] = {}
    for entry in data.get("ligands", []):
        resn = entry.get("residue_name")
        store_dir = entry.get("store_dir")
        if resn and store_dir:
            out[resn] = store_dir
    if not out:
        raise RuntimeError(f"No ligand param entries found in {index_json}")
    return out

[docs] def prepare_fe_handler( step: Step, system: SimSystem, params: Dict[str, Any] ) -> ExecResult: """Construct the initial FE directory layout for a ligand. Parameters ---------- step : Step Pipeline metadata (unused). system : SimSystem Simulation system descriptor. params : dict Handler payload validated into :class:`StepPayload`. Returns ------- ExecResult Metadata describing the generated directories. """ payload = StepPayload.model_validate(params) if payload.sim is None: raise ValueError("[prepare_fe] Missing simulation configuration in payload.") sim = payload.sim partition = payload.get("partition") or payload.get("queue") or "normal" phase_name = payload.get("phase_name") or "prepare_fe" components = list(payload.get("components") or getattr(sim, "components", []) or []) if not components: raise ValueError("No components specified in sim config.") ligand = system.meta.get("ligand") residue_name = system.meta.get("residue_name") if not ligand or not residue_name: raise ValueError("System meta must include 'ligand' and 'residue_name'.") child_root = system.root system_root = _system_root_for(child_root) param_dir_dict = _load_param_dir_dict(system_root) comp_windows: dict = payload.get("component_lambdas") or sim.component_lambdas # type: ignore[attr-defined] sys_params = payload.sys_params or SystemParams() extra_restraints: Optional[str] = sys_params.get("extra_restraints", None) extra_restraint_fc = float(sys_params.get("extra_restraint_fc", 10.0)) extra_conformation_restraints: Optional[Path] = sys_params.get( "extra_conformation_restraints", None ) pair_meta = { key: system.meta.get(key) for key in ( "pair_id", "ligand_ref", "ligand_alt", "residue_ref", "residue_alt", "input_ref", "input_alt", "atom_mapper", "atom_mapper_options", "kartograf_options", "lomap_options", ) if system.meta.get(key) is not None } infe = bool(sim.infe) artifacts: Dict[str, Any] = {} logger.debug( f"[{phase_name}] start ligand={ligand} residue={residue_name} components={components}" ) # Patch sim for pre-prepare cases (e.g., RBFE pre_prepare_fe uses z) if any(comp == "z" for comp in components): if getattr(sim, "dec_method", None) not in {"sdr", "dd"}: sim = sim.model_copy(deep=True) sim.dec_method = "sdr" if sim.dic_n_steps.get("z", 0) <= 0: sim = sim.model_copy(deep=True) fallback = sim.dic_n_steps.get("x", 0) or sim.n_steps_dict.get("x_n_steps", 0) if fallback <= 0: fallback = int(getattr(sim, "eq_steps", 0) or 0) if fallback <= 0: fallback = 100000 sim.dic_n_steps["z"] = int(fallback) # Build per component (scaffold / templates only; win=-1) for comp in components: workdir = child_root / "fe" / comp workdir.mkdir(parents=True, exist_ok=True) logger.debug(f"[{phase_name}] building component '{comp}' in {workdir}") builder = AlchemicalFEBuilder( ligand=ligand, residue_name=residue_name, param_dir_dict=param_dir_dict, sim_config=sim, component=comp, component_windows=comp_windows[comp], working_dir=workdir, system_root=system_root, infe=infe, win=-1, extra={ "extra_restraints": extra_restraints, "extra_restraint_fc": extra_restraint_fc, "extra_conformation_restraints": extra_conformation_restraints, "partition": partition, **pair_meta, }, ) builder.build() # will create <comp>-1, amber templates, run files, etc. artifacts[f"{comp}_workdir"] = str(workdir) # emit the common OK marker used by the orchestrator marker = child_root / "fe" / f"{phase_name}.ok" marker.parent.mkdir(parents=True, exist_ok=True) marker.write_text("ok\n") logger.debug(f"[{phase_name}] finished ligand={ligand}{marker}") marker_rel = marker.relative_to(system.root).as_posix() register_phase_state( system.root, phase_name, required=[[marker_rel]], success=[[marker_rel]], ) return ExecResult(job_ids=[], artifacts={"prepare_fe_ok": marker, **artifacts})
# ----------------------------- # prepare_fe_windows (expand per-lambda windows) # -----------------------------
[docs] def prepare_fe_windows_handler( step: Step, system: SimSystem, params: Dict[str, Any] ) -> ExecResult: """ Expand FE windows for each requested component: - copies <comp>-1 to <comp>-2, <comp>-3, ... (depending on lambda schedule) - keeps run scripts consistent in each window (builders call write_run_file) - writes artifacts/fe/windows.json summarizing windows Builders re-use the same interface; here we just iterate components and request per-window builds by calling with win >= 1. """ payload = StepPayload.model_validate(params) if payload.sim is None: raise ValueError( "[prepare_fe_windows] Missing simulation configuration in payload." ) sim = payload.sim partition = payload.get("partition") or payload.get("queue") or "normal" components = list(getattr(sim, "components", []) or []) if not components: raise RuntimeError( "No components specified in sim config for FE window preparation." ) ligand = system.meta.get("ligand") residue_name = system.meta.get("residue_name") if not ligand or not residue_name: raise ValueError("System meta must include 'ligand' and 'residue_name'.") child_root = system.root system_root = _system_root_for(child_root) param_dir_dict = _load_param_dir_dict(system_root) comp_windows: dict = payload.get("component_lambdas") or sim.component_lambdas # type: ignore[attr-defined] sys_params = payload.sys_params or SystemParams() extra_restraints: Optional[str] = sys_params.get("extra_restraints", None) extra_restraint_fc = float(sys_params.get("extra_restraint_fc", 10.0)) extra_conformation_restraints: Optional[Path] = sys_params.get( "extra_conformation_restraints", None ) pair_meta = { key: system.meta.get(key) for key in ( "pair_id", "ligand_ref", "ligand_alt", "residue_ref", "residue_alt", "input_ref", "input_alt", "atom_mapper", "atom_mapper_options", "kartograf_options", "lomap_options", ) if system.meta.get(key) is not None } infe = False if extra_restraints is not None: infe = False sim.barostat = "1" if extra_conformation_restraints is not None: infe = True # cannot do NFE with barostat 1 (Berendsen) sim.barostat = "2" windows_summary: Dict[str, Any] = {} logger.debug( f"[prepare_fe_windows] start ligand={ligand} residue={residue_name} components={components}" ) for comp in components: workdir = child_root / "fe" / comp lambdas = comp_windows[comp] for win_idx, _ in enumerate(lambdas): logger.debug( f"[prepare_fe_windows] {comp} → creating window {win_idx} in {workdir}" ) builder = AlchemicalFEBuilder( ligand=ligand, residue_name=residue_name, param_dir_dict=param_dir_dict, sim_config=sim, component=comp, infe=infe, component_windows=lambdas, working_dir=workdir, system_root=system_root, win=win_idx, extra={ "extra_restraints": extra_restraints, "extra_restraint_fc": extra_restraint_fc, "extra_conformation_restraints": extra_conformation_restraints, "partition": partition, **pair_meta, }, ) builder.build() windows_summary[comp] = {"n_windows": len(lambdas), "lambdas": lambdas} # Always write REMD inputs; run.remd controls whether they are submitted. batch_ops.prepare_batch_component( workdir, comp=comp, n_windows=len(lambdas), hmr=str(sim.hmr).lower() == "yes", ) remd_ops.prepare_remd_component( workdir, comp=comp, sim=sim, n_windows=len(lambdas), partition=partition, ) # write a canonical windows.json under artifacts/fe/ windows_json = child_root / "fe" / "artifacts" / "windows.json" windows_json.parent.mkdir(parents=True, exist_ok=True) windows_json.write_text(json.dumps(windows_summary, indent=2) + "\n") prepare_finished = child_root / "fe" / "prepare_fe_windows.ok" open(prepare_finished, "w").close() windows_rel = prepare_finished.relative_to(system.root).as_posix() prepare_rel = ( (child_root / "fe" / "prepare_fe.ok") .relative_to(system.root) .as_posix() ) register_phase_state( system.root, "prepare_fe", required=[[prepare_rel, windows_rel]], success=[[prepare_rel, windows_rel]], ) logger.debug(f"[prepare_fe_windows] finished ligand={ligand}{windows_json}") return ExecResult(job_ids=[], artifacts={"windows_json": windows_json})