"""Utilities for inspecting replica-exchange simulations."""
from __future__ import annotations
# Copy from Amber FETools (refactored into a single class)
import os
from typing import Dict, List, Optional, Tuple
import numpy as np
from loguru import logger
__all__ = ["RemdLog", "plot_trajectory"]
[docs]
class RemdLog:
r"""
Read and analyse AMBER ``remlog`` files.
The parser reconstructs the replica $\leftrightarrow$ state mapping at each
exchange step and reports high-level metrics such as average single-pass
duration and the number of round trips.
Parameters
----------
inputfile : str
Path to the ``remlog`` text file produced by AMBER.
"""
def __init__(self, inputfile: str):
if not os.path.isfile(inputfile):
raise FileNotFoundError(f"Input file '{inputfile}' does not exist.")
self.inputfile: str = inputfile
# Parsed data
self.replica_trajectory: Optional[np.ndarray] = None # shape: (n_replica, n_step+1)
self.replica_state_count: Optional[np.ndarray] = None # shape: (n_replica, n_state)
self.replica_ex_count: Optional[np.ndarray] = None # shape: (n_replica, n_state-1)
self.replica_ex_succ: Optional[np.ndarray] = None # shape: (n_replica, n_state-1)
self.ARs: Optional[List[float]] = None # neighbor acceptance ratios
# Meta
self.n_replica: Optional[int] = None
self.n_step: Optional[int] = None
self._read_log()
def _read_log(self) -> None:
"""Parse ``self.inputfile`` and populate cached arrays."""
(
self.replica_trajectory,
self.replica_state_count,
self.replica_ex_count,
self.replica_ex_succ,
self.ARs,
self.n_replica,
self.n_step,
) = self._read_rem_log()
[docs]
def analyze(self) -> Dict[str, float | List[float]]:
"""
Summarise the replica trajectory.
Returns
-------
dict
Dictionary with the same keys as :meth:`get_remd_info`.
"""
return self._remd_analysis(self.replica_trajectory, self.ARs)
[docs]
@classmethod
def get_remd_info(cls, inputfile: str) -> Dict[str, float | List[float]]:
"""
Convenience helper that parses and analyses a ``remlog`` file.
Parameters
----------
inputfile : str
Path to the ``remlog`` text file.
Returns
-------
dict
Same structure as :meth:`analyze`.
"""
rl = cls(inputfile)
rl._read_log()
return rl.analyze()
# ---------- Internals ----------
def _read_rem_log(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[float], int, int]:
"""
Parse the on-disk remlog file.
Returns
-------
tuple
``(replica_trajectory, replica_state_count, replica_ex_count,
replica_ex_succ, neighbor_acceptance_ratios, n_replica, n_step)``.
"""
logger.info("Analyzing remlog file: {}", self.inputfile)
np.set_printoptions(precision=2, linewidth=150, formatter={"int": "{:2d}".format})
rep: List[int] = []
neigh: List[int] = []
succ: List[str] = []
try:
with open(self.inputfile, "r") as f:
lines = f.readlines()
except FileNotFoundError:
raise FileNotFoundError(f"File '{self.inputfile}' not found.")
count = 0
n_replica = 0
for line in lines:
count += 1
if not line or line[0] == "#":
continue
# Defensive slicing against short lines
rep.append(int(line[0:6].strip()))
neigh.append(int(line[6:12].strip()))
# Older/newer formats differ where the success char sits
# Prefer column 66 if it's T/F, else fallback to 91
ch66 = line[66:67] if len(line) > 66 else ""
ch91 = line[91:92] if len(line) > 91 else ""
if ch66 in ("T", "F"):
succ.append(ch66)
else:
succ.append(ch91)
# Heuristic from original code to estimate number of replicas early
if count > 200:
n_replica = max(rep[0:200])
logger.info("Done reading the remlog.")
# Final replica/step counts
n_replica = max(rep[0:200]) if n_replica == 0 else n_replica
if n_replica <= 0:
raise ValueError("Failed to infer number of replicas from remlog.")
total_records = len(rep)
if total_records % n_replica != 0:
logger.warning(
"Total records ({}) not divisible by n_replica ({}). "
"Rounding down to an integer number of steps.",
total_records,
n_replica,
)
n_step = total_records // n_replica
n_state = n_replica
logger.info("# of Replicas: {} # of Steps: {}", n_replica, n_step)
# Parse neighbor acceptance ratios from the tail: last n_replica-1 lines
# Original code used [-n_replica : -1]
tail = lines[-n_replica:-1] if len(lines) >= n_replica else []
ARs = []
for t in tail:
bits = t.strip().split()
if bits:
try:
ARs.append(float(bits[-1]))
except ValueError:
pass # ignore lines that don't parse cleanly
# Fallback to empty or keep partial list
# (Keep behavior close to original—no hard failure)
# Allocate arrays
replica_trajectory = np.zeros((n_replica, n_step + 1), np.int64)
replica_state_count = np.zeros((n_replica, n_state), np.int64)
replica_ex_count = np.zeros((n_replica, n_state - 1), np.int64)
replica_ex_succ = np.zeros((n_replica, n_state - 1), np.int64)
# Initialize each replica to its own state at step 0
for i in range(n_replica):
replica_trajectory[i, 0] = i + 1
replica_state_count[i, i] = 1
# Build trajectory over steps based on pairwise exchanges
for m in range(n_step):
# Carry forward last assignments
replica_trajectory[:, m + 1] = replica_trajectory[:, m]
# Odd-even neighbor swaps
for i in range((m + 1) % 2, n_replica - 1, 2):
k = m * n_replica + i
# Find which rows currently hold state i+1 and i+2
x = np.where(replica_trajectory[:, m + 1] == i + 1)[0]
y = np.where(replica_trajectory[:, m + 1] == i + 2)[0]
if x.size > 0:
replica_ex_count[x, i] += 1
if k < len(succ) and succ[k] == "T":
replica_ex_succ[x, i] += 1
if y.size > 0:
replica_trajectory[y, m + 1] = i + 1
replica_trajectory[x, m + 1] = i + 2
# Update counts of time spent in each state
idx = replica_trajectory[:, m + 1] - 1 # 0-based states
for j in range(n_replica):
replica_state_count[j, idx[j]] += 1
return (
replica_trajectory,
replica_state_count,
replica_ex_count,
replica_ex_succ,
ARs,
n_replica,
n_step,
)
@staticmethod
def _remd_analysis(
replica_trajectory: np.ndarray, ARs: List[float]
) -> Dict[str, float | List[float]]:
"""
Compute REMD round-trip statistics from a replica/state table.
Parameters
----------
replica_trajectory : numpy.ndarray
Array of shape ``(n_replica, n_step + 1)`` describing which thermodynamic
state each replica occupied at every step.
ARs : list[float]
Neighbor acceptance ratios parsed from the tail of the remlog.
Returns
-------
dict
Summary containing the average single-pass length, round trips per
replica, total round trips, and the provided acceptance ratios.
"""
n_replica = int(np.size(replica_trajectory, 0))
n_step = int(np.size(replica_trajectory, 1))
logger.info("Analyzing trajectory: n_replica={}, n_step={}", n_replica, n_step)
# Times to go from end-to-end (1 -> N and N -> 1), plus dwell times
h1n: List[int] = []
hn1: List[int] = []
k1n: List[int] = []
kn1: List[int] = []
trip_count_1n = [0] * n_replica
trip_count_n1 = [0] * n_replica
for i in range(n_replica):
first_step_at_1 = -1
first_step_at_n = -1
last_step_at_1 = -1
last_step_at_n = -1
at_1 = 0
at_n = 0
for j in range(n_step):
state = replica_trajectory[i, j]
if state == 1:
last_step_at_1 = j
if at_1 == 0:
at_1 = 1
at_n = 0
first_step_at_1 = j
if first_step_at_n >= 0:
hn1.append(j - first_step_at_n)
first_step_at_n = -1
trip_count_n1[i] += 1
if last_step_at_n >= 0:
kn1.append(j - last_step_at_n)
last_step_at_n = -1
if state == n_replica:
last_step_at_n = j
if at_n == 0:
at_n = 1
at_1 = 0
first_step_at_n = j
if first_step_at_1 >= 0:
h1n.append(j - first_step_at_1)
first_step_at_1 = -1
trip_count_1n[i] += 1
if last_step_at_1 >= 0:
k1n.append(j - last_step_at_1)
last_step_at_1 = -1
output_data: Dict[str, float | List[float]] = {}
if len(h1n) == 0 or len(hn1) == 0:
logger.warning("No single pass found (no 1↔N transitions detected).")
output_data["Average single pass steps:"] = 1.0e8
output_data["Round trips per replica:"] = 0.0
output_data["Total round trips:"] = 0.0
output_data["neighbor_acceptance_ratio"] = ARs
return output_data
hh = h1n + hn1
mean_value = float(np.mean(hh))
output_data["Average single pass steps:"] = mean_value
output_data["Round trips per replica:"] = float(len(hh) / 2 / n_replica)
output_data["Total round trips:"] = float(len(hh) / 2)
output_data["neighbor_acceptance_ratio"] = ARs
return output_data
[docs]
def plot_trajectory(
replica_trajectory,
figsize=(10, 6),
alpha=0.8,
linewidth=1.5,
subplot=False,
ncols=4,
):
"""
Visualise the replica walk through thermodynamic states.
Parameters
----------
replica_trajectory : numpy.ndarray
Array of shape ``(n_replica, n_step + 1)`` containing state indices.
figsize : tuple, optional
Base figure size. When ``subplot=True`` the width/height apply to each
panel instead of the aggregate.
alpha : float, optional
Line transparency used for individual replica traces.
linewidth : float, optional
Width of trajectory lines.
subplot : bool, optional
When ``True``, render one subplot per replica; otherwise plot all
replicas on a shared axis.
ncols : int, optional
Number of subplot columns when ``subplot=True``.
"""
import matplotlib.pyplot as plt # deferred import to avoid heavy backends
n_replica, n_step_plus1 = replica_trajectory.shape
steps = np.arange(n_step_plus1)
cmap = plt.cm.rainbow
colors = [cmap(i / n_replica) for i in range(n_replica)]
if not subplot:
# --- Single axis plot ---
plt.figure(figsize=figsize)
for i in range(n_replica):
plt.plot(
steps,
replica_trajectory[i],
color=colors[i],
alpha=alpha,
linewidth=linewidth,
label=f"Replica {i+1}" if n_replica <= 15 else None,
)
plt.xlabel("Step")
plt.ylabel("State index")
plt.title("Replica Trajectories")
if n_replica <= 15:
plt.legend(loc="best", fontsize=8)
plt.tight_layout()
plt.show()
else:
# --- Subplot mode ---
nrows = int(np.ceil(n_replica / ncols))
fig_width = figsize[0] * ncols
fig_height = figsize[1] * nrows
fig, axes = plt.subplots(
nrows, ncols, figsize=(fig_width, fig_height), sharex=True, sharey=True
)
axes = axes.flatten()
for i in range(n_replica):
ax = axes[i]
ax.plot(
steps,
replica_trajectory[i],
color=colors[i],
alpha=alpha,
linewidth=linewidth,
)
ax.set_title(f"Replica {i+1}", fontsize=9)
ax.tick_params(labelsize=8)
# Hide unused subplots if n_replica not multiple of ncols
for j in range(n_replica, len(axes)):
fig.delaxes(axes[j])
fig.suptitle("Replica Trajectories", fontsize=12)
fig.text(0.5, 0.04, "Step", ha="center")
fig.text(0.04, 0.5, "State index", va="center", rotation="vertical")
plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])
plt.show()