Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <cstdint>
#include "../util/cuda_runtime.h" //cuda::sm_arch

namespace te_mxfp4 {

Expand Down Expand Up @@ -99,50 +98,6 @@ __device__ __forceinline__ void bf16x4_to_float4(
v3 = uint_as_float(((uint32_t)((packed >> 48) & 0xFFFF)) << 16);
}

// ============================================================================
// WARP PRIMITIVES - AMD-Specific DPP/Swizzle Instructions
// ============================================================================

/*
* ds_swizzle Instructions
* -----------------------
* These perform intra-wavefront data exchange without shared memory.
* The offset parameter encodes the permutation pattern.
*
* Format: offset = (AND_mask << 10) | (OR_mask << 5) | XOR_mask
*
* Common patterns:
* - 0x041F: XOR with lane 1 (exchange with adjacent thread)
* - 0x081F: XOR with lane 2 (exchange 2 positions away)
* - 0x101F: XOR with lane 4 (exchange 4 positions away)
*
* Reference: AMD CDNA4 ISA, ds_swizzle_b32 (page 480)
*/

__device__ __forceinline__ float ds_swizzle_xor1(float val) {
float result;
#ifndef __gfx1250__ //instruction not supported on this GPU
asm volatile(
"ds_swizzle_b32 %0, %1 offset:0x041F\n\t"
"s_waitcnt lgkmcnt(0)"
: "=v"(result) : "v"(val)
);
#endif
return result;
}

__device__ __forceinline__ float ds_swizzle_xor2(float val) {
float result;
#ifndef __gfx1250__ //instruction not supported on this GPU
asm volatile(
"ds_swizzle_b32 %0, %1 offset:0x081F\n\t"
"s_waitcnt lgkmcnt(0)"
: "=v"(result) : "v"(val)
);
#endif
return result;
}

// ============================================================================
// REDUCTION OPERATIONS - Finding Maximum Absolute Value
// ============================================================================
Expand All @@ -159,27 +114,14 @@ __device__ __forceinline__ float ds_swizzle_xor2(float val) {
* Step 3: XOR 1 - reduce 2 values to 1 (thread 0)
*/
__device__ __forceinline__ float warp_reduce_max_8_dpp(float val) {
#ifndef __gfx1250__ //instruction not supported on this GPU
uint32_t v = float_as_uint(val);
uint32_t tmp;

// Step 1: Exchange with thread 4 positions away
asm volatile("ds_swizzle_b32 %0, %1 offset:0x101F" : "=v"(tmp) : "v"(v));
asm volatile("s_waitcnt lgkmcnt(0)" :::);
val = fmaxf(val, uint_as_float(tmp));
v = float_as_uint(val);
val = fmaxf(val, __shfl_xor(val, 4));

// Step 2: Exchange with thread 2 positions away
asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F" : "=v"(tmp) : "v"(v));
asm volatile("s_waitcnt lgkmcnt(0)" :::);
val = fmaxf(val, uint_as_float(tmp));
v = float_as_uint(val);
val = fmaxf(val, __shfl_xor(val, 2));

// Step 3: Exchange with adjacent thread
asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F" : "=v"(tmp) : "v"(v));
asm volatile("s_waitcnt lgkmcnt(0)" :::);
val = fmaxf(val, uint_as_float(tmp));
#endif
val = fmaxf(val, __shfl_xor(val, 1));

return val;
}
Expand Down Expand Up @@ -218,10 +160,10 @@ __device__ __forceinline__ void hadamard16_inplace(
v3 = a1 - a3;

// Stage 2: Cross-thread exchange (XOR 1) - combine pairs
float p0 = ds_swizzle_xor1(v0);
float p1 = ds_swizzle_xor1(v1);
float p2 = ds_swizzle_xor1(v2);
float p3 = ds_swizzle_xor1(v3);
float p0 = __shfl_xor(v0, 1);
float p1 = __shfl_xor(v1, 1);
float p2 = __shfl_xor(v2, 1);
float p3 = __shfl_xor(v3, 1);

bool sign2 = (tid & 1);
v0 = sign2 ? (p0 - v0) : (p0 + v0);
Expand All @@ -230,10 +172,10 @@ __device__ __forceinline__ void hadamard16_inplace(
v3 = sign2 ? (p3 - v3) : (p3 + v3);

// Stage 3: Cross-thread exchange (XOR 2) - final combination
p0 = ds_swizzle_xor2(v0);
p1 = ds_swizzle_xor2(v1);
p2 = ds_swizzle_xor2(v2);
p3 = ds_swizzle_xor2(v3);
p0 = __shfl_xor(v0, 2);
p1 = __shfl_xor(v1, 2);
p2 = __shfl_xor(v2, 2);
p3 = __shfl_xor(v3, 2);

bool sign3 = (tid >> 1) & 1;
float t0 = sign3 ? (p0 - v0) : (p0 + v0);
Expand Down Expand Up @@ -738,10 +680,6 @@ inline void nvte_cast_transpose_mxfp4_fused_shuffle(
int colwise_scale_M_pad, int colwise_scale_N_pad,
hipStream_t stream
) {
//TODO: remove when enable HW code
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 125) {
NVTE_ERROR("Hadamard transform is not yet supported on this GPU");
}
dim3 grid((M + te_mxfp4::BLOCK_M - 1) / te_mxfp4::BLOCK_M,
(N + te_mxfp4::BLOCK_N - 1) / te_mxfp4::BLOCK_N);
dim3 block(te_mxfp4::THREADS_PER_BLOCK);
Expand Down
121 changes: 1 addition & 120 deletions transformer_engine/common/hadamard_transform/hadamard_transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -498,23 +498,7 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri

#ifdef __HIP_PLATFORM_AMD__

// Tiling / layout constants
//
// A 16-point WHT operates on tiles of kHadamardDim (16) elements.
// Each tile is processed by kThreadsPerWHT (4) threads, each holding
// kElemsPerThread (4) values, so one wavefront of kWarpSize (64) lanes
// handles kRowsPerWarp (16) independent tiles (= rows) simultaneously.
// kWarpsPerBlock wavefronts are combined into a thread-block that covers
// kRowsPerBlock (64) consecutive rows.
static constexpr int kHadamardDim = 16; // WHT dimension (H16)
static constexpr int kWarpSize = 64; // Wavefront width
static constexpr int kThreadsPerWHT = 4; // threads per 16-pt WHT
static constexpr int kElemsPerThread = 4; // elements each thread owns
static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16
static constexpr int kWarpsPerBlock = 4;
static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64
static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256
static constexpr float kHadamardScale = 0.25f; // 1/sqrt(16)
#include "wht16.cuh"

// Reduce per-warp amax values in warp 0 and atomically update a global amax.
__device__ __forceinline__ void reduce_block_amax(
Expand All @@ -527,26 +511,6 @@ __device__ __forceinline__ void reduce_block_amax(
atomicMaxFloat(global_amax, val);
}

// ds_swizzle: sub-wavefront exchange without LDS.
// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu.
__device__ __forceinline__ float ds_swizzle_xor1(float v) {
float r;
Comment thread
ipanfilo marked this conversation as resolved.
#ifndef __gfx1250__ //instruction not supported on this GPU
asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t"
"s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v));
#endif
return r;
}

__device__ __forceinline__ float ds_swizzle_xor2(float v) {
float r;
#ifndef __gfx1250__ //instruction not supported on this GPU
asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t"
"s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v));
#endif
return r;
}

// BF16 helpers
__device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast<float>(v); }
__device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return static_cast<__hip_bfloat16>(v); }
Expand All @@ -573,89 +537,6 @@ __device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, fl
| ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48);
}

// -----------------------------------------------------------------------
// 16-point WHT via the Kronecker trick (no shared memory)
// -----------------------------------------------------------------------
//
// 1. The vec operator
// vec() flattens a matrix into a column vector by stacking its
// columns one on top of the other:
//
// X = |a c| vec(X) = |a|
// |b d| |b|
// |c|
// |d|
//
// 2. The "Kronecker trick" for 1D -> 2D
// The fundamental identity that connects these concepts is:
//
// vec(B . X . A^T) = (A (x) B) . vec(X)
//
// For a 16-point Hadamard transform (H16 = H4 (x) H4),
// set A = H4 and B = H4. The formula becomes:
//
// H16 . x = vec(H4 . X . H4^T)
//
// 3. Data layout (column-major, one column per thread)
// Reshape the 16-element 1D vector x into a 4x4 matrix X
// by filling columns first:
//
// X = | x0 x4 x8 x12 | thread 0 holds col 0: v0..v3 = x0 ..x3
// | x1 x5 x9 x13 | thread 1 holds col 1: v0..v3 = x4 ..x7
// | x2 x6 x10 x14 | thread 2 holds col 2: v0..v3 = x8 ..x11
// | x3 x7 x11 x15 | thread 3 holds col 3: v0..v3 = x12..x15
//
// 4. Three-stage computation
// Stage 1 (local H4) : left-multiply H4 . X (within each thread)
// Stage 2 (xor-1 swap) : \ (across 4 threads)
// Stage 3 (xor-2 swap) : / right-multiply . H4^T together these two butterfly stages = H4^T
//
// Result: vec(H4 . X . H4^T) = H16 . x
//
// 5. Randomised Hadamard Transform (RHT)
// A diagonal sign matrix D (from sign_mask) is applied either
// before the WHT (apply_pre=true, forward) or after (inverse).
//
// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace,
// extended with NV random_sign_mask (uint16_t bitmask).
// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3).
// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse).
__device__ __forceinline__ void wht16(
float& v0, float& v1, float& v2, float& v3,
int thread_in_group, uint16_t sign_mask, bool apply_pre) {
auto sgn = [&](int k) -> float {
return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f;
};

if (apply_pre) {
v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3);
}

// Stage 1: local H4
float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3;
v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3;

// Stage 2: cross-thread XOR-1
{ float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1),
p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3);
bool up=(thread_in_group&1);
v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1);
v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); }

// Stage 3: cross-thread XOR-2
{ float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1),
p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3);
bool up=(thread_in_group>>1)&1;
v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1);
v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); }

v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale;

if (!apply_pre) {
v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3);
}
}

// Grid: blockIdx.x = col tile [0, row_length/16)
// blockIdx.y = row batch [0, ceil(num_rows/64))
// Block: 256 threads = 4 wavefronts of 64 lanes.
Expand Down
23 changes: 4 additions & 19 deletions transformer_engine/common/hadamard_transform/wht16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,6 @@ static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64
static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256
static constexpr float kHadamardScale = 0.25f;

// ds_swizzle: sub-wavefront exchange without LDS.
__device__ __forceinline__ float ds_swizzle_xor1(float v) {
float r;
asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t"
"s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v));
return r;
}

__device__ __forceinline__ float ds_swizzle_xor2(float v) {
float r;
asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t"
"s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v));
return r;
}

// -----------------------------------------------------------------------
// 16-point WHT via the Kronecker trick (no shared memory)
// -----------------------------------------------------------------------
Expand Down Expand Up @@ -101,15 +86,15 @@ __device__ __forceinline__ void wht16(
v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3;

// Stage 2: cross-thread XOR-1
{ float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1),
p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3);
{ float p0=__shfl_xor(v0, 1), p1=__shfl_xor(v1, 1),
p2=__shfl_xor(v2, 1), p3=__shfl_xor(v3, 1);
bool up=(thread_in_group&1);
v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1);
v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); }

// Stage 3: cross-thread XOR-2
{ float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1),
p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3);
{ float p0=__shfl_xor(v0, 2), p1=__shfl_xor(v1, 2),
p2=__shfl_xor(v2, 2), p3=__shfl_xor(v3, 2);
bool up=(thread_in_group>>1)&1;
v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1);
v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); }
Expand Down
Loading