Source code for batter.runtime.fe_repo

from __future__ import annotations

from dataclasses import dataclass
from datetime import datetime
from datetime import timezone

from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

import pandas as pd
from pydantic import BaseModel, Field
import json
import shutil

from .portable import ArtifactStore, Artifact


__all__ = ["WindowResult", "FERecord", "FEResultsRepository"]


[docs] class WindowResult(BaseModel): """ Result for a single lambda window/component. Parameters ---------- component : str Component key (e.g., 'e', 'v', 'z'). lam : float Lambda value in [0, 1]. dG : float Free-energy increment (kcal/mol). dG_se : float Standard error (kcal/mol). n_samples : int Samples (or effective sample size). meta : dict Extra metadata. """ component: str lam: float dG: float dG_se: float = 0.0 n_samples: int = 0 meta: Dict[str, Any] = Field(default_factory=dict)
[docs] class FERecord(BaseModel): """ A full FE result bundle (portable, versioned). Parameters ---------- run_id : str Unique run identifier. ligand : str Ligand identifier. mol_name : str Molecule resname. system_name : str Logical system name. fe_type : str Protocol type (e.g., 'uno_rest', 'asfe'). temperature : float Simulation temperature (K). method : {"mbar","ti"} Integration method. total_dG : float Total free energy (kcal/mol). total_se : float Standard error (kcal/mol). components : list[str] Active components in this run. created_at : str ISO-8601 timestamp (UTC, Z-suffix). windows : list[WindowResult] Per-window results. canonical_smiles : str, optional Canonicalised ligand SMILES captured during parameterization. original_name : str, optional Original ligand identifier or title when known. original_path : str, optional Source path of the ligand before staging. protocol : str Logical protocol used to generate the result (e.g., ``"abfe"``). sim_range : tuple[int, int], optional (start, end) lambda range used for analysis. status : {"success","failed","unbound"} Final status recorded for the ligand. """ run_id: str ligand: str mol_name: str system_name: str fe_type: str temperature: float method: Literal["mbar", "ti"] = "mbar" total_dG: float total_se: float = 0.0 components: List[str] = Field(default_factory=list) created_at: str = Field( default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds") ) windows: List[WindowResult] = Field(default_factory=list) canonical_smiles: str | None = None original_name: str | None = None original_path: str | None = None protocol: str = "abfe" sim_range: tuple[int, int] | None = None status: Literal["success", "failed", "unbound"] = "success"
[docs] class FEResultsRepository: def __init__(self, store: "ArtifactStore") -> None: self.store = store self._root = store.root / "results" self._idx = self._root / "index.csv" def _lig_dir(self, run_id: str, ligand: str) -> Path: return self._root / run_id / ligand def _normalize_row(self, row: dict[str, Any]) -> dict[str, Any]: normalized = dict(row) normalized.setdefault("temperature", pd.NA) normalized.setdefault("total_dG", pd.NA) normalized.setdefault("total_se", pd.NA) normalized.setdefault("canonical_smiles", "") normalized.setdefault("original_name", "") normalized.setdefault("original_path", "") normalized.setdefault("protocol", "") normalized.setdefault("sim_range", "") normalized.setdefault( "created_at", datetime.now(timezone.utc).isoformat(timespec="seconds") ) normalized.setdefault("status", "success") normalized.setdefault("failure_reason", "") return normalized def _append_index_row(self, row: dict[str, Any]) -> None: row = self._normalize_row(row) cols = [ "run_id", "ligand", "mol_name", "system_name", "temperature", "total_dG", "total_se", "canonical_smiles", "original_name", "original_path", "protocol", "sim_range", "status", "failure_reason", "created_at", ] if self._idx.exists(): df = pd.read_csv(self._idx) # Remove existing entry with same (run_id, ligand) if {"run_id", "ligand"}.issubset(df.columns): df = df[ ~((df["run_id"] == row["run_id"]) & (df["ligand"] == row["ligand"])) ].copy() else: # Start from an empty DataFrame with the right columns df = pd.DataFrame(columns=cols) # Ensure all expected columns exist for col in cols: if col not in df.columns: df[col] = pd.NA # Append the new row without using concat # Make sure we only write known columns; fill missing with NA new_row = {col: row.get(col, pd.NA) for col in cols} if df.empty: df = pd.DataFrame([new_row], columns=cols) else: df.loc[len(df)] = new_row # Enforce column order df = df[cols] df.to_csv(self._idx, index=False)
[docs] def save(self, rec: FERecord, copy_from: Path | None = None) -> None: lig_dir = self._lig_dir(rec.run_id, rec.ligand) lig_dir.mkdir(parents=True, exist_ok=True) # clear any stale failure marker when writing a success record (lig_dir / "failure.json").unlink(missing_ok=True) # write JSON record (lig_dir / "record.json").write_text(json.dumps(rec.__dict__, indent=2)) # optional: copy raw Results/ in if copy_from and copy_from.exists(): # keep raw artifacts alongside the record shutil.rmtree(lig_dir / "Results", ignore_errors=True) shutil.copytree(copy_from, lig_dir / "Results") # update index table (append-or-upsert by (run_id, ligand)) row = { "run_id": rec.run_id, "ligand": rec.ligand, "mol_name": rec.mol_name, "system_name": rec.system_name, "temperature": rec.temperature, "total_dG": rec.total_dG, "total_se": rec.total_se, "canonical_smiles": rec.canonical_smiles or "", "original_name": rec.original_name or "", "original_path": rec.original_path or "", "protocol": rec.protocol, "sim_range": rec.sim_range if rec.sim_range is not None else "", "created_at": rec.created_at, "status": rec.status, "failure_reason": "", } self._append_index_row(row)
[docs] def index(self) -> "pd.DataFrame": cols = [ "run_id", "ligand", "mol_name", "system_name", "temperature", "total_dG", "total_se", "canonical_smiles", "original_name", "original_path", "protocol", "sim_range", "created_at", ] if self._idx.exists(): df = pd.read_csv(self._idx) else: df = pd.DataFrame(columns=cols) # drop old columns if present for drop in ("fe_type", "components", "method"): if drop in df.columns: df = df.drop(columns=[drop]) # ensure columns exist for key in ("status", "failure_reason"): if key not in df.columns: df[key] = pd.NA for col in cols: if col not in df.columns: df[col] = pd.NA df["failure_reason"] = df["failure_reason"].fillna("") return df[cols + ["status", "failure_reason"]]
[docs] def record_failure( self, run_id: str, ligand: str, system_name: str, temperature: float, *, status: Literal["failed", "unbound"], reason: str | None = None, canonical_smiles: str | None = None, original_name: str | None = None, original_path: str | None = None, protocol: str = "abfe", sim_range: tuple[int, int] | None = None, ) -> None: lig_dir = self._lig_dir(run_id, ligand) lig_dir.mkdir(parents=True, exist_ok=True) failure_detail = { "run_id": run_id, "ligand": ligand, "status": status, "reason": reason or "", "protocol": protocol, "timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"), } (lig_dir / "failure.json").write_text(json.dumps(failure_detail, indent=2)) row = { "run_id": run_id, "ligand": ligand, "mol_name": "", "system_name": system_name, "temperature": temperature, "total_dG": pd.NA, "total_se": pd.NA, "canonical_smiles": canonical_smiles or "", "original_name": original_name or "", "original_path": original_path or "", "protocol": protocol, "sim_range": sim_range or "", "status": status, "failure_reason": reason or "", "created_at": failure_detail["timestamp"], } self._append_index_row(row)
[docs] def load(self, run_id: str, ligand: str) -> FERecord: p = self._lig_dir(run_id, ligand) / "record.json" d = json.loads(p.read_text()) d["components"] = ( d.get("components", "").split(",") if isinstance(d.get("components"), str) else d.get("components", []) ) return FERecord(**d)