Conversation
Summary of ChangesHello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the groundwork for supporting Prefix Language Modeling (PrefixLM) within the FlexAttention mechanism. It introduces the necessary masking logic and configuration parameters to enable models like moondream3 to utilize PrefixLM, which is a crucial step for their integration in a subsequent PR. The changes primarily involve adding a new mask type and integrating it into the attention metadata and builder processes. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for PrefixLM attention masking within the FlexAttention backend, laying the groundwork for moondream3 model support. The changes are well-contained, adding a new masking function prefixlm_mask_mod, updating FlexAttentionMetadata with relevant fields, and integrating the new logic alongside the existing causal and bidirectional attention mechanisms. The implementation is sound, and I have one suggestion to simplify the new masking logic for better readability.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for PrefixLM attention, which is a prerequisite for moondream3. The changes involve adding a new mask function for PrefixLM, updating metadata classes, and integrating the new logic into the attention backend.
My review focuses on improving code maintainability by reducing duplication and simplifying boolean logic. I've identified a few areas where the code can be refactored for better clarity and conciseness. Overall, the changes are in the right direction.
| return ((q_idx < prefix_len) | ||
| | ((q_idx >= prefix_len) & (kv_idx >= prefix_len) | ||
| & (q_idx >= kv_idx))) |
There was a problem hiding this comment.
The boolean expression for the mask can be simplified. The (q_idx >= prefix_len) check in the second part of the | operation is redundant. If the first part (q_idx < prefix_len) is false, then q_idx >= prefix_len is implicitly true. Removing this redundant check will make the code slightly more efficient and easier to read.
return ((q_idx < prefix_len)
| ((kv_idx >= prefix_len) & (q_idx >= kv_idx)))| def get_prefixlm_mask_mod(self) -> _mask_mod_signature: | ||
| """Creates the mask_mod function for PrefixLM. | ||
|
|
||
| This function creates the combined mask mod function that handles: | ||
| 1. The paged attention block mapping | ||
| 2. The mapping from packed query sequences to logical query entries | ||
| 3. PrefixLM masking logic | ||
|
|
||
| It also by defaults adds the decoding offset to the query indices. | ||
| With this info we create the "logical" indices that are passed to | ||
| mask_mod functions. This allows mask mod functions to be agnostic to | ||
| layout of the query and key/value tensors. | ||
| """ | ||
| assert self.doc_ids is not None | ||
|
|
||
| def final_mask_mod( | ||
| b: torch.Tensor, | ||
| h: torch.Tensor, | ||
| q_idx: torch.Tensor, | ||
| physical_kv_idx: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| (is_valid, logical_q_idx, | ||
| logical_kv_idx) = self._convert_physical_to_logical( | ||
| self.doc_ids, q_idx, physical_kv_idx) | ||
| # Apply mask modification only for valid indices | ||
| return torch.where( | ||
| is_valid, | ||
| prefixlm_mask_mod(b, h, logical_q_idx, logical_kv_idx, self.prefix_len), | ||
| False, | ||
| ) | ||
|
|
||
| return final_mask_mod |
| if self.prefixlm: | ||
| mask_mod = self.get_prefixlm_mask_mod() | ||
| kv_len = self.total_cache_tokens | ||
| elif self.causal: | ||
| mask_mod = self.get_causal_mask_mod() | ||
| kv_len = self.total_cache_tokens |
There was a problem hiding this comment.
The logic for prefixlm and causal attention is very similar, leading to duplicated code. You can combine these two branches to make the code more concise and maintainable.
if self.prefixlm or self.causal:
if self.prefixlm:
mask_mod = self.get_prefixlm_mask_mod()
else:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
For moondream3 support, in a later PR.