diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 2fe90ff..f2376d9 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -20,7 +20,7 @@ jobs: with: fetch-depth: 0 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -31,9 +31,11 @@ jobs: python -m pip install . - name: Run pre-commit hooks run: | - git fetch origin main - BASE_REF="$(git merge-base HEAD origin/main)" - pre-commit run --from-ref "$BASE_REF" --to-ref HEAD + if [ "${{ github.event_name }}" = "pull_request" ]; then + pre-commit run --from-ref origin/${{ github.base_ref }} --to-ref HEAD + else + pre-commit run --from-ref HEAD~1 --to-ref HEAD + fi - name: Test with coverage run: | coverage run --source=abcfold --module pytest --verbose tests && coverage report --show-missing diff --git a/README.md b/README.md index 2a03fd1..b133064 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,67 @@ Below are scripts for adding MMseqs2 MSAs and custom templates to AlphaFold3 inp > [!WARNING] > These scripts will only modify the input JSON files, I.E. they will NOT run AlphaFold3, Boltz, Chai-1, OpenFold3 and Protenix. +### Scoring existing Boltz complexes and poses + +ABCFold also includes a Boltz2 utility for scoring existing complex coordinates +or fixed-receptor ligand poses without running Boltz diffusion sampling. See +[Boltz Existing-Structure Scoring](docs/boltz_existing_scoring.md) for +confidence scoring, affinity scoring, and `--reuse_trunk` examples. + +After installation, use `boltz-score-existing` directly: + +```bash +boltz-score-existing poses.sdf --receptor receptor.pdb --affinity +``` + +For a checkout managed with `uv`, either run through `uv`: + +```bash +uv run boltz-score-existing poses.sdf --receptor receptor.pdb --affinity +``` + +or activate the project environment first: + +```bash +source .venv/bin/activate +boltz-score-existing poses.sdf --receptor receptor.pdb --affinity +``` + +The module form +`python -m abcfold.boltz.score_existing` remains available for development +checkouts. + +### Docking ligand SMILES into a crystal pocket with Boltz + +ABCFold also includes a Boltz-native docking wrapper for the case where a +crystal receptor and pocket are known, but the ligand pose should be generated +from SMILES by Boltz. See +[Boltz Crystal-Pocket Docking](docs/boltz_crystal_docking.md) for crystal +template, pocket constraint, reference-ligand, and affinity examples. + +After installation, use `boltz-dock-crystal` directly: + +```bash +boltz-dock-crystal receptor.pdb "CCO" --pocket_residue A:145 --affinity +``` + +For a checkout managed with `uv`, either run through `uv`: + +```bash +uv run boltz-dock-crystal receptor.pdb "CCO" --pocket_residue A:145 --affinity +``` + +or activate the project environment first: + +```bash +source .venv/bin/activate +boltz-dock-crystal receptor.pdb "CCO" --pocket_residue A:145 --affinity +``` + +The module form +`python -m abcfold.boltz.dock_crystal` remains available for development +checkouts. + ### Adding MMseqs2 MSAs and templates To add MMseqs2 MSAs and templates to the AlphaFold3 input JSON, you can use the `mmseqs2msa`: diff --git a/abcfold/boltz/__init__.py b/abcfold/boltz/__init__.py new file mode 100644 index 0000000..c2278e0 --- /dev/null +++ b/abcfold/boltz/__init__.py @@ -0,0 +1 @@ +"""Boltz utility entrypoints for ABCFold.""" diff --git a/abcfold/boltz/dock_crystal.py b/abcfold/boltz/dock_crystal.py new file mode 100644 index 0000000..af2a70f --- /dev/null +++ b/abcfold/boltz/dock_crystal.py @@ -0,0 +1,574 @@ +"""Run Boltz-native ligand docking against a crystal receptor template.""" + +from __future__ import annotations + +import argparse +import configparser +import json +import shutil +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from Bio.Data.PDBData import protein_letters_3to1_extended +from Bio.PDB.PDBParser import PDBParser + +from abcfold.boltz.check_install import ensure_boltz_env +from abcfold.output.utils import verify_config_file + + +@dataclass(frozen=True) +class ProteinResidue: + chain_id: str + pdb_number: int + sequence_index: int + residue: Any + + +@dataclass(frozen=True) +class ProteinChain: + chain_id: str + sequence: str + residues: list[ProteinResidue] + + +@dataclass(frozen=True) +class DockingInput: + yaml_path: Path + command: list[str] + contacts: list[list[str | int]] + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="boltz-dock-crystal", + description=( + "Dock a ligand SMILES with Boltz while constraining the protein to " + "a crystal receptor template and the ligand to a pocket." + ), + ) + parser.add_argument("receptor", type=Path, help="Crystal receptor PDB file.") + parser.add_argument("smiles", help="Ligand SMILES to dock into the receptor.") + parser.add_argument( + "--out_dir", + type=Path, + default=Path("boltz_crystal_docking"), + help="Directory for the generated YAML and Boltz outputs.", + ) + parser.add_argument( + "--protein_chain", + action="append", + dest="protein_chains", + help=( + "Protein chain to include. Repeat for multi-chain receptors. " + "Defaults to all protein chains in the first model." + ), + ) + parser.add_argument( + "--ligand_chain_id", + default="L", + help="Boltz chain id assigned to the docked ligand.", + ) + parser.add_argument( + "--pocket_residue", + action="append", + default=[], + help=( + "Pocket residue in CHAIN:RESNUM form. Repeat or use comma-separated " + "values. RESNUM is PDB numbering by default." + ), + ) + parser.add_argument( + "--pocket_numbering", + choices=["pdb", "sequence"], + default="pdb", + help=( + "Interpret --pocket_residue numbers as PDB residue numbers or " + "sequence indices." + ), + ) + parser.add_argument( + "--reference_ligand_chain", + help="Optional ligand chain in the receptor PDB used to infer pocket residues.", + ) + parser.add_argument( + "--pocket_cutoff", + type=float, + default=6.0, + help=( + "Distance cutoff in Angstrom for --reference_ligand_chain pocket " + "inference." + ), + ) + parser.add_argument( + "--max_distance", + type=float, + default=6.0, + help="Boltz pocket max_distance in Angstrom.", + ) + parser.add_argument( + "--template_threshold", + type=float, + default=1.0, + help="Allowed template deviation in Angstrom when force_template is enabled.", + ) + parser.add_argument( + "--no_force_template", + action="store_true", + help="Do not force the protein backbone toward the crystal template.", + ) + parser.add_argument( + "--no_force_pocket", + action="store_true", + help="Do not force the ligand toward the pocket contacts.", + ) + parser.add_argument( + "--affinity", + action="store_true", + help="Ask Boltz to predict ligand affinity for the generated pose.", + ) + parser.add_argument( + "--use_msa_server", + action="store_true", + help="Let Boltz query the MSA server instead of using msa: empty.", + ) + parser.add_argument( + "--no_use_potentials", + action="store_true", + help="Do not pass --use_potentials to boltz predict.", + ) + parser.add_argument( + "--diffusion_samples", + type=int, + default=25, + help="Number of Boltz diffusion samples.", + ) + parser.add_argument( + "--recycling_steps", + type=int, + default=10, + help="Number of Boltz recycling steps.", + ) + parser.add_argument( + "--sampling_steps", + type=int, + default=200, + help="Number of Boltz diffusion sampling steps.", + ) + parser.add_argument( + "--step_scale", + type=float, + help="Optional Boltz diffusion step scale.", + ) + parser.add_argument( + "--cache", + type=Path, + default=Path.home() / ".boltz", + help="Boltz cache directory.", + ) + parser.add_argument( + "--devices", + type=int, + default=1, + help="Number of devices passed to boltz predict.", + ) + parser.add_argument( + "--accelerator", + choices=["gpu", "cpu", "tpu"], + default="gpu", + help="Boltz accelerator.", + ) + parser.add_argument( + "--output_format", + choices=["mmcif", "pdb"], + default="mmcif", + help="Boltz output structure format.", + ) + parser.add_argument( + "--runner", + choices=["abcfold-env", "path"], + default="abcfold-env", + help=( + "Run through ABCFold's managed Boltz micromamba env, or use boltz " + "from PATH." + ), + ) + parser.add_argument( + "--config-file", + type=Path, + default=Path.home() / ".abcfold_config.ini", + help="ABCFold config used when --runner abcfold-env is selected.", + ) + parser.add_argument( + "--dry_run", + action="store_true", + help="Write the YAML and command file without running Boltz.", + ) + return parser.parse_args(argv) + + +def _read_receptor(path: Path) -> Any: + if path.suffix.lower() != ".pdb": + raise ValueError("Boltz crystal docking currently expects a receptor PDB file.") + return PDBParser(QUIET=True).get_structure(path.stem, str(path)) + + +def _one_letter(resname: str) -> str | None: + return protein_letters_3to1_extended.get(resname.upper()) + + +def _extract_protein_chains( + structure: Any, + selected_chain_ids: set[str] | None, +) -> list[ProteinChain]: + chains = [] + for chain in structure[0]: + if selected_chain_ids is not None and chain.id not in selected_chain_ids: + continue + + residues = [] + letters = [] + sequence_index = 1 + for residue in chain: + if residue.id[0] != " ": + continue + letter = _one_letter(residue.resname) + if letter is None: + continue + letters.append(letter) + residues.append( + ProteinResidue( + chain_id=chain.id, + pdb_number=int(residue.id[1]), + sequence_index=sequence_index, + residue=residue, + ) + ) + sequence_index += 1 + + if residues: + chains.append( + ProteinChain( + chain_id=chain.id, + sequence="".join(letters), + residues=residues, + ) + ) + + if not chains: + raise ValueError("No protein chains were found in the receptor PDB.") + return chains + + +def _parse_pocket_residue_tokens(tokens: list[str]) -> list[tuple[str, int]]: + parsed = [] + for token_group in tokens: + for token in token_group.split(","): + token = token.strip() + if not token: + continue + if ":" not in token: + raise ValueError(f"Pocket residue must use CHAIN:RESNUM: {token}") + chain_id, residue_number = token.split(":", 1) + if not chain_id: + raise ValueError(f"Pocket residue is missing a chain id: {token}") + parsed.append((chain_id, int(residue_number))) + return parsed + + +def _contacts_from_tokens( + protein_chains: list[ProteinChain], + tokens: list[str], + numbering: str, +) -> list[list[str | int]]: + contacts: list[list[str | int]] = [] + chain_lookup = {chain.chain_id: chain for chain in protein_chains} + + for chain_id, residue_number in _parse_pocket_residue_tokens(tokens): + chain = chain_lookup.get(chain_id) + if chain is None: + raise ValueError( + f"Pocket chain {chain_id} is not in the receptor proteins." + ) + + for residue in chain.residues: + number = ( + residue.pdb_number + if numbering == "pdb" + else residue.sequence_index + ) + if number == residue_number: + contacts.append([chain_id, residue.sequence_index]) + break + else: + raise ValueError( + f"Pocket residue {chain_id}:{residue_number} was not found " + f"with {numbering} numbering." + ) + + return contacts + + +def _squared_distance(atom_a: Any, atom_b: Any) -> float: + delta = atom_a.coord - atom_b.coord + return float(delta.dot(delta)) + + +def _ligand_atoms_from_chain(structure: Any, chain_id: str) -> list[Any]: + if chain_id not in structure[0]: + raise ValueError(f"Reference ligand chain {chain_id} was not found.") + atoms = [ + atom + for residue in structure[0][chain_id] + if residue.id[0] != " " + for atom in residue + if atom.element != "H" + ] + if not atoms: + raise ValueError(f"Reference ligand chain {chain_id} has no ligand atoms.") + return atoms + + +def _contacts_from_reference_ligand( + protein_chains: list[ProteinChain], + structure: Any, + ligand_chain_id: str, + cutoff: float, +) -> list[list[str | int]]: + ligand_atoms = _ligand_atoms_from_chain(structure, ligand_chain_id) + cutoff_sq = cutoff * cutoff + contacts: list[list[str | int]] = [] + + for chain in protein_chains: + for protein_residue in chain.residues: + is_contact = any( + _squared_distance(atom, ligand_atom) <= cutoff_sq + for atom in protein_residue.residue + if atom.element != "H" + for ligand_atom in ligand_atoms + ) + if is_contact: + contacts.append([chain.chain_id, protein_residue.sequence_index]) + + if not contacts: + raise ValueError( + f"No pocket residues were found within {cutoff:g} A of ligand " + f"chain {ligand_chain_id}." + ) + return contacts + + +def _dedupe_contacts(contacts: list[list[str | int]]) -> list[list[str | int]]: + seen = set() + deduped: list[list[str | int]] = [] + for chain_id, residue_index in contacts: + key = (str(chain_id), int(residue_index)) + if key in seen: + continue + seen.add(key) + deduped.append([key[0], key[1]]) + return deduped + + +def _yaml_scalar(value: Any) -> str: + if isinstance(value, bool): + return "true" if value else "false" + if isinstance(value, (int, float)): + return f"{value:g}" if isinstance(value, float) else str(value) + return json.dumps(str(value)) + + +def _yaml_flow_list(values: list[Any]) -> str: + rendered = [] + for value in values: + if isinstance(value, list): + rendered.append(_yaml_flow_list(value)) + else: + rendered.append(_yaml_scalar(value)) + return "[" + ", ".join(rendered) + "]" + + +def _render_yaml( + receptor: Path, + protein_chains: list[ProteinChain], + ligand_chain_id: str, + smiles: str, + contacts: list[list[str | int]], + args: argparse.Namespace, +) -> str: + lines = ["version: 1", "sequences:"] + for chain in protein_chains: + lines.extend([ + " - protein:", + f" id: {_yaml_scalar(chain.chain_id)}", + f" sequence: {_yaml_scalar(chain.sequence)}", + ]) + if not args.use_msa_server: + lines.append(" msa: empty") + + lines.extend([ + " - ligand:", + f" id: {_yaml_scalar(ligand_chain_id)}", + f" smiles: {_yaml_scalar(smiles)}", + "constraints:", + " - pocket:", + f" binder: {_yaml_scalar(ligand_chain_id)}", + f" contacts: {_yaml_flow_list(contacts)}", + f" max_distance: {_yaml_scalar(args.max_distance)}", + f" force: {_yaml_scalar(not args.no_force_pocket)}", + "templates:", + f" - pdb: {_yaml_scalar(str(receptor))}", + ]) + + chain_ids = [chain.chain_id for chain in protein_chains] + template_ids = [f"{chain.chain_id}1" for chain in protein_chains] + if len(chain_ids) == 1: + lines.append(f" chain_id: {_yaml_scalar(chain_ids[0])}") + lines.append(f" template_id: {_yaml_scalar(template_ids[0])}") + else: + lines.append(f" chain_id: {_yaml_flow_list(chain_ids)}") + lines.append(f" template_id: {_yaml_flow_list(template_ids)}") + + lines.extend([ + f" force: {_yaml_scalar(not args.no_force_template)}", + f" threshold: {_yaml_scalar(args.template_threshold)}", + ]) + + if args.affinity: + lines.extend([ + "properties:", + " - affinity:", + f" binder: {_yaml_scalar(ligand_chain_id)}", + ]) + + return "\n".join(lines) + "\n" + + +def generate_boltz_crystal_dock_command( + input_yaml: Path, + output_dir: Path, + args: argparse.Namespace, +) -> list[str]: + cmd = [ + "boltz", + "predict", + str(input_yaml), + "--out_dir", + str(output_dir), + "--override", + "--diffusion_samples", + str(args.diffusion_samples), + "--recycling_steps", + str(args.recycling_steps), + "--sampling_steps", + str(args.sampling_steps), + "--cache", + str(args.cache), + "--devices", + str(args.devices), + "--accelerator", + args.accelerator, + "--output_format", + args.output_format, + ] + if args.use_msa_server: + cmd.append("--use_msa_server") + if not args.no_use_potentials: + cmd.append("--use_potentials") + if args.step_scale is not None: + cmd.extend(["--step_scale", str(args.step_scale)]) + return cmd + + +def _load_config(config_file: Path) -> dict[str, str]: + default_config_file = Path(__file__).parents[1] / "data" / "config.ini" + config_file = config_file.expanduser() + if not config_file.exists(): + shutil.copy(default_config_file, config_file) + else: + verify_config_file(config_file, default_config_file) + + config = configparser.ConfigParser() + config.read(str(config_file)) + runtime_config = {} + for section in config.sections(): + runtime_config.update(dict(config.items(section))) + return runtime_config + + +def prepare_crystal_docking_input(args: argparse.Namespace) -> DockingInput: + receptor = args.receptor.expanduser().resolve() + out_dir = args.out_dir.expanduser().resolve() + out_dir.mkdir(parents=True, exist_ok=True) + + structure = _read_receptor(receptor) + selected_chain_ids = ( + set(args.protein_chains) + if args.protein_chains is not None + else None + ) + protein_chains = _extract_protein_chains(structure, selected_chain_ids) + contacts = _contacts_from_tokens( + protein_chains, + args.pocket_residue, + args.pocket_numbering, + ) + if args.reference_ligand_chain is not None: + contacts.extend( + _contacts_from_reference_ligand( + protein_chains, + structure, + args.reference_ligand_chain, + args.pocket_cutoff, + ) + ) + contacts = _dedupe_contacts(contacts) + if not contacts: + raise ValueError( + "No pocket contacts were provided. Use --pocket_residue or " + "--reference_ligand_chain." + ) + + yaml_path = out_dir / "boltz_crystal_dock.yaml" + yaml_text = _render_yaml( + receptor, + protein_chains, + args.ligand_chain_id, + args.smiles, + contacts, + args, + ) + yaml_path.write_text(yaml_text) + + command = generate_boltz_crystal_dock_command(yaml_path, out_dir, args) + (out_dir / "boltz_crystal_dock_command.json").write_text( + json.dumps(command, indent=2) + "\n" + ) + return DockingInput(yaml_path=yaml_path, command=command, contacts=contacts) + + +def run_crystal_docking(args: argparse.Namespace) -> DockingInput: + docking_input = prepare_crystal_docking_input(args) + if args.dry_run: + return docking_input + + if args.runner == "abcfold-env": + env = ensure_boltz_env(config=_load_config(args.config_file)) + env.run(docking_input.command, capture_output=True) + else: + subprocess.run(docking_input.command, check=True) + return docking_input + + +def main(argv: list[str] | None = None) -> None: + args = _parse_args(argv) + docking_input = run_crystal_docking(args) + print(f"Boltz input YAML: {docking_input.yaml_path}") + print(f"Pocket contacts: {len(docking_input.contacts)}") + if args.dry_run: + print("Dry run complete; Boltz was not executed.") + + +if __name__ == "__main__": + main() diff --git a/abcfold/boltz/score_existing.py b/abcfold/boltz/score_existing.py new file mode 100644 index 0000000..5ff949a --- /dev/null +++ b/abcfold/boltz/score_existing.py @@ -0,0 +1,1146 @@ +"""Score existing complexes with Boltz2 confidence and optional affinity heads. + +This module intentionally bypasses Boltz diffusion sampling. It parses one or +more ready PDB/mmCIF complexes, featurizes them with Boltz2, runs the trunk plus +confidence module with ``skip_run_structure=True``, and writes score files. +""" + +from __future__ import annotations + +import argparse +import csv +import hashlib +import json +import re +from dataclasses import asdict, dataclass, replace +from pathlib import Path +from typing import Any + +SUMMARY_KEYS = [ + "confidence_score", + "ptm", + "iptm", + "ligand_iptm", + "protein_iptm", + "complex_plddt", + "complex_iplddt", + "complex_pde", + "complex_ipde", +] + +AFFINITY_KEYS = [ + "affinity_pred_value", + "affinity_probability_binary", + "affinity_pred_value1", + "affinity_probability_binary1", + "affinity_pred_value2", + "affinity_probability_binary2", +] + + +@dataclass(frozen=True) +class LigandPose: + path: Path + pose_index: int + mol: Any + + +def _safe_id(path: Path) -> str: + """Return a Boltz-safe record id derived from a structure path.""" + return re.sub(r"[^A-Za-z0-9_.-]+", "_", path.stem).strip("_") or "complex" + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="boltz-score-existing", + description="Score existing PDB/mmCIF complexes with Boltz2 confidence.", + ) + parser.add_argument( + "structures", + nargs="+", + type=Path, + help=( + "Existing complex structure files (.cif/.mmcif/.pdb), or ligand " + "SDF files when --receptor is set." + ), + ) + parser.add_argument( + "--receptor", + type=Path, + help=( + "Protein receptor PDB. When set, positional inputs are ligand SDF " + "poses scored as receptor-ligand complexes." + ), + ) + parser.add_argument( + "--ligand_chain_id", + default="L", + help="Chain id assigned to SDF ligands in --receptor mode.", + ) + parser.add_argument( + "--out_dir", + type=Path, + default=Path("boltz_existing_scores"), + help="Directory for confidence JSON/NPZ outputs.", + ) + parser.add_argument( + "--cache", + type=Path, + default=Path.home() / ".boltz", + help="Boltz cache containing mols/, boltz2_conf.ckpt, and boltz2_aff.ckpt.", + ) + parser.add_argument( + "--affinity", + action="store_true", + help=( + "Also run the Boltz2 affinity head on the provided coordinates. " + "This uses boltz2_aff.ckpt by default and still skips diffusion." + ), + ) + parser.add_argument( + "--affinity_checkpoint", + type=Path, + help="Optional path to a Boltz2 affinity checkpoint.", + ) + parser.add_argument( + "--no_affinity_mw_correction", + action="store_true", + help="Disable Boltz2 molecular-weight correction for affinity predictions.", + ) + parser.add_argument( + "--device", + choices=["auto", "cpu", "cuda", "mps"], + default="auto", + help="Torch device. auto prefers CUDA, then MPS, then CPU.", + ) + parser.add_argument( + "--recycling_steps", + type=int, + default=0, + help="Pairformer recycling steps before confidence scoring.", + ) + parser.add_argument( + "--reuse_trunk", + action="store_true", + help=( + "Cache trunk embeddings from the first structure and reuse them for " + "the remaining structures. Only valid for the same target topology." + ), + ) + parser.add_argument( + "--write_full_pae", + action="store_true", + help="Write full PAE matrices. Disabled by default to keep screening fast.", + ) + parser.add_argument( + "--write_full_pde", + action="store_true", + help="Write full PDE matrices. Disabled by default to keep screening fast.", + ) + parser.add_argument( + "--use_kernels", + action="store_true", + help="Enable optional CUDA kernels if they are installed.", + ) + parser.add_argument( + "--no_download", + action="store_true", + help="Do not download missing Boltz2 weights or molecule cache.", + ) + return parser.parse_args(argv) + + +def _import_boltz() -> dict[str, Any]: + """Import Boltz only when the scoring command actually runs.""" + import numpy as np + import torch + from boltz.data import const + from boltz.data.feature.featurizerv2 import Boltz2Featurizer + from boltz.data.module.inferencev2 import collate + from boltz.data.mol import load_canonicals, load_molecules + from boltz.data.parse.schema import parse_boltz_schema + from boltz.data.tokenize.boltz2 import Boltz2Tokenizer + from boltz.data.types import Coords, Input + from boltz.main import (Boltz2DiffusionParams, BoltzSteeringParams, + MSAModuleArgs, PairformerArgsV2, download_boltz2) + from boltz.model.models.boltz2 import Boltz2 + + return { + "np": np, + "torch": torch, + "const": const, + "Boltz2Featurizer": Boltz2Featurizer, + "load_canonicals": load_canonicals, + "load_molecules": load_molecules, + "collate": collate, + "parse_boltz_schema": parse_boltz_schema, + "Boltz2Tokenizer": Boltz2Tokenizer, + "Coords": Coords, + "Input": Input, + "Boltz2DiffusionParams": Boltz2DiffusionParams, + "BoltzSteeringParams": BoltzSteeringParams, + "MSAModuleArgs": MSAModuleArgs, + "PairformerArgsV2": PairformerArgsV2, + "download_boltz2": download_boltz2, + "Boltz2": Boltz2, + } + + +def _select_device(torch: Any, device: str) -> Any: + if device == "auto": + if torch.cuda.is_available(): + return torch.device("cuda") + if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + selected = torch.device(device) + if selected.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA was requested but is not available.") + if ( + selected.type == "mps" + and ( + not getattr(torch.backends, "mps", None) + or not torch.backends.mps.is_available() + ) + ): + raise RuntimeError("MPS was requested but is not available.") + return selected + + +def _read_bio_structure(path: Path, record_id: str): + from Bio.PDB.MMCIFParser import MMCIFParser + from Bio.PDB.PDBParser import PDBParser + + suffix = path.suffix.lower() + if suffix == ".pdb": + return PDBParser(QUIET=True).get_structure(record_id, str(path)) + if suffix in {".cif", ".mmcif"}: + return MMCIFParser(QUIET=True).get_structure(record_id, str(path)) + raise ValueError(f"Unsupported structure suffix for {path}: {path.suffix}") + + +def _protein_residues(chain: Any, boltz: dict[str, Any]) -> list[Any]: + protein_tokens = set(boltz["const"].prot_token_to_letter) + return [ + residue + for residue in chain + if residue.id[0] == " " and residue.resname in protein_tokens + ] + + +def _ligand_residues(chain: Any) -> list[Any]: + return [ + residue + for residue in chain + if residue.id[0] != " " + ] + + +def _schema_from_bio_structure(record_id: str, structure: Any, boltz: dict[str, Any]): + sequences = [] + model = structure[0] + + for chain in model: + protein_residues = _protein_residues(chain, boltz) + if protein_residues: + sequence = "".join( + boltz["const"].prot_token_to_letter.get(residue.resname, "X") + for residue in protein_residues + ) + sequences.append({ + "protein": { + "id": chain.id, + "sequence": sequence, + "msa": "empty", + } + }) + continue + + for residue in _ligand_residues(chain): + sequences.append({ + "ligand": { + "id": chain.id, + "ccd": residue.resname, + } + }) + + if not sequences: + raise ValueError(f"No supported protein or CCD ligand chains in {record_id}.") + + return { + "version": 1, + "sequences": sequences, + } + + +def _with_affinity_property( + schema: dict[str, Any], + ligand_chain_id: str | None = None, +) -> dict[str, Any]: + if ligand_chain_id is None: + ligand_ids = [] + for item in schema["sequences"]: + ligand = item.get("ligand") + if ligand is None: + continue + ligand_id = ligand["id"] + if isinstance(ligand_id, str): + ligand_ids.append(ligand_id) + else: + ligand_ids.extend(ligand_id) + + if len(ligand_ids) != 1: + raise ValueError( + "Affinity scoring requires exactly one ligand chain unless " + "--ligand_chain_id identifies the binder." + ) + ligand_chain_id = ligand_ids[0] + + return { + **schema, + "properties": [ + *schema.get("properties", []), + {"affinity": {"binder": ligand_chain_id}}, + ], + } + + +def _protein_schema_items_from_bio_structure( + record_id: str, + structure: Any, + boltz: dict[str, Any], +) -> list[dict[str, Any]]: + sequences = [] + model = structure[0] + + for chain in model: + protein_residues = _protein_residues(chain, boltz) + if not protein_residues: + continue + + sequence = "".join( + boltz["const"].prot_token_to_letter.get(residue.resname, "X") + for residue in protein_residues + ) + sequences.append({ + "protein": { + "id": chain.id, + "sequence": sequence, + "msa": "empty", + } + }) + + if not sequences: + raise ValueError(f"No supported protein chains in receptor {record_id}.") + + return sequences + + +def _coords_by_chain_residue(residues: list[Any]) -> list[dict[str, Any]]: + residue_coords = [] + for residue in residues: + atom_coords = { + atom.name.strip(): tuple(float(x) for x in atom.coord) + for atom in residue + } + residue_coords.append(atom_coords) + return residue_coords + + +def _transfer_coordinates( + target_structure: Any, + bio_structure: Any, + boltz: dict[str, Any], +): + model = bio_structure[0] + atoms = target_structure.atoms.copy() + coords_data = [] + + for chain in target_structure.chains: + chain_name = str(chain["name"]) + if chain_name not in model: + raise ValueError(f"Structure is missing chain {chain_name}.") + + bio_chain = model[chain_name] + if int(chain["mol_type"]) == boltz["const"].chain_type_ids["NONPOLYMER"]: + source_residues = _ligand_residues(bio_chain) + else: + source_residues = _protein_residues(bio_chain, boltz) + + res_start = int(chain["res_idx"]) + res_end = res_start + int(chain["res_num"]) + target_residues = target_structure.residues[res_start:res_end] + if len(source_residues) != len(target_residues): + raise ValueError( + f"Chain {chain_name} residue count mismatch: " + f"{len(source_residues)} in structure, " + f"{len(target_residues)} in target." + ) + + source_coords = _coords_by_chain_residue(source_residues) + for residue, atom_lookup in zip(target_residues, source_coords): + atom_start = int(residue["atom_idx"]) + atom_end = atom_start + int(residue["atom_num"]) + for atom_idx in range(atom_start, atom_end): + atom_name = str(atoms[atom_idx]["name"]).strip() + coords = atom_lookup.get(atom_name) + if coords is None: + atoms[atom_idx]["coords"] = (0.0, 0.0, 0.0) + atoms[atom_idx]["is_present"] = False + else: + atoms[atom_idx]["coords"] = coords + atoms[atom_idx]["is_present"] = True + + coords_data = [(coord,) for coord in atoms["coords"]] + coords = boltz["np"].array(coords_data, dtype=boltz["Coords"]) + return replace(target_structure, atoms=atoms, coords=coords) + + +def _transfer_receptor_ligand_coordinates( + target_structure: Any, + receptor_structure: Any, + ligand_coords_by_chain: dict[str, dict[str, tuple[float, float, float]]], + boltz: dict[str, Any], +): + model = receptor_structure[0] + atoms = target_structure.atoms.copy() + nonpolymer_type = boltz["const"].chain_type_ids["NONPOLYMER"] + + for chain in target_structure.chains: + chain_name = str(chain["name"]) + res_start = int(chain["res_idx"]) + res_end = res_start + int(chain["res_num"]) + target_residues = target_structure.residues[res_start:res_end] + + if int(chain["mol_type"]) == nonpolymer_type: + atom_lookup = ligand_coords_by_chain.get(chain_name) + if atom_lookup is None: + raise ValueError(f"Missing ligand coordinates for chain {chain_name}.") + if len(target_residues) != 1: + raise ValueError( + f"Only single-residue SDF ligands are supported; " + f"chain {chain_name} has {len(target_residues)} residues." + ) + residue_coord_blocks = [atom_lookup] + else: + if chain_name not in model: + raise ValueError(f"Receptor is missing chain {chain_name}.") + source_residues = _protein_residues(model[chain_name], boltz) + if len(source_residues) != len(target_residues): + raise ValueError( + f"Chain {chain_name} residue count mismatch: " + f"{len(source_residues)} in receptor, " + f"{len(target_residues)} in target." + ) + residue_coord_blocks = _coords_by_chain_residue(source_residues) + + for residue, atom_lookup in zip(target_residues, residue_coord_blocks): + atom_start = int(residue["atom_idx"]) + atom_end = atom_start + int(residue["atom_num"]) + for atom_idx in range(atom_start, atom_end): + atom_name = str(atoms[atom_idx]["name"]).strip() + coords = atom_lookup.get(atom_name) + if coords is None: + atoms[atom_idx]["coords"] = (0.0, 0.0, 0.0) + atoms[atom_idx]["is_present"] = False + else: + atoms[atom_idx]["coords"] = coords + atoms[atom_idx]["is_present"] = True + + coords_data = [(coord,) for coord in atoms["coords"]] + coords = boltz["np"].array(coords_data, dtype=boltz["Coords"]) + return replace(target_structure, atoms=atoms, coords=coords) + + +def _read_sdf_poses(path: Path) -> list[LigandPose]: + from rdkit import Chem + + if path.suffix.lower() != ".sdf": + raise ValueError(f"--receptor mode expects SDF ligand poses, got {path}.") + + supplier = Chem.SDMolSupplier(str(path), sanitize=True, removeHs=False) + poses = [ + LigandPose(path=path, pose_index=idx, mol=mol) + for idx, mol in enumerate(supplier, start=1) + if mol is not None + ] + if not poses: + raise ValueError(f"No readable ligand poses in {path}.") + + for pose in poses: + if pose.mol.GetNumConformers() == 0: + raise ValueError(f"Ligand pose {path}#{pose.pose_index} has no conformer.") + return poses + + +def _smiles_from_ligand_mol(mol: Any) -> str: + from rdkit import Chem + + mol_no_h = Chem.RemoveHs(Chem.Mol(mol), sanitize=False) + return Chem.MolToSmiles(mol_no_h, isomericSmiles=True) + + +def _ligand_coords_by_atom_name( + source_mol: Any, + target_mol: Any, +) -> dict[str, tuple[float, float, float]]: + from rdkit import Chem + + source_no_h = Chem.RemoveHs(Chem.Mol(source_mol), sanitize=False) + match = source_no_h.GetSubstructMatch(target_mol, useChirality=False) + if not match: + source_no_h = Chem.RemoveHs(Chem.Mol(source_mol), sanitize=True) + match = source_no_h.GetSubstructMatch(target_mol, useChirality=False) + if not match: + raise ValueError("Could not map SDF ligand atoms onto the Boltz ligand.") + if len(match) != target_mol.GetNumAtoms(): + raise ValueError("Incomplete SDF ligand atom mapping.") + + conformer = source_no_h.GetConformer() + atom_coords = {} + for target_idx, source_idx in enumerate(match): + target_atom = target_mol.GetAtomWithIdx(target_idx) + atom_name = target_atom.GetProp("name") + pos = conformer.GetAtomPosition(source_idx) + atom_coords[atom_name] = (float(pos.x), float(pos.y), float(pos.z)) + return atom_coords + + +def _ligand_target_mol(target: Any, ligand_chain_id: str) -> Any: + for chain in target.structure.chains: + if str(chain["name"]) != ligand_chain_id: + continue + res_idx = int(chain["res_idx"]) + res_name = str(target.structure.residues[res_idx]["name"]) + return target.extra_mols[res_name] + raise ValueError(f"Target is missing ligand chain {ligand_chain_id}.") + + +def _target_from_structure( + path: Path, + mols: dict[str, Any], + mol_dir: Path, + boltz: dict[str, Any], + compute_affinity: bool, +) -> Any: + record_id = _safe_id(path) + bio_structure = _read_bio_structure(path, record_id) + schema = _schema_from_bio_structure(record_id, bio_structure, boltz) + if compute_affinity: + schema = _with_affinity_property(schema) + target = boltz["parse_boltz_schema"]( + record_id, + schema, + mols, + str(mol_dir), + True, + ) + structure = _transfer_coordinates(target.structure, bio_structure, boltz) + return replace(target, structure=structure) + + +def _target_from_receptor_ligand( + record_id: str, + receptor_structure: Any, + ligand_pose: LigandPose, + ligand_chain_id: str, + mols: dict[str, Any], + mol_dir: Path, + boltz: dict[str, Any], + compute_affinity: bool, +) -> Any: + schema = { + "version": 1, + "sequences": [ + *_protein_schema_items_from_bio_structure( + record_id, + receptor_structure, + boltz, + ), + { + "ligand": { + "id": ligand_chain_id, + "smiles": _smiles_from_ligand_mol(ligand_pose.mol), + } + }, + ], + } + if compute_affinity: + schema = _with_affinity_property(schema, ligand_chain_id) + target = boltz["parse_boltz_schema"]( + record_id, + schema, + mols, + str(mol_dir), + True, + ) + ligand_coords = _ligand_coords_by_atom_name( + ligand_pose.mol, + _ligand_target_mol(target, ligand_chain_id), + ) + structure = _transfer_receptor_ligand_coordinates( + target.structure, + receptor_structure, + {ligand_chain_id: ligand_coords}, + boltz, + ) + return replace(target, structure=structure) + + +def _features_from_structure( + structure_path: Path, + tokenizer: Any, + featurizer: Any, + canonical_mols: dict[str, Any], + mol_dir: Path, + boltz: dict[str, Any], + compute_affinity: bool, +) -> tuple[dict[str, Any], str]: + record_id = _safe_id(structure_path) + target = _target_from_structure( + structure_path, + canonical_mols.copy(), + mol_dir, + boltz, + compute_affinity, + ) + + input_data = boltz["Input"]( + target.structure, + {}, + record=target.record, + residue_constraints=target.residue_constraints, + templates=target.templates, + extra_mols=target.extra_mols, + ) + tokenized = tokenizer.tokenize(input_data) + + molecules = {} + molecules.update(canonical_mols) + mol_names = set(tokenized.tokens["res_name"].tolist()) - set(molecules) + if mol_names: + molecules.update(boltz["load_molecules"](str(mol_dir), sorted(mol_names))) + + features = featurizer.process( + tokenized, + molecules=molecules, + random=boltz["np"].random.default_rng(42), + training=False, + max_atoms=None, + max_tokens=None, + max_seqs=boltz["const"].max_msa_seqs, + pad_to_max_seqs=False, + single_sequence_prop=0.0, + compute_frames=True, + inference_pocket_constraints=None, + inference_contact_constraints=None, + compute_constraint_features=True, + compute_affinity=compute_affinity, + ) + features["record"] = target.record + return features, record_id + + +def _features_from_ligand_pose( + ligand_pose: LigandPose, + receptor_structure: Any, + ligand_chain_id: str, + tokenizer: Any, + featurizer: Any, + canonical_mols: dict[str, Any], + mol_dir: Path, + boltz: dict[str, Any], + compute_affinity: bool, +) -> tuple[dict[str, Any], str]: + record_id = _safe_id(ligand_pose.path) + if ligand_pose.pose_index > 1: + record_id = f"{record_id}_pose{ligand_pose.pose_index}" + target = _target_from_receptor_ligand( + record_id, + receptor_structure, + ligand_pose, + ligand_chain_id, + canonical_mols.copy(), + mol_dir, + boltz, + compute_affinity, + ) + + input_data = boltz["Input"]( + target.structure, + {}, + record=target.record, + residue_constraints=target.residue_constraints, + templates=target.templates, + extra_mols=target.extra_mols, + ) + tokenized = tokenizer.tokenize(input_data) + + molecules = {} + molecules.update(canonical_mols) + molecules.update(target.extra_mols) + mol_names = set(tokenized.tokens["res_name"].tolist()) - set(molecules) + if mol_names: + molecules.update(boltz["load_molecules"](str(mol_dir), sorted(mol_names))) + + features = featurizer.process( + tokenized, + molecules=molecules, + random=boltz["np"].random.default_rng(42), + training=False, + max_atoms=None, + max_tokens=None, + max_seqs=boltz["const"].max_msa_seqs, + pad_to_max_seqs=False, + single_sequence_prop=0.0, + compute_frames=True, + inference_pocket_constraints=None, + inference_contact_constraints=None, + compute_constraint_features=True, + compute_affinity=compute_affinity, + ) + features["record"] = target.record + return features, record_id + + +def _load_model(args: argparse.Namespace, device: Any, boltz: dict[str, Any]) -> Any: + checkpoint = ( + args.affinity_checkpoint + if args.affinity_checkpoint is not None + else args.cache / ("boltz2_aff.ckpt" if args.affinity else "boltz2_conf.ckpt") + ) + checkpoint = checkpoint.expanduser().resolve() + if not checkpoint.exists(): + raise FileNotFoundError( + f"Boltz2 checkpoint not found: {checkpoint}. " + "Run without --no_download or provide --cache." + ) + + torch = boltz["torch"] + diffusion_params = boltz["Boltz2DiffusionParams"]() + pairformer_args = boltz["PairformerArgsV2"]() + msa_args = boltz["MSAModuleArgs"]( + subsample_msa=False, + num_subsampled_msa=boltz["const"].max_msa_seqs, + use_paired_feature=True, + ) + steering_args = boltz["BoltzSteeringParams"]() + use_kernels = args.use_kernels and ( + device.type == "cuda" + and torch.cuda.get_device_properties(device).major >= 8 + ) + model = boltz["Boltz2"].load_from_checkpoint( + checkpoint, + strict=True, + predict_args={}, + map_location="cpu", + diffusion_process_args=asdict(diffusion_params), + ema=False, + use_kernels=use_kernels, + pairformer_args=asdict(pairformer_args), + msa_args=asdict(msa_args), + steering_args=asdict(steering_args), + affinity_mw_correction=not args.no_affinity_mw_correction, + skip_run_structure=True, + ) + if args.affinity: + if not hasattr(model, "affinity_module") and not hasattr( + model, + "affinity_module1", + ): + raise RuntimeError( + f"Checkpoint does not expose a Boltz2 affinity head: {checkpoint}" + ) + # Boltz2.forward expects diffusion-produced sample_atom_coords for its + # built-in affinity path. This CLI scores externally supplied coords, so + # affinity is called explicitly after the no-diffusion forward pass. + model.affinity_prediction = False + model.eval() + model.to(device) + return model + + +def _batch_to_device(batch: dict[str, Any], device: Any, torch: Any) -> dict[str, Any]: + moved = {} + for key, value in batch.items(): + moved[key] = value.to(device) if isinstance(value, torch.Tensor) else value + return moved + + +def _batch_signature(batch: dict[str, Any]) -> str: + """Create a topology signature used to guard trunk reuse.""" + hasher = hashlib.sha256() + for key in ["res_type", "asym_id", "mol_type", "token_pad_mask", "atom_pad_mask"]: + value = batch[key].detach().cpu().contiguous() + hasher.update(key.encode()) + hasher.update(str(tuple(value.shape)).encode()) + hasher.update(value.numpy().tobytes()) + + atom_to_token = batch["atom_to_token"].detach().cpu().argmax(dim=-1).contiguous() + hasher.update(b"atom_to_token_argmax") + hasher.update(str(tuple(atom_to_token.shape)).encode()) + hasher.update(atom_to_token.numpy().tobytes()) + return hasher.hexdigest() + + +def _confidence_from_cached_trunk( + model: Any, + batch: dict[str, Any], + cached: dict[str, Any], +) -> dict[str, Any]: + s_inputs = model.input_embedder(batch) + out = model.confidence_module( + s_inputs=s_inputs.detach(), + s=cached["s"].detach(), + z=cached["z"].detach(), + x_pred=batch["coords"].repeat_interleave(1, 0), + feats=batch, + pred_distogram_logits=cached["pdistogram"][:, :, :, 0].detach(), + multiplicity=1, + run_sequentially=True, + use_kernels=model.use_kernels, + ) + out.update({ + "s": cached["s"], + "z": cached["z"], + "pdistogram": cached["pdistogram"], + }) + return out + + +def _affinity_from_existing_coords( + model: Any, + batch: dict[str, Any], + out: dict[str, Any], + torch: Any, +) -> dict[str, Any]: + if "affinity_token_mask" not in batch: + raise ValueError( + "Affinity features are missing. Build features with compute_affinity=True." + ) + + pad_token_mask = batch["token_pad_mask"][0] + rec_mask = (batch["mol_type"][0] == 0) * pad_token_mask + lig_mask = batch["affinity_token_mask"][0].to(torch.bool) * pad_token_mask + cross_pair_mask = ( + lig_mask[:, None] * rec_mask[None, :] + + rec_mask[:, None] * lig_mask[None, :] + + lig_mask[:, None] * lig_mask[None, :] + ) + z_affinity = out["z"] * cross_pair_mask[None, :, :, None] + coords_affinity = batch["coords"].detach() + if coords_affinity.ndim == 3: + coords_affinity = coords_affinity[:, None] + if coords_affinity.ndim != 4: + raise ValueError( + f"Expected coordinate tensor with 3 or 4 dimensions, got " + f"{tuple(coords_affinity.shape)}." + ) + + s_inputs = model.input_embedder(batch, affinity=True) + if getattr(model, "affinity_ensemble", False): + dict_out_affinity1 = model.affinity_module1( + s_inputs=s_inputs.detach(), + z=z_affinity.detach(), + x_pred=coords_affinity, + feats=batch, + multiplicity=1, + use_kernels=model.use_kernels, + ) + dict_out_affinity2 = model.affinity_module2( + s_inputs=s_inputs.detach(), + z=z_affinity.detach(), + x_pred=coords_affinity, + feats=batch, + multiplicity=1, + use_kernels=model.use_kernels, + ) + dict_out_affinity1["affinity_probability_binary"] = torch.sigmoid( + dict_out_affinity1["affinity_logits_binary"] + ) + dict_out_affinity2["affinity_probability_binary"] = torch.sigmoid( + dict_out_affinity2["affinity_logits_binary"] + ) + affinity_out = { + "affinity_pred_value": ( + dict_out_affinity1["affinity_pred_value"] + + dict_out_affinity2["affinity_pred_value"] + ) + / 2, + "affinity_probability_binary": ( + dict_out_affinity1["affinity_probability_binary"] + + dict_out_affinity2["affinity_probability_binary"] + ) + / 2, + "affinity_pred_value1": dict_out_affinity1["affinity_pred_value"], + "affinity_probability_binary1": dict_out_affinity1[ + "affinity_probability_binary" + ], + "affinity_pred_value2": dict_out_affinity2["affinity_pred_value"], + "affinity_probability_binary2": dict_out_affinity2[ + "affinity_probability_binary" + ], + } + if getattr(model, "affinity_mw_correction", False): + model_coef = 1.03525938 + mw_coef = -0.59992683 + bias = 2.83288489 + mw = batch["affinity_mw"][0] ** 0.3 + affinity_out["affinity_pred_value"] = ( + model_coef * affinity_out["affinity_pred_value"] + mw_coef * mw + bias + ) + return affinity_out + + dict_out_affinity = model.affinity_module( + s_inputs=s_inputs.detach(), + z=z_affinity.detach(), + x_pred=coords_affinity, + feats=batch, + multiplicity=1, + use_kernels=model.use_kernels, + ) + return { + "affinity_pred_value": dict_out_affinity["affinity_pred_value"], + "affinity_probability_binary": torch.sigmoid( + dict_out_affinity["affinity_logits_binary"] + ), + } + + +def _scalar(out: dict[str, Any], key: str) -> float: + value = out[key] + if hasattr(value, "detach"): + value = value.detach().float().cpu().reshape(-1) + return float(value[0].item()) + return float(value) + + +def _confidence_summary(out: dict[str, Any], torch: Any) -> dict[str, float]: + iptm = out["iptm"] + ptm = out["ptm"] + ranking_term = iptm if not torch.allclose(iptm, torch.zeros_like(iptm)) else ptm + confidence_score = (4 * out["complex_plddt"] + ranking_term) / 5 + scored = { + key: _scalar(out, key) + for key in SUMMARY_KEYS + if key != "confidence_score" + } + scored["confidence_score"] = float( + confidence_score.detach().float().cpu().reshape(-1)[0].item() + ) + return {key: scored[key] for key in SUMMARY_KEYS} + + +def _affinity_summary(out: dict[str, Any]) -> dict[str, float]: + return { + key: _scalar(out, key) + for key in AFFINITY_KEYS + if key in out + } + + +def _serialise_pair_chains_iptm(value: Any) -> Any: + if hasattr(value, "detach"): + return value.detach().float().cpu().tolist() + if isinstance(value, dict): + return { + str(k): _serialise_pair_chains_iptm(v) + for k, v in value.items() + } + return value + + +def _write_outputs( + out_dir: Path, + record_id: str, + source_path: Path, + summary: dict[str, float], + out: dict[str, Any], + write_full_pae: bool, + write_full_pde: bool, +) -> dict[str, Any]: + import numpy as np + + record_dir = out_dir / record_id + record_dir.mkdir(parents=True, exist_ok=True) + + json_payload = { + "source_structure": str(source_path), + **summary, + "pair_chains_iptm": _serialise_pair_chains_iptm( + out.get("pair_chains_iptm", {}) + ), + } + confidence_path = record_dir / f"confidence_{record_id}.json" + confidence_path.write_text(json.dumps(json_payload, indent=4)) + + plddt = out["plddt"].detach().float().cpu().numpy() + np.savez_compressed(record_dir / f"plddt_{record_id}.npz", plddt=plddt) + + if write_full_pae and "pae" in out: + pae = out["pae"].detach().float().cpu().numpy() + np.savez_compressed(record_dir / f"pae_{record_id}.npz", pae=pae) + if write_full_pde and "pde" in out: + pde = out["pde"].detach().float().cpu().numpy() + np.savez_compressed(record_dir / f"pde_{record_id}.npz", pde=pde) + + if "affinity_pred_value" in summary: + affinity_payload = { + key: summary[key] + for key in AFFINITY_KEYS + if key in summary + } + affinity_path = record_dir / f"affinity_{record_id}.json" + affinity_path.write_text(json.dumps(affinity_payload, indent=4)) + else: + affinity_path = None + + return { + "structure": str(source_path), + "record_id": record_id, + **summary, + "confidence_json": str(confidence_path), + **({"affinity_json": str(affinity_path)} if affinity_path else {}), + } + + +def score_existing_complexes(args: argparse.Namespace) -> list[dict[str, Any]]: + boltz = _import_boltz() + torch = boltz["torch"] + device = _select_device(torch, args.device) + args.cache = args.cache.expanduser().resolve() + args.cache.mkdir(parents=True, exist_ok=True) + mol_dir = args.cache / "mols" + if not args.no_download: + boltz["download_boltz2"](args.cache) + if not mol_dir.exists(): + raise FileNotFoundError( + f"Boltz2 molecule cache not found: {mol_dir}. " + "Run without --no_download or provide --cache." + ) + + args.out_dir.mkdir(parents=True, exist_ok=True) + canonical_mols = boltz["load_canonicals"](str(mol_dir)) + tokenizer = boltz["Boltz2Tokenizer"]() + featurizer = boltz["Boltz2Featurizer"]() + model = _load_model(args, device, boltz) + receptor_structure = None + if args.receptor is not None: + args.receptor = args.receptor.expanduser().resolve() + receptor_structure = _read_bio_structure(args.receptor, "receptor") + + rows = [] + trunk_cache = None + trunk_signature = None + seen_record_ids: dict[str, int] = {} + with torch.inference_mode(): + for input_path in args.structures: + input_path = input_path.expanduser().resolve() + feature_inputs: list[Path | LigandPose] + if receptor_structure is None: + feature_inputs = [input_path] + else: + feature_inputs = list(_read_sdf_poses(input_path)) + + for feature_input in feature_inputs: + source_path = ( + feature_input.path + if isinstance(feature_input, LigandPose) + else feature_input + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + if isinstance(feature_input, LigandPose): + assert receptor_structure is not None + features, record_id = _features_from_ligand_pose( + feature_input, + receptor_structure, + args.ligand_chain_id, + tokenizer, + featurizer, + canonical_mols, + mol_dir, + boltz, + args.affinity, + ) + else: + features, record_id = _features_from_structure( + feature_input, + tokenizer, + featurizer, + canonical_mols, + mol_dir, + boltz, + args.affinity, + ) + batch = boltz["collate"]([features]) + batch = _batch_to_device(batch, device, torch) + signature = _batch_signature(batch) + if args.reuse_trunk and trunk_cache is not None: + if signature != trunk_signature: + raise ValueError( + "--reuse_trunk requires all structures to have the same " + "token and atom topology as the first structure." + ) + out = _confidence_from_cached_trunk(model, batch, trunk_cache) + else: + out = model( + batch, + recycling_steps=args.recycling_steps, + num_sampling_steps=None, + diffusion_samples=1, + max_parallel_samples=None, + run_confidence_sequentially=True, + ) + if args.reuse_trunk and trunk_cache is None: + trunk_cache = { + "s": out["s"], + "z": out["z"], + "pdistogram": out["pdistogram"], + } + trunk_signature = signature + summary = _confidence_summary(out, torch) + if args.affinity: + affinity_out = _affinity_from_existing_coords( + model, + batch, + out, + torch, + ) + out.update(affinity_out) + summary.update(_affinity_summary(out)) + + seen_count = seen_record_ids.get(record_id, 0) + seen_record_ids[record_id] = seen_count + 1 + output_id = ( + record_id if seen_count == 0 else f"{record_id}_{seen_count}" + ) + rows.append( + _write_outputs( + args.out_dir, + output_id, + source_path, + summary, + out, + args.write_full_pae, + args.write_full_pde, + ) + ) + + summary_path = args.out_dir / "scores.csv" + with summary_path.open("w", newline="") as handle: + writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys())) + writer.writeheader() + writer.writerows(rows) + return rows + + +def main(argv: list[str] | None = None) -> None: + args = _parse_args(argv) + rows = score_existing_complexes(args) + print(f"Wrote scores for {len(rows)} structures to {args.out_dir / 'scores.csv'}") + + +if __name__ == "__main__": + main() diff --git a/docs/boltz_crystal_docking.md b/docs/boltz_crystal_docking.md new file mode 100644 index 0000000..aa43f7a --- /dev/null +++ b/docs/boltz_crystal_docking.md @@ -0,0 +1,130 @@ +# Boltz Crystal-Pocket Docking + +`abcfold.boltz.dock_crystal` runs Boltz structure generation with a fixed +crystal receptor as a template, a ligand SMILES, and pocket constraints. Use it +when you know the receptor structure and pocket, but want Boltz to generate the +ligand pose instead of only scoring an existing pose. + +This is a Boltz-native docking/co-folding mode. It is not a classical force +field minimizer: Boltz still runs diffusion sampling, but the protein is guided +toward the crystal template and the ligand is guided into the pocket. + +## Command + +Installed ABCFold environments expose this as: + +```bash +boltz-dock-crystal --help +``` + +For a local checkout managed with `uv`, use either: + +```bash +uv run boltz-dock-crystal --help +``` + +or activate the environment before using the command directly: + +```bash +source .venv/bin/activate +boltz-dock-crystal --help +``` + +The development-module equivalent is: + +```bash +python -m abcfold.boltz.dock_crystal --help +``` + +## Explicit Pocket Residues + +```bash +boltz-dock-crystal \ + crystal_receptor.pdb \ + "CCOc1ccc(...)" \ + --protein_chain A \ + --pocket_residue A:145 \ + --pocket_residue A:146 \ + --pocket_residue A:189 \ + --out_dir boltz_crystal_dock \ + --affinity +``` + +`--pocket_residue` uses PDB residue numbering by default. The wrapper converts +those residues to the sequence indices required by Boltz constraints. Use +`--pocket_numbering sequence` if the input numbers are already Boltz sequence +indices. + +## Infer Pocket From a Crystal Ligand + +If the receptor PDB still contains a reference ligand chain, the wrapper can +infer pocket residues by distance: + +```bash +boltz-dock-crystal \ + crystal_complex.pdb \ + "CCOc1ccc(...)" \ + --protein_chain A \ + --reference_ligand_chain L \ + --pocket_cutoff 6.0 \ + --out_dir boltz_crystal_dock \ + --affinity +``` + +Only protein chains are written to the Boltz sequence section. The reference +ligand is used to choose pocket residues; the docked ligand still comes from +the SMILES argument. + +## Template and Pocket Strength + +By default, the generated YAML includes: + +- `templates.force: true` +- `templates.threshold: 1.0` +- `constraints.pocket.force: true` +- `constraints.pocket.max_distance: 6.0` + +These settings keep the protein close to the crystal receptor and steer the +ligand into the pocket. They can be relaxed: + +```bash +boltz-dock-crystal \ + crystal_receptor.pdb \ + "CCOc1ccc(...)" \ + --pocket_residue A:145 \ + --template_threshold 2.0 \ + --max_distance 8.0 +``` + +Use `--no_force_template` or `--no_force_pocket` only when you want Boltz to be +less constrained. + +## Accuracy and Runtime + +This mode runs Boltz diffusion, so it is much slower than +`abcfold.boltz.score_existing`, which only scores supplied coordinates. The +default docking settings use: + +- `--diffusion_samples 25` +- `--recycling_steps 10` +- `--sampling_steps 200` +- `--use_potentials` + +For a quick dry run that only writes the Boltz YAML and command: + +```bash +boltz-dock-crystal \ + crystal_receptor.pdb \ + "CCOc1ccc(...)" \ + --pocket_residue A:145 \ + --dry_run +``` + +The output directory contains: + +- `boltz_crystal_dock.yaml` +- `boltz_crystal_dock_command.json` +- Boltz prediction outputs when `--dry_run` is not used + +Use `--use_msa_server` to let Boltz fetch MSAs. Without it, the wrapper writes +`msa: empty` for speed and offline execution. diff --git a/docs/boltz_existing_scoring.md b/docs/boltz_existing_scoring.md new file mode 100644 index 0000000..3a9a69e --- /dev/null +++ b/docs/boltz_existing_scoring.md @@ -0,0 +1,185 @@ +# Boltz Existing-Structure Scoring + +`abcfold.boltz.score_existing` scores already-built complexes with Boltz2 +without running Boltz diffusion sampling. Use it when protein or complex +coordinates already exist and you want Boltz2 confidence and, optionally, +Boltz2 affinity estimates for those coordinates. + +This is a scoring utility, not a structure-generation or local-minimization +tool. It does not move the protein or ligand. + +## Command + +Installed ABCFold environments expose this as: + +```bash +boltz-score-existing --help +``` + +For a local checkout managed with `uv`, use either: + +```bash +uv run boltz-score-existing --help +``` + +or activate the environment before using the command directly: + +```bash +source .venv/bin/activate +boltz-score-existing --help +``` + +The development-module equivalent is: + +```bash +python -m abcfold.boltz.score_existing --help +``` + +## What It Computes + +By default, the command writes Boltz2 confidence metrics: + +- `confidence_score` +- `ptm` +- `iptm` +- `ligand_iptm` +- `protein_iptm` +- `complex_plddt` +- `complex_iplddt` +- `complex_pde` +- `complex_ipde` + +With `--affinity`, it also writes Boltz2 affinity outputs from +`boltz2_aff.ckpt`: + +- `affinity_pred_value` +- `affinity_probability_binary` +- `affinity_pred_value1` +- `affinity_probability_binary1` +- `affinity_pred_value2` +- `affinity_probability_binary2` + +The affinity mode uses the supplied coordinates directly. It does not run +diffusion to generate `sample_atom_coords`. + +## Score Ready PDB or mmCIF Complexes + +Use this mode when each input file already contains the protein and ligand +chains in one `.pdb`, `.cif`, or `.mmcif` file. + +```bash +boltz-score-existing \ + complex_1.cif complex_2.cif \ + --out_dir boltz_existing_scores \ + --cache ~/.boltz \ + --device cuda \ + --no_download +``` + +For affinity estimates on the same coordinates: + +```bash +boltz-score-existing \ + complex_1.cif complex_2.cif \ + --out_dir boltz_existing_scores_affinity \ + --cache ~/.boltz \ + --device cuda \ + --no_download \ + --affinity +``` + +Affinity scoring for ready complex files currently requires exactly one ligand +chain unless a future wrapper identifies the binder explicitly. + +## Score Receptor PDB Plus Ligand SDF Poses + +Use this mode for docking-style outputs where the receptor is fixed and ligand +poses are stored as SDF files. This is the format used by DEKOIS2/Matcha-style +and HEDGEHOG-style pose scoring. + +```bash +boltz-score-existing \ + poses.sdf \ + --receptor receptor.pdb \ + --out_dir boltz_pose_scores \ + --cache ~/.boltz \ + --device cuda \ + --no_download +``` + +For confidence plus affinity: + +```bash +boltz-score-existing \ + poses.sdf \ + --receptor receptor.pdb \ + --out_dir boltz_pose_scores_affinity \ + --cache ~/.boltz \ + --device cuda \ + --no_download \ + --affinity +``` + +If several SDF files are provided, each readable conformer is scored. The +ligand is represented from the SDF molecule, while the receptor coordinates are +read from the receptor PDB. + +## Reusing the Trunk + +`--reuse_trunk` caches the first trunk output and reuses it for later inputs. +This is only valid when all scored structures have the same token and atom +topology, for example multiple poses of the same ligand against the same +protein. + +```bash +boltz-score-existing \ + same_ligand_poses.sdf \ + --receptor receptor.pdb \ + --out_dir boltz_pose_scores_reuse \ + --cache ~/.boltz \ + --device cuda \ + --no_download \ + --affinity \ + --reuse_trunk +``` + +Do not use `--reuse_trunk` for different ligands. The command checks the input +topology and fails if later structures do not match the first one. + +## Outputs + +The output directory contains: + +- `scores.csv`: one summary row per scored structure or pose. +- `/confidence_.json`: confidence metrics and + pair-chain `ipTM`. +- `/plddt_.npz`: compressed pLDDT array. +- `/affinity_.json`: affinity outputs when `--affinity` + is enabled. +- Optional `pae_*.npz` and `pde_*.npz` files when `--write_full_pae` or + `--write_full_pde` is enabled. + +## Runtime Notes + +The command avoids the expensive Boltz diffusion sampling path, so it is faster +than full Boltz structure prediction. It still loads the Boltz2 checkpoint and +runs the trunk, so the first score is not instantaneous. On GPU, scoring +additional poses in the same process is much cheaper than the cold start. + +The utility currently sets protein MSAs to `empty`, so proteins are scored in +single-sequence mode. That is useful for fast pose screening, but confidence and +affinity values should be calibrated against the intended benchmark before +using them as final decision metrics. + +## Required Cache Files + +The default cache is `~/.boltz`. For confidence scoring, it must contain: + +- `mols/` +- `boltz2_conf.ckpt` + +For `--affinity`, it must also contain: + +- `boltz2_aff.ckpt` + +Omit `--no_download` if the cache should be populated automatically by Boltz. diff --git a/pyproject.toml b/pyproject.toml index 8f04af3..1f20e46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,8 @@ abcfold = ["data/config.ini"] [project.scripts] abcfold = "abcfold.abcfold:main" +boltz-score-existing = "abcfold.boltz.score_existing:main" +boltz-dock-crystal = "abcfold.boltz.dock_crystal:main" mmseqs2msa = "abcfold.scripts.add_mmseqs_msa:main" custom_templates = "abcfold.scripts.add_custom_template:main" ipsae = "abcfold.scripts.ipsae:main" diff --git a/tests/test_boltz_dock_crystal.py b/tests/test_boltz_dock_crystal.py new file mode 100644 index 0000000..60afdf3 --- /dev/null +++ b/tests/test_boltz_dock_crystal.py @@ -0,0 +1,115 @@ +from pathlib import Path + +from abcfold.boltz.dock_crystal import (_parse_args, + generate_boltz_crystal_dock_command, + prepare_crystal_docking_input) + + +def atom_line( + record: str, + serial: int, + atom: str, + resname: str, + chain: str, + resnum: int, + x_coord: float, + y_coord: float, + element: str, +) -> str: + return ( + f"{record:<6}{serial:5d} {atom:<4} {resname:>3} {chain}{resnum:4d}" + f" {x_coord:8.3f}{y_coord:8.3f}{0.0:8.3f}" + f" 1.00 20.00 {element:>2}" + ) + + +def write_receptor(path: Path) -> None: + path.write_text( + "\n".join([ + atom_line("ATOM", 1, "N", "ALA", "A", 1, 0.0, 0.0, "N"), + atom_line("ATOM", 2, "CA", "ALA", "A", 1, 1.5, 0.0, "C"), + atom_line("ATOM", 3, "C", "ALA", "A", 1, 2.5, 1.0, "C"), + atom_line("ATOM", 4, "N", "GLY", "A", 2, 4.0, 1.0, "N"), + atom_line("ATOM", 5, "CA", "GLY", "A", 2, 5.0, 2.0, "C"), + atom_line("ATOM", 6, "C", "GLY", "A", 2, 6.0, 1.0, "C"), + atom_line("ATOM", 7, "N", "SER", "A", 3, 8.0, 1.0, "N"), + atom_line("ATOM", 8, "CA", "SER", "A", 3, 9.0, 2.0, "C"), + atom_line("ATOM", 9, "C", "SER", "A", 3, 10.0, 1.0, "C"), + atom_line("HETATM", 10, "C1", "LIG", "L", 1, 5.1, 2.1, "C"), + atom_line("HETATM", 11, "O1", "LIG", "L", 1, 5.8, 2.1, "O"), + "TER", + "END", + "", + ]) + ) + + +def test_prepare_crystal_docking_input_with_pdb_numbered_pocket(tmp_path): + receptor = tmp_path / "receptor.pdb" + write_receptor(receptor) + args = _parse_args([ + str(receptor), + "CCO", + "--out_dir", + str(tmp_path / "out"), + "--pocket_residue", + "A:2", + "--affinity", + "--dry_run", + ]) + + prepared = prepare_crystal_docking_input(args) + yaml_text = prepared.yaml_path.read_text() + + assert prepared.contacts == [["A", 2]] + assert "sequence: \"AGS\"" in yaml_text + assert "smiles: \"CCO\"" in yaml_text + assert "contacts: [[\"A\", 2]]" in yaml_text + assert "templates:" in yaml_text + assert "template_id: \"A1\"" in yaml_text + assert "properties:" in yaml_text + assert "--use_potentials" in prepared.command + + +def test_prepare_crystal_docking_input_can_infer_pocket(tmp_path): + receptor = tmp_path / "receptor.pdb" + write_receptor(receptor) + args = _parse_args([ + str(receptor), + "CCO", + "--out_dir", + str(tmp_path / "out"), + "--reference_ligand_chain", + "L", + "--pocket_cutoff", + "1.0", + "--dry_run", + ]) + + prepared = prepare_crystal_docking_input(args) + + assert prepared.contacts == [["A", 2]] + + +def test_generate_boltz_crystal_dock_command(tmp_path): + args = _parse_args([ + str(tmp_path / "receptor.pdb"), + "CCO", + "--pocket_residue", + "A:1", + "--use_msa_server", + "--step_scale", + "1.5", + "--dry_run", + ]) + + cmd = generate_boltz_crystal_dock_command( + tmp_path / "input.yaml", + tmp_path / "out", + args, + ) + + assert cmd[:2] == ["boltz", "predict"] + assert "--use_msa_server" in cmd + assert "--use_potentials" in cmd + assert "--step_scale" in cmd diff --git a/tests/test_boltz_entrypoints.py b/tests/test_boltz_entrypoints.py new file mode 100644 index 0000000..9e620e9 --- /dev/null +++ b/tests/test_boltz_entrypoints.py @@ -0,0 +1,11 @@ +from importlib.metadata import entry_points + + +def test_boltz_console_entrypoints_are_registered(): + scripts = { + entry_point.name: entry_point.value + for entry_point in entry_points(group="console_scripts") + } + + assert scripts["boltz-score-existing"] == "abcfold.boltz.score_existing:main" + assert scripts["boltz-dock-crystal"] == "abcfold.boltz.dock_crystal:main"