Source code for rush.mol

"""
Provides data structures and helpers for molecular systems and structures:

- Classes for Rush Topology, Residues, Chains, and TRC types.
- Element types and bonds.
- Fragment type to represent fragmented systems.

Quick Links
-----------

- :class:`rush.mol.TRC`
- :class:`rush.mol.Topology`
- :class:`rush.mol.Residues`
- :class:`rush.mol.Chains`
"""

import json
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from pathlib import Path
from typing import Self


[docs] class Element(IntEnum): """Represents all relevant elements.""" X = 0 H = 1 He = 2 Li = 3 Be = 4 B = 5 C = 6 N = 7 O = 8 # noqa: E741 F = 9 Ne = 10 Na = 11 Mg = 12 Al = 13 Si = 14 P = 15 S = 16 Cl = 17 Ar = 18 K = 19 Ca = 20 Sc = 21 Ti = 22 V = 23 Cr = 24 Mn = 25 Fe = 26 Co = 27 Ni = 28 Cu = 29 Zn = 30 Ga = 31 Ge = 32 As = 33 Se = 34 Br = 35 Kr = 36
[docs] @classmethod def from_str(cls, symbol: str) -> Self: """Parse element from string symbol.""" # First try the symbol as-is (for proper case like "Fe") try: return cls[symbol] except KeyError: pass # Try uppercase (for "FE" -> "Fe") symbol_upper = symbol.upper() try: # Check all enum members for case-insensitive match for member in cls: if member.name.upper() == symbol_upper: return member except Exception: pass # Try common variations if symbol_upper in ["D"]: # Deuterium -> Hydrogen return cls.H raise ValueError(f"Unknown element symbol: {symbol}")
def __str__(self) -> str: return self.name
[docs] class AtomRef: """Reference to an atom by index.""" def __init__(self, value: int): if value < 0: raise ValueError("Atom index must be non-negative") self.value = value def __eq__(self, other): return isinstance(other, AtomRef) and self.value == other.value def __hash__(self): return hash(self.value) def __repr__(self): return f"AtomRef({self.value})" def __int__(self): return self.value
[docs] class FragmentRef: """Reference to a fragment by index.""" def __init__(self, value: int): if value < 0: raise ValueError("Fragment index must be non-negative") self.value = value def __eq__(self, other): return isinstance(other, FragmentRef) and self.value == other.value def __hash__(self): return hash(self.value) def __repr__(self): return f"FragmentRef({self.value})" def __int__(self): return self.value
[docs] class ResidueRef: """Reference to a residue by index.""" def __init__(self, value: int): if value < 0: raise ValueError("Residue index must be non-negative") self.value = value def __eq__(self, other): return isinstance(other, ResidueRef) and self.value == other.value def __hash__(self): return hash(self.value) def __repr__(self): return f"ResidueRef({self.value})" def __int__(self): return self.value
[docs] class ChainRef: """Reference to a chain by index.""" def __init__(self, value: int): if value < 0: raise ValueError("Chain index must be non-negative") self.value = value def __eq__(self, other): return isinstance(other, ChainRef) and self.value == other.value def __hash__(self): return hash(self.value) def __repr__(self): return f"ChainRef({self.value})" def __int__(self): return self.value
[docs] @dataclass class FormalCharge: """Formal charge of an atom.""" charge: int def __repr__(self): return f"FormalCharge({self.charge})" def __int__(self): return self.charge
[docs] @dataclass class PartialCharge: """Partial charge of an atom.""" charge: float def __repr__(self): return f"PartialCharge({self.charge})" def __float__(self): return self.charge
[docs] class BondOrder(IntEnum): """Bond order enum.""" Single = 1 Double = 2 Triple = 3 OneAndAHalf = 4 # Partial bond (e.g. amide bond) Ring = 5 # Aromatic
[docs] @dataclass class Bond: """Bond between two atoms.""" atom1: AtomRef atom2: AtomRef order: BondOrder def __post_init__(self): if self.atom1.value == self.atom2.value: raise ValueError("Bond cannot connect an atom to itself")
[docs] class Fragment: """Fragment containing a list of atoms.""" def __init__(self, atoms: list[AtomRef] | list[int] | None = None): # Store as list of integers to match JSON serialization if atoms is None: self.atoms = [] else: self.atoms = [ atom.value if isinstance(atom, AtomRef) else atom for atom in atoms ] def __len__(self) -> int: return len(self.atoms) def __iter__(self): # Return AtomRef objects when iterating return (AtomRef(atom) for atom in self.atoms) def __eq__(self, other): return isinstance(other, Fragment) and self.atoms == other.atoms def __repr__(self): return f"Fragment({[AtomRef(a) for a in self.atoms]})"
[docs] class SchemaVersion(Enum): """Schema version for the topology format.""" V1 = "v1" V2 = "v2"
[docs] @dataclass class Topology: """Topology contains all atom information.""" schema_version: SchemaVersion = SchemaVersion.V2 # Element of each atom symbols: list[Element] = field(default_factory=list) # XYZ coordinates of each atom (3 * len(symbols)) geometry: list[float] = field(default_factory=list) # Optional atom labels labels: list[str] | None = None # Optional formal charges formal_charges: list[FormalCharge] | None = None # Optional partial charges partial_charges: list[PartialCharge] | None = None # Optional connectivity connectivity: list[Bond] | None = None # Optional velocities (3 * len(symbols)) velocities: list[float] | None = None # Optional fragments fragments: list[Fragment] | None = None # Optional fragment charges fragment_formal_charges: list[FormalCharge] | None = None fragment_partial_charges: list[PartialCharge] | None = None
[docs] @staticmethod def from_json(json_content: str | Path | dict) -> "Topology": if isinstance(json_content, str): topology_data = json.loads(json_content) elif isinstance(json_content, Path): with json_content.open() as f: topology_data = json.load(f) elif isinstance(json_content, dict): topology_data = json_content else: print( "WARNING: Tried to load Topology from JSON but " "it wasn't a str, Path, or dict!" ) topology_data = json_content topology = Topology() # Default, could parse from schema_version topology.schema_version = SchemaVersion.V2 topology.symbols = [Element.from_str(s) for s in topology_data["symbols"]] topology.geometry = topology_data["geometry"] if "labels" in topology_data and topology_data["labels"]: topology.labels = topology_data["labels"] if "formal_charges" in topology_data and topology_data["formal_charges"]: topology.formal_charges = [ FormalCharge(c) for c in topology_data["formal_charges"] ] if "partial_charges" in topology_data and topology_data["partial_charges"]: topology.partial_charges = [ PartialCharge(c) for c in topology_data["partial_charges"] ] if "velocities" in topology_data and topology_data["velocities"]: topology.velocities = topology_data["velocities"] if "connectivity" in topology_data and topology_data["connectivity"]: # Connectivity is a list of [atom1, atom2, bond_order] # BondOrder enum: 1=Single, 2=Double, 3=Triple, 4=OneAndAHalf (partial), 5=Ring (aromatic) bonds = [] for bond_data in topology_data["connectivity"]: if isinstance(bond_data, list) and len(bond_data) >= 2: atom1_idx = bond_data[0] atom2_idx = bond_data[1] bond_order_val = bond_data[2] # Support old version mapping: 254 -> 4 (OneAndAHalf/partial), 255 -> 5 (Ring/aromatic) if bond_order_val == 254: bond_order_val = 4 elif bond_order_val == 255: bond_order_val = 5 bond_order = BondOrder(bond_order_val) bonds.append( Bond(AtomRef(atom1_idx), AtomRef(atom2_idx), bond_order) ) topology.connectivity = bonds if "fragments" in topology_data and topology_data["fragments"]: topology.fragments = [Fragment(frag) for frag in topology_data["fragments"]] if ( "fragment_formal_charges" in topology_data and topology_data["fragment_formal_charges"] ): topology.fragment_formal_charges = [ FormalCharge(c) for c in topology_data["fragment_formal_charges"] ] if ( "fragment_partial_charges" in topology_data and topology_data["fragment_partial_charges"] ): topology.fragment_partial_charges = [ PartialCharge(c) for c in topology_data["fragment_partial_charges"] ] return topology
[docs] def to_json(self) -> dict[str, object]: topology_dict: dict[str, object] = { "schema_version": "0.2.0", "symbols": [str(symbol) for symbol in self.symbols], "geometry": self.geometry, } if self.labels is not None: topology_dict["labels"] = self.labels if self.formal_charges is not None: topology_dict["formal_charges"] = [c.charge for c in self.formal_charges] if self.partial_charges is not None: topology_dict["partial_charges"] = [c.charge for c in self.partial_charges] if self.connectivity is not None: topology_dict["connectivity"] = [ [bond.atom1.value, bond.atom2.value, bond.order.value] for bond in self.connectivity ] else: topology_dict["connectivity"] = [] if self.velocities is not None: topology_dict["velocities"] = self.velocities if self.fragments is not None: topology_dict["fragments"] = [fragment.atoms for fragment in self.fragments] else: topology_dict["fragments"] = [] if self.fragment_formal_charges is not None: topology_dict["fragment_formal_charges"] = [ c.charge for c in self.fragment_formal_charges ] else: topology_dict["fragment_formal_charges"] = [] if self.fragment_partial_charges is not None: topology_dict["fragment_partial_charges"] = [ c.charge for c in self.fragment_partial_charges ] return topology_dict
[docs] def check(self) -> None: """Validate the topology structure.""" # Check geometry length if len(self.geometry) != len(self.symbols) * 3: raise ValueError( f"Geometry length {len(self.geometry)} != symbols length {len(self.symbols)} * 3" ) # Check optional field lengths if self.labels is not None and len(self.labels) != len(self.symbols): raise ValueError( f"Labels length {len(self.labels)} != symbols length {len(self.symbols)}" ) if self.partial_charges is not None and len(self.partial_charges) != len( self.symbols ): raise ValueError( f"Partial charges length {len(self.partial_charges)} != symbols length {len(self.symbols)}" ) if self.formal_charges is not None and len(self.formal_charges) != len( self.symbols ): raise ValueError( f"Formal charges length {len(self.formal_charges)} != symbols length {len(self.symbols)}" ) if ( self.velocities is not None and len(self.velocities) != len(self.symbols) * 3 ): raise ValueError( f"Velocities length {len(self.velocities)} != symbols length {len(self.symbols)} * 3" ) # Check connectivity if self.connectivity is not None: for bond in self.connectivity: if bond.atom1.value >= len(self.symbols) or bond.atom2.value >= len( self.symbols ): raise ValueError( f"Bond references invalid atom indices: {bond.atom1.value}, {bond.atom2.value}" ) # Check fragments if self.fragments is not None: atom_set = set() for fragment in self.fragments: for atom_idx in fragment.atoms: if atom_idx >= len(self.symbols): raise ValueError( f"Fragment references invalid atom index: {atom_idx}" ) if atom_idx in atom_set: raise ValueError( f"Atom {atom_idx} appears in multiple fragments" ) atom_set.add(atom_idx) if len(atom_set) != len(self.symbols): raise ValueError("Not all atoms are assigned to fragments")
[docs] def distance_between_atoms(self, atom1: AtomRef, atom2: AtomRef) -> float: """Calculate distance between two atoms.""" if atom1.value >= len(self.symbols) or atom2.value >= len(self.symbols): raise ValueError("Invalid atom indices") i1, i2 = atom1.value * 3, atom2.value * 3 dx = self.geometry[i1] - self.geometry[i2] dy = self.geometry[i1 + 1] - self.geometry[i2 + 1] dz = self.geometry[i1 + 2] - self.geometry[i2 + 2] return (dx * dx + dy * dy + dz * dz) ** 0.5
[docs] def distance_to_point( self, atom: AtomRef, point: tuple[float, float, float] ) -> float: """Calculate distance from atom to a point.""" if atom.value >= len(self.symbols): raise ValueError("Invalid atom index") i = atom.value * 3 dx = self.geometry[i] - point[0] dy = self.geometry[i + 1] - point[1] dz = self.geometry[i + 2] - point[2] return (dx * dx + dy * dy + dz * dz) ** 0.5
[docs] def get_atoms_near_point( self, point: tuple[float, float, float], threshold: float, atom_indices: list[int] | None = None, ) -> list[int]: """Get atom indices within threshold distance of a point.""" if atom_indices is None: atom_indices = list(range(len(self.symbols))) near_atoms = [] for atom_idx in atom_indices: if atom_idx >= len(self.symbols): continue distance = self.distance_to_point(AtomRef(atom_idx), point) if distance <= threshold: near_atoms.append(atom_idx) return near_atoms
[docs] def get_fragments_near_fragment( self, frag_idx: int, threshold: float, atom_indices: list[int] | None = None, ) -> list[FragmentRef]: """Get fragment indices within threshold distance of another fragment.""" if not self.fragments: return [] if atom_indices is None: atom_indices = list(range(len(self.symbols))) near_atoms = set() for atom_idx in self.fragments[frag_idx]: atom_idx = int(atom_idx) if atom_idx >= len(self.symbols): print("Warning: bad atom index {atom_index}", file=sys.stderr) continue near_atoms |= { AtomRef(a) for a in self.get_atoms_near_point( ( self.geometry[atom_idx * 3], self.geometry[atom_idx * 3 + 1], self.geometry[atom_idx * 3 + 2], ), threshold, ) } return [ FragmentRef(i) for (i, f) in enumerate(self.fragments) if (i != frag_idx and not near_atoms.isdisjoint(f)) ]
[docs] def extend(self, other: Self) -> None: """Extend this topology with atoms from another topology.""" offset = len(self.symbols) # Extend basic arrays self.symbols.extend(other.symbols) self.geometry.extend(other.geometry) # Extend optional arrays if self.labels is not None and other.labels is not None: self.labels.extend(other.labels) elif self.labels is not None and other.labels is None: self.labels.extend([""] * len(other.symbols)) if self.partial_charges is not None and other.partial_charges is not None: self.partial_charges.extend(other.partial_charges) elif self.partial_charges is not None and other.partial_charges is None: self.partial_charges.extend([PartialCharge(0.0)] * len(other.symbols)) if self.formal_charges is not None and other.formal_charges is not None: self.formal_charges.extend(other.formal_charges) elif self.formal_charges is not None and other.formal_charges is None: self.formal_charges.extend([FormalCharge(0)] * len(other.symbols)) if self.velocities is not None and other.velocities is not None: self.velocities.extend(other.velocities) elif self.velocities is not None and other.velocities is None: self.velocities.extend([0.0] * (len(other.symbols) * 3)) # Update connectivity with offset if other.connectivity is not None: if self.connectivity is None: self.connectivity = [] for bond in other.connectivity: new_bond = Bond( AtomRef(bond.atom1.value + offset), AtomRef(bond.atom2.value + offset), bond.order, ) self.connectivity.append(new_bond) # Update fragments with offset if self.fragments is not None and other.fragments is not None: for fragment in other.fragments: new_atoms = [AtomRef(atom + offset) for atom in fragment.atoms] self.fragments.append(Fragment(new_atoms)) elif self.fragments is not None and other.fragments is None: # Create a single fragment for all new atoms new_atoms = [AtomRef(i + offset) for i in range(len(other.symbols))] self.fragments.append(Fragment(new_atoms)) # Extend fragment charges if other.fragment_formal_charges is not None: if self.fragment_formal_charges is None: self.fragment_formal_charges = [] self.fragment_formal_charges.extend(other.fragment_formal_charges) if other.fragment_partial_charges is not None: if self.fragment_partial_charges is None: self.fragment_partial_charges = [] self.fragment_partial_charges.extend(other.fragment_partial_charges)
[docs] def new_topology_from_residue_subset( self, residue_subset: list["Residue"] ) -> "Topology": """Create a new topology containing only atoms from specified residues.""" new_topology = Topology(schema_version=self.schema_version) # Collect all atom indices from residues atom_indices = [] for residue in residue_subset: atom_indices.extend(residue.atoms) # Already integers # Build atom mapping old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(atom_indices)} # Copy basic data new_topology.symbols = [self.symbols[i] for i in atom_indices] new_topology.geometry = [] for i in atom_indices: new_topology.geometry.extend(self.geometry[i * 3 : (i + 1) * 3]) # Copy optional data if self.labels: new_topology.labels = [self.labels[i] for i in atom_indices] if self.partial_charges: new_topology.partial_charges = [ self.partial_charges[i] for i in atom_indices ] if self.formal_charges: new_topology.formal_charges = [self.formal_charges[i] for i in atom_indices] if self.velocities: new_topology.velocities = [] for i in atom_indices: new_topology.velocities.extend(self.velocities[i * 3 : (i + 1) * 3]) # Copy connectivity (only bonds between atoms in subset) if self.connectivity: new_topology.connectivity = [] for bond in self.connectivity: if bond.atom1.value in old_to_new and bond.atom2.value in old_to_new: new_bond = Bond( AtomRef(old_to_new[bond.atom1.value]), AtomRef(old_to_new[bond.atom2.value]), bond.order, ) new_topology.connectivity.append(new_bond) return new_topology
[docs] class AminoAcidSeq(Enum): """Amino acid sequence names.""" GLY = "GLY" ALA = "ALA" VAL = "VAL" LEU = "LEU" ILE = "ILE" PRO = "PRO" SER = "SER" THR = "THR" ASN = "ASN" GLN = "GLN" CYS = "CYS" CYD = "CYD" CYX = "CYX" MET = "MET" PHE = "PHE" TYR = "TYR" TYD = "TYD" TRP = "TRP" ASP = "ASP" ASH = "ASH" GLU = "GLU" GLH = "GLH" HIS = "HIS" HIN = "HIN" HID = "HID" HIE = "HIE" HIP = "HIP" LYS = "LYS" LYD = "LYD" LYN = "LYN" ARG = "ARG" HYP = "HYP" ACE = "ACE" BNC = "BNC" NME = "NME" NMA = "NMA" NHH = "NHH" UNK = "UNK"
[docs] @classmethod def is_amino_acid(cls, residue_name: str) -> bool: """Check if a residue name is a known amino acid.""" try: cls(residue_name.upper()) return True except ValueError: return False
[docs] def to_single_letter(self) -> str: """Convert to single letter code.""" mapping = { "GLY": "G", "ALA": "A", "VAL": "V", "LEU": "L", "ILE": "I", "PRO": "P", "SER": "S", "THR": "T", "ASN": "N", "GLN": "Q", "CYS": "C", "CYD": "C", "CYX": "C", "MET": "M", "PHE": "F", "TYR": "Y", "TYD": "Y", "TRP": "W", "ASP": "D", "ASH": "D", "GLU": "E", "GLH": "E", "HIS": "H", "HIN": "H", "HID": "H", "HIE": "H", "HIP": "H", "LYS": "K", "LYD": "K", "LYN": "K", "ARG": "R", "HYP": "O", } return mapping.get(self.value, "X")
[docs] class Residue: """A residue containing a list of atoms.""" def __init__(self, atoms: list[AtomRef] | list[int] | None = None): # Store as list of integers to match JSON serialization if atoms is None: self.atoms = [] else: self.atoms = [ atom.value if isinstance(atom, AtomRef) else atom for atom in atoms ] def __len__(self) -> int: return len(self.atoms) def __iter__(self): # Return AtomRef objects when iterating return (AtomRef(atom) for atom in self.atoms)
[docs] def contains(self, atom: AtomRef) -> bool: return atom.value in self.atoms
def __eq__(self, other): return isinstance(other, Residue) and self.atoms == other.atoms def __repr__(self): return f"Residue({[AtomRef(a) for a in self.atoms]})"
[docs] @dataclass class Residues: """Collection of residues with metadata.""" # List of residues residues: list[Residue] = field(default_factory=list) # Sequence names (e.g., amino acid names) seqs: list[str] = field(default_factory=list) # Sequence numbers seq_ns: list[int] = field(default_factory=list) # Insertion codes insertion_codes: list[str] = field(default_factory=list) # WARN: Deprecated labeled: list[ResidueRef] | None = None # WARN: Deprecated labels: list[list[str]] | None = None
[docs] @staticmethod def from_json(json_content: str | Path | dict) -> "Residues": if isinstance(json_content, str): residues_data = json.loads(json_content) elif isinstance(json_content, Path): with json_content.open() as f: residues_data = json.load(f) elif isinstance(json_content, dict): residues_data = json_content else: print( "WARNING: Tried to load Residues from JSON but " "it wasn't a str, Path, or dict!" ) residues_data = json_content residues = Residues() residues.residues = [Residue(res) for res in residues_data["residues"]] residues.seqs = residues_data["seqs"] residues.seq_ns = residues_data["seq_ns"] residues.insertion_codes = residues_data["insertion_codes"] if residues_data.get("labeled"): residues.labeled = [ResidueRef(r) for r in residues_data["labeled"]] if residues_data.get("labels"): residues.labels = residues_data["labels"] return residues
[docs] def to_json(self) -> dict[str, object]: residues_dict: dict[str, object] = { "residues": [residue.atoms for residue in self.residues], "seqs": self.seqs, "seq_ns": self.seq_ns, "insertion_codes": self.insertion_codes, } if self.labeled is not None: residues_dict["labeled"] = [r.value for r in self.labeled] if self.labels is not None: residues_dict["labels"] = self.labels return residues_dict
[docs] def check(self) -> None: """Validate the residues structure.""" if len(self.seqs) != len(self.residues): raise ValueError( f"Seqs length {len(self.seqs)} != residues length {len(self.residues)}" ) if len(self.seq_ns) != len(self.residues): raise ValueError( f"Seq_ns length {len(self.seq_ns)} != residues length {len(self.residues)}" ) if len(self.insertion_codes) != len(self.residues): raise ValueError( f"Insertion codes length {len(self.insertion_codes)} != residues length {len(self.residues)}" )
[docs] def is_amino_acid(self, index: int) -> bool: """Check if residue at index is an amino acid.""" if index >= len(self.seqs): return False return AminoAcidSeq.is_amino_acid(self.seqs[index])
[docs] def amino_acid_indices(self) -> list[int]: """Get indices of amino acid residues.""" return [i for i in range(len(self.seqs)) if self.is_amino_acid(i)]
[docs] def non_amino_acid_indices(self) -> list[int]: """Get indices of non-amino acid residues.""" return [i for i in range(len(self.seqs)) if not self.is_amino_acid(i)]
[docs] def extend(self, other: Self) -> None: """Extend this residues collection with another.""" # Calculate atom offset for renumbering offset = sum(len(residue.atoms) for residue in self.residues) # Calculate residue offset before extending (number of residues in self before merge) residue_offset = len(self.residues) # Extend residues with renumbered atoms for residue in other.residues: new_atoms = [atom + offset for atom in residue.atoms] self.residues.append(Residue(new_atoms)) # Extend metadata self.seqs.extend(other.seqs) self.seq_ns.extend(other.seq_ns) self.insertion_codes.extend(other.insertion_codes) # Handle labeled residues and labels if other.labeled is not None: if self.labeled is None: self.labeled = [] # Renumber residue references for ref in other.labeled: if isinstance(ref, ResidueRef): self.labeled.append(ResidueRef(ref.value + residue_offset)) elif isinstance(ref, int): self.labeled.append(ref + residue_offset) else: self.labeled.append(ref) if other.labels is not None: if self.labels is None: self.labels = [] # Copy labels (they're lists, so we need to copy them) for label in other.labels: if isinstance(label, list): self.labels.append(label.copy()) else: self.labels.append(label)
[docs] def new_residues_from_subset(self, residue_refs: list[ResidueRef]) -> "Residues": """Create new residues collection from a subset of residue references.""" new_residues = Residues() offset = 0 for residue_ref in residue_refs: if residue_ref.value >= len(self.residues): continue # Get original residue original_residue = self.residues[residue_ref.value] residue_len = len(original_residue.atoms) # Create new residue with renumbered atoms new_atoms = [offset + i for i in range(residue_len)] new_residues.residues.append(Residue(new_atoms)) # Copy metadata new_residues.seqs.append(self.seqs[residue_ref.value]) new_residues.seq_ns.append(self.seq_ns[residue_ref.value]) new_residues.insertion_codes.append(self.insertion_codes[residue_ref.value]) offset += residue_len return new_residues
[docs] class Chain: """A chain containing a list of residues.""" def __init__(self, residues: list[ResidueRef] | list[int] | None = None): # Store as list of integers to match JSON serialization if residues is None: self.residues = [] else: self.residues = [ res.value if isinstance(res, ResidueRef) else res for res in residues ] def __len__(self) -> int: return len(self.residues) def __iter__(self): # Return ResidueRef objects when iterating return (ResidueRef(res) for res in self.residues)
[docs] def contains(self, residue: ResidueRef) -> bool: return residue.value in self.residues
def __eq__(self, other): return isinstance(other, Chain) and self.residues == other.residues def __repr__(self): return f"Chain({[ResidueRef(r) for r in self.residues]})"
[docs] @dataclass class Chains: """Collection of chains with secondary structure information.""" # List of chains chains: list[Chain] = field(default_factory=list) # Optional alpha helix residues alpha_helices: list[ResidueRef] | None = None # Optional beta sheet residues beta_sheets: list[ResidueRef] | None = None # WARN: Deprecated labeled: list[ChainRef] | None = None # WARN: Deprecated labels: list[list[str]] | None = None
[docs] @staticmethod def from_json(json_content: str | Path | dict) -> "Chains": if isinstance(json_content, str): chains_data = json.loads(json_content) elif isinstance(json_content, Path): with json_content.open() as f: chains_data = json.load(f) elif isinstance(json_content, dict): chains_data = json_content else: print( "WARNING: Tried to load Chains from JSON but " "it wasn't a str, Path, or dict!" ) chains_data = json_content chains = Chains() chains.chains = [Chain(chain) for chain in chains_data["chains"]] if chains_data.get("alpha_helices"): chains.alpha_helices = [ResidueRef(r) for r in chains_data["alpha_helices"]] if chains_data.get("beta_sheets"): chains.beta_sheets = [ResidueRef(r) for r in chains_data["beta_sheets"]] if chains_data.get("labeled"): chains.labeled = [ChainRef(c) for c in chains_data["labeled"]] if chains_data.get("labels"): chains.labels = chains_data["labels"] return chains
[docs] def to_json(self) -> dict[str, object]: chains_dict: dict[str, object] = { "chains": [chain.residues for chain in self.chains], } if self.alpha_helices is not None: chains_dict["alpha_helices"] = [r.value for r in self.alpha_helices] if self.beta_sheets is not None: chains_dict["beta_sheets"] = [r.value for r in self.beta_sheets] if self.labeled is not None: chains_dict["labeled"] = [r.value for r in self.labeled] if self.labels is not None: chains_dict["labels"] = self.labels return chains_dict
[docs] def check(self) -> None: """Validate the chains structure.""" # Basic validation - more complex checks could be added pass
[docs] def extend(self, other: Self) -> None: """Extend this chains collection with another.""" # Calculate residue offset residue_offset = sum(len(chain.residues) for chain in self.chains) # Extend chains with renumbered residue references for chain in other.chains: new_residue_refs = [ref + residue_offset for ref in chain.residues] self.chains.append(Chain(new_residue_refs)) # Extend secondary structure info if self.alpha_helices is not None and other.alpha_helices is not None: new_alpha_helices = [ ref.value + residue_offset for ref in other.alpha_helices ] self.alpha_helices.extend([ResidueRef(ref) for ref in new_alpha_helices]) if self.beta_sheets is not None and other.beta_sheets is not None: new_beta_sheets = [ref.value + residue_offset for ref in other.beta_sheets] self.beta_sheets.extend([ResidueRef(ref) for ref in new_beta_sheets])
[docs] def new_chains_from_residue_subset( self, residue_refs: list[ResidueRef] ) -> "Chains": """Create new chains collection from a subset of residue references.""" new_chains = Chains() # Create mapping from old residue indices to new ones old_to_new_residue = {ref.value: i for i, ref in enumerate(residue_refs)} # Group residues by their original chains chain_to_new_residues = defaultdict(list) for new_idx, residue_ref in enumerate(residue_refs): # Find which chain this residue belonged to for chain_idx, chain in enumerate(self.chains): if residue_ref.value in chain.residues: chain_to_new_residues[chain_idx].append(new_idx) break # Create new chains for chain_idx in sorted(chain_to_new_residues.keys()): new_chain_residues = chain_to_new_residues[chain_idx] # Sort by original sequence order original_chain = self.chains[chain_idx] new_chain_residues.sort( key=lambda new_idx: original_chain.residues.index( residue_refs[new_idx].value ) ) new_chains.chains.append(Chain(new_chain_residues)) # Filter secondary structure info if self.alpha_helices: new_alpha_helices = [] for residue_ref in self.alpha_helices: if residue_ref.value in old_to_new_residue: new_alpha_helices.append( ResidueRef(old_to_new_residue[residue_ref.value]) ) new_chains.alpha_helices = new_alpha_helices if new_alpha_helices else None if self.beta_sheets: new_beta_sheets = [] for residue_ref in self.beta_sheets: if residue_ref.value in old_to_new_residue: new_beta_sheets.append( ResidueRef(old_to_new_residue[residue_ref.value]) ) new_chains.beta_sheets = new_beta_sheets if new_beta_sheets else None return new_chains
[docs] @dataclass class TRC: """ Combined Topology, Residues, and Chains structure. This is the main structure for representing molecular systems on the Rush platform. """ topology: Topology = field(default_factory=Topology) residues: Residues = field(default_factory=Residues) chains: Chains = field(default_factory=Chains)
[docs] def check(self) -> None: """Validate the entire TRC structure.""" self.topology.check() self.residues.check() self.chains.check() # Check that all atoms are in residues atom_set = set() for residue in self.residues.residues: for atom_idx in residue.atoms: if atom_idx in atom_set: raise ValueError(f"Atom {atom_idx} appears in multiple residues") atom_set.add(atom_idx) if len(atom_set) != len(self.topology.symbols): raise ValueError("Not all atoms are assigned to residues") # Check that all residues are in chains residue_set = set() for chain in self.chains.chains: for residue_idx in chain.residues: if residue_idx >= len(self.residues.residues): raise ValueError( f"Chain references invalid residue index: {residue_idx}" ) if residue_idx in residue_set: raise ValueError( f"Residue {residue_idx} appears in multiple chains" ) residue_set.add(residue_idx) if len(residue_set) != len(self.residues.residues): raise ValueError("Not all residues are assigned to chains")
[docs] def extend(self, other: Self) -> None: """Extend this TRC with another TRC.""" self.topology.extend(other.topology) self.residues.extend(other.residues) self.chains.extend(other.chains)
[docs] def new_trc_from_residue_subset(self, residue_refs: list[ResidueRef]) -> "TRC": """Create new TRC from a subset of residue references.""" # Get residue subset residue_subset = [self.residues.residues[ref.value] for ref in residue_refs] return TRC( topology=self.topology.new_topology_from_residue_subset(residue_subset), residues=self.residues.new_residues_from_subset(residue_refs), chains=self.chains.new_chains_from_residue_subset(residue_refs), )
[docs] @dataclass(frozen=True) class ResidueId: """Unique identifier for a residue.""" chain_id: str sequence_number: int insertion_code: str residue_name: str def __str__(self) -> str: return f"{self.chain_id}_{self.sequence_number:>9}_{self.insertion_code}_{self.residue_name}"