Skip to content

[Attention] feat: support PrefixLM#1526

Open
AlpinDale wants to merge 2 commits into
mainfrom
prefixlm
Open

[Attention] feat: support PrefixLM#1526
AlpinDale wants to merge 2 commits into
mainfrom
prefixlm

Conversation

@AlpinDale
Copy link
Copy Markdown
Collaborator

For moondream3 support, in a later PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request lays the groundwork for supporting Prefix Language Modeling (PrefixLM) within the FlexAttention mechanism. It introduces the necessary masking logic and configuration parameters to enable models like moondream3 to utilize PrefixLM, which is a crucial step for their integration in a subsequent PR. The changes primarily involve adding a new mask type and integrating it into the attention metadata and builder processes.

Highlights

  • PrefixLM Masking Function: Introduced a new prefixlm_mask_mod function that defines the specific masking logic for Prefix Language Modeling, allowing bidirectional attention within a prefix and causal attention for the suffix.
  • FlexAttentionMetadata Update: The FlexAttentionMetadata dataclass has been extended with prefixlm (boolean) and prefix_len (integer) fields to control and specify PrefixLM behavior.
  • Conditional Mask Application: Modified the FlexAttention class's build_block_mask and __post_init__ methods to dynamically select and apply the prefixlm_mask_mod when PrefixLM is enabled, alongside existing causal and bidirectional masking options.
  • Builder Method Enhancement: The FlexAttentionBuilder.build method now accepts prefixlm and prefix_len as parameters, allowing for the configuration of PrefixLM during the attention metadata construction.
  • Forward Pass Logic Adjustment: Updated the forward method in FlexAttention to incorporate the prefixlm flag when determining whether to apply non-causal attention, ensuring correct behavior for PrefixLM models.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for PrefixLM attention masking within the FlexAttention backend, laying the groundwork for moondream3 model support. The changes are well-contained, adding a new masking function prefixlm_mask_mod, updating FlexAttentionMetadata with relevant fields, and integrating the new logic alongside the existing causal and bidirectional attention mechanisms. The implementation is sound, and I have one suggestion to simplify the new masking logic for better readability.

Comment thread aphrodite/v1/attention/backends/flex_attention.py Outdated
@AlpinDale
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for PrefixLM attention, which is a prerequisite for moondream3. The changes involve adding a new mask function for PrefixLM, updating metadata classes, and integrating the new logic into the attention backend.

My review focuses on improving code maintainability by reducing duplication and simplifying boolean logic. I've identified a few areas where the code can be refactored for better clarity and conciseness. Overall, the changes are in the right direction.

Comment on lines +263 to +265
return ((q_idx < prefix_len)
| ((q_idx >= prefix_len) & (kv_idx >= prefix_len)
& (q_idx >= kv_idx)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The boolean expression for the mask can be simplified. The (q_idx >= prefix_len) check in the second part of the | operation is redundant. If the first part (q_idx < prefix_len) is false, then q_idx >= prefix_len is implicitly true. Removing this redundant check will make the code slightly more efficient and easier to read.

    return ((q_idx < prefix_len)
            | ((kv_idx >= prefix_len) & (q_idx >= kv_idx)))

Comment on lines +399 to +430
def get_prefixlm_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for PrefixLM.

This function creates the combined mask mod function that handles:
1. The paged attention block mapping
2. The mapping from packed query sequences to logical query entries
3. PrefixLM masking logic

It also by defaults adds the decoding offset to the query indices.
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
"""
assert self.doc_ids is not None

def final_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
# Apply mask modification only for valid indices
return torch.where(
is_valid,
prefixlm_mask_mod(b, h, logical_q_idx, logical_kv_idx, self.prefix_len),
False,
)

return final_mask_mod
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new method get_prefixlm_mask_mod is almost identical to get_causal_mask_mod. This code duplication can make maintenance harder. Consider refactoring the common logic into a shared helper method to improve code quality and reduce redundancy.

Comment on lines +524 to 529
if self.prefixlm:
mask_mod = self.get_prefixlm_mask_mod()
kv_len = self.total_cache_tokens
elif self.causal:
mask_mod = self.get_causal_mask_mod()
kv_len = self.total_cache_tokens
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for prefixlm and causal attention is very similar, leading to duplicated code. You can combine these two branches to make the code more concise and maintainable.

        if self.prefixlm or self.causal:
            if self.prefixlm:
                mask_mod = self.get_prefixlm_mask_mod()
            else:
                mask_mod = self.get_causal_mask_mod()
            kv_len = self.total_cache_tokens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant