From 009364c3dc1e61d101bd07d17924f7c41337ebe7 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 22:02:43 +0100 Subject: [PATCH 1/4] batchedmatmul: align types. --- examples/batchmatmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/batchmatmul.py b/examples/batchmatmul.py index b0929639..a4f7874e 100644 --- a/examples/batchmatmul.py +++ b/examples/batchmatmul.py @@ -45,7 +45,7 @@ def batchmatmul_cutile_kernel(A, B, C, tm: ct.Constant[int], tn: ct.Constant[int # Example harness #============================================================================= -def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int = None, N: int = None, dtype=np.float16): +def prepare(*, benchmark: bool = False, Batch: int = None, M: int = None, K: int = None, N: int = None, dtype=np.float32): """Allocate and initialize data for batch matmul.""" if Batch is None: Batch = 8 if benchmark else 4 From cbfbdc3893aa58730e773d9646c4c5689fcc9390 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 22:03:12 +0100 Subject: [PATCH 2/4] layernorm: use contiguous layout. --- examples/layernorm.jl | 110 ++++++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 53 deletions(-) diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 49289949..493e5947 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -1,5 +1,8 @@ # LayerNorm example - Julia port of cuTile Python's LayerNorm.py sample # +# Data layout: 2D tensors are stored as (N, M) so that the normalization +# dimension N is contiguous in Julia's column-major memory layout. +# # SPDX-License-Identifier: Apache-2.0 using CUDA @@ -11,10 +14,10 @@ import cuTile as ct Forward pass: computes mean/var, normalizes input, and applies affine transform. Args: - X: Input tensor (M, N). + X: Input tensor (N, M). W: Weight tensor (N,). B: Bias tensor (N,). - Y: Output tensor (M, N). + Y: Output tensor (N, M). Mean: Output mean tensor (M,). Rstd: Output reciprocal standard deviation tensor (M,). eps: Epsilon for numerical stability. @@ -25,44 +28,44 @@ function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1} Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1}, eps::Float32, TILE_N::Int) bid_m = ct.bid(1) - num_tiles = ct.num_tiles(X, 2, (1, TILE_N)) - N = size(X, 2) + num_tiles = ct.num_tiles(X, 1, (TILE_N, 1)) + N = size(X, 1) # Compute mean - mean = zeros(Float32, (1, TILE_N)) + mean = zeros(Float32, (TILE_N, 1)) j = Int32(1) while j <= num_tiles - tx = ct.load(X; index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.Zero) + tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero) mean = mean .+ tx j += Int32(1) end - mean = sum(mean; dims=2) / N + mean = sum(mean; dims=1) / N ct.store(Mean; index=bid_m, tile=mean) # Compute variance - var = zeros(Float32, (1, TILE_N)) + var = zeros(Float32, (TILE_N, 1)) j = Int32(1) while j <= num_tiles - tx = ct.load(X; index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.Zero) + tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero) # Mask for valid elements - mask = reshape(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange(TILE_N)) .<= N, (1, TILE_N)) + mask = reshape(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange(TILE_N)) .<= N, (TILE_N, 1)) centered_tx = ifelse.(mask, tx .- mean, 0.0f0) var = var .+ (centered_tx .^ 2.0f0) j += Int32(1) end - var = sum(var; dims=2) / N + var = sum(var; dims=1) / N rstd = 1.0f0 ./ sqrt.(var .+ eps) ct.store(Rstd; index=bid_m, tile=rstd) # Normalize and apply affine transformation j = Int32(1) while j <= num_tiles - tx = ct.load(X; index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.Zero) - tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (1, TILE_N)) - tb = reshape(ct.load(B; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (1, TILE_N)) + tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero) + tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1)) + tb = reshape(ct.load(B; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1)) ty = (tx .- mean) .* rstd ty = ty .* tw .+ tb - ct.store(Y; index=(bid_m, j), tile=ty) + ct.store(Y; index=(j, bid_m), tile=ty) j += Int32(1) end @@ -86,9 +89,9 @@ This gets inlined by Julia's compiler. bid_m and j are 1-indexed (block ID and tile index). """ @inline function bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) - tx = ct.load(X; index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.Zero) - tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (1, TILE_N)) - tdy = ct.load(DY; index=(bid_m, j), shape=(1, TILE_N), padding_mode=ct.PaddingMode.Zero) + tx = ct.load(X; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero) + tw = reshape(ct.load(W; index=j, shape=(TILE_N,), padding_mode=ct.PaddingMode.Zero), (TILE_N, 1)) + tdy = ct.load(DY; index=(j, bid_m), shape=(TILE_N, 1), padding_mode=ct.PaddingMode.Zero) xhat = (tx .- mean) .* rstd wdy = tw .* tdy @@ -96,7 +99,7 @@ bid_m and j are 1-indexed (block ID and tile index). indices = ct.arange(TILE_N) offset = (j - Int32(1)) * Int32(TILE_N) global_indices = offset .+ indices - mask = reshape(global_indices .<= N, (1, TILE_N)) + mask = reshape(global_indices .<= N, (TILE_N, 1)) xhat_masked = ifelse.(mask, xhat, 0.0f0) wdy_masked = ifelse.(mask, wdy, 0.0f0) @@ -110,9 +113,9 @@ end Backward pass: computes gradient with respect to input X. Args: - DX: Output gradient with respect to X (M, N). - DY: Input gradient with respect to Y (M, N). - X: Input tensor (M, N). + DX: Output gradient with respect to X (N, M). + DY: Input gradient with respect to Y (N, M). + X: Input tensor (N, M). W: Weight tensor (N,). Mean: Mean tensor (M,). Rstd: Reciprocal standard deviation tensor (M,). @@ -123,16 +126,16 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3 Mean::ct.TileArray{Float32, 1}, Rstd::ct.TileArray{Float32, 1}, TILE_N::Int) bid_m = ct.bid(1) - num_tiles = ct.num_tiles(X, 2, (1, TILE_N)) - N = size(X, 2) + num_tiles = ct.num_tiles(X, 1, (TILE_N, 1)) + N = size(X, 1) # Load mean and rstd for this row mean = ct.load(Mean; index=bid_m, shape=(1,), padding_mode=ct.PaddingMode.Zero) rstd = ct.load(Rstd; index=bid_m, shape=(1,), padding_mode=ct.PaddingMode.Zero) # First pass: compute c1 and c2 reduction terms - c1 = zeros(Float32, (1, TILE_N)) - c2 = zeros(Float32, (1, TILE_N)) + c1 = zeros(Float32, (TILE_N, 1)) + c2 = zeros(Float32, (TILE_N, 1)) j = Int32(1) while j <= num_tiles _, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) @@ -140,15 +143,15 @@ function layer_norm_bwd_dx(DX::ct.TileArray{Float32, 2}, DY::ct.TileArray{Float3 c2 = c2 .+ wdy j += Int32(1) end - c1 = sum(c1; dims=2) / N - c2 = sum(c2; dims=2) / N + c1 = sum(c1; dims=1) / N + c2 = sum(c2; dims=1) / N # Second pass: compute dX j = Int32(1) while j <= num_tiles _, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd - ct.store(DX; index=(bid_m, j), tile=tdx) + ct.store(DX; index=(j, bid_m), tile=tdx) j += Int32(1) end @@ -162,11 +165,11 @@ Backward pass part 1: computes dX and partial dW/dB. Accumulates partial gradients using atomic locks. Args: - DX: Output gradient with respect to X (M, N). - DY: Input gradient with respect to Y (M, N). + DX: Output gradient with respect to X (N, M). + DY: Input gradient with respect to Y (N, M). DW: Partial gradient with respect to W (N, GROUP_SIZE_M). DB: Partial gradient with respect to B (N, GROUP_SIZE_M). - X: Input tensor (M, N). + X: Input tensor (N, M). W: Weight tensor (N,). Mean: Mean tensor (M,). Rstd: Reciprocal standard deviation tensor (M,). @@ -181,8 +184,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til Locks::ct.TileArray{Int, 1}, GROUP_SIZE_M::Int, TILE_N::Int) bid_m = ct.bid(1) - num_tiles = ct.num_tiles(X, 2, (1, TILE_N)) - N = size(X, 2) + num_tiles = ct.num_tiles(X, 1, (TILE_N, 1)) + N = size(X, 1) group_bid_m = ((bid_m - Int32(1)) % Int32(GROUP_SIZE_M)) + Int32(1) # Load mean and rstd for this row @@ -190,8 +193,8 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til rstd = ct.load(Rstd; index=bid_m, shape=(1,), padding_mode=ct.PaddingMode.Zero) # First pass: compute c1 and c2 reduction terms - c1 = zeros(Float32, (1, TILE_N)) - c2 = zeros(Float32, (1, TILE_N)) + c1 = zeros(Float32, (TILE_N, 1)) + c2 = zeros(Float32, (TILE_N, 1)) j = Int32(1) while j <= num_tiles _, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) @@ -199,15 +202,15 @@ function layer_norm_bwd_dx_partial_dwdb(DX::ct.TileArray{Float32, 2}, DY::ct.Til c2 = c2 .+ wdy j += Int32(1) end - c1 = sum(c1; dims=2) / N - c2 = sum(c2; dims=2) / N + c1 = sum(c1; dims=1) / N + c2 = sum(c2; dims=1) / N # Second pass: compute dX and partial dW/dB j = Int32(1) while j <= num_tiles tdy, xhat, wdy = bwd_helper(X, W, DY, bid_m, j, mean, rstd, TILE_N, N) tdx = (wdy .- (xhat .* c1 .+ c2)) .* rstd - ct.store(DX; index=(bid_m, j), tile=tdx) + ct.store(DX; index=(j, bid_m), tile=tdx) partial_dw = reshape(tdy .* xhat, (TILE_N, 1)) partial_db = reshape(tdy, (TILE_N, 1)) @@ -279,16 +282,16 @@ function prepare(; benchmark::Bool=false, N::Int=benchmark ? 4096 : 256, eps::Float32=1f-5, GROUP_SIZE_M::Int=64) return (; - # Forward inputs/outputs - X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, M, N), + # Forward inputs/outputs — 2D tensors stored as (N, M) for contiguous N access + X = -2.3f0 .+ 0.5f0 .* CUDA.rand(Float32, N, M), W = CUDA.randn(Float32, N), B = CUDA.randn(Float32, N), - Y = CuArray{Float32}(undef, M, N), + Y = CuArray{Float32}(undef, N, M), Mean = CuArray{Float32}(undef, M), Rstd = CuArray{Float32}(undef, M), # Backward inputs/outputs - DY = 0.1f0 .* CUDA.randn(Float32, M, N), - DX = CuArray{Float32}(undef, M, N), + DY = 0.1f0 .* CUDA.randn(Float32, N, M), + DX = CuArray{Float32}(undef, N, M), DW_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), DB_partial = CuArray{Float32}(undef, N, GROUP_SIZE_M), Locks = CuArray{Int}(undef, GROUP_SIZE_M), @@ -345,27 +348,28 @@ end function verify(data, result) (; X, W, B, DY, N, eps) = data + # All data is (N, M): reduce along dim 1 (N), broadcast W/B along dim 2 (M) X_cpu = Array(X) W_cpu = Array(W) B_cpu = Array(B) DY_cpu = Array(DY) # Forward verification - expected_mean = vec(sum(X_cpu, dims=2) ./ N) - expected_var = vec(sum((X_cpu .- expected_mean) .^ 2, dims=2) ./ N) + expected_mean = sum(X_cpu, dims=1) ./ N # (1, M) + expected_var = sum((X_cpu .- expected_mean) .^ 2, dims=1) ./ N # (1, M) expected_rstd = 1.0f0 ./ sqrt.(expected_var .+ eps) - xhat = (X_cpu .- expected_mean) .* expected_rstd - expected_Y = xhat .* W_cpu' .+ B_cpu' + xhat = (X_cpu .- expected_mean) .* expected_rstd # (N, M) + expected_Y = xhat .* W_cpu .+ B_cpu # W/B are (N,), broadcast over M @assert isapprox(expected_Y, Array(result.Y); rtol=1e-2) "Y mismatch" # Backward verification - wdy = W_cpu' .* DY_cpu - c1 = sum(xhat .* wdy, dims=2) ./ N - c2 = sum(wdy, dims=2) ./ N + wdy = W_cpu .* DY_cpu # (N, M) + c1 = sum(xhat .* wdy, dims=1) ./ N # (1, M) + c2 = sum(wdy, dims=1) ./ N # (1, M) expected_DX = (wdy .- (xhat .* c1 .+ c2)) .* expected_rstd - expected_DW = vec(sum(DY_cpu .* xhat, dims=1)) - expected_DB = vec(sum(DY_cpu, dims=1)) + expected_DW = vec(sum(DY_cpu .* xhat, dims=2)) # reduce over M → (N,) + expected_DB = vec(sum(DY_cpu, dims=2)) # reduce over M → (N,) @assert isapprox(expected_DX, Array(result.DX); rtol=1e-2) "dX mismatch" @assert isapprox(expected_DW, Array(result.FINAL_DW); rtol=1e-2) "dW mismatch" From b07eeab38bfbb8f4b3b4e688603e9aff13846945 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 22:28:51 +0100 Subject: [PATCH 3/4] fft: use natural storage. --- examples/fft.jl | 230 +++++++++++++++++++++--------------------------- 1 file changed, 101 insertions(+), 129 deletions(-) diff --git a/examples/fft.jl b/examples/fft.jl index c4159d28..80b75feb 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -3,9 +3,8 @@ # This implements a 3-stage Cooley-Tukey FFT decomposition. The FFT of size N is decomposed # as N = F0 * F1 * F2, allowing efficient tensor factorization. # -# Key difference from Python: Julia uses column-major storage, so reshape dimensions are -# swapped and right-multiply (X @ W) is used instead of left-multiply (W @ X) to process -# rows instead of columns, achieving the same strided element access pattern. +# All shapes are the reverse of Python's row-major shapes, so that the memory layout is +# identical. No extra permutedims are needed for batch dimension shuffling. # # SPDX-License-Identifier: Apache-2.0 @@ -14,22 +13,19 @@ import cuTile as ct using Test using FFTW -# FFT kernel - 3-stage Cooley-Tukey decomposition (column-major version) +# FFT kernel - 3-stage Cooley-Tukey decomposition # -# The key insight: In Python row-major, reshape (F0, F1F2) puts stride-F1F2 elements -# in columns. In Julia column-major, reshape (F1F2, F0) puts stride-F0 elements in rows. -# We use right-multiply X @ W instead of W @ X to process rows instead of columns. -# -# Input/output memory layout: (D, BS, N2D) where D=2 for real/imag interleaving. -# Internally, BS is permuted to trailing position for batched matmul convention. +# Python row-major shape (A, B, C) ↔ Julia col-major shape (C, B, A) — same memory layout. +# Python left-multiply W @ X ↔ Julia right-multiply X * W (batch dims trailing). +# Python ct.permute(x, (0,2,3,1)) ↔ Julia permutedims(x, (3,1,2,4)). function fft_kernel( - x_packed_in::ct.TileArray{Float32, 3}, # Input (D, BS, N2D) - natural Julia complex layout - y_packed_out::ct.TileArray{Float32, 3}, # Output (D, BS, N2D) - W0::ct.TileArray{Float32, 3}, # W0 (F0, F0, 2) DFT matrix - W1::ct.TileArray{Float32, 3}, # W1 (F1, F1, 2) - W2::ct.TileArray{Float32, 3}, # W2 (F2, F2, 2) - T0::ct.TileArray{Float32, 3}, # T0 (F1F2, F0, 2) twiddle factors - T1::ct.TileArray{Float32, 3}, # T1 (F0F2, F1, 2) twiddle factors + x_packed_in::ct.TileArray{Float32, 3}, # Input (D, N2D, BS) + y_packed_out::ct.TileArray{Float32, 3}, # Output (D, N2D, BS) + W0::ct.TileArray{Float32, 3}, # W0 (2, F0, F0) DFT matrix + W1::ct.TileArray{Float32, 3}, # W1 (2, F1, F1) + W2::ct.TileArray{Float32, 3}, # W2 (2, F2, F2) + T0::ct.TileArray{Float32, 3}, # T0 (2, F1F2, F0) twiddle factors + T1::ct.TileArray{Float32, 3}, # T1 (2, F2, F1) twiddle factors n_const::Int, f0_const::Int, f1_const::Int, @@ -41,7 +37,6 @@ function fft_kernel( d_const::Int, n2d_const::Int ) - # Extract constant values N = n_const F0 = f0_const F1 = f1_const @@ -56,156 +51,137 @@ function fft_kernel( bid = ct.bid(1) # --- Load Input Data --- - # Input is (D, BS, N2D) where D=2 for real/imag. Load and permute BS to trailing. - X_ri_mem = reshape(ct.load(x_packed_in; index=(1, bid, 1), shape=(D, BS, N2D)), (2, BS, N)) - X_ri = permutedims(X_ri_mem, (1, 3, 2)) # (2, N, BS) — trailing batch + # Input is (D, N2D, BS). Load and reshape to (2, N, BS). + X_ri = reshape(ct.load(x_packed_in; index=(1, 1, bid), shape=(D, N2D, BS)), (2, N, BS)) - # Split real and imaginary parts (extract from first dimension) - X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F1F2, F0, BS)) - X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F1F2, F0, BS)) + # Split real and imaginary parts, reshape to 4D factored form + X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS)) + X_i = reshape(ct.extract(X_ri, (2, 1, 1), (1, N, BS)), (F2, F1, F0, BS)) # --- Load DFT Matrices --- - # W0 (F0 x F0) - for right-multiply X @ W0, batch dim trailing - W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(F0, F0, 2)), (F0, F0, 2)) - W0_r = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS)) - W0_i = ct.broadcast_to(reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (F0, F0, 1)), (F0, F0, BS)) + # W0 (F0 × F0): trailing batch dim 1 for broadcast in batched matmul + W0_ri = reshape(ct.load(W0; index=(1, 1, 1), shape=(2, F0, F0)), (2, F0, F0)) + W0_r = reshape(ct.extract(W0_ri, (1, 1, 1), (1, F0, F0)), (F0, F0, 1)) + W0_i = reshape(ct.extract(W0_ri, (2, 1, 1), (1, F0, F0)), (F0, F0, 1)) - # W1 (F1 x F1) - W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(F1, F1, 2)), (F1, F1, 2)) - W1_r = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS)) - W1_i = ct.broadcast_to(reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (F1, F1, 1)), (F1, F1, BS)) + # W1 (F1 × F1) + W1_ri = reshape(ct.load(W1; index=(1, 1, 1), shape=(2, F1, F1)), (2, F1, F1)) + W1_r = reshape(ct.extract(W1_ri, (1, 1, 1), (1, F1, F1)), (F1, F1, 1)) + W1_i = reshape(ct.extract(W1_ri, (2, 1, 1), (1, F1, F1)), (F1, F1, 1)) - # W2 (F2 x F2) - W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(F2, F2, 2)), (F2, F2, 2)) - W2_r = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS)) - W2_i = ct.broadcast_to(reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (F2, F2, 1)), (F2, F2, BS)) + # W2 (F2 × F2) + W2_ri = reshape(ct.load(W2; index=(1, 1, 1), shape=(2, F2, F2)), (2, F2, F2)) + W2_r = reshape(ct.extract(W2_ri, (1, 1, 1), (1, F2, F2)), (F2, F2, 1)) + W2_i = reshape(ct.extract(W2_ri, (2, 1, 1), (1, F2, F2)), (F2, F2, 1)) # --- Load Twiddle Factors --- - # T0 (F1F2, F0) - note swapped from Python's (F0, F1F2) - T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(F1F2, F0, 2)), (F1F2, F0, 2)) - T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (N, 1)) - T0_i = reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (N, 1)) + # T0 (2, F1F2, F0) → flatten to (1, N) for element-wise multiply + T0_ri = reshape(ct.load(T0; index=(1, 1, 1), shape=(2, F1F2, F0)), (2, F1F2, F0)) + T0_r = reshape(ct.extract(T0_ri, (1, 1, 1), (1, F1F2, F0)), (1, N)) + T0_i = reshape(ct.extract(T0_ri, (2, 1, 1), (1, F1F2, F0)), (1, N)) - # T1 (F0F2, F1) - note swapped from Python's (F1, F2) - T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(F0F2, F1, 2)), (F0F2, F1, 2)) - T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (F0F2 * F1, 1)) - T1_i = reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (F0F2 * F1, 1)) + # T1 (2, F2, F1) → flatten to (1, F1F2) for element-wise multiply + T1_ri = reshape(ct.load(T1; index=(1, 1, 1), shape=(2, F2, F1)), (2, F2, F1)) + T1_r = reshape(ct.extract(T1_ri, (1, 1, 1), (1, F2, F1)), (1, F1F2)) + T1_i = reshape(ct.extract(T1_ri, (2, 1, 1), (1, F2, F1)), (1, F1F2)) # --- Stage 0: F0-point DFT --- - # X is (F1F2, F0, BS), W0 is (F0, F0, BS) — trailing batch - # Right-multiply: X @ W0 processes each row (F1F2 rows, each with F0 elements) - X_r_ = X_r * W0_r - X_i * W0_i # (F1F2, F0, BS) @ (F0, F0, BS) → (F1F2, F0, BS) - X_i_ = X_r * W0_i + X_i * W0_r + # X: (F1F2, F0, BS) × W0: (F0, F0, 1) → (F1F2, F0, BS) + X_r = reshape(X_r, (F1F2, F0, BS)) + X_i = reshape(X_i, (F1F2, F0, BS)) + X_r_ = reshape(X_r * W0_r - X_i * W0_i, (1, N, BS)) + X_i_ = reshape(X_r * W0_i + X_i * W0_r, (1, N, BS)) # --- Twiddle & Permute 0 --- - # Reshape to (N, BS) for element-wise twiddle multiply - X_r_flat = reshape(X_r_, (N, BS)) - X_i_flat = reshape(X_i_, (N, BS)) - X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat - X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat - - # Reshape and permute for stage 1 - # Reshape to (F2, F1, F0, BS) then permute to (F0F2, F1, BS) for stage 1 - X_r3 = reshape(X_r2, (F2, F1, F0, BS)) - X_i3 = reshape(X_i2, (F2, F1, F0, BS)) - X_r4 = permutedims(X_r3, (1, 3, 2, 4)) # (F2, F0, F1, BS) - X_i4 = permutedims(X_i3, (1, 3, 2, 4)) - X_r5 = reshape(X_r4, (F0F2, F1, BS)) - X_i5 = reshape(X_i4, (F0F2, F1, BS)) + X_r2 = T0_r .* X_r_ .- T0_i .* X_i_ + X_i2 = T0_i .* X_r_ .+ T0_r .* X_i_ + + # Reshape to 4D factored form and permute for stage 1 + X_r3 = permutedims(reshape(X_r2, (F2, F1, F0, BS)), (3, 1, 2, 4)) # → (F0, F2, F1, BS) + X_i3 = permutedims(reshape(X_i2, (F2, F1, F0, BS)), (3, 1, 2, 4)) # --- Stage 1: F1-point DFT --- - # X is (F0F2, F1, BS), W1 is (F1, F1, BS) - X_r6 = X_r5 * W1_r - X_i5 * W1_i - X_i6 = X_r5 * W1_i + X_i5 * W1_r + # Merge (F0, F2) → F0F2; X: (F0F2, F1, BS) × W1: (F1, F1, 1) → (F0F2, F1, BS) + X_r4 = reshape(X_r3, (F0F2, F1, BS)) + X_i4 = reshape(X_i3, (F0F2, F1, BS)) + X_r5 = reshape(X_r4 * W1_r - X_i4 * W1_i, (F0, F1F2, BS)) + X_i5 = reshape(X_r4 * W1_i + X_i4 * W1_r, (F0, F1F2, BS)) # --- Twiddle & Permute 1 --- - X_r_flat2 = reshape(X_r6, (N, BS)) - X_i_flat2 = reshape(X_i6, (N, BS)) - X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2 - X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2 - - # Reshape and permute for stage 2 - X_r8 = reshape(X_r7, (F2, F0, F1, BS)) - X_i8 = reshape(X_i7, (F2, F0, F1, BS)) - X_r9 = permutedims(X_r8, (2, 3, 1, 4)) # (F0, F1, F2, BS) - X_i9 = permutedims(X_i8, (2, 3, 1, 4)) - X_r10 = reshape(X_r9, (F0F1, F2, BS)) - X_i10 = reshape(X_i9, (F0F1, F2, BS)) + X_r6 = T1_r .* X_r5 .- T1_i .* X_i5 + X_i6 = T1_i .* X_r5 .+ T1_r .* X_i5 + + # Reshape to 4D and permute for stage 2 + X_r7 = permutedims(reshape(X_r6, (F0, F2, F1, BS)), (3, 1, 2, 4)) # → (F1, F0, F2, BS) + X_i7 = permutedims(reshape(X_i6, (F0, F2, F1, BS)), (3, 1, 2, 4)) # --- Stage 2: F2-point DFT --- - # X is (F0F1, F2, BS), W2 is (F2, F2, BS) - X_r11 = X_r10 * W2_r - X_i10 * W2_i - X_i11 = X_r10 * W2_i + X_i10 * W2_r + # Merge (F1, F0) → F0F1; X: (F0F1, F2, BS) × W2: (F2, F2, 1) → (F0F1, F2, BS) + X_r8 = reshape(X_r7, (F0F1, F2, BS)) + X_i8 = reshape(X_i7, (F0F1, F2, BS)) + X_r9 = X_r8 * W2_r - X_i8 * W2_i + X_i9 = X_r8 * W2_i + X_i8 * W2_r - # --- Final Output --- - X_r_final = reshape(X_r11, (1, N, BS)) - X_i_final = reshape(X_i11, (1, N, BS)) + # --- Final permute --- + X_r10 = permutedims(reshape(X_r9, (F1, F0, F2, BS)), (2, 1, 3, 4)) # → (F0, F1, F2, BS) + X_i10 = permutedims(reshape(X_i9, (F1, F0, F2, BS)), (2, 1, 3, 4)) # --- Concatenate and Store --- - # Permute BS back to middle for memory layout (D, BS, N2D) - Y_ri = permutedims(reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)), (1, 3, 2)) - ct.store(y_packed_out; index=(1, bid, 1), tile=Y_ri) + X_r_final = reshape(X_r10, (1, N, BS)) + X_i_final = reshape(X_i10, (1, N, BS)) + Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)) + ct.store(y_packed_out; index=(1, 1, bid), tile=Y_ri) return end # Helper: Generate DFT matrix W_n^{ij} = exp(-2πi * ij / n) +# Stored as (2, size, size) — reversed from Python's (size, size, 2). +# DFT matrices are symmetric in i,j so no transpose needed. function dft_matrix(size::Int) W = zeros(ComplexF32, size, size) for i in 0:size-1, j in 0:size-1 W[i+1, j+1] = exp(-2π * im * i * j / size) end - result = zeros(Float32, size, size, 2) - result[:, :, 1] = Float32.(real.(W)) - result[:, :, 2] = Float32.(imag.(W)) + result = zeros(Float32, 2, size, size) + result[1, :, :] = Float32.(real.(W)) + result[2, :, :] = Float32.(imag.(W)) return result end -# Generate twiddle factors T0 for column-major layout (F1F2, F0) -# In Julia column-major, position (j, i) in (F1F2, F0) has linear index j + i*F1F2 -# This corresponds to Python's position (i, j) in (F0, F1F2) with linear index i*F1F2 + j -# The twiddle value is ω_N^{i * j} +# Generate twiddle factors T0: (2, F1F2, F0) — reversed from Python's (F0, F1F2, 2) function make_twiddles_T0(F0::Int, F1F2::Int, N::Int) - T0 = zeros(Float32, F1F2, F0, 2) - for j in 0:F1F2-1, i in 0:F0-1 + T0 = zeros(Float32, 2, F1F2, F0) + for i in 0:F0-1, j in 0:F1F2-1 val = exp(-2π * im * i * j / N) - T0[j+1, i+1, 1] = Float32(real(val)) - T0[j+1, i+1, 2] = Float32(imag(val)) + T0[1, j+1, i+1] = Float32(real(val)) + T0[2, j+1, i+1] = Float32(imag(val)) end return T0 end -# Generate twiddle factors T1 for column-major layout (F0F2, F1) -# After stage 0 and permute, data is in (F0F2, F1) layout -# The twiddle value is ω_{F1F2}^{j * k} where j is F1 index and k is F2 index within F0F2 -function make_twiddles_T1(F0::Int, F1::Int, F2::Int) - F0F2 = F0 * F2 - F1F2 = F1 * F2 - T1 = zeros(Float32, F0F2, F1, 2) - for k in 0:F0F2-1, j in 0:F1-1 - # k encodes (f0, f2) = (k ÷ F2, k % F2) after permute - f2 = k % F2 - val = exp(-2π * im * j * f2 / F1F2) - T1[k+1, j+1, 1] = Float32(real(val)) - T1[k+1, j+1, 2] = Float32(imag(val)) +# Generate twiddle factors T1: (2, F2, F1) — reversed from Python's (F1, F2, 2) +function make_twiddles_T1(F1::Int, F2::Int, F1F2::Int) + T1 = zeros(Float32, 2, F2, F1) + for j in 0:F1-1, k in 0:F2-1 + val = exp(-2π * im * j * k / F1F2) + T1[1, k+1, j+1] = Float32(real(val)) + T1[2, k+1, j+1] = Float32(imag(val)) end return T1 end -# Generate all W and T matrices for column-major algorithm +# Generate all W and T matrices function make_twiddles(factors::NTuple{3, Int}) F0, F1, F2 = factors N = F0 * F1 * F2 F1F2 = F1 * F2 - # DFT matrices (same for row/column-major since symmetric) W0 = dft_matrix(F0) W1 = dft_matrix(F1) W2 = dft_matrix(F2) - - # Column-major twiddle factors T0 = make_twiddles_T0(F0, F1F2, N) - T1 = make_twiddles_T1(F0, F1, F2) + T1 = make_twiddles_T1(F1, F2, F1F2) return (W0, W1, W2, T0, T1) end @@ -223,9 +199,10 @@ function prepare(; benchmark::Bool=false, @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" CUDA.seed!(42) - input = CUDA.randn(ComplexF32, batch, n) + # Store as (n, batch) so reinterpret gives (2, n, batch) = (D, N2D, batch) + # This matches Python's (batch, N2D, D) row-major in memory. + input = CUDA.randn(ComplexF32, n, batch) - # Pre-compute twiddles (one-time CPU cost) W0, W1, W2, T0, T1 = make_twiddles(factors) W0_gpu = CuArray(W0) W1_gpu = CuArray(W1) @@ -233,11 +210,10 @@ function prepare(; benchmark::Bool=false, T0_gpu = CuArray(T0) T1_gpu = CuArray(T1) - # Pack input D = atom_packing_dim N2D = n * 2 ÷ D - x_packed = reinterpret(reshape, Float32, input) - y_packed = CuArray{Float32}(undef, D, batch, N2D) + x_packed = reinterpret(reshape, Float32, input) # (2, n, batch) = (D, N2D, batch) + y_packed = CuArray{Float32}(undef, D, N2D, batch) return (; input, x_packed, y_packed, @@ -276,7 +252,7 @@ function run(data; nruns::Int=1, warmup::Int=0) push!(times, t * 1000) # ms end - # Unpack output + # Unpack output: (2, n, batch) → ComplexF32(n, batch) y_complex = reinterpret(reshape, ComplexF32, y_packed) output = copy(y_complex) @@ -284,7 +260,8 @@ function run(data; nruns::Int=1, warmup::Int=0) end function verify(data, result) - reference = FFTW.fft(Array(data.input), 2) + # FFT along dim 1 (the n dimension) + reference = FFTW.fft(Array(data.input), 1) @assert isapprox(Array(result.output), reference, rtol=1e-4) end @@ -296,16 +273,13 @@ function run_others(data; nruns::Int=1, warmup::Int=0) (; input, batch, n) = data results = Dict{String, Vector{Float64}}() - output_cufft = similar(input) - - # CUFFT via CUDA.CUFFT CUDA.@sync for _ in 1:warmup - CUDA.CUFFT.fft!(copy(input), 2) + CUDA.CUFFT.fft!(copy(input), 1) end times_cufft = Float64[] for _ in 1:nruns input_copy = copy(input) - t = CUDA.@elapsed CUDA.CUFFT.fft!(input_copy, 2) + t = CUDA.@elapsed CUDA.CUFFT.fft!(input_copy, 1) push!(times_cufft, t * 1000) end results["cuFFT"] = times_cufft @@ -320,7 +294,6 @@ end function main() println("--- Running cuTile FFT Example ---") - # Configuration BATCH_SIZE = 2 FFT_SIZE = 8 FFT_FACTORS = (2, 2, 2) @@ -332,7 +305,6 @@ function main() println(" FFT Factors: $FFT_FACTORS") println(" Atom Packing Dim: $ATOM_PACKING_DIM") - # Use prepare/run/verify pattern data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))") From 96c765be549e4c21d31cad870cc49b9b4228661e Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 25 Mar 2026 22:30:29 +0100 Subject: [PATCH 4/4] Update README. --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e92c0bc4..0e29f0d6 100644 --- a/README.md +++ b/README.md @@ -96,16 +96,16 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080: | Kernel | Julia | Python | Status | |--------|-------|--------|--------| -| Vector Addition | 813 GB/s | 834 GB/s | OK (-3%) | -| Matrix Transpose | 769 GB/s | 795 GB/s | OK (-3%) | -| Matrix Multiplication | 48.3 TFLOPS | 48.6 TFLOPS | OK (=) | -| Layer Normalization | 254 GB/s | 683 GB/s | https://github.com/JuliaGPU/cuTile.jl/issues/1 (-63%) | -| Batch Matrix Multiply | 31.7 TFLOPS | 31.6 TFLOPS | OK (=) | -| FFT (3-stage Cooley-Tukey) | 508 μs | 230 μs | (-55%) | - -Compute-intensive kernels (matmul, batch matmul) perform identically to Python. Memory-bound -kernels (vadd, transpose) are within ~3% of Python. The layernorm kernel is slower due to -conservative token threading in the compiler (see https://github.com/JuliaGPU/cuTile.jl/issues/1). +| Vector Addition | 840 GB/s | 844 GB/s | OK (=) | +| Matrix Transpose | 806 GB/s | 816 GB/s | OK (-1%) | +| Layer Normalization | 1074 GB/s | 761 GB/s | OK (+41%) | +| Matrix Multiplication | 36.8 TFLOPS | 50.7 TFLOPS | -27% | +| Batch Matrix Multiply | 28.3 TFLOPS | 40.0 TFLOPS | -29% | +| FFT (3-stage Cooley-Tukey) | 571 μs | 192 μs | -66% | + +Memory-bound kernels (vadd, transpose, layernorm) match or beat Python. Compute-intensive +kernels (matmul, batch matmul, FFT) are slower due to conservative token threading in the +generated Tile IR, which serializes loads that could otherwise be pipelined. ## Supported Operations