"""Handlers that queue free-energy equilibration and production jobs."""
from __future__ import annotations
import os
import subprocess
import shlex
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from loguru import logger
from batter.exec.slurm_mgr import SlurmJobManager, SlurmJobSpec
from batter.orchestrate.state_registry import register_phase_state
from batter.pipeline.payloads import StepPayload
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
from batter.utils import components_under
from batter.exec.handlers.batch import render_batch_slurm_script, _tpl_path
from textwrap import dedent
from batter._internal.templates import RUN_FILES_DIR as RUN_FILES_ORIG
def _equil_window_dir(root: Path, comp: str) -> Path:
"""Return the equilibration window directory for ``comp``."""
return root / "fe" / comp / f"{comp}-1"
def _production_window_dirs(root: Path, comp: str) -> List[Path]:
"""Return production window directories for ``comp``."""
base = root / "fe" / comp
if not base.exists():
return []
out: List[Path] = []
for p in sorted(base.iterdir()):
if not p.is_dir():
continue
if p.name == f"{comp}-1":
continue
if p.name.startswith(comp):
tail = p.name[len(comp) :]
if tail and tail.lstrip("-").isdigit():
out.append(p)
return out
def _spec_from_dir(
workdir: Path,
*,
finished_name: str,
job_name: str,
stage: str | None,
failed_name: str = "FAILED",
script_rel: str = "SLURMM-run",
header_name: str | None = None,
header_template: Path | None = None,
header_root: Path | None = None,
extra_env: Optional[Dict[str, str]] = None,
batch_script: Path | None = None,
submit_dir: Path | None = None,
) -> SlurmJobSpec:
"""Build a :class:`SlurmJobSpec` for ``workdir``."""
script_name = batch_script.name if batch_script else script_rel
return SlurmJobSpec(
workdir=workdir,
script_rel=script_name,
finished_name=finished_name,
failed_name=failed_name,
name=job_name,
header_name=header_name,
header_template=header_template,
header_root=header_root,
stage=stage,
extra_env=extra_env or {},
batch_script=batch_script,
submit_dir=submit_dir,
)
[docs]
def fe_equil_handler(
step: Step, system: SimSystem, params: Dict[str, Any]
) -> ExecResult:
"""Queue equilibration jobs for each component of a ligand.
Parameters
----------
step, system : ignored
Included for parity with the handler signature.
params : dict
Handler payload containing the job manager and configuration values.
Returns
-------
ExecResult
Number of jobs enqueued (without waiting for completion).
"""
payload = StepPayload.model_validate(params)
lig = system.meta.get("ligand", system.name)
max_jobs = int(payload.get("max_active_jobs", 500))
stage = payload.get("job_stage") or "fe_equil"
phase_name = payload.get("phase_name") or "fe_equil"
job_mgr = payload.get("job_mgr")
if not isinstance(job_mgr, SlurmJobManager):
raise ValueError(
"[fe_equil] payload['job_mgr'] must be an instance of SlurmJobManager."
)
comps = components_under(system.root)
if not comps:
raise FileNotFoundError(
f"[fe_equil:{lig}] No components found under {system.root/'fe'}"
)
# quota enforced inside the job manager before each add
register_phase_state(
system.root,
phase_name,
required=[
["fe/{comp}/{comp}-1/EQ_FINISHED"],
["fe/{comp}/{comp}-1/FAILED"],
],
success=[["fe/{comp}/{comp}-1/EQ_FINISHED"]],
failure=[["fe/{comp}/{comp}-1/FAILED"]],
)
count = 0
for comp in comps:
wd = _equil_window_dir(system.root, comp)
if not wd.exists():
logger.warning(
f"[fe_equil:{lig}] missing equil window dir: {wd} — skipping"
)
continue
env = {"ONLY_EQ": "1", "INPCRD": "full.inpcrd"}
extra_env = payload.get("extra_env") or {}
for key, value in extra_env.items():
env[str(key)] = str(value)
job_name = f"fep_{os.path.abspath(system.root)}_{comp}_fe_equil"
spec = _spec_from_dir(
wd,
finished_name="EQ_FINISHED",
job_name=job_name,
stage=stage,
header_name="SLURMM-Am.header",
header_template=RUN_FILES_ORIG / "SLURMM-Am.header",
header_root=Path(getattr(payload.get("sim"), "slurm_header_dir", Path.home() / ".batter"))
if payload.get("sim")
else None,
extra_env=env,
)
job_mgr.add(spec)
count += 1
if count == 0:
raise RuntimeError(f"[fe_equil:{lig}] No component equil windows to submit.")
logger.debug(f"[fe_equil:{lig}] enqueued {count} component equil job(s).")
# Don’t claim success/terminal state; we’re not waiting here.
return ExecResult(job_ids=[], artifacts={"count": count})
[docs]
def fe_handler(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult:
"""Queue production jobs for each component/window combination.
Parameters
----------
step, system : ignored
Provided for handler API compatibility.
params : dict
Handler payload containing the job manager and configuration values.
Returns
-------
ExecResult
Number of jobs enqueued (without waiting for completion).
"""
payload = StepPayload.model_validate(params)
lig = system.meta.get("ligand", system.name)
max_jobs = int(payload.get("max_active_jobs", 500))
stage = payload.get("job_stage") or "fe"
remd_enabled = False
if payload.sim is not None:
remd_enabled = str(getattr(payload.sim, "remd", "no")).lower() == "yes"
batch_mode = bool(payload.get("batch_mode"))
job_mgr = payload.get("job_mgr")
if not isinstance(job_mgr, SlurmJobManager):
raise ValueError(
"[fe] payload['job_mgr'] must be an instance of SlurmJobManager."
)
comps = components_under(system.root)
if not comps:
raise FileNotFoundError(
f"[fe:{lig}] No components found under {system.root/'fe'}"
)
if remd_enabled:
register_phase_state(
system.root,
"fe",
required=[["fe/{comp}/FINISHED"], ["fe/{comp}/FAILED"]],
success=[["fe/{comp}/FINISHED"]],
failure=[["fe/{comp}/FAILED"]],
)
else:
register_phase_state(
system.root,
"fe",
required=[
["fe/{comp}/{comp}{win:02d}/FINISHED"],
["fe/{comp}/{comp}{win:02d}/FAILED"],
],
success=[["fe/{comp}/{comp}{win:02d}/FINISHED"]],
failure=[["fe/{comp}/{comp}{win:02d}/FAILED"]],
)
count = 0
if batch_mode and not remd_enabled:
run_root = system.root.parent.parent if system.root.name else system.root
batch_root = payload.get("batch_run_root") or (run_root / "batch_run")
lig = system.meta.get("ligand", system.name)
safe_lig = lig.replace("/", "_")
lig_batch_dir = batch_root / safe_lig
batch_gpus = payload.get("batch_gpus")
helper = _write_ligand_fe_batch_runner(
system_root=system.root,
helper_root=lig_batch_dir,
ligand=lig,
batch_gpus=batch_gpus,
gpus_per_task=payload.get("batch_gpus_per_task") or 1,
)
extra_sbatch: list[str] = []
if batch_gpus:
extra_sbatch += ["--gres", f"gpu:{batch_gpus}"]
batch_script = render_batch_slurm_script(
batch_root=lig_batch_dir,
target_dir=lig_batch_dir,
run_script=helper.name,
env=None,
system_name=getattr(payload.get("sim"), "system_name", system.name),
stage="fe",
pose=safe_lig,
header_root=getattr(payload.get("sim"), "slurm_header_dir", None),
)
spec = SlurmJobSpec(
workdir=lig_batch_dir,
script_rel=batch_script.name,
finished_name=f"fe_{safe_lig}.FINISHED",
failed_name=f"fe_{safe_lig}.FAILED",
name=f"fe_{safe_lig}",
stage=stage,
batch_script=batch_script,
submit_dir=lig_batch_dir,
header_name="SLURMM-BATCH.header",
header_template=_tpl_path("SLURMM-BATCH.header"),
header_root=Path(getattr(payload.get("sim"), "slurm_header_dir", Path.home() / ".batter"))
if payload.get("sim")
else None,
extra_sbatch=extra_sbatch,
)
job_mgr.add(spec)
return ExecResult(job_ids=[], artifacts={"batch_run": batch_root})
for comp in comps:
if remd_enabled:
comp_dir = system.root / "fe" / comp
job_name = f"fep_{os.path.abspath(system.root)}_{comp}_remd"
spec = SlurmJobSpec(
workdir=comp_dir,
script_rel="SLURMM-BATCH-remd",
finished_name="FINISHED",
failed_name="FAILED",
name=job_name,
stage=stage,
header_name="SLURMM-BATCH-remd.header",
header_template=Path(__file__).resolve().parent.parent
/ "_internal"
/ "templates"
/ "remd_run_files"
/ "SLURMM-BATCH-remd.header",
header_root=Path(getattr(payload.get("sim"), "slurm_header_dir", Path.home() / ".batter"))
if payload.get("sim")
else None,
)
job_mgr.add(spec)
count += 1
continue
for wd in _production_window_dirs(system.root, comp):
env = {"INPCRD": f"../{comp}-1/eqnpt04.rst7"}
job_name = f"fep_{os.path.abspath(system.root)}_{comp}_{wd.name}_fe"
batch_script = None
if batch_mode:
batch_root = payload.get("batch_run_root") or (
system.root.parent.parent / "batch_run"
)
batch_script = render_batch_slurm_script(
batch_root=batch_root,
target_dir=wd,
run_script="run-local.bash",
env=env,
system_name=getattr(payload.get("sim"), "system_name", system.name),
stage="fe",
pose=f"{system.meta.get('ligand', system.name)}_{comp}_{wd.name}",
header_root=getattr(payload.get("sim"), "slurm_header_dir", None),
)
spec = _spec_from_dir(
wd,
finished_name="FINISHED",
job_name=job_name,
stage=stage,
header_name="SLURMM-Am.header",
header_template=RUN_FILES_ORIG / "SLURMM-Am.header",
header_root=Path(getattr(payload.get("sim"), "slurm_header_dir", Path.home() / ".batter"))
if payload.get("sim")
else None,
extra_env=env,
batch_script=batch_script,
submit_dir=batch_script.parent if batch_script else None,
)
job_mgr.add(spec)
count += 1
if count == 0:
raise RuntimeError(f"[fe:{lig}] No production windows to submit.")
logger.debug(f"[fe:{lig}] enqueued {count} production job(s).")
# Don’t claim success/terminal state; we’re not waiting here.
return ExecResult(job_ids=[], artifacts={"count": count})
def _write_ligand_fe_batch_runner(
*,
system_root: Path,
helper_root: Path,
ligand: str,
batch_gpus: int | None = None,
gpus_per_task: int = 1,
) -> Path:
"""Render a helper that launches all production windows for a single ligand in parallel."""
helper_root.mkdir(parents=True, exist_ok=True)
safe_lig = ligand.replace("/", "_")
helper = helper_root / f"run_fe_{safe_lig}.sh"
gpus_per_task = max(1, int(gpus_per_task))
gpu_line = (
f'TOTAL_GPUS="{batch_gpus}"'
if batch_gpus
else 'TOTAL_GPUS="${SLURM_GPUS_ON_NODE:-1}"'
)
text = (
dedent(
f"""
#!/usr/bin/env bash
set -euo pipefail
{gpu_line}
GPUS_PER_TASK={gpus_per_task}
if [[ -z "$TOTAL_GPUS" ]]; then
if [[ -n "${{SLURM_GPUS:-}}" ]]; then TOTAL_GPUS="${{SLURM_GPUS}}"; else TOTAL_GPUS="1"; fi
fi
slots=$((TOTAL_GPUS / GPUS_PER_TASK))
if [[ $slots -lt 1 ]]; then slots=1; fi
status=0
declare -a pids=()
running=0
for w in "{(system_root / 'fe').as_posix()}"/*/*; do
comp_dir=$(dirname "$w")
comp=$(basename "$comp_dir")
base=$(basename "$w")
if [[ "$base" == "$comp-1" ]]; then
continue
fi
if [[ -x "$w/run-local.bash" ]]; then
echo "[batter-batch] fe running $w"
(
cd "$w"
srun -N 1 -n 1 --gpus-per-task $GPUS_PER_TASK /bin/bash run-local.bash
) &
pids+=($!)
running=$((running + 1))
if [[ $running -ge $slots ]]; then
if wait -n; then :; else status=$?; fi
running=$((running - 1))
fi
fi
done
for pid in "${{pids[@]:-}}"; do
if wait "$pid"; then :; else status=$?; fi
done
if [[ $status -eq 0 ]]; then
touch "{(helper_root / f'fe_{safe_lig}.FINISHED').as_posix()}"
else
touch "{(helper_root / f'fe_{safe_lig}.FAILED').as_posix()}"
fi
exit $status
"""
).strip()
+ "\n"
)
helper.write_text(text)
try:
helper.chmod(0o755)
except Exception:
pass
return helper