"""
batter.orchestrate.run
======================
Top-level orchestration entry for BATTER runs.
This module wires:
YAML (RunConfig) → shared system build → bulk ligand staging →
single param job ("param_ligands") → per-ligand pipelines → FE record save.
"""
from __future__ import annotations
from datetime import datetime, timezone
import json
import smtplib
import shutil
from pathlib import Path
from typing import Any, Dict, List, Literal
from smtplib import SMTPException
import pandas as pd
import yaml
from loguru import logger
from batter.config.run import RunConfig
from batter.systems.core import SimSystem
from batter.exec.local import LocalBackend
from batter.exec.slurm import SlurmBackend
from batter.pipeline.pipeline import Pipeline
from batter.pipeline.step import Step
from batter.pipeline.payloads import StepPayload
from batter.runtime.portable import ArtifactStore
from batter.runtime.fe_repo import FEResultsRepository
from batter.exec.slurm_mgr import SlurmJobManager
from batter.orchestrate.backend import register_local_handlers
from batter.orchestrate.ligands import (
discover_staged_ligands,
resolve_ligand_map,
)
from batter.orchestrate.markers import (
handle_phase_failures,
partition_children_by_status,
run_phase_skipping_done,
is_done,
)
from batter.orchestrate.pipeline_utils import select_pipeline
from batter.orchestrate.results_io import (
extract_ligand_metadata,
save_fe_records,
)
from batter.orchestrate.run_support import (
compute_run_signature as _compute_run_signature,
generate_run_id,
ligand_names_path as _ligand_names_path,
load_stored_ligand_names as _load_stored_ligand_names,
payload_path as _payload_path,
resolve_signature_conflict as _resolve_signature_conflict,
select_run_id,
select_system_builder as _select_system_builder,
stored_payload as _stored_payload,
stored_signature as _stored_signature,
store_ligand_names as _store_ligand_names,
)
_PARENT_ONLY_STEP_NAMES = frozenset({"system_prep", "system_prep_asfe", "param_ligands"})
_PHASE_STEP_NAMES: dict[str, frozenset[str]] = {
"prepare_equil": frozenset({"prepare_equil"}),
"equil": frozenset({"equil"}),
"equil_analysis": frozenset({"equil_analysis"}),
"pre_prepare_fe": frozenset({"pre_prepare_fe"}),
"pre_fe_equil": frozenset({"pre_fe_equil"}),
"prepare_fe": frozenset({"prepare_fe", "prepare_fe_windows"}),
"fe_equil": frozenset({"fe_equil"}),
"fe": frozenset({"fe"}),
"analyze": frozenset({"analyze"}),
}
def _options_dict(value: Any) -> Dict[str, Any]:
if value is None:
return {}
if hasattr(value, "model_dump"):
return dict(value.model_dump(mode="json", exclude_none=True, exclude_unset=True))
if isinstance(value, dict):
return {str(key): val for key, val in value.items() if val is not None}
return dict(value)
def _rbfe_mapper_options(rbfe_cfg) -> Dict[str, Dict[str, Any]]:
if rbfe_cfg is None:
return {"kartograf": {}, "lomap": {}}
return {
"kartograf": _options_dict(getattr(rbfe_cfg, "kartograf", None)),
"lomap": _options_dict(getattr(rbfe_cfg, "lomap", None)),
}
def _slurm_registry_path(run_dir: Path) -> Path:
"""Return the registry path under artifacts/slurm, migrating legacy .slurm if present."""
new_path = run_dir / "artifacts" / "slurm" / "queue.jsonl"
old_path = run_dir / ".slurm" / "queue.jsonl"
if old_path.exists() and not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
try:
old_path.replace(new_path)
except Exception:
shutil.copy2(old_path, new_path)
return new_path
def _store_run_yaml_copy(run_dir: Path, yaml_path: Path) -> None:
"""Persist a copy of the user YAML under artifacts/config for future reuse."""
cfg_dir = run_dir / "artifacts" / "config"
cfg_dir.mkdir(parents=True, exist_ok=True)
dst = cfg_dir / "run_config.yaml"
if dst.exists():
return
try:
shutil.copy2(yaml_path, dst)
except Exception as exc:
logger.warning(f"Could not store run YAML copy at {dst}: {exc}")
def _clear_failure_markers(run_dir: Path) -> None:
"""Remove failure markers, retry counters, and progress caches under a run directory."""
sim_root = run_dir / "simulations"
if not sim_root.exists():
return
removed = 0
for marker_name in ("FAILED", "ATTEMPT_FAILED", "job_attempt.txt"):
for path in sim_root.rglob(marker_name):
try:
path.unlink()
removed += 1
except Exception:
continue
for path in sim_root.rglob("progress"):
if not path.is_dir():
continue
for csv in path.glob("*.csv"):
try:
csv.unlink()
removed += 1
except Exception:
continue
if removed:
logger.info(f"[cleanup] Removed {removed} failure/progress marker(s).")
progress_root = run_dir / "artifacts" / "progress"
if progress_root.exists():
try:
shutil.rmtree(progress_root)
logger.info(f"[cleanup] Removed progress cache folder: {progress_root}")
except Exception:
logger.warning(f"[cleanup] Failed to remove progress cache folder: {progress_root}")
def _build_per_ligand_pipeline(tpl: Pipeline, sim_cfg_updated: Any) -> Pipeline:
"""Clone ``tpl`` for per-ligand execution and refresh embedded sim configs."""
steps: list[Step] = []
for step in tpl.ordered_steps():
if step.name in _PARENT_ONLY_STEP_NAMES:
continue
payload = step.payload.copy_with() if step.payload is not None else None
if payload is not None and payload.sim is not None:
payload = payload.copy_with(sim=sim_cfg_updated)
steps.append(
Step(
name=step.name,
requires=[req for req in step.requires if req not in _PARENT_ONLY_STEP_NAMES],
payload=payload,
)
)
return Pipeline(steps)
def _format_summary_float(value: Any) -> str:
"""Render FE summary floats consistently for terminal/email tables."""
try:
if pd.isna(value):
return ""
return f"{float(value):.3f}"
except Exception:
return str(value)
def _build_run_summary_table(
repo: FEResultsRepository, run_id: str
) -> str | None:
"""Build a plain-text summary table for a completed run."""
try:
df = repo.index()
except Exception as exc:
logger.warning(f"Could not load FE summary for run '{run_id}': {exc}")
return None
if df.empty or "run_id" not in df.columns:
return None
summary = df[df["run_id"] == run_id].copy()
if summary.empty:
return None
cols = [
"original_name",
"ligand",
"mol_name",
"total_dG",
"total_se",
"status",
"failure_reason",
]
for col in cols:
if col not in summary.columns:
summary[col] = pd.NA
summary["ligand"] = summary["ligand"].fillna("")
summary["original_name"] = summary["original_name"].fillna("")
summary["original_name"] = summary["original_name"].mask(
summary["original_name"] == "", summary["ligand"]
)
summary = summary[cols].sort_values(
["status", "original_name", "ligand"], na_position="last", kind="stable"
)
for col in ("original_name", "ligand", "mol_name", "status", "failure_reason"):
summary[col] = summary[col].fillna("")
for col in ("total_dG", "total_se"):
summary[col] = summary[col].map(_format_summary_float)
with pd.option_context("display.max_columns", None, "display.width", 160):
return summary.to_string(index=False)
def _select_phase_pipeline(pipeline: Pipeline, step_names: frozenset[str]) -> Pipeline:
"""Return the sub-pipeline containing only ``step_names`` and their local edges."""
selected = [step for step in pipeline.ordered_steps() if step.name in step_names]
selected_names = {step.name for step in selected}
return Pipeline(
[
Step(
name=step.name,
requires=[req for req in step.requires if req in selected_names],
payload=step.payload,
)
for step in selected
]
)
def _run_phase_with_failure_policy(
phase: Pipeline,
children: List[SimSystem],
phase_name: str,
backend,
*,
max_workers: int | None = None,
on_failure: str | None = None,
job_mgr: SlurmJobManager | None = None,
dry_run: bool = False,
dry_run_message: str | None = None,
) -> tuple[List[SimSystem], bool]:
"""Run ``phase`` and apply prune/retry semantics in the current invocation.
Returns
-------
tuple[list[SimSystem], bool]
The surviving systems after applying the failure policy and a flag
indicating whether the caller should exit early for dry-run mode.
"""
mode = (on_failure or "").lower()
current_children = children
used_retry = False
while True:
finished = run_phase_skipping_done(
phase,
current_children,
phase_name,
backend,
max_workers=max_workers,
on_failure=on_failure,
)
if job_mgr is not None and not finished:
job_mgr.wait_all()
if dry_run and job_mgr.triggered:
logger.success(
dry_run_message
or f"[DRY-RUN] Reached first SLURM submission point ({phase_name}). Exiting without submitting."
)
return current_children, True
_ok, bad = partition_children_by_status(current_children, phase_name)
if not bad:
return current_children, False
if mode == "retry" and not used_retry:
names = ", ".join(c.meta.get("ligand", c.name) for c in bad)
logger.warning(
f"[{phase_name}] Retrying {len(bad)} ligand(s) once in this run: {names}"
)
current_children = handle_phase_failures(
current_children, phase_name, "retry"
)
used_retry = True
continue
final_mode = "prune" if mode == "retry" else on_failure
current_children = handle_phase_failures(
current_children, phase_name, final_mode
)
return current_children, False
def _build_rbfe_network_plan(
ligands: List[str],
lig_map: Dict[str, str],
rbfe_cfg,
config_dir: Path,
) -> dict:
from batter.rbfe import (
RBFENetwork,
resolve_mapping_fn,
load_mapping_file,
konnektor_pairs,
)
try:
import konnektor
except ImportError:
raise ImportError('konnektor package is not installed, install it with `conda install konnektor`')
from batter.config.utils import sanitize_ligand_name
available = [sanitize_ligand_name(x) for x in ligands if x]
if len(available) < 2:
raise RuntimeError("RBFE requires at least two ligands.")
mapping_source: Dict[str, Any] = {}
atom_mapper = str(getattr(rbfe_cfg, "atom_mapper", "kartograf") or "kartograf").lower()
mapper_options = _rbfe_mapper_options(rbfe_cfg)
pairs: List[tuple[str, str]] = []
if rbfe_cfg.mapping_file:
pairs = load_mapping_file(Path(rbfe_cfg.mapping_file))
network = RBFENetwork.from_ligands(available, mapping_fn=lambda _: pairs)
mapping_source["mapping_file"] = str(rbfe_cfg.mapping_file)
else:
mapping_name = rbfe_cfg.mapping or "default"
if mapping_name == "konnektor":
pairs = konnektor_pairs(
available,
{name: Path(lig_map[name]) for name in available},
layout=rbfe_cfg.konnektor_layout,
plot_path=config_dir / "rbfe_network.png",
atom_mapper=atom_mapper,
kartograf_options=mapper_options["kartograf"],
lomap_options=mapper_options["lomap"],
)
network = RBFENetwork.from_ligands(available, mapping_fn=lambda _: pairs)
mapping_source["mapping"] = "konnektor"
if rbfe_cfg.konnektor_layout:
mapping_source["konnektor_layout"] = rbfe_cfg.konnektor_layout
elif mapping_name in {"default", "star", "first"}:
try:
pairs = konnektor_pairs(
available,
{name: Path(lig_map[name]) for name in available},
layout="star",
plot_path=config_dir / "rbfe_network.png",
atom_mapper=atom_mapper,
kartograf_options=mapper_options["kartograf"],
lomap_options=mapper_options["lomap"],
)
network = RBFENetwork.from_ligands(available, mapping_fn=lambda _: pairs)
mapping_source["mapping"] = mapping_name
mapping_source["konnektor_layout"] = "star"
except Exception as exc:
logger.warning(
f"RBFE default mapping requested StarNetworkGenerator but failed "
f"({exc}); falling back to internal default mapping."
)
mapping_fn = resolve_mapping_fn(mapping_name)
network = RBFENetwork.from_ligands(available, mapping_fn=mapping_fn)
pairs = list(network.pairs)
mapping_source["mapping"] = mapping_name
else:
mapping_fn = resolve_mapping_fn(mapping_name)
network = RBFENetwork.from_ligands(available, mapping_fn=mapping_fn)
pairs = list(network.pairs)
mapping_source["mapping"] = mapping_name
mapping_source["atom_mapper"] = atom_mapper
if any(mapper_options.values()):
mapping_source["mapper_options"] = mapper_options
selected_mapper_options = mapper_options.get(atom_mapper, {})
if selected_mapper_options:
mapping_source["atom_mapper_options"] = selected_mapper_options
payload = network.to_mapping()
if bool(getattr(rbfe_cfg, "both_directions", False)):
bidirectional_pairs: List[List[str]] = []
seen: set[tuple[str, str]] = set()
for ref, alt in payload.get("pairs", []):
for pair in ((ref, alt), (alt, ref)):
if pair in seen:
continue
seen.add(pair)
bidirectional_pairs.append([pair[0], pair[1]])
payload["pairs"] = bidirectional_pairs
mapping_source["both_directions"] = True
payload.update(mapping_source)
rbfe_network_path = config_dir / "rbfe_network.json"
rbfe_network_path.write_text(json.dumps(payload, indent=2))
logger.info(
f"RBFE network planned: {len(network.ligands)} ligands, {len(network.pairs)} pairs with both directions={mapping_source.get('both_directions', False)}"
)
return payload
def _maybe_regenerate_rbfe_network_after_pruning(
*,
available_ligands: List[str],
lig_map: Dict[str, str],
payload: Dict[str, Any],
rbfe_cfg,
config_dir: Path,
) -> Dict[str, Any]:
"""Rebuild RBFE network if some planned ligands were pruned before transformations."""
from batter.config.utils import sanitize_ligand_name
available = [sanitize_ligand_name(x) for x in available_ligands if x]
available_set = set(available)
planned = [
sanitize_ligand_name(str(x))
for x in (payload.get("ligands") or [])
if x
]
pruned = [name for name in planned if name not in available_set]
if not pruned:
return payload
if len(available) < 2:
return payload
lig_map_sanitized = {
sanitize_ligand_name(str(name)): path for name, path in lig_map.items()
}
missing_paths = [name for name in available if name not in lig_map_sanitized]
if missing_paths:
logger.warning(
f"RBFE network regeneration skipped: missing ligand input path(s) for {', '.join(missing_paths)}."
)
return payload
logger.warning(
f"Detected {len(pruned)} pruned ligand(s) before RBFE transformations ({', '.join(pruned)}). "
"Regenerating RBFE network on remaining ligands."
)
try:
regenerated = _build_rbfe_network_plan(
available,
{name: lig_map_sanitized[name] for name in available},
rbfe_cfg,
config_dir,
)
except Exception as exc:
logger.warning(
f"RBFE network regeneration failed ({exc}); continuing with existing network payload."
)
return payload
logger.info(
f"RBFE network regenerated after pruning: {len(regenerated.get('ligands') or [])} ligands, "
f"{len(regenerated.get('pairs') or [])} pairs."
)
return regenerated
def _materialize_extra_conf_restraints(
source: Path | str | None, run_dir: Path, yaml_dir: Path
) -> Path | None:
"""Copy extra_conformation_restraints into artifacts/config for reuse and return the stored path."""
if not source:
return None
src = Path(source)
if not src.is_absolute():
src = (yaml_dir / src).resolve()
dest_dir = run_dir / "artifacts" / "config"
dest_dir.mkdir(parents=True, exist_ok=True)
dest = dest_dir / src.name
if dest.exists():
return dest
if src.exists():
try:
shutil.copy2(src, dest)
return dest
except Exception as exc:
logger.warning(f"Could not copy extra_conformation_restraints from {src}: {exc}")
return None
logger.warning(
f"extra_conformation_restraints missing at {src} and no stored copy under {dest}"
)
return None
[docs]
def run_from_yaml(
path: Path | str,
on_failure: Literal["prune", "raise", "retry"] = None,
run_overrides: Dict[str, Any] | None = None,
) -> None:
"""Execute a BATTER workflow described by a YAML file."""
path = Path(path)
logger.info(f"Starting BATTER run from {path}")
# Config must load successfully before email settings are available.
rc = RunConfig.load(path)
run_state: dict[str, Any] = {
"run_id": None,
"run_dir": Path(getattr(rc.run, "output_folder", path.parent)),
}
try:
_run_from_yaml_impl(
path,
rc,
on_failure=on_failure,
run_overrides=run_overrides,
run_state=run_state,
)
except Exception as exc:
_notify_run_failure(
rc,
run_state.get("run_id"),
run_state.get("run_dir"),
exc,
)
raise
def _run_from_yaml_impl(
path: Path | str,
rc: RunConfig,
on_failure: Literal["prune", "raise", "retry"] = None,
run_overrides: Dict[str, Any] | None = None,
run_state: dict[str, Any] | None = None,
) -> None:
"""Execute a BATTER workflow described by a YAML file.
Parameters
----------
path : str or pathlib.Path
Path to the top-level run YAML file.
on_failure : {"prune", "raise", "retry"}, optional
Override for the failure policy applied to ligand pipelines.
run_overrides : dict, optional
Overrides applied to the ``run`` section (e.g., only FE preparation).
"""
path = Path(path)
if run_overrides:
logger.info(f"Applying run overrides: {run_overrides}")
rc = rc.model_copy(update={"run": rc.run.model_copy(update=run_overrides)})
if on_failure:
rc.run.on_failure = on_failure
logger.info(
"Run configuration:\n{}",
yaml.safe_dump(rc.model_dump(mode="json"), sort_keys=False)
)
yaml_dir = path.parent
# ligand params output directory
if rc.create.param_outdir is None:
rc.create.param_outdir = str(rc.run.output_folder / "ligand_params")
else:
logger.info(
f"Using user-specified ligand param_outdir: {rc.create.param_outdir}"
)
sim_cfg = rc.resolved_sim_config()
logger.info(f"Loaded simulation config for system: {sim_cfg.system_name}")
# Backend
if rc.backend == "slurm":
backend = SlurmBackend()
else:
backend = LocalBackend()
register_local_handlers(backend)
# Shared System Build (system-level assets live under sys.root)
builder = _select_system_builder(rc.protocol, rc.run.system_type)
requested_run_id = getattr(rc.run, "run_id", "auto")
config_signature, config_payload = _compute_run_signature(path, run_overrides)
while True:
run_id, run_dir = select_run_id(
rc.run.output_folder,
rc.protocol,
rc.create.system_name,
requested_run_id,
)
if run_state is not None:
run_state["run_id"] = run_id
run_state["run_dir"] = run_dir
stored_sig, sig_path = _stored_signature(run_dir)
stored_payload = _stored_payload(run_dir)
if _resolve_signature_conflict(
stored_sig,
config_signature,
requested_run_id,
rc.run.allow_run_id_mismatch,
run_id=run_id,
run_dir=run_dir,
stored_payload=stored_payload,
current_payload=config_payload,
):
break
logger.info(
f"Existing execution {run_dir} uses different configuration hash ({stored_sig[:12]}); creating a fresh run.",
)
requested_run_id = "auto"
run_id = generate_run_id(rc.protocol, rc.create.system_name)
run_dir = Path(rc.run.output_folder) / "executions" / run_id
run_dir.mkdir(parents=True, exist_ok=True)
continue
logger.info(f"Using run_id='{run_id}' under {run_dir}")
_, sig_path = _stored_signature(run_dir)
_store_run_yaml_copy(run_dir, path)
# Ligands
lig_original_names: Dict[str, str] = {}
staged_lig_map = discover_staged_ligands(run_dir)
stored_names = _load_stored_ligand_names(run_dir)
if staged_lig_map:
lig_map = staged_lig_map
lig_original_names = stored_names
if lig_original_names:
logger.debug(
"Loaded %d original ligand names from %s",
len(lig_original_names),
_ligand_names_path(run_dir),
)
logger.info(
f"Resuming with {len(lig_map)} staged ligands discovered under {run_dir}"
)
else:
# Fall back to YAML resolution (requires original paths/files to exist)
lig_map, lig_original_names = resolve_ligand_map(rc, yaml_dir)
if lig_original_names:
_store_ligand_names(run_dir, lig_original_names)
rc.create.ligand_paths = {k: str(v) for k, v in lig_map.items()}
# Build system-prep params exactly once (after run_dir is known)
extra_conf_path = _materialize_extra_conf_restraints(
rc.create.extra_conformation_restraints, run_dir, yaml_dir
)
sys_params = {
"param_outdir": str(rc.create.param_outdir),
"system_name": rc.create.system_name,
"protein_input": str(rc.create.protein_input),
"system_input": str(rc.create.system_input) if rc.create.system_input else None,
"system_coordinate": (
str(rc.create.system_coordinate) if rc.create.system_coordinate else None
),
"ligand_paths": rc.create.ligand_paths,
"anchor_atoms": list(rc.create.anchor_atoms or []),
"protein_align": str(rc.create.protein_align),
"lipid_mol": list(rc.create.lipid_mol or []),
"other_mol": list(rc.create.other_mol or []),
"ligand_ff": rc.create.ligand_ff,
"retain_lig_prot": bool(rc.create.retain_lig_prot),
"charge": rc.create.param_charge,
"yaml_dir": str(yaml_dir),
"extra_restraints": rc.create.extra_restraints,
"extra_restraint_fc": rc.create.extra_restraint_fc,
"extra_conformation_restraints": extra_conf_path
or rc.create.extra_conformation_restraints,
}
base_meta = {}
if rc.protocol == "rbfe":
base_meta["mode"] = "RBFE"
sys_exec = SimSystem(name=rc.create.system_name, root=run_dir, meta=base_meta)
sys_exec = builder.build(sys_exec, rc.create)
sig_path.parent.mkdir(parents=True, exist_ok=True)
sig_path.write_text(config_signature + "\n")
_payload_path(run_dir).write_text(
json.dumps(config_payload, sort_keys=True, indent=2)
)
# Per-execution run directory (auto-resume latest when 'auto')
logger.add(run_dir / "batter.run.log", level="DEBUG")
dry_run = rc.run.dry_run
if dry_run:
logger.warning("DRY RUN mode enabled: no SLURM jobs will be submitted.")
# SLURM manager (registry per execution)
slurm_flags = rc.run.slurm.to_sbatch_flags() if rc.run.slurm else None
batch_mode = bool(getattr(rc.run, "batch_mode", False))
if batch_mode:
raise NotImplementedError('batch mode not implemented')
batch_poll = 10.0 if batch_mode else 60 * 15
registry_file = None if batch_mode else _slurm_registry_path(run_dir)
job_mgr = SlurmJobManager(
poll_s=batch_poll,
max_retries=3,
resubmit_backoff_s=30,
registry_file=registry_file,
dry_run=dry_run,
sbatch_flags=slurm_flags,
#batch_mode=batch_mode,
#batch_gpus=getattr(rc.run, "batch_gpus", None),
#gpus_per_task=getattr(rc.run, "batch_gpus_per_task", 1),
#srun_extra=getattr(rc.run, "batch_srun_extra", None),
max_active_jobs=rc.run.max_active_jobs,
partition=rc.run.slurm.partition if rc.run.slurm else None,
)
# Build pipeline with explicit sys_params
tpl = select_pipeline(
rc.protocol,
sim_cfg,
rc.run.only_fe_preparation,
sys_params=sys_params,
partition=rc.run.slurm.partition if rc.run.slurm else None,
)
# Run parent-only steps at run_dir by using a run-scoped SimSystem
run_sys = SimSystem(
name=f"{sys_exec.name}:{run_id}",
root=run_dir,
protein=sys_exec.protein,
topology=sys_exec.topology,
coordinates=sys_exec.coordinates,
ligands=tuple(), # parent steps don't need per-ligand sdf
lipid_mol=sys_exec.lipid_mol,
other_mol=sys_exec.other_mol,
anchors=sys_exec.anchors,
meta=sys_exec.meta,
)
# Stage ligands under this execution
lig_root = run_dir / "simulations"
lig_root.mkdir(parents=True, exist_ok=True)
for lig_name, lig_path in lig_map.items():
sub = lig_root / lig_name / "inputs"
if not (sub / f"ligand{Path(lig_path).suffix}").exists():
builder.make_child_for_ligand(sys_exec, lig_name, lig_path)
logger.debug(f"Staged {len(lig_map)} ligand subsystems under {lig_root}")
parent_failure = False
parent_only = Pipeline(
[step for step in tpl.ordered_steps() if step.name in _PARENT_ONLY_STEP_NAMES]
)
if parent_only.ordered_steps():
names = [s.name for s in parent_only.ordered_steps()]
logger.debug(f"Executing parent-only steps at {run_dir}: {names}")
for step in parent_only.ordered_steps():
if is_done(run_sys, step.name):
logger.info(f"[skip] {step.name}: finished.")
continue
try:
step_params = step.params
if isinstance(step_params, dict):
step_params = dict(step_params)
step_params["on_failure"] = rc.run.on_failure
elif hasattr(step_params, "copy_with"):
step_params = step_params.copy_with(on_failure=rc.run.on_failure)
backend.run(step, run_sys, step_params)
except Exception as exc:
if step.name == "param_ligands" and (rc.run.on_failure or "").lower() in {"prune", "retry"}:
parent_failure = True
logger.error(
"[param_ligands] encountered error with on_failure=%s: %s — continuing with successful ligands only.",
rc.run.on_failure,
exc,
)
break
raise
config_dir = run_dir / "artifacts" / "config"
config_dir.mkdir(parents=True, exist_ok=True)
overrides_path = config_dir / "sim_overrides.json"
sim_cfg_updated = sim_cfg
if rc.protocol == "rbfe":
rbfe_network_path = config_dir / "rbfe_network.json"
if not rbfe_network_path.exists():
from batter.config.run import RBFENetworkArgs
rbfe_cfg = rc.rbfe or RBFENetworkArgs()
_build_rbfe_network_plan(
list(lig_map.keys()), lig_map, rbfe_cfg, config_dir
)
if overrides_path.exists():
upd = json.loads(overrides_path.read_text()) or {}
sim_cfg_updated = sim_cfg.model_copy(
update={k: v for k, v in upd.items() if v is not None}
)
from batter.config.io import write_yaml_config
write_yaml_config(sim_cfg_updated, config_dir / "sim.resolved.yaml")
run_meta_path = config_dir / "run_meta.json"
run_meta_path.write_text(
json.dumps(
{
"protocol": rc.protocol,
"backend": rc.backend,
"system_name": rc.create.system_name,
"run_id": run_id,
},
indent=2,
)
)
per_lig = _build_per_ligand_pipeline(tpl, sim_cfg_updated)
phase_prepare_equil = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["prepare_equil"])
phase_equil = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["equil"])
phase_equil_analysis = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["equil_analysis"])
phase_pre_prepare_fe = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["pre_prepare_fe"])
phase_pre_fe_equil = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["pre_fe_equil"])
phase_prepare_fe = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["prepare_fe"])
phase_fe_equil = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["fe_equil"])
phase_fe = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["fe"])
phase_analyze = _select_phase_pipeline(per_lig, _PHASE_STEP_NAMES["analyze"])
param_idx_path = run_dir / "artifacts" / "ligand_params" / "index.json"
if not param_idx_path.exists():
if parent_failure and (rc.run.on_failure or "").lower() in {"prune", "retry"}:
logger.warning(
"Parametrization failed and no ligand param index was written; continuing with 0 ligands due to on_failure=%s.",
rc.run.on_failure,
)
param_index = {"ligands": []}
else:
raise FileNotFoundError(f"Missing ligand param index: {param_idx_path}")
else:
param_index = json.loads(param_idx_path.read_text())
param_dir_dict = {e["residue_name"]: e["store_dir"] for e in param_index["ligands"]}
lig_resname_map = {}
for entry in param_index["ligands"]:
lig = entry.get("ligand")
resn = entry.get("residue_name")
lig_resname_map[lig] = resn
children_all: List[SimSystem] = []
for lig_name, resn in lig_resname_map.items():
d = run_dir / "simulations" / lig_name
child_meta = sys_exec.meta.merge(
ligand=lig_name,
residue_name=resn,
param_dir_dict=param_dir_dict,
)
children_all.append(
SimSystem(
name=f"{sys_exec.name}:{lig_name}:{run_id}",
root=d,
protein=sys_exec.protein,
topology=sys_exec.topology,
coordinates=sys_exec.coordinates,
ligands=tuple([d / "inputs" / "ligand.sdf"]),
lipid_mol=sys_exec.lipid_mol,
other_mol=sys_exec.other_mol,
anchors=sys_exec.anchors,
meta=child_meta,
)
)
children = children_all
fe_children_all: List[SimSystem] = children_all
if getattr(rc.run, "clean_failures", False):
_clear_failure_markers(run_dir)
# --------------------
# PHASE 1: prepare_equil (parallel)
# --------------------
if phase_prepare_equil.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_prepare_equil,
children,
"prepare_equil",
backend,
max_workers=rc.run.max_workers,
on_failure=rc.run.on_failure,
)
if should_exit:
return
else:
logger.info("[skip] prepare_equil: no steps in this protocol.")
# --------------------
# PHASE 2: equil (parallel) → must COMPLETE for all ligands
# --------------------
def _inject_mgr(
p: Pipeline, stage_name: str, extra_payload: dict[str, Any] | None = None
) -> Pipeline:
job_mgr.set_stage(stage_name)
patched = []
for s in p.ordered_steps():
base_payload = s.payload or StepPayload()
updates = {"job_mgr": job_mgr, "job_stage": stage_name}
if rc.run.max_active_jobs is not None:
updates["max_active_jobs"] = rc.run.max_active_jobs
updates["batch_mode"] = batch_mode
updates["batch_run_root"] = run_dir / "batch_run"
updates["batch_gpus"] = getattr(rc.run, "batch_gpus", None)
updates["batch_gpus_per_task"] = getattr(rc.run, "batch_gpus_per_task", 1)
if extra_payload:
updates.update(extra_payload)
payload = base_payload.copy_with(**updates)
patched.append(Step(name=s.name, requires=s.requires, payload=payload))
return Pipeline(patched)
def _inject_payload(p: Pipeline, **updates: Any) -> Pipeline:
patched = []
for s in p.ordered_steps():
base_payload = s.payload or StepPayload()
payload = base_payload.copy_with(**updates)
patched.append(Step(name=s.name, requires=s.requires, payload=payload))
return Pipeline(patched)
phase_equil = _inject_mgr(phase_equil, "equil")
if phase_equil.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_equil,
children,
"equil",
backend,
max_workers=1,
on_failure=rc.run.on_failure,
job_mgr=job_mgr,
dry_run=dry_run,
dry_run_message="[DRY-RUN] Reached first SLURM submission point (equil). Exiting without submitting.",
)
if should_exit:
return
else:
logger.info("[skip] equil: no steps in this protocol.")
# --------------------
# PHASE 2.5: equil_analysis (parallel) → prune UNBOUND if requested
# --------------------
# prune UNBOUND ligands before FE prep
unbound_children: list[SimSystem] = []
def _filter_bound(children_list):
keep = []
for c in children_list:
if (c.root / "equil" / "UNBOUND").exists():
lig = c.meta.get("ligand", c.name)
logger.warning(f"Pruning UNBOUND ligand after equil: {lig}")
unbound_children.append(c)
continue
keep.append(c)
return keep
if phase_equil_analysis.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_equil_analysis,
children,
"equil_analysis",
backend,
max_workers=rc.run.max_workers,
on_failure=rc.run.on_failure,
)
if should_exit:
return
children = _filter_bound(children)
else:
logger.info("[skip] equil_analysis: no steps in this protocol.")
# --------------------
# PHASE 2.6: pre_prepare_fe (RBFE ligand prep) → z-1 only
# --------------------
if phase_pre_prepare_fe.ordered_steps():
phase_pre_prepare_fe = _inject_payload(
phase_pre_prepare_fe,
components=["z"],
component_lambdas={"z": [0.0]},
phase_name="pre_prepare_fe",
)
children, should_exit = _run_phase_with_failure_policy(
phase_pre_prepare_fe,
children,
"pre_prepare_fe",
backend,
max_workers=rc.run.max_workers,
on_failure=rc.run.on_failure,
)
if should_exit:
return
else:
logger.info("[skip] pre_prepare_fe: no steps in this protocol.")
# --------------------
# PHASE 2.7: pre_fe_equil → must COMPLETE for all ligands
# --------------------
phase_pre_fe_equil = _inject_mgr(
phase_pre_fe_equil,
"pre_fe_equil",
extra_payload={"phase_name": "pre_fe_equil", "extra_env": {"SKIP_WINDOW_EQ": "1"}},
)
if phase_pre_fe_equil.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_pre_fe_equil,
children,
"pre_fe_equil",
backend,
max_workers=1,
on_failure=rc.run.on_failure,
job_mgr=job_mgr,
dry_run=dry_run,
dry_run_message="[DRY-RUN] Reached first SLURM submission point (pre_fe_equil). Exiting without submitting.",
)
if should_exit:
return
else:
logger.info("[skip] pre_fe_equil: no steps in this protocol.")
# --------------------
# RBFE: build transformation systems (pairs) after pre_fe_equil
# --------------------
if rc.protocol == "rbfe":
from batter.rbfe import RBFENetwork
from batter.config.utils import sanitize_ligand_name
from batter.config.run import RBFENetworkArgs
available = [c.meta.get("ligand") for c in children if c.meta.get("ligand")]
if len(available) < 2:
raise RuntimeError(
"RBFE requires at least two ligands that completed equilibration."
)
available = [sanitize_ligand_name(x) for x in available]
available_set = set(available)
rbfe_network_path = config_dir / "rbfe_network.json"
if rbfe_network_path.exists():
payload = json.loads(rbfe_network_path.read_text())
else:
rbfe_cfg = rc.rbfe or RBFENetworkArgs()
payload = _build_rbfe_network_plan(
list(lig_map.keys()), lig_map, rbfe_cfg, config_dir
)
if (rc.run.on_failure or "").lower() in {"prune", "retry"}:
rbfe_cfg = rc.rbfe or RBFENetworkArgs()
payload = _maybe_regenerate_rbfe_network_after_pruning(
available_ligands=available,
lig_map=lig_map,
payload=payload,
rbfe_cfg=rbfe_cfg,
config_dir=config_dir,
)
pairs = payload.get("pairs") or []
if not pairs:
raise RuntimeError("RBFE mapping produced no ligand pairs.")
cleaned_pairs = []
for pair in pairs:
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise RuntimeError(f"RBFE mapping entries must be 2-tuples; got {pair!r}.")
cleaned_pairs.append(
(
sanitize_ligand_name(str(pair[0])),
sanitize_ligand_name(str(pair[1])),
)
)
if (rc.run.on_failure or "").lower() in {"prune", "retry"}:
pruned = [
p
for p in cleaned_pairs
if p[0] in available_set and p[1] in available_set
]
if not pruned:
raise RuntimeError(
"RBFE mapping does not include any available ligands after pruning."
)
if len(pruned) != len(cleaned_pairs):
logger.warning(
"Pruned %d RBFE pair(s) due to on_failure=%s.",
len(cleaned_pairs) - len(pruned),
rc.run.on_failure,
)
cleaned_pairs = pruned
network = RBFENetwork.from_ligands(available, mapping_fn=lambda _: cleaned_pairs)
atom_mapper = str(
payload.get("atom_mapper")
or getattr(rc.rbfe, "atom_mapper", "kartograf")
or "kartograf"
).lower()
payload_mapper_options = payload.get("mapper_options")
if isinstance(payload_mapper_options, dict):
mapper_options = {
"kartograf": _options_dict(payload_mapper_options.get("kartograf")),
"lomap": _options_dict(payload_mapper_options.get("lomap")),
}
else:
mapper_options = _rbfe_mapper_options(rc.rbfe)
payload_atom_mapper_options = payload.get("atom_mapper_options")
if isinstance(payload_atom_mapper_options, dict):
atom_mapper_options = _options_dict(payload_atom_mapper_options)
else:
atom_mapper_options = mapper_options.get(atom_mapper, {})
# Build transformation systems under simulations/transformations/
trans_root = run_dir / "simulations" / "transformations"
trans_root.mkdir(parents=True, exist_ok=True)
rbfe_children: List[SimSystem] = []
for ref, alt in network.pairs:
pair_id = f"{ref}~{alt}"
pair_dir = trans_root / pair_id
inputs_dir = pair_dir / "inputs"
inputs_dir.mkdir(parents=True, exist_ok=True)
ref_src = Path(lig_map[ref])
alt_src = Path(lig_map[alt])
ref_dst = inputs_dir / f"{ref}{ref_src.suffix}"
alt_dst = inputs_dir / f"{alt}{alt_src.suffix}"
if not ref_dst.exists():
shutil.copy2(ref_src, ref_dst)
if not alt_dst.exists():
shutil.copy2(alt_src, alt_dst)
resn_ref = lig_resname_map.get(ref)
resn_alt = lig_resname_map.get(alt)
if not resn_ref or not resn_alt:
raise RuntimeError(
f"Missing residue names for RBFE pair {pair_id}: {ref}={resn_ref}, {alt}={resn_alt}."
)
pair_meta = sys_exec.meta.merge(
ligand=pair_id,
residue_name=resn_ref,
mode="RBFE",
param_dir_dict=param_dir_dict,
pair_id=pair_id,
ligand_ref=ref,
ligand_alt=alt,
residue_ref=resn_ref,
residue_alt=resn_alt,
input_ref=str(ref_dst),
input_alt=str(alt_dst),
atom_mapper=atom_mapper,
atom_mapper_options=atom_mapper_options,
kartograf_options=mapper_options.get("kartograf", {}),
lomap_options=mapper_options.get("lomap", {}),
)
rbfe_children.append(
SimSystem(
name=f"{sys_exec.name}:{pair_id}:{run_id}",
root=pair_dir,
protein=sys_exec.protein,
topology=sys_exec.topology,
coordinates=sys_exec.coordinates,
ligands=tuple([ref_dst, alt_dst]),
lipid_mol=sys_exec.lipid_mol,
other_mol=sys_exec.other_mol,
anchors=sys_exec.anchors,
meta=pair_meta,
)
)
# Switch to transformation systems for FE stages/results
children = rbfe_children
fe_children_all = rbfe_children
# --------------------
# PHASE 3: prepare_fe (parallel)
# --------------------
if phase_prepare_fe.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_prepare_fe,
children,
"prepare_fe",
backend,
max_workers=rc.run.max_workers,
on_failure=rc.run.on_failure,
)
if should_exit:
return
else:
logger.info("[skip] prepare_fe: no steps in this protocol.")
# --------------------
# PHASE 4: fe_equil → must COMPLETE for all ligands
# --------------------
phase_fe_equil = _inject_mgr(phase_fe_equil, "fe_equil")
if phase_fe_equil.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_fe_equil,
children,
"fe_equil",
backend,
max_workers=1,
on_failure=rc.run.on_failure,
job_mgr=job_mgr,
dry_run=dry_run,
dry_run_message="[DRY-RUN] Reached first SLURM submission point (fe_equil). Exiting without submitting.",
)
if should_exit:
return
else:
logger.info("[skip] fe_equil: no steps in this protocol.")
# --------------------
# PHASE 5: fe → must COMPLETE for all ligands
# --------------------
phase_fe = _inject_mgr(phase_fe, "fe")
has_fe_phase = bool(phase_fe.ordered_steps())
if has_fe_phase:
children, should_exit = _run_phase_with_failure_policy(
phase_fe,
children,
"fe",
backend,
max_workers=1,
on_failure=rc.run.on_failure,
job_mgr=job_mgr,
dry_run=dry_run,
dry_run_message="[DRY-RUN] Reached first SLURM submission point (fe). Exiting without submitting.",
)
if should_exit:
return
else:
logger.info("[skip] fe: no steps in this protocol.")
# --------------------
# PHASE 6: analyze (parallel)
# --------------------
def _inject_analysis_workers(p: Pipeline) -> Pipeline:
patched = []
for s in p.ordered_steps():
payload = (s.payload or StepPayload()).copy_with(
analysis_n_workers=rc.run.max_workers
)
patched.append(Step(name=s.name, requires=s.requires, payload=payload))
return Pipeline(patched)
phase_analyze = _inject_analysis_workers(phase_analyze)
if phase_analyze.ordered_steps():
children, should_exit = _run_phase_with_failure_policy(
phase_analyze,
children,
"analyze",
backend,
max_workers=rc.run.max_workers,
on_failure=rc.run.on_failure,
)
if should_exit:
return
else:
logger.info("[skip] analyze: no steps in this protocol.")
# --------------------
# FE record save
# --------------------
if not has_fe_phase:
logger.info(
"FE production skipped (--only-equil); ending run without FE record export."
)
return
store = ArtifactStore(rc.run.output_folder)
repo = FEResultsRepository(store)
analysis_start_step = sim_cfg_updated.analysis_start_step
if analysis_start_step is not None:
analysis_start_step = int(analysis_start_step)
failures: list[tuple[str, str, str]] = []
if rc.protocol != "rbfe":
for child in unbound_children:
ligand = child.meta["ligand"]
reason = "UNBOUND detected during equilibration"
canonical_smiles, original_name, original_path = extract_ligand_metadata(
child, lig_original_names
)
repo.record_failure(
run_id=run_id,
ligand=ligand,
system_name=sim_cfg_updated.system_name,
temperature=sim_cfg_updated.temperature,
status="unbound",
reason=reason,
canonical_smiles=canonical_smiles,
original_name=original_name,
original_path=original_path,
protocol=rc.protocol,
analysis_start_step=analysis_start_step,
)
failures.append((ligand, "unbound", reason))
failures.extend(
save_fe_records(
run_dir=run_dir,
run_id=run_id,
children_all=fe_children_all,
sim_cfg_updated=sim_cfg_updated,
repo=repo,
protocol=rc.protocol,
analysis_start_step=analysis_start_step,
)
)
if failures:
failed = ", ".join(
[f"{n} ({status}: {reason})" for n, status, reason in failures]
)
logger.warning(f"{len(failures)} ligand(s) had post-run issues: {failed}")
if rc.protocol == "rbfe":
try:
from batter.analysis.cinnabar import auto_write_rbfe_cinnabar_for_run
cinnabar_export = auto_write_rbfe_cinnabar_for_run(rc.run.output_folder, run_id)
logger.info(
f"Wrote RBFE Cinnabar bundle for run '{run_id}' to {cinnabar_export['output_dir']}."
)
if cinnabar_export.get("absolute_warning"):
logger.warning(str(cinnabar_export["absolute_warning"]))
if cinnabar_export.get("replicate_note"):
logger.info(str(cinnabar_export["replicate_note"]))
except Exception as exc:
logger.warning(
f"Automatic RBFE Cinnabar export failed for run '{run_id}': {exc}"
)
summary_table = _build_run_summary_table(repo, run_id)
if summary_table:
logger.info(f"Final FE summary for run '{run_id}':\n{summary_table}")
else:
logger.warning(
f"No FE summary rows found for run '{run_id}' after FE record export."
)
logger.success(
f"All phases completed {run_dir}. FE records saved to repository {rc.run.output_folder}/results/."
)
_notify_run_completion(rc, run_id, run_dir, failures, summary_table=summary_table)
def _notify_run_completion(
rc: RunConfig,
run_id: str,
run_dir: Path,
failures: list[tuple[str, str, str]],
summary_table: str | None = None,
) -> None:
_notify_run_status(
rc,
status="completed",
run_id=run_id,
run_dir=run_dir,
failures=failures,
summary_table=summary_table,
)
def _notify_run_failure(
rc: RunConfig,
run_id: str | None,
run_dir: Path | None,
error: Exception,
) -> None:
_notify_run_status(
rc,
status="failed",
run_id=run_id,
run_dir=run_dir,
error=error,
)
def _notify_run_status(
rc: RunConfig,
status: Literal["completed", "failed"],
run_id: str | None,
run_dir: Path | None,
failures: list[tuple[str, str, str]] | None = None,
error: Exception | None = None,
summary_table: str | None = None,
) -> None:
recipient = rc.run.email_on_completion
if not recipient:
return
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
display_run_id = run_id or "unknown"
if status == "failed":
subject = f"BATTER run '{display_run_id}' of {rc.create.system_name} failed"
else:
subject = f"BATTER run '{display_run_id}' of {rc.create.system_name} completed"
results_path = Path(rc.run.output_folder) / "results"
body_lines = ["Hi there!", ""]
if status == "failed":
body_lines.extend(
[
f"Your BATTER run '{rc.create.system_name}' (run_id='{display_run_id}') failed at {timestamp} UTC.",
f"Protocol: {rc.protocol}",
f"Last known run path: {run_dir or rc.run.output_folder}",
"",
]
)
if error is not None:
body_lines.extend(
[
"Error:",
str(error),
"",
]
)
body_lines.append("FE records may be incomplete because the run exited early.")
else:
body_lines.extend(
[
f"Your BATTER run '{rc.create.system_name}' (run_id='{display_run_id}') completed at {timestamp} UTC.",
f"Protocol: {rc.protocol}",
f"Output folder: {run_dir}",
f"FE records stored under: {results_path}",
"",
]
)
if failures:
body_lines.append(
"The following ligand(s) had post-run issues (see logs for additional context):"
)
for ligand, failure_status, reason in failures:
body_lines.append(f"- {ligand} ({failure_status}): {reason}")
else:
body_lines.append("No ligand failures were detected.")
if summary_table:
body_lines.extend(
[
"",
"Final FE summary:",
"",
summary_table,
]
)
body_lines.extend(
[
"",
"Best wishes,",
"BATTER",
]
)
message_body = "\n".join(body_lines)
sender = rc.run.email_sender
if not sender:
logger.warning(
"No sender email configured; cannot send completion notification. set `run.email_sender` in your YAML."
)
return
message = (
f"From: batter <{sender}>\n"
f"To: {recipient}\n"
f"Subject: {subject}\n\n"
f"{message_body}"
)
try:
with smtplib.SMTP("localhost") as smtp:
smtp.sendmail(sender, [recipient], message)
logger.info(f"Sent completion notification to {recipient}")
except SMTPException as exc:
logger.warning(f"Failed to send completion email to {recipient}: {exc}")
except Exception as exc: # pragma: no cover - best-effort notification
logger.warning(f"Unexpected error while sending completion email: {exc}")