From 2d03ed389d42933e0b32f5b130acfa5671e9f2a5 Mon Sep 17 00:00:00 2001 From: Julia Varga Date: Mon, 26 Jan 2026 23:07:52 +0200 Subject: [PATCH 1/2] added early stopping by specified score --- run/run_af3.py | 95 +++++++++++++++++++++----------------------------- 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/run/run_af3.py b/run/run_af3.py index 722424a..a054017 100644 --- a/run/run_af3.py +++ b/run/run_af3.py @@ -237,16 +237,7 @@ def extract_inference_results_and_maybe_embeddings( @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class ResultsForSeed: - """Stores the inference results (diffusion samples) for a single seed. - - Attributes: - seed: The seed used to generate the samples. - inference_results: The inference results, one per sample. - full_fold_input: The fold input that must also include the results of - running the data pipeline - MSA and templates. - embeddings: The final trunk single and pair embeddings, if requested. - """ - + """Stores the inference results (diffusion samples) for a single seed.""" seed: int inference_results: Sequence[model.InferenceResult] full_fold_input: folding_input.Input @@ -259,6 +250,8 @@ def predict_structure( buckets: Sequence[int] | None = None, ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, + early_stop_metric: str | None = None, + early_stop_threshold: float | None = None, ) -> Sequence[ResultsForSeed]: """Runs the full inference pipeline to predict structures for each seed.""" @@ -283,6 +276,7 @@ def predict_structure( ) all_inference_start_time = time.time() all_inference_results = [] + for seed, example in zip(fold_input.rng_seeds, featurised_examples): print(f'Running model inference with seed {seed}...') inference_start_time = time.time() @@ -312,11 +306,26 @@ def predict_structure( embeddings=embeddings, ) ) + + # Check early stopping condition + if early_stop_metric is not None and early_stop_threshold is not None: + should_stop = any( + float(ir.metadata.get(early_stop_metric, 0)) >= early_stop_threshold + for ir in inference_results + ) + if should_stop: + print( + f'Early stopping: {early_stop_metric} >= {early_stop_threshold} ' + f'(seed {seed})' + ) + break + print( 'Running model inference and extracting output structures with' - f' {len(fold_input.rng_seeds)} seed(s) took' + f' {len(all_inference_results)} seed(s) took' f' {time.time() - all_inference_start_time:.2f} seconds.' ) + return all_inference_results @@ -336,8 +345,12 @@ def write_outputs( all_inference_results: Sequence[ResultsForSeed], output_dir: os.PathLike[str] | str, job_name: str, + ranking_metric: str | None = None, ) -> None: """Writes outputs to the specified output directory.""" + if ranking_metric is None: + ranking_metric = 'ranking_score' + ranking_scores = [] max_ranking_score = None max_ranking_result = None @@ -357,7 +370,7 @@ def write_outputs( output_dir=sample_dir, name=f'{job_name}_seed-{seed}_sample-{sample_idx}', ) - ranking_score = float(result.metadata['ranking_score']) + ranking_score = float(result.metadata.get(ranking_metric, 0)) ranking_scores.append((seed, sample_idx, ranking_score)) if max_ranking_score is None or ranking_score > max_ranking_score: max_ranking_score = ranking_score @@ -372,21 +385,18 @@ def write_outputs( name=f'{job_name}_seed-{seed}', ) - if max_ranking_result is not None: # True iff ranking_scores non-empty. + if max_ranking_result is not None: post_processing.write_output( inference_result=max_ranking_result, output_dir=output_dir, - # The output terms of use are the same for all seeds/samples. terms_of_use=output_terms, name=job_name, ) - # Save csv of ranking scores with seeds and sample indices, to allow easier - # comparison of ranking scores across different runs. with open( os.path.join(output_dir, f'{job_name}_ranking_scores.csv'), 'wt' ) as f: writer = csv.writer(f) - writer.writerow(['seed', 'sample', 'ranking_score']) + writer.writerow(['seed', 'sample', ranking_metric]) writer.writerows(ranking_scores) @@ -416,6 +426,8 @@ def process_fold_input( ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, force_output_dir: bool = False, + early_stop_metric: str | None = None, + early_stop_threshold: float | None = None, ) -> folding_input.Input: ... @@ -430,6 +442,8 @@ def process_fold_input( ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, force_output_dir: bool = False, + early_stop_metric: str | None = None, + early_stop_threshold: float | None = None, ) -> Sequence[ResultsForSeed]: ... @@ -443,38 +457,10 @@ def process_fold_input( ref_max_modified_date: datetime.date | None = None, conformer_max_iterations: int | None = None, force_output_dir: bool = False, + early_stop_metric: str | None = None, + early_stop_threshold: float | None = None, ) -> folding_input.Input | Sequence[ResultsForSeed]: - """Runs data pipeline and/or inference on a single fold input. - - Args: - fold_input: Fold input to process. - data_pipeline_config: Data pipeline config to use. If None, skip the data - pipeline. - model_runner: Model runner to use. If None, skip inference. - output_dir: Output directory to write to. - buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation - of the model. If None, calculate the appropriate bucket size from the - number of tokens. If not None, must be a sequence of at least one integer, - in strictly increasing order. Will raise an error if the number of tokens - is more than the largest bucket size. - ref_max_modified_date: Optional maximum date that controls whether to allow - use of model coordinates for a chemical component from the CCD if RDKit - conformer generation fails and the component does not have ideal - coordinates set. Only for components that have been released before this - date the model coordinates can be used as a fallback. - conformer_max_iterations: Optional override for maximum number of iterations - to run for RDKit conformer search. - force_output_dir: If True, do not create a new output directory even if the - existing one is non-empty. Instead use the existing output directory and - potentially overwrite existing files. If False, create a new timestamped - output directory instead if the existing one is non-empty. - - Returns: - The processed fold input, or the inference results for each seed. - - Raises: - ValueError: If the fold input has no chains. - """ + """Runs data pipeline and/or inference on a single fold input.""" print(f'\nRunning fold job {fold_input.name}...') if not fold_input.chains: @@ -517,12 +503,15 @@ def process_fold_input( buckets=buckets, ref_max_modified_date=ref_max_modified_date, conformer_max_iterations=conformer_max_iterations, + early_stop_metric=early_stop_metric, + early_stop_threshold=early_stop_threshold, ) - print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...') + print(f'Writing outputs with {len(all_inference_results)} seed(s)...') write_outputs( all_inference_results=all_inference_results, output_dir=output_dir, job_name=fold_input.sanitised_name(), + ranking_metric=early_stop_metric, ) output = all_inference_results @@ -612,8 +601,6 @@ def main(args_dict: Dict[str, Any]) -> None: max_template_date = datetime.date.fromisoformat(args_dict["max_template_date"]) if args_dict["run_data_pipeline"]: - # We skip this (by setting run_data_pipeline=False) since we handle MSAs - # and templates differently. expand_path = lambda x: replace_db_dir(x, DB_DIR.value) data_pipeline_config = pipeline.DataPipelineConfig( jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value, @@ -658,7 +645,6 @@ def main(args_dict: Dict[str, Any]) -> None: device=devices[args_dict["gpu_device"]], model_dir=pathlib.Path(args_dict["model_dir"]), ) - # Check we can load the model parameters before launching anything. print('Checking that model parameters can be loaded...') _ = model_runner.model_params else: @@ -678,6 +664,8 @@ def main(args_dict: Dict[str, Any]) -> None: ref_max_modified_date=max_template_date, conformer_max_iterations=args_dict["conformer_max_iterations"], force_output_dir=args_dict["force_output_dir"], + early_stop_metric=args_dict["early_stop_metric"], + early_stop_threshold=args_dict["early_stop_threshold"], ) num_fold_inputs += 1 @@ -685,13 +673,10 @@ def main(args_dict: Dict[str, Any]) -> None: if __name__ == '__main__': - # Work around for a known XLA issue: - # https://github.com/google-deepmind/alphafold3/blob/main/docs/performance.md#compilation-time-workaround-with-xla-flags os.environ["XLA_FLAGS"] = "--xla_gpu_enable_triton_gemm=false" args_dict = get_af3_args() - # Add required flag for CUDA compute capability 7.x if args_dict["cuda_compute_7x"]: os.environ["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter" From dc124987a19841ac53953fe259e84b30a6fb4ca4 Mon Sep 17 00:00:00 2001 From: Julia Varga Date: Mon, 26 Jan 2026 23:08:13 +0200 Subject: [PATCH 2/2] added early stopping by specified score --- run/af3_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/run/af3_utils.py b/run/af3_utils.py index 6c22398..dda6a7f 100644 --- a/run/af3_utils.py +++ b/run/af3_utils.py @@ -149,6 +149,22 @@ def get_af3_parser() -> FileArgumentParser: " AlphaFold 3 will use the seeds as provided in the input JSON." ) + # Early stopping arguments. + parser.add_argument( + "--early_stop_metric", + type=str, + default=None, + help="Metric to use for early stopping and ranking (e.g., 'actifptm', 'ranking_score'). " + "If set, this metric is also used for selecting the best result." + ) + parser.add_argument( + "--early_stop_threshold", + type=float, + default=None, + help="Stop processing seeds when any sample achieves score >= this threshold. " + "Requires --early_stop_metric to be set." + ) + # Control which stages to run. parser.add_argument( "--run_inference", @@ -276,6 +292,8 @@ def get_af3_args(arg_file: Optional[str] = None) -> Dict[str, Any]: if args.num_seeds is not None: if args.num_seeds < 1: raise ValueError("--num_seeds must be greater than or equal to 1.") + if args.early_stop_threshold is not None and args.early_stop_metric is None: + raise ValueError("--early_stop_threshold requires --early_stop_metric to be set.") return vars(args)