"""State-function based free-energy correction for RBFE networks.
Acknowledgement
---------------
This module implements the matrix-based State-Function Based Free Energy
Correction (SFC) workflow for BATTER's analysis API, following the article
and supporting information cited below.
Reference
---------
Liu, R.; Lai, Y.; Yao, Y.; Huang, W.; Zhong, Y.; Luo, H.-B.; Li, Z.
State Function-Based Correction: A Simple and Efficient Free-Energy Correction
Algorithm for Large-Scale Relative Binding Free-Energy Calculations.
J. Phys. Chem. Lett. 2025, 16, 23, 5763-5768.
doi:10.1021/acs.jpclett.5c01119
The historical ``cycle_closure_*`` function names are kept for compatibility
with the existing BATTER Cinnabar integration. They now run SFC/WSFC rather
than the earlier cycle-enumeration WCC algorithm.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Sequence
import numpy as np
import pandas as pd
__all__ = [
"CycleClosureEdge",
"CycleClosureResult",
"StateFunctionCorrectionEdge",
"StateFunctionCorrectionResult",
"calculate_cycle_closure",
"calculate_state_function_correction",
"cycle_closure_from_dataframe",
"cycle_closure_from_file",
"read_cycle_closure_file",
"read_state_function_correction_file",
"state_function_correction_from_dataframe",
"state_function_correction_from_file",
]
SFC_MIN_UNCERTAINTY = 1.0e-6
[docs]
@dataclass(frozen=True)
class CycleClosureEdge:
"""One directed RBFE edge used as SFC input.
Parameters
----------
label_a, label_b
Ligand labels defining the edge direction.
ddg
Relative free energy for ``label_a -> label_b``.
uncertainties
Optional standard-error columns. Each supplied column creates one WSFC
estimate using uncertainty-derived weights.
"""
label_a: str
label_b: str
ddg: float
uncertainties: tuple[float, ...] = ()
[docs]
@dataclass(frozen=True)
class CycleClosureResult:
"""SFC result tables and metadata."""
reference: str
reference_free_energy: float
node_results: pd.DataFrame
edge_results: pd.DataFrame
cycles: tuple[tuple[str, ...], ...] = ()
iterations: tuple[int, ...] = ()
converged: tuple[bool, ...] = ()
method: str = "sfc"
schemes: tuple[str, ...] = ()
StateFunctionCorrectionEdge = CycleClosureEdge
StateFunctionCorrectionResult = CycleClosureResult
def _coerce_edges(
edges: Iterable[CycleClosureEdge | Sequence[object]],
) -> tuple[CycleClosureEdge, ...]:
coerced: list[CycleClosureEdge] = []
for edge in edges:
if isinstance(edge, CycleClosureEdge):
candidate = edge
else:
if len(edge) < 3:
raise ValueError("SFC edge sequences must contain at least 3 values.")
label_a, label_b, ddg, *uncertainties = edge
candidate = CycleClosureEdge(
label_a=str(label_a),
label_b=str(label_b),
ddg=float(ddg),
uncertainties=tuple(float(value) for value in uncertainties),
)
label_a = str(candidate.label_a).strip()
label_b = str(candidate.label_b).strip()
if not label_a or not label_b:
raise ValueError("SFC edge labels cannot be empty.")
if label_a == label_b:
raise ValueError("SFC edges cannot connect a ligand to itself.")
if not math.isfinite(float(candidate.ddg)):
raise ValueError(f"Non-finite SFC ddG for {label_a}->{label_b}.")
uncertainties = tuple(float(value) for value in candidate.uncertainties)
if any((not math.isfinite(value)) or value < 0 for value in uncertainties):
raise ValueError(
f"Uncertainties for {label_a}->{label_b} must be finite and >= 0."
)
uncertainties = tuple(
SFC_MIN_UNCERTAINTY if value == 0 else value
for value in uncertainties
)
coerced.append(
CycleClosureEdge(
label_a=label_a,
label_b=label_b,
ddg=float(candidate.ddg),
uncertainties=uncertainties,
)
)
if not coerced:
raise ValueError("SFC requires at least one edge.")
uncertainty_counts = {len(edge.uncertainties) for edge in coerced}
if len(uncertainty_counts) != 1:
raise ValueError("All SFC edges must use the same uncertainty columns.")
return tuple(coerced)
def _ordered_labels(
edges: Sequence[CycleClosureEdge],
reference: str | None,
) -> list[str]:
labels: list[str] = []
for edge in edges:
for label in (edge.label_a, edge.label_b):
if label not in labels:
labels.append(label)
if reference is None:
return labels
reference = str(reference).strip()
if reference not in labels:
raise ValueError(f"Reference ligand {reference!r} is not present in the graph.")
return [reference, *[label for label in labels if label != reference]]
def _design_matrix(
edges: Sequence[CycleClosureEdge],
labels: Sequence[str],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
label_index = {label: idx for idx, label in enumerate(labels)}
n_edges = len(edges)
n_labels = len(labels)
n_uncertainty_cols = len(edges[0].uncertainties)
a_matrix = np.zeros((n_edges, n_labels), dtype=float)
b_vector = np.zeros(n_edges, dtype=float)
uncertainty_matrix = np.zeros((n_edges, n_uncertainty_cols), dtype=float)
for row_idx, edge in enumerate(edges):
idx_a = label_index[edge.label_a]
idx_b = label_index[edge.label_b]
a_matrix[row_idx, idx_a] = -1.0
a_matrix[row_idx, idx_b] = 1.0
b_vector[row_idx] = edge.ddg
for col_idx, uncertainty in enumerate(edge.uncertainties):
uncertainty_matrix[row_idx, col_idx] = uncertainty
return a_matrix, b_vector, uncertainty_matrix
def _validate_connected_system(a_matrix: np.ndarray, n_labels: int) -> None:
rank = int(np.linalg.matrix_rank(a_matrix))
if rank < n_labels - 1:
raise ValueError(
"SFC requires a connected RBFE graph; the design matrix is rank "
f"deficient (rank={rank}, expected at least {n_labels - 1})."
)
def _uncertainty_weights(uncertainties: np.ndarray) -> np.ndarray:
total = float(np.sum(uncertainties))
if not math.isfinite(total) or total <= 0:
raise ValueError("SFC uncertainty weights require positive uncertainties.")
normalized = uncertainties / total
return 1.0 / np.square(normalized)
def _solve_state_function(
a_matrix: np.ndarray,
b_vector: np.ndarray,
*,
weights: np.ndarray | None,
reference_index: int,
reference_free_energy: float,
reference_weight: float,
) -> np.ndarray:
n_labels = a_matrix.shape[1]
ref_row = np.zeros((1, n_labels), dtype=float)
ref_row[0, reference_index] = 1.0
a_aug = np.vstack([a_matrix, ref_row])
b_aug = np.concatenate([b_vector, [float(reference_free_energy)]])
if weights is None:
weights_aug = np.ones(a_aug.shape[0], dtype=float)
weights_aug[-1] = float(reference_weight)
else:
if len(weights) != len(b_vector):
raise ValueError("SFC weights must match the number of RBFE edges.")
weights_aug = np.concatenate([np.asarray(weights, dtype=float), [reference_weight]])
if np.any(~np.isfinite(weights_aug)) or np.any(weights_aug <= 0):
raise ValueError("SFC weights must be finite and > 0.")
sqrt_weights = np.sqrt(weights_aug)
weighted_a = a_aug * sqrt_weights[:, None]
weighted_b = b_aug * sqrt_weights
solution, *_ = np.linalg.lstsq(weighted_a, weighted_b, rcond=None)
# A constant shift leaves every predicted ddG unchanged and makes the
# reported reference free energy exact instead of merely high-weighted.
solution = solution + (float(reference_free_energy) - solution[reference_index])
return solution
def _edge_dataframe(
edges: Sequence[CycleClosureEdge],
labels: Sequence[str],
scheme_vectors: dict[str, np.ndarray],
selected_scheme: str,
) -> pd.DataFrame:
label_index = {label: idx for idx, label in enumerate(labels)}
records: list[dict[str, float | str]] = []
for edge in edges:
idx_a = label_index[edge.label_a]
idx_b = label_index[edge.label_b]
record: dict[str, float | str] = {
"labelA": edge.label_a,
"labelB": edge.label_b,
}
for scheme, vector in scheme_vectors.items():
predicted = float(vector[idx_b] - vector[idx_a])
record[f"ddG_{scheme}"] = predicted
record[f"pair_error_{scheme}"] = abs(float(edge.ddg) - predicted)
record["pair_error"] = float(record[f"pair_error_{selected_scheme}"])
records.append(record)
return pd.DataFrame.from_records(records)
def _node_error_vectors(
edges: Sequence[CycleClosureEdge],
labels: Sequence[str],
vector: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
label_index = {label: idx for idx, label in enumerate(labels)}
incident_errors: list[list[float]] = [[] for _ in labels]
for edge in edges:
idx_a = label_index[edge.label_a]
idx_b = label_index[edge.label_b]
predicted = float(vector[idx_b] - vector[idx_a])
residual = abs(float(edge.ddg) - predicted)
incident_errors[idx_a].append(residual)
incident_errors[idx_b].append(residual)
max_error = np.zeros(len(labels), dtype=float)
rms_error = np.zeros(len(labels), dtype=float)
for idx, errors in enumerate(incident_errors):
if not errors:
continue
arr = np.asarray(errors, dtype=float)
max_error[idx] = float(np.max(arr))
rms_error[idx] = float(np.sqrt(np.mean(np.square(arr))))
return max_error, rms_error
def _node_dataframe(
edges: Sequence[CycleClosureEdge],
labels: Sequence[str],
scheme_vectors: dict[str, np.ndarray],
selected_scheme: str,
) -> pd.DataFrame:
scheme_errors = {
scheme: _node_error_vectors(edges, labels, vector)
for scheme, vector in scheme_vectors.items()
}
records: list[dict[str, float | str]] = []
for idx, label in enumerate(labels):
record: dict[str, float | str] = {"label": label}
for scheme, vector in scheme_vectors.items():
max_error, rms_error = scheme_errors[scheme]
record[f"dG_{scheme}"] = float(vector[idx])
record[f"path_dependent_error_{scheme}"] = float(max_error[idx])
record[f"path_independent_error_{scheme}"] = float(rms_error[idx])
selected_max, selected_rms = scheme_errors[selected_scheme]
record["path_dependent_error"] = float(selected_max[idx])
record["path_independent_error"] = float(selected_rms[idx])
records.append(record)
return pd.DataFrame.from_records(records)
[docs]
def calculate_cycle_closure(
edges: Iterable[CycleClosureEdge | Sequence[object]],
*,
reference: str | None = None,
reference_free_energy: float = 0.0,
reference_weight: float = 1e6,
require_cycles: bool | None = None,
**_compat_kwargs,
) -> CycleClosureResult:
"""Run SFC/WSFC correction on an RBFE graph.
``require_cycles`` and extra keyword arguments are accepted for compatibility
with the previous WCC implementation. SFC does not enumerate cycles and can
operate on any connected RBFE graph.
"""
coerced_edges = _coerce_edges(edges)
labels = _ordered_labels(coerced_edges, reference)
reference = labels[0]
a_matrix, b_vector, uncertainty_matrix = _design_matrix(coerced_edges, labels)
_validate_connected_system(a_matrix, len(labels))
scheme_vectors: dict[str, np.ndarray] = {
"sfc": _solve_state_function(
a_matrix,
b_vector,
weights=None,
reference_index=0,
reference_free_energy=reference_free_energy,
reference_weight=reference_weight,
)
}
for col_idx in range(uncertainty_matrix.shape[1]):
scheme = f"wsfc{col_idx + 1}"
scheme_vectors[scheme] = _solve_state_function(
a_matrix,
b_vector,
weights=_uncertainty_weights(uncertainty_matrix[:, col_idx]),
reference_index=0,
reference_free_energy=reference_free_energy,
reference_weight=reference_weight,
)
selected_scheme = next(reversed(scheme_vectors))
schemes = tuple(scheme_vectors.keys())
return CycleClosureResult(
reference=reference,
reference_free_energy=float(reference_free_energy),
node_results=_node_dataframe(coerced_edges, labels, scheme_vectors, selected_scheme),
edge_results=_edge_dataframe(coerced_edges, labels, scheme_vectors, selected_scheme),
cycles=(),
iterations=tuple(1 for _ in schemes),
converged=tuple(True for _ in schemes),
method="sfc",
schemes=schemes,
)
def _first_existing_column(df: pd.DataFrame, candidates: Sequence[str]) -> str | None:
for column in candidates:
if column in df.columns:
return column
return None
[docs]
def cycle_closure_from_dataframe(
df: pd.DataFrame,
*,
label_a_col: str = "labelA",
label_b_col: str = "labelB",
ddg_col: str | None = None,
uncertainty_cols: Sequence[str] | None = None,
reference: str | None = None,
reference_free_energy: float = 0.0,
**kwargs,
) -> CycleClosureResult:
"""Build SFC input from a dataframe and run the correction."""
if ddg_col is None:
ddg_col = _first_existing_column(
df,
("calc_DDG", "DDG (kcal/mol)", "DDG", "ddG", "ddg", "dG", "total_dG"),
)
if ddg_col is None:
raise ValueError("Could not infer the SFC ddG column.")
if uncertainty_cols is None:
uncertainty_col = _first_existing_column(
df,
(
"calc_dDDG",
"uncertainty (kcal/mol)",
"uncertainty",
"dDDG",
"ddG_error",
"total_se",
"std",
),
)
uncertainty_cols = [uncertainty_col] if uncertainty_col is not None else []
required = {label_a_col, label_b_col, ddg_col, *uncertainty_cols}
missing = required - set(df.columns)
if missing:
raise ValueError(f"Missing SFC dataframe columns: {sorted(missing)}")
edges = [
CycleClosureEdge(
label_a=str(row[label_a_col]),
label_b=str(row[label_b_col]),
ddg=float(row[ddg_col]),
uncertainties=tuple(float(row[column]) for column in uncertainty_cols),
)
for _, row in df.iterrows()
]
return calculate_cycle_closure(
edges,
reference=reference,
reference_free_energy=reference_free_energy,
**kwargs,
)
[docs]
def read_cycle_closure_file(path: str | Path) -> pd.DataFrame:
"""Read a whitespace-delimited SFC input file.
The first three columns are named ``labelA``, ``labelB``, and ``ddG``.
Additional columns are treated as standard-error columns named ``std1``,
``std2``, etc.
"""
input_path = Path(path)
df = pd.read_csv(input_path, sep=r"\s+", header=None, comment="#")
if df.shape[1] < 3:
raise ValueError("SFC input files need at least three columns.")
columns = ["labelA", "labelB", "ddG"]
columns.extend(f"std{idx}" for idx in range(1, df.shape[1] - 2))
df.columns = columns
return df
[docs]
def cycle_closure_from_file(
path: str | Path,
*,
reference: str | None = None,
reference_free_energy: float = 0.0,
**kwargs,
) -> CycleClosureResult:
"""Read an SFC-style input file and run state-function correction."""
df = read_cycle_closure_file(path)
uncertainty_cols = [column for column in df.columns if column.startswith("std")]
return cycle_closure_from_dataframe(
df,
ddg_col="ddG",
uncertainty_cols=uncertainty_cols,
reference=reference,
reference_free_energy=reference_free_energy,
**kwargs,
)
calculate_state_function_correction = calculate_cycle_closure
state_function_correction_from_dataframe = cycle_closure_from_dataframe
read_state_function_correction_file = read_cycle_closure_file
state_function_correction_from_file = cycle_closure_from_file