Source code for rush.pbsa

"""
PBSA module for the Rush Python client.

Computes solvation energies using the Poisson-Boltzmann Surface Area method.

Usage::

    from rush import pbsa

    result = pbsa.solvation_energy("mol.json", ...).fetch()
    print(result.solvation_energy)
"""

import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from string import Template
from typing import Any

from gql.transport.exceptions import TransportQueryError

from rush import TRC, Topology, TRCRef
from rush._trc import to_topology_vobj

from ._utils import float_to_str
from .client import (
    RunOpts,
    RunSpec,
    RushObject,
    _get_project_id,
    _json_content_name,
    _submit_rex,
    save_json,
)
from .run import RushRun

# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------


[docs] @dataclass class Result: """Parsed PBSA solvation energy results (all values in Hartrees).""" solvation_energy: float polar_solvation_energy: float nonpolar_solvation_energy: float
[docs] @dataclass(frozen=True) class ResultPaths: """Workspace path for saved PBSA output.""" output: Path
[docs] @dataclass(frozen=True) class ResultRef: """Lightweight reference to PBSA output. PBSA results are small enough to be returned inline (three floats), so no object store download is needed. """ solvation_energy: float polar_solvation_energy: float nonpolar_solvation_energy: float
[docs] @classmethod def from_raw_output(cls, res: Any) -> "ResultRef": """Parse raw ``collect_run`` output into a ``ResultRef``.""" if isinstance(res, list) and len(res) == 3: return cls( solvation_energy=float(res[0]), polar_solvation_energy=float(res[1]), nonpolar_solvation_energy=float(res[2]), ) raise ValueError( f"pbsa should return exactly 3 float outputs, " f"got {type(res).__name__} with {len(res) if hasattr(res, '__len__') else '?'} items." )
[docs] def fetch(self) -> Result: """Return parsed PBSA results (no download needed — data is inline).""" return Result( solvation_energy=self.solvation_energy, polar_solvation_energy=self.polar_solvation_energy, nonpolar_solvation_energy=self.nonpolar_solvation_energy, )
[docs] def save(self) -> ResultPaths: """Save PBSA results as JSON to the workspace.""" output_json = asdict(self.fetch()) return ResultPaths( output=save_json( output_json, name=_json_content_name("pbsa_output", output_json), ), )
# --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def solvation_energy( mol: TRC | TRCRef | Path | str | RushObject | Topology, solute_dielectric: float, solvent_dielectric: float, solvent_radius: float, ion_concentration: float, temperature: float, spacing: float, sasa_gamma: float, sasa_beta: float, sasa_n_samples: int, convergence: float, box_size_factor: float, run_spec: RunSpec = RunSpec(gpus=1), run_opts: RunOpts = RunOpts(), ) -> RushRun[ResultRef]: """ Submit a PBSA solvation energy calculation for the topology at *topology_path*. Returns a :class:`~rush.run.RushRun` handle. Call ``.fetch()`` to get the parsed result, or ``.save()`` to write it to disk as JSON. """ # Upload inputs topology_vobj = to_topology_vobj(mol) # Run rex rex = Template("""let obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 }, pbsa = λ topology → pbsa_rex_s ($run_spec) (pbsa_rex::PBSAParameters { solute_dielectric = $solute_dielectric, solvent_dielectric = $solvent_dielectric, solvent_radius = $solvent_radius, ion_concentration = $ion_concentration, temperature = $temperature, spacing = $spacing, sasa_gamma = $sasa_gamma, sasa_beta = $sasa_beta, sasa_n_samples = $sasa_n_samples, convergence = $convergence, box_size_factor = $box_size_factor, }) (obj_j topology) in pbsa "$topology_vobj_path" """).substitute( run_spec=run_spec._to_rex(), solute_dielectric=float_to_str(solute_dielectric), solvent_dielectric=float_to_str(solvent_dielectric), solvent_radius=float_to_str(solvent_radius), ion_concentration=float_to_str(ion_concentration), temperature=float_to_str(temperature), spacing=float_to_str(spacing), sasa_gamma=float_to_str(sasa_gamma), sasa_beta=float_to_str(sasa_beta), sasa_n_samples=sasa_n_samples, convergence=float_to_str(convergence), box_size_factor=float_to_str(box_size_factor), topology_vobj_path=topology_vobj["path"], ) try: return RushRun( _submit_rex(_get_project_id(), rex, run_opts), ResultRef, ) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise