"""Public API for BATTER.
This module collects the stable entry points intended for external consumption.
They fall into four broad categories:
* **Configuration helpers** – load and dump ``RunConfig`` / ``SimulationConfig`` objects.
* **Execution** – orchestrate complete workflows from a YAML definition.
* **Portable results** – inspect and copy artifacts produced by a run.
* **Utilities** – clone the state of an execution for reproducibility.
Typical usage
-------------
Run a workflow from a top-level YAML::
from batter.api import run_from_yaml
run_from_yaml("examples/mabfe_example.yaml")
Inspect FE records stored in a work directory::
from batter.api import list_fe_runs, load_fe_run
runs = list_fe_runs("work/adrb2")
latest = runs.iloc[-1]["run_id"]
# pass ``ligand`` when the run contains more than one ligand
record = load_fe_run("work/adrb2", latest, ligand="LIG1")
Run FE analysis on an existing execution::
from batter.api import run_analysis_from_execution
run_analysis_from_execution("work/adrb2", latest, ligand="LIG1")
For more examples, refer to ``docs/getting_started.rst`` and the tutorials.
"""
from __future__ import annotations
import json
import shutil
from pathlib import Path
from typing import Any, TYPE_CHECKING, Sequence, Union
from loguru import logger
from tqdm import tqdm
from ._version import __version__ # semantic version string
from .config.simulation import SimulationConfig
from .config.run import RunConfig
from .config import (
load_run_config,
dump_run_config,
load_simulation_config as load_sim_config,
dump_simulation_config as save_sim_config,
)
from .orchestrate.run import run_from_yaml, save_fe_records
from .runtime.portable import ArtifactStore
from .runtime.fe_repo import FEResultsRepository, FERecord, WindowResult
from .utils.exec_clone import clone_execution
from .systems.core import SimSystem, SystemMeta
from batter.pipeline.payloads import StepPayload
from batter.pipeline.step import Step
if TYPE_CHECKING:
import pandas as pd # type: ignore[assignment] # for type hints only
__all__ = [
# version
"__version__",
# configs
"SimulationConfig",
"RunConfig",
"load_run_config",
"dump_run_config",
"load_sim_config",
"save_sim_config",
# orchestration
"run_from_yaml",
# portable store + results
"ArtifactStore",
"FEResultsRepository",
"FERecord",
"WindowResult",
"SimSystem",
"list_fe_runs",
"load_fe_run",
"read_cinnabar_outputs",
"run_analysis_from_execution",
"clone_execution",
]
def _resolve_execution_run(work_root: Path, run_id: str | None) -> tuple[str, Path]:
"""Resolve an execution directory, defaulting to the most recent run."""
requested = (run_id or "").strip() or None
runs_root = work_root / "executions"
if requested:
run_dir = runs_root / requested
if not run_dir.is_dir():
raise FileNotFoundError(
f"Run '{requested}' does not exist under {work_root}."
)
return requested, run_dir
if not runs_root.is_dir():
raise FileNotFoundError(f"No executions found under {work_root}.")
candidates = [p for p in runs_root.iterdir() if p.is_dir()]
if not candidates:
raise FileNotFoundError(f"No executions found under {work_root}.")
latest = max(candidates, key=lambda p: p.stat().st_mtime)
logger.info(
f"No run_id provided; using latest execution '{latest.name}' under {work_root}."
)
return latest.name, latest
[docs]
def list_fe_runs(work_dir: Union[str, Path]) -> "pd.DataFrame":
"""
Return an index of FE runs contained in a portable work directory.
Parameters
----------
work_dir : str or Path
Path to the root directory of a BATTER execution (portable layout).
Returns
-------
pandas.DataFrame
DataFrame with one row per stored FE run. Columns include ``run_id``,
``ligand``, ``mol_name``, ``system_name``, ``temperature``, ``total_dG``,
``total_se``, ``canonical_smiles``, ``original_name``, ``original_path``,
``protocol``, ``analysis_start_step``, ``n_bootstraps``, ``status``,
``failure_reason``, and
``created_at``.
"""
store = ArtifactStore(Path(work_dir))
repo = FEResultsRepository(store)
return repo.index()
[docs]
def load_fe_run(
work_dir: Union[str, Path], run_id: str, ligand: str | None = None
) -> FERecord:
"""
Load a single FE record by ``run_id`` from a portable work directory.
Parameters
----------
work_dir : str or Path
Root directory of the BATTER execution.
run_id : str
Identifier of the FE run to load (as returned by :func:`list_fe_runs`).
ligand : str, optional
Ligand identifier when multiple ligands were processed in the run. If omitted,
the sole ligand is selected automatically or a ValueError is raised when
multiple matches exist.
Returns
-------
FERecord
Structured record containing total ΔG, standard error, components, and
per-window results.
"""
store = ArtifactStore(Path(work_dir))
repo = FEResultsRepository(store)
if ligand:
return repo.load(run_id, ligand)
df = repo.index()
matches = df[df["run_id"] == run_id]
if matches.empty:
raise KeyError(f"No FE records found for run_id '{run_id}'.")
if len(matches) > 1:
raise ValueError(
f"Multiple ligands stored for run_id '{run_id}'. "
"Call `load_fe_run` with the `ligand` argument or inspect `list_fe_runs`."
)
ligand_name = matches.iloc[0]["ligand"]
return repo.load(run_id, ligand_name)
[docs]
def read_cinnabar_outputs(
bundle_dir: Union[str, Path],
*,
require_absolute: bool = False,
):
"""
Read a generated Cinnabar export bundle from disk.
Parameters
----------
bundle_dir : str or Path
Directory containing ``cinnabar_relative.csv`` and optional absolute
and SFC correction CSVs produced by the Cinnabar export.
require_absolute : bool, optional
When ``True``, raise if the bundle does not contain
``cinnabar_absolute.csv``.
Returns
-------
tuple[pandas.DataFrame, pandas.DataFrame]
Relative and absolute tables. Each table includes uncorrected columns
and SFC correction columns when those outputs are present, with free-energy
units stored in a ``unit`` column. The ``*_uncorrected`` columns are
sourced from Cinnabar's CSVs, and the ``*_cycle_closure`` columns are
sourced from the SFC CSVs.
"""
from batter.analysis.cinnabar import read_cinnabar_outputs as _read
return _read(bundle_dir, require_absolute=require_absolute)
[docs]
def run_analysis_from_execution(
work_dir: Union[str, Path],
run_id: str | None = None,
*,
ligand: str | None = None,
components: Sequence[str] | None = None,
n_workers: int | None = None,
analysis_start_step: int | None = None,
n_bootstraps: int | None = None,
overwrite: bool = True,
raise_on_error: bool = True,
) -> None:
"""
Run FE analysis for a partially finished/finished execution.
Parameters
----------
work_dir : str or Path
Root directory containing the portable execution store.
run_id : str, optional
Identifier of the execution (e.g., ``run-20240101``). When omitted,
the most recently modified execution under ``<work_dir>/executions`` is used.
ligand : str, optional
Ligand identifier to target when only a subset should be analyzed.
components : sequence of str, optional
Components to include during analysis (overrides ``sim_cfg.components``).
n_workers : int, optional
Number of worker processes requested for the analysis handler.
analysis_start_step : int, optional
First production step to include in analysis (per window); overrides config.
n_bootstraps : int, optional
Number of MBAR bootstrap resamples; overrides config.
overwrite: bool, optional
When ``True`` (default), overwrite any existing analysis results for the run_id.
When ``False``, skip ligands that already have analysis outputs.
raise_on_error : bool, optional
When ``True`` (default) propagate errors raised by the analysis handler.
Set to ``False`` to log the failure and continue with other ligands.
"""
work_root = Path(work_dir)
run_id, run_dir = _resolve_execution_run(work_root, run_id)
config_dir = run_dir / "artifacts" / "config"
sim_cfg_path = config_dir / "sim.resolved.yaml"
if not sim_cfg_path.exists():
raise FileNotFoundError(
f"Simulation configuration missing for run '{run_id}' at {sim_cfg_path}."
)
sim_cfg = load_sim_config(sim_cfg_path)
run_meta_path = config_dir / "run_meta.json"
run_meta: dict[str, Any] = {}
if run_meta_path.exists():
run_meta = json.loads(run_meta_path.read_text()) or {}
protocol = run_meta.get("protocol", "abfe")
system_name = run_meta.get("system_name") or sim_cfg.system_name
index_path = run_dir / "artifacts" / "ligand_params" / "index.json"
if not index_path.exists():
raise FileNotFoundError(
f"Ligand index missing for run '{run_id}' at {index_path}."
)
ligands_payload = json.loads(index_path.read_text()) or {}
entries = ligands_payload.get("ligands", [])
if not entries:
raise RuntimeError(f"No ligands recorded for run '{run_id}'.")
param_dir_dict = {
entry.get("residue_name"): entry.get("store_dir")
for entry in entries
if entry.get("residue_name") and entry.get("store_dir")
}
requested = (ligand or "").strip() or None
protocol_lower = str(protocol).lower()
children: list[SimSystem] = []
if protocol_lower == "rbfe":
trans_root = run_dir / "simulations" / "transformations"
if not trans_root.is_dir():
raise FileNotFoundError(
f"RBFE transformations directory not found for run '{run_id}': {trans_root}"
)
lig_resname_map = {
entry.get("ligand"): entry.get("residue_name")
for entry in entries
if entry.get("ligand")
}
pair_dirs = sorted([p for p in trans_root.iterdir() if p.is_dir()])
for pair_dir in pair_dirs:
pair_id = pair_dir.name
ref = None
alt = None
if "~" in pair_id:
ref, alt = pair_id.split("~", 1)
# For RBFE runs, keep --ligand for compatibility:
# it can target an exact pair id or either endpoint ligand.
if requested:
endpoint_match = requested == ref or requested == alt
if requested != pair_id and not endpoint_match:
continue
fe_root = pair_dir / "fe"
if not fe_root.is_dir():
raise FileNotFoundError(
f"Simulation directory for RBFE pair '{pair_id}' is missing FE data at {fe_root}."
)
residue_ref = lig_resname_map.get(ref) if ref else None
residue_alt = lig_resname_map.get(alt) if alt else None
meta = SystemMeta.from_mapping(
{
"ligand": pair_id,
"residue_name": residue_ref,
"mode": "RBFE",
"param_dir_dict": dict(param_dir_dict) if param_dir_dict else {},
"pair_id": pair_id,
"ligand_ref": ref,
"ligand_alt": alt,
"residue_ref": residue_ref,
"residue_alt": residue_alt,
}
)
children.append(
SimSystem(
name=f"{system_name}:{pair_id}:{run_id}",
root=pair_dir,
meta=meta,
)
)
if requested and not children:
raise KeyError(
f"RBFE target '{ligand}' not present in run '{run_id}' "
"(expected a pair id like 'LIG1~LIG2' or an endpoint ligand name)."
)
if not children:
raise RuntimeError(f"No RBFE transformation pairs found in run '{run_id}'.")
target_label = "transformations"
else:
requested_set: Sequence[str] | None = [requested] if requested else None
for entry in entries:
lig_name = entry["ligand"]
if requested_set and lig_name not in requested_set:
continue
child_root = run_dir / "simulations" / lig_name
if not child_root.is_dir():
raise FileNotFoundError(
f"Simulation directory for ligand '{lig_name}' was not found at {child_root}."
)
meta = SystemMeta(
ligand=lig_name,
residue_name=entry.get("residue_name"),
param_dir_dict=dict(param_dir_dict) if param_dir_dict else {},
)
children.append(
SimSystem(
name=f"{system_name}:{lig_name}:{run_id}",
root=child_root,
meta=meta,
)
)
if requested_set and not children:
raise KeyError(f"Ligand '{ligand}' not present in run '{run_id}'.")
target_label = "ligands"
logger.info(f"Running analysis for {len(children)} {target_label} in run '{run_id}'.")
logger.info(f"Number of workers: {n_workers}")
payload_data: dict[str, Any] = {"sim": sim_cfg}
if components:
payload_data["components"] = list(components)
if n_workers is not None:
payload_data["analysis_n_workers"] = n_workers
payload_data["n_workers"] = n_workers
if analysis_start_step is not None:
if analysis_start_step < 0:
raise ValueError("analysis_start_step must be >= 0.")
analysis_start_step_val = int(analysis_start_step)
payload_data["analysis_start_step"] = analysis_start_step_val
logger.info(f"Analysis start step set to: {analysis_start_step_val}")
else:
analysis_start_step_val = int(getattr(sim_cfg, "analysis_start_step", 0))
payload_data["analysis_start_step"] = analysis_start_step_val
logger.info(f"Analysis start step loaded: {analysis_start_step_val}")
if n_bootstraps is not None:
if n_bootstraps < 0:
raise ValueError("n_bootstraps must be >= 0.")
n_bootstraps_val = int(n_bootstraps)
payload_data["n_bootstraps"] = n_bootstraps_val
logger.info(f"MBAR bootstrap resamples set to: {n_bootstraps_val}")
else:
n_bootstraps_val = int(getattr(sim_cfg, "n_bootstraps", 0) or 0)
payload_data["n_bootstraps"] = n_bootstraps_val
logger.info(f"MBAR bootstrap resamples loaded: {n_bootstraps_val}")
payload = StepPayload(**payload_data)
params = payload.to_mapping()
analyze_step = Step(name="analyze")
from batter.exec.handlers.fe_analysis import analyze_handler
def _analysis_outputs_present(fe_root: Path) -> bool:
return (
(fe_root / "Results" / "Results.dat").exists()
and (fe_root / "analyze.ok").exists()
)
def _clear_analysis_outputs(fe_root: Path) -> None:
shutil.rmtree(fe_root / "Results", ignore_errors=True)
(fe_root / "analyze.ok").unlink(missing_ok=True)
skipped = 0
for child in tqdm(children, desc="Running analysis", unit="ligand"):
fe_root = child.root / "fe"
ligand_name = child.meta.get("ligand") or child.name
if not overwrite and _analysis_outputs_present(fe_root):
logger.info(
f"Skipping analysis for ligand '{ligand_name}' (results already exist; overwrite=False)."
)
skipped += 1
continue
if overwrite:
_clear_analysis_outputs(fe_root)
try:
analyze_handler(analyze_step, child, params)
except Exception as exc:
msg = f"Analysis failed for ligand '{ligand_name}' in run '{run_id}': {exc}"
if raise_on_error:
raise RuntimeError(msg) from exc
logger.warning(msg)
continue
if skipped:
logger.info(f"Skipped analysis for {skipped} ligand(s) with existing results.")
store = ArtifactStore(work_root)
repo = FEResultsRepository(store)
failures = save_fe_records(
run_dir=run_dir,
run_id=run_id,
children_all=children,
sim_cfg_updated=sim_cfg,
repo=repo,
protocol=protocol,
analysis_start_step=analysis_start_step_val,
n_bootstraps=n_bootstraps_val,
)
if failures:
failed = ", ".join(
[f"{name} ({status}: {reason})" for name, status, reason in failures]
)
logger.warning(f"Analysis recorded issues for run '{run_id}': {failed}")
if protocol_lower == "rbfe":
try:
from batter.analysis.cinnabar import auto_write_rbfe_cinnabar_for_run
cinnabar_export = auto_write_rbfe_cinnabar_for_run(work_root, run_id)
logger.info(
f"Wrote RBFE Cinnabar bundle for run '{run_id}' to {cinnabar_export['output_dir']}."
)
if cinnabar_export.get("absolute_warning"):
logger.warning(str(cinnabar_export["absolute_warning"]))
if cinnabar_export.get("replicate_note"):
logger.info(str(cinnabar_export["replicate_note"]))
except Exception as exc:
logger.warning(
f"Automatic RBFE Cinnabar export failed for run '{run_id}': {exc}"
)
logger.info(f"Analysis complete for run '{run_dir}'.")