import sys
from string import Template
from gql.transport.exceptions import TransportQueryError
from rush.client import (
RunError,
RunOpts,
RunSpec,
_get_project_id,
_submit_rex,
collect_run,
)
from rush.utils import bool_to_str, float_to_str
[docs]
def auto3d(
smis: list[str],
k: int = 1,
batchsize_atoms: int = 1024,
capacity: int = 40,
convergence_threshold: float = 0.003,
enumerate_isomer: bool = True,
enumerate_tautomer: bool = False,
max_confs: int | None = None,
opt_steps: int = 5000,
patience: int = 1000,
threshold: float = 0.3,
run_spec: RunSpec = RunSpec(),
run_opts: RunOpts = RunOpts(),
collect=False,
):
"""
Runs Auto3D on a list of SMILES strings, returning either the TRC structure
or an error string.
"""
rex = Template("""let
auto3d = λ smis →
try_auto3d_rex
default_runspec_gpu
(auto3d_rex::Auto3dOptions {
k = Some (int $k),
batchsize_atoms = Some $batchsize_atoms,
capacity = Some $capacity,
convergence_threshold = Some $convergence_threshold,
enumerate_isomer = Some $enumerate_isomer,
enumerate_tautomer = Some $enumerate_tautomer,
job_name = None,
max_confs = $max_confs,
memory = None,
mpi_np = Some 4,
opt_steps = Some $opt_steps,
optimizing_engine = Some auto3d_rex::Auto3dOptimizingEngines::AIMNET,
patience = Some $patience,
threshold = Some $threshold,
verbose = Some false,
window = None,
})
$smis
in
auto3d $smis
""").substitute(
smis=f"[{', '.join([f'"{smi}"' for smi in smis])}]",
k=k,
batchsize_atoms=batchsize_atoms,
capacity=capacity,
convergence_threshold=float_to_str(convergence_threshold),
enumerate_isomer=bool_to_str(enumerate_isomer),
enumerate_tautomer=bool_to_str(enumerate_tautomer),
max_confs=max_confs,
opt_steps=opt_steps,
patience=patience,
threshold=float_to_str(threshold),
run_spec=run_spec._to_rex(),
)
try:
run_id = _submit_rex(_get_project_id(), rex, run_opts)
if not collect:
return run_id
result = collect_run(run_id)
if isinstance(result, RunError):
return result
def is_result_type(result):
return (
isinstance(result, dict)
and len(result) == 1
and ("Ok" in result or "Err" in result)
)
# TODO: no special cases for Result unwrapping
return [
next(iter(r_i.values())) if is_result_type(r_i) else r_i for r_i in result
]
except TransportQueryError as e:
if e.errors:
print("Error:", file=sys.stderr)
for error in e.errors:
print(f" {error['message']}", file=sys.stderr)