[KDA] fix internal output_final_state wrapper issue in SM90#66
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an "output_final_state" flag to the "kda_fwd_prefill" kernel, making the final state tensor allocation and storage optional across the C++, Python, and CUDA layers. The changes include logic to handle null pointers in the kernel and a new test case to verify the functionality. Feedback was provided regarding the robustness of the C++ API, suggesting that the "output_final_state" flag should be respected even when a buffer is explicitly provided to prevent unintended memory stores.
|
I reran the benchmark after this change. The previous suspicious short-sequence result seems to have been benchmark noise. For |
|
Hi, it seems that #64 has a similar state output logic, are there any differences? |
From what I can see, #64 only partially handles the output-state path. In particular:
This PR is narrower: it explicitly passes |
|
Got it, thanks. I will review this PR. |
📌 Description
This PR is a follow-up to #63.
#63 fixed the Python wrapper behavior of Hopper KDA fused prefill so that
final_stateis returned asNonewhenoutput_final_state=False. However, that PR only changed the Python-side return value. The underlying C++/CUDA path still allocated anoutput_statebuffer and passed a non-nullptr_output_stateto the kernel, so the final state was still written internally even when the caller did not request it.This PR passes
output_final_statefrom the Python wrapper to the C++ API, and avoids allocating or storing the final state when it is not requested.Specifically, this PR:
output_final_state;output_stateonly whenoutput_final_state=True;nullptrasptr_output_statewhen final state output is not requested;kv_store()in the SM90 KDA mainloop whenptr_output_state == nullptr;output_final_state=FalsereturnsNoneand produces the same output tensor asoutput_final_state=True.This avoids unnecessary final-state allocation and the final global-memory store while preserving the output tensor.
🔍 Related Issues
mentioned here.
🚀 Pull Request Checklist
Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
tested with
pytest tests/test_kda_fused_fwd.py⚡ Performance
python benchmarks/bench_kda_fused_fwd.pybefore:
1:output_final_state = True
2: output_final_state = False
after:
1: output_final_state = True
2: output_final_state = False
Reviewer Notes