Source code for rush.auto3d

import sys
from string import Template

from gql.transport.exceptions import TransportQueryError

from rush.client import (
    RunError,
    RunOpts,
    RunSpec,
    _get_project_id,
    _submit_rex,
    collect_run,
)
from rush.utils import bool_to_str, float_to_str


[docs] def auto3d( smis: list[str], k: int = 1, batchsize_atoms: int = 1024, capacity: int = 40, convergence_threshold: float = 0.003, enumerate_isomer: bool = True, enumerate_tautomer: bool = False, max_confs: int | None = None, opt_steps: int = 5000, patience: int = 1000, threshold: float = 0.3, run_spec: RunSpec = RunSpec(), run_opts: RunOpts = RunOpts(), collect=False, ): """ Runs Auto3D on a list of SMILES strings, returning either the TRC structure or an error string. """ rex = Template("""let auto3d = λ smis → try_auto3d_rex default_runspec_gpu (auto3d_rex::Auto3dOptions { k = Some (int $k), batchsize_atoms = Some $batchsize_atoms, capacity = Some $capacity, convergence_threshold = Some $convergence_threshold, enumerate_isomer = Some $enumerate_isomer, enumerate_tautomer = Some $enumerate_tautomer, job_name = None, max_confs = $max_confs, memory = None, mpi_np = Some 4, opt_steps = Some $opt_steps, optimizing_engine = Some auto3d_rex::Auto3dOptimizingEngines::AIMNET, patience = Some $patience, threshold = Some $threshold, verbose = Some false, window = None, }) $smis in auto3d $smis """).substitute( smis=f"[{', '.join([f'"{smi}"' for smi in smis])}]", k=k, batchsize_atoms=batchsize_atoms, capacity=capacity, convergence_threshold=float_to_str(convergence_threshold), enumerate_isomer=bool_to_str(enumerate_isomer), enumerate_tautomer=bool_to_str(enumerate_tautomer), max_confs=max_confs, opt_steps=opt_steps, patience=patience, threshold=float_to_str(threshold), run_spec=run_spec._to_rex(), ) try: run_id = _submit_rex(_get_project_id(), rex, run_opts) if not collect: return run_id result = collect_run(run_id) if isinstance(result, RunError): return result def is_result_type(result): return ( isinstance(result, dict) and len(result) == 1 and ("Ok" in result or "Err" in result) ) # TODO: no special cases for Result unwrapping return [ next(iter(r_i.values())) if is_result_type(r_i) else r_i for r_i in result ] except TransportQueryError as e: if e.errors: print("Error:", file=sys.stderr) for error in e.errors: print(f" {error['message']}", file=sys.stderr)