Open
Conversation
(+) fix: fix access to global config fix: allow lack of predicted_aligned_error head
afc75f2 to
c2e8505
Compare
Author
|
@sokrypton I think this is ready for merging. It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up) You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes |
curtisdow1973-sys
approved these changes
Oct 3, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements FlashAttention similarly to google-deepmind/alphafold#931
For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3$\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])
Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.
Notes
Key variations from a reference flash attention kernel are:
There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.
Comments
use_flash_attention=FalseI haven't changed behaviour: here's a colab showing same 37.3s runtime from the main branch.