Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions hyvideo/modules/attenion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,44 @@ def attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
else:
attn1 = F.scaled_dot_product_attention(
q[:, :, :cu_seqlens_q[1]],
k[:, :, :cu_seqlens_kv[1]],
v[:, :, :cu_seqlens_kv[1]],
attn_mask=attn_mask,
dropout_p=drop_rate,
is_causal=causal
)
attn2 = F.scaled_dot_product_attention(
q[:, :, cu_seqlens_q[1]:],
k[:, :, cu_seqlens_kv[1]:],
v[:, :, cu_seqlens_kv[1]:],
attn_mask=None,
dropout_p=drop_rate,
is_causal=False
)
x = torch.cat([attn1, attn2], dim=2)
# q shape after pre_attn_layout: [batch_size, heads, seq_len, dim]
# cu_seqlens has 2*batch_size+1 entries encoding (valid, padding) pairs
# per batch item. Iterate over all batch items to handle batch_size > 1.
seq_len_q = q.shape[2]
seq_len_kv = k.shape[2]
x = torch.empty_like(q)
for i in range(batch_size):
# Per-batch valid lengths (cu_seqlens stores absolute offsets
# into the flattened sequence; subtract the batch base offset)
valid_q = (cu_seqlens_q[2 * i + 1] - cu_seqlens_q[2 * i]).item()
valid_kv = (cu_seqlens_kv[2 * i + 1] - cu_seqlens_kv[2 * i]).item()

# Attend over valid (image + text) tokens
attn1 = F.scaled_dot_product_attention(
q[i : i + 1, :, :valid_q],
k[i : i + 1, :, :valid_kv],
v[i : i + 1, :, :valid_kv],
attn_mask=attn_mask,
dropout_p=drop_rate,
is_causal=causal,
)
x[i : i + 1, :, :valid_q] = attn1

# Attend over padding tokens (no mask, no causal)
pad_q = seq_len_q - valid_q
pad_kv = seq_len_kv - valid_kv
if pad_q > 0 and pad_kv > 0:
attn2 = F.scaled_dot_product_attention(
q[i : i + 1, :, valid_q:],
k[i : i + 1, :, valid_kv:],
v[i : i + 1, :, valid_kv:],
attn_mask=None,
dropout_p=drop_rate,
is_causal=False,
)
x[i : i + 1, :, valid_q:] = attn2
elif pad_q > 0:
x[i : i + 1, :, valid_q:] = 0
elif mode == "flash":
x = flash_attn_varlen_func(
q,
Expand All @@ -143,6 +164,21 @@ def attention(
b, a, s, _ = q.shape
s1 = k.size(2)
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)

# Build a block-diagonal mask from cu_seqlens so that tokens in
# different segments (valid vs padding) cannot attend to each other.
# Without this, vanilla attention bleeds across segment boundaries,
# causing severe quality degradation compared to flash attention.
if cu_seqlens_q is not None:
for i in range(batch_size):
# Per-batch valid lengths
valid_q = (cu_seqlens_q[2 * i + 1] - cu_seqlens_q[2 * i]).item()
valid_kv = (cu_seqlens_kv[2 * i + 1] - cu_seqlens_kv[2 * i]).item()
# Valid query tokens must not attend to padding key tokens
attn_bias[i, :, :valid_q, valid_kv:] = float("-inf")
# Padding query tokens must not attend to valid key tokens
attn_bias[i, :, valid_q:, :valid_kv] = float("-inf")

if causal:
# Only applied to self attention
assert (
Expand Down