Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions run/af3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@ def get_af3_parser() -> FileArgumentParser:
" AlphaFold 3 will use the seeds as provided in the input JSON."
)

# Early stopping arguments.
parser.add_argument(
"--early_stop_metric",
type=str,
default=None,
help="Metric to use for early stopping and ranking (e.g., 'actifptm', 'ranking_score'). "
"If set, this metric is also used for selecting the best result."
)
parser.add_argument(
"--early_stop_threshold",
type=float,
default=None,
help="Stop processing seeds when any sample achieves score >= this threshold. "
"Requires --early_stop_metric to be set."
)

# Control which stages to run.
parser.add_argument(
"--run_inference",
Expand Down Expand Up @@ -276,6 +292,8 @@ def get_af3_args(arg_file: Optional[str] = None) -> Dict[str, Any]:
if args.num_seeds is not None:
if args.num_seeds < 1:
raise ValueError("--num_seeds must be greater than or equal to 1.")
if args.early_stop_threshold is not None and args.early_stop_metric is None:
raise ValueError("--early_stop_threshold requires --early_stop_metric to be set.")

return vars(args)

Expand Down
95 changes: 40 additions & 55 deletions run/run_af3.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,7 @@ def extract_inference_results_and_maybe_embeddings(

@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class ResultsForSeed:
"""Stores the inference results (diffusion samples) for a single seed.

Attributes:
seed: The seed used to generate the samples.
inference_results: The inference results, one per sample.
full_fold_input: The fold input that must also include the results of
running the data pipeline - MSA and templates.
embeddings: The final trunk single and pair embeddings, if requested.
"""

"""Stores the inference results (diffusion samples) for a single seed."""
seed: int
inference_results: Sequence[model.InferenceResult]
full_fold_input: folding_input.Input
Expand All @@ -259,6 +250,8 @@ def predict_structure(
buckets: Sequence[int] | None = None,
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
early_stop_metric: str | None = None,
early_stop_threshold: float | None = None,
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""

Expand All @@ -283,6 +276,7 @@ def predict_structure(
)
all_inference_start_time = time.time()
all_inference_results = []

for seed, example in zip(fold_input.rng_seeds, featurised_examples):
print(f'Running model inference with seed {seed}...')
inference_start_time = time.time()
Expand Down Expand Up @@ -312,11 +306,26 @@ def predict_structure(
embeddings=embeddings,
)
)

# Check early stopping condition
if early_stop_metric is not None and early_stop_threshold is not None:
should_stop = any(
float(ir.metadata.get(early_stop_metric, 0)) >= early_stop_threshold
for ir in inference_results
)
if should_stop:
print(
f'Early stopping: {early_stop_metric} >= {early_stop_threshold} '
f'(seed {seed})'
)
break

print(
'Running model inference and extracting output structures with'
f' {len(fold_input.rng_seeds)} seed(s) took'
f' {len(all_inference_results)} seed(s) took'
f' {time.time() - all_inference_start_time:.2f} seconds.'
)

return all_inference_results


Expand All @@ -336,8 +345,12 @@ def write_outputs(
all_inference_results: Sequence[ResultsForSeed],
output_dir: os.PathLike[str] | str,
job_name: str,
ranking_metric: str | None = None,
) -> None:
"""Writes outputs to the specified output directory."""
if ranking_metric is None:
ranking_metric = 'ranking_score'

ranking_scores = []
max_ranking_score = None
max_ranking_result = None
Expand All @@ -357,7 +370,7 @@ def write_outputs(
output_dir=sample_dir,
name=f'{job_name}_seed-{seed}_sample-{sample_idx}',
)
ranking_score = float(result.metadata['ranking_score'])
ranking_score = float(result.metadata.get(ranking_metric, 0))
ranking_scores.append((seed, sample_idx, ranking_score))
if max_ranking_score is None or ranking_score > max_ranking_score:
max_ranking_score = ranking_score
Expand All @@ -372,21 +385,18 @@ def write_outputs(
name=f'{job_name}_seed-{seed}',
)

if max_ranking_result is not None: # True iff ranking_scores non-empty.
if max_ranking_result is not None:
post_processing.write_output(
inference_result=max_ranking_result,
output_dir=output_dir,
# The output terms of use are the same for all seeds/samples.
terms_of_use=output_terms,
name=job_name,
)
# Save csv of ranking scores with seeds and sample indices, to allow easier
# comparison of ranking scores across different runs.
with open(
os.path.join(output_dir, f'{job_name}_ranking_scores.csv'), 'wt'
) as f:
writer = csv.writer(f)
writer.writerow(['seed', 'sample', 'ranking_score'])
writer.writerow(['seed', 'sample', ranking_metric])
writer.writerows(ranking_scores)


Expand Down Expand Up @@ -416,6 +426,8 @@ def process_fold_input(
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
force_output_dir: bool = False,
early_stop_metric: str | None = None,
early_stop_threshold: float | None = None,
) -> folding_input.Input:
...

Expand All @@ -430,6 +442,8 @@ def process_fold_input(
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
force_output_dir: bool = False,
early_stop_metric: str | None = None,
early_stop_threshold: float | None = None,
) -> Sequence[ResultsForSeed]:
...

Expand All @@ -443,38 +457,10 @@ def process_fold_input(
ref_max_modified_date: datetime.date | None = None,
conformer_max_iterations: int | None = None,
force_output_dir: bool = False,
early_stop_metric: str | None = None,
early_stop_threshold: float | None = None,
) -> folding_input.Input | Sequence[ResultsForSeed]:
"""Runs data pipeline and/or inference on a single fold input.

Args:
fold_input: Fold input to process.
data_pipeline_config: Data pipeline config to use. If None, skip the data
pipeline.
model_runner: Model runner to use. If None, skip inference.
output_dir: Output directory to write to.
buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation
of the model. If None, calculate the appropriate bucket size from the
number of tokens. If not None, must be a sequence of at least one integer,
in strictly increasing order. Will raise an error if the number of tokens
is more than the largest bucket size.
ref_max_modified_date: Optional maximum date that controls whether to allow
use of model coordinates for a chemical component from the CCD if RDKit
conformer generation fails and the component does not have ideal
coordinates set. Only for components that have been released before this
date the model coordinates can be used as a fallback.
conformer_max_iterations: Optional override for maximum number of iterations
to run for RDKit conformer search.
force_output_dir: If True, do not create a new output directory even if the
existing one is non-empty. Instead use the existing output directory and
potentially overwrite existing files. If False, create a new timestamped
output directory instead if the existing one is non-empty.

Returns:
The processed fold input, or the inference results for each seed.

Raises:
ValueError: If the fold input has no chains.
"""
"""Runs data pipeline and/or inference on a single fold input."""
print(f'\nRunning fold job {fold_input.name}...')

if not fold_input.chains:
Expand Down Expand Up @@ -517,12 +503,15 @@ def process_fold_input(
buckets=buckets,
ref_max_modified_date=ref_max_modified_date,
conformer_max_iterations=conformer_max_iterations,
early_stop_metric=early_stop_metric,
early_stop_threshold=early_stop_threshold,
)
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
print(f'Writing outputs with {len(all_inference_results)} seed(s)...')
write_outputs(
all_inference_results=all_inference_results,
output_dir=output_dir,
job_name=fold_input.sanitised_name(),
ranking_metric=early_stop_metric,
)
output = all_inference_results

Expand Down Expand Up @@ -612,8 +601,6 @@ def main(args_dict: Dict[str, Any]) -> None:

max_template_date = datetime.date.fromisoformat(args_dict["max_template_date"])
if args_dict["run_data_pipeline"]:
# We skip this (by setting run_data_pipeline=False) since we handle MSAs
# and templates differently.
expand_path = lambda x: replace_db_dir(x, DB_DIR.value)
data_pipeline_config = pipeline.DataPipelineConfig(
jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value,
Expand Down Expand Up @@ -658,7 +645,6 @@ def main(args_dict: Dict[str, Any]) -> None:
device=devices[args_dict["gpu_device"]],
model_dir=pathlib.Path(args_dict["model_dir"]),
)
# Check we can load the model parameters before launching anything.
print('Checking that model parameters can be loaded...')
_ = model_runner.model_params
else:
Expand All @@ -678,20 +664,19 @@ def main(args_dict: Dict[str, Any]) -> None:
ref_max_modified_date=max_template_date,
conformer_max_iterations=args_dict["conformer_max_iterations"],
force_output_dir=args_dict["force_output_dir"],
early_stop_metric=args_dict["early_stop_metric"],
early_stop_threshold=args_dict["early_stop_threshold"],
)
num_fold_inputs += 1

print(f'Done running {num_fold_inputs} fold jobs.')


if __name__ == '__main__':
# Work around for a known XLA issue:
# https://github.com/google-deepmind/alphafold3/blob/main/docs/performance.md#compilation-time-workaround-with-xla-flags
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false"

args_dict = get_af3_args()

# Add required flag for CUDA compute capability 7.x
if args_dict["cuda_compute_7x"]:
os.environ["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter"

Expand Down