"""Prepare complex systems (protein/ligand/membrane) for simulations."""
from __future__ import annotations
from collections import Counter
import contextlib
import itertools
import json
import os
import shutil
import string
from importlib import resources
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import MDAnalysis as mda
import numpy as np
import pandas as pd
from MDAnalysis.analysis import align
from MDAnalysis.analysis.dssp import DSSP
from loguru import logger
from batter._internal.templates import BUILD_FILES_DIR as build_files_orig
from batter.orchestrate.state_registry import register_phase_state
from batter.pipeline.payloads import StepPayload, SystemParams
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem
from batter.utils.builder_utils import find_anchor_atoms
_PROTEIN_BREAK_CA_DISTANCE_CUTOFF_A = 10.0
_CHAIN_ID_ALPHABET = string.ascii_uppercase + string.ascii_lowercase + string.digits
_XY_ROTATION_REFINE_DEGREES = (45.0, 15.0, 5.0, 1.0)
def _as_abs(p: str | Path | None, base: Path) -> Path | None:
if p is None:
return None
p = Path(p)
return p if p.is_absolute() else (base / p).resolve()
def _copy(src: Path, dst: Path) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
def _chain_id_from_index(index: int) -> str:
if index >= len(_CHAIN_ID_ALPHABET):
raise ValueError(
"Too many protein fragments to encode in single-character PDB chain IDs. "
f"Found fragment index {index + 1}, but only {len(_CHAIN_ID_ALPHABET)} IDs are available."
)
return _CHAIN_ID_ALPHABET[index]
def _get_single_ca_position(residue) -> np.ndarray | None:
ca_atoms = residue.atoms.select_atoms("name CA")
if ca_atoms.n_atoms != 1:
return None
return np.asarray(ca_atoms.positions[0], dtype=float)
def _rotation_matrix_x(angle_deg: float) -> np.ndarray:
angle = np.deg2rad(angle_deg)
c = float(np.cos(angle))
s = float(np.sin(angle))
return np.array(
[[1.0, 0.0, 0.0], [0.0, c, -s], [0.0, s, c]],
dtype=float,
)
def _rotation_matrix_y(angle_deg: float) -> np.ndarray:
angle = np.deg2rad(angle_deg)
c = float(np.cos(angle))
s = float(np.sin(angle))
return np.array(
[[c, 0.0, s], [0.0, 1.0, 0.0], [-s, 0.0, c]],
dtype=float,
)
def _rotation_matrix_z(angle_deg: float) -> np.ndarray:
angle = np.deg2rad(angle_deg)
c = float(np.cos(angle))
s = float(np.sin(angle))
return np.array(
[[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]],
dtype=float,
)
def _apply_rotation(coords: np.ndarray, rotation: np.ndarray) -> np.ndarray:
"""Apply a column-vector rotation matrix to row-vector coordinates."""
return np.asarray(coords, dtype=float) @ rotation.T
def _xy_box_score(coords: np.ndarray) -> tuple[float, float, float]:
spans = np.ptp(coords, axis=0)
return (
float(spans[0] * spans[1]),
float(spans[2]),
float(spans[0] + spans[1]),
)
def _score_lt(
candidate: tuple[float, float, float],
current: tuple[float, float, float],
*,
tol: float = 1e-6,
) -> bool:
for cand_val, curr_val in zip(candidate, current):
if cand_val < curr_val - tol:
return True
if cand_val > curr_val + tol:
return False
return False
def _principal_axis_rotations(coords: np.ndarray) -> list[np.ndarray]:
_, _, vh = np.linalg.svd(coords, full_matrices=False)
axes = vh.T
if np.linalg.det(axes) < 0.0:
axes[:, -1] *= -1.0
rotations: list[np.ndarray] = [np.eye(3, dtype=float)]
for perm in itertools.permutations(range(3)):
permuted = axes[:, perm]
for signs in itertools.product((-1.0, 1.0), repeat=3):
basis = permuted * np.asarray(signs, dtype=float)
if np.linalg.det(basis) <= 0.0:
continue
rotations.append(basis.T)
return rotations
def _refine_xy_box_rotation(
coords: np.ndarray,
initial_rotation: np.ndarray,
*,
step_degrees: tuple[float, ...] = _XY_ROTATION_REFINE_DEGREES,
) -> tuple[np.ndarray, tuple[float, float, float]]:
best_rotation = np.asarray(initial_rotation, dtype=float)
best_score = _xy_box_score(_apply_rotation(coords, best_rotation))
for step_deg in step_degrees:
while True:
improved = False
local_best_rotation = best_rotation
local_best_score = best_score
delta_values = (-step_deg, 0.0, step_deg)
for dx, dy, dz in itertools.product(delta_values, repeat=3):
if dx == dy == dz == 0.0:
continue
delta_rotation = (
_rotation_matrix_z(dz)
@ _rotation_matrix_y(dy)
@ _rotation_matrix_x(dx)
)
candidate_rotation = delta_rotation @ best_rotation
candidate_score = _xy_box_score(
_apply_rotation(coords, candidate_rotation)
)
if _score_lt(candidate_score, local_best_score):
local_best_rotation = candidate_rotation
local_best_score = candidate_score
improved = True
if not improved:
break
best_rotation = local_best_rotation
best_score = local_best_score
return best_rotation, best_score
def _find_min_xy_box_rotation(
coords: np.ndarray,
) -> tuple[np.ndarray, tuple[float, float, float], tuple[float, float, float]]:
coords = np.asarray(coords, dtype=float)
if coords.ndim != 2 or coords.shape[1] != 3:
raise ValueError(
f"Expected an (N, 3) coordinate array for XY box optimization, got {coords.shape}."
)
if coords.shape[0] < 2:
score = _xy_box_score(coords if len(coords) else np.zeros((1, 3), dtype=float))
return np.eye(3, dtype=float), score, score
centered = coords - coords.mean(axis=0, keepdims=True)
before_score = _xy_box_score(centered)
best_rotation = np.eye(3, dtype=float)
best_score = before_score
for rotation in _principal_axis_rotations(centered):
refined_rotation, refined_score = _refine_xy_box_rotation(centered, rotation)
if _score_lt(refined_score, best_score):
best_rotation = refined_rotation
best_score = refined_score
return best_rotation, before_score, best_score
def _split_residues_on_breaks(
residues,
*,
segid: str,
chain_id: str,
ca_distance_cutoff: float = _PROTEIN_BREAK_CA_DISTANCE_CUTOFF_A,
) -> tuple[list[list[Any]], list[str]]:
residue_list = list(residues)
if not residue_list:
return [], []
fragments: list[list[Any]] = [[residue_list[0]]]
warnings: list[str] = []
for prev_residue, curr_residue in zip(residue_list, residue_list[1:]):
reasons: list[str] = []
prev_resid = int(prev_residue.resid)
curr_resid = int(curr_residue.resid)
if curr_resid != prev_resid + 1:
reasons.append(f"resid discontinuity ({prev_resid} -> {curr_resid})")
prev_ca = _get_single_ca_position(prev_residue)
curr_ca = _get_single_ca_position(curr_residue)
if prev_ca is not None and curr_ca is not None:
ca_distance = float(np.linalg.norm(curr_ca - prev_ca))
if ca_distance > ca_distance_cutoff:
reasons.append(
f"C-alpha distance {ca_distance:.1f} A > {ca_distance_cutoff:.1f} A"
)
if reasons:
warnings.append(
"Detected a protein break in system_prep "
f"(segid={segid or '?'}, chain={chain_id or '?'}) "
f"between residues {prev_resid} and {curr_resid}: "
+ "; ".join(reasons)
+ ". BATTER will split these residues into separate segments/chains."
)
fragments.append([])
fragments[-1].append(curr_residue)
return fragments, warnings
def _group_residues_by_source_identity(residues) -> list[list[Any]]:
residue_list = list(residues)
if not residue_list:
return []
groups: list[list[Any]] = [[residue_list[0]]]
prev_chain_id = str(residue_list[0].atoms.chainIDs[0]).strip() if len(residue_list[0].atoms) else ""
prev_segid = str(residue_list[0].segid).strip()
for residue in residue_list[1:]:
chain_id = str(residue.atoms.chainIDs[0]).strip() if len(residue.atoms) else ""
segid = str(residue.segid).strip()
if chain_id != prev_chain_id or segid != prev_segid:
groups.append([])
groups[-1].append(residue)
prev_chain_id = chain_id
prev_segid = segid
return groups
def _protein_segid_overrides(universe: mda.Universe) -> tuple[dict[int, str], int]:
"""
Build per-atom segid overrides to canonicalize segids within each protein residue.
Some input PDBs carry a segid on heavy atoms but leave hydrogens blank.
MDAnalysis then parses those atoms as separate residues/segments on reload.
Compute a residue-level canonical segid so aligned intermediates can be
rewritten with consistent per-residue segids before they are reloaded.
"""
try:
universe.atoms.segids
except AttributeError:
return {}, 0
protein_atoms = universe.select_atoms("protein")
if protein_atoms.n_atoms == 0:
return {}, 0
residue_atom_indices: dict[tuple[str, int, str], list[int]] = {}
for atom in protein_atoms:
chain_id = str(getattr(atom, "chainID", "")).strip()
residue_key = (chain_id, int(atom.resid), str(atom.resname).strip())
residue_atom_indices.setdefault(residue_key, []).append(int(atom.index))
segid_overrides: dict[int, str] = {}
normalized_count = 0
for atom_indices in residue_atom_indices.values():
atom_group = universe.atoms[atom_indices]
segids = [str(segid).strip() for segid in atom_group.segids]
unique_segids = set(segids)
if len(unique_segids) <= 1:
continue
nonempty_segids = [segid for segid in segids if segid]
if nonempty_segids:
canonical_segid = Counter(nonempty_segids).most_common(1)[0][0]
else:
canonical_segid = segids[0]
for atom_index in atom_indices:
segid_overrides[atom_index] = canonical_segid
normalized_count += 1
return segid_overrides, normalized_count
def _write_pdb_with_normalized_protein_segids(
universe: mda.Universe,
output_path: Path,
) -> int:
"""
Write a PDB while normalizing mixed per-atom protein segids per residue.
"""
segid_overrides, normalized_count = _protein_segid_overrides(universe)
universe.atoms.write(output_path.as_posix())
if not segid_overrides:
return normalized_count
rewritten_lines: list[str] = []
atom_counter = 0
for line in output_path.read_text().splitlines(True):
if line.startswith(("ATOM", "HETATM")):
atom = universe.atoms[atom_counter]
atom_counter += 1
canonical_segid = segid_overrides.get(int(atom.index))
if canonical_segid is not None:
stripped = line.rstrip("\n")
if len(stripped) < 76:
stripped = stripped.ljust(76)
line = f"{stripped[:72]}{canonical_segid:<4}{stripped[76:]}\n"
rewritten_lines.append(line)
output_path.write_text("".join(rewritten_lines))
return normalized_count
def _select_fragment_atoms(
universe: mda.Universe,
residues: list[Any],
*,
chain_id: str,
segid: str,
):
resid_seq = " ".join(str(int(residue.resid)) for residue in residues)
selectors: list[str] = []
if chain_id:
selectors.append(f"protein and chainID {chain_id} and resid {resid_seq}")
if segid:
selectors.append(f"protein and segid {segid} and resid {resid_seq}")
selectors.append(f"protein and resid {resid_seq}")
for selector in selectors:
selection = universe.select_atoms(selector)
if selection.n_residues == len(residues):
return selection
raise ValueError(
"Could not match a protein fragment back to the aligned protein using "
f"segid={segid!r}, chainID={chain_id!r}, residues={[int(r.resid) for r in residues]}."
)
def _ensure_pdb(lig_path: Path, out_dir: Path) -> Path:
"""
Ensure a PDB exists for ligand file; if not PDB, convert via RDKit.
Returns the path to a PDB file.
"""
if lig_path.suffix.lower() == ".pdb":
return lig_path
try:
from rdkit import Chem
except Exception as e:
raise RuntimeError(
f"Ligand {lig_path} is not PDB; RDKit is required to convert SDF/MOL2 → PDB."
) from e
out_dir.mkdir(parents=True, exist_ok=True)
out_pdb = out_dir / f"{lig_path.stem}.pdb"
if lig_path.suffix.lower() == ".sdf":
suppl = Chem.SDMolSupplier(str(lig_path), removeHs=False)
mols = [m for m in suppl if m is not None]
if not mols:
raise ValueError(f"RDKit could not read any molecule from {lig_path}")
Chem.MolToPDBFile(mols[0], str(out_pdb))
elif lig_path.suffix.lower() == ".mol2":
mol = Chem.MolFromMol2File(str(lig_path), removeHs=False, sanitize=False)
if mol is None:
raise ValueError(f"RDKit could not read {lig_path}")
Chem.MolToPDBFile(mol, str(out_pdb))
elif lig_path.suffix.lower() == "pdb":
_copy(lig_path, out_pdb)
else:
raise ValueError(f"Unsupported ligand format: {lig_path.suffix} for {lig_path}")
return out_pdb
class _SystemPrepRunner:
def __init__(self, system: SimSystem, yaml_dir: Path) -> None:
self.system = system
self.yaml_dir = yaml_dir
self.output_dir = system.root
self.ligands_folder = self.output_dir / "all-ligands"
self.ligandff_folder = self.output_dir / "artifacts" / "ligands"
self.ligandff_folder.mkdir(parents=True, exist_ok=True)
# state
self._system_name: str = ""
self._protein_input: str = ""
self._system_topology: str | None = None
self._system_coordinate: str | None = None
self.receptor_segment: str | None = None
self.protein_align: str = "name CA and resid 60 to 250"
self.receptor_ff: str = "protein.ff14SB"
self.retain_lig_prot: bool = True
self.ligand_ph: float = 7.4
self.lipid_mol: List[str] = []
self.membrane_simulation: bool = False
self.lipid_ff: str = "lipid21"
self.overwrite: bool = False
self.verbose: bool = False
self.ligand_dict: Dict[str, str] = {}
self.unique_mol_names: List[str] = []
self.system_dimensions = np.zeros(3)
# alignment intermediates
self._protein_aligned_pdb: str | None = None
self._system_aligned_pdb: str | None = None
self.mobile_coord: np.ndarray | None = None
self.ref_coord: np.ndarray | None = None
self.mobile_com: np.ndarray | None = None
self.ref_com: np.ndarray | None = None
self.box_rotation_matrix: np.ndarray = np.eye(3)
# anchors
self.anchor_atoms: List[str] = []
self.ligand_anchor_atom: str | None = None
self.l1_x = self.l1_y = self.l1_z = None
self.l1_range = None
self.p1 = self.p2 = self.p3 = None
@property
def system_name(self) -> str:
return self._system_name
def _resolve_input_path(self, p: str) -> str:
ap = _as_abs(p, self.yaml_dir)
if ap is None:
raise ValueError("unexpected None path")
return str(ap)
@contextlib.contextmanager
def _change_dir(self, path: Path):
cwd = Path.cwd()
try:
os.chdir(path)
yield
finally:
os.chdir(cwd)
def _prepare_membrane(self):
"""
Convert input lipid names to lipid21 set (PC/PA/OL for POPC) via lookup CSV.
"""
logger.debug("Input: membrane system")
# read charmmlipid2amber file
charmm_csv_path = resources.files("batter") / "data/charmmlipid2amber.csv"
charmm_amber_lipid_df = pd.read_csv(charmm_csv_path, header=1, sep=",")
lipid_mol = list(self.lipid_mol)
logger.debug(f"Converting lipid input: {lipid_mol}")
amber_lipid_mol = charmm_amber_lipid_df.query("residue in @lipid_mol")[
"replace"
]
amber_lipid_mol = (
amber_lipid_mol.apply(lambda x: x.split()[1]).unique().tolist()
)
# extend instead of replacing so that we can have both
lipid_mol.extend(amber_lipid_mol)
self.lipid_mol = lipid_mol
logger.debug(f"New lipid_mol list: {self.lipid_mol}")
def _run_input_protein_dssp(self) -> Dict[str, Any]:
"""
Run DSSP on the input protein structure and persist the assignments.
"""
dssp_npy = self.ligands_folder / "protein_input_dssp.npy"
dssp_json = self.ligands_folder / "protein_input_dssp.json"
try:
u_prot = mda.Universe(self._protein_input)
dssp_ana = DSSP(u_prot.select_atoms('protein and not resname NMA ACE')).run()
dssp_array = np.asarray(dssp_ana.results["dssp"])
except Exception as exc:
try:
logger.warning(f"Failed to run DSSP on full protein input {self._protein_input}, trying with last residue removed")
dssp_ana = DSSP(u_prot.select_atoms('protein and not resname NMA ACE').residues[:-1].atoms).run()
dssp_array = np.asarray(dssp_ana.results["dssp"])
except Exception as exc:
logger.warning(
f"Failed to run DSSP on protein input {self._protein_input}: {exc}. No secondary-structure conditioned restraints. "
"If you want to debug, please run `DSSP` in MDAnalysis on the input protein file."
)
dssp_array = np.array([])
np.save(dssp_npy, dssp_array)
dssp_json.write_text(json.dumps(dssp_array.tolist()))
return {
"npy": str(dssp_npy),
"json": str(dssp_json),
"shape": list(dssp_array.shape),
"results": dssp_array.tolist(),
}
def _get_alignment(self):
"""
Prepare for the alignment of the protein and ligand to the system.
"""
logger.debug("Getting the alignment of the protein and ligand to the system")
# translate the cog of protein to the origin
#
u_prot = mda.Universe(self._protein_input)
u_sys = mda.Universe(self._system_input_pdb, format="XPDB")
cog_prot = u_sys.select_atoms("protein and name CA C N O").center_of_geometry()
u_sys.atoms.positions -= cog_prot
# get translation-rotation matrix
mobile = u_prot.select_atoms(self.protein_align).select_atoms(
"name CA and not resname NMA ACE"
)
ref = u_sys.select_atoms(self.protein_align).select_atoms(
"name CA and not resname NMA ACE"
)
if mobile.n_atoms != ref.n_atoms:
raise ValueError(
f"Number of atoms in the alignment selection is different: protein_input: "
f"{mobile.n_atoms} and system_input {ref.n_atoms} \n"
f"The selection string is {self.protein_align} and name CA and not resname NMA ACE\n"
f"protein selected resids: {mobile.residues.resids}\n"
f"system selected resids: {ref.residues.resids}\n"
"set `protein_align` to a selection string that has the same number of atoms in both files"
"when running `create_system`."
)
mobile_com = mobile.center(weights=None)
ref_com = ref.center(weights=None)
mobile_coord = mobile.positions - mobile_com
ref_coord = ref.positions - ref_com
_ = align._fit_to(
mobile_coordinates=mobile_coord,
ref_coordinates=ref_coord,
mobile_atoms=u_prot.atoms,
mobile_com=mobile_com,
ref_com=ref_com,
)
cog_prot = u_prot.select_atoms("protein and name CA C N O").center_of_geometry()
u_prot.atoms.positions -= cog_prot
self.box_rotation_matrix = np.eye(3)
if self._system_topology is None:
protein_atoms = u_prot.select_atoms("protein and not resname NMA ACE")
if protein_atoms.n_atoms >= 2:
rotation_matrix, score_before, score_after = _find_min_xy_box_rotation(
protein_atoms.positions
)
if _score_lt(score_after, score_before):
u_prot.atoms.positions = _apply_rotation(
u_prot.atoms.positions, rotation_matrix
)
u_sys.atoms.positions = _apply_rotation(
u_sys.atoms.positions, rotation_matrix
)
self.box_rotation_matrix = rotation_matrix
logger.info(
"Optimized protein orientation for smaller XY box area without system_input: "
f"{score_before[0]:.2f} -> {score_after[0]:.2f} A^2 "
f"(z span {score_before[1]:.2f} -> {score_after[1]:.2f} A)."
)
final_ref = u_prot.select_atoms(self.protein_align).select_atoms(
"name CA and not resname NMA ACE"
)
final_ref_com = final_ref.center(weights=None)
final_ref_coord = final_ref.positions - final_ref_com
protein_aligned_path = self.ligands_folder / "protein_aligned.pdb"
system_aligned_path = self.ligands_folder / "system_aligned.pdb"
normalized_prot_residues = _write_pdb_with_normalized_protein_segids(
u_prot, protein_aligned_path
)
normalized_sys_residues = _write_pdb_with_normalized_protein_segids(
u_sys, system_aligned_path
)
if normalized_prot_residues or normalized_sys_residues:
logger.warning(
"Detected mixed per-atom protein segid assignments; normalized segids "
f"for {normalized_prot_residues} residue(s) in the aligned protein and "
f"{normalized_sys_residues} residue(s) in the aligned system before grouping."
)
self._protein_aligned_pdb = str(protein_aligned_path)
self._system_aligned_pdb = str(system_aligned_path)
# store these for ligand alignment
self.mobile_com = mobile_com
self.mobile_coord = mobile_coord
self.ref_com = final_ref_com
self.ref_coord = final_ref_coord
def _process_system(self):
"""
Generate the protein, reference, and lipid (if applicable) files.
We will align the protein_input to the system_topology because
the system_topology is generated by dabble and may be shifted;
we want to align the protein to the system so the membrane is
properly positioned.
"""
logger.debug("Processing the system")
if not self._protein_aligned_pdb or not self._system_aligned_pdb:
raise RuntimeError("Alignment not computed. Call _get_alignment() first.")
u_prot = mda.Universe(self._protein_aligned_pdb)
u_sys = mda.Universe(self._system_aligned_pdb, format="XPDB")
try:
u_sys.atoms.chainIDs
except AttributeError:
u_sys.add_TopologyAttr("chainIDs")
try:
u_prot.atoms.chainIDs
except AttributeError:
u_prot.add_TopologyAttr("chainIDs")
memb_seg = u_sys.add_Segment(segid="MEMB")
water_seg = u_sys.add_Segment(segid="WATR")
protein_fragment_groups: list[tuple[Any, str]] = []
fragment_chain_index = 0
protein_source_groups = _group_residues_by_source_identity(
u_prot.select_atoms("protein").residues
)
for source_group in protein_source_groups:
chain_id = (
str(source_group[0].atoms.chainIDs[0]).strip() if len(source_group[0].atoms) else ""
)
segid = str(source_group[0].segid).strip()
residue_groups, split_warnings = _split_residues_on_breaks(
source_group,
segid=segid,
chain_id=chain_id,
)
for warning_message in split_warnings:
logger.warning(warning_message)
for residues in residue_groups:
new_chain_id = _chain_id_from_index(fragment_chain_index)
fragment_chain_index += 1
prot_selection = _select_fragment_atoms(
u_prot,
residues,
chain_id=chain_id,
segid=segid,
)
prot_selection.atoms.chainIDs = new_chain_id
protein_fragment_groups.append((prot_selection, new_chain_id))
comp_2_combined = []
if self.receptor_segment:
protein_anchor = u_prot.select_atoms(
f"segid {self.receptor_segment} and protein"
)
other_protein = u_prot.select_atoms(
f"not segid {self.receptor_segment} and protein"
)
comp_2_combined.append(protein_anchor)
comp_2_combined.append(other_protein)
else:
comp_2_combined.append(u_prot.select_atoms("protein"))
for prot_selection, new_chain_id in protein_fragment_groups:
prot_selection.residues.segments = u_prot.add_Segment(segid=new_chain_id)
if self.membrane_simulation:
membrane_ag = u_sys.select_atoms(f'resname {" ".join(self.lipid_mol)}')
if len(membrane_ag) == 0:
logger.warning(
f"No membrane atoms found with resname {self.lipid_mol}. Available resnames are {list(np.unique(u_sys.atoms.resnames))}. "
"Please check the lipid_mol parameter.",
)
else:
with open(f"{build_files_orig}/memb_opls2charmm.json", "r") as f:
MEMB_OPLS_2_CHARMM_DICT = json.load(f)
if np.any(membrane_ag.names == "O1"):
if np.any(membrane_ag.residues.resnames != "POPC"):
raise ValueError(
f"Found OPLS lipid name {membrane_ag.residues.resnames}, only 'POPC' is supported."
)
# convert the lipid names to CHARMM names
membrane_ag.names = [
MEMB_OPLS_2_CHARMM_DICT.get(name, name)
for name in membrane_ag.names
]
logger.info("Converting OPLS lipid names to CHARMM names.")
membrane_ag.chainIDs = "M"
membrane_ag.residues.segments = memb_seg
logger.debug(f"Number of lipid molecules: {membrane_ag.n_residues}")
comp_2_combined.append(membrane_ag)
else:
membrane_ag = u_sys.atoms[[]] # empty selection
# gather water (and ions) around protein/membrane
water_ag = u_sys.select_atoms(
"byres (((resname SPC and name O) or water) and around 15 (protein or group memb))",
memb=membrane_ag,
)
logger.debug(f"Number of water molecules: {water_ag.n_residues}")
ion_ag = u_sys.select_atoms(
"byres (resname SOD POT CLA NA CL and around 5 (protein))"
)
logger.debug(f"Number of ion molecules: {ion_ag.n_residues}")
# normalize ion names
ion_ag.select_atoms("resname SOD").names = "Na+"
ion_ag.select_atoms("resname SOD").residues.resnames = "Na+"
ion_ag.select_atoms("resname NA").names = "Na+"
ion_ag.select_atoms("resname NA").residues.resnames = "Na+"
ion_ag.select_atoms("resname POT").names = "K+"
ion_ag.select_atoms("resname POT").residues.resnames = "K+"
ion_ag.select_atoms("resname CLA").names = "Cl-"
ion_ag.select_atoms("resname CLA").residues.resnames = "Cl-"
ion_ag.select_atoms("resname CL").names = "Cl-"
ion_ag.select_atoms("resname CL").residues.resnames = "Cl-"
water_ag = water_ag + ion_ag
water_ag.chainIDs = "W"
water_ag.residues.segments = water_seg
if len(water_ag) == 0:
logger.warning(
f"No water molecules found in the system. Available resnames are {np.unique(u_sys.atoms.resnames)}. "
"Please check the system_topology and system_coordinate files.",
)
else:
comp_2_combined.append(water_ag)
u_merged = mda.Merge(*comp_2_combined)
water = u_merged.select_atoms("water or resname SPC")
if len(water) != 0:
logger.debug(
f"Number of water molecules in merged system: {water.n_residues}"
)
logger.debug(f"Water atom names: {water.residues[0].atoms.names}")
# Normalize water O names for tleap
water.select_atoms("name OW").names = "O"
water.select_atoms("name OH2").names = "O"
box_dim = np.zeros(6)
if len(self.system_dimensions) == 3:
box_dim[:3] = self.system_dimensions
box_dim[3:] = 90.0
elif len(self.system_dimensions) == 6:
box_dim = self.system_dimensions
else:
raise ValueError(f"Invalid system_dimensions: {self.system_dimensions}")
u_merged.dimensions = box_dim
charmm_2_std_resname_map = {
"HIS": "HIE", # generic HIS → HIE
"HSD": "HID", # δ-protonated
"HSE": "HIE", # ε-protonated
"HIP": "HIP", # doubly protonated
}
def infer_histidine_resname(res) -> str:
"""
Infer HID/HIE/HIP from explicit hydrogens, if present.
Falls back to HIE when ambiguous or hydrogens absent.
"""
# Atom names are the most informative for histidine protonation
atom_names = {a.name.upper() for a in res.atoms}
# Common naming across force fields: HD1 on ND1, HE2 on NE2
has_hd1 = "HD1" in atom_names
has_he2 = "HE2" in atom_names
if has_hd1 and has_he2:
logger.warning(f"Found both HD1 and HE2 in residue {res.resname} {res.resid}; setting to HIP")
return "HIP"
if has_hd1:
return "HID"
if has_he2:
return "HIE"
# If hydrogens exist but aren't named HD1/HE2, we can't reliably infer
# (or hydrogens are absent entirely). Default to HIE.
return "HIE"
# replace CHARMM specific resname
for res in u_merged.residues:
# if the protein contains hydrogen and use a generic HIS name, get the correct resname based on protonation
if res.resname == "HIS":
new_name = infer_histidine_resname(res)
else:
new_name = charmm_2_std_resname_map.get(res.resname, res.resname)
res.resname = new_name
charmm_2_std_resname_map = {
("ILE", "CD"): "CD1",
}
# replace CHARMM specific atom name
for atom in u_merged.atoms:
new_name = charmm_2_std_resname_map.get((atom.resname, atom.name), atom.name)
atom.name = new_name
u_merged.atoms.write(f"{self.ligands_folder}/{self.system_name}.pdb")
protein_ref = u_prot.select_atoms("protein")
protein_ref.write(f"{self.ligands_folder}/reference.pdb")
def _align_2_system(self, mobile_atoms):
"""
Apply the stored rigid-body transform to bring a ligand into system frame.
"""
_ = align._fit_to(
mobile_coordinates=self.mobile_coord,
ref_coordinates=self.ref_coord,
mobile_atoms=mobile_atoms,
mobile_com=self.mobile_com,
ref_com=self.ref_com,
)
def _prepare_all_ligands(self):
"""
Prepare ligand ligands for the system from input ligand files (PDB/SDF/MOL2).
"""
logger.debug("prepare ligands")
new_ligand_dict: Dict[str, str] = {}
# name order is deterministic
for i, (name, ligand_path) in enumerate(sorted(self.ligand_dict.items())):
name_up = name.upper()
ligand_file = _ensure_pdb(Path(ligand_path), self.ligandff_folder)
u = mda.Universe(str(ligand_file))
try:
u.atoms.chainIDs
except AttributeError:
u.add_TopologyAttr("chainIDs")
lig_seg = u.add_Segment(segid="LIG")
u.atoms.chainIDs = "L"
u.atoms.residues.segments = lig_seg
u.atoms.residues.resnames = "lig"
logger.debug(f"Processing ligand {i}: {ligand_path}")
self._align_2_system(u.atoms)
out_ligand = f"{self.ligands_folder}/{name}.pdb"
u.atoms.write(out_ligand)
new_ligand_dict[name] = out_ligand
self.ligand_dict = new_ligand_dict
# -----------------------
# Orchestrated entry
# -----------------------
def run(
self,
*,
system_name: str,
protein_input: str,
ligand_paths: Dict[str, str],
anchor_atoms: List[str],
system_topology: str | None = None,
ligand_anchor_atom: str | None = None,
receptor_segment: str | None = None,
system_coordinate: str | None = None,
protein_align: str = "name CA and resid 60 to 250",
receptor_ff: str = "protein.ff14SB",
retain_lig_prot: bool = True,
ligand_ph: float = 7.4,
lipid_mol: List[str] = [],
lipid_ff: str = "lipid21",
unbound_threshold: float | None = None,
overwrite: bool = False,
verbose: bool = False,
) -> Dict[str, Any]:
self._system_name = system_name
self._protein_input = self._resolve_input_path(protein_input)
self._system_topology = self._resolve_input_path(system_topology) if system_topology else None
self._system_coordinate = (
self._resolve_input_path(system_coordinate)
if system_coordinate
else None
)
self.ligand_dict = {k: self._resolve_input_path(v) for k, v in ligand_paths.items()}
# prefer the provided keys for naming
self.unique_mol_names = [k.upper() for k in ligand_paths.keys()]
self.receptor_segment = receptor_segment
self.protein_align = protein_align
self.receptor_ff = receptor_ff
self.retain_lig_prot = retain_lig_prot
self.ligand_ph = ligand_ph
self.overwrite = overwrite
self.lipid_mol = lipid_mol or []
self.membrane_simulation = bool(self.lipid_mol)
self.lipid_ff = lipid_ff
# sanity checks
if not Path(self._protein_input).exists():
raise FileNotFoundError(f"Protein input file not found: {protein_input}")
for p in self.ligand_dict.values():
if not Path(p).exists():
raise FileNotFoundError(f"Ligand file not found: {p}")
if self._system_coordinate and not Path(self._system_coordinate).exists():
raise FileNotFoundError(
f"System coordinate file not found: {system_coordinate}"
)
# Directories
self.ligands_folder.mkdir(parents=True, exist_ok=True)
dssp_result = self._run_input_protein_dssp()
# Box dimensions
if self.membrane_simulation or self._system_topology is not None:
u_sys = mda.Universe(self._system_topology, format="XPDB")
if self._system_coordinate:
with open(self._system_coordinate) as f:
lines = f.readlines()
box = np.array([float(x) for x in lines[-1].split()])
self.system_dimensions = box
u_sys.load_new(self._system_coordinate, format="INPCRD")
else:
try:
self.system_dimensions = u_sys.dimensions[:3]
except TypeError:
if self.membrane_simulation:
raise ValueError(
"No box dimensions found in system_topology; required for membrane systems."
)
protein = u_sys.select_atoms("protein")
padding = 10.0
box_x = (
protein.positions[:, 0].max()
- protein.positions[:, 0].min()
+ 2 * padding
)
box_y = (
protein.positions[:, 1].max()
- protein.positions[:, 1].min()
+ 2 * padding
)
box_z = (
protein.positions[:, 2].max()
- protein.positions[:, 2].min()
+ 2 * padding
)
self.system_dimensions = np.array([box_x, box_y, box_z])
logger.warning(
"No box dimensions in system_topology. Using default 10 Å padding around protein. "
f"Box dimensions: {self.system_dimensions}"
)
u_sys.atoms.write(f"{self.ligands_folder}/system_input.pdb")
self._system_input_pdb = f"{self.ligands_folder}/system_input.pdb"
else:
self._system_input_pdb = self._protein_input
if (
self.membrane_simulation
and (u_sys.atoms.dimensions is None or not u_sys.atoms.dimensions.any())
and self._system_coordinate is None
):
raise ValueError(
"No box dimensions found in system_topology or system_coordinate when lipid system is on."
)
# membrane remapping (if any)
if self.membrane_simulation:
self._prepare_membrane()
# Align protein to system, save aligned files, compute translation
self._get_alignment()
# Build reference & docked PDBs
self._process_system()
# Make <ligand>.pdb for each ligand by translation-only
self._prepare_all_ligands()
# Anchors from first ligand + protein
u_prot = mda.Universe(f"{self.output_dir}/all-ligands/reference.pdb")
first_ligand_path = sorted(self.ligand_dict.values())[0]
u_lig = mda.Universe(first_ligand_path)
lig_sdf = str(Path(ligand_paths[self.unique_mol_names[0]]))
l1_x, l1_y, l1_z, p1, p2, p3, l1_range = find_anchor_atoms(
u_prot,
u_lig,
lig_sdf,
anchor_atoms,
ligand_anchor_atom,
unbound_threshold=unbound_threshold,
)
self.anchor_atoms = anchor_atoms
self.ligand_anchor_atom = ligand_anchor_atom
self.l1_x, self.l1_y, self.l1_z = l1_x, l1_y, l1_z
self.p1, self.p2, self.p3 = p1, p2, p3
self.l1_range = l1_range
# manifest for downstream steps
manifest = {
"system_name": self._system_name,
"reference": str(self.ligands_folder / "reference.pdb"),
"docked": str(self.ligands_folder / f"{self._system_name}.pdb"),
"ligands": dict(self.ligand_dict),
"dssp": dssp_result,
"anchors": {"p1": self.p1, "p2": self.p2, "p3": self.p3},
"l1": {
"x": self.l1_x,
"y": self.l1_y,
"z": self.l1_z,
"range": self.l1_range,
},
"membrane": (
{"lipid_mol": self.lipid_mol, "lipid_ff": self.lipid_ff}
if self.membrane_simulation
else None
),
}
(self.ligands_folder / "manifest.json").write_text(
json.dumps(manifest, indent=2)
)
logger.debug("System loaded and prepared.")
return manifest
[docs]
def system_prep(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult:
"""Prepare a system by aligning components and generating reference structures.
Parameters
----------
step : Step
Pipeline metadata (unused).
system : SimSystem
Simulation system descriptor.
params : dict
Handler payload validated into :class:`StepPayload`.
Returns
-------
ExecResult
Paths to generated reference structures and a metadata dictionary with
anchor and membrane information.
"""
logger.info(f"[system_prep] Preparing system in {system.root}")
payload = StepPayload.model_validate(params)
sys_params = payload.sys_params or SystemParams()
yaml_dir = Path(sys_params["yaml_dir"]).resolve()
threshold_val = sys_params.get(
"unbound_threshold",
getattr(payload.sim, "unbound_threshold", 8.0),
)
unbound_threshold = (
float(threshold_val) if threshold_val is not None else None
)
runner = _SystemPrepRunner(system, yaml_dir)
manifest = runner.run(
system_name=sys_params["system_name"],
protein_input=sys_params["protein_input"],
system_topology=sys_params.get("system_input", None),
ligand_paths=sys_params["ligand_paths"],
anchor_atoms=list(sys_params.get("anchor_atoms", [])),
ligand_anchor_atom=sys_params.get("ligand_anchor_atom"),
receptor_segment=sys_params.get("receptor_segment"),
system_coordinate=sys_params.get("system_coordinate"),
protein_align=sys_params.get("protein_align", "name CA and resid 60 to 250"),
receptor_ff=sys_params.get("receptor_ff", "protein.ff14SB"),
retain_lig_prot=bool(sys_params.get("retain_lig_prot", True)),
ligand_ph=float(sys_params.get("ligand_ph", 7.4)),
lipid_mol=list(sys_params.get("lipid_mol", [])),
lipid_ff=sys_params.get("lipid_ff", "lipid21"),
unbound_threshold=unbound_threshold,
overwrite=bool(sys_params.get("overwrite", False)),
verbose=bool(sys_params.get("verbose", False)),
)
outputs = [
system.root / "all-ligands" / "reference.pdb",
system.root / "all-ligands" / f"{sys_params['system_name']}.pdb",
]
updates = {
"p1": manifest["anchors"]["p1"],
"p2": manifest["anchors"]["p2"],
"p3": manifest["anchors"]["p3"],
"l1_x": manifest["l1"]["x"],
"l1_y": manifest["l1"]["y"],
"l1_z": manifest["l1"]["z"],
"l1_range": manifest["l1"]["range"],
"lipid_mol": manifest["membrane"]["lipid_mol"] if manifest["membrane"] else [],
}
(manifest_dir := (system.root / "artifacts" / "config")).mkdir(
parents=True, exist_ok=True
)
overrides_path = system.root / "artifacts" / "config" / "sim_overrides.json"
overrides_path.write_text(json.dumps(updates, indent=2))
marker_rel = overrides_path.relative_to(system.root).as_posix()
register_phase_state(
system.root,
"system_prep",
required=[[marker_rel]],
success=[[marker_rel]],
)
logger.info(f"[system_prep] System preparation complete.")
info = {"system_prep_ok": True, **manifest, "sim_updates": updates}
return ExecResult(outputs, info)