Source code for rush.convert.pdb

"""
PDB file parsing and writing functionality.
"""

import sys
from collections import OrderedDict, defaultdict
from dataclasses import dataclass

from ..mol import (
    TRC,
    AminoAcidSeq,
    AtomRef,
    Bond,
    BondOrder,
    Chain,
    ChainRef,
    Element,
    FormalCharge,
    Fragment,
    Residue,
    ResidueId,
    ResidueRef,
)


@dataclass
class PDBAtom:
    """Represents a parsed PDB ATOM/HETATM record."""

    atom_idx: int
    atom_name: str
    alternate_location: str | None
    residue_name: str
    chain_id: str
    sequence_number: int
    residue_insertion: str | None
    atom_x: float
    atom_y: float
    atom_z: float
    occupancy: float
    temperature_factor: float
    segment_id: str | None
    element_symbol: Element
    charge: int | None


def _parse_pdb_atom_line(line: str, line_num: int) -> PDBAtom:
    """Parse a PDB ATOM or HETATM line."""
    if len(line) < 54:
        raise ValueError(f"Line {line_num}: ATOM/HETATM line too short")

    try:
        atom_idx = int(line[6:11].strip())
        atom_name = line[12:16].strip()
        alternate_location = line[16].strip() if line[16].strip() else None
        residue_name = line[17:20].strip()
        chain_id = line[21].strip() if len(line) > 21 else ""
        sequence_number = int(line[22:26].strip()) if line[22:26].strip() else 1
        residue_insertion = (
            line[26].strip() if len(line) > 26 and line[26].strip() else None
        )

        atom_x = float(line[30:38].strip()) if line[30:38].strip() else 0.0
        atom_y = float(line[38:46].strip()) if line[38:46].strip() else 0.0
        atom_z = float(line[46:54].strip()) if line[46:54].strip() else 0.0

        occupancy = (
            float(line[54:60].strip())
            if len(line) > 60 and line[54:60].strip()
            else 1.0
        )
        temperature_factor = (
            float(line[60:66].strip())
            if len(line) > 66 and line[60:66].strip()
            else 0.0
        )

        segment_id = (
            line[72:76].strip() if len(line) > 76 and line[72:76].strip() else None
        )

        element_symbol_str = (
            line[76:78].strip()
            if len(line) > 78 and line[76:78].strip()
            else atom_name[0]
        )
        element_symbol = Element.from_str(element_symbol_str)

        charge = None
        if len(line) > 80 and line[78:80].strip():
            charge_str = line[78:80].strip()
            if charge_str:
                # Parse charge like "+1", "-2", etc.
                if charge_str[-1] in "+-":
                    sign = 1 if charge_str[-1] == "+" else -1
                    magnitude = int(charge_str[:-1]) if charge_str[:-1] else 1
                    charge = sign * magnitude
                else:
                    charge = int(charge_str)

        return PDBAtom(
            atom_idx=atom_idx,
            atom_name=atom_name,
            alternate_location=alternate_location,
            residue_name=residue_name,
            chain_id=chain_id,
            sequence_number=sequence_number,
            residue_insertion=residue_insertion,
            atom_x=atom_x,
            atom_y=atom_y,
            atom_z=atom_z,
            occupancy=occupancy,
            temperature_factor=temperature_factor,
            segment_id=segment_id,
            element_symbol=element_symbol,
            charge=charge,
        )
    except (ValueError, IndexError) as e:
        raise ValueError(f"Line {line_num}: Error parsing ATOM/HETATM line: {e}")


def _parse_conect_line(line: str) -> list[int]:
    """Parse a CONECT line and return list of atom indices."""
    atom_idxs = []
    # CONECT format: positions 6-11, 11-16, 16-21, 21-26, 26-31 for atom indices
    start = 6
    while start < len(line):
        end = start + 5
        if end > len(line):
            end = len(line)
        atom_idx_str = line[start:end].strip()
        if atom_idx_str:
            try:
                atom_idxs.append(int(atom_idx_str))
            except ValueError:
                break
        else:
            break
        start = end
    return atom_idxs


def _build_trc(
    atoms: list[PDBAtom],
    atom_ids: list[int],
    residue_data: OrderedDict,
    chain_data: dict[str, set[ResidueId]],
    connectivity: list[tuple[int, int, int]],
) -> TRC:
    """Build a TRC structure from parsed PDB data."""

    trc = TRC()

    # Build topology
    trc.topology.symbols = [atom.element_symbol for atom in atoms]
    trc.topology.geometry = []
    for atom in atoms:
        trc.topology.geometry.extend([atom.atom_x, atom.atom_y, atom.atom_z])

    trc.topology.labels = [atom.atom_name for atom in atoms]

    # Formal charges (per atom)
    atom_formal_charges = [atom.charge or 0 for atom in atoms]
    trc.topology.formal_charges = [
        FormalCharge(charge) for charge in atom_formal_charges
    ]

    # Sort residues by ResidueId (chain_id, sequence_number, insertion_code, residue_name)
    # This matches the Rust BTreeMap ordering
    sorted_residue_ids = sorted(
        residue_data.keys(),
        key=lambda rid: (
            rid.chain_id,
            rid.sequence_number,
            rid.insertion_code,
            rid.residue_name,
        ),
    )

    # Build residues in sorted order
    residue_list = []
    seq_names = []
    seq_numbers = []
    insertion_codes_list = []

    for residue_id in sorted_residue_ids:
        atom_indices = residue_data[residue_id]
        residue_atoms = [AtomRef(idx) for idx in atom_indices]
        residue_list.append(Residue(residue_atoms))
        seq_names.append(residue_id.residue_name)
        seq_numbers.append(residue_id.sequence_number)
        # Convert "~" back to empty string for storage
        insertion_code = (
            "" if residue_id.insertion_code == "~" else residue_id.insertion_code
        )
        insertion_codes_list.append(insertion_code)

    trc.residues.residues = residue_list
    trc.residues.seqs = seq_names
    trc.residues.seq_ns = seq_numbers
    trc.residues.insertion_codes = insertion_codes_list

    # Build chains
    chains = []
    residue_id_to_index = {rid: idx for idx, rid in enumerate(sorted_residue_ids)}
    chain_ids = sorted(chain_data.keys())

    for chain_id in chain_ids:
        chain_residue_ids = chain_data[chain_id]
        # Sort residues in chain by sequence number
        sorted_residue_ids = sorted(
            chain_residue_ids, key=lambda rid: (rid.sequence_number, rid.insertion_code)
        )

        chain_residue_refs = [
            ResidueRef(residue_id_to_index[rid]) for rid in sorted_residue_ids
        ]
        chains.append(Chain(chain_residue_refs))

    trc.chains.chains = chains
    trc.chains.labeled = [ChainRef(i) for i in range(len(chains))]
    trc.chains.labels = [[chain_id] for chain_id in chain_ids]

    # Create fragments (one per residue) - amino acids as default fragments
    trc.topology.fragments = [
        Fragment([AtomRef(atom_idx) for atom_idx in residue.atoms])
        for residue in trc.residues.residues
    ]

    # Process connectivity
    connectivity_deduper = {}  # (origin, target) -> order
    for origin_id, target_id, order in connectivity:
        # Convert atom IDs to indices
        try:
            origin_idx = atom_ids.index(origin_id)
        except ValueError:
            continue

        try:
            target_idx = atom_ids.index(target_id)
        except ValueError:
            continue

        # Check if reverse bond already exists (dedup)
        if (target_idx, origin_idx) in connectivity_deduper:
            continue

        # If same bond already exists, increment order (double bond)
        if (origin_idx, target_idx) in connectivity_deduper:
            connectivity_deduper[(origin_idx, target_idx)] += 1
        else:
            connectivity_deduper[(origin_idx, target_idx)] = order

    # Convert to Bond objects
    bonds = []
    for (origin_idx, target_idx), order in connectivity_deduper.items():
        bonds.append(
            Bond(
                AtomRef(min(origin_idx, target_idx)),
                AtomRef(max(origin_idx, target_idx)),
                BondOrder(order),
            )
        )
    trc.topology.connectivity = bonds

    # Calculate fragment formal charges (sum of atom charges in each residue)
    fragment_formal_charges = []
    for residue in trc.residues.residues:
        total_charge = sum(atom_formal_charges[atom_idx] for atom_idx in residue.atoms)
        fragment_formal_charges.append(FormalCharge(total_charge))
    trc.topology.fragment_formal_charges = fragment_formal_charges

    return trc


def _apply_global_connectivity(
    trc: TRC, atom_ids: list[int], global_connectivity: list[tuple[int, int, int]]
):
    """Apply global connectivity records to a TRC."""
    if not global_connectivity:
        return

    connectivity_deduper = {}  # (origin, target) -> order

    for origin_id, target_id, order in global_connectivity:
        # Convert atom IDs to indices
        try:
            origin_idx = atom_ids.index(origin_id)
        except ValueError:
            continue

        try:
            target_idx = atom_ids.index(target_id)
        except ValueError:
            continue

        # Check if reverse bond already exists (dedup)
        if (target_idx, origin_idx) in connectivity_deduper:
            continue

        # If same bond already exists, increment order (double bond)
        if (origin_idx, target_idx) in connectivity_deduper:
            connectivity_deduper[(origin_idx, target_idx)] += 1
        else:
            connectivity_deduper[(origin_idx, target_idx)] = order

    # Convert to Bond objects
    additional_bonds = []
    for (origin_idx, target_idx), order in connectivity_deduper.items():
        additional_bonds.append(
            Bond(
                AtomRef(min(origin_idx, target_idx)),
                AtomRef(max(origin_idx, target_idx)),
                BondOrder(order),
            )
        )

    # Add to existing connectivity
    if trc.topology.connectivity:
        trc.topology.connectivity.extend(additional_bonds)
    else:
        trc.topology.connectivity = additional_bonds


[docs] def from_pdb(pdb_content: str) -> TRC | list[TRC]: """ Parse PDB file content into TRC structures. Args: pdb_content: String content of a PDB file Returns: TRC structure or list of TRC structures (one per model in multi-model files) """ trcs = [] trc_atom_ids = [] global_connectivity = [] # List of (origin, target, order) tuples lines = pdb_content.strip().split("\n") line_iter = iter(enumerate(lines, 1)) eof = False while not eof: # Storage for current model atoms = [] atom_ids = [] residue_data = OrderedDict() # ResidueId -> atom indices chain_data = defaultdict(set) # chain_id -> set of ResidueIds connectivity = [] # Local connectivity for this model in_model = False while True: try: line_num, line = next(line_iter) except StopIteration: eof = True break if len(line) < 6: continue record_type = line[:6].strip() if record_type == "MODEL": in_model = True elif record_type == "ENDMDL": in_model = False break elif record_type in ["ATOM", "HETATM"]: in_model = True try: atom = _parse_pdb_atom_line(line, line_num) # Only process atoms with alternate location "A" or None # Skip atoms with other alternate locations (e.g., "B", "C", etc.) if ( atom.alternate_location is None or atom.alternate_location == "A" ): atoms.append(atom) atom_ids.append(atom.atom_idx) # Create residue identifier # Note: insertion_code uses "~" for sorting (to sort after all letters) # but the actual value stored in the residues structure is empty string residue_id = ResidueId( chain_id=atom.chain_id, sequence_number=atom.sequence_number, insertion_code=atom.residue_insertion or "~", residue_name=atom.residue_name, ) # Add to residue data if residue_id not in residue_data: residue_data[residue_id] = [] residue_data[residue_id].append( len(atoms) - 1 ) # Index in atoms list # Add to chain data chain_data[atom.chain_id].add(residue_id) # else: skip atoms with other alternate locations except ValueError as e: print(f"Warning: {e}", file=sys.stderr) continue elif record_type == "CONECT": try: atom_idxs = _parse_conect_line(line) if len(atom_idxs) >= 2: origin = atom_idxs[0] for target in atom_idxs[1:]: if in_model: connectivity.append((origin, target, 1)) else: global_connectivity.append((origin, target, 1)) except (ValueError, IndexError): continue elif record_type == "END": break # If no atoms were found, skip this model if not atoms: if eof: break else: continue # Build the TRC for this model trc = _build_trc(atoms, atom_ids, residue_data, chain_data, connectivity) trcs.append(trc) trc_atom_ids.append(atom_ids) if eof: break # Apply global connectivity to all models for trc, atom_ids in zip(trcs, trc_atom_ids): _apply_global_connectivity(trc, atom_ids, global_connectivity) # If no TRCs were created, return an empty one if not trcs: trcs.append(TRC()) if len(trcs) == 1: return trcs[0] return trcs
[docs] def to_pdb(trc: TRC) -> str: """ Convert TRC structure to PDB format string. Args: trc: TRC structure to convert Returns: PDB format string """ lines = [] # Create mapping from residue to chain residue_to_chain = {} for chain_idx, chain in enumerate(trc.chains.chains): for residue_idx in chain.residues: residue_to_chain[residue_idx] = chain_idx atom_idx = 1 for residue_idx, residue in enumerate(trc.residues.residues): chain_idx = residue_to_chain.get(residue_idx, 0) chain_id = chr(65 + chain_idx) if chain_idx < 26 else "A" # A, B, C, ... residue_name = ( trc.residues.seqs[residue_idx] if residue_idx < len(trc.residues.seqs) else "UNK" ) seq_num = ( trc.residues.seq_ns[residue_idx] if residue_idx < len(trc.residues.seq_ns) else 1 ) insertion_code = ( trc.residues.insertion_codes[residue_idx] if residue_idx < len(trc.residues.insertion_codes) else "" ) for atom_idx in residue.atoms: if atom_idx >= len(trc.topology.symbols): continue element = trc.topology.symbols[atom_idx] atom_name = ( trc.topology.labels[atom_idx] if trc.topology.labels else str(element) ) x = ( trc.topology.geometry[atom_idx * 3] if atom_idx * 3 < len(trc.topology.geometry) else 0.0 ) y = ( trc.topology.geometry[atom_idx * 3 + 1] if atom_idx * 3 + 1 < len(trc.topology.geometry) else 0.0 ) z = ( trc.topology.geometry[atom_idx * 3 + 2] if atom_idx * 3 + 2 < len(trc.topology.geometry) else 0.0 ) formal_charge = 0 if trc.topology.formal_charges and atom_idx < len( trc.topology.formal_charges ): formal_charge = trc.topology.formal_charges[atom_idx].charge # Format ATOM record record_type = ( "ATOM" if AminoAcidSeq.is_amino_acid(residue_name) else "HETATM" ) line = f"{record_type:<6}{atom_idx:>5} {atom_name:<4} {residue_name:>3} {chain_id}{seq_num:>4}{insertion_code:<1} {x:>8.3f}{y:>8.3f}{z:>8.3f} 1.00 0.00 {str(element):>2}{formal_charge:+2d}" lines.append(line) atom_idx += 1 lines.append("END") return "\n".join(lines)