diff --git a/README.md b/README.md index 0e29f0d6..df020bf5 100644 --- a/README.md +++ b/README.md @@ -96,16 +96,12 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080: | Kernel | Julia | Python | Status | |--------|-------|--------|--------| -| Vector Addition | 840 GB/s | 844 GB/s | OK (=) | -| Matrix Transpose | 806 GB/s | 816 GB/s | OK (-1%) | -| Layer Normalization | 1074 GB/s | 761 GB/s | OK (+41%) | -| Matrix Multiplication | 36.8 TFLOPS | 50.7 TFLOPS | -27% | -| Batch Matrix Multiply | 28.3 TFLOPS | 40.0 TFLOPS | -29% | -| FFT (3-stage Cooley-Tukey) | 571 μs | 192 μs | -66% | - -Memory-bound kernels (vadd, transpose, layernorm) match or beat Python. Compute-intensive -kernels (matmul, batch matmul, FFT) are slower due to conservative token threading in the -generated Tile IR, which serializes loads that could otherwise be pipelined. +| Vector Addition | 841 GB/s | 847 GB/s | OK (=) | +| Matrix Transpose | 807 GB/s | 813 GB/s | OK (-1%) | +| Layer Normalization | 653 GB/s | 758 GB/s | -14% | +| Matrix Multiplication | 43.1 TFLOPS | 50.3 TFLOPS | -14% | +| Batch Matrix Multiply | 30.4 TFLOPS | 40.0 TFLOPS | -24% | +| FFT (3-stage Cooley-Tukey) | 620 μs | 486 μs | -28% | ## Supported Operations diff --git a/examples/fft.jl b/examples/fft.jl index 80b75feb..6b6de2fd 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -52,7 +52,7 @@ function fft_kernel( # --- Load Input Data --- # Input is (D, N2D, BS). Load and reshape to (2, N, BS). - X_ri = reshape(ct.load(x_packed_in; index=(1, 1, bid), shape=(D, N2D, BS)), (2, N, BS)) + X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS)) # Split real and imaginary parts, reshape to 4D factored form X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS)) @@ -130,7 +130,7 @@ function fft_kernel( X_r_final = reshape(X_r10, (1, N, BS)) X_i_final = reshape(X_i10, (1, N, BS)) Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)) - ct.store(y_packed_out; index=(1, 1, bid), tile=Y_ri) + ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri) return end @@ -192,15 +192,12 @@ end function prepare(; benchmark::Bool=false, batch::Int=benchmark ? 64 : 2, - n::Int=benchmark ? 512 : 8, factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2), - atom_packing_dim::Int=2) - @assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N" + atom_packing_dim::Int=min(64, 2 * prod(factors))) + n = prod(factors) @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" CUDA.seed!(42) - # Store as (n, batch) so reinterpret gives (2, n, batch) = (D, N2D, batch) - # This matches Python's (batch, N2D, D) row-major in memory. input = CUDA.randn(ComplexF32, n, batch) W0, W1, W2, T0, T1 = make_twiddles(factors) @@ -212,7 +209,10 @@ function prepare(; benchmark::Bool=false, D = atom_packing_dim N2D = n * 2 ÷ D - x_packed = reinterpret(reshape, Float32, input) # (2, n, batch) = (D, N2D, batch) + # Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major. + # When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout. + x_ri = reinterpret(reshape, Float32, input) # (2, n, batch) + x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch) y_packed = CuArray{Float32}(undef, D, N2D, batch) return (; @@ -252,8 +252,9 @@ function run(data; nruns::Int=1, warmup::Int=0) push!(times, t * 1000) # ms end - # Unpack output: (2, n, batch) → ComplexF32(n, batch) - y_complex = reinterpret(reshape, ComplexF32, y_packed) + # Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch) + y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch) + y_complex = reinterpret(reshape, ComplexF32, y_ri) output = copy(y_complex) return (; output, times) @@ -294,18 +295,12 @@ end function main() println("--- Running cuTile FFT Example ---") - BATCH_SIZE = 2 - FFT_SIZE = 8 - FFT_FACTORS = (2, 2, 2) - ATOM_PACKING_DIM = 2 - + data = prepare() println(" Configuration:") - println(" FFT Size (N): $FFT_SIZE") - println(" Batch Size: $BATCH_SIZE") - println(" FFT Factors: $FFT_FACTORS") - println(" Atom Packing Dim: $ATOM_PACKING_DIM") - - data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM) + println(" FFT Size (N): $(data.n)") + println(" Batch Size: $(data.batch)") + println(" FFT Factors: $(data.factors)") + println(" Atom Packing Dim: $(data.D)") println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))") result = run(data) diff --git a/examples/fft.py b/examples/fft.py index 8c60ab6d..369c9299 100644 --- a/examples/fft.py +++ b/examples/fft.py @@ -111,7 +111,7 @@ def fft_make_twiddles(factors, precision, device): # Example harness #============================================================================= -def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, factors: tuple = None, atom_packing_dim: int = 2): +def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None): """Allocate and initialize data for FFT.""" if batch is None: batch = 64 if benchmark else 2 @@ -119,10 +119,7 @@ def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, fac factors = (8, 8, 8) if benchmark else (2, 2, 2) F0, F1, F2 = factors N = F0 * F1 * F2 - if size is None: - size = N - assert size == N, f"size ({size}) must equal product of factors ({N})" - D = atom_packing_dim + D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim input_data = torch.randn(batch, N, dtype=torch.complex64, device='cuda') @@ -218,11 +215,12 @@ def run_others(data, *, nruns: int = 1, warmup: int = 0): # Main #============================================================================= -def test_fft(batch, size, factors, name=None): +def test_fft(batch, factors, name=None): """Test FFT with given parameters.""" + size = factors[0] * factors[1] * factors[2] name = name or f"fft batch={batch}, size={size}, factors={factors}" print(f"--- {name} ---") - data = prepare(batch=batch, size=size, factors=factors) + data = prepare(batch=batch, factors=factors) result = run(data) verify(data, result) print(" passed") @@ -231,8 +229,8 @@ def test_fft(batch, size, factors, name=None): def main(): print("--- cuTile FFT Examples ---\n") - test_fft(64, 512, (8, 8, 8)) - test_fft(32, 512, (8, 8, 8)) + test_fft(64, (8, 8, 8)) + test_fft(32, (8, 8, 8)) print("\n--- All FFT examples completed ---")