diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 7acd068..e8f3a54 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -21,7 +21,7 @@ using OptionalTensor = std::optional; -std::tuple +std::tuple kda_fwd_prefill( OptionalTensor output_, OptionalTensor output_state_, @@ -34,6 +34,7 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, + bool output_final_state, bool safe_gate) { // Q, K, V: [packed_seq, H, D] (already packed by Python layer) auto packed_seq = q.size(0); @@ -52,12 +53,17 @@ kda_fwd_prefill( {packed_seq, num_heads, head_size}, torch::TensorOptions().dtype(q.dtype()).device(q.device())); - // Allocate output state if not provided - torch::Tensor output_state = output_state_.has_value() - ? output_state_.value() - : torch::zeros( - {num_seqs, num_heads, head_size, head_size}, - torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + // output_final_state controls the API side effect. If it is false, ignore + // even an explicitly provided output_state_ buffer so the kernel skips the + // final-state store. + OptionalTensor output_state = std::nullopt; + if (output_final_state) { + output_state = output_state_.has_value() + ? output_state_.value() + : torch::zeros( + {num_seqs, num_heads, head_size, head_size}, + torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + } // Validate dtypes TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16"); @@ -70,7 +76,10 @@ kda_fwd_prefill( TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); - TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous"); + if (output_state.has_value()) { + TORCH_CHECK(output_state->dtype() == torch::kFloat32, "output_state must be float32"); + TORCH_CHECK(output_state->is_contiguous(), "output_state must be contiguous"); + } TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens must be contiguous"); TORCH_CHECK(workspace_buffer.is_contiguous(), "workspace_buffer must be contiguous"); @@ -87,6 +96,8 @@ kda_fwd_prefill( "alpha shape must be [packed_seq, num_heads, head_size]"); alpha_ptr = alpha.data_ptr(); } + + float* output_state_ptr = output_state.has_value() ? output_state->data_ptr() : nullptr; if (beta_.has_value()) { auto& beta = beta_.value(); TORCH_CHECK( @@ -121,7 +132,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state.data_ptr(), + output_state_ptr, reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), @@ -142,7 +153,7 @@ kda_fwd_prefill( kda::sm90::launch_kda_fwd_prefill_kernel( stream, reinterpret_cast(output.data_ptr()), - output_state.data_ptr(), + output_state_ptr, reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index ba2deb6..d14a41c 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -51,7 +51,7 @@ ChunkKDAFwdRecompWU( #endif #if defined(CULA_SM90A_ENABLED) -std::tuple +std::tuple> kda_fwd_prefill( std::optional output_, std::optional output_state_, @@ -64,6 +64,7 @@ kda_fwd_prefill( torch::Tensor const& cu_seqlens, torch::Tensor workspace_buffer, float scale, + bool output_final_state, bool safe_gate); #endif diff --git a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp index d22814d..1f78169 100644 --- a/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp +++ b/csrc/kda/sm90/collective/mainloop_kda_fwd.hpp @@ -1369,7 +1369,9 @@ struct FlatMainloopTmaWarpSpecializedKdaFwd { /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::true_type{}); } - kv_store(); + if (params.ptr_output_state != nullptr) { + kv_store(); + } } template diff --git a/cula/kda/hopper_fused_fwd.py b/cula/kda/hopper_fused_fwd.py index 152cfd3..cc42827 100644 --- a/cula/kda/hopper_fused_fwd.py +++ b/cula/kda/hopper_fused_fwd.py @@ -102,7 +102,8 @@ def forward( workspace_buffer = _get_cache_buf("hopper_kda_fwd_workspace", workspace_size, q.device) # call the C++ kernel - # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) + # Signature: kda_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, + # workspace, scale, output_final_state, safe_gate) o, final_state = cula_cuda.kda_fwd_prefill( None, # output_ (auto-allocate) None, # output_state_ (auto-allocate) @@ -115,13 +116,14 @@ def forward( cu_seqlens, workspace_buffer, scale, + output_final_state, safe_gate, ) # reshape back o = rearrange(o, "(b t) h d -> b t h d", b=batch_size) - return o.to(q.dtype), final_state if output_final_state else None + return o.to(q.dtype), final_state @staticmethod @input_guard diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index b9b59f9..9c32552 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -173,6 +173,53 @@ def test_safe_gate_chunk( assert_close("ht", ref_ht_fla_trans, tri_ht, 0.005) +def test_safe_gate_chunk_no_final_state(): + cula_kda_fused_fwd = get_kda_fused_fwd(device) + + B, T, H, D = 1, 63, 1, 128 + dtype = torch.bfloat16 + + torch.manual_seed(42) + q = torch.rand(B, T, H, D, dtype=dtype, device=device) + k = torch.rand(B, T, H, D, dtype=dtype, device=device) + v = torch.rand(B, T, H, D, dtype=dtype, device=device) + g = F.logsigmoid(torch.randn(B, T, H, D, dtype=torch.float32, device=device)).clamp(-5, 0) + beta = torch.randn(B, T, H, dtype=torch.float32, device=device).sigmoid() + h0 = torch.randn(B, H, D, D, dtype=torch.float32, device=device) + h0_vk = h0.transpose(-1, -2).contiguous() + + q = F.normalize(q, p=2, dim=-1) + k = F.normalize(k, p=2, dim=-1) + + tri_no_state, tri_ht_no_state = cula_kda_fused_fwd( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=h0_vk.clone(), + output_final_state=False, + safe_gate=True, + lower_bound=-5.0, + ) + + tri_with_state, tri_ht_with_state = cula_kda_fused_fwd( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + initial_state=h0_vk.clone(), + output_final_state=True, + safe_gate=True, + lower_bound=-5.0, + ) + + assert tri_ht_no_state is None + assert tri_ht_with_state is not None + assert_close("o", tri_with_state, tri_no_state, 0.005) + + @pytest.mark.parametrize("beta_dtype", [torch.float32, torch.bfloat16], ids=["beta_fp32", "beta_bf16"]) @pytest.mark.parametrize( ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate"),