From 30057be31fd8730758e438a1f73b71312d08f72f Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Sat, 25 Apr 2026 02:47:44 +0300 Subject: [PATCH 1/4] Add Chai fast runner --- abcfold/abcfold.py | 6 +- abcfold/argparse_utils.py | 5 + abcfold/chai1/run_chai1_fast.py | 299 ++++++++++++++++++++++++++++++++ tests/test_argparse.py | 10 +- tests/test_run_chai.py | 61 +++++++ 5 files changed, 379 insertions(+), 2 deletions(-) create mode 100644 abcfold/chai1/run_chai1_fast.py diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 8a06eb4..3eef438 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -209,7 +209,11 @@ def run(args, config, defaults, config_file): successful_runs.append(boltz_success) if args.chai1: - from abcfold.chai1.run_chai1 import run_chai + if args.chai_fast: + from abcfold.chai1.run_chai1_fast import \ + run_chai_fast as run_chai + else: + from abcfold.chai1.run_chai1 import run_chai template_hits_path = None if args.templates and args.mmseqs2: diff --git a/abcfold/argparse_utils.py b/abcfold/argparse_utils.py index f4306bd..fa1ab17 100644 --- a/abcfold/argparse_utils.py +++ b/abcfold/argparse_utils.py @@ -238,6 +238,11 @@ def chai_argparse_util(parser): action="store_true", help="Run Chai-1", ) + parser.add_argument( + "--chai_fast", + action="store_true", + help=dedent("[optional] Run Chai-1 through the persistent fast worker."), + ) return parser diff --git a/abcfold/chai1/run_chai1_fast.py b/abcfold/chai1/run_chai1_fast.py new file mode 100644 index 0000000..19ddd5f --- /dev/null +++ b/abcfold/chai1/run_chai1_fast.py @@ -0,0 +1,299 @@ +import json +import logging +import subprocess +import tempfile +import textwrap +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Iterable, Union + +from abcfold.chai1.af3_to_chai import ChaiFasta +from abcfold.chai1.check_install import ensure_chai_env +from abcfold.chai1.run_chai1 import normalize_device + +logger = logging.getLogger("logger") + + +def normalize_devices(gpus: str | None) -> list[str | None]: + if gpus is None: + return [None] + if gpus == "cpu": + return ["cpu"] + if gpus == "all": + return ["cuda"] + + devices: list[str | None] = [] + for gpu in gpus.split(","): + device = normalize_device(gpu.strip()) + if device is not None: + devices.append(device) + if not devices: + raise ValueError("No valid Chai devices were requested.") + return devices + + +def run_chai_fast( + input_json: Union[str, Path], + output_dir: Union[str, Path], + config: dict, + save_input: bool = False, + test: bool = False, + number_of_models: int = 5, + num_recycles: int = 10, + use_templates_server: bool = False, + template_hits_path: Path | None = None, + device: str | None = None, +) -> bool: + return run_chai_batch( + [input_json], + output_dir, + config=config, + save_input=save_input, + test=test, + number_of_models=number_of_models, + num_recycles=num_recycles, + use_templates_server=use_templates_server, + template_hits_paths=( + {Path(input_json): template_hits_path} + if template_hits_path is not None + else None + ), + devices=device, + nested_outputs=False, + postprocess=False, + ) + + +def run_chai_batch( + input_jsons: Iterable[Union[str, Path]], + output_dir: Union[str, Path], + config: dict, + save_input: bool = False, + test: bool = False, + number_of_models: int = 5, + num_recycles: int = 10, + use_templates_server: bool = False, + template_hits_paths: dict[Path, Path | None] | None = None, + devices: str | None = None, + nested_outputs: bool = True, + postprocess: bool = True, +) -> bool: + input_paths = [Path(path) for path in input_jsons] + if not input_paths: + logger.error("No Chai input JSON files were provided") + return False + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + device_slots = normalize_devices(devices) + + env = None + if not test: + logger.debug("Checking if Chai-1 is installed") + env = ensure_chai_env(config=config) + + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + tasks = _prepare_tasks( + input_paths=input_paths, + output_dir=output_dir, + config=config, + save_input=save_input, + use_templates_server=use_templates_server, + template_hits_paths=template_hits_paths, + nested_outputs=nested_outputs, + work_root=temp_dir, + ) + if test: + logger.info("Skipping Chai fast backend execution in test mode") + return True + + manifests = _write_device_manifests( + tasks=tasks, + devices=device_slots, + work_root=temp_dir, + number_of_models=number_of_models, + num_recycles=num_recycles, + ) + worker_script = _write_worker_script(temp_dir) + repo_root = Path(__file__).resolve().parents[2] + + def run_manifest(manifest_path: Path) -> bool: + assert env is not None + cmd = [ + "python", + str(worker_script), + "--manifest", + str(manifest_path), + "--repo-root", + str(repo_root), + ] + try: + env.run(cmd) + except subprocess.CalledProcessError as e: + _write_error_log(Path(output_dir), e) + return False + return True + + if len(manifests) > 1: + with ThreadPoolExecutor(max_workers=len(manifests)) as pool: + futures = [ + pool.submit(run_manifest, manifest) for manifest in manifests + ] + if not all(future.result() for future in as_completed(futures)): + return False + else: + if not run_manifest(manifests[0]): + return False + + if postprocess: + _postprocess_cases(tasks, config=config, save_input=save_input) + + logger.info("Chai fast run complete") + return True + + +def _prepare_tasks( + input_paths: list[Path], + output_dir: Path, + config: dict, + save_input: bool, + use_templates_server: bool, + template_hits_paths: dict[Path, Path | None] | None, + nested_outputs: bool, + work_root: Path, +) -> list[dict]: + tasks = [] + for input_json in input_paths: + with input_json.open("r") as f: + input_params = json.load(f) + case_id = input_params.get("name") or input_json.stem + case_output_dir = ( + output_dir / "outputs" / case_id if nested_outputs else output_dir + ) + case_output_dir.mkdir(parents=True, exist_ok=True) + working_dir = case_output_dir if save_input else work_root / case_id + working_dir.mkdir(parents=True, exist_ok=True) + + chai_fasta = ChaiFasta(working_dir, config=config) + chai_fasta.json_to_fasta(input_json) + template_hits_path = ( + template_hits_paths.get(input_json) if template_hits_paths else None + ) + + for seed in chai_fasta.seeds: + seed_output_dir = case_output_dir / f"chai_output_seed-{seed}" + seed_output_dir.mkdir(parents=True, exist_ok=True) + tasks.append( + { + "input_json": str(input_json), + "case_id": case_id, + "case_output_dir": str(case_output_dir), + "fasta": str(chai_fasta.fasta), + "msa_dir": str(chai_fasta.working_dir), + "constraints": str(chai_fasta.constraints), + "output_dir": str(seed_output_dir), + "seed": seed, + "use_templates_server": use_templates_server, + "template_hits_path": ( + str(template_hits_path) + if template_hits_path is not None + else None + ), + } + ) + return tasks + + +def _write_device_manifests( + tasks: list[dict], + devices: list[str | None], + work_root: Path, + number_of_models: int, + num_recycles: int, +) -> list[Path]: + per_device: list[list[dict]] = [[] for _ in devices] + for index, task in enumerate(tasks): + device_index = index % len(devices) + task = dict(task) + task["device"] = devices[device_index] + task["number_of_models"] = number_of_models + task["num_recycles"] = num_recycles + per_device[device_index].append(task) + + manifests = [] + for index, device_tasks in enumerate(per_device): + if not device_tasks: + continue + manifest_path = work_root / f"chai_fast_manifest_{index}.json" + manifest_path.write_text(json.dumps(device_tasks, indent=2)) + manifests.append(manifest_path) + return manifests + + +def _write_worker_script(work_root: Path) -> Path: + worker_script = work_root / "chai_fast_worker.py" + worker_script.write_text(textwrap.dedent(""" + import argparse + import json + import sys + from pathlib import Path + + parser = argparse.ArgumentParser() + parser.add_argument("--manifest", required=True, type=Path) + parser.add_argument("--repo-root", required=True, type=Path) + args = parser.parse_args() + + sys.path.insert(0, str(args.repo_root)) + + from abcfold.chai1.chai import run_inference_wrapper + + tasks = json.loads(args.manifest.read_text()) + for task in tasks: + output_dir = Path(task["output_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + constraint_path = Path(task["constraints"]) + if not constraint_path.exists(): + constraint_path = None + template_hits_path = task.get("template_hits_path") + if template_hits_path is not None: + template_hits_path = Path(template_hits_path) + run_inference_wrapper( + Path(task["fasta"]), + output_dir=output_dir, + msa_directory=Path(task["msa_dir"]), + constraint_path=constraint_path, + use_templates_server=task["use_templates_server"], + template_hits_path=template_hits_path, + num_trunk_recycles=task["num_recycles"], + num_diffn_timesteps=200, + num_diffn_samples=task["number_of_models"], + seed=task["seed"], + device=task["device"], + ) + """)) + return worker_script + + +def _postprocess_cases(tasks: list[dict], config: dict, save_input: bool) -> None: + from abcfold.output.chai import ChaiOutput + + seen_cases: dict[str, dict] = {} + for task in tasks: + seen_cases.setdefault(task["case_id"], task) + + for case_id, task in seen_cases.items(): + case_output_dir = Path(task["case_output_dir"]) + chai_output_dirs: list[Union[str, Path]] = list( + case_output_dir.glob("chai_output*") + ) + with Path(task["input_json"]).open("r") as f: + input_params = json.load(f) + ChaiOutput(chai_output_dirs, input_params, case_id, config, save_input) + + +def _write_error_log(output_dir: Path, error: subprocess.CalledProcessError) -> None: + stderr = error.stderr or "" + output_err_file = output_dir / "chai_fast_error.log" + output_err_file.write_text(stderr) + logger.error("Chai fast run failed. Error log is in %s", output_err_file) diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 9f5cc7e..86b7eaf 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -1,7 +1,7 @@ import argparse from abcfold.argparse_utils import (alphafold_argparse_util, - boltz_argparse_util, + boltz_argparse_util, chai_argparse_util, custom_template_argpase_util, main_argpase_util, mmseqs2_argparse_util) @@ -41,6 +41,14 @@ def test_boltz_crystal_mode_argparse_util(): assert args.boltz_num_workers == 2 +def test_chai_fast_argparse_util(): + parser = argparse.ArgumentParser() + parser = chai_argparse_util(parser) + args = parser.parse_args(["--chai1", "--chai_fast"]) + assert args.chai1 + assert args.chai_fast + + def test_custom_template_argpase_util(): parser = argparse.ArgumentParser() parser = custom_template_argpase_util(parser) diff --git a/tests/test_run_chai.py b/tests/test_run_chai.py index 6f5d663..cc14b41 100644 --- a/tests/test_run_chai.py +++ b/tests/test_run_chai.py @@ -1,5 +1,6 @@ import os import tempfile +from pathlib import Path import pytest @@ -59,3 +60,63 @@ def test_run_chai(test_data): except Exception as e: print(e) assert False + + +def test_normalize_chai_fast_devices(): + from abcfold.chai1.run_chai1_fast import normalize_devices + + assert normalize_devices(None) == [None] + assert normalize_devices("cpu") == ["cpu"] + assert normalize_devices("all") == ["cuda"] + assert normalize_devices("0,1") == ["cuda:0", "cuda:1"] + + +def test_run_chai_batch_test_mode_creates_native_layout(monkeypatch, test_data): + from abcfold.chai1 import run_chai1_fast + + def fail_if_called(*args, **kwargs): + raise AssertionError("test mode must not create or use the Chai env") + + monkeypatch.setattr(run_chai1_fast, "ensure_chai_env", fail_if_called) + + with tempfile.TemporaryDirectory() as temp_dir: + run_ok = run_chai1_fast.run_chai_batch( + [test_data.test_inputA_json, test_data.test_inputDNA_json], + temp_dir, + config=test_data.config_dict, + test=True, + devices="0,1", + ) + + assert run_ok + assert Path(temp_dir, "outputs", "2PV7", "chai_output_seed-1").is_dir() + assert Path(temp_dir, "outputs", "DNA_example", "chai_output_seed-1").is_dir() + + +def test_chai_fast_manifests_preserve_quality_settings(): + from abcfold.chai1.run_chai1_fast import (_write_device_manifests, + _write_worker_script) + + tasks = [ + {"seed": 1, "output_dir": "out-1"}, + {"seed": 2, "output_dir": "out-2"}, + ] + with tempfile.TemporaryDirectory() as temp_dir: + manifests = _write_device_manifests( + tasks, + devices=["cuda:0", "cuda:1"], + work_root=Path(temp_dir), + number_of_models=5, + num_recycles=10, + ) + + assert len(manifests) == 2 + manifest_text = "\n".join(path.read_text() for path in manifests) + assert '"device": "cuda:0"' in manifest_text + assert '"device": "cuda:1"' in manifest_text + assert '"number_of_models": 5' in manifest_text + assert '"num_recycles": 10' in manifest_text + assert ( + "num_diffn_timesteps=200" + in _write_worker_script(Path(temp_dir)).read_text() + ) From e907ba0efbbb81ac82b6457017400ffeb9327cb4 Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Tue, 28 Apr 2026 12:13:54 +0300 Subject: [PATCH 2/4] Make Chai fast runner the default --- README.md | 7 +++++++ abcfold/abcfold.py | 6 +----- abcfold/argparse_utils.py | 5 ----- docs/chai_fast_runner.md | 42 +++++++++++++++++++++++++++++++++++++++ tests/test_argparse.py | 5 ++--- 5 files changed, 52 insertions(+), 13 deletions(-) create mode 100644 docs/chai_fast_runner.md diff --git a/README.md b/README.md index 2a03fd1..9620d46 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,13 @@ However, there you may wish to use the following flags to add run time options s #### OpenFold3 arguments - `--inference_ckpt_path` [optional] Path for model checkpoint to be used for inference. If not specified, will attempt to find or download parameters in ~/.openfold3/ +#### Chai-1 fast runner + +`--chai1` uses ABCFold's persistent Chai worker by default. The worker keeps +the same Chai quality settings as the standard runner while reducing repeated +process setup overhead and distributing seed jobs across the requested +`--gpus` slots. See [Chai Fast Runner](docs/chai_fast_runner.md) for details. + #### Template arguments - `--templates`: Flag to enable a template search diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 3eef438..77e5794 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -209,11 +209,7 @@ def run(args, config, defaults, config_file): successful_runs.append(boltz_success) if args.chai1: - if args.chai_fast: - from abcfold.chai1.run_chai1_fast import \ - run_chai_fast as run_chai - else: - from abcfold.chai1.run_chai1 import run_chai + from abcfold.chai1.run_chai1_fast import run_chai_fast as run_chai template_hits_path = None if args.templates and args.mmseqs2: diff --git a/abcfold/argparse_utils.py b/abcfold/argparse_utils.py index fa1ab17..f4306bd 100644 --- a/abcfold/argparse_utils.py +++ b/abcfold/argparse_utils.py @@ -238,11 +238,6 @@ def chai_argparse_util(parser): action="store_true", help="Run Chai-1", ) - parser.add_argument( - "--chai_fast", - action="store_true", - help=dedent("[optional] Run Chai-1 through the persistent fast worker."), - ) return parser diff --git a/docs/chai_fast_runner.md b/docs/chai_fast_runner.md new file mode 100644 index 0000000..1e18fa2 --- /dev/null +++ b/docs/chai_fast_runner.md @@ -0,0 +1,42 @@ +# Chai Fast Runner + +ABCFold runs Chai-1 through the fast runner by default when `--chai1` is set. +There is no separate `--chai_fast` flag. + +The fast runner prepares the same Chai FASTA, MSA directory, and constraints as +the standard runner, then executes seed jobs through persistent worker +processes. This avoids repeated process setup overhead and lets ABCFold assign +jobs across the requested `--gpus` device slots. + +## Usage + +```bash +uv run abcfold input.json output_dir --chai1 --gpus 0,1 +``` + +Use `--gpus all` to let Chai choose CUDA, `--gpus cpu` for CPU execution, or a +comma-separated list such as `--gpus 0,1` to split seed jobs across devices. + +## Output Layout + +For a normal ABCFold run, the fast runner preserves the existing Chai output +layout: + +- `chai_output_seed-/` +- `pred.model_idx_.cif` +- `scores.model_idx_.npz` +- ABCFold post-processing and visualization outputs + +Batch helper calls can also write native nested output directories under +`outputs//chai_output_seed-/` for downstream export workflows. + +## Quality Settings + +The fast runner does not lower Chai quality settings. It forwards: + +- `--number_of_models` to Chai diffusion sample count. +- `--num_recycles` to Chai trunk recycle count. +- 200 diffusion timesteps, matching the existing ABCFold Chai runner. + +The speedup comes from orchestration and worker reuse, not from reducing model +quality. diff --git a/tests/test_argparse.py b/tests/test_argparse.py index 86b7eaf..2fcb4f6 100644 --- a/tests/test_argparse.py +++ b/tests/test_argparse.py @@ -41,12 +41,11 @@ def test_boltz_crystal_mode_argparse_util(): assert args.boltz_num_workers == 2 -def test_chai_fast_argparse_util(): +def test_chai_argparse_util(): parser = argparse.ArgumentParser() parser = chai_argparse_util(parser) - args = parser.parse_args(["--chai1", "--chai_fast"]) + args = parser.parse_args(["--chai1"]) assert args.chai1 - assert args.chai_fast def test_custom_template_argpase_util(): From 63ff74728dbcb800c1778b511b74e48231060748 Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Tue, 28 Apr 2026 12:48:47 +0300 Subject: [PATCH 3/4] Remove legacy Chai runner --- abcfold/chai1/run_chai1.py | 219 -------------------------------- abcfold/chai1/run_chai1_fast.py | 15 ++- tests/test_run_chai.py | 37 +----- 3 files changed, 19 insertions(+), 252 deletions(-) delete mode 100644 abcfold/chai1/run_chai1.py diff --git a/abcfold/chai1/run_chai1.py b/abcfold/chai1/run_chai1.py deleted file mode 100644 index 5217b08..0000000 --- a/abcfold/chai1/run_chai1.py +++ /dev/null @@ -1,219 +0,0 @@ -import logging -import os -import shutil -import subprocess -import tempfile -from pathlib import Path -from typing import Union - -from abcfold.chai1.af3_to_chai import ChaiFasta -from abcfold.chai1.check_install import ensure_chai_env - -logger = logging.getLogger("logger") -os.environ["DISABLE_PANDERA_IMPORT_WARNING"] = "True" - - -def normalize_device(gpus: str | None) -> str | None: - if gpus is None: - return None - if gpus == "cpu": - return "cpu" - if gpus == "all": - return "cuda" - - # Validate and normalize the GPU list. - gpu_ids = [] - for gpu in gpus.split(","): - gpu = gpu.strip() - if not gpu.isdigit(): - raise ValueError(f"Invalid GPU ID: {gpu}") - gpu_ids.append(gpu) - - # Chai accepts a single device, so use the first requested GPU. - return f"cuda:{gpu_ids[0]}" - - -def run_chai( - input_json: Union[str, Path], - output_dir: Union[str, Path], - config: dict, - save_input: bool = False, - test: bool = False, - number_of_models: int = 5, - num_recycles: int = 10, - use_templates_server: bool = False, - template_hits_path: Path | None = None, - device: str | None = None, -) -> bool: - """ - Run Chai-1 using the input JSON file - - Args: - input_json (Union[str, Path]): Path to the input JSON file - output_dir (Union[str, Path]): Path to the output directory - config (dict): Configuration dictionary - save_input (bool): If True, save the input fasta file and MSA to the output - directory - test (bool): If True, run the test command - number_of_models (int): Number of models to generate - num_recycles (int): Number of trunk recycles - use_templates_server (bool): If True, use templates from the server - template_hits_path (Path): Path to the template hits m8 file - device (str | None): If specified, use the specified GPU - - Returns: - Bool: True if the Chai-1 run was successful, False otherwise - - """ - input_json = Path(input_json) - output_dir = Path(output_dir) - - logger.debug("Checking if Chai-1 is installed") - env = ensure_chai_env(config=config) - - with tempfile.TemporaryDirectory() as temp_dir: - working_dir = Path(temp_dir) - if save_input: - logger.info("Saving input fasta file and msa to the output directory") - working_dir = output_dir - working_dir.mkdir(parents=True, exist_ok=True) - - chai_fasta = ChaiFasta(working_dir, config=config) - chai_fasta.json_to_fasta(input_json) - - out_fasta = chai_fasta.fasta - msa_dir = chai_fasta.working_dir - out_constraints = chai_fasta.constraints - - normalized_device = normalize_device(device) - for seed in chai_fasta.seeds: - chai_output_dir = output_dir / f"chai_output_seed-{seed}" - - logger.info(f"Running Chai-1 using seed {seed}") - cmd = ( - generate_chai_command( - out_fasta, - msa_dir, - out_constraints, - chai_output_dir, - number_of_models, - num_recycles=num_recycles, - seed=seed, - use_templates_server=use_templates_server, - template_hits_path=template_hits_path, - device=normalized_device, - ) - if not test - else generate_chai_test_command() - ) - - try: - env.run(cmd) - except subprocess.CalledProcessError as e: - stderr = e.stderr or "" - if stderr: - if chai_output_dir.exists(): - output_err_file = chai_output_dir / "chai_error.log" - else: - output_err_file = chai_output_dir.parent / "chai_error.log" - output_err_file.write_text(stderr) - logger.error( - "Chai-1 run failed. Error log is in %s", output_err_file - ) - else: - logger.error("Chai-1 run failed") - return False - - logger.info("Chai-1 run complete") - return True - - -def generate_chai_command( - input_fasta: Union[str, Path], - msa_dir: Union[str, Path], - input_constraints: Union[str, Path], - output_dir: Union[str, Path], - number_of_models: int = 5, - num_recycles: int = 10, - seed: int = 42, - use_templates_server: bool = False, - template_hits_path: Path | None = None, - device: str | None = None, -) -> list: - """ - Generate the Chai-1 command - - Args: - input_fasta (Union[str, Path]): Path to the input fasta file - msa_dir (Union[str, Path]): Path to the MSA directory - input_constraints (Union[str, Path]): Path to the input constraints file - output_dir (Union[str, Path]): Path to the output directory - number_of_models (int): Number of models to generate - num_recycles (int): Number of trunk recycles - seed (int): Seed for the random number generator - use_templates_server (bool): If True, use templates from the server - template_hits_path (Path): Path to the template hits m8 file - device (str | None): If specified, use the specified GPU - - Returns: - list: The Chai-1 command - - """ - - chai_exe = Path(__file__).parent / "chai.py" - cmd = [ - "python", - str(chai_exe), - "fold", - str(input_fasta) - ] - - if Path(msa_dir).exists(): - cmd += ["--msa-directory", str(msa_dir)] - if Path(input_constraints).exists(): - cmd += ["--constraint-path", str(input_constraints)] - - cmd += ["--num-diffn-samples", str(number_of_models)] - # Do not lower this without full validation: 5 diffusion steps produced physically invalid structures. - cmd += ["--num-diffn-timesteps", "200"] - cmd += ["--num-trunk-recycles", str(num_recycles)] - cmd += ["--seed", str(seed)] - - assert not ( - use_templates_server and template_hits_path - ), "Cannot specify both templates server and path" - - if shutil.which("kalign") is None and (use_templates_server or template_hits_path): - logger.warning( - "kalign not found, skipping template search kalign is required. \ -Please install kalign to use templates with Chai-1." - ) - else: - if use_templates_server: - cmd += ["--use-templates-server"] - if template_hits_path: - cmd += ["--template-hits-path", str(template_hits_path)] - - if device is not None and device != "all": - cmd += ["--device", device] - - cmd += [str(output_dir)] - - return cmd - - -def generate_chai_test_command() -> list: - """ - Generate the Chai-1 test command - - Args: - None - - Returns: - list: The Chai-1 test command - """ - return [ - "chai-lab", - "fold", - "--help", - ] diff --git a/abcfold/chai1/run_chai1_fast.py b/abcfold/chai1/run_chai1_fast.py index 19ddd5f..6598207 100644 --- a/abcfold/chai1/run_chai1_fast.py +++ b/abcfold/chai1/run_chai1_fast.py @@ -9,11 +9,24 @@ from abcfold.chai1.af3_to_chai import ChaiFasta from abcfold.chai1.check_install import ensure_chai_env -from abcfold.chai1.run_chai1 import normalize_device logger = logging.getLogger("logger") +def normalize_device(device: str | None) -> str | None: + if device is None: + return None + if device == "cpu": + return "cpu" + if device == "all": + return "cuda" + if "," in device: + device = device.split(",")[0] + if device.isdigit(): + return f"cuda:{device}" + return device + + def normalize_devices(gpus: str | None) -> list[str | None]: if gpus is None: return [None] diff --git a/tests/test_run_chai.py b/tests/test_run_chai.py index cc14b41..f77d37b 100644 --- a/tests/test_run_chai.py +++ b/tests/test_run_chai.py @@ -7,50 +7,23 @@ try: import chai_lab # noqa F401 - run_chai1 = True + chai_lab_installed = True except ImportError: - run_chai1 = False - - -@pytest.mark.skipif(not run_chai1, reason="chai_lab not installed") -def test_generate_chai_command(test_data): - from abcfold.chai1.run_chai1 import generate_chai_command - - input_fasta = "/road/to/nowhere.fasta" - msa_dir = "/road/to/nowhere" - with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as fp: - constraints = fp.name - output_dir = "/road/to/nowhere" - - cmd = generate_chai_command( - input_fasta=input_fasta, - msa_dir=msa_dir, - input_constraints=constraints, - output_dir=output_dir, - ) - - assert cmd[1].endswith("chai.py") - assert "fold" in cmd - assert input_fasta in cmd - assert msa_dir in cmd - assert constraints in cmd - assert output_dir in cmd - assert "--num-diffn-samples" in cmd - assert "5" in cmd + chai_lab_installed = False @pytest.mark.skipif( - os.getenv("CI") == "true" and not run_chai1, + os.getenv("CI") == "true" and not chai_lab_installed, reason="Skipping test in CI environment", ) def test_run_chai(test_data): pytest.importorskip("chai_lab") - from abcfold.chai1.run_chai1 import run_chai + from abcfold.chai1.run_chai1_fast import run_chai_fast with tempfile.TemporaryDirectory() as temp_dir: try: - run_chai( + run_chai_fast( test_data.test_inputA_json, temp_dir, save_input=True, From 1fd7958122cc3dfd60d011f677ba409969452c32 Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Wed, 29 Apr 2026 20:25:38 +0300 Subject: [PATCH 4/4] Document Chai fast smoke checks --- README.md | 3 ++- docs/chai_fast_runner.md | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9620d46..32340b0 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,8 @@ However, there you may wish to use the following flags to add run time options s `--chai1` uses ABCFold's persistent Chai worker by default. The worker keeps the same Chai quality settings as the standard runner while reducing repeated process setup overhead and distributing seed jobs across the requested -`--gpus` slots. See [Chai Fast Runner](docs/chai_fast_runner.md) for details. +`--gpus` slots. `--chai_fast` was intentionally removed; use `--chai1`. +See [Chai Fast Runner](docs/chai_fast_runner.md) for smoke checks and details. #### Template arguments diff --git a/docs/chai_fast_runner.md b/docs/chai_fast_runner.md index 1e18fa2..d7d0eba 100644 --- a/docs/chai_fast_runner.md +++ b/docs/chai_fast_runner.md @@ -1,7 +1,7 @@ # Chai Fast Runner ABCFold runs Chai-1 through the fast runner by default when `--chai1` is set. -There is no separate `--chai_fast` flag. +There is no separate `--chai_fast` flag; it was intentionally removed. The fast runner prepares the same Chai FASTA, MSA directory, and constraints as the standard runner, then executes seed jobs through persistent worker @@ -17,6 +17,20 @@ uv run abcfold input.json output_dir --chai1 --gpus 0,1 Use `--gpus all` to let Chai choose CUDA, `--gpus cpu` for CPU execution, or a comma-separated list such as `--gpus 0,1` to split seed jobs across devices. +## Smoke Validation + +Check the installed CLI without starting inference: + +```bash +uv run abcfold --help +``` + +The help should include `--chai1` and should not include `--chai_fast`: + +```bash +uv run abcfold --help | grep chai +``` + ## Output Layout For a normal ABCFold run, the fast runner preserves the existing Chai output