diff --git a/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4_shuffled.cuh b/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4_shuffled.cuh index 932e06a4a..7baa20350 100644 --- a/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4_shuffled.cuh +++ b/transformer_engine/common/cast/mxfp4/cast_transpose_mxfp4_shuffled.cuh @@ -35,7 +35,6 @@ #include #include #include -#include "../util/cuda_runtime.h" //cuda::sm_arch namespace te_mxfp4 { @@ -99,50 +98,6 @@ __device__ __forceinline__ void bf16x4_to_float4( v3 = uint_as_float(((uint32_t)((packed >> 48) & 0xFFFF)) << 16); } -// ============================================================================ -// WARP PRIMITIVES - AMD-Specific DPP/Swizzle Instructions -// ============================================================================ - -/* - * ds_swizzle Instructions - * ----------------------- - * These perform intra-wavefront data exchange without shared memory. - * The offset parameter encodes the permutation pattern. - * - * Format: offset = (AND_mask << 10) | (OR_mask << 5) | XOR_mask - * - * Common patterns: - * - 0x041F: XOR with lane 1 (exchange with adjacent thread) - * - 0x081F: XOR with lane 2 (exchange 2 positions away) - * - 0x101F: XOR with lane 4 (exchange 4 positions away) - * - * Reference: AMD CDNA4 ISA, ds_swizzle_b32 (page 480) - */ - -__device__ __forceinline__ float ds_swizzle_xor1(float val) { - float result; -#ifndef __gfx1250__ //instruction not supported on this GPU - asm volatile( - "ds_swizzle_b32 %0, %1 offset:0x041F\n\t" - "s_waitcnt lgkmcnt(0)" - : "=v"(result) : "v"(val) - ); -#endif - return result; -} - -__device__ __forceinline__ float ds_swizzle_xor2(float val) { - float result; -#ifndef __gfx1250__ //instruction not supported on this GPU - asm volatile( - "ds_swizzle_b32 %0, %1 offset:0x081F\n\t" - "s_waitcnt lgkmcnt(0)" - : "=v"(result) : "v"(val) - ); -#endif - return result; -} - // ============================================================================ // REDUCTION OPERATIONS - Finding Maximum Absolute Value // ============================================================================ @@ -159,27 +114,14 @@ __device__ __forceinline__ float ds_swizzle_xor2(float val) { * Step 3: XOR 1 - reduce 2 values to 1 (thread 0) */ __device__ __forceinline__ float warp_reduce_max_8_dpp(float val) { -#ifndef __gfx1250__ //instruction not supported on this GPU - uint32_t v = float_as_uint(val); - uint32_t tmp; - // Step 1: Exchange with thread 4 positions away - asm volatile("ds_swizzle_b32 %0, %1 offset:0x101F" : "=v"(tmp) : "v"(v)); - asm volatile("s_waitcnt lgkmcnt(0)" :::); - val = fmaxf(val, uint_as_float(tmp)); - v = float_as_uint(val); + val = fmaxf(val, __shfl_xor(val, 4)); // Step 2: Exchange with thread 2 positions away - asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F" : "=v"(tmp) : "v"(v)); - asm volatile("s_waitcnt lgkmcnt(0)" :::); - val = fmaxf(val, uint_as_float(tmp)); - v = float_as_uint(val); + val = fmaxf(val, __shfl_xor(val, 2)); // Step 3: Exchange with adjacent thread - asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F" : "=v"(tmp) : "v"(v)); - asm volatile("s_waitcnt lgkmcnt(0)" :::); - val = fmaxf(val, uint_as_float(tmp)); -#endif + val = fmaxf(val, __shfl_xor(val, 1)); return val; } @@ -218,10 +160,10 @@ __device__ __forceinline__ void hadamard16_inplace( v3 = a1 - a3; // Stage 2: Cross-thread exchange (XOR 1) - combine pairs - float p0 = ds_swizzle_xor1(v0); - float p1 = ds_swizzle_xor1(v1); - float p2 = ds_swizzle_xor1(v2); - float p3 = ds_swizzle_xor1(v3); + float p0 = __shfl_xor(v0, 1); + float p1 = __shfl_xor(v1, 1); + float p2 = __shfl_xor(v2, 1); + float p3 = __shfl_xor(v3, 1); bool sign2 = (tid & 1); v0 = sign2 ? (p0 - v0) : (p0 + v0); @@ -230,10 +172,10 @@ __device__ __forceinline__ void hadamard16_inplace( v3 = sign2 ? (p3 - v3) : (p3 + v3); // Stage 3: Cross-thread exchange (XOR 2) - final combination - p0 = ds_swizzle_xor2(v0); - p1 = ds_swizzle_xor2(v1); - p2 = ds_swizzle_xor2(v2); - p3 = ds_swizzle_xor2(v3); + p0 = __shfl_xor(v0, 2); + p1 = __shfl_xor(v1, 2); + p2 = __shfl_xor(v2, 2); + p3 = __shfl_xor(v3, 2); bool sign3 = (tid >> 1) & 1; float t0 = sign3 ? (p0 - v0) : (p0 + v0); @@ -738,10 +680,6 @@ inline void nvte_cast_transpose_mxfp4_fused_shuffle( int colwise_scale_M_pad, int colwise_scale_N_pad, hipStream_t stream ) { - //TODO: remove when enable HW code - if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 125) { - NVTE_ERROR("Hadamard transform is not yet supported on this GPU"); - } dim3 grid((M + te_mxfp4::BLOCK_M - 1) / te_mxfp4::BLOCK_M, (N + te_mxfp4::BLOCK_N - 1) / te_mxfp4::BLOCK_N); dim3 block(te_mxfp4::THREADS_PER_BLOCK); diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index 32dc5fe7c..105ca26b2 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -498,23 +498,7 @@ __global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restri #ifdef __HIP_PLATFORM_AMD__ -// Tiling / layout constants -// -// A 16-point WHT operates on tiles of kHadamardDim (16) elements. -// Each tile is processed by kThreadsPerWHT (4) threads, each holding -// kElemsPerThread (4) values, so one wavefront of kWarpSize (64) lanes -// handles kRowsPerWarp (16) independent tiles (= rows) simultaneously. -// kWarpsPerBlock wavefronts are combined into a thread-block that covers -// kRowsPerBlock (64) consecutive rows. -static constexpr int kHadamardDim = 16; // WHT dimension (H16) -static constexpr int kWarpSize = 64; // Wavefront width -static constexpr int kThreadsPerWHT = 4; // threads per 16-pt WHT -static constexpr int kElemsPerThread = 4; // elements each thread owns -static constexpr int kRowsPerWarp = kWarpSize / kThreadsPerWHT; // 16 -static constexpr int kWarpsPerBlock = 4; -static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 -static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 -static constexpr float kHadamardScale = 0.25f; // 1/sqrt(16) +#include "wht16.cuh" // Reduce per-warp amax values in warp 0 and atomically update a global amax. __device__ __forceinline__ void reduce_block_amax( @@ -527,26 +511,6 @@ __device__ __forceinline__ void reduce_block_amax( atomicMaxFloat(global_amax, val); } -// ds_swizzle: sub-wavefront exchange without LDS. -// Same instructions as cast_transpose_mxfp4_kernel_shuffled.cu. -__device__ __forceinline__ float ds_swizzle_xor1(float v) { - float r; -#ifndef __gfx1250__ //instruction not supported on this GPU - asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); -#endif - return r; -} - -__device__ __forceinline__ float ds_swizzle_xor2(float v) { - float r; -#ifndef __gfx1250__ //instruction not supported on this GPU - asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); -#endif - return r; -} - // BF16 helpers __device__ __forceinline__ float to_f32 (__hip_bfloat16 v) { return static_cast(v); } __device__ __forceinline__ __hip_bfloat16 to_bf16(float v) { return static_cast<__hip_bfloat16>(v); } @@ -573,89 +537,6 @@ __device__ __forceinline__ uint64_t pack_bf16x4(float v0, float v1, float v2, fl | ((uint64_t)bf16_to_bits(to_bf16(v3)) << 48); } -// ----------------------------------------------------------------------- -// 16-point WHT via the Kronecker trick (no shared memory) -// ----------------------------------------------------------------------- -// -// 1. The vec operator -// vec() flattens a matrix into a column vector by stacking its -// columns one on top of the other: -// -// X = |a c| vec(X) = |a| -// |b d| |b| -// |c| -// |d| -// -// 2. The "Kronecker trick" for 1D -> 2D -// The fundamental identity that connects these concepts is: -// -// vec(B . X . A^T) = (A (x) B) . vec(X) -// -// For a 16-point Hadamard transform (H16 = H4 (x) H4), -// set A = H4 and B = H4. The formula becomes: -// -// H16 . x = vec(H4 . X . H4^T) -// -// 3. Data layout (column-major, one column per thread) -// Reshape the 16-element 1D vector x into a 4x4 matrix X -// by filling columns first: -// -// X = | x0 x4 x8 x12 | thread 0 holds col 0: v0..v3 = x0 ..x3 -// | x1 x5 x9 x13 | thread 1 holds col 1: v0..v3 = x4 ..x7 -// | x2 x6 x10 x14 | thread 2 holds col 2: v0..v3 = x8 ..x11 -// | x3 x7 x11 x15 | thread 3 holds col 3: v0..v3 = x12..x15 -// -// 4. Three-stage computation -// Stage 1 (local H4) : left-multiply H4 . X (within each thread) -// Stage 2 (xor-1 swap) : \ (across 4 threads) -// Stage 3 (xor-2 swap) : / right-multiply . H4^T together these two butterfly stages = H4^T -// -// Result: vec(H4 . X . H4^T) = H16 . x -// -// 5. Randomised Hadamard Transform (RHT) -// A diagonal sign matrix D (from sign_mask) is applied either -// before the WHT (apply_pre=true, forward) or after (inverse). -// -// Adapted from cast_transpose_mxfp4_kernel_shuffled.cu::hadamard16_inplace, -// extended with NV random_sign_mask (uint16_t bitmask). -// thread_in_group [0,3]: drives ds_swizzle polarity (identical to MLPerf tid & 3). -// apply_pre=true -> D before WHT (forward); false -> D after WHT (inverse). -__device__ __forceinline__ void wht16( - float& v0, float& v1, float& v2, float& v3, - int thread_in_group, uint16_t sign_mask, bool apply_pre) { - auto sgn = [&](int k) -> float { - return ((sign_mask >> (thread_in_group * kElemsPerThread + k)) & 1u) ? -1.f : 1.f; - }; - - if (apply_pre) { - v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); - } - - // Stage 1: local H4 - float a0=v0+v1, a1=v0-v1, a2=v2+v3, a3=v2-v3; - v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; - - // Stage 2: cross-thread XOR-1 - { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), - p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); - bool up=(thread_in_group&1); - v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); - v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } - - // Stage 3: cross-thread XOR-2 - { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), - p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); - bool up=(thread_in_group>>1)&1; - v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); - v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } - - v0*=kHadamardScale; v1*=kHadamardScale; v2*=kHadamardScale; v3*=kHadamardScale; - - if (!apply_pre) { - v0*=sgn(0); v1*=sgn(1); v2*=sgn(2); v3*=sgn(3); - } -} - // Grid: blockIdx.x = col tile [0, row_length/16) // blockIdx.y = row batch [0, ceil(num_rows/64)) // Block: 256 threads = 4 wavefronts of 64 lanes. diff --git a/transformer_engine/common/hadamard_transform/wht16.cuh b/transformer_engine/common/hadamard_transform/wht16.cuh index 490ebbb6d..4bd538a69 100644 --- a/transformer_engine/common/hadamard_transform/wht16.cuh +++ b/transformer_engine/common/hadamard_transform/wht16.cuh @@ -23,21 +23,6 @@ static constexpr int kRowsPerBlock = kRowsPerWarp * kWarpsPerBlock; // 64 static constexpr int kThreadsPerBlock = kWarpSize * kWarpsPerBlock; // 256 static constexpr float kHadamardScale = 0.25f; -// ds_swizzle: sub-wavefront exchange without LDS. -__device__ __forceinline__ float ds_swizzle_xor1(float v) { - float r; - asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); - return r; -} - -__device__ __forceinline__ float ds_swizzle_xor2(float v) { - float r; - asm volatile("ds_swizzle_b32 %0, %1 offset:0x081F\n\t" - "s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v)); - return r; -} - // ----------------------------------------------------------------------- // 16-point WHT via the Kronecker trick (no shared memory) // ----------------------------------------------------------------------- @@ -101,15 +86,15 @@ __device__ __forceinline__ void wht16( v0=a0+a2; v2=a0-a2; v1=a1+a3; v3=a1-a3; // Stage 2: cross-thread XOR-1 - { float p0=ds_swizzle_xor1(v0), p1=ds_swizzle_xor1(v1), - p2=ds_swizzle_xor1(v2), p3=ds_swizzle_xor1(v3); + { float p0=__shfl_xor(v0, 1), p1=__shfl_xor(v1, 1), + p2=__shfl_xor(v2, 1), p3=__shfl_xor(v3, 1); bool up=(thread_in_group&1); v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); } // Stage 3: cross-thread XOR-2 - { float p0=ds_swizzle_xor2(v0), p1=ds_swizzle_xor2(v1), - p2=ds_swizzle_xor2(v2), p3=ds_swizzle_xor2(v3); + { float p0=__shfl_xor(v0, 2), p1=__shfl_xor(v1, 2), + p2=__shfl_xor(v2, 2), p3=__shfl_xor(v3, 2); bool up=(thread_in_group>>1)&1; v0=up?(p0-v0):(p0+v0); v1=up?(p1-v1):(p1+v1); v2=up?(p2-v2):(p2+v2); v3=up?(p3-v3):(p3+v3); }