Source code for rush.convert.mmcif

"""
mmCIF file parsing functionality.
"""

from collections import OrderedDict, defaultdict

from ..mol import (
    TRC,
    AtomRef,
    Bond,
    BondOrder,
    Chain,
    ChainRef,
    Element,
    FormalCharge,
    Fragment,
    Residue,
    ResidueId,
    ResidueRef,
)


def _parse_mmcif_value(value: str) -> str:
    """Parse an mmCIF value, handling quoted strings and special characters."""
    value = value.strip()
    if value in (".", "?"):
        return ""
    if value.startswith("'") and value.endswith("'"):
        return value[1:-1]
    if value.startswith('"') and value.endswith('"'):
        return value[1:-1]
    return value


def _parse_mmcif_loop(
    lines: list[str], start_idx: int, prefix: str
) -> tuple[tuple[list[str], list[list[str]]] | None, int]:
    """
    Parse an mmCIF loop starting at start_idx.

    Returns:
        ((column_names, rows), next_idx) or (None, next_idx) if not a loop with the given prefix
    """
    i = start_idx

    # Check if this is a loop
    if i >= len(lines) or not lines[i].strip().startswith("loop_"):
        return (None, i)

    i += 1

    # Parse column names
    columns = []
    while i < len(lines):
        line = lines[i].strip()
        if not line or line.startswith("#"):
            i += 1
            continue
        if not line.startswith("_"):
            break
        if line.startswith(prefix):
            columns.append(line[len(prefix) :])
        elif (
            columns
        ):  # Started collecting columns for this prefix, now hit a different prefix
            break
        i += 1

    if not columns:
        return (None, i)

    # Parse data rows (may span multiple lines)
    rows = []
    while i < len(lines):
        line = lines[i].strip()
        if not line or line.startswith("#"):
            i += 1
            continue
        if line.startswith("_") or line.startswith("loop_"):
            break

        # Parse fields from current line (and additional lines if needed)
        fields = []
        current_line = lines[i]
        i += 1

        while len(fields) < len(columns):
            # Parse tokens from current_line
            tokens = []
            j = 0
            current_line_stripped = current_line.rstrip("\n\r")
            while j < len(current_line_stripped):
                # Skip whitespace
                while (
                    j < len(current_line_stripped) and current_line_stripped[j] in " \t"
                ):
                    j += 1
                if j >= len(current_line_stripped):
                    break

                # Check for quoted string
                if current_line_stripped[j] in ("'", '"'):
                    quote_char = current_line_stripped[j]
                    j += 1
                    start = j
                    while (
                        j < len(current_line_stripped)
                        and current_line_stripped[j] != quote_char
                    ):
                        j += 1
                    tokens.append(current_line_stripped[start:j])
                    j += 1  # Skip closing quote
                else:
                    # Unquoted value
                    start = j
                    while (
                        j < len(current_line_stripped)
                        and current_line_stripped[j] not in " \t"
                    ):
                        j += 1
                    tokens.append(current_line_stripped[start:j])

            fields.extend(tokens)

            # If we don't have enough fields yet, try to read the next line
            if len(fields) < len(columns):
                if i < len(lines):
                    next_line = lines[i].strip()
                    if (
                        next_line
                        and not next_line.startswith("_")
                        and not next_line.startswith("loop_")
                        and not next_line.startswith("data_")
                    ):
                        current_line = lines[i]
                        i += 1
                    else:
                        break
                else:
                    break

        if len(fields) == len(columns):
            rows.append(fields)

    return ((columns, rows), i)


def _build_trc_from_mmcif_atoms(
    atoms: list[dict],
    struct_conn_data: tuple[list[str], list[list[str]]] | None,
    comp_bond_data: tuple[list[str], list[list[str]]] | None,
) -> TRC:
    """Build a TRC from parsed mmCIF atoms."""
    trc = TRC()

    atom_ids = []
    atom_labels = []
    atom_formal_charges = []
    atom_symbols = []
    geometry = []

    residue_data = OrderedDict()
    chain_data = defaultdict(set)
    atom_index_map = {}  # Original atom index to topology index

    for orig_idx, atom in enumerate(atoms):
        # Only process atoms with alternate location "A" or None
        alt_id = atom["label_alt_id"]
        if alt_id and alt_id != "A":
            continue

        # Parse element from type_symbol
        type_symbol = atom["type_symbol"]
        # Remove non-alphabetic characters
        element_str = "".join(c for c in type_symbol if c.isalpha())
        try:
            element = Element.from_str(element_str)
        except (ValueError, KeyError):
            element = Element.C  # Default to carbon

        topology_idx = len(atom_symbols)
        atom_index_map[orig_idx] = topology_idx

        atom_symbols.append(element)
        geometry.extend([atom["Cartn_x"], atom["Cartn_y"], atom["Cartn_z"]])

        atom_ids.append(atom["id"])
        atom_labels.append(atom["label_atom_id"])
        atom_formal_charges.append(atom["pdbx_formal_charge"])

        # Create residue identifier using auth fields and "~" for sorting
        residue_id = ResidueId(
            chain_id=atom["auth_asym_id"],
            sequence_number=atom["auth_seq_id"],
            insertion_code=atom["pdbx_PDB_ins_code"] or "~",
            residue_name=atom["label_comp_id"],
        )

        if residue_id not in residue_data:
            residue_data[residue_id] = []
        residue_data[residue_id].append(len(atom_symbols) - 1)

        chain_data[atom["auth_asym_id"]].add(residue_id)

    # Build topology
    trc.topology.symbols = atom_symbols
    trc.topology.geometry = geometry
    trc.topology.labels = atom_labels
    trc.topology.formal_charges = [
        FormalCharge(charge) for charge in atom_formal_charges
    ]

    # Sort residues by ResidueId
    sorted_residue_ids = sorted(
        residue_data.keys(),
        key=lambda rid: (
            rid.chain_id,
            rid.sequence_number,
            rid.insertion_code,
            rid.residue_name,
        ),
    )

    # Build residues
    residue_list = []
    seq_names = []
    seq_numbers = []
    insertion_codes_list = []

    for residue_id in sorted_residue_ids:
        atom_indices = residue_data[residue_id]
        residue_list.append(Residue([AtomRef(idx) for idx in atom_indices]))
        seq_names.append(residue_id.residue_name)
        seq_numbers.append(residue_id.sequence_number)
        # Convert "~" back to empty string
        insertion_code = (
            "" if residue_id.insertion_code == "~" else residue_id.insertion_code
        )
        insertion_codes_list.append(insertion_code)

    trc.residues.residues = residue_list
    trc.residues.seqs = seq_names
    trc.residues.seq_ns = seq_numbers
    trc.residues.insertion_codes = insertion_codes_list

    # Build chains
    chains = []
    residue_id_to_index = {rid: idx for idx, rid in enumerate(sorted_residue_ids)}
    chain_ids = sorted(chain_data.keys())

    for chain_id in chain_ids:
        chain_residue_ids = chain_data[chain_id]
        sorted_chain_residue_ids = sorted(
            chain_residue_ids,
            key=lambda rid: (rid.sequence_number, rid.insertion_code, rid.residue_name),
        )

        chain_residue_refs = [
            ResidueRef(residue_id_to_index[rid])
            for rid in sorted_chain_residue_ids
            if rid in residue_id_to_index
        ]
        chains.append(Chain(chain_residue_refs))

    trc.chains.chains = chains
    trc.chains.labeled = [ChainRef(i) for i in range(len(chains))]
    trc.chains.labels = [[chain_id] for chain_id in chain_ids]

    # Create fragments (one per residue)
    trc.topology.fragments = [
        Fragment([AtomRef(atom_idx) for atom_idx in residue.atoms])
        for residue in trc.residues.residues
    ]

    # Calculate fragment formal charges
    fragment_formal_charges = []
    for residue in trc.residues.residues:
        total_charge = sum(atom_formal_charges[atom_idx] for atom_idx in residue.atoms)
        fragment_formal_charges.append(FormalCharge(total_charge))
    trc.topology.fragment_formal_charges = fragment_formal_charges

    # Build connectivity from struct_conn and chem_comp_bond
    connectivity_deduper = {}  # (min_idx, max_idx) -> bond_order

    # Parse struct_conn (inter-residue bonds)
    if struct_conn_data:
        columns, rows = struct_conn_data
        col_idx = {col: idx for idx, col in enumerate(columns)}

        for row in rows:

            def get_val(name: str) -> str:
                idx = col_idx.get(name)
                if idx is not None and idx < len(row):
                    return _parse_mmcif_value(row[idx])
                return ""

            def get_int_val(name: str) -> int:
                val = get_val(name)
                try:
                    return int(val) if val else 0
                except ValueError:
                    return 0

            # Find atoms by label (uses label_ fields, not auth_)
            ptnr1_atom = get_val("ptnr1_label_atom_id")
            ptnr1_asym = get_val("ptnr1_label_asym_id")
            ptnr1_seq = get_int_val("ptnr1_label_seq_id")
            ptnr2_atom = get_val("ptnr2_label_atom_id")
            ptnr2_asym = get_val("ptnr2_label_asym_id")
            ptnr2_seq = get_int_val("ptnr2_label_seq_id")
            conn_type = get_val("conn_type_id")

            # Find matching atoms using label_ fields (find FIRST match like Rust .position())
            atom1_orig_idx = None
            atom2_orig_idx = None
            for idx, atom in enumerate(atoms):
                if atom1_orig_idx is None and (
                    atom["label_atom_id"] == ptnr1_atom
                    and atom["label_asym_id"] == ptnr1_asym
                    and atom["label_seq_id"] == ptnr1_seq
                ):
                    atom1_orig_idx = idx
                if atom2_orig_idx is None and (
                    atom["label_atom_id"] == ptnr2_atom
                    and atom["label_asym_id"] == ptnr2_asym
                    and atom["label_seq_id"] == ptnr2_seq
                ):
                    atom2_orig_idx = idx
                if atom1_orig_idx is not None and atom2_orig_idx is not None:
                    break

            if atom1_orig_idx is not None and atom2_orig_idx is not None:
                topo_idx1 = atom_index_map.get(atom1_orig_idx)
                topo_idx2 = atom_index_map.get(atom2_orig_idx)

                if topo_idx1 is not None and topo_idx2 is not None:
                    bond_order = 1  # Default to single bond
                    if conn_type in ["covale", "metalc", "disulf"]:
                        bond_order = 1

                    min_idx = min(topo_idx1, topo_idx2)
                    max_idx = max(topo_idx1, topo_idx2)
                    connectivity_deduper[(min_idx, max_idx)] = bond_order

    # Parse chem_comp_bond (intra-residue bonds)
    if comp_bond_data:
        columns, rows = comp_bond_data
        col_idx = {col: idx for idx, col in enumerate(columns)}

        # Build mapping of comp_id -> bonds
        comp_bonds = defaultdict(list)
        for row in rows:

            def get_val(name: str) -> str:
                idx = col_idx.get(name)
                if idx is not None and idx < len(row):
                    return _parse_mmcif_value(row[idx])
                return ""

            comp_id = get_val("comp_id")
            atom_id_1 = get_val("atom_id_1")
            atom_id_2 = get_val("atom_id_2")
            value_order = get_val("value_order")

            comp_bonds[comp_id].append((atom_id_1, atom_id_2, value_order))

        # Group atoms by residue for efficient lookup
        # Note: Rust uses (comp_id, auth_asym_id, auth_seq_id) without insertion code
        residue_atoms = defaultdict(
            list
        )  # (comp_id, auth_asym_id, auth_seq_id) -> list of (orig_idx, topo_idx, atom)
        for orig_idx, atom in enumerate(atoms):
            if atom["label_alt_id"] == "" or atom["label_alt_id"] == "A":
                topo_idx = atom_index_map.get(orig_idx)
                if topo_idx is not None:
                    key = (
                        atom["label_comp_id"],
                        atom["auth_asym_id"],
                        atom["auth_seq_id"],
                    )
                    residue_atoms[key].append((orig_idx, topo_idx, atom))

        # Apply bond definitions to residues
        for (comp_id, chain_id, seq_id), res_atoms in residue_atoms.items():
            if comp_id in comp_bonds:
                for atom_id_1, atom_id_2, value_order in comp_bonds[comp_id]:
                    # Find THE FIRST atom that matches each atom_id (Rust uses find())
                    topo_idx1 = None
                    topo_idx2 = None
                    for _, topo_idx, atom in res_atoms:
                        if topo_idx1 is None and atom["label_atom_id"] == atom_id_1:
                            topo_idx1 = topo_idx
                        if topo_idx2 is None and atom["label_atom_id"] == atom_id_2:
                            topo_idx2 = topo_idx
                        if topo_idx1 is not None and topo_idx2 is not None:
                            break

                    if topo_idx1 is not None and topo_idx2 is not None:
                        # Parse bond order
                        bond_order = 1
                        if value_order == "SING":
                            bond_order = 1
                        elif value_order == "DOUB":
                            bond_order = 2
                        elif value_order == "TRIP":
                            bond_order = 3
                        elif value_order == "QUAD":
                            bond_order = 4
                        elif value_order == "AROM":
                            bond_order = 5

                        min_idx = min(topo_idx1, topo_idx2)
                        max_idx = max(topo_idx1, topo_idx2)
                        connectivity_deduper[(min_idx, max_idx)] = bond_order

    # Convert to Bond objects
    bonds = []
    for (min_idx, max_idx), order in sorted(connectivity_deduper.items()):
        bonds.append(Bond(AtomRef(min_idx), AtomRef(max_idx), BondOrder(order)))
    trc.topology.connectivity = bonds

    return trc


[docs] def from_mmcif(mmcif_content: str) -> TRC | list[TRC]: """ Parse mmCIF file contents into TRC structures. Args: mmcif_content: String contents of an mmCIF file Returns: TRC structure or list of TRC structures """ lines = mmcif_content.split("\n") trcs = [] # Parse loops models = defaultdict(list) # model_num -> list of atoms atom_loop_data = None struct_conn_data = None comp_bond_data = None i = 0 while i < len(lines): if lines[i].strip().startswith("loop_"): # Try to parse atom_site loop result, next_i = _parse_mmcif_loop(lines, i, "_atom_site.") if result: columns, rows = result # Check if this has atom_site columns if any("id" in col or "type_symbol" in col for col in columns): atom_loop_data = (columns, rows) i = next_i continue # Try to parse struct_conn loop result, next_i = _parse_mmcif_loop(lines, i, "_struct_conn.") if result: struct_conn_data = result i = next_i continue # Try to parse chem_comp_bond loop result, next_i = _parse_mmcif_loop(lines, i, "_chem_comp_bond.") if result: comp_bond_data = result i = next_i continue i = next_i else: i += 1 if not atom_loop_data: empty_trc = TRC() empty_trc.chains.labeled = [] empty_trc.chains.labels = [] return [empty_trc] columns, rows = atom_loop_data # Find column indices col_idx = {} for idx, col in enumerate(columns): col_idx[col] = idx # Parse atoms for row in rows: if len(row) < len(columns): continue def get_val(name: str, default: str = "") -> str: idx = col_idx.get(name) if idx is not None and idx < len(row): val = _parse_mmcif_value(row[idx]) return val if val else default return default def get_int(name: str, default: int = 0) -> int | None: val = get_val(name) if not val: return None try: return int(val) except ValueError: return None def get_int_with_default(name: str, default: int = 0) -> int: val = get_int(name) return val if val is not None else default def get_float(name: str, default: float = 0.0) -> float: val = get_val(name) try: return float(val) if val else default except ValueError: return default # Parse auth_seq_id with fallback logic matching Rust auth_seq_id_val = get_int("auth_seq_id") if auth_seq_id_val is None: auth_seq_id_val = get_int("label_seq_id") if auth_seq_id_val is None: auth_seq_id_val = 0 atom = { "id": get_int_with_default("id", 0), "type_symbol": get_val("type_symbol", "C"), "label_atom_id": get_val("label_atom_id", "C"), "label_alt_id": get_val("label_alt_id", ""), "label_comp_id": get_val("label_comp_id", "UNK"), "label_asym_id": get_val("label_asym_id", "A"), "label_seq_id": get_int_with_default("label_seq_id", 0), "pdbx_PDB_ins_code": get_val("pdbx_PDB_ins_code", ""), "Cartn_x": get_float("Cartn_x", 0.0), "Cartn_y": get_float("Cartn_y", 0.0), "Cartn_z": get_float("Cartn_z", 0.0), "occupancy": get_float("occupancy", 1.0), "B_iso_or_equiv": get_float("B_iso_or_equiv", 0.0), "pdbx_formal_charge": get_int_with_default("pdbx_formal_charge", 0), "auth_asym_id": ( get_val("auth_asym_id", "") or get_val("label_asym_id", "A") ), "auth_seq_id": auth_seq_id_val, "group_PDB": get_val("group_PDB", "ATOM"), "pdbx_PDB_model_num": get_val("pdbx_PDB_model_num", "1"), } model_num = atom["pdbx_PDB_model_num"] models[model_num].append(atom) # Build TRC for each model for model_num in sorted(models.keys()): atoms = models[model_num] trc = _build_trc_from_mmcif_atoms(atoms, struct_conn_data, comp_bond_data) trcs.append(trc) if not trcs: empty_trc = TRC() empty_trc.chains.labeled = [] empty_trc.chains.labels = [] trcs.append(empty_trc) if len(trcs) == 1: return trcs[0] return trcs