Source code for rush.hyper._hyper_solvate_sumo

"""Hyper solvation wrapper for the Rush Python client."""

from __future__ import annotations

import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Self

from gql.transport.exceptions import TransportQueryError

from .._rex import optional_str
from ..mol import TRC
from ..objects import RushObject
from ..runs import Run, RunOpts, RunSpec
from ..session import _submit_rex
from ._common import (
    ItemError,
    TRCInput,
    _fetch_trc,
    _format_rex_list,
    _parse_batch_outputs,
    _to_rex_json_obj,
    _upload_trc_object,
)


[docs] @dataclass(frozen=True) class HyperConfig: """Config for :func:`hyper_solvate_sumo`.""" max_inputs: int | None = None padding_nm: float | None = None seed: int | None = None timeout_seconds: int | None = None
[docs] @dataclass(frozen=True) class TRCBatchResultRef: """Result reference for TRC-producing Hyper batch entrypoints.""" items: list[RushObject | ItemError]
[docs] @classmethod def from_raw_output(cls, raw: Any) -> Self: parsed = _parse_batch_outputs(raw, _parse_trc_item, "hyper TRC batch") return cls(items=parsed)
def __getitem__(self, index: int) -> RushObject | ItemError: return self.items[index] def __len__(self) -> int: return len(self.items)
[docs] def fetch(self) -> list[TRC | ItemError]: out: list[TRC | ItemError] = [] for item in self.items: if isinstance(item, ItemError): out.append(item) else: out.append(_fetch_trc(item)) return out
[docs] def save(self) -> list[Path | ItemError]: return [ item if isinstance(item, ItemError) else item.save(ext="json") for item in self.items ]
def _parse_trc_item(raw: Any) -> RushObject: if not isinstance(raw, dict): raise ValueError(f"Expected TRC output object, got {type(raw).__name__}") return RushObject.from_dict(raw) def _to_rex_solvate_config(config: HyperConfig | None) -> str: if config is None: return "None" return Template( """Some (hyper_solvate_sumo::HyperConfig { max_inputs = $max_inputs, padding_nm = $padding_nm, seed = $seed, timeout_seconds = $timeout_seconds, })""" ).substitute( max_inputs=optional_str(config.max_inputs), padding_nm=optional_str(config.padding_nm), seed=optional_str(config.seed), timeout_seconds=optional_str(config.timeout_seconds), )
[docs] def hyper_solvate_sumo( input_trcs: list[TRCInput], config: HyperConfig | None = None, run_spec: RunSpec = RunSpec(target="Bullet"), run_opts: RunOpts = RunOpts(), ) -> Run[TRCBatchResultRef]: """Submit Hyper solvation for one or more TRC inputs.""" input_exprs = [_to_rex_json_obj(_upload_trc_object(item)) for item in input_trcs] rex = Template( """hyper_solvate_sumo_s ($run_spec) ($config) $inputs""" ).substitute( run_spec=run_spec._to_rex(), config=_to_rex_solvate_config(config), inputs=_format_rex_list(input_exprs), ) try: return Run(_submit_rex(rex, run_opts), TRCBatchResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise