"""
EXESS geometry optimization for the Rush Python client.
Quick Links
-----------
- :func:`rush.exess.optimization`
- :class:`rush.exess.OptimizationResult`
"""
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from string import Template
from typing import Any, Literal
from gql.transport.exceptions import TransportQueryError
from rush import TRC, Residues, TRCRef
from rush._trc import to_residues_vobj, to_topology_vobj
from .._utils import optional_str
from ..client import (
RunOpts,
RunSpec,
RushObject,
_get_project_id,
_submit_rex,
fetch_object,
)
from ..mol import Topology
from ..run import RushRun
from ._energy import (
AuxBasisT,
BasisT,
KSDFTKeywords,
MethodT,
SCFKeywords,
StandardOrientationT,
System,
_KSDFTDefault,
)
# ---------------------------------------------------------------------------
# Input types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OptimizationConvergenceCriteria:
metric: str | None = None
gradient_threshold: float | None = None
delta_energy_threshold: float | None = None
step_component_threshold: float | None = None
def _to_rex(self, reference_fragment: int | None = None):
return Template(
"""Some (exess_geo_opt_rex::OptimizationConvergenceCriteria {
metric = $maybe_metric,
gradient_threshold = $maybe_gradient_threshold,
delta_energy_threshold = $maybe_delta_energy_threshold,
step_component_threshold = $maybe_step_component_threshold,
})"""
).substitute(
maybe_metric=optional_str(self.metric), # TODO: enum prefix
maybe_gradient_threshold=optional_str(self.gradient_threshold),
maybe_delta_energy_threshold=optional_str(self.delta_energy_threshold),
maybe_step_component_threshold=optional_str(self.step_component_threshold),
)
type CoordinateSystemT = Literal["Cartesian", "NaturalInternal", "DelocalisedInternal"]
type HessianGuessTypeT = Literal["Identity", "ScaledIdentity", "Schlegel", "Lindh"]
type OptimizationAlgorithmTypeT = Literal[
"EigenvectorFollowing", "TrustRegionAugmentedHessian", "LBFGS"
]
[docs]
@dataclass
class TrustRegionKeywords:
initial_radius: float | None = None
max_radius: float | None = None
min_radius: float | None = None
increase_factor: float | None = None
decrease_factor: float | None = None
constrict_factor: float | None = None
increase_threshold: float | None = None
decrease_threshold: float | None = None
rejection_threshold: float | None = None
def _to_rex(self):
return Template(
"""Some (exess_geo_opt_rex::TrustRegionKeywords {
initial_radius = $maybe_initial_radius,
max_radius = $maybe_max_radius,
min_radius = $maybe_min_radius,
increase_factor = $maybe_increase_factor,
decrease_factor = $maybe_decrease_factor,
constrict_factor = $maybe_constrict_factor,
increase_threshold = $maybe_increase_threshold,
decrease_threshold = $maybe_decrease_threshold,
rejection_threshold = $maybe_rejection_threshold,
})"""
).substitute(
maybe_initial_radius=optional_str(self.initial_radius),
maybe_max_radius=optional_str(self.max_radius),
maybe_min_radius=optional_str(self.min_radius),
maybe_increase_factor=optional_str(self.increase_factor),
maybe_decrease_factor=optional_str(self.decrease_factor),
maybe_constrict_factor=optional_str(self.constrict_factor),
maybe_increase_threshold=optional_str(self.increase_threshold),
maybe_decrease_threshold=optional_str(self.decrease_threshold),
maybe_rejection_threshold=optional_str(self.rejection_threshold),
)
type LBFGSLinesearchT = Literal[
"MoreThuente", "BacktrackingArmijo", "BacktrackingWolfe", "BacktrackingStrongWolfe"
]
[docs]
@dataclass
class LBFGSKeywords:
linesearch: LBFGSLinesearchT = "BacktrackingStrongWolfe"
n_corrections: int | None = None
epsilon: float | None = None
max_linesearch: int | None = None
gtol: float | None = None
def _to_rex(self):
return Template(
"""Some (exess_geo_opt_rex::LBFGSKeywords {
linesearch = $maybe_linesearch,
n_corrections = $maybe_n_corrections,
epsilon = $maybe_epsilon,
max_linesearch = $maybe_max_linesearch,
gtol = $maybe_gtol,
})"""
).substitute(
maybe_linesearch=optional_str(
self.linesearch, "exess_geo_opt_rex::LBFGSLinesearch::"
),
maybe_n_corrections=optional_str(self.n_corrections),
maybe_epsilon=optional_str(self.epsilon),
maybe_max_linesearch=optional_str(self.max_linesearch),
maybe_gtol=optional_str(self.gtol),
)
[docs]
@dataclass
class OptimizationKeywords:
convergence_criteria: OptimizationConvergenceCriteria | None = None
optimizer_reset_interval: int | None = None
coordinate_system: CoordinateSystemT | None = None
constraints: list[list[int]] | None = None
hessian_guess: HessianGuessTypeT | None = None
algorithm: OptimizationAlgorithmTypeT | None = None
lbfgs_keywords: LBFGSKeywords | None = None
frozen_distance_slippage_tolerance_angstroms: float | None = None
frozen_angle_slippage_tolerance_degrees: float | None = None
trust_region_keywords: TrustRegionKeywords | None = None
fixed_atoms: list[int] | None = None
free_atoms: list[int] | None = None
fixed_fragments: list[int] | None = None
free_fragments: list[int] | None = None
fix_heavy: bool | None = None
def _to_rex(self, max_iters):
return Template(
"""Some (exess_geo_opt_rex::OptimizationKeywords {
max_iters = $max_iters,
convergence_criteria = $maybe_convergence_criteria,
optimizer_reset_interval = $maybe_optimizer_reset_interval,
coordinate_system = $maybe_coordinate_system,
constraints = $maybe_constraints,
hessian_guess = $maybe_hessian_guess,
algorithm = $maybe_algorithm,
lbfgs_keywords = $maybe_lbfgs_keywords,
frozen_distance_slippage_tolerance_angstroms = $maybe_frozen_distance_slippage_tolerance_angstroms,
frozen_angle_slippage_tolerance_degrees = $maybe_frozen_angle_slippage_tolerance_degrees,
trust_region_keywords = $maybe_trust_region_keywords,
fixed_atoms = $maybe_fixed_atoms,
free_atoms = $maybe_free_atoms,
fixed_fragments = $maybe_fixed_fragments,
free_fragments = $maybe_free_fragments,
fix_heavy = $maybe_fix_heavy,
})"""
).substitute(
max_iters=max_iters,
maybe_convergence_criteria=(
self.convergence_criteria._to_rex()
if self.convergence_criteria is not None
else "None"
),
maybe_optimizer_reset_interval=optional_str(self.optimizer_reset_interval),
maybe_coordinate_system=optional_str(
self.coordinate_system, "exess_geo_opt_rex::CoordinateSystem::"
),
# maybe_constraints=optional_list(
# self.constraints,
# lambda constraint: f"vec![{', '.join(f'exess_geo_opt_rex::AtomRef ({atom})' for atom in constraint)}]",
# ),
maybe_constraints="None", # TODO
maybe_hessian_guess=optional_str(
self.hessian_guess, "exess_geo_opt_rex::HessianGuessType::"
),
maybe_algorithm=optional_str(
self.algorithm, "exess_geo_opt_rex::OptimizationAlgorithmType::"
),
maybe_lbfgs_keywords=(
self.lbfgs_keywords._to_rex()
if self.lbfgs_keywords is not None
else "None"
),
maybe_frozen_distance_slippage_tolerance_angstroms=optional_str(
self.frozen_distance_slippage_tolerance_angstroms
),
maybe_frozen_angle_slippage_tolerance_degrees=optional_str(
self.frozen_angle_slippage_tolerance_degrees
),
maybe_trust_region_keywords=(
self.trust_region_keywords._to_rex()
if self.trust_region_keywords is not None
else "None"
),
maybe_fixed_atoms=optional_str(self.fixed_atoms),
maybe_free_atoms=optional_str(self.free_atoms),
maybe_fixed_fragments=optional_str(self.fixed_fragments),
maybe_free_fragments=optional_str(self.free_fragments),
maybe_fix_heavy=optional_str(self.fix_heavy),
)
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
[docs]
@dataclass
class OptimizationStep:
total_energy: float
max_gradient_component: float
[docs]
@dataclass
class OptimizationResult:
trajectory: list[Topology]
steps: list[OptimizationStep]
[docs]
@dataclass(frozen=True)
class OptimizationResultPaths:
trajectory: Path
steps: Path
[docs]
@dataclass(frozen=True)
class OptimizationResultRef:
"""Lightweight reference to optimization outputs in the Rush object store."""
trajectory: RushObject
steps: RushObject
[docs]
@classmethod
def from_raw_output(cls, res: Any) -> "OptimizationResultRef":
"""Parse raw ``collect_run`` output into an ``OptimizationResultRef``."""
if not isinstance(res, list) or len(res) != 2:
raise ValueError(
"optimization should return exactly 2 outputs (trajectory + steps), "
f"got {type(res).__name__} with {len(res) if hasattr(res, '__len__') else '?'} items."
)
return cls(
trajectory=RushObject.from_dict(res[0]),
steps=RushObject.from_dict(res[1]),
)
[docs]
def fetch(self) -> OptimizationResult:
"""Download optimization outputs and parse into Python objects."""
trajectory = [
Topology.from_json(t)
for t in json.loads(fetch_object(self.trajectory.path))
]
steps = [
OptimizationStep(**step)
for step in json.loads(fetch_object(self.steps.path))
]
return OptimizationResult(trajectory=trajectory, steps=steps)
[docs]
def save(self) -> OptimizationResultPaths:
"""Download optimization outputs and save to the workspace."""
return OptimizationResultPaths(
trajectory=self.trajectory.save(),
steps=self.steps.save(),
)
# ---------------------------------------------------------------------------
# Submission
# ---------------------------------------------------------------------------
[docs]
def optimization(
mol: TRC
| TRCRef
| tuple[Path | str | RushObject | Topology, Path | str | RushObject | Residues]
| Path
| str
| RushObject
| Topology,
max_iters: int,
optimization_keywords: OptimizationKeywords = OptimizationKeywords(),
method: MethodT = "RestrictedKSDFT",
basis: BasisT = "cc-pVDZ",
aux_basis: AuxBasisT | None = None,
standard_orientation: StandardOrientationT | None = None,
force_cartesian_basis_sets: bool | None = None,
scf_keywords: SCFKeywords | None = None,
ksdft_keywords: KSDFTKeywords | _KSDFTDefault | None = _KSDFTDefault.DEFAULT,
qm_fragments: list[int] | None = None,
mm_fragments: list[int] | None = None,
system: System | None = None,
run_spec: RunSpec = RunSpec(gpus=1),
run_opts: RunOpts = RunOpts(),
) -> RushRun[OptimizationResultRef]:
"""
Submit a geometry optimization for the topology at *topology_path*.
Returns a :class:`~rush.run.RushRun` handle. Call ``.fetch()`` to get the
parsed trajectory and optimization steps, or ``.save()`` to write them to disk.
"""
ksdft_keywords = KSDFTKeywords.resolve(ksdft_keywords, method)
# Upload inputs
residues_vobj = None
match mol:
case TRC() | TRCRef():
topology_vobj = to_topology_vobj(mol.topology)
residues_vobj = to_residues_vobj(mol.residues)
case (t, r):
topology_vobj = to_topology_vobj(t)
residues_vobj = to_residues_vobj(r)
case _:
topology_vobj = to_topology_vobj(mol)
# Run rex
rex = Template("""let
obj_j = λ j →
VirtualObject { path = j, format = ObjectFormat::json, size = 0 },
exess = λ topology residues →
exess_geo_opt_rex_s
($run_spec)
(exess_geo_opt_rex::OptimizationParams {
schema_version = "0.2.0",
external_charges = None,
model = Some (exess_geo_opt_rex::Model {
method = exess_geo_opt_rex::Method::$method,
basis = "$basis",
aux_basis = $maybe_aux_basis,
standard_orientation = $maybe_standard_orientation,
force_cartesian_basis_sets = $maybe_force_cartesian_basis_sets,
}),
system = $maybe_system,
keywords = exess_geo_opt_rex::Keywords {
scf = $maybe_scf_keywords,
ks_dft = $maybe_ks_keywords,
rtat = None,
frag = None,
boundary = None,
log = None,
dynamics = None,
integrals = None,
debug = None,
export = None,
guess = None,
force_field = None,
optimization = $maybe_optimization_keywords,
hessian = None,
gradient = None,
qmmm = $maybe_qmmm_keywords,
machine_learning = None,
regions = $maybe_regions,
},
})
[ (obj_j topology) ]
$residues_expr
in
exess "$topology_vobj_path" "$residues_vobj_path"
""").substitute(
run_spec=run_spec._to_rex(),
method=method,
basis=basis,
maybe_aux_basis=optional_str(aux_basis),
maybe_standard_orientation=optional_str(
standard_orientation, "exess_rex::StandardOrientation::"
),
maybe_force_cartesian_basis_sets=optional_str(force_cartesian_basis_sets),
maybe_system=system._to_rex() if system is not None else "None",
maybe_scf_keywords=(
scf_keywords._to_rex() if scf_keywords is not None else "None"
),
maybe_ks_keywords=(
ksdft_keywords._to_rex() if ksdft_keywords is not None else "None"
),
maybe_optimization_keywords=(
optimization_keywords._to_rex(max_iters)
if optimization_keywords is not None
else "None"
),
maybe_qmmm_keywords=(
"""Some (exess_qmmm_rex::QMMMKeywords {
n_timesteps = 1,
dt_ps = 0.002,
temperature_kelvin = 290.0,
pressure_atm = None,
minimisation = None,
trajectory = None,
restraints = None,
energy_csv = None,
})"""
if mm_fragments or (qm_fragments is not None)
else "None"
),
maybe_regions=(
Template(
"""Some (exess_qmmm_rex::RegionKeywords {
qm_fragments = $maybe_qm_fragments,
mm_fragments = $maybe_mm_fragments,
ml_fragments = Some [],
})"""
).substitute(
maybe_qm_fragments=optional_str(qm_fragments),
maybe_mm_fragments=optional_str(mm_fragments),
)
if not (qm_fragments is None and mm_fragments is None)
else "None"
),
residues_expr=(
"(Some [ (obj_j residues) ])" if residues_vobj is not None else "None"
),
topology_vobj_path=topology_vobj["path"],
residues_vobj_path=residues_vobj["path"] if residues_vobj is not None else "",
)
try:
return RushRun(
_submit_rex(_get_project_id(), rex, run_opts),
OptimizationResultRef,
)
except TransportQueryError as e:
if e.errors:
for error in e.errors:
print(f"Error: {error['message']}", file=sys.stderr)
raise