fix(hip): small-kernel correctness and test coverage for CDNA3 (activation, norm, rope)#220
Open
demandal25 wants to merge 6 commits intoROCm:amd-integrationfrom
Open
fix(hip): small-kernel correctness and test coverage for CDNA3 (activation, norm, rope)#220demandal25 wants to merge 6 commits intoROCm:amd-integrationfrom
demandal25 wants to merge 6 commits intoROCm:amd-integrationfrom
Conversation
…n mma_hip.h - math_hip.h: change shfl_xor_sync width 32→64; all callers use lane_mask ≤ 16 so the exchange pattern is identical, but 64 is semantically correct for CDNA3 wavefronts and future-proofs callers with larger lane_mask values. Resolves the FIXME left by diptorupd. - mma_hip.h: replace std::bit_cast (C++20) with __builtin_memcpy for the bf16 ones vector initialisation; fixes JIT recompile failure under -std=c++17. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
hipStream_t stream = stream inside the DISPATCH lambda was a C++ self- initialisation: the RHS 'stream' referred to the newly declared variable (uninitialised), not the outer at::hip::getCurrentHIPStream() return value. Move stream acquisition outside the lambda so the capture is well-defined. Also removes: dead dynamicSmemBytes local (inlined as 0), dead 'auto kernel = act_and_mul_kernel<c_type, gelu>' (assigned, never used). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…n norm tests - Add torch.bfloat16 to all four norm test parametrizations (rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, gemma_fused_add_rmsnorm); previously only fp16 was exercised, leaving the bf16 dispatch path untested. - Fix gemma_fused_add_rms_norm reference: was computing x + residual in bf16 (up to 15% error), now promotes to float32 before adding to match the kernel's accumulation order. - Use dtype-aware tolerances: 1.6e-2 for bfloat16, 1e-3 for float16. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… kernels First test coverage for silu_and_mul, gelu_tanh_and_mul, and gelu_and_mul on ROCm. Tests fp16 and bfloat16 across 10 shapes that cover partial wavefronts, wave64-unaligned sizes (5504), typical LLM FFN dims, and the 1024-thread blockDim cap. References use fp32 accumulation to match the kernel's internal cast behaviour. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…d imports from rope_reference import * relied on tests/test_helpers being on sys.path with no package qualification. Switch to the canonical module path tests.test_helpers.rope_reference and import only the three symbols used (RotaryEmbedding, apply_rotary_emb, precompute_freqs_cis). Also parenthesise the inline ternary inside the kv_last_page_len_e list literal for ruff/black-compatible formatting. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR addresses ROCm/HIP correctness issues on CDNA3 and expands test coverage for small-kernel paths (activation, norm, RoPE), with a focus on bf16 and wave64 behavior.
Changes:
- Fixes HIP kernel launch undefined behavior in fused activation dispatch by acquiring the HIP stream outside the dispatch lambda.
- Updates HIP shuffle XOR wrapper to use wave64 width for CDNA3 and replaces a C++20
std::bit_castusage with a C++17-compatible memcpy-based approach. - Improves ROCm test correctness and coverage: fixes Gemma fused-add RMSNorm reference to match fp32 promotion, adds bf16 coverage across norm tests, adds new fused-activation tests, and replaces a RoPE wildcard import with explicit imports.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/rocm_tests/test_rope_hip.py | Replaces wildcard RoPE reference import with explicit imports; minor formatting tweak in a tensor construction. |
| tests/rocm_tests/test_norm_hip.py | Fixes Gemma fused-add RMSNorm reference precision and adds bf16 coverage/tolerances across norm tests. |
| tests/rocm_tests/test_activation_hip.py | Adds new ROCm tests for fused activation kernels across fp16/bf16 and multiple shapes. |
| include/gpu_iface/backend/hip/mma_hip.h | Removes C++20 std::bit_cast dependency in bf16 rowsum path via memcpy-based bit-cast. |
| include/gpu_iface/backend/hip/math_hip.h | Adjusts HIP shfl_xor_sync wrapper to use width=64 for CDNA3 wavefronts. |
| flashinfer/csrc_rocm/activation.cu | Fixes stream self-initialization UB in dispatch lambdas; simplifies kernel launches by removing dead locals. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
223
to
227
| constexpr uint32_t bf16_one_pair = 0x3F803F80u; // two bf16 1.0 values packed | ||
| constexpr uint64_t bf16_ones = (uint64_t{bf16_one_pair} << 32) | bf16_one_pair; | ||
| f16x4 b = std::bit_cast<f16x4>(bf16_ones); | ||
| f16x4 b; | ||
| __builtin_memcpy(&b, &bf16_ones, sizeof(f16x4)); | ||
| out = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, c, 0, 0, 0); |
Comment on lines
+38
to
+51
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| y = torch.nn.functional.gelu(gate, approximate="tanh") * up | ||
| return y.to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] |
Comment on lines
+38
to
+51
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| y = torch.nn.functional.gelu(gate, approximate="tanh") * up | ||
| return y.to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] |
Comment on lines
+38
to
+51
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] | ||
| y = torch.nn.functional.gelu(gate, approximate="tanh") * up | ||
| return y.to(x.dtype) | ||
|
|
||
|
|
||
| def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| gate, up = x.float()[..., :d], x.float()[..., d:] |
…ation_hip.py mma_hip.h: - Remove unused #include <bit> (was only needed for std::bit_cast which was replaced with __builtin_memcpy in a prior commit). - Add static_assert(sizeof(f16x4) == sizeof(bf16_ones)) before the memcpy type-pun so a layout mismatch is caught at compile time rather than silently producing a partial copy. test_activation_hip.py: - Each reference function called x.float() twice, allocating two fp32 copies of the input. Convert once to x_f32 and slice gate/up from it. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
activation.cu:hipStream_t stream = streaminside each DISPATCH lambda was a C++ self-initialisation (RHS refers to the uninitialized variable, not the outergetCurrentHIPStream()return). Moved stream acquisition outside the lambda. Also removes deaddynamicSmemByteslocals and an unusedauto kernel =binding ingelu_and_mul.shfl_xor_syncwidth for CDNA3 (math_hip.h): changed__shfl_xor(x, lane_mask, 32)→width=64. Resolves the FIXME left in place. All existing callers uselane_mask ≤ 16, so the exchange pattern is identical; the change is semantically correct for 64-lane wavefronts and future-proofs callers.mma_hip.h:std::bit_cast(C++20) in the bf16 rowsum path caused JIT recompile failures under-std=c++17. Replaced with__builtin_memcpy.gemma_fused_add_rms_normreference (test_norm_hip.py): reference was accumulatingx + residualin bf16 (~15% error vs kernel's fp32 promotion). Fixed to match kernel behaviour.rmsnorm,fused_add_rmsnorm,gemma_rmsnorm,gemma_fused_add_rmsnormwere fp16-only; bf16 dispatch path was completely untested.test_activation_hip.py: first test coverage forsilu_and_mul,gelu_tanh_and_mul,gelu_and_mulon ROCm — 60 tests across fp16/bf16 and 10 shapes covering partial wavefronts, wave64-unaligned sizes, and the 1024-thread cap.from rope_reference import *with explicit named imports fromtests.test_helpers.rope_reference.Test plan
pytest tests/rocm_tests/test_norm_hip.py -m "not slow"— 672 passed (fp16 + bf16)pytest tests/rocm_tests/test_activation_hip.py -m "not slow"— 60 passed (fp16 + bf16)pytest tests/rocm_tests/test_rope_hip.py -m "not slow"— 7850 passed, 5 skippedbenchmarks/rocm_benchmarks/bench_fa2_prefill.py --timing-only— no regression (seq=4096 causal ~72.9 TFLOPS)🤖 Generated with Claude Code