Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,12 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080:

| Kernel | Julia | Python | Status |
|--------|-------|--------|--------|
| Vector Addition | 840 GB/s | 844 GB/s | OK (=) |
| Matrix Transpose | 806 GB/s | 816 GB/s | OK (-1%) |
| Layer Normalization | 1074 GB/s | 761 GB/s | OK (+41%) |
| Matrix Multiplication | 36.8 TFLOPS | 50.7 TFLOPS | -27% |
| Batch Matrix Multiply | 28.3 TFLOPS | 40.0 TFLOPS | -29% |
| FFT (3-stage Cooley-Tukey) | 571 μs | 192 μs | -66% |

Memory-bound kernels (vadd, transpose, layernorm) match or beat Python. Compute-intensive
kernels (matmul, batch matmul, FFT) are slower due to conservative token threading in the
generated Tile IR, which serializes loads that could otherwise be pipelined.
| Vector Addition | 841 GB/s | 847 GB/s | OK (=) |
| Matrix Transpose | 807 GB/s | 813 GB/s | OK (-1%) |
| Layer Normalization | 653 GB/s | 758 GB/s | -14% |
| Matrix Multiplication | 43.1 TFLOPS | 50.3 TFLOPS | -14% |
| Batch Matrix Multiply | 30.4 TFLOPS | 40.0 TFLOPS | -24% |
| FFT (3-stage Cooley-Tukey) | 620 μs | 486 μs | -28% |


## Supported Operations
Expand Down
37 changes: 16 additions & 21 deletions examples/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function fft_kernel(

# --- Load Input Data ---
# Input is (D, N2D, BS). Load and reshape to (2, N, BS).
X_ri = reshape(ct.load(x_packed_in; index=(1, 1, bid), shape=(D, N2D, BS)), (2, N, BS))
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS))

# Split real and imaginary parts, reshape to 4D factored form
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS))
Expand Down Expand Up @@ -130,7 +130,7 @@ function fft_kernel(
X_r_final = reshape(X_r10, (1, N, BS))
X_i_final = reshape(X_i10, (1, N, BS))
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS))
ct.store(y_packed_out; index=(1, 1, bid), tile=Y_ri)
ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri)

return
end
Expand Down Expand Up @@ -192,15 +192,12 @@ end

function prepare(; benchmark::Bool=false,
batch::Int=benchmark ? 64 : 2,
n::Int=benchmark ? 512 : 8,
factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2),
atom_packing_dim::Int=2)
@assert factors[1] * factors[2] * factors[3] == n "Factors must multiply to N"
atom_packing_dim::Int=min(64, 2 * prod(factors)))
n = prod(factors)
@assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim"

CUDA.seed!(42)
# Store as (n, batch) so reinterpret gives (2, n, batch) = (D, N2D, batch)
# This matches Python's (batch, N2D, D) row-major in memory.
input = CUDA.randn(ComplexF32, n, batch)

W0, W1, W2, T0, T1 = make_twiddles(factors)
Expand All @@ -212,7 +209,10 @@ function prepare(; benchmark::Bool=false,

D = atom_packing_dim
N2D = n * 2 ÷ D
x_packed = reinterpret(reshape, Float32, input) # (2, n, batch) = (D, N2D, batch)
# Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
# When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
x_ri = reinterpret(reshape, Float32, input) # (2, n, batch)
x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch)
y_packed = CuArray{Float32}(undef, D, N2D, batch)

return (;
Expand Down Expand Up @@ -252,8 +252,9 @@ function run(data; nruns::Int=1, warmup::Int=0)
push!(times, t * 1000) # ms
end

# Unpack output: (2, n, batch) → ComplexF32(n, batch)
y_complex = reinterpret(reshape, ComplexF32, y_packed)
# Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch)
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch)
y_complex = reinterpret(reshape, ComplexF32, y_ri)
output = copy(y_complex)

return (; output, times)
Expand Down Expand Up @@ -294,18 +295,12 @@ end
function main()
println("--- Running cuTile FFT Example ---")

BATCH_SIZE = 2
FFT_SIZE = 8
FFT_FACTORS = (2, 2, 2)
ATOM_PACKING_DIM = 2

data = prepare()
println(" Configuration:")
println(" FFT Size (N): $FFT_SIZE")
println(" Batch Size: $BATCH_SIZE")
println(" FFT Factors: $FFT_FACTORS")
println(" Atom Packing Dim: $ATOM_PACKING_DIM")

data = prepare(; batch=BATCH_SIZE, n=FFT_SIZE, factors=FFT_FACTORS, atom_packing_dim=ATOM_PACKING_DIM)
println(" FFT Size (N): $(data.n)")
println(" Batch Size: $(data.batch)")
println(" FFT Factors: $(data.factors)")
println(" Atom Packing Dim: $(data.D)")
println("\nInput data shape: $(size(data.input)), dtype: $(eltype(data.input))")

result = run(data)
Expand Down
16 changes: 7 additions & 9 deletions examples/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,15 @@ def fft_make_twiddles(factors, precision, device):
# Example harness
#=============================================================================

def prepare(*, benchmark: bool = False, batch: int = None, size: int = None, factors: tuple = None, atom_packing_dim: int = 2):
def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None):
"""Allocate and initialize data for FFT."""
if batch is None:
batch = 64 if benchmark else 2
if factors is None:
factors = (8, 8, 8) if benchmark else (2, 2, 2)
F0, F1, F2 = factors
N = F0 * F1 * F2
if size is None:
size = N
assert size == N, f"size ({size}) must equal product of factors ({N})"
D = atom_packing_dim
D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim

input_data = torch.randn(batch, N, dtype=torch.complex64, device='cuda')

Expand Down Expand Up @@ -218,11 +215,12 @@ def run_others(data, *, nruns: int = 1, warmup: int = 0):
# Main
#=============================================================================

def test_fft(batch, size, factors, name=None):
def test_fft(batch, factors, name=None):
"""Test FFT with given parameters."""
size = factors[0] * factors[1] * factors[2]
name = name or f"fft batch={batch}, size={size}, factors={factors}"
print(f"--- {name} ---")
data = prepare(batch=batch, size=size, factors=factors)
data = prepare(batch=batch, factors=factors)
result = run(data)
verify(data, result)
print(" passed")
Expand All @@ -231,8 +229,8 @@ def test_fft(batch, size, factors, name=None):
def main():
print("--- cuTile FFT Examples ---\n")

test_fft(64, 512, (8, 8, 8))
test_fft(32, 512, (8, 8, 8))
test_fft(64, (8, 8, 8))
test_fft(32, (8, 8, 8))

print("\n--- All FFT examples completed ---")

Expand Down
Loading