Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions csrc/api/kda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

using OptionalTensor = std::optional<torch::Tensor>;

std::tuple<torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, OptionalTensor>
kda_fwd_prefill(
OptionalTensor output_,
OptionalTensor output_state_,
Expand All @@ -34,6 +34,7 @@ kda_fwd_prefill(
torch::Tensor const& cu_seqlens,
torch::Tensor workspace_buffer,
float scale,
bool output_final_state,
bool safe_gate) {
// Q, K, V: [packed_seq, H, D] (already packed by Python layer)
auto packed_seq = q.size(0);
Expand All @@ -52,12 +53,17 @@ kda_fwd_prefill(
{packed_seq, num_heads, head_size},
torch::TensorOptions().dtype(q.dtype()).device(q.device()));

// Allocate output state if not provided
torch::Tensor output_state = output_state_.has_value()
? output_state_.value()
: torch::zeros(
{num_seqs, num_heads, head_size, head_size},
torch::TensorOptions().dtype(torch::kFloat32).device(q.device()));
// output_final_state controls the API side effect. If it is false, ignore
// even an explicitly provided output_state_ buffer so the kernel skips the
// final-state store.
OptionalTensor output_state = std::nullopt;
if (output_final_state) {
output_state = output_state_.has_value()
? output_state_.value()
: torch::zeros(
{num_seqs, num_heads, head_size, head_size},
torch::TensorOptions().dtype(torch::kFloat32).device(q.device()));
}

// Validate dtypes
TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16");
Expand All @@ -70,7 +76,10 @@ kda_fwd_prefill(
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous");
if (output_state.has_value()) {
TORCH_CHECK(output_state->dtype() == torch::kFloat32, "output_state must be float32");
TORCH_CHECK(output_state->is_contiguous(), "output_state must be contiguous");
}
TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens must be contiguous");
TORCH_CHECK(workspace_buffer.is_contiguous(), "workspace_buffer must be contiguous");

Expand All @@ -87,6 +96,8 @@ kda_fwd_prefill(
"alpha shape must be [packed_seq, num_heads, head_size]");
alpha_ptr = alpha.data_ptr<float>();
}

float* output_state_ptr = output_state.has_value() ? output_state->data_ptr<float>() : nullptr;
if (beta_.has_value()) {
auto& beta = beta_.value();
TORCH_CHECK(
Expand Down Expand Up @@ -121,7 +132,7 @@ kda_fwd_prefill(
kda::sm90::launch_kda_fwd_prefill_kernel<Sm90, bf16, bf16, float, bf16>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
output_state_ptr,
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
Expand All @@ -142,7 +153,7 @@ kda_fwd_prefill(
kda::sm90::launch_kda_fwd_prefill_kernel<Sm90, bf16, bf16, float, float>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
output_state_ptr,
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
Expand Down
3 changes: 2 additions & 1 deletion csrc/api/pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ ChunkKDAFwdRecompWU(
#endif

#if defined(CULA_SM90A_ENABLED)
std::tuple<torch::Tensor, torch::Tensor>
std::tuple<torch::Tensor, std::optional<torch::Tensor>>
kda_fwd_prefill(
std::optional<torch::Tensor> output_,
std::optional<torch::Tensor> output_state_,
Expand All @@ -64,6 +64,7 @@ kda_fwd_prefill(
torch::Tensor const& cu_seqlens,
torch::Tensor workspace_buffer,
float scale,
bool output_final_state,
bool safe_gate);
#endif

Expand Down
4 changes: 3 additions & 1 deletion csrc/kda/sm90/collective/mainloop_kda_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,9 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd {
/*is_first_block_=*/cute::false_type{},
/*is_final_block_=*/cute::true_type{});
}
kv_store();
if (params.ptr_output_state != nullptr) {
kv_store();
}
}

template <class ProblemShape, class WorkDesc>
Expand Down
6 changes: 4 additions & 2 deletions cula/kda/hopper_fused_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def forward(
workspace_buffer = _get_cache_buf("hopper_kda_fwd_workspace", workspace_size, q.device)

# call the C++ kernel
# Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate)
# Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens,
# workspace, scale, output_final_state, safe_gate)
o, final_state = cula_cuda.kda_fwd_prefill(
None, # output_ (auto-allocate)
None, # output_state_ (auto-allocate)
Expand All @@ -115,13 +116,14 @@ def forward(
cu_seqlens,
workspace_buffer,
scale,
output_final_state,
safe_gate,
)

# reshape back
o = rearrange(o, "(b t) h d -> b t h d", b=batch_size)

return o.to(q.dtype), final_state if output_final_state else None
return o.to(q.dtype), final_state

@staticmethod
@input_guard
Expand Down
47 changes: 47 additions & 0 deletions tests/test_kda_fused_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,53 @@ def test_safe_gate_chunk(
assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005)


def test_safe_gate_chunk_no_final_state():
cula_kda_fused_fwd = get_kda_fused_fwd(device)

B, T, H, D = 1, 63, 1, 128
dtype = torch.bfloat16

torch.manual_seed(42)
q = torch.rand(B, T, H, D, dtype=dtype, device=device)
k = torch.rand(B, T, H, D, dtype=dtype, device=device)
v = torch.rand(B, T, H, D, dtype=dtype, device=device)
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float32, device=device)).clamp(-5, 0)
beta = torch.randn(B, T, H, dtype=torch.float32, device=device).sigmoid()
h0 = torch.randn(B, H, D, D, dtype=torch.float32, device=device)
h0_vk = h0.transpose(-1, -2).contiguous()

q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)

tri_no_state, tri_ht_no_state = cula_kda_fused_fwd(
q=q.clone(),
k=k.clone(),
v=v.clone(),
g=g.clone(),
beta=beta.clone(),
initial_state=h0_vk.clone(),
output_final_state=False,
safe_gate=True,
lower_bound=-5.0,
)

tri_with_state, tri_ht_with_state = cula_kda_fused_fwd(
q=q.clone(),
k=k.clone(),
v=v.clone(),
g=g.clone(),
beta=beta.clone(),
initial_state=h0_vk.clone(),
output_final_state=True,
safe_gate=True,
lower_bound=-5.0,
)

assert tri_ht_no_state is None
assert tri_ht_with_state is not None
assert_close("o", tri_with_state, tri_no_state, 0.005)


@pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"])
@pytest.mark.parametrize(
("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"),
Expand Down