"""Run post-processing analysis on free-energy simulations."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional
from loguru import logger
from batter.analysis.analysis import analyze_lig_task
from batter.orchestrate.state_registry import register_phase_state
from batter.pipeline.payloads import StepPayload
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
from batter.utils import components_under
def _production_window_indices(fe_root: Path, comp: str) -> List[int]:
"""
Return sorted integer indices N for windows <ligand>/fe/<comp>/<compN> (N >= 0).
(We intentionally skip the equil dir '<comp>-1'.)
"""
base = fe_root / comp
if not base.exists():
return []
out: List[int] = []
for p in base.iterdir():
if not p.is_dir():
continue
name = p.name
if name == f"{comp}-1":
continue
if not name.startswith(comp):
continue
tail = name[len(comp):] # e.g., "0", "1", ...
try:
idx = int(tail)
except ValueError:
continue
if idx >= 0:
out.append(idx)
return sorted(out)
def _infer_component_windows_dict(fe_root: Path, components: List[str]) -> Dict[str, List[int]]:
"""
Best-effort reconstruction of the windows list per component by directory scanning.
If a component has no windows, returns an empty list for it.
"""
d: Dict[str, List[int]] = {}
for comp in components:
d[comp] = _production_window_indices(fe_root, comp)
return d
[docs]
def analyze_handler(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult:
"""Run FE analysis for a ligand rooted at ``<system.root>/fe``.
Parameters
----------
step : Step
Pipeline metadata (unused).
system : SimSystem
Simulation system descriptor.
params : dict
Handler payload validated into :class:`StepPayload`.
Returns
-------
ExecResult
Mapping with the generated ``Results.dat`` and optional timeseries
artefacts.
"""
lig = system.meta.get("ligand")
mol = system.meta.get("residue_name")
payload = StepPayload.model_validate(params)
sim_cfg = payload.sim
fe_root = system.root / "fe"
if not fe_root.exists():
raise FileNotFoundError(f"[analyze:{lig}] Missing FE folder: {fe_root}")
is_rbfe_pair = str(system.meta.get("mode", "")).upper() == "RBFE"
default_components = components_under(fe_root)
components: List[str] = list(default_components)
temperature: float = 298.15
water_model: str = "tip3p"
rocklin_correction: bool = False
n_workers_override = payload.get("analysis_n_workers", None)
n_workers: int = int(n_workers_override) if n_workers_override is not None else 4
rest: Tuple[str, ...] = tuple()
sim_start_step: Optional[int] = None
sim_n_bootstraps: int = 0
sim_dt: float = 0.0
sim_ntwx: int = 0
if sim_cfg is not None:
if sim_cfg.components:
components = list(sim_cfg.components)
temperature = float(sim_cfg.temperature)
water_model = str(sim_cfg.water_model).lower()
rocklin_correction = bool(sim_cfg.rocklin_correction)
rest = tuple(sim_cfg.rest)
sim_start_step = int(getattr(sim_cfg, "analysis_start_step", 0))
sim_n_bootstraps = int(getattr(sim_cfg, "n_bootstraps", 0))
sim_dt = float(getattr(sim_cfg, "dt", 0.0))
sim_ntwx = int(getattr(sim_cfg, "ntwx", 0))
components = list(payload.get("components", components))
temperature = float(payload.get("temperature", temperature))
water_model = str(payload.get("water_model", water_model)).lower()
rocklin_correction = bool(payload.get("rocklin_correction", rocklin_correction))
n_workers = int(payload.get("n_workers", n_workers))
# RBFE pair analysis is currently x-component only.
if is_rbfe_pair:
if (fe_root / "x").exists():
components = ["x"]
rocklin_correction = False
# Optional: analysis start step override; else use config default
sim_start_step = int(payload.get("analysis_start_step", sim_start_step or 0))
sim_n_bootstraps = int(payload.get("n_bootstraps", sim_n_bootstraps))
sim_dt = float(payload.get("dt", sim_dt))
sim_ntwx = int(payload.get("ntwx", sim_ntwx))
# Try to reconstruct windows per component if the pipeline didn’t inject it
component_windows_dict = payload.get("component_windows_dict")
if not component_windows_dict:
component_windows_dict = _infer_component_windows_dict(fe_root, components)
logger.debug(f"[analyze:{lig}] Starting FE analysis "
f"(components={components}, T={temperature}K, rocklin={rocklin_correction}, mol={mol})")
try:
analyze_lig_task(
lig_path=system.root / "fe",
lig=lig,
components=components,
rest=rest,
temperature=temperature,
water_model=water_model,
component_windows_dict=component_windows_dict,
rocklin_correction=rocklin_correction,
analysis_start_step=sim_start_step,
raise_on_error=True,
mol=mol,
n_workers=n_workers,
n_bootstraps=sim_n_bootstraps,
dt=sim_dt,
ntwx=sim_ntwx,
)
except Exception as e:
logger.error(f"[analyze:{lig}] Analysis failed: {e}")
raise
# Collect artifacts
results_dir = fe_root / "Results"
arts: Dict[str, Path] = {}
if results_dir.exists():
res_file = results_dir / "Results.dat"
if res_file.exists():
arts["results_dat"] = res_file
ts_json = results_dir / "fe_timeseries.json"
if ts_json.exists():
arts["fe_timeseries_json"] = ts_json
ts_png = results_dir / "fe_timeseries.png"
if ts_png.exists():
arts["fe_timeseries_png"] = ts_png
if is_rbfe_pair and res_file.exists():
summary_path = results_dir / "rbfe_pair_summary.json"
total_fe = None
total_se = None
for raw in res_file.read_text().splitlines():
parts = [p for p in raw.replace("\t", " ").split() if p]
if len(parts) >= 3 and parts[0].lower() == "total":
try:
total_fe = float(parts[1])
total_se = float(parts[2])
except ValueError:
pass
break
summary = {
"pair_id": system.meta.get("pair_id", lig),
"ligand_ref": system.meta.get("ligand_ref"),
"ligand_alt": system.meta.get("ligand_alt"),
"total_dg_kcal_mol": total_fe,
"total_se_kcal_mol": total_se,
"components": components,
}
summary_path.write_text(json.dumps(summary, indent=2) + "\n")
arts["rbfe_pair_summary_json"] = summary_path
analyzed_finished = fe_root / "analyze.ok"
open(analyzed_finished, "w").close()
analyze_rel = analyzed_finished.relative_to(system.root).as_posix()
results_rel = (results_dir / "Results.dat").relative_to(system.root).as_posix()
register_phase_state(
system.root,
"analyze",
required=[[analyze_rel, results_rel]],
success=[[analyze_rel, results_rel]],
)
logger.debug(f"[analyze:{lig}] FE analysis done. Artifacts: {', '.join(p.name for p in arts.values()) or 'none'}")
return ExecResult(job_ids=[], artifacts=arts)