Source code for batter.analysis.analysis

from __future__ import annotations

import os
import re
import glob
import json
import math
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import logging
from loguru import logger
from joblib import Parallel, delayed

from pymbar.timeseries import detect_equilibration, subsample_correlated_data
import MDAnalysis as mda
from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals

from alchemlyb.estimators import MBAR
from alchemlyb.parsing.amber import extract_u_nk
from alchemlyb.convergence import forward_backward_convergence, block_average
from alchemlyb.visualisation import (
    plot_convergence,
    plot_mbar_overlap_matrix,
    plot_block_average,
)

import seaborn as sns
import matplotlib.pyplot as plt

from batter.utils import run_with_log, cpptraj
from batter.analysis.utils import exclude_outliers


COMPONENTS_DICT = {
    "rest": ["a", "l", "t", "c", "r", "m", "n"],
    "dd": ["e", "v", "f", "w", "x", "o", "s", "z", "y", "m"],
}

# sign that determines direction of contribution to total FE
COMPONENT_DIRECTION_DICT = {
    "m": -1,
    "n": +1,
    "e": -1,
    "v": -1,
    "o": -1,
    "z": -1,
    "y": +1,
    "m": -1,
    "x": +1,
    "Boresch": -1,
}


[docs] class SilenceAlchemlybOnly: def __enter__(self): logger.disable("alchemlyb") logger.disable("alchemlyb.parsing") logger.disable("alchemlyb.parsing.amber") self._py_loggers = [ logging.getLogger("alchemlyb"), logging.getLogger("alchemlyb.parsing"), logging.getLogger("alchemlyb.parsing.amber"), ] self._prev_levels = [py_logger.level for py_logger in self._py_loggers] for py_logger in self._py_loggers: py_logger.setLevel(logging.WARNING) def __exit__(self, *args): logger.enable("alchemlyb") logger.enable("alchemlyb.parsing") logger.enable("alchemlyb.parsing.amber") if hasattr(self, "_py_loggers"): for py_logger, prev_level in zip(self._py_loggers, self._prev_levels): py_logger.setLevel(prev_level)
def _is_incomplete_amber_out_error(exc: ValueError) -> bool: msg = str(exc) incomplete_markers = ( 'no "CONTROL DATA" section found', 'no "RESULTS" section found', 'no "ATOMIC" section found', "No starting simulation time", "no free energy section found", 'no valid "temp0" record found', "does not contain any data", ) return any(marker in msg for marker in incomplete_markers)
[docs] class FEAnalysisBase(ABC): """ Minimal interface shared across component analysis routines. Attributes ---------- results : dict Storage for the scalar FE, uncertainty, convergence tables, and FE time series generated by subclasses. """ def __init__(self): self.results = { "fe": None, # scalar in 'energy_unit' "fe_error": None, # scalar (same unit) "convergence": {}, # dict of dataframes/arrays "fe_timeseries": None, # Nx2 array: [FE, FE_err] across progress fractions }
[docs] @abstractmethod def run_analysis(self): ...
[docs] @abstractmethod def plot_convergence(self, ax=None, **kwargs): ...
@property def fe(self): return self.results["fe"] @property def fe_error(self): return self.results["fe_error"] @property def convergence(self): return self.results["convergence"] @property def fe_timeseries(self): return self.results["fe_timeseries"]
[docs] def dump(self, filename="results.json"): """Store results to JSON (omit heavy convergence tables).""" fe = float(self.fe) if self.fe is not None else None fe_err = float(self.fe_error) if self.fe_error is not None else None fets = self.fe_timeseries fets_list = fets.tolist() if isinstance(fets, np.ndarray) else fets with open(filename, "w") as f: json.dump( {"fe": fe, "fe_error": fe_err, "fe_timeseries": fets_list}, f, indent=2 )
[docs] class MBARAnalysis(FEAnalysisBase): """ Post-process a single component with :class:`alchemlyb.estimators.MBAR`. Parameters ---------- lig_folder : str Absolute path to the ligand work directory. component : str Component identifier (e.g., ``"e"`` or ``"m"``). windows : list[int] Lambda windows present for the component. temperature : float Simulation temperature in Kelvin. energy_unit : {"kcal/mol", "kJ/mol", "kT"}, optional Output energy unit. Internally every value is accumulated in units of ``kT`` and converted before publishing the results. analysis_start_step : int, optional Discard frames with step <= this value before analysis. detect_equil : bool, optional When ``True`` the equilibration time of each window is detected and the pre-equilibrated portion is discarded. n_bootstraps : int, optional Number of bootstrap samples handed to :class:`MBAR`. n_jobs : int, optional Level of joblib parallelism when parsing windows. load : bool, optional When ``True`` reuse cached ``*_df_list.pickle`` files if available. """ def __init__( self, lig_folder: str, component: str, windows: List[int], temperature: float, energy_unit: str = "kcal/mol", analysis_start_step: int = 0, detect_equil: bool = True, n_bootstraps: int = 0, n_jobs: int = 6, load: bool = False, dt: float = 0.0, ntwx: Optional[int] = None, ): super().__init__() if dt is None or dt <= 0: raise ValueError("dt must be > 0 for analysis to convert steps to time.") self.lig_folder = lig_folder self.result_folder = f"{self.lig_folder}/Results" os.makedirs(self.result_folder, exist_ok=True) comp_folder = f"{self.lig_folder}/{component}" if not os.path.isdir(comp_folder): raise ValueError(f"Component folder not found: {comp_folder}") self.comp_folder = comp_folder self.component = component self.windows = windows self.temperature = float(temperature) if energy_unit not in ("kcal/mol", "kJ/mol", "kT"): raise ValueError("energy_unit must be 'kcal/mol', 'kJ/mol', or 'kT'") self.energy_unit = energy_unit self.kT = 0.0019872041 * self.temperature self.analysis_start_step = max(0, int(analysis_start_step)) self.dt = float(dt) if dt is not None else 0.0 self.ntwx = ntwx if ntwx is not None else 0 logger.debug( f"[MBARAnalysis:init] comp={component}, windows={windows}, " f"analysis_start_step={self.analysis_start_step}, dt={self.dt}, ntwx={self.ntwx}" ) self.detect_equil = bool(detect_equil) self.n_bootstraps = int(n_bootstraps) self.n_jobs = int(n_jobs) self.load = bool(load) self._data_initialized = False # public props used after get_mbar_data() @property def u_df(self) -> pd.DataFrame: return self._u_df @property def data_list(self) -> List[pd.DataFrame]: return self._data_list
[docs] def get_mbar_data(self) -> None: """ Parse and cache the not reduced potentials for all lambda windows. Notes ----- The concatenated dataframe is stored in ``self._u_df`` while the list of per-window frames is available via :attr:`data_list`. """ pkl = f"{self.result_folder}/{self.component}_df_list.pickle" if self.load and os.path.exists(pkl): with open(pkl, "rb") as f: df_list = pickle.load(f) logger.debug(f"[MBARAnalysis] Loaded cached data from {pkl}") else: logger.debug(f"[MBARAnalysis] Parsing data for component {self.component}") df_list = self._get_data_list() self._data_list = df_list # get reduced df_list by substracting the reference U from the lambda simulation self._data_list = [df.subtract(df.iloc[:, i], axis=0) for i, df in enumerate(self._data_list)] self._u_df = pd.concat(self._data_list) self.timeseries = [df.index.get_level_values("time").values for df in self._data_list] self._data_initialized = True
[docs] def run_analysis(self) -> None: if not self._data_initialized: self.get_mbar_data() mbar = MBAR(n_bootstraps=self.n_bootstraps) mbar.fit(self.u_df) self._mbar = mbar # accumulate error in kT space then convert err_kT = np.sqrt( sum( mbar.d_delta_f_.iloc[i, i + 1] ** 2 for i in range(len(mbar.d_delta_f_) - 1) ) ) delta_kT = mbar.delta_f_.iloc[0, -1] if self.energy_unit == "kcal/mol": self.results["fe"] = float(delta_kT * self.kT) self.results["fe_error"] = float(err_kT * self.kT) elif self.energy_unit == "kJ/mol": self.results["fe"] = float(delta_kT * self.kT * 4.184) self.results["fe_error"] = float(err_kT * self.kT * 4.184) else: # kT self.results["fe"] = float(delta_kT) self.results["fe_error"] = float(err_kT) # plot mbar.delta_f_.T[0] to see the free energy differences between all windows (debug) # with error bars from mbar.d_delta_f_.T[0] fig, ax = plt.subplots(figsize=(8, 6)) ax.errorbar( range(len(mbar.delta_f_.columns)), mbar.delta_f_.iloc[0, :], yerr=mbar.d_delta_f_.iloc[0, :]) ax.set_xlabel("Lambda Window Index") ax.set_ylabel("Free Energy Difference (kT)") plt.title(f"MBAR Free Energy Differences for Component {self.component}") plt.tight_layout() plt.savefig(f"{self.result_folder}/{self.component}_mbar_delta_f.png", dpi=200) plt.close(fig) # Convergence summaries with SilenceAlchemlybOnly(): tc = forward_backward_convergence( self.data_list, "MBAR", error_tol=100, method="default" ) self.results["convergence"]["time_convergence"] = tc # forward/backward times (MultiIndex) + FE arrays (in kcal/mol) forward_FE = tc.Forward.values * self.kT forward_FE_err = tc.Forward_Error.values * self.kT backward_FE = tc.Backward.values * self.kT backward_FE_err = tc.Backward_Error.values * self.kT # fe_timeseries: N x 2 array (value, stderr) self.results["fe_timeseries"] = np.column_stack( [forward_FE, forward_FE_err] ) # block average (10 blocks) ba = block_average( self.data_list, estimator="MBAR", num=10, method="default" ) self.results["convergence"]["block_convergence"] = ba block_FE = ba.FE.values * self.kT block_FE_err = ba.FE_Error.values * self.kT # pack in a simple dataframe with sequential fraction labels self.results["convergence"]["block_timeseries"] = pd.DataFrame( {"FE": block_FE, "FE_Error": block_FE_err}, index=np.linspace(0.1, 1.0, len(block_FE)), ) self.results["convergence"]["overlap_matrix"] = mbar.overlap_matrix self.results["convergence"]["mbar"] = mbar # persist with open(f"{self.result_folder}/{self.component}_results.pickle", "wb") as f: pickle.dump(self.results, f) self.dump(f"{self.result_folder}/{self.component}_results.json")
@staticmethod def _extract_all_for_window( win_i: int, comp_folder: str, component: str, temperature: float, analysis_start_step: int, truncate: bool, dt: float = 0.004, log_level: int = logging.WARNING, ) -> pd.DataFrame: """ Extract reduced potentials for a single window. Parameters ---------- win_i : int Window index within ``self.windows``. comp_folder : str Component directory. component : str Component identifier. temperature : float Simulation temperature in Kelvin. analysis_start_step : int Discard frames with step <= this value. truncate : bool If ``True``, detect equilibration and discard early frames. dt : float Time step (ps) used to convert analysis_start_step into ps. Returns ------- pandas.DataFrame Reduced potentials referenced to ``win_i`` in units of ``kT``. """ logger.debug(f"[MBARAnalysis] Extracting window {component}{win_i:02d}") win_dir = f"{comp_folder}/{component}{win_i:02d}" patterns = [f"{win_dir}/mdin-*.out", f"{win_dir}/md-*.out"] mdouts: List[str] = [] for pat in patterns: mdouts.extend(glob.glob(pat)) if mdouts: def _idx(path: str) -> int: base = os.path.basename(path) m = re.search(r"(?:mdin-|md-)(\d+)\.out$", base) return int(m.group(1)) if m else 0 mdouts = sorted(set(mdouts), key=_idx) if not mdouts: raise FileNotFoundError(f"No Amber out files in {win_dir}") logger.debug( f"[MBARAnalysis] {component}{win_i:02d} using {len(mdouts)} mdout files" ) dfs = [] with SilenceAlchemlybOnly(): for fn in mdouts: try: df_part = extract_u_nk( fn, T=temperature, reduced=False, raise_error=False ) except ValueError as exc: if _is_incomplete_amber_out_error(exc): logger.warning( f"[MBARAnalysis] Skipping incomplete Amber out file: {fn} ({exc})" ) continue raise dfs.append(df_part) if not dfs: raise FileNotFoundError(f"No parseable Amber out files in {win_dir}") df = pd.concat(dfs) # Drop early frames if requested (convert steps -> ps) if analysis_start_step > 0: threshold = analysis_start_step * dt if threshold > df.index.get_level_values(0).max(): raise ValueError( f"[MBARAnalysis] {component}{win_i:02d} WARNING: " f"analysis_start_step={analysis_start_step} exceeds max time " f"in data ({df.index.get_level_values(0).max()/dt:.0f} steps)! " ) else: logger.debug( f"[MBARAnalysis] {component}{win_i:02d} dropping frames <= {threshold} ps " ) df = df[df.index.get_level_values(0) > threshold] # reduce index to start from zero time df.index = df.index.map(lambda t: (t[0] - threshold, *t[1:])) # Mixed precision spikes guard df = exclude_outliers(df, iclam=win_i) # detect_equilibration on the reference column of this window if truncate: with SilenceAlchemlybOnly(): t0, g, Neff_max = detect_equilibration(df.iloc[:, win_i], nskip=10) df = df.iloc[t0:, :] indices = subsample_correlated_data(df.iloc[:, win_i], g=g) df = df.iloc[indices, :] logger.debug( f"[MBARAnalysis] {component}{win_i:02d} detected equilibration at after row {t0}" ) # subtract reference (this window) to yield reduced potentials # do it later # ref = df.iloc[:, win_i] # df = df.subtract(ref, axis=0) logger.debug( f"[MBARAnalysis] {component}{win_i:02d} final data shape: {df.shape}" ) if df.empty: # reuse the full df df = pd.concat(dfs) logger.warning( f"[MBARAnalysis] {component}{win_i:02d} WARNING: returning untruncated data!" ) return df def _get_data_list(self) -> List[pd.DataFrame]: df_list = Parallel(n_jobs=self.n_jobs)( delayed(self._extract_all_for_window)( win_i=win_i, comp_folder=self.comp_folder, component=self.component, temperature=self.temperature, analysis_start_step=self.analysis_start_step, truncate=self.detect_equil, dt=self.dt, ) for win_i in range(len(self.windows)) ) for df in df_list: df.attrs["temperature"] = self.temperature df.attrs["energy_unit"] = "kT" # save df_list info df_list_attrs = { "component": self.component, "temperature": self.temperature, "energy_unit": "kT", "analysis_start_step": self.analysis_start_step, "detect_equil": self.detect_equil, "dt": self.dt, "win_sizes": {i: len(df) for i, df in enumerate(df_list)}, } with open( f"{self.result_folder}/{self.component}_df_list_attrs.json", "w" ) as f: json.dump(df_list_attrs, f, indent=2) with open(f"{self.result_folder}/{self.component}_df_list.pickle", "wb") as f: pickle.dump(df_list, f) return df_list
[docs] def plot_time_convergence(self, ax=None, **kwargs): df = self.results["convergence"]["time_convergence"] return plot_convergence(df, ax=ax, **kwargs)
[docs] def plot_overlap_matrix(self, ax=None, **kwargs): return plot_mbar_overlap_matrix( self.results["convergence"]["overlap_matrix"], ax=ax, **kwargs )
[docs] def plot_block_convergence(self, ax=None, **kwargs): df = self.results["convergence"]["block_convergence"] return plot_block_average(df, ax=ax, **kwargs)
[docs] def plot_convergence( self, save_path: Optional[str] = None, title: Optional[str] = None ): fig, axes = plt.subplot_mosaic( [["A", "A", "B"], ["C", "C", "C"]], figsize=(24, 14) ) self.plot_time_convergence( ax=axes["A"], units=self.energy_unit, final_error=0.6 ) self.plot_overlap_matrix(ax=axes["B"]) self.plot_block_convergence( ax=axes["C"], units=self.energy_unit, final_error=0.6 ) axes["A"].set_title("Time Convergence", fontsize=10) axes["B"].set_title("Overlap Matrix", fontsize=10) axes["C"].set_title("Block Convergence", fontsize=10) plt.tight_layout() if title: plt.suptitle(title, y=1.02) if save_path: fig.savefig(save_path, dpi=200) plt.close(fig) else: plt.show()
[docs] class RESTMBARAnalysis(MBARAnalysis): """ MBAR analysis variant for restraint components that require cpptraj traces. """ def _extract_restraints_from_windows(self): num_win = len(self.windows) component = self.component disang_file = f"{self.comp_folder}/{component}00/disang.rest" with open(disang_file, "r") as f: disang = f.readlines() num_rest = 0 for line in disang: cols = line.split() if not cols: continue tag = cols[-1] if component == "t" and tag == "#Lig_TR": num_rest += 1 elif component in ("l", "c") and tag in ("#Lig_C", "#Lig_D"): num_rest += 1 elif component in ("a", "r") and tag in ("#Rec_C", "#Rec_D"): num_rest += 1 elif component in ("m", "n") and tag in ( "#Rec_C", "#Rec_D", "#Lig_TR", "#Lig_C", "#Lig_D", ): num_rest += 1 rty = ["d"] * num_rest rfc = np.zeros([num_win, num_rest], dtype=float) req = np.zeros([num_win, num_rest], dtype=float) for win in range(num_win): dpath = f"{self.comp_folder}/{component}{win:02d}/disang.rest" with open(dpath, "r") as fh: lines = fh.readlines() r = 0 for line in lines: cols = line.split() if not cols: continue tag = cols[-1] def _natms() -> int: return len(cols[2].split(",")) - 1 if component == "t" and tag == "#Lig_TR": req[win, r] = float(cols[6].replace(",", "")) nat = _natms() if nat == 2: rty[r] = "d" rfc[win, r] = float(cols[12].replace(",", "")) elif nat == 3: rty[r] = "a" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) elif nat == 4: rty[r] = "t" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) else: raise ValueError("Unknown restraint natoms") r += 1 elif component in ("l", "c") and tag in ("#Lig_C", "#Lig_D"): req[win, r] = float(cols[6].replace(",", "")) nat = _natms() if nat == 2: rty[r] = "d" rfc[win, r] = float(cols[12].replace(",", "")) elif nat == 3: rty[r] = "a" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) elif nat == 4: rty[r] = "t" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) else: raise ValueError("Unknown restraint natoms") r += 1 elif component in ("a", "r") and tag in ("#Rec_C", "#Rec_D"): req[win, r] = float(cols[6].replace(",", "")) nat = _natms() if nat == 2: rty[r] = "d" rfc[win, r] = float(cols[12].replace(",", "")) elif nat == 3: rty[r] = "a" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) elif nat == 4: rty[r] = "t" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) else: raise ValueError("Unknown restraint natoms") r += 1 elif component in ("m", "n") and tag in ( "#Rec_C", "#Rec_D", "#Lig_TR", "#Lig_C", "#Lig_D", ): req[win, r] = float(cols[6].replace(",", "")) nat = _natms() if nat == 2: rty[r] = "d" rfc[win, r] = float(cols[12].replace(",", "")) elif nat == 3: rty[r] = "a" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) elif nat == 4: rty[r] = "t" rfc[win, r] = ( float(cols[12].replace(",", "")) * (np.pi / 180.0) ** 2 ) else: raise ValueError("Unknown restraint natoms") r += 1 return rfc, req, rty, num_rest def _get_data_list(self) -> List[pd.DataFrame]: rfc, req, rty, num_rest = self._extract_restraints_from_windows() df_list = Parallel(n_jobs=self.n_jobs)( delayed(self._extract_all_for_window)( win_i=win_i, comp_folder=self.comp_folder, component=self.component, temperature=self.temperature, analysis_start_step=self.analysis_start_step, truncate=self.detect_equil, rfc=rfc, req=req, rty=rty, num_rest=num_rest, num_win=len(self.windows), dt=self.dt, ntwx=self.ntwx, ) for win_i in range(len(self.windows)) ) with open(f"{self.result_folder}/{self.component}_df_list.pickle", "wb") as f: pickle.dump(df_list, f) for df in df_list: df.attrs["temperature"] = self.temperature df.attrs["energy_unit"] = "kT" return df_list @staticmethod def _extract_all_for_window( win_i: int, comp_folder: str, component: str, temperature: float, analysis_start_step: int, rfc: np.ndarray, req: np.ndarray, rty: List[str], num_rest: int, num_win: int, truncate: bool, dt: float, ntwx: int, ) -> pd.DataFrame: """Compute reduced potentials for REST components from restraint traces.""" kT = 0.0019872041 * temperature win_dir = Path(f"{comp_folder}/{component}{win_i:02d}") cwd0 = Path.cwd() try: os.chdir(win_dir) # enumerate mdin-XX.nc (or fallback md01.nc..) nc_list: List[str] = [] nsims = len(glob.glob("mdin-*.nc")) for i in range(nsims): fn = f"mdin-{i:02d}.nc" if os.path.exists(fn): nc_list.append(fn) if not nc_list: fallback = ["md01.nc", "md02.nc", "md03.nc", "md04.nc"] nc_list = [f for f in fallback if os.path.exists(f)] if not nc_list: raise FileNotFoundError("No NetCDF trajs for REST window") logger.debug( f"[RESTMBAR] {component}{win_i:02d} using {len(nc_list)} nc files" ) # generate restraint traces via cpptraj using current topology choice def _gen(top_choice: str): generate_results_rest(nc_list, component, blocks=5, top=top_choice) try: _gen("full") except Exception: _gen("vac") with open("restraints.dat", "r") as fin: lines = [ln for ln in fin if (ln and ln[0] not in "#@")] val = np.zeros((len(lines), num_rest), dtype=float) for n, line in enumerate(lines): cols = line.split() for r in range(num_rest): if rty[r] == "t": tmp = float(cols[r + 1]) if tmp < req[win_i, r] - 180.0: tmp += 360.0 elif tmp > req[win_i, r] + 180.0: tmp -= 360.0 val[n, r] = tmp else: val[n, r] = float(cols[r + 1]) # reduced potential at this window if component != "u": if rfc[win_i, 0] == 0: # guard tiny zeros tmp = np.ones((num_rest,), np.float64) * 1e-3 u = np.sum(tmp * (val - req[win_i]) ** 2 / kT, axis=1) else: u = np.sum(rfc[win_i] * (val - req[win_i]) ** 2 / kT, axis=1) else: u = (rfc[win_i, 0] * (val[:, 0] - req[win_i, 0]) ** 2) / kT # Drop early frames if requested (convert steps -> frame index) start_idx = max(0, int(analysis_start_step)) if analysis_start_step > 0 and ntwx > 0: # frames recorded every ntwx steps; dt cancels but kept for clarity start_idx = max(0, int(math.ceil(analysis_start_step / float(ntwx)))) if start_idx > 0: logger.debug( f"[RESTMBAR] {component}{win_i:02d} dropping first {start_idx} frames " f"(analysis_start_step={analysis_start_step}, ntwx={ntwx})" ) if start_idx > 0: u = u[start_idx:] val = val[start_idx:] t0 = 0 if truncate: with SilenceAlchemlybOnly(): t0, _, _ = detect_equilibration(u, nskip=10) u = u[t0:] val = val[t0:] Upot = np.zeros((num_win, len(u)), np.float64) for w in range(num_win): if component != "u": Upot[w] = np.sum(rfc[w] * (val - req[w]) ** 2 / kT, axis=1) else: Upot[w] = (rfc[w, 0] * (val[:, 0] - req[w, 0]) ** 2) / kT # Pack like alchemlyb (time,lambdas) MultiIndex win_i_list = np.arange(num_win, dtype=np.float64) mbar_time = np.arange(len(u), dtype=np.float64) clambda = float(win_i) mbar_df = pd.DataFrame( Upot, index=np.array(win_i_list, dtype=np.float64), columns=pd.MultiIndex.from_arrays( [mbar_time, np.repeat(clambda, len(mbar_time))], names=["time", "lambdas"], ), ).T return mbar_df finally: os.chdir(cwd0)
[docs] class BoreschAnalysis(FEAnalysisBase): def __init__(self, disangfile, k_r, k_a, temperature): """ Initialize the Boresch analysis with the disang file and parameters. Parameters ---------- disangfile : str The path to the disang file containing the anchor atoms. k_r : float The force constant for the translation restraint. k_a : float The force constant for the angle and dihedral restraints. They are the same (they don't have to be). temperature : float The temperature in Kelvin for the analysis. """ super().__init__() self.disangfile = disangfile self.k_r = k_r self.k_a = k_a assert self.k_r > 0.0, "k_r must be positive" assert self.k_a > 0.0, "k_a must be positive" self.temperature = temperature
[docs] def run_analysis(self): """ Run the analytical analysis for Boresch restraint. """ logger.debug("Running analytical analysis for Boresch restraint") def _extract_r2_val(line: str) -> float: m = re.search(r"\br2=\s*([+-]?\d+(?:\.\d+)?)", line) if not m: raise ValueError(f"Couldn't find r2= value in line: {line}") return float(m.group(1)) # Read disang file to get anchor atoms with open(self.disangfile, "r") as f_in: lines = [line.rstrip() for line in f_in] tr_lines = list(line for line in lines if "#Lig_TR" in line) r0 = _extract_r2_val(tr_lines[0]) # P1–L1 distance (target at r2) a1_0 = _extract_r2_val(tr_lines[1]) # P2–P1–L1 angle t1_0 = _extract_r2_val(tr_lines[2]) # P3–P2–P1–L1 dihedral a2_0 = _extract_r2_val(tr_lines[3]) # P1–L1–L2 angle t2_0 = _extract_r2_val(tr_lines[4]) # P2–P1–L1–L2 dihedral t3_0 = _extract_r2_val(tr_lines[5]) # P1–L1–L2–L3 dihedral fe_bd = self.fe_int( r0, a1_0, t1_0, a2_0, t2_0, t3_0, self.k_r, self.k_a, self.temperature ) self.results["fe"] = fe_bd self.results["fe_error"] = 0.0 logger.debug(f"Analytical release ligand TR: {fe_bd:.2f} kcal/mol")
[docs] def plot_convergence(self, ax=None, **kwargs): """ no convergence for analytical results """ pass
[docs] @staticmethod def fe_int(r1_0, a1_0, t1_0, a2_0, t2_0, t3_0, k_r, k_a, temperature): """ Calculate the analytical free energy of boresch restraint. from BAT.py """ R = 1.987204118e-3 # kcal/mol-K, a.k.a. boltzman constant beta = 1 / (temperature * R) r1lb, r1ub, r1st = [0.0, 100.0, 0.0001] a1lb, a1ub, a1st = [0.0, np.pi, 0.00005] t1lb, t1ub, t1st = [-np.pi, np.pi, 0.00005] a2lb, a2ub, a2st = [0.0, np.pi, 0.00005] t2lb, t2ub, t2st = [-np.pi, np.pi, 0.00005] t3lb, t3ub, t3st = [-np.pi, np.pi, 0.00005] def dih_per(lb, ub, st, t_0): drange = np.arange(lb, ub, st) delta = drange - np.radians(t_0) for i in range(0, len(delta)): if delta[i] >= np.pi: delta[i] = delta[i] - (2 * np.pi) if delta[i] <= -np.pi: delta[i] = delta[i] + (2 * np.pi) return delta def f_r1(val): return (val**2) * np.exp(-beta * k_r * (val - r1_0) ** 2) def f_a1(val): return np.sin(val) * np.exp(-beta * k_a * (val - np.radians(a1_0)) ** 2) def f_a2(val): return np.sin(val) * np.exp(-beta * k_a * (val - np.radians(a2_0)) ** 2) def f_t1(delta): return np.exp(-beta * k_a * (delta) ** 2) def f_t2(delta): return np.exp(-beta * k_a * (delta) ** 2) def f_t3(delta): return np.exp(-beta * k_a * (delta) ** 2) # Integrate translation and rotation r1_int, a1_int, t1_int, a2_int, t2_int, t3_int = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] intrange = np.arange(r1lb, r1ub, r1st) r1_int = np.trapz(f_r1(intrange), intrange) intrange = np.arange(a1lb, a1ub, a1st) a1_int = np.trapz(f_a1(intrange), intrange) intrange = dih_per(t1lb, t1ub, t1st, t1_0) t1_int = np.trapz(f_t1(intrange), intrange) intrange = np.arange(a2lb, a2ub, a2st) a2_int = np.trapz(f_a2(intrange), intrange) intrange = dih_per(t2lb, t2ub, t2st, t2_0) t2_int = np.trapz(f_t2(intrange), intrange) intrange = dih_per(t3lb, t3ub, t3st, t3_0) t3_int = np.trapz(f_t3(intrange), intrange) return ( R * temperature * np.log( (1 / (8.0 * np.pi * np.pi)) * (1.0 / 1660.0) * r1_int * a1_int * t1_int * a2_int * t2_int * t3_int ) )
[docs] def generate_results_rest( md_sim_files: List[str], comp: str, blocks: int = 5, top: str = "full" ) -> None: """ Build a cpptraj input on the fly using 'restraints.in' template in cwd, swapping the topology to ../{comp}-1/{top}.prmtop and appending trajins. """ with open("restraints.in", "r") as f: lines = f.readlines() # drop any existing trajin lines lines = [ln for ln in lines if "trajin" not in ln] # replace the parm line parm_idx = None for i, ln in enumerate(lines): if "parm " in ln: parm_idx = i break if parm_idx is None: raise ValueError("restraints.in missing a 'parm' line.") lines[parm_idx] = re.sub( r"parm\s+(\S+)", f"parm ../{comp}-1/{top}.prmtop", lines[parm_idx] ) with open("restraints_curr.in", "w") as f: f.writelines(lines[: parm_idx + 1]) for mdin in md_sim_files: f.write(f"trajin {mdin}\n") f.writelines(lines[parm_idx + 1 :]) rc = run_with_log(f"{cpptraj} -i restraints_curr.in > restraints.log 2>&1") if rc != 0: raise RuntimeError("cpptraj failed; see restraints.log")
# ---- lig wrapper ------------------------------------------------------------
[docs] def analyze_lig_task( lig_path: str, lig: str, components: List[str], rest: Tuple[float, float, float, float, float], temperature: float, water_model: str, component_windows_dict: Dict[str, List[int]], rocklin_correction: bool = False, analysis_start_step: int = 0, raise_on_error: bool = True, mol: str = "LIG", n_workers: int = 4, n_bootstraps: int = 0, dt: float = 0.0, ntwx: int = 0, ): """ Analyze one lig under lig_path for the requested components. """ os.makedirs(f"{lig_path}/Results", exist_ok=True) results_entries: List[str] = [] LEN_FE_TIMESERIES = 10 try: fe_values: List[float] = [] fe_stds: List[float] = [] fe_timeseries: Dict[str, np.ndarray] = {} # Analytical Boresch (if present) if "v" in components: boresch_file = f"{lig_path}/v/v-1/disang.rest" elif "o" in components: boresch_file = f"{lig_path}/o/o-1/disang.rest" elif "z" in components: boresch_file = f"{lig_path}/z/z-1/disang.rest" else: boresch_file = None if boresch_file: k_r, k_a = rest[2], rest[3] bor = BoreschAnalysis( disangfile=boresch_file, k_r=k_r, k_a=k_a, temperature=temperature ) bor.run_analysis() fe_values.append(COMPONENT_DIRECTION_DICT["Boresch"] * bor.results["fe"]) fe_stds.append(bor.results["fe_error"]) fe_timeseries["Boresch"] = np.asarray([bor.results["fe"], 0.0]) results_entries.append( f"Boresch\t{COMPONENT_DIRECTION_DICT['Boresch'] * bor.results['fe']:.2f}\t{bor.results['fe_error']:.2f}" ) for comp in components: comp_path = f"{lig_path}/{comp}" windows = component_windows_dict[comp] # skip 'n' if no conformational restraints are applied if comp == "n" and rest[1] == 0 and rest[4] == 0: logger.debug("Skipping 'n' (no conformational restraints).") continue logger.debug( f"[analyze_lig] {lig} comp={comp} windows={windows} " f"analysis_start_step={analysis_start_step}, n_bootstraps={n_bootstraps}, dt={dt}, ntwx={ntwx}" ) if comp in COMPONENTS_DICT["dd"]: ana = MBARAnalysis( lig_folder=lig_path, component=comp, windows=windows, temperature=temperature, analysis_start_step=analysis_start_step, n_bootstraps=n_bootstraps, load=False, n_jobs=n_workers, dt=dt, ntwx=ntwx, ) ana.run_analysis() ana.plot_convergence( save_path=f"{lig_path}/Results/{comp}_convergence.png", title=f"Convergence for {comp} {mol}", ) fe_values.append(COMPONENT_DIRECTION_DICT[comp] * ana.results["fe"]) fe_stds.append(ana.results["fe_error"]) fe_timeseries[comp] = ana.results["fe_timeseries"] results_entries.append( f"{comp}\t{COMPONENT_DIRECTION_DICT[comp]*ana.results['fe']:.2f}\t{ana.results['fe_error']:.2f}" ) elif comp in COMPONENTS_DICT["rest"]: ana = RESTMBARAnalysis( lig_folder=lig_path, component=comp, windows=windows, temperature=temperature, analysis_start_step=analysis_start_step, n_bootstraps=n_bootstraps, load=False, n_jobs=n_workers, dt=dt, ntwx=ntwx, ) ana.run_analysis() ana.plot_convergence( save_path=f"{lig_path}/Results/{comp}_convergence.png", title=f"Convergence for {comp} {mol}", ) fe_values.append(COMPONENT_DIRECTION_DICT[comp] * ana.results["fe"]) fe_stds.append(ana.results["fe_error"]) fe_timeseries[comp] = ana.results["fe_timeseries"] results_entries.append( f"{comp}\t{COMPONENT_DIRECTION_DICT[comp]*ana.results['fe']:.2f}\t{ana.results['fe_error']:.2f}" ) logger.debug(f"[analyze_lig] {lig} combining components for total FE") logger.debug(f" component FE values: {fe_values}") logger.debug(f" component FE stds: {fe_stds}") # total FE and timeseries (sum in quadrature for std) fe_value = float(np.sum(fe_values)) if fe_values else float("nan") fe_std = ( float(np.sqrt(np.sum(np.array(fe_stds) ** 2))) if fe_stds else float("nan") ) fe_ts_val = np.zeros(LEN_FE_TIMESERIES, dtype=float) fe_ts_err2 = np.zeros(LEN_FE_TIMESERIES, dtype=float) for comp, ts in fe_timeseries.items(): direction = COMPONENT_DIRECTION_DICT.get(comp, +1) if ts.ndim == 1: fe_ts_val += float(ts[0]) * direction else: # assume Nx2 (value, stderr) n = min(LEN_FE_TIMESERIES, ts.shape[0]) fe_ts_val[:n] += ts[:n, 0] * direction fe_ts_err2[:n] += ts[:n, 1] ** 2 fe_ts_err = np.sqrt(fe_ts_err2) except Exception as e: logger.error(f"Error during FE analysis for {lig}: {e}") if raise_on_error: raise fe_value = float("nan") fe_std = float("nan") fe_ts_val = np.zeros(LEN_FE_TIMESERIES) * np.nan fe_ts_err = np.zeros(LEN_FE_TIMESERIES) * np.nan # Optional Rocklin correction (component 'y') if rocklin_correction and "y" in components: from .rocklin import run_rocklin_correction universe = mda.Universe( f"{lig_path}/y/y-1/full.prmtop", f"{lig_path}/y/y-1/eq_output.pdb", ) box = universe.dimensions[:3] lig_ag = universe.select_atoms(f"resname {mol}") if len(lig_ag) == 0: raise ValueError( f"No ligand atoms found for Rocklin correction with resname {mol}" ) lig_netq = int(round(lig_ag.total_charge())) other_ag = universe.atoms - lig_ag other_netq = int(round(other_ag.total_charge())) if lig_netq == 0: logger.debug( f"Rocklin correction skipped: ligand netq={lig_netq}, other netq={other_netq}" ) if lig_netq != 0: logger.debug( f"Rocklin correction with ligand netq={lig_netq}, other netq={other_netq}" ) corr = run_rocklin_correction( universe=universe, mol_name=mol, box=box, lig_netq=lig_netq, other_netq=other_netq, temp=temperature, water_model=water_model, ) fe_value += corr results_entries.append(f"Rocklin\t{corr:.2f}\t0.00") fe_ts_val += corr results_entries.append(f"Total\t{fe_value:.2f}\t{fe_std:.2f}") with open(f"{lig_path}/Results/Results.dat", "w") as f: f.write("\n".join(results_entries)) with open(f"{lig_path}/Results/fe_timeseries.json", "w") as f: json.dump( {"fe_value": fe_ts_val.tolist(), "fe_std": fe_ts_err.tolist()}, f, indent=2 ) sns.set(style="whitegrid") fig, ax = plt.subplots(figsize=(6, 4)) x = np.arange(1, LEN_FE_TIMESERIES + 1) / LEN_FE_TIMESERIES * 100.0 ax.errorbar(x, fe_ts_val, yerr=fe_ts_err, fmt="-o", capsize=4) ax.axhline(fe_value, linestyle="--", label="FE value (±1 kcal/mol)") ax.fill_between(x, fe_value - 1.0, fe_value + 1.0, alpha=0.2) ax.set_xlabel("Simulation Progress (%)") ax.set_ylabel("Free Energy (kcal/mol)") ax.set_title(f"Free Energy Timeseries for {mol}") ax.legend(loc="upper right") plt.tight_layout() fig.savefig(f"{lig_path}/Results/fe_timeseries.png", dpi=200) plt.close(fig) return