Source code for batter.exec.handlers.fe_analysis

"""Run post-processing analysis on free-energy simulations."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional

from loguru import logger

from batter.analysis.analysis import analyze_lig_task
from batter.orchestrate.state_registry import register_phase_state
from batter.pipeline.payloads import StepPayload
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
from batter.utils import components_under


def _production_window_indices(fe_root: Path, comp: str) -> List[int]:
    """
    Return sorted integer indices N for windows <ligand>/fe/<comp>/<compN> (N >= 0).
    (We intentionally skip the equil dir '<comp>-1'.)
    """
    base = fe_root / comp
    if not base.exists():
        return []
    out: List[int] = []
    for p in base.iterdir():
        if not p.is_dir():
            continue
        name = p.name
        if name == f"{comp}-1":
            continue
        if not name.startswith(comp):
            continue
        tail = name[len(comp):]  # e.g., "0", "1", ...
        try:
            idx = int(tail)
        except ValueError:
            continue
        if idx >= 0:
            out.append(idx)
    return sorted(out)


def _infer_component_windows_dict(fe_root: Path, components: List[str]) -> Dict[str, List[int]]:
    """
    Best-effort reconstruction of the windows list per component by directory scanning.
    If a component has no windows, returns an empty list for it.
    """
    d: Dict[str, List[int]] = {}
    for comp in components:
        d[comp] = _production_window_indices(fe_root, comp)
    return d


[docs] def analyze_handler(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult: """Run FE analysis for a ligand rooted at ``<system.root>/fe``. Parameters ---------- step : Step Pipeline metadata (unused). system : SimSystem Simulation system descriptor. params : dict Handler payload validated into :class:`StepPayload`. Returns ------- ExecResult Mapping with the generated ``Results.dat`` and optional timeseries artefacts. """ lig = system.meta.get("ligand") mol = system.meta.get("residue_name") payload = StepPayload.model_validate(params) sim_cfg = payload.sim fe_root = system.root / "fe" if not fe_root.exists(): raise FileNotFoundError(f"[analyze:{lig}] Missing FE folder: {fe_root}") is_rbfe_pair = str(system.meta.get("mode", "")).upper() == "RBFE" default_components = components_under(fe_root) components: List[str] = list(default_components) temperature: float = 298.15 water_model: str = "tip3p" rocklin_correction: bool = False n_workers_override = payload.get("analysis_n_workers", None) n_workers: int = int(n_workers_override) if n_workers_override is not None else 4 rest: Tuple[str, ...] = tuple() sim_start_step: Optional[int] = None sim_n_bootstraps: int = 0 sim_dt: float = 0.0 sim_ntwx: int = 0 if sim_cfg is not None: if sim_cfg.components: components = list(sim_cfg.components) temperature = float(sim_cfg.temperature) water_model = str(sim_cfg.water_model).lower() rocklin_correction = bool(sim_cfg.rocklin_correction) rest = tuple(sim_cfg.rest) sim_start_step = int(getattr(sim_cfg, "analysis_start_step", 0)) sim_n_bootstraps = int(getattr(sim_cfg, "n_bootstraps", 0)) sim_dt = float(getattr(sim_cfg, "dt", 0.0)) sim_ntwx = int(getattr(sim_cfg, "ntwx", 0)) components = list(payload.get("components", components)) temperature = float(payload.get("temperature", temperature)) water_model = str(payload.get("water_model", water_model)).lower() rocklin_correction = bool(payload.get("rocklin_correction", rocklin_correction)) n_workers = int(payload.get("n_workers", n_workers)) # RBFE pair analysis is currently x-component only. if is_rbfe_pair: if (fe_root / "x").exists(): components = ["x"] rocklin_correction = False # Optional: analysis start step override; else use config default sim_start_step = int(payload.get("analysis_start_step", sim_start_step or 0)) sim_n_bootstraps = int(payload.get("n_bootstraps", sim_n_bootstraps)) sim_dt = float(payload.get("dt", sim_dt)) sim_ntwx = int(payload.get("ntwx", sim_ntwx)) # Try to reconstruct windows per component if the pipeline didn’t inject it component_windows_dict = payload.get("component_windows_dict") if not component_windows_dict: component_windows_dict = _infer_component_windows_dict(fe_root, components) logger.debug(f"[analyze:{lig}] Starting FE analysis " f"(components={components}, T={temperature}K, rocklin={rocklin_correction}, mol={mol})") try: analyze_lig_task( lig_path=system.root / "fe", lig=lig, components=components, rest=rest, temperature=temperature, water_model=water_model, component_windows_dict=component_windows_dict, rocklin_correction=rocklin_correction, analysis_start_step=sim_start_step, raise_on_error=True, mol=mol, n_workers=n_workers, n_bootstraps=sim_n_bootstraps, dt=sim_dt, ntwx=sim_ntwx, ) except Exception as e: logger.error(f"[analyze:{lig}] Analysis failed: {e}") raise # Collect artifacts results_dir = fe_root / "Results" arts: Dict[str, Path] = {} if results_dir.exists(): res_file = results_dir / "Results.dat" if res_file.exists(): arts["results_dat"] = res_file ts_json = results_dir / "fe_timeseries.json" if ts_json.exists(): arts["fe_timeseries_json"] = ts_json ts_png = results_dir / "fe_timeseries.png" if ts_png.exists(): arts["fe_timeseries_png"] = ts_png if is_rbfe_pair and res_file.exists(): summary_path = results_dir / "rbfe_pair_summary.json" total_fe = None total_se = None for raw in res_file.read_text().splitlines(): parts = [p for p in raw.replace("\t", " ").split() if p] if len(parts) >= 3 and parts[0].lower() == "total": try: total_fe = float(parts[1]) total_se = float(parts[2]) except ValueError: pass break summary = { "pair_id": system.meta.get("pair_id", lig), "ligand_ref": system.meta.get("ligand_ref"), "ligand_alt": system.meta.get("ligand_alt"), "total_dg_kcal_mol": total_fe, "total_se_kcal_mol": total_se, "components": components, } summary_path.write_text(json.dumps(summary, indent=2) + "\n") arts["rbfe_pair_summary_json"] = summary_path analyzed_finished = fe_root / "analyze.ok" open(analyzed_finished, "w").close() analyze_rel = analyzed_finished.relative_to(system.root).as_posix() results_rel = (results_dir / "Results.dat").relative_to(system.root).as_posix() register_phase_state( system.root, "analyze", required=[[analyze_rel, results_rel]], success=[[analyze_rel, results_rel]], ) logger.debug(f"[analyze:{lig}] FE analysis done. Artifacts: {', '.join(p.name for p in arts.values()) or 'none'}") return ExecResult(job_ids=[], artifacts=arts)