Skip to content

fix(hip): small-kernel correctness and test coverage for CDNA3 (activation, norm, rope)#220

Open
demandal25 wants to merge 6 commits intoROCm:amd-integrationfrom
demandal25:port-small-kernels-to-hip
Open

fix(hip): small-kernel correctness and test coverage for CDNA3 (activation, norm, rope)#220
demandal25 wants to merge 6 commits intoROCm:amd-integrationfrom
demandal25:port-small-kernels-to-hip

Conversation

@demandal25
Copy link
Copy Markdown
Collaborator

@demandal25 demandal25 commented May 5, 2026

Summary

  • Fix UB in activation.cu: hipStream_t stream = stream inside each DISPATCH lambda was a C++ self-initialisation (RHS refers to the uninitialized variable, not the outer getCurrentHIPStream() return). Moved stream acquisition outside the lambda. Also removes dead dynamicSmemBytes locals and an unused auto kernel = binding in gelu_and_mul.
  • Fix shfl_xor_sync width for CDNA3 (math_hip.h): changed __shfl_xor(x, lane_mask, 32)width=64. Resolves the FIXME left in place. All existing callers use lane_mask ≤ 16, so the exchange pattern is identical; the change is semantically correct for 64-lane wavefronts and future-proofs callers.
  • Restore C++17 compatibility in mma_hip.h: std::bit_cast (C++20) in the bf16 rowsum path caused JIT recompile failures under -std=c++17. Replaced with __builtin_memcpy.
  • Fix gemma_fused_add_rms_norm reference (test_norm_hip.py): reference was accumulating x + residual in bf16 (~15% error vs kernel's fp32 promotion). Fixed to match kernel behaviour.
  • Add bfloat16 coverage to all norm tests: rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, gemma_fused_add_rmsnorm were fp16-only; bf16 dispatch path was completely untested.
  • Add test_activation_hip.py: first test coverage for silu_and_mul, gelu_tanh_and_mul, gelu_and_mul on ROCm — 60 tests across fp16/bf16 and 10 shapes covering partial wavefronts, wave64-unaligned sizes, and the 1024-thread cap.
  • Fix rope test wildcard import: replace from rope_reference import * with explicit named imports from tests.test_helpers.rope_reference.

Test plan

  • pytest tests/rocm_tests/test_norm_hip.py -m "not slow" — 672 passed (fp16 + bf16)
  • pytest tests/rocm_tests/test_activation_hip.py -m "not slow" — 60 passed (fp16 + bf16)
  • pytest tests/rocm_tests/test_rope_hip.py -m "not slow" — 7850 passed, 5 skipped
  • benchmarks/rocm_benchmarks/bench_fa2_prefill.py --timing-only — no regression (seq=4096 causal ~72.9 TFLOPS)
image

🤖 Generated with Claude Code

demandal25 and others added 5 commits May 5, 2026 01:43
…n mma_hip.h

- math_hip.h: change shfl_xor_sync width 32→64; all callers use lane_mask ≤ 16
  so the exchange pattern is identical, but 64 is semantically correct for CDNA3
  wavefronts and future-proofs callers with larger lane_mask values. Resolves
  the FIXME left by diptorupd.
- mma_hip.h: replace std::bit_cast (C++20) with __builtin_memcpy for the bf16
  ones vector initialisation; fixes JIT recompile failure under -std=c++17.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
hipStream_t stream = stream inside the DISPATCH lambda was a C++ self-
initialisation: the RHS 'stream' referred to the newly declared variable
(uninitialised), not the outer at::hip::getCurrentHIPStream() return value.
Move stream acquisition outside the lambda so the capture is well-defined.

Also removes: dead dynamicSmemBytes local (inlined as 0), dead
'auto kernel = act_and_mul_kernel<c_type, gelu>' (assigned, never used).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…n norm tests

- Add torch.bfloat16 to all four norm test parametrizations (rmsnorm,
  fused_add_rmsnorm, gemma_rmsnorm, gemma_fused_add_rmsnorm); previously
  only fp16 was exercised, leaving the bf16 dispatch path untested.
- Fix gemma_fused_add_rms_norm reference: was computing x + residual in
  bf16 (up to 15% error), now promotes to float32 before adding to match
  the kernel's accumulation order.
- Use dtype-aware tolerances: 1.6e-2 for bfloat16, 1e-3 for float16.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
… kernels

First test coverage for silu_and_mul, gelu_tanh_and_mul, and gelu_and_mul
on ROCm. Tests fp16 and bfloat16 across 10 shapes that cover partial
wavefronts, wave64-unaligned sizes (5504), typical LLM FFN dims, and the
1024-thread blockDim cap. References use fp32 accumulation to match the
kernel's internal cast behaviour.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…d imports

from rope_reference import * relied on tests/test_helpers being on sys.path
with no package qualification. Switch to the canonical module path
tests.test_helpers.rope_reference and import only the three symbols used
(RotaryEmbedding, apply_rotary_emb, precompute_freqs_cis).

Also parenthesise the inline ternary inside the kv_last_page_len_e list
literal for ruff/black-compatible formatting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@demandal25 demandal25 marked this pull request as ready for review May 7, 2026 19:26
Copilot AI review requested due to automatic review settings May 7, 2026 19:26
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses ROCm/HIP correctness issues on CDNA3 and expands test coverage for small-kernel paths (activation, norm, RoPE), with a focus on bf16 and wave64 behavior.

Changes:

  • Fixes HIP kernel launch undefined behavior in fused activation dispatch by acquiring the HIP stream outside the dispatch lambda.
  • Updates HIP shuffle XOR wrapper to use wave64 width for CDNA3 and replaces a C++20 std::bit_cast usage with a C++17-compatible memcpy-based approach.
  • Improves ROCm test correctness and coverage: fixes Gemma fused-add RMSNorm reference to match fp32 promotion, adds bf16 coverage across norm tests, adds new fused-activation tests, and replaces a RoPE wildcard import with explicit imports.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
tests/rocm_tests/test_rope_hip.py Replaces wildcard RoPE reference import with explicit imports; minor formatting tweak in a tensor construction.
tests/rocm_tests/test_norm_hip.py Fixes Gemma fused-add RMSNorm reference precision and adds bf16 coverage/tolerances across norm tests.
tests/rocm_tests/test_activation_hip.py Adds new ROCm tests for fused activation kernels across fp16/bf16 and multiple shapes.
include/gpu_iface/backend/hip/mma_hip.h Removes C++20 std::bit_cast dependency in bf16 rowsum path via memcpy-based bit-cast.
include/gpu_iface/backend/hip/math_hip.h Adjusts HIP shfl_xor_sync wrapper to use width=64 for CDNA3 wavefronts.
flashinfer/csrc_rocm/activation.cu Fixes stream self-initialization UB in dispatch lambdas; simplifies kernel launches by removing dead locals.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 223 to 227
constexpr uint32_t bf16_one_pair = 0x3F803F80u; // two bf16 1.0 values packed
constexpr uint64_t bf16_ones = (uint64_t{bf16_one_pair} << 32) | bf16_one_pair;
f16x4 b = std::bit_cast<f16x4>(bf16_ones);
f16x4 b;
__builtin_memcpy(&b, &bf16_ones, sizeof(f16x4));
out = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a, b, c, 0, 0, 0);
Comment thread tests/rocm_tests/test_activation_hip.py Outdated
Comment on lines +38 to +51
gate, up = x.float()[..., :d], x.float()[..., d:]
return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype)


def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
y = torch.nn.functional.gelu(gate, approximate="tanh") * up
return y.to(x.dtype)


def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
Comment thread tests/rocm_tests/test_activation_hip.py Outdated
Comment on lines +38 to +51
gate, up = x.float()[..., :d], x.float()[..., d:]
return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype)


def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
y = torch.nn.functional.gelu(gate, approximate="tanh") * up
return y.to(x.dtype)


def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
Comment thread tests/rocm_tests/test_activation_hip.py Outdated
Comment on lines +38 to +51
gate, up = x.float()[..., :d], x.float()[..., d:]
return (gate / (1.0 + torch.exp(-gate)) * up).to(x.dtype)


def _gelu_tanh_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
y = torch.nn.functional.gelu(gate, approximate="tanh") * up
return y.to(x.dtype)


def _gelu_and_mul_ref(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
gate, up = x.float()[..., :d], x.float()[..., d:]
…ation_hip.py

mma_hip.h:
- Remove unused #include <bit> (was only needed for std::bit_cast which was
  replaced with __builtin_memcpy in a prior commit).
- Add static_assert(sizeof(f16x4) == sizeof(bf16_ones)) before the memcpy
  type-pun so a layout mismatch is caught at compile time rather than
  silently producing a partial copy.

test_activation_hip.py:
- Each reference function called x.float() twice, allocating two fp32
  copies of the input. Convert once to x_f32 and slice gate/up from it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants