Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 92 additions & 84 deletions src/compiler/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,114 +30,113 @@ function memory_scope_to_scope(scope::Int)
end
end

"""
atomic_tfunc(ptrs) -> Type

Shared tfunc for atomic operations (add, xchg, cas).
Always returns Tile{T, S}, even for 0D (S = Tuple{}).
"""
function atomic_tfunc(𝕃, @nospecialize(ptrs), @nospecialize args...)
ptrs_type = CC.widenconst(ptrs)
ptrs_type isa DataType && ptrs_type <: Tile || return nothing
ptr_type = eltype(ptrs_type)
ptr_type <: Ptr || return nothing
T = eltype(ptr_type)
S = ptrs_type.parameters[2]
return Tile{T, S}
end

# cuda_tile.atomic_cas_tko
@intrinsic atomic_cas(array, index, expected, desired,
memory_order, memory_scope)
tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
@intrinsic atomic_cas(ptr_tile, expected, desired, mask, memory_order, memory_scope)
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas), @nospecialize(ptrs), @nospecialize args...)
atomic_tfunc(𝕃, ptrs, args...)
end
efunc(::typeof(Intrinsics.atomic_cas), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
cb = ctx.cb
tt = ctx.tt

# args: (array, index, expected, desired, memory_order, memory_scope)
array_arg = args[1]
# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile"))
expected_tv = emit_value!(ctx, args[2])
expected_tv === nothing && throw(IRError("atomic CAS requires expected value"))
desired_tv = emit_value!(ctx, args[3])
desired_tv === nothing && throw(IRError("atomic CAS requires desired value"))

# Get array info
arg_idx = extract_argument_index(array_arg)
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
# Check if mask is provided (ghost Nothing = no mask)
has_mask = get_constant(ctx, args[4]) !== nothing

if !is_tilearray
throw(IRError("atomic_cas requires a TileArray argument"))
end
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic CAS requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic CAS requires constant memory_scope"))

ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
array_val = ptr_vals[1]
tilearray_type = get_arg_type(ctx, arg_idx)
elem_type = eltype(tilearray_type)
shape = ptr_tv.shape

# Get expected and desired values
expected_tv = emit_value!(ctx, args[3])
expected_tv === nothing && throw(IRError("atomic_cas requires expected value"))
desired_tv = emit_value!(ctx, args[4])
desired_tv === nothing && throw(IRError("atomic_cas requires desired value"))
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
ptrs_type = CC.widenconst(ptr_tv.jltype)
ptr_type = eltype(ptrs_type)
elem_type = eltype(ptr_type)

# Get memory order and scope from args
memory_order = @something get_constant(ctx, args[5]) throw(IRError("atomic_cas requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("atomic_cas requires constant memory_scope"))

# Create result type (0D tile of element type)
dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, Int[])
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

# Get index and create pointer type
index_tv = emit_value!(ctx, args[2])
ptr_type = pointer_type!(tt, dtype)
ptr_tile_type = tile_type!(tt, ptr_type, Int[])

# Compute pointer using OffsetOp (handles any integer index type)
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)

# Emit atomic CAS
mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, pointers,
expected_tv.v, desired_tv.v;
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
old_val, new_token = if has_mask
mask_tv = emit_value!(ctx, args[4])
mask_tv === nothing && throw(IRError("atomic CAS: cannot resolve mask"))
encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, expected_tv.v, desired_tv.v;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
else
encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, expected_tv.v, desired_tv.v;
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
end
ctx.token = new_token

# Return scalar type (not Tile) to match the intrinsic signature
CGVal(old_val, result_tile_type, elem_type, Int[])
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
end

# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
cb = ctx.cb
tt = ctx.tt

# args: (array, index, val, memory_order, memory_scope)
array_arg = args[1]
# args: (ptr_tile, val, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile"))
val_tv = emit_value!(ctx, args[2])
val_tv === nothing && throw(IRError("atomic RMW requires value"))

# Get array info
arg_idx = extract_argument_index(array_arg)
is_tilearray = arg_idx !== nothing && is_destructured_arg(ctx, arg_idx)
# Check if mask is provided (ghost Nothing = no mask)
has_mask = get_constant(ctx, args[3]) !== nothing

if !is_tilearray
throw(IRError("atomic operations require a TileArray argument"))
end
# Get memory order and scope from args
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic RMW requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic RMW requires constant memory_scope"))

ptr_vals = get_arg_flat_values(ctx, arg_idx, :ptr)
isempty(ptr_vals) && throw(IRError("Cannot get ptr from TileArray argument"))
array_val = ptr_vals[1]
tilearray_type = get_arg_type(ctx, arg_idx)
elem_type = eltype(tilearray_type)
shape = ptr_tv.shape

# Get update value
val_tv = emit_value!(ctx, args[3])
val_tv === nothing && throw(IRError("atomic operation requires value"))
# Get element type from pointer tile: Tile{Ptr{T}, S} -> T
ptrs_type = CC.widenconst(ptr_tv.jltype)
ptr_type = eltype(ptrs_type)
elem_type = eltype(ptr_type)

# Get memory order and scope from args
memory_order = @something get_constant(ctx, args[4]) throw(IRError("atomic operation requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("atomic operation requires constant memory_scope"))

# Create result type (0D tile of element type)
# Create result type
dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, Int[])
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

# Get index and create pointer type
index_tv = emit_value!(ctx, args[2])
ptr_type = pointer_type!(tt, dtype)
ptr_tile_type = tile_type!(tt, ptr_type, Int[])

# Compute pointer using OffsetOp (handles any integer index type)
pointers = encode_OffsetOp!(cb, ptr_tile_type, array_val, index_tv.v)

# Use float add mode for floating point types
actual_mode = mode
if mode == AtomicADD && elem_type <: AbstractFloat
Expand All @@ -148,30 +147,39 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, pointers,
val_tv.v, actual_mode;
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
old_val, new_token = if has_mask
mask_tv = emit_value!(ctx, args[3])
mask_tv === nothing && throw(IRError("atomic RMW: cannot resolve mask"))
encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, val_tv.v, actual_mode;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
else
encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, val_tv.v, actual_mode;
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
end
ctx.token = new_token

# Return scalar type (not Tile) to match the intrinsic signature
CGVal(old_val, result_tile_type, elem_type, Int[])
CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
end

# cuda_tile.atomic_rmw_tko with XCHG
@intrinsic atomic_xchg(array, index, val, memory_order, memory_scope)
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
@intrinsic atomic_xchg(ptr_tile, val, mask, memory_order, memory_scope)
tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg), @nospecialize args...) = atomic_tfunc(𝕃, args...)
efunc(::typeof(Intrinsics.atomic_xchg), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg), args)
emit_atomic_rmw!(ctx, args, AtomicXCHG)
end

# cuda_tile.atomic_rmw_tko with ADD
@intrinsic atomic_add(array, index, val,
memory_order, memory_scope)
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize(array), @nospecialize args...) = eltype(CC.widenconst(array))
@intrinsic atomic_add(ptr_tile, val, mask, memory_order, memory_scope)
tfunc(𝕃, ::typeof(Intrinsics.atomic_add), @nospecialize args...) = atomic_tfunc(𝕃, args...)
efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
Expand Down
107 changes: 89 additions & 18 deletions src/language/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,61 @@ module MemScope
const System = 2
end

# ============================================================================
# Pointer/mask helpers
#
# Both scalar and tile-indexed paths compute (ptr_tile, mask, shape) here,
# then pass to a single set of intrinsics.
# ============================================================================

# Scalar index -> 0D pointer tile, no mask
@inline function _atomic_ptr_and_mask(array::TileArray{T}, index::Integer) where {T}
idx_0 = Tile(Int32(index - One()))
ptr_tile = Intrinsics.offset(array.ptr, idx_0)
(ptr_tile, nothing, ())
end

# N-D tile indices -> N-D pointer tile with bounds mask
@inline function _atomic_ptr_and_mask(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}}) where {T, N}
# Convert each index to 0-indexed
indices_0 = ntuple(Val(N)) do d
indices[d] .- one(eltype(indices[d]))
end

# Broadcast all index tiles to a common shape
S = reduce(broadcast_shape, ntuple(d -> size(indices[d]), Val(N)))

# Broadcast and convert to Int32
indices_i32 = ntuple(Val(N)) do d
convert(Tile{Int32}, broadcast_to(indices_0[d], S))
end

# Linear index: sum(idx[d] * stride[d])
linear_idx = reduce(.+, ntuple(Val(N)) do d
indices_i32[d] .* broadcast_to(Tile(array.strides[d]), S)
end)

ptr_tile = Intrinsics.offset(array.ptr, linear_idx)

# Bounds mask: 0 <= idx[d] < size[d] for all d
zero_bc = broadcast_to(Tile(Int32(0)), S)
mask = reduce(.&, ntuple(Val(N)) do d
(indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to(Tile(size(array, d)), S))
end)

(ptr_tile, mask, S)
end

# 1D convenience: single Tile -> 1-tuple
@inline function _atomic_ptr_and_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T}
_atomic_ptr_and_mask(array, (indices,))
end

# ============================================================================
# Atomic CAS
# ============================================================================

"""
atomic_cas(array::TileArray, index, expected, desired; memory_order, memory_scope) -> T

Expand All @@ -40,43 +95,59 @@ while ct.atomic_cas(locks, idx, Int32(0), Int32(1); memory_order=ct.MemoryOrder.
end
```
"""
@inline function atomic_cas(array::TileArray{T}, index, expected::T, desired::T;
@inline function atomic_cas(array::TileArray{T}, indices,
expected::TileOrScalar{T}, desired::TileOrScalar{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
Intrinsics.atomic_cas(array, index - One(), expected, desired, memory_order, memory_scope)
ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices)
expected_bc = S === () ? Tile(expected) : broadcast_to(Tile(expected), S)
desired_bc = S === () ? Tile(desired) : broadcast_to(Tile(desired), S)
result = Intrinsics.atomic_cas(ptr_tile, expected_bc, desired_bc, mask,
memory_order, memory_scope)
S === () ? Intrinsics.to_scalar(result) : result
end

# ============================================================================
# Atomic RMW operations (atomic_add, atomic_xchg)
# ============================================================================

"""
atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T
atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T

Atomic exchange. Atomically replaces the value at `index` with `val` and returns
Atomic addition. Atomically adds `val` to the value at `index` and returns
the original value. Index is 1-indexed.

# Example
```julia
# Spin-lock release
ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release)
old_val = ct.atomic_add(counters, idx, Int32(1))
```
"""
@inline function atomic_xchg(array::TileArray{T}, index, val::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
Intrinsics.atomic_xchg(array, index - One(), val, memory_order, memory_scope)
end
function atomic_add end

"""
atomic_add(array::TileArray, index, val; memory_order, memory_scope) -> T
atomic_xchg(array::TileArray, index, val; memory_order, memory_scope) -> T

Atomic addition. Atomically adds `val` to the value at `index` and returns
Atomic exchange. Atomically replaces the value at `index` with `val` and returns
the original value. Index is 1-indexed.

# Example
```julia
old_val = ct.atomic_add(counters, idx, Int32(1))
# Spin-lock release
ct.atomic_xchg(locks, idx, Int32(0); memory_order=ct.MemoryOrder.Release)
```
"""
@inline function atomic_add(array::TileArray{T}, index, val::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope)
function atomic_xchg end

for op in (:add, :xchg)
fname = Symbol(:atomic_, op)
intrinsic = Symbol(:atomic_, op)

@eval @inline function $fname(array::TileArray{T}, indices, val::TileOrScalar{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
ptr_tile, mask, S = _atomic_ptr_and_mask(array, indices)
val_bc = S === () ? Tile(val) : broadcast_to(Tile(val), S)
result = Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
S === () ? Intrinsics.to_scalar(result) : result
end
end
Loading