Source code for rush.exess._qmmm

"""
EXESS QM/MM simulations for the Rush Python client.

Quick Links
-----------

- :func:`rush.exess.qmmm`
- :class:`rush.exess.QMMMResult`
"""

import json
import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any

from gql.transport.exceptions import TransportQueryError

from rush import TRC, Residues, Topology, TRCRef
from rush._trc import to_residues_vobj, to_topology_vobj

from .._utils import optional_str
from ..client import (
    RunOpts,
    RunSpec,
    RushObject,
    _get_project_id,
    _submit_rex,
    fetch_object,
)
from ..run import RushRun
from ._energy import (
    AuxBasisT,
    BasisT,
    FragKeywords,
    KSDFTKeywords,
    MethodT,
    SCFKeywords,
    StandardOrientationT,
    System,
    _KSDFTDefault,
)

# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------


[docs] @dataclass class Trajectory: """ Configure the output of QMMM runs. By default, will provide all atoms at every frame. """ #: Save every n frames to the trajectory, where n is the interval specified. interval: int | None = None #: The frame at which to start the trajectory. start: int | None = None #: The frame at which to end the trajectory. end: int | None = None #: Whether to include waters in the trajectory. Convenient for reducing output size. include_waters: int | None = None def _to_rex(self): return Template( """Some (exess_qmmm_rex::MDTrajectory { format = None, interval = $maybe_interval, start = $maybe_start, end = $maybe_end, include_waters = $maybe_include_waters, })""" ).substitute( maybe_interval=optional_str(self.interval), maybe_start=optional_str(self.start), maybe_end=optional_str(self.end), maybe_include_waters=optional_str(self.include_waters), )
[docs] @dataclass class Restraints: """ Restrain atoms using an external force proportional to its distance from its original position, scaled by `k` (larger values mean a stronger restraint). All atoms can be fixed by specifying `free_atoms = []`. """ #: Scaling factor for restraints (larger values mean a stronger restraint). k: float | None = None #: Which atoms to hold fixed. All fixed/free parameters are mutually exclusive. fixed_atoms: list[int] | None = None #: Which atoms to keep unfixed. All fixed/free parameters are mutually exclusive. free_atoms: list[int] | None = None #: Which fragments to hold fixed. All fixed/free parameters are mutually exclusive. fixed_fragments: list[int] | None = None #: Which fragments to keep unfixed. All fixed/free parameters are mutually exclusive. free_fragments: list[int] | None = None #: Flag to easily enable fixing all heavy atoms only. Mutually exclusive with fixed/free parameters. fix_heavy: bool | None = None def _to_rex(self): return Template( """Some (exess_rex::Restraints { k = $maybe_k, fixed_atoms = $maybe_fixed_atoms, free_atoms = $maybe_free_atoms, fixed_fragments = $maybe_fixed_fragments, free_fragments = $maybe_free_fragments, fix_heavy = $maybe_fix_heavy, })""" ).substitute( maybe_k=optional_str(self.k), maybe_fixed_atoms=optional_str(self.fixed_atoms), maybe_free_atoms=optional_str(self.free_atoms), maybe_fixed_fragments=optional_str(self.fixed_fragments), maybe_free_fragments=optional_str(self.free_fragments), maybe_fix_heavy=optional_str(self.fix_heavy), )
# --------------------------------------------------------------------------- # Result types # ---------------------------------------------------------------------------
[docs] @dataclass class QMMMResult: geometries: list[list[float]]
[docs] @dataclass(frozen=True) class QMMMResultPaths: output: Path
[docs] @dataclass(frozen=True) class QMMMResultRef: """Lightweight reference to QM/MM outputs in the Rush object store.""" output: RushObject
[docs] @classmethod def from_raw_output(cls, res: Any) -> "QMMMResultRef": """Parse raw ``collect_run`` output into a ``QMMMResultRef``.""" if not isinstance(res, dict) or not isinstance(res.get("path"), str): raise ValueError( f"qmmm output received unexpected format: {type(res).__name__}" ) return cls(output=RushObject.from_dict(res))
[docs] def fetch(self) -> QMMMResult: """Download QM/MM outputs and parse into Python objects.""" return QMMMResult(**json.loads(fetch_object(self.output.path)))
[docs] def save(self) -> QMMMResultPaths: """Download QM/MM outputs and save to the workspace.""" return QMMMResultPaths(output=self.output.save())
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def qmmm( mol: TRC | TRCRef | tuple[Path | str | RushObject | Topology, Path | str | RushObject | Residues] | Path | str | RushObject | Topology, n_timesteps: int, dt_ps: float = 2e-3, temperature_kelvin: float = 290.0, pressure_atm: float | None = None, restraints: Restraints | None = None, trajectory: Trajectory = Trajectory(), gradient_finite_difference_step_size: float | None = None, method: MethodT = "RestrictedKSDFT", basis: BasisT = "cc-pVDZ", aux_basis: AuxBasisT | None = None, standard_orientation: StandardOrientationT | None = None, force_cartesian_basis_sets: bool | None = None, scf_keywords: SCFKeywords | None = None, frag_keywords: FragKeywords = FragKeywords(), ksdft_keywords: KSDFTKeywords | _KSDFTDefault | None = _KSDFTDefault.DEFAULT, qm_fragments: list[int] | None = None, mm_fragments: list[int] | None = None, system: System | None = None, run_spec: RunSpec = RunSpec(gpus=1), run_opts: RunOpts = RunOpts(), ) -> RushRun[QMMMResultRef]: """ Submit a QM/MM simulation for the topology at *topology_path*. Returns a :class:`~rush.run.RushRun` handle. Call ``.fetch()`` to get the parsed trajectory, or ``.save()`` to write it to disk. """ ksdft_keywords = KSDFTKeywords.resolve(ksdft_keywords, method) # Upload inputs residues_vobj = None match mol: case TRC() | TRCRef(): topology_vobj = to_topology_vobj(mol.topology) residues_vobj = to_residues_vobj(mol.residues) case (t, r): topology_vobj = to_topology_vobj(t) residues_vobj = to_residues_vobj(r) case _: topology_vobj = to_topology_vobj(mol) # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, exess = λ topology residues → exess_qmmm_rex_s ($run_spec) (exess_qmmm_rex::QMMMParams { schema_version = "0.2.0", model = Some (exess_qmmm_rex::Model { method = exess_qmmm_rex::Method::$method, basis = "$basis", aux_basis = $maybe_aux_basis, standard_orientation = $maybe_standard_orientation, force_cartesian_basis_sets = $maybe_force_cartesian_basis_sets, }), system = $system, keywords = exess_qmmm_rex::Keywords { scf = $maybe_scf_keywords, ks_dft = $maybe_ks_keywords, rtat = None, frag = $maybe_frag_keywords, boundary = None, log = None, dynamics = None, integrals = None, debug = None, export = None, guess = None, force_field = None, optimization = None, hessian = None, gradient = Some (exess_qmmm_rex::GradientKeywords { finite_difference_step_size = $maybe_gradient_finite_difference_step_size, method = Some exess_qmmm_rex::DerivativesMethod::Analytical, }), qmmm = Some (exess_qmmm_rex::QMMMKeywords { n_timesteps = $n_timesteps, dt_ps = $dt_ps, temperature_kelvin = $temperature_kelvin, pressure_atm = $maybe_pressure_atm, minimisation = None, trajectory = $trajectory, restraints = $maybe_restraints, energy_csv = None, }), machine_learning = None, regions = $maybe_regions, }, }) (obj_j topology) (Some (obj_j residues)) in exess "$topology_vobj_path" "$residues_vobj_path" """).substitute( run_spec=run_spec._to_rex(), method=method, basis=basis, maybe_aux_basis=optional_str(aux_basis), maybe_standard_orientation=optional_str( standard_orientation, "exess_rex::StandardOrientation::" ), maybe_force_cartesian_basis_sets=optional_str(force_cartesian_basis_sets), system=system._to_rex() if system is not None else "None", maybe_scf_keywords=( scf_keywords._to_rex() if scf_keywords is not None else "None" ), maybe_ks_keywords=( ksdft_keywords._to_rex() if ksdft_keywords is not None else "None" ), maybe_frag_keywords=( frag_keywords._to_rex() if frag_keywords is not None else "None" ), maybe_gradient_finite_difference_step_size=optional_str( gradient_finite_difference_step_size ), n_timesteps=n_timesteps, dt_ps=dt_ps, temperature_kelvin=temperature_kelvin, maybe_pressure_atm=optional_str(pressure_atm), trajectory=trajectory._to_rex(), maybe_restraints=restraints._to_rex() if restraints is not None else "None", maybe_regions=( Template( """Some (exess_qmmm_rex::RegionKeywords { qm_fragments = $maybe_qm_fragments, mm_fragments = $maybe_mm_fragments, ml_fragments = Some [], })""" ).substitute( maybe_qm_fragments=optional_str(qm_fragments), maybe_mm_fragments=optional_str(mm_fragments), ) if not (qm_fragments is None and mm_fragments is None) else "None" ), topology_vobj_path=topology_vobj["path"], residues_vobj_path=residues_vobj["path"] if residues_vobj is not None else "", ) try: return RushRun( _submit_rex(_get_project_id(), rex, run_opts), QMMMResultRef, ) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise