Source code for batter.rbfe

"""RBFE network helpers."""

from __future__ import annotations

from pathlib import Path
import json
from dataclasses import dataclass
from typing import Callable, Iterable, Sequence, Tuple, List, Any, Mapping
from loguru import logger

from batter.config.utils import sanitize_ligand_name
from rdkit import Chem
from rdkit.Geometry import Point3D
from rdkit.Chem import rdMolAlign, AllChem


def _normalize_atom_mapper(atom_mapper: str | None) -> str:
    mapper = str(atom_mapper or "kartograf").strip().lower()
    if mapper not in {"kartograf", "lomap"}:
        raise ValueError(
            f"Unknown atom mapper '{atom_mapper}'. Available: kartograf, lomap"
        )
    return mapper


def _mapper_options_dict(options: Any | None) -> dict[str, Any]:
    if options is None:
        return {}
    if hasattr(options, "model_dump"):
        return dict(options.model_dump(exclude_none=True, exclude_unset=True))
    if isinstance(options, Mapping):
        return {str(key): value for key, value in options.items() if value is not None}
    return dict(options)


def _lomap_mapper_kwargs(options: Any | None = None) -> dict[str, Any]:
    kwargs = {
        "time": 20,
        "threed": True,
        "max3d": 1.5,
        "element_change": False,
        "shift": True,
    }
    kwargs.update(_mapper_options_dict(options))
    return kwargs


def _kartograf_mapper_kwargs(
    options: Any | None = None,
    *,
    atom_map_hydrogens_default: bool,
) -> dict[str, Any]:
    mapper_options = _mapper_options_dict(options)
    use_element_filter = mapper_options.pop("filter_element_changes", True)
    use_attached_h_filter = mapper_options.pop("filter_mismatched_attached_h_count", False)
    mapper_options.pop("atom_map_hydrogens", None)
    mapper_options.pop("map_hydrogens_on_hydrogens_only", None)

    kwargs = {
        "atom_max_distance": 0.95,
        "map_hydrogens_on_hydrogens_only": True,
        "atom_map_hydrogens": atom_map_hydrogens_default,
        "map_exact_ring_matches_only": True,
        "allow_partial_fused_rings": True,
        "allow_bond_breaks": False,
    }
    kwargs.update(mapper_options)

    additional_mapping_filter_functions = []
    if use_element_filter:
        additional_mapping_filter_functions.append(filter_element_changes)
    if use_attached_h_filter:
        additional_mapping_filter_functions.append(filter_mismatched_attached_h_count)
    kwargs["additional_mapping_filter_functions"] = additional_mapping_filter_functions
    return kwargs


def _build_konnektor_atom_mapper(
    atom_mapper: str,
    *,
    hmr: bool = True,
    kartograf_options: Any | None = None,
    lomap_options: Any | None = None,
):
    mapper_name = _normalize_atom_mapper(atom_mapper)
    if mapper_name == "lomap":
        from lomap import LomapAtomMapper

        return LomapAtomMapper(**_lomap_mapper_kwargs(lomap_options))

    return _build_current_kartograf_atom_mapper_for_network(
        kartograf_options=kartograf_options
    )


def _build_current_kartograf_atom_mapper_for_network(
    kartograf_options: Any | None = None,
):
    """Return the Kartograf mapper currently used for RBFE network generation."""
    from kartograf.atom_mapper import KartografAtomMapper

    return KartografAtomMapper(
        **_kartograf_mapper_kwargs(
            kartograf_options,
            atom_map_hydrogens_default=False,
        )
    )


[docs] def filter_element_changes( molA: Chem.Mol, molB: Chem.Mol, mapping: dict[int, int] ) -> dict[int, int]: """Forces a mapping to exclude any alchemical element changes in the core""" filtered_mapping = {} for i, j in mapping.items(): if ( molA.GetAtomWithIdx(i).GetAtomicNum() != molB.GetAtomWithIdx(j).GetAtomicNum() ): continue filtered_mapping[i] = j return filtered_mapping
[docs] def filter_mismatched_attached_h_count( molA: Chem.Mol, molB: Chem.Mol, mapping: dict[int, int] ) -> dict[int, int]: """ Exclude mapped heavy-atom pairs where the number of directly attached H differs. This helps avoid HMR mass mismatches for 'common/core' atoms. """ filtered = {} for i, j in mapping.items(): a = molA.GetAtomWithIdx(i) b = molB.GetAtomWithIdx(j) hA = a.GetTotalNumHs(includeNeighbors=True) hB = b.GetTotalNumHs(includeNeighbors=True) if hA != hB: continue filtered[i] = j return filtered
RBFEPair = Tuple[str, str] RBFEMapFn = Callable[[Sequence[str]], Iterable[RBFEPair]] def _dedupe_pairs(pairs: Iterable[RBFEPair]) -> List[RBFEPair]: seen: set[RBFEPair] = set() out: List[RBFEPair] = [] for pair in pairs: if pair in seen: continue seen.add(pair) out.append(pair) return out def _normalize_pair(pair: Any) -> RBFEPair: if isinstance(pair, str): if "~" in pair: left, right = (p.strip() for p in pair.split("~", 1)) elif "," in pair: left, right = (p.strip() for p in pair.split(",", 1)) else: parts = [p for p in pair.split() if p] if len(parts) != 2: raise ValueError(f"RBFE mapping line must contain 2 tokens: {pair!r}") left, right = parts elif isinstance(pair, (list, tuple)) and len(pair) == 2: left, right = pair else: raise ValueError(f"RBFE mapping entries must be 2-tuples; got {pair!r}.") return (sanitize_ligand_name(str(left)), sanitize_ligand_name(str(right))) def _pairs_from_data(data: Any) -> List[RBFEPair]: if isinstance(data, dict): if "pairs" in data: raw = data["pairs"] elif "edges" in data: raw = data["edges"] else: # adjacency mapping: {LIG1: [LIG2, LIG3], ...} raw = [] for src, targets in data.items(): if not isinstance(targets, (list, tuple)): raise ValueError( "RBFE mapping dict must map ligands to list of targets." ) for tgt in targets: raw.append([src, tgt]) return [_normalize_pair(p) for p in raw] if isinstance(data, list): return [_normalize_pair(p) for p in data] raise ValueError(f"Unsupported RBFE mapping data type: {type(data).__name__}")
[docs] def load_mapping_file(path: Path) -> List[RBFEPair]: """ Load RBFE mapping pairs from a file. Supported formats: - JSON/YAML: list of pairs, or dict with 'pairs'/'edges', or adjacency mapping. - Text: one pair per line, separated by '~', ',' or whitespace. """ if not path.exists(): raise FileNotFoundError(f"RBFE mapping file not found: {path}") suffix = path.suffix.lower() if suffix in {".json", ".yaml", ".yml"}: if suffix == ".json": data = json.loads(path.read_text()) else: import yaml data = yaml.safe_load(path.read_text()) pairs = _pairs_from_data(data) else: pairs = [] for raw in path.read_text().splitlines(): line = raw.strip() if not line or line.startswith("#"): continue pairs.append(_normalize_pair(line)) if not pairs: raise ValueError(f"RBFE mapping file produced no pairs: {path}") return pairs
[docs] def resolve_mapping_fn(name: str | None) -> RBFEMapFn: """ Resolve a mapping function by name. """ if not name: return RBFENetwork.default_mapping key = str(name).strip().lower() if key in {"default", "star", "first"}: return RBFENetwork.default_mapping if key in {"konnektor"}: raise ValueError( "RBFE mapping 'konnektor' requires ligand inputs; it must be resolved " "in the orchestrator when building the network." ) raise ValueError(f"Unknown RBFE mapping '{name}'. Available: default, konnektor")
def _load_rdkit_mol(path: Path): from rdkit import Chem suffix = path.suffix.lower() if suffix in {".sdf", ".sd"}: supplier = Chem.SDMolSupplier(str(path), removeHs=False) mol = supplier[0] if supplier and len(supplier) > 0 else None elif suffix == ".mol2": mol = Chem.MolFromMol2File(str(path), removeHs=False) elif suffix == ".pdb": from MDAnalysis import Universe u = Universe(str(path)) mol = u.atoms.convert_to("RDKIT") else: mol = Chem.MolFromMolFile(str(path), removeHs=False) if mol is None: raise ValueError(f"Failed to load ligand from {path} with RDKit.") return mol def _resolve_konnektor_generator(layout: str | None): try: from konnektor import network_planners as gen except ImportError as exc: raise RuntimeError( "Konnektor mapping requires the 'konnektor' package to be installed." ) from exc layout_key = (layout or "star").strip().lower() candidates: dict[str, type] = {} for name in dir(gen): if not name.endswith("NetworkGenerator"): continue cls = getattr(gen, name) short = name[: -len("NetworkGenerator")].lower() candidates[short] = cls candidates[name.lower()] = cls logger.debug(f'Available Konnektor network generators: {list(candidates.keys())}') if layout_key not in candidates: raise ValueError( f"Unknown Konnektor layout '{layout_key}'. Available: {', '.join(candidates.keys())}" ) return candidates[layout_key] def _pairs_from_konnektor_network(network) -> List[RBFEPair]: edges = getattr(network, "edges", None) if edges is None and hasattr(network, "to_edges"): edges = network.to_edges() if edges is None: raise RuntimeError("Konnektor network did not expose edges.") pairs: List[RBFEPair] = [] for edge in edges: if isinstance(edge, (list, tuple)) and len(edge) == 2: a, b = edge elif hasattr(edge, "componentA") and hasattr(edge, "componentB"): a, b = edge.componentA, edge.componentB elif hasattr(edge, "component1") and hasattr(edge, "component2"): a, b = edge.component1, edge.component2 elif hasattr(edge, "components"): comps = list(edge.components) if len(comps) != 2: raise RuntimeError("Konnektor edge did not include two components.") a, b = comps else: raise RuntimeError("Unsupported Konnektor edge object format.") name_a = sanitize_ligand_name(getattr(a, "name", str(a))) name_b = sanitize_ligand_name(getattr(b, "name", str(b))) pairs.append((name_a, name_b)) return pairs
[docs] def konnektor_pairs( ligands: Sequence[str], ligand_files: Mapping[str, Path], layout: str | None = None, plot_path: Path | None = None, hmr: bool = True, atom_mapper: str = "kartograf", kartograf_options: Any | None = None, lomap_options: Any | None = None, ) -> List[RBFEPair]: """ Build RBFE pairs using Konnektor network planners. """ try: from gufe import SmallMoleculeComponent from lomap.gufe_bindings.scorers import default_lomap_score except ImportError as exc: raise RuntimeError( "Konnektor mapping requires 'gufe' and 'lomap' dependencies." ) from exc generator_cls = _resolve_konnektor_generator(layout) if generator_cls.__name__.lower().startswith("explicit"): raise ValueError( "Konnektor 'explicit' layout requires explicit edges; use rbfe.mapping_file." ) mapper = _build_konnektor_atom_mapper( atom_mapper, hmr=hmr, kartograf_options=kartograf_options, lomap_options=lomap_options, ) generator = generator_cls(mappers=mapper, scorer=default_lomap_score) components: List[SmallMoleculeComponent] = [] for lig in ligands: path = Path(ligand_files[lig]) mol = _load_rdkit_mol(path) components.append(SmallMoleculeComponent(mol, name=lig)) if hasattr(generator, "generate_ligand_network"): network = generator.generate_ligand_network(components) elif hasattr(generator, "generate_network"): network = generator.generate_network(components) elif callable(generator): network = generator(components) else: raise RuntimeError("Unsupported Konnektor generator API.") if plot_path is not None: try: from konnektor.visualization import draw_ligand_network fig = draw_ligand_network(network=network, title=getattr(network, "name", None)) plot_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(plot_path, dpi=200) with open(f"{plot_path.parent}/network.graphml", "w") as writer: writer.write(network.to_graphml()) except Exception: pass pairs = _pairs_from_konnektor_network(network) if not pairs: raise ValueError("Konnektor mapping produced no ligand pairs.") return pairs
[docs] def draw_explicit_konnektor_network( pairs: Sequence[Sequence[str] | tuple[str, str]], ligand_files: Mapping[str, Path], plot_path: Path, hmr: bool = True, atom_mapper: str = "kartograf", kartograf_options: Any | None = None, lomap_options: Any | None = None, ) -> None: """Build an explicit Konnektor network from pairs and draw it.""" mapper_name = _normalize_atom_mapper(atom_mapper) try: from konnektor.network_planners import ExplicitNetworkGenerator from konnektor.visualization import draw_ligand_network from gufe import SmallMoleculeComponent from lomap.gufe_bindings.scorers import default_lomap_score align_mol_shape = None if mapper_name == "kartograf": from kartograf.atom_aligner import align_mol_shape as _align_mol_shape align_mol_shape = _align_mol_shape except Exception: return try: mapper = _build_konnektor_atom_mapper( mapper_name, hmr=hmr, kartograf_options=kartograf_options, lomap_options=lomap_options, ) except Exception: return comp_by_name: dict[str, SmallMoleculeComponent] = {} edges = [] nodes_by_name: dict[str, SmallMoleculeComponent] = {} for ref, alt in pairs: name_a = str(ref) name_b = str(alt) if name_a not in ligand_files or name_b not in ligand_files: continue if name_a not in comp_by_name: mol_a = _load_rdkit_mol(Path(ligand_files[name_a])) comp_by_name[name_a] = SmallMoleculeComponent(mol_a, name=name_a) if name_b not in comp_by_name: mol_b = _load_rdkit_mol(Path(ligand_files[name_b])) comp_by_name[name_b] = SmallMoleculeComponent(mol_b, name=name_b) comp_a = comp_by_name[name_a] comp_b = comp_by_name[name_b] if align_mol_shape is not None: try: comp_b = align_mol_shape(comp_b, ref_mol=comp_a) except Exception: pass edges.append((comp_a, comp_b)) nodes_by_name.setdefault(name_a, comp_a) nodes_by_name.setdefault(name_b, comp_b) if not edges: return nodes = list(nodes_by_name.values()) generator = ExplicitNetworkGenerator(mappers=mapper, scorer=default_lomap_score) try: network = generator.generate_ligand_network(edges=edges, nodes=nodes) fig = draw_ligand_network(network=network, title=getattr(network, "name", None)) plot_path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(plot_path, dpi=200) with open(f"{plot_path.parent}/network.graphml", "w") as writer: writer.write(network.to_graphml()) except Exception: return
[docs] @dataclass(frozen=True) class RBFENetwork: """ Record the RBFE simulation mapping as ligand pairs. Parameters ---------- ligands : Sequence[str] Ordered ligand identifiers participating in the network. pairs : Sequence[tuple[str, str]] Directed pairs describing simulations to run (reference, target). """ ligands: Tuple[str, ...] pairs: Tuple[RBFEPair, ...]
[docs] @staticmethod def default_mapping(ligands: Sequence[str]) -> List[RBFEPair]: """ Default RBFE mapping: first ligand paired to each subsequent ligand. """ if len(ligands) < 2: return [] root = ligands[0] return [(root, lig) for lig in ligands[1:]]
[docs] @classmethod def from_ligands( cls, ligands: Sequence[str], mapping_fn: RBFEMapFn | None = None, ) -> "RBFENetwork": """ Build an RBFE network from ligand identifiers and a mapping function. Parameters ---------- ligands : Sequence[str] Ordered ligand identifiers. mapping_fn : callable, optional Function that returns iterable of (ref, target) pairs. When omitted, defaults to mapping the first ligand to all others. """ if not ligands: raise ValueError("RBFE network requires at least two ligands.") lig_list = [sanitize_ligand_name(str(lig)) for lig in ligands] if len(lig_list) < 2: raise ValueError("RBFE network requires at least two ligands.") if len(set(lig_list)) != len(lig_list): raise ValueError("RBFE network ligand identifiers must be unique.") builder = mapping_fn or cls.default_mapping raw_pairs = list(builder(lig_list)) if not raw_pairs: raise ValueError("RBFE mapping function returned no ligand pairs.") lig_set = set(lig_list) cleaned: List[RBFEPair] = [] for pair in raw_pairs: if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise ValueError(f"RBFE mapping entries must be 2-tuples; got {pair!r}.") ref, tgt = str(pair[0]), str(pair[1]) if ref not in lig_set or tgt not in lig_set: raise ValueError( f"RBFE mapping contains unknown ligand(s): {(ref, tgt)!r}." ) if ref == tgt: raise ValueError("RBFE mapping cannot include self-pairs.") cleaned.append((ref, tgt)) deduped = _dedupe_pairs(cleaned) return cls(ligands=tuple(lig_list), pairs=tuple(deduped))
[docs] def to_mapping(self) -> dict: """ Return a JSON-serializable mapping payload. """ return { "ligands": list(self.ligands), "pairs": [list(p) for p in self.pairs], }