From 0b2d7c58bdc92ffcaeb65529a6369eb9d40bfb67 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 23 Feb 2026 11:35:15 +0100 Subject: [PATCH 1/6] Add tile-indexed methods for existing atomic operations --- src/compiler/intrinsics/atomics.jl | 116 ++++++++++++++++++++++ src/language/atomics.jl | 147 ++++++++++++++++++++++++++++ test/codegen/operations.jl | 58 +++++++++++ test/execution/atomics.jl | 148 +++++++++++++++++++++++++++++ 4 files changed, 469 insertions(+) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 9c480bf8..2a845391 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -177,3 +177,119 @@ efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) = function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args) emit_atomic_rmw!(ctx, args, AtomicADD) end + +# ============================================================================ +# Tile-indexed atomic operations +# These take pre-computed pointer tiles, value tiles, and masks. +# Used by the public API for tile-indexed atomic operations. +# ============================================================================ + +# Shared codegen helper for tile-indexed atomic RMW operations +function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) + cb = ctx.cb + tt = ctx.tt + + # args: (ptr_tile, val, mask, memory_order, memory_scope) + ptr_tv = emit_value!(ctx, args[1]) + ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile")) + val_tv = emit_value!(ctx, args[2]) + val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value")) + mask_tv = emit_value!(ctx, args[3]) + mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask")) + + memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order")) + memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope")) + + shape = val_tv.shape + elem_type = eltype(val_tv.jltype) + + dtype = julia_to_tile_dtype!(tt, elem_type) + result_tile_type = tile_type!(tt, dtype, collect(shape)) + token_type = Token(tt) + + # Auto-promote integer ADD to float ADD for floating-point types + actual_mode = mode + if mode == AtomicADD && elem_type <: AbstractFloat + actual_mode = AtomicADDF + end + + mem_ordering = memory_order_to_semantics(memory_order) + mem_scope = memory_scope_to_scope(memory_scope) + + old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, val_tv.v, actual_mode; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + ctx.token = new_token + + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) +end + +# Tile-indexed atomic exchange +@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope) +function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...) + CC.widenconst(val) +end +efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) = + CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args) + emit_atomic_rmw_tile!(ctx, args, AtomicXCHG) +end + +# Tile-indexed atomic addition +@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope) +function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...) + CC.widenconst(val) +end +efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) = + CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args) + emit_atomic_rmw_tile!(ctx, args, AtomicADD) +end + +# Tile-indexed atomic compare-and-swap +@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope) +function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...) + CC.widenconst(expected) +end +efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) = + CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) +function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args) + cb = ctx.cb + tt = ctx.tt + + # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope) + ptr_tv = emit_value!(ctx, args[1]) + ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile")) + expected_tv = emit_value!(ctx, args[2]) + expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value")) + desired_tv = emit_value!(ctx, args[3]) + desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value")) + mask_tv = emit_value!(ctx, args[4]) + mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask")) + + memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order")) + memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope")) + + shape = expected_tv.shape + elem_type = eltype(expected_tv.jltype) + + dtype = julia_to_tile_dtype!(tt, elem_type) + result_tile_type = tile_type!(tt, dtype, collect(shape)) + token_type = Token(tt) + + mem_ordering = memory_order_to_semantics(memory_order) + mem_scope = memory_scope_to_scope(memory_scope) + + old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, expected_tv.v, desired_tv.v; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + ctx.token = new_token + + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) +end diff --git a/src/language/atomics.jl b/src/language/atomics.jl index 5405449a..c3ee8393 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -80,3 +80,150 @@ old_val = ct.atomic_add(counters, idx, Int32(1)) memory_scope::Int=MemScope.Device) where {T} Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope) end + +# ============================================================================ +# Tile-indexed atomic operations (scatter-gather style indexing) +# These accept Tile indices to perform atomic operations on multiple elements. +# ============================================================================ + +# --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl) --- + +@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{I}) where {T, I <: Integer} + indices_0 = indices .- one(I) + indices_i32 = convert(Tile{Int32}, indices_0) + ptr_tile = Intrinsics.offset(array.ptr, indices_i32) + zero_0d = Tile(Int32(0)) + size_0d = Tile(size(array, 1)) + mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d) + (ptr_tile, mask, size(indices)) +end + +@inline function _atomic_ptrs_mask(array::TileArray{T, 2}, + indices::Tuple{Tile{I0}, Tile{I1}}) where {T, I0 <: Integer, I1 <: Integer} + idx0_0 = indices[1] .- one(I0) + idx1_0 = indices[2] .- one(I1) + + S = broadcast_shape(size(indices[1]), size(indices[2])) + idx0_bc = broadcast_to(idx0_0, S) + idx1_bc = broadcast_to(idx1_0, S) + + idx0_i32 = convert(Tile{Int32}, idx0_bc) + idx1_i32 = convert(Tile{Int32}, idx1_bc) + + stride0_0d = Tile(array.strides[1]) + stride1_0d = Tile(array.strides[2]) + stride0 = broadcast_to(stride0_0d, S) + stride1 = broadcast_to(stride1_0d, S) + + linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1 + ptr_tile = Intrinsics.offset(array.ptr, linear_idx) + + zero_0d = Tile(Int32(0)) + zero_bc = broadcast_to(zero_0d, S) + size0_bc = broadcast_to(Tile(size(array, 1)), S) + size1_bc = broadcast_to(Tile(size(array, 2)), S) + + mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc) + mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc) + mask = mask0 .& mask1 + + (ptr_tile, mask, S) +end + +# --- RMW operations (atomic_add, atomic_xchg) --- + +const _ATOMIC_RMW_OPS = ( + (:add, :atomic_add_tile), + (:xchg, :atomic_xchg_tile), +) + +for (op, intrinsic) in _ATOMIC_RMW_OPS + fname = Symbol(:atomic_, op) + + # 1D with scalar value + @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + val_tile = broadcast_to(Tile(val), S) + Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) + end + + # 1D with tile value + @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I <: Integer} + ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices) + Intrinsics.$intrinsic(ptr_tile, val, mask, memory_order, memory_scope) + end + + # 2D with scalar value + @eval @inline function $fname(array::TileArray{T, 2}, + indices::Tuple{Tile{I0}, Tile{I1}}, val::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + val_tile = broadcast_to(Tile(val), S) + Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) + end + + # 2D with tile value + @eval @inline function $fname(array::TileArray{T, 2}, + indices::Tuple{Tile{I0}, Tile{I1}}, val::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + val_bc = broadcast_to(val, S) + Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) + end +end + +# --- CAS operations (separate due to different signature) --- + +# 1D with scalar expected/desired +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I}, + expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + expected_tile = broadcast_to(Tile(expected), S) + desired_tile = broadcast_to(Tile(desired), S) + Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask, + memory_order, memory_scope) +end + +# 1D with tile expected/desired +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I}, + expected::Tile{T}, desired::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I <: Integer} + ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices) + Intrinsics.atomic_cas_tile(ptr_tile, expected, desired, mask, + memory_order, memory_scope) +end + +# 2D with scalar expected/desired +@inline function atomic_cas(array::TileArray{T, 2}, + indices::Tuple{Tile{I0}, Tile{I1}}, + expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + expected_tile = broadcast_to(Tile(expected), S) + desired_tile = broadcast_to(Tile(desired), S) + Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask, + memory_order, memory_scope) +end + +# 2D with tile expected/desired +@inline function atomic_cas(array::TileArray{T, 2}, + indices::Tuple{Tile{I0}, Tile{I1}}, + expected::Tile{T}, desired::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + expected_bc = broadcast_to(expected, S) + desired_bc = broadcast_to(desired, S) + Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask, + memory_order, memory_scope) +end diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index 8da55a90..57a2f31d 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -1418,6 +1418,64 @@ end end end + + @testset "tile-indexed atomic_cas_tko" begin + spec = ct.ArraySpec{1}(16, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr + @check "iota" + indices = ct.arange((16,), Int) + @check "offset" + @check "atomic_cas_tko" + ct.atomic_cas(arr, indices, Int32(0), Int32(1)) + return + end + end + end + + @testset "tile-indexed atomic_rmw_tko" begin + spec = ct.ArraySpec{1}(16, true) + # xchg + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr + @check "iota" + indices = ct.arange((16,), Int) + @check "offset" + @check "atomic_rmw_tko" + ct.atomic_xchg(arr, indices, Int32(42)) + return + end + end + + # add (integer) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do arr + @check "iota" + indices = ct.arange((16,), Int) + @check "offset" + @check "atomic_rmw_tko" + ct.atomic_add(arr, indices, Int32(1)) + return + end + end + + # add (float) + spec_f32 = ct.ArraySpec{1}(16, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec_f32}}) do arr + @check "iota" + indices = ct.arange((16,), Int) + @check "offset" + @check "atomic_rmw_tko" + ct.atomic_add(arr, indices, 1.5f0) + return + end + end + end end #========================================================================= diff --git a/test/execution/atomics.jl b/test/execution/atomics.jl index 81ffe193..60c93d42 100644 --- a/test/execution/atomics.jl +++ b/test/execution/atomics.jl @@ -166,6 +166,154 @@ end @test result == n_blocks end +# ============================================================================ +# Tile-indexed atomic operations (scatter-gather style indexing) +# ============================================================================ + +@testset "atomic_add tile-indexed 1D" begin + function atomic_add_tile_kernel(arr::ct.TileArray{Int,1}, TILE::Int) + bid = ct.bid(1) + base = (bid - 1) * TILE + indices = base .+ ct.arange((TILE,), Int) + ct.atomic_add(arr, indices, 1; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + tile_size = 16 + n = 256 + n_blocks = div(n, tile_size) + arr = CUDA.zeros(Int, n) + + ct.launch(atomic_add_tile_kernel, n_blocks, arr, ct.Constant(tile_size)) + + @test all(Array(arr) .== 1) +end + +@testset "atomic_add tile-indexed returns old values" begin + function atomic_add_return_kernel(arr::ct.TileArray{Int,1}, out::ct.TileArray{Int,1}) + indices = ct.arange((16,), Int) + old_vals = ct.atomic_add(arr, indices, 1; + memory_order=ct.MemoryOrder.AcqRel) + ct.scatter(out, indices, old_vals) + return + end + + arr = CUDA.zeros(Int, 16) + out = CUDA.fill(Int(-1), 16) + + ct.launch(atomic_add_return_kernel, 1, arr, out) + + @test all(Array(out) .== 0) + @test all(Array(arr) .== 1) +end + +@testset "atomic_add tile-indexed Float32" begin + function atomic_add_f32_tile_kernel(arr::ct.TileArray{Float32,1}, TILE::Int) + bid = ct.bid(1) + base = (bid - 1) * TILE + indices = base .+ ct.arange((TILE,), Int) + ct.atomic_add(arr, indices, 1.5f0; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + tile_size = 16 + n = 256 + n_blocks = div(n, tile_size) + arr = CUDA.zeros(Float32, n) + + ct.launch(atomic_add_f32_tile_kernel, n_blocks, arr, ct.Constant(tile_size)) + + @test all(isapprox.(Array(arr), 1.5f0)) +end + +@testset "atomic_add tile-indexed with tile values" begin + function atomic_add_tile_val_kernel(arr::ct.TileArray{Int,1}, + vals::ct.TileArray{Int,1}) + indices = ct.arange((16,), Int) + val_tile = ct.gather(vals, indices) + ct.atomic_add(arr, indices, val_tile; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + arr = CUDA.zeros(Int, 16) + vals = CuArray(collect(Int, 1:16)) + + ct.launch(atomic_add_tile_val_kernel, 1, arr, vals) + + @test Array(arr) == collect(1:16) +end + +@testset "atomic_xchg tile-indexed" begin + function atomic_xchg_tile_kernel(arr::ct.TileArray{Int,1}) + indices = ct.arange((16,), Int) + ct.atomic_xchg(arr, indices, 42; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + arr = CUDA.zeros(Int, 16) + + ct.launch(atomic_xchg_tile_kernel, 1, arr) + + @test all(Array(arr) .== 42) +end + +@testset "atomic_cas tile-indexed success" begin + function atomic_cas_tile_kernel(arr::ct.TileArray{Int,1}, out::ct.TileArray{Int,1}) + indices = ct.arange((16,), Int) + old_vals = ct.atomic_cas(arr, indices, 0, 1; + memory_order=ct.MemoryOrder.AcqRel) + ct.scatter(out, indices, old_vals) + return + end + + arr = CUDA.zeros(Int, 16) + out = CUDA.fill(Int(-1), 16) + + ct.launch(atomic_cas_tile_kernel, 1, arr, out) + + @test all(Array(out) .== 0) + @test all(Array(arr) .== 1) +end + +@testset "atomic_cas tile-indexed failure" begin + function atomic_cas_fail_kernel(arr::ct.TileArray{Int,1}, out::ct.TileArray{Int,1}) + indices = ct.arange((16,), Int) + old_vals = ct.atomic_cas(arr, indices, 0, 2; + memory_order=ct.MemoryOrder.AcqRel) + ct.scatter(out, indices, old_vals) + return + end + + arr = CUDA.fill(Int(1), 16) + out = CUDA.fill(Int(-1), 16) + + ct.launch(atomic_cas_fail_kernel, 1, arr, out) + + @test all(Array(out) .== 1) # old values returned + @test all(Array(arr) .== 1) # unchanged (CAS failed) +end + +@testset "atomic_add tile-indexed out-of-bounds" begin + function atomic_add_oob_kernel(arr::ct.TileArray{Int,1}) + # Index tile is larger than array — OOB elements should be masked + indices = ct.arange((16,), Int) + ct.atomic_add(arr, indices, 1; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + arr = CUDA.zeros(Int, 8) + + ct.launch(atomic_add_oob_kernel, 1, arr) + + # Only first 8 elements should be updated + @test all(Array(arr) .== 1) +end + @testset "1D gather - simple" begin # Simple 1D gather: copy first 16 elements using gather function gather_simple_kernel(src::ct.TileArray{Float32,1}, dst::ct.TileArray{Float32,1}) From b6d55c8093a47909f09f0a845401e346a2a3cf57 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Mon, 23 Feb 2026 11:53:02 +0100 Subject: [PATCH 2/6] generalize to N dimensions --- src/language/atomics.jl | 147 ++++++++++++++++--------------------- test/codegen/operations.jl | 17 +++++ test/execution/atomics.jl | 18 +++++ 3 files changed, 100 insertions(+), 82 deletions(-) diff --git a/src/language/atomics.jl b/src/language/atomics.jl index c3ee8393..893c0bee 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -82,54 +82,48 @@ old_val = ct.atomic_add(counters, idx, Int32(1)) end # ============================================================================ -# Tile-indexed atomic operations (scatter-gather style indexing) +# Tile-indexed atomic operations # These accept Tile indices to perform atomic operations on multiple elements. # ============================================================================ -# --- Pointer/mask helpers (same pattern as gather/scatter in operations.jl) --- +# --- Pointer/mask helper (N-dimensional) --- -@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{I}) where {T, I <: Integer} - indices_0 = indices .- one(I) - indices_i32 = convert(Tile{Int32}, indices_0) - ptr_tile = Intrinsics.offset(array.ptr, indices_i32) - zero_0d = Tile(Int32(0)) - size_0d = Tile(size(array, 1)) - mask = (indices_i32 .>= zero_0d) .& (indices_i32 .< size_0d) - (ptr_tile, mask, size(indices)) -end - -@inline function _atomic_ptrs_mask(array::TileArray{T, 2}, - indices::Tuple{Tile{I0}, Tile{I1}}) where {T, I0 <: Integer, I1 <: Integer} - idx0_0 = indices[1] .- one(I0) - idx1_0 = indices[2] .- one(I1) +@inline function _atomic_ptrs_mask(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}) where {T, N} + # Convert each index to 0-indexed + indices_0 = ntuple(Val(N)) do d + indices[d] .- one(eltype(indices[d])) + end - S = broadcast_shape(size(indices[1]), size(indices[2])) - idx0_bc = broadcast_to(idx0_0, S) - idx1_bc = broadcast_to(idx1_0, S) + # Broadcast all index tiles to a common shape + S = reduce(broadcast_shape, ntuple(d -> size(indices[d]), Val(N))) - idx0_i32 = convert(Tile{Int32}, idx0_bc) - idx1_i32 = convert(Tile{Int32}, idx1_bc) + # Broadcast and convert to Int32 + indices_i32 = ntuple(Val(N)) do d + convert(Tile{Int32}, broadcast_to(indices_0[d], S)) + end - stride0_0d = Tile(array.strides[1]) - stride1_0d = Tile(array.strides[2]) - stride0 = broadcast_to(stride0_0d, S) - stride1 = broadcast_to(stride1_0d, S) + # Linear index: sum(idx[d] * stride[d]) + linear_idx = reduce(.+, ntuple(Val(N)) do d + indices_i32[d] .* broadcast_to(Tile(array.strides[d]), S) + end) - linear_idx = idx0_i32 .* stride0 + idx1_i32 .* stride1 ptr_tile = Intrinsics.offset(array.ptr, linear_idx) - zero_0d = Tile(Int32(0)) - zero_bc = broadcast_to(zero_0d, S) - size0_bc = broadcast_to(Tile(size(array, 1)), S) - size1_bc = broadcast_to(Tile(size(array, 2)), S) - - mask0 = (idx0_i32 .>= zero_bc) .& (idx0_i32 .< size0_bc) - mask1 = (idx1_i32 .>= zero_bc) .& (idx1_i32 .< size1_bc) - mask = mask0 .& mask1 + # Bounds mask: 0 <= idx[d] < size[d] for all d + zero_bc = broadcast_to(Tile(Int32(0)), S) + mask = reduce(.&, ntuple(Val(N)) do d + (indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to(Tile(size(array, d)), S)) + end) (ptr_tile, mask, S) end +# 1D convenience: single Tile -> 1-tuple +@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T} + _atomic_ptrs_mask(array, (indices,)) +end + # --- RMW operations (atomic_add, atomic_xchg) --- const _ATOMIC_RMW_OPS = ( @@ -140,51 +134,48 @@ const _ATOMIC_RMW_OPS = ( for (op, intrinsic) in _ATOMIC_RMW_OPS fname = Symbol(:atomic_, op) - # 1D with scalar value - @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::T; + # N-D with scalar value + @eval @inline function $fname(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, val::T; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I <: Integer} + memory_scope::Int=MemScope.Device) where {T, N} ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) val_tile = broadcast_to(Tile(val), S) Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) end - # 1D with tile value - @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{I}, val::Tile{T}; + # N-D with tile value + @eval @inline function $fname(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, val::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I <: Integer} - ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices) - Intrinsics.$intrinsic(ptr_tile, val, mask, memory_order, memory_scope) + memory_scope::Int=MemScope.Device) where {T, N} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + val_bc = broadcast_to(val, S) + Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) end - # 2D with scalar value - @eval @inline function $fname(array::TileArray{T, 2}, - indices::Tuple{Tile{I0}, Tile{I1}}, val::T; + # 1D convenience: single Tile index + @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::T; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) - val_tile = broadcast_to(Tile(val), S) - Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) + memory_scope::Int=MemScope.Device) where {T} + $fname(array, (indices,), val; memory_order, memory_scope) end - # 2D with tile value - @eval @inline function $fname(array::TileArray{T, 2}, - indices::Tuple{Tile{I0}, Tile{I1}}, val::Tile{T}; + @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) - val_bc = broadcast_to(val, S) - Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) + memory_scope::Int=MemScope.Device) where {T} + $fname(array, (indices,), val; memory_order, memory_scope) end end # --- CAS operations (separate due to different signature) --- -# 1D with scalar expected/desired -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I}, +# N-D with scalar expected/desired +@inline function atomic_cas(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, expected::T, desired::T; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I <: Integer} + memory_scope::Int=MemScope.Device) where {T, N} ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) expected_tile = broadcast_to(Tile(expected), S) desired_tile = broadcast_to(Tile(desired), S) @@ -192,38 +183,30 @@ end memory_order, memory_scope) end -# 1D with tile expected/desired -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{I}, +# N-D with tile expected/desired +@inline function atomic_cas(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, expected::Tile{T}, desired::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I <: Integer} - ptr_tile, mask, _ = _atomic_ptrs_mask(array, indices) - Intrinsics.atomic_cas_tile(ptr_tile, expected, desired, mask, + memory_scope::Int=MemScope.Device) where {T, N} + ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + expected_bc = broadcast_to(expected, S) + desired_bc = broadcast_to(desired, S) + Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask, memory_order, memory_scope) end -# 2D with scalar expected/desired -@inline function atomic_cas(array::TileArray{T, 2}, - indices::Tuple{Tile{I0}, Tile{I1}}, +# 1D convenience: single Tile index +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, expected::T, desired::T; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) - expected_tile = broadcast_to(Tile(expected), S) - desired_tile = broadcast_to(Tile(desired), S) - Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask, - memory_order, memory_scope) + memory_scope::Int=MemScope.Device) where {T} + atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) end -# 2D with tile expected/desired -@inline function atomic_cas(array::TileArray{T, 2}, - indices::Tuple{Tile{I0}, Tile{I1}}, +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, expected::Tile{T}, desired::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, I0 <: Integer, I1 <: Integer} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) - expected_bc = broadcast_to(expected, S) - desired_bc = broadcast_to(desired, S) - Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask, - memory_order, memory_scope) + memory_scope::Int=MemScope.Device) where {T} + atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) end diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index 57a2f31d..89782c11 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -1434,6 +1434,23 @@ end end + @testset "tile-indexed 3D atomic_add" begin + spec3d = ct.ArraySpec{3}(16, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Int32,3,spec3d}}) do arr + @check "iota" + i = ct.arange((4,), Int) + j = ct.arange((4,), Int) + k = ct.arange((4,), Int) + @check "offset" + @check "atomic_rmw_tko" + ct.atomic_add(arr, (i, j, k), Int32(1)) + return + end + end + end + @testset "tile-indexed atomic_rmw_tko" begin spec = ct.ArraySpec{1}(16, true) # xchg diff --git a/test/execution/atomics.jl b/test/execution/atomics.jl index 60c93d42..e757ed87 100644 --- a/test/execution/atomics.jl +++ b/test/execution/atomics.jl @@ -314,6 +314,24 @@ end @test all(Array(arr) .== 1) end +@testset "atomic_add tile-indexed 3D" begin + function atomic_add_3d_kernel(arr::ct.TileArray{Int,3}) + # 3D index tiles — each is length 4, will broadcast to (4,4,4) = 64 elements + i = ct.reshape(ct.arange((4,), Int), (4, 1, 1)) + j = ct.reshape(ct.arange((4,), Int), (1, 4, 1)) + k = ct.reshape(ct.arange((4,), Int), (1, 1, 4)) + ct.atomic_add(arr, (i, j, k), 1; + memory_order=ct.MemoryOrder.AcqRel) + return + end + + arr = CUDA.zeros(Int, 4, 4, 4) + + ct.launch(atomic_add_3d_kernel, 1, arr) + + @test all(Array(arr) .== 1) +end + @testset "1D gather - simple" begin # Simple 1D gather: copy first 16 elements using gather function gather_simple_kernel(src::ct.TileArray{Float32,1}, dst::ct.TileArray{Float32,1}) From 81efb472e7aebcbd73012c5bf0b3c65d60e139ee Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 2 Mar 2026 09:44:30 +0100 Subject: [PATCH 3/6] Unify scalar and tile-indexed atomic intrinsics Replace 6 intrinsics (3 scalar + 3 tile) with 3 unified ones that take (ptr_tile, val, mask, ...) matching Python cuTile's design. Both paths now compute pointers via Intrinsics.offset in the language layer, with mask=nothing for scalar indices (no mask in bytecode) and Tile{Bool} for tile-indexed (bounds mask passed through). Co-Authored-By: Claude Opus 4.6 --- src/compiler/intrinsics/atomics.jl | 303 ++++++++++------------------- src/language/atomics.jl | 175 +++++++++-------- test/codegen/operations.jl | 3 + 3 files changed, 199 insertions(+), 282 deletions(-) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 2a845391..69819d3c 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -30,69 +30,87 @@ function memory_scope_to_scope(scope::Int) end end +""" + atomic_tfunc(ptrs) -> Type + +Shared tfunc for atomic operations (add, xchg, cas). +Returns raw T for 0D pointer tiles, Tile{T, S} for N-D. +""" +function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...) + ptrs_type = CC.widenconst(ptrs) + ptrs_type isa DataType && ptrs_type <: Tile || return nothing + ptr_type = eltype(ptrs_type) + ptr_type <: Ptr || return nothing + T = eltype(ptr_type) + S = ptrs_type.parameters[2] + S === Tuple{} && return T + return Tile{T, S} +end + # cuda_tile.atomic_cas_tko -@intrinsic atomic_cas(array, index, expected, desired, - memory_order, memory_scope) -tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array)) +@intrinsic atomic_cas(ptr_tile, expected, desired, mask, memory_order, memory_scope) +function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(ptrs), @nospecialize args...) + atomic_tfunc(𝕃, ptrs, args...) +end efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) = CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) cb = ctx.cb tt = ctx.tt - # args: (array, index, expected, desired, memory_order, memory_scope) - array_arg = args[1] - - # Get array info - arg_idx = extract_argument_index(array_arg) - is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx) + # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope) + ptr_tv = emit_value!(ctx, args[1]) + ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile")) + expected_tv = emit_value!(ctx, args[2]) + expected_tv === nothing && throw(IRError("atomic CAS requires expected value")) + desired_tv = emit_value!(ctx, args[3]) + desired_tv === nothing && throw(IRError("atomic CAS requires desired value")) - if !is_tilearray - throw(IRError("atomic_cas requires a TileArray argument")) - end + # Check if mask is provided (ghost Nothing = no mask) + has_mask = get_constant(ctx, args[4]) !== nothing - ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr) - isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument")) - array_val = ptr_vals[1] - tilearray_type = get_arg_type(ctx, arg_idx) - elem_type = eltype(tilearray_type) + memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic CAS requires constant memory_order")) + memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic CAS requires constant memory_scope")) - # Get expected and desired values - expected_tv = emit_value!(ctx, args[3]) - expected_tv === nothing && throw(IRError("atomic_cas requires expected value")) - desired_tv = emit_value!(ctx, args[4]) - desired_tv === nothing && throw(IRError("atomic_cas requires desired value")) + shape = ptr_tv.shape - # Get memory order and scope from args - memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic_cas requires constant memory_order")) - memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic_cas requires constant memory_scope")) + # Get element type from pointer tile: Tile{Ptr{T}, S} -> T + ptrs_type = CC.widenconst(ptr_tv.jltype) + ptr_type = eltype(ptrs_type) + elem_type = eltype(ptr_type) - # Create result type (0D tile of element type) dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, Int[]) + result_tile_type = tile_type!(tt, dtype, collect(shape)) token_type = Token(tt) - # Get index and create pointer type - index_tv = emit_value!(ctx, args[2]) - ptr_type = pointer_type!(tt, dtype) - ptr_tile_type = tile_type!(tt, ptr_type, Int[]) - - # Compute pointer using OffsetOp (handles any integer index type) - pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v) - # Emit atomic CAS mem_ordering = memory_order_to_semantics(memory_order) mem_scope = memory_scope_to_scope(memory_scope) - old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, pointers, - expected_tv.v, desired_tv.v; - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + if has_mask + mask_tv = emit_value!(ctx, args[4]) + mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask")) + old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, expected_tv.v, desired_tv.v; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + else + old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, expected_tv.v, desired_tv.v; + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + end ctx.token = new_token - # Return scalar type (not Tile) to match the intrinsic signature - CGVal(old_val, result_tile_type, elem_type, Int[]) + # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D + if isempty(shape) + CGVal(old_val, result_tile_type, elem_type, Int[]) + else + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) + end end # cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations) @@ -100,44 +118,31 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) cb = ctx.cb tt = ctx.tt - # args: (array, index, val, memory_order, memory_scope) - array_arg = args[1] - - # Get array info - arg_idx = extract_argument_index(array_arg) - is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx) + # args: (ptr_tile, val, mask, memory_order, memory_scope) + ptr_tv = emit_value!(ctx, args[1]) + ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile")) + val_tv = emit_value!(ctx, args[2]) + val_tv === nothing && throw(IRError("atomic RMW requires value")) - if !is_tilearray - throw(IRError("atomic operations require a TileArray argument")) - end + # Check if mask is provided (ghost Nothing = no mask) + has_mask = get_constant(ctx, args[3]) !== nothing - ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr) - isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument")) - array_val = ptr_vals[1] - tilearray_type = get_arg_type(ctx, arg_idx) - elem_type = eltype(tilearray_type) + # Get memory order and scope from args + memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic RMW requires constant memory_order")) + memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic RMW requires constant memory_scope")) - # Get update value - val_tv = emit_value!(ctx, args[3]) - val_tv === nothing && throw(IRError("atomic operation requires value")) + shape = ptr_tv.shape - # Get memory order and scope from args - memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic operation requires constant memory_order")) - memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic operation requires constant memory_scope")) + # Get element type from pointer tile: Tile{Ptr{T}, S} -> T + ptrs_type = CC.widenconst(ptr_tv.jltype) + ptr_type = eltype(ptrs_type) + elem_type = eltype(ptr_type) - # Create result type (0D tile of element type) + # Create result type dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, Int[]) + result_tile_type = tile_type!(tt, dtype, collect(shape)) token_type = Token(tt) - # Get index and create pointer type - index_tv = emit_value!(ctx, args[2]) - ptr_type = pointer_type!(tt, dtype) - ptr_tile_type = tile_type!(tt, ptr_type, Int[]) - - # Compute pointer using OffsetOp (handles any integer index type) - pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v) - # Use float add mode for floating point types actual_mode = mode if mode == AtomicADD && elem_type <: AbstractFloat @@ -148,20 +153,35 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) mem_ordering = memory_order_to_semantics(memory_order) mem_scope = memory_scope_to_scope(memory_scope) - old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, pointers, - val_tv.v, actual_mode; - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + if has_mask + mask_tv = emit_value!(ctx, args[3]) + mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask")) + old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, val_tv.v, actual_mode; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + else + old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, val_tv.v, actual_mode; + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) + end ctx.token = new_token - # Return scalar type (not Tile) to match the intrinsic signature - CGVal(old_val, result_tile_type, elem_type, Int[]) + # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D + if isempty(shape) + CGVal(old_val, result_tile_type, elem_type, Int[]) + else + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) + end end # cuda_tile.atomic_rmw_tko with XCHG -@intrinsic atomic_xchg(array, index, val, memory_order, memory_scope) -tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array)) +@intrinsic atomic_xchg(ptr_tile, val, mask, memory_order, memory_scope) +tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize args...) = atomic_tfunc(𝕃, args...) efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) = CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args) @@ -169,127 +189,10 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args) end # cuda_tile.atomic_rmw_tko with ADD -@intrinsic atomic_add(array, index, val, - memory_order, memory_scope) -tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array)) +@intrinsic atomic_add(ptr_tile, val, mask, memory_order, memory_scope) +tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize args...) = atomic_tfunc(𝕃, args...) efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) = CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args) emit_atomic_rmw!(ctx, args, AtomicADD) end - -# ============================================================================ -# Tile-indexed atomic operations -# These take pre-computed pointer tiles, value tiles, and masks. -# Used by the public API for tile-indexed atomic operations. -# ============================================================================ - -# Shared codegen helper for tile-indexed atomic RMW operations -function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) - cb = ctx.cb - tt = ctx.tt - - # args: (ptr_tile, val, mask, memory_order, memory_scope) - ptr_tv = emit_value!(ctx, args[1]) - ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile")) - val_tv = emit_value!(ctx, args[2]) - val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value")) - mask_tv = emit_value!(ctx, args[3]) - mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask")) - - memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order")) - memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope")) - - shape = val_tv.shape - elem_type = eltype(val_tv.jltype) - - dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, collect(shape)) - token_type = Token(tt) - - # Auto-promote integer ADD to float ADD for floating-point types - actual_mode = mode - if mode == AtomicADD && elem_type <: AbstractFloat - actual_mode = AtomicADDF - end - - mem_ordering = memory_order_to_semantics(memory_order) - mem_scope = memory_scope_to_scope(memory_scope) - - old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, val_tv.v, actual_mode; - mask=mask_tv.v, - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) - ctx.token = new_token - - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) -end - -# Tile-indexed atomic exchange -@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope) -function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...) - CC.widenconst(val) -end -efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) = - CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) -function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args) - emit_atomic_rmw_tile!(ctx, args, AtomicXCHG) -end - -# Tile-indexed atomic addition -@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope) -function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...) - CC.widenconst(val) -end -efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) = - CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) -function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args) - emit_atomic_rmw_tile!(ctx, args, AtomicADD) -end - -# Tile-indexed atomic compare-and-swap -@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope) -function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...) - CC.widenconst(expected) -end -efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) = - CC.Effects(effects; effect_free=CC.ALWAYS_FALSE) -function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args) - cb = ctx.cb - tt = ctx.tt - - # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope) - ptr_tv = emit_value!(ctx, args[1]) - ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile")) - expected_tv = emit_value!(ctx, args[2]) - expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value")) - desired_tv = emit_value!(ctx, args[3]) - desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value")) - mask_tv = emit_value!(ctx, args[4]) - mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask")) - - memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order")) - memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope")) - - shape = expected_tv.shape - elem_type = eltype(expected_tv.jltype) - - dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, collect(shape)) - token_type = Token(tt) - - mem_ordering = memory_order_to_semantics(memory_order) - mem_scope = memory_scope_to_scope(memory_scope) - - old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, expected_tv.v, desired_tv.v; - mask=mask_tv.v, - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) - ctx.token = new_token - - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) -end diff --git a/src/language/atomics.jl b/src/language/atomics.jl index 893c0bee..0df27dea 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -25,71 +25,23 @@ module MemScope const System = 2 end -""" - atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T - -Atomic compare-and-swap. Atomically compares the value at `index` with `expected`, -and if equal, replaces it with `desired`. Returns the original value. -Index is 1-indexed. - -# Example -```julia -# Spin-lock acquisition -while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.Acquire) == Int32(1) - # spin -end -``` -""" -@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - Intrinsics.atomic_cas(array, index - One(), expected, desired, memory_order, memory_scope) -end - -""" - atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T - -Atomic exchange. Atomically replaces the value at `index` with `val` and returns -the original value. Index is 1-indexed. - -# Example -```julia -# Spin-lock release -ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release) -``` -""" -@inline function atomic_xchg(array::TileArray{T}, index, val::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - Intrinsics.atomic_xchg(array, index - One(), val, memory_order, memory_scope) -end - -""" - atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T - -Atomic addition. Atomically adds `val` to the value at `index` and returns -the original value. Index is 1-indexed. - -# Example -```julia -old_val = ct.atomic_add(counters, idx, Int32(1)) -``` -""" -@inline function atomic_add(array::TileArray{T}, index, val::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope) -end - # ============================================================================ -# Tile-indexed atomic operations -# These accept Tile indices to perform atomic operations on multiple elements. +# Pointer/mask helpers +# +# Both scalar and tile-indexed paths compute (ptr_tile, mask, shape) here, +# then pass to a single set of intrinsics. # ============================================================================ -# --- Pointer/mask helper (N-dimensional) --- +# Scalar index -> 0D pointer tile, no mask +@inline function _atomic_ptr_and_mask(array::TileArray{T}, index::Integer) where {T} + idx_0 = Tile(Int32(index - One())) + ptr_tile = Intrinsics.offset(array.ptr, idx_0) + (ptr_tile, nothing, ()) +end -@inline function _atomic_ptrs_mask(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}) where {T, N} +# N-D tile indices -> N-D pointer tile with bounds mask +@inline function _atomic_ptr_and_mask(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}) where {T, N} # Convert each index to 0-indexed indices_0 = ntuple(Val(N)) do d indices[d] .- one(eltype(indices[d])) @@ -120,36 +72,69 @@ end end # 1D convenience: single Tile -> 1-tuple -@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T} - _atomic_ptrs_mask(array, (indices,)) +@inline function _atomic_ptr_and_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T} + _atomic_ptr_and_mask(array, (indices,)) end -# --- RMW operations (atomic_add, atomic_xchg) --- +# ============================================================================ +# Atomic RMW operations (atomic_add, atomic_xchg) +# ============================================================================ + +""" + atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T + +Atomic addition. Atomically adds `val` to the value at `index` and returns +the original value. Index is 1-indexed. + +# Example +```julia +old_val = ct.atomic_add(counters, idx, Int32(1)) +``` +""" +function atomic_add end -const _ATOMIC_RMW_OPS = ( - (:add, :atomic_add_tile), - (:xchg, :atomic_xchg_tile), -) +""" + atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T -for (op, intrinsic) in _ATOMIC_RMW_OPS +Atomic exchange. Atomically replaces the value at `index` with `val` and returns +the original value. Index is 1-indexed. + +# Example +```julia +# Spin-lock release +ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release) +``` +""" +function atomic_xchg end + +for op in (:add, :xchg) fname = Symbol(:atomic_, op) + intrinsic = Symbol(:atomic_, op) - # N-D with scalar value + # Scalar index, scalar val + @eval @inline function $fname(array::TileArray{T}, index, val::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T} + ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) + Intrinsics.$intrinsic(ptr_tile, Tile(val), mask, memory_order, memory_scope) + end + + # N-D tile indices, scalar val @eval @inline function $fname(array::TileArray{T, N}, indices::NTuple{N, Tile{<:Integer}}, val::T; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) val_tile = broadcast_to(Tile(val), S) Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) end - # N-D with tile value + # N-D tile indices, tile val @eval @inline function $fname(array::TileArray{T, N}, indices::NTuple{N, Tile{<:Integer}}, val::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) val_bc = broadcast_to(val, S) Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) end @@ -168,32 +153,58 @@ for (op, intrinsic) in _ATOMIC_RMW_OPS end end -# --- CAS operations (separate due to different signature) --- +# ============================================================================ +# Atomic CAS operations +# ============================================================================ + +""" + atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T + +Atomic compare-and-swap. Atomically compares the value at `index` with `expected`, +and if equal, replaces it with `desired`. Returns the original value. +Index is 1-indexed. + +# Example +```julia +# Spin-lock acquisition +while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.Acquire) == Int32(1) + # spin +end +``` +""" +# Scalar index +@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T} + ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) + Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, + memory_order, memory_scope) +end -# N-D with scalar expected/desired +# N-D tile indices, scalar expected/desired @inline function atomic_cas(array::TileArray{T, N}, indices::NTuple{N, Tile{<:Integer}}, expected::T, desired::T; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) expected_tile = broadcast_to(Tile(expected), S) desired_tile = broadcast_to(Tile(desired), S) - Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask, - memory_order, memory_scope) + Intrinsics.atomic_cas(ptr_tile, expected_tile, desired_tile, mask, + memory_order, memory_scope) end -# N-D with tile expected/desired +# N-D tile indices, tile expected/desired @inline function atomic_cas(array::TileArray{T, N}, indices::NTuple{N, Tile{<:Integer}}, expected::Tile{T}, desired::Tile{T}; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptrs_mask(array, indices) + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) expected_bc = broadcast_to(expected, S) desired_bc = broadcast_to(desired, S) - Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask, - memory_order, memory_scope) + Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask, + memory_order, memory_scope) end # 1D convenience: single Tile index diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index 89782c11..90f2f3b4 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -1384,6 +1384,7 @@ @check_label "entry" code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do locks bid = ct.bid(1) + @check "offset" @check "atomic_cas_tko" old = ct.atomic_cas(locks, bid, Int32(0), Int32(1); memory_order=ct.MemoryOrder.Acquire) @@ -1399,6 +1400,7 @@ @check_label "entry" code_tiled(Tuple{ct.TileArray{Int32,1,spec}}) do locks bid = ct.bid(1) + @check "offset" @check "atomic_rmw_tko" ct.atomic_xchg(locks, bid, Int32(0); memory_order=ct.MemoryOrder.Release) @@ -1412,6 +1414,7 @@ @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,1,spec_f32}}) do counters bid = ct.bid(1) + @check "offset" @check "atomic_rmw_tko" ct.atomic_add(counters, bid, 1.0f0) return From ffd18ff1e3f4d5cd1b433217b968001885afeb1a Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 2 Mar 2026 10:07:13 +0100 Subject: [PATCH 4/6] Make atomic intrinsics consistently tile-based Move scalar-to-tile conversion from the intrinsic layer to the language layer: atomic_tfunc and emit functions now always return Tile{T, S} (even for 0D), and scalar atomic methods unwrap via Intrinsics.to_scalar(). Co-Authored-By: Claude Opus 4.6 --- src/compiler/intrinsics/atomics.jl | 17 +++-------------- src/language/atomics.jl | 8 +++++--- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 69819d3c..7082e259 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -34,7 +34,7 @@ end atomic_tfunc(ptrs) -> Type Shared tfunc for atomic operations (add, xchg, cas). -Returns raw T for 0D pointer tiles, Tile{T, S} for N-D. +Always returns Tile{T, S}, even for 0D (S = Tuple{}). """ function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...) ptrs_type = CC.widenconst(ptrs) @@ -43,7 +43,6 @@ function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...) ptr_type <: Ptr || return nothing T = eltype(ptr_type) S = ptrs_type.parameters[2] - S === Tuple{} && return T return Tile{T, S} end @@ -105,12 +104,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) end ctx.token = new_token - # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D - if isempty(shape) - CGVal(old_val, result_tile_type, elem_type, Int[]) - else - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) - end + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) end # cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations) @@ -171,12 +165,7 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) end ctx.token = new_token - # Return type depends on shape: raw T for 0D, Tile{T, S} for N-D - if isempty(shape) - CGVal(old_val, result_tile_type, elem_type, Int[]) - else - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) - end + CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) end # cuda_tile.atomic_rmw_tko with XCHG diff --git a/src/language/atomics.jl b/src/language/atomics.jl index 0df27dea..452b0444 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -116,7 +116,8 @@ for op in (:add, :xchg) memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T} ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) - Intrinsics.$intrinsic(ptr_tile, Tile(val), mask, memory_order, memory_scope) + Intrinsics.to_scalar( + Intrinsics.$intrinsic(ptr_tile, Tile(val), mask, memory_order, memory_scope)) end # N-D tile indices, scalar val @@ -177,8 +178,9 @@ end memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T} ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) - Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, - memory_order, memory_scope) + Intrinsics.to_scalar( + Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, + memory_order, memory_scope)) end # N-D tile indices, scalar expected/desired From 4dcb1666b626691a5f57a15762987ab94f2dc272 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 2 Mar 2026 12:07:04 +0100 Subject: [PATCH 5/6] Clean-up. --- src/compiler/intrinsics/atomics.jl | 48 +++++----- src/language/atomics.jl | 141 +++++++++++++++-------------- 2 files changed, 95 insertions(+), 94 deletions(-) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 7082e259..7a11f56a 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -86,21 +86,21 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) mem_ordering = memory_order_to_semantics(memory_order) mem_scope = memory_scope_to_scope(memory_scope) - if has_mask + old_val, new_token = if has_mask mask_tv = emit_value!(ctx, args[4]) mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask")) - old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, expected_tv.v, desired_tv.v; - mask=mask_tv.v, - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, expected_tv.v, desired_tv.v; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) else - old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, expected_tv.v, desired_tv.v; - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, expected_tv.v, desired_tv.v; + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) end ctx.token = new_token @@ -147,21 +147,21 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode) mem_ordering = memory_order_to_semantics(memory_order) mem_scope = memory_scope_to_scope(memory_scope) - if has_mask + old_val, new_token = if has_mask mask_tv = emit_value!(ctx, args[3]) mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask")) - old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, val_tv.v, actual_mode; - mask=mask_tv.v, - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, val_tv.v, actual_mode; + mask=mask_tv.v, + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) else - old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, - ptr_tv.v, val_tv.v, actual_mode; - token=ctx.token, - memory_ordering=mem_ordering, - memory_scope=mem_scope) + encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, + ptr_tv.v, val_tv.v, actual_mode; + token=ctx.token, + memory_ordering=mem_ordering, + memory_scope=mem_scope) end ctx.token = new_token diff --git a/src/language/atomics.jl b/src/language/atomics.jl index 452b0444..d5a13d28 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -76,6 +76,77 @@ end _atomic_ptr_and_mask(array, (indices,)) end +# ============================================================================ +# Atomic CAS operations +# ============================================================================ + +""" + atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T + +Atomic compare-and-swap. Atomically compares the value at `index` with `expected`, +and if equal, replaces it with `desired`. Returns the original value. +Index is 1-indexed. + +# Example +```julia +# Spin-lock acquisition +while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.Acquire) == Int32(1) + # spin +end +``` +""" +# Scalar index +@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T} + ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) + Intrinsics.to_scalar( + Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, + memory_order, memory_scope)) +end + +# N-D tile indices, scalar expected/desired +@inline function atomic_cas(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, + expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, N} + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) + expected_tile = broadcast_to(Tile(expected), S) + desired_tile = broadcast_to(Tile(desired), S) + Intrinsics.atomic_cas(ptr_tile, expected_tile, desired_tile, mask, + memory_order, memory_scope) +end + +# N-D tile indices, tile expected/desired +@inline function atomic_cas(array::TileArray{T, N}, + indices::NTuple{N, Tile{<:Integer}}, + expected::Tile{T}, desired::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T, N} + ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) + expected_bc = broadcast_to(expected, S) + desired_bc = broadcast_to(desired, S) + Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask, + memory_order, memory_scope) +end + +# 1D convenience: single Tile index +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, + expected::T, desired::T; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T} + atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) +end + +@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, + expected::Tile{T}, desired::Tile{T}; + memory_order::Int=MemoryOrder.AcqRel, + memory_scope::Int=MemScope.Device) where {T} + atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) +end + + # ============================================================================ # Atomic RMW operations (atomic_add, atomic_xchg) # ============================================================================ @@ -153,73 +224,3 @@ for op in (:add, :xchg) $fname(array, (indices,), val; memory_order, memory_scope) end end - -# ============================================================================ -# Atomic CAS operations -# ============================================================================ - -""" - atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T - -Atomic compare-and-swap. Atomically compares the value at `index` with `expected`, -and if equal, replaces it with `desired`. Returns the original value. -Index is 1-indexed. - -# Example -```julia -# Spin-lock acquisition -while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.Acquire) == Int32(1) - # spin -end -``` -""" -# Scalar index -@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) - Intrinsics.to_scalar( - Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, - memory_order, memory_scope)) -end - -# N-D tile indices, scalar expected/desired -@inline function atomic_cas(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, - expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - expected_tile = broadcast_to(Tile(expected), S) - desired_tile = broadcast_to(Tile(desired), S) - Intrinsics.atomic_cas(ptr_tile, expected_tile, desired_tile, mask, - memory_order, memory_scope) -end - -# N-D tile indices, tile expected/desired -@inline function atomic_cas(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, - expected::Tile{T}, desired::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - expected_bc = broadcast_to(expected, S) - desired_bc = broadcast_to(desired, S) - Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask, - memory_order, memory_scope) -end - -# 1D convenience: single Tile index -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, - expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) -end - -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, - expected::Tile{T}, desired::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) -end From 81381297bb8d0d3586d04c0828e9c930aefb29e2 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 2 Mar 2026 13:34:01 +0100 Subject: [PATCH 6/6] Further simplify. --- src/language/atomics.jl | 97 +++++------------------------------------ src/language/types.jl | 17 +++++--- 2 files changed, 24 insertions(+), 90 deletions(-) diff --git a/src/language/atomics.jl b/src/language/atomics.jl index d5a13d28..535d1114 100644 --- a/src/language/atomics.jl +++ b/src/language/atomics.jl @@ -77,7 +77,7 @@ end end # ============================================================================ -# Atomic CAS operations +# Atomic CAS # ============================================================================ """ @@ -95,58 +95,18 @@ while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder. end ``` """ -# Scalar index -@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T; +@inline function atomic_cas(array::TileArray{T}, indices, + expected::TileOrScalar{T}, desired::TileOrScalar{T}; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T} - ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) - Intrinsics.to_scalar( - Intrinsics.atomic_cas(ptr_tile, Tile(expected), Tile(desired), mask, - memory_order, memory_scope)) -end - -# N-D tile indices, scalar expected/desired -@inline function atomic_cas(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, - expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - expected_tile = broadcast_to(Tile(expected), S) - desired_tile = broadcast_to(Tile(desired), S) - Intrinsics.atomic_cas(ptr_tile, expected_tile, desired_tile, mask, - memory_order, memory_scope) -end - -# N-D tile indices, tile expected/desired -@inline function atomic_cas(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, - expected::Tile{T}, desired::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - expected_bc = broadcast_to(expected, S) - desired_bc = broadcast_to(desired, S) - Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask, - memory_order, memory_scope) -end - -# 1D convenience: single Tile index -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, - expected::T, desired::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) -end - -@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer}, - expected::Tile{T}, desired::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope) + expected_bc = S === () ? Tile(expected) : broadcast_to(Tile(expected), S) + desired_bc = S === () ? Tile(desired) : broadcast_to(Tile(desired), S) + result = Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask, + memory_order, memory_scope) + S === () ? Intrinsics.to_scalar(result) : result end - # ============================================================================ # Atomic RMW operations (atomic_add, atomic_xchg) # ============================================================================ @@ -182,45 +142,12 @@ for op in (:add, :xchg) fname = Symbol(:atomic_, op) intrinsic = Symbol(:atomic_, op) - # Scalar index, scalar val - @eval @inline function $fname(array::TileArray{T}, index, val::T; + @eval @inline function $fname(array::TileArray{T}, indices, val::TileOrScalar{T}; memory_order::Int=MemoryOrder.AcqRel, memory_scope::Int=MemScope.Device) where {T} - ptr_tile, mask, _ = _atomic_ptr_and_mask(array, index) - Intrinsics.to_scalar( - Intrinsics.$intrinsic(ptr_tile, Tile(val), mask, memory_order, memory_scope)) - end - - # N-D tile indices, scalar val - @eval @inline function $fname(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, val::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} - ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - val_tile = broadcast_to(Tile(val), S) - Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope) - end - - # N-D tile indices, tile val - @eval @inline function $fname(array::TileArray{T, N}, - indices::NTuple{N, Tile{<:Integer}}, val::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T, N} ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices) - val_bc = broadcast_to(val, S) - Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) - end - - # 1D convenience: single Tile index - @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::T; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - $fname(array, (indices,), val; memory_order, memory_scope) - end - - @eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::Tile{T}; - memory_order::Int=MemoryOrder.AcqRel, - memory_scope::Int=MemScope.Device) where {T} - $fname(array, (indices,), val; memory_order, memory_scope) + val_bc = S === () ? Tile(val) : broadcast_to(Tile(val), S) + result = Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope) + S === () ? Intrinsics.to_scalar(result) : result end end diff --git a/src/language/types.jl b/src/language/types.jl index 66397cac..edc9069e 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -1,5 +1,6 @@ public TileArray, Tile, Constant, TFloat32, similar_type, - ScalarInt, ScalarFloat, TileInt, TileFloat, ScalarOrTileInt, ScalarOrTileFloat + ScalarInt, ScalarFloat, IntTile, FloatTile, TileOrInt, TileOrFloat, + TileOrScalar """ ArraySpec{N} @@ -250,6 +251,9 @@ In kernel code, this is compiled to a ConstantOp. Tile{T, Tuple{}}() end +# No-op: pass-through for values already wrapped as Tile +@inline Tile(tile::Tile) = tile + #============================================================================= View Types =============================================================================# @@ -367,16 +371,19 @@ const ScalarInt = Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64 const ScalarFloat = Union{Float16, BFloat16, Float32, Float64, TFloat32} """Integer tile types.""" -const TileInt{S} = Tile{T, S} where {T <: ScalarInt} +const IntTile{S} = Tile{T, S} where {T <: ScalarInt} """Floating-point tile types.""" -const TileFloat{S} = Tile{T, S} where {T <: ScalarFloat} +const FloatTile{S} = Tile{T, S} where {T <: ScalarFloat} + +"""Scalar or tile of element type T.""" +const TileOrScalar{T} = Union{T, Tile{T}} """Integer values (scalar or tile).""" -const ScalarOrTileInt = Union{ScalarInt, TileInt} +const TileOrInt = TileOrScalar{<:ScalarInt} """Floating-point values (scalar or tile).""" -const ScalarOrTileFloat = Union{ScalarFloat, TileFloat} +const TileOrFloat = TileOrScalar{<:ScalarFloat} #=============================================================================