#!/usr/bin/env python3
"""
NN-xTB module helpers for the Rush Python client.
NN-xTB reparameterizes xTB with a neural network to approach DFT-level accuracy
while keeping xTB-like speed. It supports arbitrary charge and spin states and
is well-suited for large-scale screening where fast, per-atom forces or
vibrational frequencies are needed. Frequency calculations are more expensive.
"""
import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from gql.transport.exceptions import TransportQueryError
from .client import (
RunOpts,
RunSpec,
_get_project_id,
_submit_rex,
collect_run,
upload_object,
)
from .utils import optional_str
[docs]
@dataclass
class NnxtbResults:
"""
Parsed nn-xTB results.
Use this to load JSON output from the Rush object store. When calling
`nnxtb(..., collect=True)`, the return value includes a `path` to the JSON
output. After reading the json into a dict, you can pass it to this class
like `NnxtbResults(**data)`.
"""
energy_mev: float
forces_mev_per_angstrom: list[tuple[float, float, float]] | None
frequencies_inv_cm: list[float] | None
def __init__(
self, energy_mev, forces_mev_per_angstrom=None, frequencies_inv_cm=None
):
self.energy_mev = energy_mev
self.forces_mev_per_angstrom = forces_mev_per_angstrom
self.frequencies_inv_cm = frequencies_inv_cm
[docs]
def nnxtb(
topology_path: Path | str,
compute_forces: bool | None = None,
compute_frequencies: bool | None = None,
multiplicity: int | None = None,
run_spec: RunSpec = RunSpec(gpus=1, storage=100),
run_opts: RunOpts = RunOpts(),
collect=False,
):
"""
Run NN-xTB on the system in the QDX topology file at `topology_path`.
Args:
topology_path: Path to a TRC topology JSON file.
compute_forces: Whether to compute per-atom forces.
Defaults to true.
compute_frequencies: Whether to compute vibrational frequencies.
Defaults to false.
multiplicity: Spin multiplicity. Defaults to 1 (singlet).
run_spec: Rush compute resources to request.
run_opts: Rush run metadata.
collect: Whether to wait for completion and return outputs.
"""
# Upload inputs
topology_vobj = upload_object(topology_path)
charge = 0
# Run rex
rex = Template("""let
obj_j = λ j →
VirtualObject { path = j, format = ObjectFormat::json, size = 0 },
nnxtb = λ topology →
nnxtb_rex_s
($run_spec)
(nnxtb_rex::NnxtbConfig {
compute_forces = $maybe_compute_forces,
compute_frequencies = $maybe_compute_frequencies,
charge = $maybe_charge,
multiplicity = $maybe_multiplicity,
})
(obj_j topology)
in
nnxtb "$topology_vobj_path"
""").substitute(
run_spec=run_spec._to_rex(),
maybe_compute_forces=optional_str(compute_forces),
maybe_compute_frequencies=optional_str(compute_frequencies),
maybe_charge=f"Some (int {charge})" if charge is not None else None,
maybe_multiplicity=optional_str(multiplicity),
topology_vobj_path=topology_vobj["path"],
)
try:
run_id = _submit_rex(_get_project_id(), rex, run_opts)
if collect:
return collect_run(run_id)
else:
return run_id
except TransportQueryError as e:
if e.errors:
for error in e.errors:
print(f"Error: {error['message']}", file=sys.stderr)