diff --git a/README.md b/README.md index 2a03fd1..32340b0 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,14 @@ However, there you may wish to use the following flags to add run time options s #### OpenFold3 arguments - `--inference_ckpt_path` [optional] Path for model checkpoint to be used for inference. If not specified, will attempt to find or download parameters in ~/.openfold3/ +#### Chai-1 fast runner + +`--chai1` uses ABCFold's persistent Chai worker by default. The worker keeps +the same Chai quality settings as the standard runner while reducing repeated +process setup overhead and distributing seed jobs across the requested +`--gpus` slots. `--chai_fast` was intentionally removed; use `--chai1`. +See [Chai Fast Runner](docs/chai_fast_runner.md) for smoke checks and details. + #### Template arguments - `--templates`: Flag to enable a template search diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 8a06eb4..77e5794 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -209,7 +209,7 @@ def run(args, config, defaults, config_file): successful_runs.append(boltz_success) if args.chai1: - from abcfold.chai1.run_chai1 import run_chai + from abcfold.chai1.run_chai1_fast import run_chai_fast as run_chai template_hits_path = None if args.templates and args.mmseqs2: diff --git a/abcfold/chai1/run_chai1.py b/abcfold/chai1/run_chai1.py deleted file mode 100644 index 5217b08..0000000 --- a/abcfold/chai1/run_chai1.py +++ /dev/null @@ -1,219 +0,0 @@ -import logging -import os -import shutil -import subprocess -import tempfile -from pathlib import Path -from typing import Union - -from abcfold.chai1.af3_to_chai import ChaiFasta -from abcfold.chai1.check_install import ensure_chai_env - -logger = logging.getLogger("logger") -os.environ["DISABLE_PANDERA_IMPORT_WARNING"] = "True" - - -def normalize_device(gpus: str | None) -> str | None: - if gpus is None: - return None - if gpus == "cpu": - return "cpu" - if gpus == "all": - return "cuda" - - # Validate and normalize the GPU list. - gpu_ids = [] - for gpu in gpus.split(","): - gpu = gpu.strip() - if not gpu.isdigit(): - raise ValueError(f"Invalid GPU ID: {gpu}") - gpu_ids.append(gpu) - - # Chai accepts a single device, so use the first requested GPU. - return f"cuda:{gpu_ids[0]}" - - -def run_chai( - input_json: Union[str, Path], - output_dir: Union[str, Path], - config: dict, - save_input: bool = False, - test: bool = False, - number_of_models: int = 5, - num_recycles: int = 10, - use_templates_server: bool = False, - template_hits_path: Path | None = None, - device: str | None = None, -) -> bool: - """ - Run Chai-1 using the input JSON file - - Args: - input_json (Union[str, Path]): Path to the input JSON file - output_dir (Union[str, Path]): Path to the output directory - config (dict): Configuration dictionary - save_input (bool): If True, save the input fasta file and MSA to the output - directory - test (bool): If True, run the test command - number_of_models (int): Number of models to generate - num_recycles (int): Number of trunk recycles - use_templates_server (bool): If True, use templates from the server - template_hits_path (Path): Path to the template hits m8 file - device (str | None): If specified, use the specified GPU - - Returns: - Bool: True if the Chai-1 run was successful, False otherwise - - """ - input_json = Path(input_json) - output_dir = Path(output_dir) - - logger.debug("Checking if Chai-1 is installed") - env = ensure_chai_env(config=config) - - with tempfile.TemporaryDirectory() as temp_dir: - working_dir = Path(temp_dir) - if save_input: - logger.info("Saving input fasta file and msa to the output directory") - working_dir = output_dir - working_dir.mkdir(parents=True, exist_ok=True) - - chai_fasta = ChaiFasta(working_dir, config=config) - chai_fasta.json_to_fasta(input_json) - - out_fasta = chai_fasta.fasta - msa_dir = chai_fasta.working_dir - out_constraints = chai_fasta.constraints - - normalized_device = normalize_device(device) - for seed in chai_fasta.seeds: - chai_output_dir = output_dir / f"chai_output_seed-{seed}" - - logger.info(f"Running Chai-1 using seed {seed}") - cmd = ( - generate_chai_command( - out_fasta, - msa_dir, - out_constraints, - chai_output_dir, - number_of_models, - num_recycles=num_recycles, - seed=seed, - use_templates_server=use_templates_server, - template_hits_path=template_hits_path, - device=normalized_device, - ) - if not test - else generate_chai_test_command() - ) - - try: - env.run(cmd) - except subprocess.CalledProcessError as e: - stderr = e.stderr or "" - if stderr: - if chai_output_dir.exists(): - output_err_file = chai_output_dir / "chai_error.log" - else: - output_err_file = chai_output_dir.parent / "chai_error.log" - output_err_file.write_text(stderr) - logger.error( - "Chai-1 run failed. Error log is in %s", output_err_file - ) - else: - logger.error("Chai-1 run failed") - return False - - logger.info("Chai-1 run complete") - return True - - -def generate_chai_command( - input_fasta: Union[str, Path], - msa_dir: Union[str, Path], - input_constraints: Union[str, Path], - output_dir: Union[str, Path], - number_of_models: int = 5, - num_recycles: int = 10, - seed: int = 42, - use_templates_server: bool = False, - template_hits_path: Path | None = None, - device: str | None = None, -) -> list: - """ - Generate the Chai-1 command - - Args: - input_fasta (Union[str, Path]): Path to the input fasta file - msa_dir (Union[str, Path]): Path to the MSA directory - input_constraints (Union[str, Path]): Path to the input constraints file - output_dir (Union[str, Path]): Path to the output directory - number_of_models (int): Number of models to generate - num_recycles (int): Number of trunk recycles - seed (int): Seed for the random number generator - use_templates_server (bool): If True, use templates from the server - template_hits_path (Path): Path to the template hits m8 file - device (str | None): If specified, use the specified GPU - - Returns: - list: The Chai-1 command - - """ - - chai_exe = Path(__file__).parent / "chai.py" - cmd = [ - "python", - str(chai_exe), - "fold", - str(input_fasta) - ] - - if Path(msa_dir).exists(): - cmd += ["--msa-directory", str(msa_dir)] - if Path(input_constraints).exists(): - cmd += ["--constraint-path", str(input_constraints)] - - cmd += ["--num-diffn-samples", str(number_of_models)] - # Do not lower this without full validation: 5 diffusion steps produced physically invalid structures. - cmd += ["--num-diffn-timesteps", "200"] - cmd += ["--num-trunk-recycles", str(num_recycles)] - cmd += ["--seed", str(seed)] - - assert not ( - use_templates_server and template_hits_path - ), "Cannot specify both templates server and path" - - if shutil.which("kalign") is None and (use_templates_server or template_hits_path): - logger.warning( - "kalign not found, skipping template search kalign is required. \ -Please install kalign to use templates with Chai-1." - ) - else: - if use_templates_server: - cmd += ["--use-templates-server"] - if template_hits_path: - cmd += ["--template-hits-path", str(template_hits_path)] - - if device is not None and device != "all": - cmd += ["--device", device] - - cmd += [str(output_dir)] - - return cmd - - -def generate_chai_test_command() -> list: - """ - Generate the Chai-1 test command - - Args: - None - - Returns: - list: The Chai-1 test command - """ - return [ - "chai-lab", - "fold", - "--help", - ] diff --git a/abcfold/chai1/run_chai1_fast.py b/abcfold/chai1/run_chai1_fast.py new file mode 100644 index 0000000..6598207 --- /dev/null +++ b/abcfold/chai1/run_chai1_fast.py @@ -0,0 +1,312 @@ +import json +import logging +import subprocess +import tempfile +import textwrap +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Iterable, Union + +from abcfold.chai1.af3_to_chai import ChaiFasta +from abcfold.chai1.check_install import ensure_chai_env + +logger = logging.getLogger("logger") + + +def normalize_device(device: str | None) -> str | None: + if device is None: + return None + if device == "cpu": + return "cpu" + if device == "all": + return "cuda" + if "," in device: + device = device.split(",")[0] + if device.isdigit(): + return f"cuda:{device}" + return device + + +def normalize_devices(gpus: str | None) -> list[str | None]: + if gpus is None: + return [None] + if gpus == "cpu": + return ["cpu"] + if gpus == "all": + return ["cuda"] + + devices: list[str | None] = [] + for gpu in gpus.split(","): + device = normalize_device(gpu.strip()) + if device is not None: + devices.append(device) + if not devices: + raise ValueError("No valid Chai devices were requested.") + return devices + + +def run_chai_fast( + input_json: Union[str, Path], + output_dir: Union[str, Path], + config: dict, + save_input: bool = False, + test: bool = False, + number_of_models: int = 5, + num_recycles: int = 10, + use_templates_server: bool = False, + template_hits_path: Path | None = None, + device: str | None = None, +) -> bool: + return run_chai_batch( + [input_json], + output_dir, + config=config, + save_input=save_input, + test=test, + number_of_models=number_of_models, + num_recycles=num_recycles, + use_templates_server=use_templates_server, + template_hits_paths=( + {Path(input_json): template_hits_path} + if template_hits_path is not None + else None + ), + devices=device, + nested_outputs=False, + postprocess=False, + ) + + +def run_chai_batch( + input_jsons: Iterable[Union[str, Path]], + output_dir: Union[str, Path], + config: dict, + save_input: bool = False, + test: bool = False, + number_of_models: int = 5, + num_recycles: int = 10, + use_templates_server: bool = False, + template_hits_paths: dict[Path, Path | None] | None = None, + devices: str | None = None, + nested_outputs: bool = True, + postprocess: bool = True, +) -> bool: + input_paths = [Path(path) for path in input_jsons] + if not input_paths: + logger.error("No Chai input JSON files were provided") + return False + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + device_slots = normalize_devices(devices) + + env = None + if not test: + logger.debug("Checking if Chai-1 is installed") + env = ensure_chai_env(config=config) + + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + tasks = _prepare_tasks( + input_paths=input_paths, + output_dir=output_dir, + config=config, + save_input=save_input, + use_templates_server=use_templates_server, + template_hits_paths=template_hits_paths, + nested_outputs=nested_outputs, + work_root=temp_dir, + ) + if test: + logger.info("Skipping Chai fast backend execution in test mode") + return True + + manifests = _write_device_manifests( + tasks=tasks, + devices=device_slots, + work_root=temp_dir, + number_of_models=number_of_models, + num_recycles=num_recycles, + ) + worker_script = _write_worker_script(temp_dir) + repo_root = Path(__file__).resolve().parents[2] + + def run_manifest(manifest_path: Path) -> bool: + assert env is not None + cmd = [ + "python", + str(worker_script), + "--manifest", + str(manifest_path), + "--repo-root", + str(repo_root), + ] + try: + env.run(cmd) + except subprocess.CalledProcessError as e: + _write_error_log(Path(output_dir), e) + return False + return True + + if len(manifests) > 1: + with ThreadPoolExecutor(max_workers=len(manifests)) as pool: + futures = [ + pool.submit(run_manifest, manifest) for manifest in manifests + ] + if not all(future.result() for future in as_completed(futures)): + return False + else: + if not run_manifest(manifests[0]): + return False + + if postprocess: + _postprocess_cases(tasks, config=config, save_input=save_input) + + logger.info("Chai fast run complete") + return True + + +def _prepare_tasks( + input_paths: list[Path], + output_dir: Path, + config: dict, + save_input: bool, + use_templates_server: bool, + template_hits_paths: dict[Path, Path | None] | None, + nested_outputs: bool, + work_root: Path, +) -> list[dict]: + tasks = [] + for input_json in input_paths: + with input_json.open("r") as f: + input_params = json.load(f) + case_id = input_params.get("name") or input_json.stem + case_output_dir = ( + output_dir / "outputs" / case_id if nested_outputs else output_dir + ) + case_output_dir.mkdir(parents=True, exist_ok=True) + working_dir = case_output_dir if save_input else work_root / case_id + working_dir.mkdir(parents=True, exist_ok=True) + + chai_fasta = ChaiFasta(working_dir, config=config) + chai_fasta.json_to_fasta(input_json) + template_hits_path = ( + template_hits_paths.get(input_json) if template_hits_paths else None + ) + + for seed in chai_fasta.seeds: + seed_output_dir = case_output_dir / f"chai_output_seed-{seed}" + seed_output_dir.mkdir(parents=True, exist_ok=True) + tasks.append( + { + "input_json": str(input_json), + "case_id": case_id, + "case_output_dir": str(case_output_dir), + "fasta": str(chai_fasta.fasta), + "msa_dir": str(chai_fasta.working_dir), + "constraints": str(chai_fasta.constraints), + "output_dir": str(seed_output_dir), + "seed": seed, + "use_templates_server": use_templates_server, + "template_hits_path": ( + str(template_hits_path) + if template_hits_path is not None + else None + ), + } + ) + return tasks + + +def _write_device_manifests( + tasks: list[dict], + devices: list[str | None], + work_root: Path, + number_of_models: int, + num_recycles: int, +) -> list[Path]: + per_device: list[list[dict]] = [[] for _ in devices] + for index, task in enumerate(tasks): + device_index = index % len(devices) + task = dict(task) + task["device"] = devices[device_index] + task["number_of_models"] = number_of_models + task["num_recycles"] = num_recycles + per_device[device_index].append(task) + + manifests = [] + for index, device_tasks in enumerate(per_device): + if not device_tasks: + continue + manifest_path = work_root / f"chai_fast_manifest_{index}.json" + manifest_path.write_text(json.dumps(device_tasks, indent=2)) + manifests.append(manifest_path) + return manifests + + +def _write_worker_script(work_root: Path) -> Path: + worker_script = work_root / "chai_fast_worker.py" + worker_script.write_text(textwrap.dedent(""" + import argparse + import json + import sys + from pathlib import Path + + parser = argparse.ArgumentParser() + parser.add_argument("--manifest", required=True, type=Path) + parser.add_argument("--repo-root", required=True, type=Path) + args = parser.parse_args() + + sys.path.insert(0, str(args.repo_root)) + + from abcfold.chai1.chai import run_inference_wrapper + + tasks = json.loads(args.manifest.read_text()) + for task in tasks: + output_dir = Path(task["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + constraint_path = Path(task["constraints"]) + if not constraint_path.exists(): + constraint_path = None + template_hits_path = task.get("template_hits_path") + if template_hits_path is not None: + template_hits_path = Path(template_hits_path) + run_inference_wrapper( + Path(task["fasta"]), + output_dir=output_dir, + msa_directory=Path(task["msa_dir"]), + constraint_path=constraint_path, + use_templates_server=task["use_templates_server"], + template_hits_path=template_hits_path, + num_trunk_recycles=task["num_recycles"], + num_diffn_timesteps=200, + num_diffn_samples=task["number_of_models"], + seed=task["seed"], + device=task["device"], + ) + """)) + return worker_script + + +def _postprocess_cases(tasks: list[dict], config: dict, save_input: bool) -> None: + from abcfold.output.chai import ChaiOutput + + seen_cases: dict[str, dict] = {} + for task in tasks: + seen_cases.setdefault(task["case_id"], task) + + for case_id, task in seen_cases.items(): + case_output_dir = Path(task["case_output_dir"]) + chai_output_dirs: list[Union[str, Path]] = list( + case_output_dir.glob("chai_output*") + ) + with Path(task["input_json"]).open("r") as f: + input_params = json.load(f) + ChaiOutput(chai_output_dirs, input_params, case_id, config, save_input) + + +def _write_error_log(output_dir: Path, error: subprocess.CalledProcessError) -> None: + stderr = error.stderr or "" + output_err_file = output_dir / "chai_fast_error.log" + output_err_file.write_text(stderr) + logger.error("Chai fast run failed. Error log is in %s", output_err_file) diff --git a/docs/chai_fast_runner.md b/docs/chai_fast_runner.md new file mode 100644 index 0000000..d7d0eba --- /dev/null +++ b/docs/chai_fast_runner.md @@ -0,0 +1,56 @@ +# Chai Fast Runner + +ABCFold runs Chai-1 through the fast runner by default when `--chai1` is set. +There is no separate `--chai_fast` flag; it was intentionally removed. + +The fast runner prepares the same Chai FASTA, MSA directory, and constraints as +the standard runner, then executes seed jobs through persistent worker +processes. This avoids repeated process setup overhead and lets ABCFold assign +jobs across the requested `--gpus` device slots. + +## Usage + +```bash +uv run abcfold input.json output_dir --chai1 --gpus 0,1 +``` + +Use `--gpus all` to let Chai choose CUDA, `--gpus cpu` for CPU execution, or a +comma-separated list such as `--gpus 0,1` to split seed jobs across devices. + +## Smoke Validation + +Check the installed CLI without starting inference: + +```bash +uv run abcfold --help +``` + +The help should include `--chai1` and should not include `--chai_fast`: + +```bash +uv run abcfold --help | grep chai +``` + +## Output Layout + +For a normal ABCFold run, the fast runner preserves the existing Chai output +layout: + +- `chai_output_seed-/` +- `pred.model_idx_.cif` +- `scores.model_idx_.npz` +- ABCFold post-processing and visualization outputs + +Batch helper calls can also write native nested output directories under +`outputs//chai_output_seed-/` for downstream export workflows. + +## Quality Settings + +The fast runner does not lower Chai quality settings. It forwards: + +- `--number_of_models` to Chai diffusion sample count. +- `--num_recycles` to Chai trunk recycle count. +- 200 diffusion timesteps, matching the existing ABCFold Chai runner. + +The speedup comes from orchestration and worker reuse, not from reducing model +quality. diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 9f5cc7e..2fcb4f6 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -1,7 +1,7 @@ import argparse from abcfold.argparse_utils import (alphafold_argparse_util, - boltz_argparse_util, + boltz_argparse_util, chai_argparse_util, custom_template_argpase_util, main_argpase_util, mmseqs2_argparse_util) @@ -41,6 +41,13 @@ def test_boltz_crystal_mode_argparse_util(): assert args.boltz_num_workers == 2 +def test_chai_argparse_util(): + parser = argparse.ArgumentParser() + parser = chai_argparse_util(parser) + args = parser.parse_args(["--chai1"]) + assert args.chai1 + + def test_custom_template_argpase_util(): parser = argparse.ArgumentParser() parser = custom_template_argpase_util(parser) diff --git a/tests/test_run_chai.py b/tests/test_run_chai.py index 6f5d663..f77d37b 100644 --- a/tests/test_run_chai.py +++ b/tests/test_run_chai.py @@ -1,55 +1,29 @@ import os import tempfile +from pathlib import Path import pytest try: import chai_lab # noqa F401 - run_chai1 = True + chai_lab_installed = True except ImportError: - run_chai1 = False - - -@pytest.mark.skipif(not run_chai1, reason="chai_lab not installed") -def test_generate_chai_command(test_data): - from abcfold.chai1.run_chai1 import generate_chai_command - - input_fasta = "/road/to/nowhere.fasta" - msa_dir = "/road/to/nowhere" - with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as fp: - constraints = fp.name - output_dir = "/road/to/nowhere" - - cmd = generate_chai_command( - input_fasta=input_fasta, - msa_dir=msa_dir, - input_constraints=constraints, - output_dir=output_dir, - ) - - assert cmd[1].endswith("chai.py") - assert "fold" in cmd - assert input_fasta in cmd - assert msa_dir in cmd - assert constraints in cmd - assert output_dir in cmd - assert "--num-diffn-samples" in cmd - assert "5" in cmd + chai_lab_installed = False @pytest.mark.skipif( - os.getenv("CI") == "true" and not run_chai1, + os.getenv("CI") == "true" and not chai_lab_installed, reason="Skipping test in CI environment", ) def test_run_chai(test_data): pytest.importorskip("chai_lab") - from abcfold.chai1.run_chai1 import run_chai + from abcfold.chai1.run_chai1_fast import run_chai_fast with tempfile.TemporaryDirectory() as temp_dir: try: - run_chai( + run_chai_fast( test_data.test_inputA_json, temp_dir, save_input=True, @@ -59,3 +33,63 @@ def test_run_chai(test_data): except Exception as e: print(e) assert False + + +def test_normalize_chai_fast_devices(): + from abcfold.chai1.run_chai1_fast import normalize_devices + + assert normalize_devices(None) == [None] + assert normalize_devices("cpu") == ["cpu"] + assert normalize_devices("all") == ["cuda"] + assert normalize_devices("0,1") == ["cuda:0", "cuda:1"] + + +def test_run_chai_batch_test_mode_creates_native_layout(monkeypatch, test_data): + from abcfold.chai1 import run_chai1_fast + + def fail_if_called(*args, **kwargs): + raise AssertionError("test mode must not create or use the Chai env") + + monkeypatch.setattr(run_chai1_fast, "ensure_chai_env", fail_if_called) + + with tempfile.TemporaryDirectory() as temp_dir: + run_ok = run_chai1_fast.run_chai_batch( + [test_data.test_inputA_json, test_data.test_inputDNA_json], + temp_dir, + config=test_data.config_dict, + test=True, + devices="0,1", + ) + + assert run_ok + assert Path(temp_dir, "outputs", "2PV7", "chai_output_seed-1").is_dir() + assert Path(temp_dir, "outputs", "DNA_example", "chai_output_seed-1").is_dir() + + +def test_chai_fast_manifests_preserve_quality_settings(): + from abcfold.chai1.run_chai1_fast import (_write_device_manifests, + _write_worker_script) + + tasks = [ + {"seed": 1, "output_dir": "out-1"}, + {"seed": 2, "output_dir": "out-2"}, + ] + with tempfile.TemporaryDirectory() as temp_dir: + manifests = _write_device_manifests( + tasks, + devices=["cuda:0", "cuda:1"], + work_root=Path(temp_dir), + number_of_models=5, + num_recycles=10, + ) + + assert len(manifests) == 2 + manifest_text = "\n".join(path.read_text() for path in manifests) + assert '"device": "cuda:0"' in manifest_text + assert '"device": "cuda:1"' in manifest_text + assert '"number_of_models": 5' in manifest_text + assert '"num_recycles": 10' in manifest_text + assert ( + "num_diffn_timesteps=200" + in _write_worker_script(Path(temp_dir)).read_text() + )