Source code for batter.exec.handlers.param_ligands

"""Parameterise ligands and populate per-ligand artifacts."""

from __future__ import annotations

import json
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List, Tuple

from loguru import logger

from batter.orchestrate.state_registry import register_phase_state
from batter.param.ligand import (
    _convert_mol_name_to_unique,
    _hash_id,
    _rdkit_load,
    _canonical_payload,
    batch_ligand_process,
)
from batter.pipeline.payloads import StepPayload, SystemParams
from batter.pipeline.step import ExecResult, Step
from batter.systems.core import SimSystem

LIGAND_FILES = ["mol2", "prmtop", "sdf", "json", "frcmod", "inpcrd", "lib"]

[docs] def copy_ligand_params(src_dir: Path, child_dir: Path, residue_name: str) -> None: """Copy ``lig.*`` artifacts into ``child_dir/params`` using ``residue_name``.""" child_params = child_dir / "params" child_params.mkdir(parents=True, exist_ok=True) for ext in LIGAND_FILES: src = src_dir / f"lig.{ext}" if not src.exists(): continue # skip missing files gracefully dst = child_params / f"{residue_name}.{ext}" if dst.exists(): continue try: shutil.copy2(src, dst) logger.debug(f"Copied {src.name}{dst}") except Exception as e: logger.warning(f"Failed to copy {src} to {dst}: {e}")
def _resolve_outdir(template: str | Path, system: SimSystem) -> Path: """Resolve ``{WORK}`` placeholders against ``system.root``.""" if not isinstance(template, (str, Path)): raise TypeError("param_ligands.outdir must be a string") resolved = str(template).replace("{WORK}", system.root.as_posix()) return Path(resolved).expanduser().resolve() def _require(sys_params: SystemParams, key: str) -> Any: try: return sys_params[key] except KeyError as exc: raise KeyError(f"[param_ligands] Missing required sys_params[{key!r}]") from exc
[docs] def param_ligands(step: Step, system: SimSystem, params: Dict[str, Any]) -> ExecResult: """Run the ligand parametrisation pipeline and index results. Parameters ---------- step : Step Pipeline metadata (unused). system : SimSystem Simulation system descriptor. params : dict Handler payload validated into :class:`StepPayload`. Returns ------- ExecResult Mapping containing the parameter store path, JSON index, manifest, and raw hashes. """ payload = StepPayload.model_validate(params) sys_params = payload.sys_params or SystemParams() lig_root = system.root / "simulations" if not lig_root.exists(): raise FileNotFoundError(f"[param_ligands] No 'ligands/' at {system.root}. Did staging run?") outdir = _resolve_outdir(sys_params["param_outdir"], system) charge = _require(sys_params, "charge") ligand_ff = _require(sys_params, "ligand_ff") retain = bool(_require(sys_params, "retain_lig_prot")) lig_map = sys_params["ligand_paths"] outdir.mkdir(parents=True, exist_ok=True) logger.info(f"[param_ligands] {len(lig_map)} ligands") logger.info( f"[param_ligands] parameterizing" f"(charge={charge}, ff={ligand_ff}, retain H={retain})" ) artifacts_index_dir = system.root / "artifacts" / "ligand_params" artifacts_index_dir.mkdir(parents=True, exist_ok=True) index_path = artifacts_index_dir / "index.json" mode_lower = "" if isinstance(params, dict): mode_lower = str(params.get("on_failure") or "").lower() elif hasattr(payload, "model_extra") and payload.model_extra: mode_lower = str(payload.model_extra.get("on_failure") or "").lower() try: # Run batch parametrization into content-addressed subfolders # Returns (hash_ids_in_order, residue_names_in_order) hashes, unique = batch_ligand_process( ligand_paths=lig_map, output_path=outdir, retain_lig_prot=retain, ligand_ff=ligand_ff, charge_method=charge, overwrite=False, run_with_slurm=False, on_failure=mode_lower, ) if not hashes: raise RuntimeError("[param_ligands] No ligands processed (empty hash list).") except Exception as exc: # allow reuse of an existing index when present if index_path.exists(): logger.error( f"[param_ligands] encountered error but index exists; reusing cached ligands. Error: {exc}", ) existing_index = json.loads(index_path.read_text()) index_entries = existing_index.get("ligands", []) # write a manifest to keep downstream in sync manifest = artifacts_index_dir / "ligand_manifest.tsv" with manifest.open("w") as mf: for entry in index_entries: mf.write( f"{entry.get('ligand')}\t{entry.get('hash')}\t{entry.get('residue_name')}\n" ) marker_rel = index_path.relative_to(system.root).as_posix() register_phase_state( system.root, "param_ligands", required=[[marker_rel]], success=[[marker_rel]], ) return ExecResult( [], { "param_store": existing_index.get("store", str(outdir)), "index_json": str(index_path), "manifest_tsv": str(manifest), "hashes": [e.get("hash") for e in index_entries], }, ) # Attempt to salvage cached ligands: use existing param store entries only salvaged_hashes: List[str] = [] unique = {} for name, path in lig_map.items(): try: mol = _rdkit_load(path, retain_h=retain) smi = _canonical_payload(mol) hid = _hash_id(smi, ligand_ff=ligand_ff, retain_h=retain) cache_dir = outdir / hid if (cache_dir / "lig.prmtop").exists(): unique[str(path)] = (hid, smi) salvaged_hashes.append(hid) except Exception: continue if salvaged_hashes: logger.error( f"[param_ligands] encountered error; salvaged {len(salvaged_hashes)} cached ligands and will skip failures.", ) hashes = salvaged_hashes else: logger.error( f"[param_ligands] encountered error and no cached ligands could be salvaged: {exc}", ) raise # generate unique list of resnames only for ligands we have data for unique_resnames: Dict[str, str] = {} seen_resnames: set[str] = set() for i, (name, p) in enumerate(lig_map.items()): smiles_val = unique.get(str(p), (None, None))[1] if smiles_val is None: continue init_mol_name = name.lower() unique_resname = _convert_mol_name_to_unique( mol_name=init_mol_name, ind=i, smiles=smiles_val, exist_mol_names=seen_resnames, ) seen_resnames.add(unique_resname) unique_resnames[name] = unique_resname # Link artifacts per staged ligand and collect index rows index_entries: List[Dict[str, Any]] = [] linked: List[Tuple[str, str]] = [] # (name, hash) for name, d in lig_map.items(): if name not in unique_resnames: logger.warning(f"[param_ligands] Skipping ligand {name} due to parametrization failure.") continue hid = unique.get(str(d), (None, None))[0] if hid is None: logger.warning(f"[param_ligands] Missing hash for ligand {name}; skipping.") continue src_dir = outdir / hid meta_path = src_dir / "metadata.json" if not src_dir.exists() or not meta_path.exists(): logger.warning( f"[param_ligands] Missing params for staged ligand {name} at {src_dir}; skipping.", ) continue meta = json.loads(meta_path.read_text()) residue_name = unique_resnames.get(name) if residue_name is None: logger.warning(f"[param_ligands] Missing residue name for {name}; skipping.") continue title = meta.get("title", name) copy_ligand_params(src_dir, lig_root / Path(name), residue_name) linked.append((name, hid, residue_name)) index_entries.append( { "ligand": name, "hash": hid, "store_dir": str(src_dir), "linked_dir": str(lig_root / Path(name) / "params"), "residue_name": residue_name, "title": title, } ) # Save a machine-readable index for downstream steps index_payload = { "store": str(outdir), "ligands": index_entries, "config": { "ligand_ff": ligand_ff, "charge": charge, "retain_lig_prot": retain, }, } index_path = artifacts_index_dir / "index.json" index_path.write_text(json.dumps(index_payload, indent=2)) marker_rel = index_path.relative_to(system.root).as_posix() register_phase_state( system.root, "param_ligands", required=[[marker_rel]], success=[[marker_rel]], ) # also save a simple TSV manifest (name\t hash) manifest = artifacts_index_dir / "ligand_manifest.tsv" with manifest.open("w") as mf: for name, lh, rn in linked: mf.write(f"{name}\t{lh}\t{rn}\n") logger.debug(f"[param_ligands] Linked params for staged ligands: {linked}") logger.debug(f"[param_ligands] Wrote index → {index_path}") # Return rich metadata so downstream steps can consume without re-reading disk, if desired return ExecResult( [], { "param_store": str(outdir), "index_json": str(artifacts_index_dir / "index.json"), "manifest_tsv": str(manifest), "hashes": hashes, }, )