Skip to content

Comments

fix: correct bidirectional attention masking in LlamaBidirectionalModel#1349

Open
oliverholworthy wants to merge 1 commit intomainfrom
oholworthy/biencoder-bidirectional-attention
Open

fix: correct bidirectional attention masking in LlamaBidirectionalModel#1349
oliverholworthy wants to merge 1 commit intomainfrom
oholworthy/biencoder-bidirectional-attention

Conversation

@oliverholworthy
Copy link
Contributor

Update the biencoder LlamaBidirectionalModel to handle masking correctly with newer transformers versions and different attention implementations.

  • Replace the _update_causal_mask override with a proper _create_bidirectional_mask method that produces the correct mask format for each attention backend (SDPA/eager, Flash Attention 2, and the native transformers >= 5.0 path)
  • The previous implementation only checked for zeros in an already-expanded mask and could silently pass incorrect mask shapes/dtypes to the attention layer

Changelog

The old _update_causal_mask had two problems:

  1. Removed upstream: transformers 4.53 deleted _update_causal_mask from LlamaModel, so our override becomes a dead method that is never called. The model silently falls back to causal masking when using sdpa or eager attn_implementation
  2. Wrong mask format for SDPA/eager: it returned None or passed through the raw 2D mask without expanding to 4D or casting to float, which could cause shape mismatches or dtype errors in scaled_dot_product_attention.
  3. No awareness of attention backend: Flash Attention 2 expects a 2D (batch, seq_len) mask, while SDPA/eager need a 4D (batch, 1, seq_len, seq_len) float mask. The old code didn't distinguish between them.

The new _create_bidirectional_mask method:

  • Uses create_bidirectional_mask from transformers.masking_utils when available (transformers >= 5.0), which handles all backends natively.
  • Falls back to backend-aware logic for older transformers: passes the 2D mask directly for flash_attention_2, and uses _prepare_4d_attention_mask for SDPA/eager.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant