from __future__ import annotations
import os
import tempfile
from filelock import FileLock
from dataclasses import dataclass
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
import pandas as pd
from pydantic import BaseModel, Field, field_validator
from loguru import logger
import json
import shutil
from .portable import ArtifactStore, Artifact
__all__ = ["WindowResult", "FERecord", "FEResultsRepository"]
[docs]
class WindowResult(BaseModel):
"""
Result for a single lambda window/component.
Parameters
----------
component : str
Component key (e.g., 'e', 'v', 'z').
lam : float
Lambda value in [0, 1].
dG : float
Free-energy increment (kcal/mol).
dG_se : float
Standard error (kcal/mol).
n_samples : int
Samples (or effective sample size).
meta : dict
Extra metadata.
"""
component: str
lam: float
dG: float
dG_se: float = 0.0
n_samples: int = 0
meta: Dict[str, Any] = Field(default_factory=dict)
[docs]
class FERecord(BaseModel):
"""
A full FE result bundle (portable, versioned).
Parameters
----------
run_id : str
Unique run identifier.
ligand : str
Ligand identifier.
mol_name : str
Molecule resname.
system_name : str
Logical system name.
fe_type : str
Protocol type (e.g., 'uno_rest', 'asfe').
temperature : float
Simulation temperature (K).
method : {"mbar","ti"}
Integration method.
total_dG : float
Total free energy (kcal/mol).
total_se : float
Standard error (kcal/mol).
components : list[str]
Active components in this run.
created_at : str
ISO-8601 timestamp (UTC, Z-suffix).
windows : list[WindowResult]
Per-window results.
canonical_smiles : str, optional
Canonicalised ligand SMILES captured during parameterization.
original_name : str, optional
Original ligand identifier or title when known.
original_path : str, optional
Source path of the ligand before staging.
protocol : str
Logical protocol used to generate the result (e.g., ``"abfe"``).
analysis_start_step : int, optional
First production step included in analysis.
n_bootstraps : int, optional
Number of MBAR bootstrap resamples used during analysis.
include_in_analysis : bool
Whether downstream aggregate analyses, such as Cinnabar export, should use
this record.
status : {"success","failed","unbound"}
Final status recorded for the ligand.
"""
run_id: str
ligand: str
mol_name: str
system_name: str
fe_type: str
temperature: float
method: Literal["mbar", "ti"] = "mbar"
total_dG: float
total_se: float = 0.0
components: List[str] = Field(default_factory=list)
created_at: str = Field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds")
)
windows: List[WindowResult] = Field(default_factory=list)
canonical_smiles: str | None = None
original_name: str | None = None
original_path: str | None = None
protocol: str = "abfe"
analysis_start_step: int | None = None
n_bootstraps: int | None = None
include_in_analysis: bool = True
status: Literal["success", "failed", "unbound"] = "success"
@field_validator("analysis_start_step", "n_bootstraps", mode="before")
@classmethod
def _coerce_optional_int(cls, v: Any) -> Any:
if v is None or v is pd.NA:
return None
if isinstance(v, str) and not v.strip():
return None
try:
return int(v)
except (TypeError, ValueError):
return None
[docs]
class FEResultsRepository:
def __init__(self, store: "ArtifactStore") -> None:
self.store = store
self._root = store.root / "results"
self._idx = self._root / "index.csv"
self._idx_lock = self._root / ".index.csv.lock"
def _lig_dir(self, run_id: str, ligand: str) -> Path:
return self._root / run_id / ligand
[docs]
def ligand_dir(self, run_id: str, ligand: str) -> Path:
return self._lig_dir(run_id, ligand)
def _publish_index_file(self, tmp_path: str) -> None:
os.replace(tmp_path, self._idx)
# ``mkstemp`` creates files as 0600. The shared FE index is intended to
# be inspectable by collaborators, while remaining writable by owner only.
os.chmod(self._idx, 0o644)
@staticmethod
def _normalize_optional_int(value: Any) -> int | None:
if value is None or value is pd.NA:
return None
if isinstance(value, str):
value = value.strip()
if not value:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
@staticmethod
def _normalize_n_bootstraps(value: Any) -> int:
normalized = FEResultsRepository._normalize_optional_int(value)
return 0 if normalized is None else normalized
@staticmethod
def _normalize_bool(value: Any, *, default: bool = True) -> bool:
if value is None or value is pd.NA:
return bool(default)
if isinstance(value, str):
text = value.strip().lower()
if not text:
return bool(default)
if text in {
"1",
"true",
"t",
"yes",
"y",
"on",
"enabled",
"include",
"included",
}:
return True
if text in {
"0",
"false",
"f",
"no",
"n",
"off",
"disabled",
"exclude",
"excluded",
}:
return False
try:
if pd.isna(value):
return bool(default)
except Exception:
pass
return bool(value)
def _normalize_row(self, row: dict[str, Any]) -> dict[str, Any]:
normalized = dict(row)
normalized.setdefault("temperature", pd.NA)
normalized.setdefault("total_dG", pd.NA)
normalized.setdefault("total_se", pd.NA)
normalized.setdefault("canonical_smiles", "")
normalized.setdefault("original_name", "")
normalized.setdefault("original_path", "")
normalized.setdefault("protocol", "")
normalized.setdefault("analysis_start_step", pd.NA)
normalized.setdefault("n_bootstraps", 0)
normalized.setdefault("include_in_analysis", True)
normalized.setdefault(
"created_at", datetime.now(timezone.utc).isoformat(timespec="seconds")
)
normalized.setdefault("status", "success")
normalized.setdefault("failure_reason", "")
return normalized
def _append_index_row(self, row: dict[str, Any]) -> None:
row = self._normalize_row(row)
cols = [
"run_id",
"ligand",
"mol_name",
"system_name",
"temperature",
"total_dG",
"total_se",
"canonical_smiles",
"original_name",
"original_path",
"protocol",
"analysis_start_step",
"n_bootstraps",
"include_in_analysis",
"status",
"failure_reason",
"created_at",
]
# serialize all index read/modify/write
self._idx.parent.mkdir(parents=True, exist_ok=True)
lock = FileLock(str(self._idx_lock))
with lock: # (optionally: lock.acquire(timeout=120) if you want a timeout)
if self._idx.exists():
df = pd.read_csv(self._idx)
else:
df = pd.DataFrame(columns=cols)
for col in cols:
if col not in df.columns:
df[col] = pd.NA
if {"run_id", "ligand"}.issubset(df.columns):
row_step = self._normalize_optional_int(row.get("analysis_start_step"))
row_bootstraps = self._normalize_n_bootstraps(row.get("n_bootstraps"))
step_series = df["analysis_start_step"].map(self._normalize_optional_int)
bootstrap_series = df["n_bootstraps"].map(self._normalize_n_bootstraps)
if row_step is None:
same_step = step_series.isna()
else:
same_step = step_series == row_step
same_bootstrap = bootstrap_series == row_bootstraps
existing = df.loc[
(df["run_id"] == row["run_id"])
& (df["ligand"] == row["ligand"])
& same_step
& same_bootstrap
]
if (
"include_in_analysis" in df.columns
and not existing.empty
and self._normalize_bool(row.get("include_in_analysis"), default=True)
):
row["include_in_analysis"] = self._normalize_bool(
existing.iloc[-1].get("include_in_analysis"),
default=True,
)
logger.info(
"Updating index for run_id={}, ligand={}, analysis_start_step={}, n_bootstraps={}",
row["run_id"],
row["ligand"],
row_step,
row_bootstraps,
)
df = df[
~(
(df["run_id"] == row["run_id"])
& (df["ligand"] == row["ligand"])
& same_step
& same_bootstrap
)
].copy().reset_index(drop=True)
# append/upsert row
new_row = {col: row.get(col, pd.NA) for col in cols}
if df.empty:
df = pd.DataFrame([new_row], columns=cols)
else:
rows = df[cols].to_dict(orient="records")
rows.append(new_row)
df = pd.DataFrame.from_records(rows, columns=cols)
# atomic write: write tmp then replace
fd, tmp = tempfile.mkstemp(
prefix=self._idx.name + ".", suffix=".tmp", dir=str(self._idx.parent)
)
try:
with os.fdopen(fd, "w", encoding="utf-8", newline="") as f:
df.to_csv(f, index=False)
f.flush()
os.fsync(f.fileno())
self._publish_index_file(tmp)
finally:
try:
os.unlink(tmp)
except FileNotFoundError:
pass
[docs]
def save(self, rec: FERecord, copy_from: Path | None = None) -> None:
lig_dir = self._lig_dir(rec.run_id, rec.ligand)
lig_dir.mkdir(parents=True, exist_ok=True)
# clear any stale failure marker when writing a success record
(lig_dir / "failure.json").unlink(missing_ok=True)
# write JSON record
(lig_dir / "record.json").write_text(json.dumps(rec.__dict__, indent=2))
# optional: copy raw Results/ in
if copy_from and copy_from.exists():
# keep raw artifacts alongside the record
shutil.rmtree(lig_dir / "Results", ignore_errors=True)
shutil.copytree(copy_from, lig_dir / "Results")
# update index table (append-or-upsert by (run_id, ligand, analysis_start_step, n_bootstraps))
analysis_start_step_val = rec.analysis_start_step
n_bootstraps_val = rec.n_bootstraps
row = {
"run_id": rec.run_id,
"ligand": rec.ligand,
"mol_name": rec.mol_name,
"system_name": rec.system_name,
"temperature": rec.temperature,
"total_dG": rec.total_dG,
"total_se": rec.total_se,
"canonical_smiles": rec.canonical_smiles or "",
"original_name": rec.original_name or "",
"original_path": rec.original_path or "",
"protocol": rec.protocol,
"analysis_start_step": (
int(analysis_start_step_val)
if analysis_start_step_val is not None
else pd.NA
),
"n_bootstraps": (
int(n_bootstraps_val) if n_bootstraps_val is not None else 0
),
"include_in_analysis": rec.include_in_analysis,
"created_at": rec.created_at,
"status": rec.status,
"failure_reason": pd.NA,
}
self._append_index_row(row)
[docs]
def index(self) -> "pd.DataFrame":
cols = [
"run_id",
"ligand",
"mol_name",
"system_name",
"temperature",
"total_dG",
"total_se",
"canonical_smiles",
"original_name",
"original_path",
"protocol",
"analysis_start_step",
"n_bootstraps",
"include_in_analysis",
"created_at",
]
if self._idx.exists():
df = pd.read_csv(self._idx)
else:
df = pd.DataFrame(columns=cols)
# drop old columns if present
for drop in ("fe_type", "components", "method"):
if drop in df.columns:
df = df.drop(columns=[drop])
if "sim_range" in df.columns:
df = df.drop(columns=["sim_range"])
# ensure columns exist
for key in ("status", "failure_reason"):
if key not in df.columns:
df[key] = pd.NA
for col in cols:
if col not in df.columns:
if col == "n_bootstraps":
df[col] = 0
elif col == "include_in_analysis":
df[col] = True
else:
df[col] = pd.NA
if "n_bootstraps" in df.columns:
df["n_bootstraps"] = df["n_bootstraps"].fillna(0)
df["include_in_analysis"] = df["include_in_analysis"].map(
lambda value: self._normalize_bool(value, default=True)
)
df["failure_reason"] = df["failure_reason"].fillna("")
return df[cols + ["status", "failure_reason"]]
[docs]
def set_analysis_inclusion(
self,
*,
run_id: str,
ligand: str,
include: bool,
analysis_start_step: int | None = None,
n_bootstraps: int | None = None,
) -> int:
"""Set ``include_in_analysis`` for matching rows in ``results/index.csv``."""
if not self._idx.exists():
raise FileNotFoundError(f"Missing FE results index: {self._idx}")
lock = FileLock(str(self._idx_lock))
with lock:
df = pd.read_csv(self._idx)
for col in (
"run_id",
"ligand",
"analysis_start_step",
"n_bootstraps",
"include_in_analysis",
):
if col not in df.columns:
if col == "include_in_analysis":
df[col] = True
elif col == "n_bootstraps":
df[col] = 0
else:
df[col] = pd.NA
mask = (df["run_id"].astype(str) == str(run_id)) & (
df["ligand"].astype(str) == str(ligand)
)
if analysis_start_step is not None:
step_series = df["analysis_start_step"].map(self._normalize_optional_int)
mask &= step_series == int(analysis_start_step)
if n_bootstraps is not None:
bootstrap_series = df["n_bootstraps"].map(self._normalize_n_bootstraps)
mask &= bootstrap_series == int(n_bootstraps)
n_updated = int(mask.sum())
if n_updated == 0:
return 0
df.loc[mask, "include_in_analysis"] = bool(include)
fd, tmp = tempfile.mkstemp(
prefix=self._idx.name + ".", suffix=".tmp", dir=str(self._idx.parent)
)
try:
with os.fdopen(fd, "w", encoding="utf-8", newline="") as f:
df.to_csv(f, index=False)
f.flush()
os.fsync(f.fileno())
self._publish_index_file(tmp)
finally:
try:
os.unlink(tmp)
except FileNotFoundError:
pass
return n_updated
[docs]
def record_failure(
self,
run_id: str,
ligand: str,
system_name: str,
temperature: float,
*,
status: Literal["failed", "unbound"],
reason: str | None = None,
canonical_smiles: str | None = None,
original_name: str | None = None,
original_path: str | None = None,
protocol: str = "abfe",
analysis_start_step: int | None = None,
n_bootstraps: int | None = None,
) -> None:
lig_dir = self._lig_dir(run_id, ligand)
lig_dir.mkdir(parents=True, exist_ok=True)
failure_detail = {
"run_id": run_id,
"ligand": ligand,
"status": status,
"reason": reason or "",
"protocol": protocol,
"timestamp": datetime.now(timezone.utc).isoformat(timespec="seconds"),
}
(lig_dir / "failure.json").write_text(json.dumps(failure_detail, indent=2))
analysis_start_step_val = (
int(analysis_start_step) if analysis_start_step is not None else pd.NA
)
n_bootstraps_val = int(n_bootstraps) if n_bootstraps is not None else 0
row = {
"run_id": run_id,
"ligand": ligand,
"mol_name": "",
"system_name": system_name,
"temperature": temperature,
"total_dG": pd.NA,
"total_se": pd.NA,
"canonical_smiles": canonical_smiles or "",
"original_name": original_name or "",
"original_path": original_path or "",
"protocol": protocol,
"analysis_start_step": analysis_start_step_val,
"n_bootstraps": n_bootstraps_val,
"include_in_analysis": True,
"status": status,
"failure_reason": reason or "",
"created_at": failure_detail["timestamp"],
}
self._append_index_row(row)
[docs]
def load(self, run_id: str, ligand: str) -> FERecord:
p = self._lig_dir(run_id, ligand) / "record.json"
d = json.loads(p.read_text())
d["components"] = (
d.get("components", "").split(",")
if isinstance(d.get("components"), str)
else d.get("components", [])
)
return FERecord(**d)