Source code for rush.smol_similarity

"""
smol-similarity module for the Rush Python client.

This module runs nearest-neighbor SMILES similarity search for one or more query
SMILES against one or more partition objects.

Usage::

    from rush import smol_similarity

    run = smol_similarity.smol_similarity_sumo(
        smol_partitions=["partition_a.json", "partition_b.json"],
        input_smis="queries.json",
    )
    result = run.fetch()
"""

import sys
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Literal, Self

from gql.transport.exceptions import TransportQueryError

from ._rex import optional_str
from .objects import RushObject, upload_object
from .runs import Run, RunOpts, RunSpec
from .session import _submit_rex


# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------


[docs] @dataclass(frozen=True) class SmolSimilarityConfig: """Optional runtime configuration for smol-similarity. Attributes: min_similarity: Optional lower bound for tanimoto similarity in [0, 1]. min_results: Minimum number of rows returned per query. max_results: Maximum number of rows returned per query. """ min_similarity: float | None = None min_results: int | None = None max_results: int | None = None
# --------------------------------------------------------------------------- # Result/error types # ---------------------------------------------------------------------------
[docs] @dataclass(frozen=True) class ExecutionError: """Per-query execution failure for a single input SMILES.""" stage: Literal["Subprocess", "OutputParse"] message: str @classmethod def from_raw(cls, raw: Any) -> Self: if not isinstance(raw, dict): raise ValueError( "smol_similarity_sumo item Err must be a dict, " f"got {type(raw).__name__}" ) stage = raw.get("stage") message = raw.get("message") if stage not in ("Subprocess", "OutputParse") or not isinstance(message, str): raise ValueError( "smol_similarity_sumo item Err missing valid stage/message fields" ) return cls(stage=stage, message=message)
[docs] @dataclass(frozen=True) class Result: """Parsed similarity output for one query SMILES.""" smiles: list[str] similarities: list[float]
[docs] @dataclass(frozen=True) class ResultPaths: """Workspace path for a saved per-query result object.""" output: Path
[docs] @dataclass(frozen=True) class ItemResultRef: """Reference to one per-query result, which may be success or error.""" output: RushObject | None error: ExecutionError | None @classmethod def from_raw_output(cls, raw: Any) -> Self: if not isinstance(raw, dict) or len(raw) != 1: raise ValueError( "smol_similarity_sumo item output should be Result{Ok|Err}, " f"got {type(raw).__name__}" ) if "Ok" in raw: value = raw["Ok"] if not isinstance(value, dict): raise ValueError( "smol_similarity_sumo item Ok must be an object descriptor dict" ) return cls(output=RushObject.from_dict(value), error=None) if "Err" in raw: return cls(output=None, error=ExecutionError.from_raw(raw["Err"])) raise ValueError("smol_similarity_sumo item output must have Ok or Err") def fetch(self) -> Result | ExecutionError: if self.error is not None: return self.error if self.output is None: raise ValueError("smol_similarity_sumo item missing both output and error") payload = self.output.fetch_list() if len(payload) != 2: raise ValueError( "smol_similarity_sumo item payload must be [smiles, similarities], " f"got list with {len(payload)} items" ) raw_smiles, raw_similarities = payload if not isinstance(raw_smiles, list) or not isinstance(raw_similarities, list): raise ValueError( "smol_similarity_sumo item payload entries must both be lists" ) smiles = [str(smi) for smi in raw_smiles] similarities = [float(similarity) for similarity in raw_similarities] return Result(smiles=smiles, similarities=similarities) def save(self) -> ResultPaths | ExecutionError: if self.error is not None: return self.error if self.output is None: raise ValueError("smol_similarity_sumo item missing both output and error") return ResultPaths(output=self.output.save())
[docs] @dataclass(frozen=True) class ResultRef: """Reference to per-query smol-similarity outputs.""" items: list[ItemResultRef] def __getitem__(self, index: int) -> ItemResultRef: return self.items[index] def __len__(self) -> int: return len(self.items) def __iter__(self) -> Iterator[ItemResultRef]: return iter(self.items)
[docs] @classmethod def from_raw_output(cls, res: Any) -> Self: """Parse raw ``collect_run`` output into a ``ResultRef``.""" if not isinstance(res, list) or len(res) == 0: raise ValueError( "smol_similarity_sumo should return a non-empty list of per-query results, " f"got {type(res).__name__}" f" with {len(res) if isinstance(res, list) else '?'} items" ) return cls(items=[ItemResultRef.from_raw_output(item) for item in res])
[docs] def fetch(self) -> list[Result | ExecutionError]: """Download outputs and parse per-query results/errors.""" return [item.fetch() for item in self.items]
[docs] def save(self) -> list[ResultPaths | ExecutionError]: """Download outputs and save per-query result objects to the workspace.""" return [item.save() for item in self.items]
def _to_uploaded_json_object(value: Path | str | RushObject) -> dict[str, Any]: if isinstance(value, RushObject): return value.to_dict() return upload_object(value) def _config_to_rex(config: SmolSimilarityConfig | None) -> str: if config is None: return "None" return Template( """Some (smol_similarity_sumo::SmolSimilarityConfig { min_similarity = $min_similarity, min_results = $min_results, max_results = $max_results })""" ).substitute( min_similarity=optional_str(config.min_similarity), min_results=optional_str(config.min_results), max_results=optional_str(config.max_results), ) # --------------------------------------------------------------------------- # Submission # ---------------------------------------------------------------------------
[docs] def smol_similarity_sumo( smol_partitions: list[Path | str | RushObject], input_smis: Path | str | RushObject, config: SmolSimilarityConfig | None = None, run_spec: RunSpec = RunSpec(cpus=4, gpus=0), run_opts: RunOpts = RunOpts(), ) -> Run[ResultRef]: """Submit a smol-similarity run. Args: smol_partitions: JSON object files, each containing a list of library SMILES. input_smis: JSON object file containing query SMILES. config: Optional search controls. Returns: A :class:`~rush.runs.Run` handle. Call ``.collect()`` for :class:`ResultRef`, then ``.fetch()`` or ``.save()``. """ partition_objects = [_to_uploaded_json_object(partition) for partition in smol_partitions] input_smis_object = _to_uploaded_json_object(input_smis) partition_objects_rex = ", ".join( f'VirtualObject {{ path = "{obj["path"]}", format = ObjectFormat::json, size = 0 }}' for obj in partition_objects ) rex = Template("""let cfg = $config, obj_j = λ j → VirtualObject { path = j, format = ObjectFormat::json, size = 0 } in smol_similarity_sumo_s ($run_spec) cfg [$smol_partition_objects] (obj_j \"$input_smis_path\") """).substitute( run_spec=run_spec._to_rex(), config=_config_to_rex(config), smol_partition_objects=partition_objects_rex, input_smis_path=input_smis_object["path"], ) try: return Run(_submit_rex(rex, run_opts), ResultRef) except TransportQueryError as e: if e.errors: for error in e.errors: print(f"Error: {error['message']}", file=sys.stderr) raise