#!/usr/bin/env python3
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from tempfile import NamedTemporaryFile
from gql.transport.exceptions import TransportQueryError
from rush.convert import from_json, from_pdb
from .client import (
RunOpts,
RunSpec,
_get_project_id,
_submit_rex,
collect_run,
upload_object,
)
from .utils import dict_to_vec_of_tuples_str, optional_str
[docs]
@dataclass
class Modification:
position: int
ccd: str
[docs]
@dataclass
class ProteinSequence:
id: list[str]
sequence: str
msa: dict[str, str] | Path | str
modifications: list[Modification] | None = None
cyclic: bool | None = None
def _to_rex(self):
if isinstance(self.msa, Path) or isinstance(self.msa, str):
self.msa = upload_object(self.msa)
return Template(
"""(boltz2_rex::Sequence::Protein {
id = $id,
sequence = "$sequence",
msa = VirtualObject { path = "$msa", format = ObjectFormat::bin, size = 0 },
modifications = None,
cyclic = $cyclic,
})"""
).substitute(
id=f"[{', '.join([f'"{v}"' for v in self.id])}]",
sequence=self.sequence,
msa=self.msa["path"],
cyclic=optional_str(self.cyclic),
)
[docs]
@dataclass
class LigandSequence:
id: list[str]
smiles: str
def _to_rex(self):
return Template(
"""(boltz2_rex::Sequence::Ligand {
id = $id,
smiles = "$smiles",
})"""
).substitute(
id=f"[{', '.join([f'"{v}"' for v in self.id])}]",
smiles=self.smiles,
)
[docs]
def boltz(
sequences: list[ProteinSequence | LigandSequence],
recycling_steps: int | None = None,
sampling_steps: int | None = None,
diffusion_samples: int | None = None,
step_scale: float | None = None,
affinity_binder_chain_id: str | None = None,
affinity_mw_correction: bool | None = None,
sampling_steps_affinity: int | None = None,
diffusion_samples_affinity: bool | None = None,
max_msa_seqs: int | None = None,
subsample_msa: bool | None = None,
num_subsampled_msa: int | None = None,
use_potentials: bool | None = None,
seed: int | None = None,
template_path: Path | str | None = None,
template_threshold_angstroms: float | None = None,
template_chain_mapping: dict[str, str] | None = None,
run_spec: RunSpec = RunSpec(gpus=1),
run_opts: RunOpts = RunOpts(),
collect=False,
):
# If necessary, upload template TRC inputs
has_template = template_path is not None
if template_path is not None:
if isinstance(template_path, str):
template_path = Path(template_path)
with open(template_path) as f:
if template_path.suffix == ".pdb":
trc = from_pdb(f.read())
else:
trc = from_json(json.load(f))
if isinstance(trc, list):
if len(trc) != 1:
raise ValueError(
f"Expected 1 TRC in {template_path}, found {len(trc)}"
)
trc = trc[0]
with (
NamedTemporaryFile(mode="w") as t_f,
NamedTemporaryFile(mode="w") as r_f,
NamedTemporaryFile(mode="w") as c_f,
):
json.dump(trc.topology.to_json(), t_f)
json.dump(trc.residues.to_json(), r_f)
json.dump(trc.chains.to_json(), c_f)
t_f.seek(0)
r_f.seek(0)
c_f.seek(0)
topology_vobj = upload_object(t_f.name)
residues_vobj = upload_object(r_f.name)
chains_vobj = upload_object(c_f.name)
# Run rex
rex = Template("""let
obj_j = λ j →
VirtualObject { path = j, format = ObjectFormat::json, size = 0 },
boltz = λ topology residues chains →
boltz2_rex_s
($run_spec)
(boltz2_rex::Boltz2Config {
recycling_steps = $maybe_recycling_steps,
sampling_steps = $maybe_sampling_steps,
diffusion_samples = $maybe_diffusion_samples,
step_scale = $maybe_step_scale,
affinity_binder_chain_id = $maybe_affinity_binder_chain_id,
affinity_mw_correction = $maybe_affinity_mw_correction,
sampling_steps_affinity = $maybe_sampling_steps_affinity,
diffusion_samples_affinity = $maybe_diffusion_samples_affinity,
max_msa_seqs = $maybe_max_msa_seqs,
subsample_msa = $maybe_subsample_msa,
num_subsampled_msa = $maybe_num_subsampled_msa,
use_potentials = $maybe_use_potentials,
seed = $maybe_seed,
template_threshold_angstroms = $maybe_template_threshold_angstroms,
template_chain_mapping = $maybe_template_chain_mapping,
})
$sequences
$template_trc_expr
in
boltz "$topology_vobj_path" "$residues_vobj_path" "$chains_vobj_path"
""").substitute(
run_spec=run_spec._to_rex(),
maybe_recycling_steps=optional_str(recycling_steps),
maybe_sampling_steps=optional_str(sampling_steps),
maybe_diffusion_samples=optional_str(diffusion_samples),
maybe_step_scale=optional_str(step_scale),
maybe_affinity_binder_chain_id=optional_str(affinity_binder_chain_id),
maybe_affinity_mw_correction=optional_str(affinity_mw_correction),
maybe_sampling_steps_affinity=optional_str(sampling_steps_affinity),
maybe_diffusion_samples_affinity=optional_str(diffusion_samples_affinity),
maybe_max_msa_seqs=optional_str(max_msa_seqs),
maybe_subsample_msa=optional_str(subsample_msa),
maybe_num_subsampled_msa=optional_str(num_subsampled_msa),
maybe_use_potentials=optional_str(use_potentials),
maybe_seed=optional_str(seed),
maybe_template_threshold_angstroms=optional_str(template_threshold_angstroms),
maybe_template_chain_mapping=(
f"(Some {dict_to_vec_of_tuples_str(template_chain_mapping)})"
if template_chain_mapping is not None
else "None"
),
sequences=f"[\n {',\n '.join([f'{seq._to_rex()}' for seq in sequences])},\n ]",
template_trc_expr=(
"(Some ((obj_j topology), (obj_j residues), (obj_j chains)) )"
if template_path is not None
else "None"
),
topology_vobj_path=topology_vobj["path"] if has_template else "",
residues_vobj_path=residues_vobj["path"] if has_template else "",
chains_vobj_path=chains_vobj["path"] if has_template else "",
)
try:
run_id = _submit_rex(_get_project_id(), rex, run_opts)
if collect:
return collect_run(run_id)
else:
return run_id
except TransportQueryError as e:
if e.errors:
for error in e.errors:
print(f"Error: {error['message']}", file=sys.stderr)