Skip to content

Conversation

@Livinfly
Copy link

@Livinfly Livinfly commented Dec 22, 2025

What does this PR do?

fix redundancy padding in block_sparse_flash_attention.

  1. Previously, if seq_len was a multiple of block_size (e.g., 64 % 64 == 0), the code incorrectly added an extra full block of padding (padding=64) instead of zero, which cost significantly when seq_len is small, might double the shape.

logical test code

block = 8
n = 16

for i in range(n):
    print(i, end=' ')
print('')

print('block:', block)

print('old:')
for i in range(n):
    mask = block - 1
    print((block - (i & mask)), end=' ')
print('')

print('new:')
for i in range(n):
    mask = block - 1
    print((block - (i & mask)) & mask, end=' ')
print('')

result

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 
block: 8
old:
8 7 6 5 4 3 2 1 8 7 6 5 4 3 2 1 
new:
0 7 6 5 4 3 2 1 0 7 6 5 4 3 2 1
  1. Fix Out-of-Bounds Access in block_sparse_flash_attention.

Previously, the shape of seqlens won't change with batch_size, which leads to (Line 48) seqlen = tl.load(seqlens + off_hz // H) (exactly point to seqlens[batch_size_id]), accessing out-of-bounds.

Following is testing result using

tl.device_print('seqlen: ', seqlen)
tl.device_print('idx_bz: ', idx_bz)  # idx_bz = idx_bz_h // n_heads, or, off_hz // H in your code

configured with

block_size_M = 128
block_size_N = 128
b, h, q_len, kv_len, d = 2, 4, 1024, 2048, 64

(in my adapted code, nevermind)

and I can originally find correctness error by following test code

def topk_mask(b, h, q_idx, kv_idx, block_index, block_size_M=128, block_size_N=128):
    q_block_idx = q_idx // block_size_M
    kv_block_idx = kv_idx // block_size_N
    allowed_k_blocks = block_index[b, h, q_block_idx]
    is_in_topk = (allowed_k_blocks == kv_block_idx[..., None])
    
    is_causal_mask = q_idx >= kv_idx
    return is_in_topk.any(dim=-1) & is_causal_mask

...

block_mask = create_block_mask(
        partial(topk_mask, block_index=block_index,
                block_size_M=block_size_M, block_size_N=block_size_N),
        b, h, q_len_padded, k_len_padded, BLOCK_SIZE=(
            block_size_M, block_size_N),
        device="cuda",
        _compile=True
    )
    print(block_mask)
    out_flexattn = flex_attention(
        q_padded, k_padded, v_padded, block_mask=block_mask)
    assert isinstance(out_flexattn, torch.Tensor)
    out_flexattn = out_flexattn[..., :seq_lens, :]

    # print(out_triton)
    # print(out_flexattn)
    
    print((out_triton - out_flexattn).max())
    print((out_triton - out_flexattn).min())

    assert torch.allclose(
        out_triton, out_flexattn, atol=1e-2), 'test sparse_attention_single_triton with FlexAttention, Wrong!'
    print('test sparse_attention_single_triton with FlexAttention, Accept!')

The following log is device_print in triton.

image image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@iofu728

Signed-off-by: Livinfly <luojie3m@gmail.com>
…tion

Signed-off-by: Livinfly <luojie3m@gmail.com>
@Livinfly Livinfly changed the title Fix(MInference): fix redundancy pad in block_sparse_flash_attention Fix(MInference): fix redundancy padding and Out-of-Bounds Access in block_sparse_flash_attention Dec 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant