From e35c1cd6be5ab2f133878f848331d320cd6f174c Mon Sep 17 00:00:00 2001 From: shreyas-omkar Date: Wed, 18 Feb 2026 16:46:14 +0530 Subject: [PATCH 01/10] Add alias-aware token threading for memory operations. Introduce alias analysis based token threading: - Group pointers into alias sets. - Maintain per-alias-set token chains. - Thread tokens only between potentially aliasing operations. - Conservatively fall back to the global set for unknown pointers. - Preserve existing control-flow token merging semantics. Enables independent memory operations to execute without unnecessary serialization. --- src/compiler/codegen.jl | 3 + src/compiler/codegen/alias_analysis.jl | 228 +++++++++++++++++++++++++ src/compiler/codegen/control_flow.jl | 41 ++++- src/compiler/codegen/kernel.jl | 22 ++- src/compiler/codegen/token_keys.jl | 38 +++++ src/compiler/codegen/token_order.jl | 130 ++++++++++++++ src/compiler/codegen/utils.jl | 55 +++++- src/compiler/intrinsics/memory.jl | 108 +++++++++--- src/compiler/intrinsics/views.jl | 60 +++++-- 9 files changed, 642 insertions(+), 43 deletions(-) create mode 100644 src/compiler/codegen/alias_analysis.jl create mode 100644 src/compiler/codegen/token_keys.jl create mode 100644 src/compiler/codegen/token_order.jl diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 564aa8ea..6cb0af0e 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,6 +1,9 @@ # Codegen: Julia IR -> Tile IR bytecode include("codegen/utils.jl") +include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY +include("codegen/alias_analysis.jl") # Defines alias_analysis_pass! +include("codegen/token_order.jl") # Defines get_alias_set, get_input_token! include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") diff --git a/src/compiler/codegen/alias_analysis.jl b/src/compiler/codegen/alias_analysis.jl new file mode 100644 index 00000000..87dcfd7e --- /dev/null +++ b/src/compiler/codegen/alias_analysis.jl @@ -0,0 +1,228 @@ +""" + AliasTracker + +Tracks alias sets for each SSA value during fixed-point analysis. +""" +mutable struct AliasTracker + dirty::Bool + aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet +end + +AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}()) + +function Base.getindex(tracker::AliasTracker, key) + return get(tracker.aliases, key, ALIAS_UNIVERSE) +end + +function Base.setindex!(tracker::AliasTracker, value::AliasSet, key) + current = get(tracker.aliases, key, nothing) + if current !== value + tracker.dirty = true + tracker.aliases[key] = value + end + return +end + +""" + alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet} + +Perform fixed-point alias analysis on structured IR. +Returns mapping from SSA values to alias sets. +""" +function alias_analysis_pass!(sci::StructuredIRCode) + tracker = AliasTracker() + + # Initialize: each argument gets its own alias set + for (idx, argtype) in enumerate(sci.argtypes) + argtype_unwrapped = CC.widenconst(argtype) + if contains_pointers(argtype_unwrapped) + arg_ref = Argument(idx) + tracker[arg_ref] = Set{Any}([arg_ref]) + end + end + + # Fixed-point iteration + iteration = 0 + max_iterations = 100 + + tracker.dirty = true + while tracker.dirty && iteration < max_iterations + tracker.dirty = false + iteration += 1 + + analyze_block!(tracker, sci.entry) + end + + @debug "Alias analysis converged in $iteration iterations" + + return tracker.aliases +end + +""" + propagate!(tracker::AliasTracker, from, to) + +Propagate alias set from `from` to `to`. +Uses direct assignment when `to` is uninitialized, union otherwise. +""" +function propagate!(tracker::AliasTracker, from, to) + from_aliases = tracker[from] + + if from_aliases === ALIAS_UNIVERSE + # Propagating UNIVERSE is always conservative + tracker[to] = ALIAS_UNIVERSE + return + end + + if haskey(tracker.aliases, to) + # Target already has an alias set union with it + to_aliases = tracker.aliases[to] + new_aliases = union(from_aliases, to_aliases) + if new_aliases != to_aliases + tracker[to] = new_aliases + end + else + # Target not yet analyzed assign directly + tracker[to] = from_aliases + end + return +end + +""" + analyze_block!(tracker::AliasTracker, block) + +Analyze all statements in a block, recursing into nested control flow. +""" +function analyze_block!(tracker::AliasTracker, block) + for (ssa_idx, entry) in block.body + if entry.stmt isa ControlFlowOp + analyze_control_flow!(tracker, entry.stmt) + else + analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt) + end + end + return +end + +# Recurse into nested control flow regions +function analyze_control_flow!(tracker::AliasTracker, op::IfOp) + analyze_block!(tracker, op.then_region) + return analyze_block!(tracker, op.else_region) +end + +function analyze_control_flow!(tracker::AliasTracker, op::ForOp) + return analyze_block!(tracker, op.body) +end + +function analyze_control_flow!(tracker::AliasTracker, op::WhileOp) + analyze_block!(tracker, op.before) + return analyze_block!(tracker, op.after) +end + +function analyze_control_flow!(tracker::AliasTracker, op::LoopOp) + return analyze_block!(tracker, op.body) +end + +# Fallback for unknown control flow ops +function analyze_control_flow!(::AliasTracker, ::ControlFlowOp) + return +end + +""" + analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + +Analyze a single statement and propagate aliases. +Handles both `:call` and `:invoke` expression forms. +""" +function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + if stmt isa Expr && (stmt.head === :call || stmt.head === :invoke) + # Normalize :call and :invoke into (func, operands) + # :call -> args = [func, operands...] + # :invoke -> args = [MethodInstance, func, operands...] + if stmt.head === :call + func = stmt.args[1] + operands = @view stmt.args[2:end] + else # :invoke + func = stmt.args[2] + operands = @view stmt.args[3:end] + end + + # Resolve func to its runtime value for intrinsic matching. + # In :invoke, func may already be the function object (not a GlobalRef). + resolved_func = if func isa GlobalRef + try + getfield(func.mod, func.name) + catch + nothing + end + else + func # Direct function value (common in :invoke) + end + + # getfield: propagate from parent + if func === GlobalRef(Core, :getfield) && length(operands) >= 1 + field = length(operands) >= 2 ? operands[2] : nothing + + # For TileArray.ptr field access, propagate pointer alias + if field isa QuoteNode && field.value === :ptr + propagate!(tracker, operands[1], ssa) + else + # Conservatively mark as UNIVERSE for non-pointer fields + tracker[ssa] = ALIAS_UNIVERSE + end + + # Pointer arithmetic: propagate from pointer operand + elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-) + for arg in operands + # Find the pointer argument and propagate + arg_aliases = tracker[arg] + if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set + propagate!(tracker, arg, ssa) + break + end + end + + # View construction: propagate alias from first operand + elseif is_view_constructor(resolved_func) || is_pointer_passthrough(resolved_func) + if length(operands) >= 1 + propagate!(tracker, operands[1], ssa) + end + + # Default: unknown operation -> UNIVERSE + else + tracker[ssa] = ALIAS_UNIVERSE + end + + elseif stmt isa ReturnNode + # No alias propagation needed + + else + # Unknown statement type -> conservative + tracker[ssa] = ALIAS_UNIVERSE + end + return +end + +# Helper functions +contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr) + +""" + is_view_constructor(func) -> Bool + +Check if a resolved function is a tensor/partition view constructor. +These propagate alias identity from their first operand. +""" +function is_view_constructor(func) + return func === Intrinsics.make_tensor_view || + func === Intrinsics.make_partition_view +end + +function is_pointer_passthrough(func) + func === GlobalRef(Core.Intrinsics, :bitcast) && return true + + # Safely check by name to avoid UndefVarError if intrinsics aren't exposed + if func isa Core.IntrinsicFunction || func isa Function + n = nameof(func) + return n === :bitcast || n === :assume_div_by || n === :assume_bounded || n === :assume_aligned + end + return false +end diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index b5a4db3d..692edb26 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ # Save token before branches token_before = ctx.token + # Save token_map before branches + token_map_before = copy(ctx.token_map) + # Emit IfOp with callback-based region building then_body = function(_) saved_block_args = copy(ctx.block_args) ctx.token = token_before # Reset to pre-branch token + ctx.token_map = copy(token_map_before) # Reset token_map too emit_block!(ctx, then_blk) if then_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ else_body = function(_) saved_block_args = copy(ctx.block_args) ctx.token = token_before # Reset to pre-branch token + ctx.token_map = copy(token_map_before) # Reset token_map too emit_block!(ctx, else_blk) if else_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ # Last result is the merged token from both branches ctx.token = results[end] + # Merge token_map from both branches + # Conservatively reset to token_before for all keys + for key in keys(ctx.token_map) + ctx.token_map[key] = results[end] + end + # Store results at IfOp's SSA index (may be empty for void-returning ifs) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end @@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit ForOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) @@ -193,8 +207,11 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), end results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values) - # Last result is the token - ctx.token = results[end] + ctx.token = ctx.global_token + + for key in keys(token_map_before) + ctx.token_map[key] = ctx.global_token + end # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) @@ -230,6 +247,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit LoopOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) @@ -263,8 +283,11 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - # Last result is the token - ctx.token = results[end] + ctx.token = ctx.global_token + + for key in keys(token_map_before) + ctx.token_map[key] = ctx.global_token + end # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) @@ -301,6 +324,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ # Number of user result types (excluding token) n_user_results = n_carries + # Save token_map before loop + token_map_before = copy(ctx.token_map) + # Emit WhileOp as cuda_tile.loop with conditional break pattern # MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } # Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } @@ -393,8 +419,11 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - # Last result is the token - ctx.token = results[end] + ctx.token = ctx.global_token + + for key in keys(token_map_before) + ctx.token_map[key] = ctx.global_token + end # Store results at the loop's SSA index (may be empty for void-returning loops) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 79e9a61d..c722cbee 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -141,10 +141,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, create_tensor_views!(ctx, arg_idx, argtype, Int[]) end + # Run alias analysis FIRST + alias_result = alias_analysis_pass!(sci) + ctx.alias_result = alias_result + # Create memory ordering token token_type = Token(tt) ctx.token_type = token_type - ctx.token = encode_MakeTokenOp!(cb, token_type) + root_token = encode_MakeTokenOp!(cb, token_type) + + ctx.global_token = root_token + ctx.token = root_token + + # Initialize token map with root token for all alias sets + # Default: all tokens start at root + ctx.token_map = Dict{TokenKey, Value}() + + unique_alias_sets = Set(values(alias_result)) + for alias_set in unique_alias_sets + ctx.token_map[last_op_key(alias_set)] = root_token + ctx.token_map[last_store_key(alias_set)] = root_token + end + + # ACQUIRE token also starts at root + ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token # Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp) hoist_returns!(ctx.sci.entry) diff --git a/src/compiler/codegen/token_keys.jl b/src/compiler/codegen/token_keys.jl new file mode 100644 index 00000000..07c448ec --- /dev/null +++ b/src/compiler/codegen/token_keys.jl @@ -0,0 +1,38 @@ +# Token role enum +@enum TokenRole LAST_OP LAST_STORE + +# Acquire token key (singleton) +struct AcquireTokenKey end +const ACQUIRE_TOKEN_KEY = AcquireTokenKey() + +# Alias token key (per alias set and role) +struct AliasTokenKey + alias_set::AliasSet + role::TokenRole +end + +# Union type for all token keys +const TokenKey = Union{AliasTokenKey, AcquireTokenKey} + +# Helper constructors +""" + last_op_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last operation (load or store) on an alias set. +""" +last_op_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_OP) + +""" + last_store_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last store operation on an alias set. +""" +last_store_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_STORE) + +# Make TokenKey hashable for use in Dict +Base.hash(key::AliasTokenKey, h::UInt) = hash((key.alias_set, key.role), h) +Base.:(==)(a::AliasTokenKey, b::AliasTokenKey) = + a.alias_set == b.alias_set && a.role == b.role + +Base.hash(::AcquireTokenKey, h::UInt) = hash(:ACQUIRE_TOKEN_KEY, h) +Base.:(==)(::AcquireTokenKey, ::AcquireTokenKey) = true diff --git a/src/compiler/codegen/token_order.jl b/src/compiler/codegen/token_order.jl new file mode 100644 index 00000000..6cee31c0 --- /dev/null +++ b/src/compiler/codegen/token_order.jl @@ -0,0 +1,130 @@ +""" + get_input_var(args) -> Any + +Extract the pointer/array variable from memory operation arguments. +""" +function get_input_var(args) + return args[1] +end + +""" + get_alias_set(ctx::CGCtx, var) -> AliasSet + +Get the alias set for a variable from analysis results. +""" +function get_alias_set(ctx::CGCtx, var) + # Trace to source + source = trace_to_source(ctx, var) + + # Lookup in alias results + return get(ctx.alias_result, source, ALIAS_UNIVERSE) +end + +""" + collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) -> Vector{Value} + +Collect all tokens that need to be joined for synchronization. +Based on Python's `_collect_join_tokens`. +""" +function collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) + tokens_to_join = [ctx.token_map[token_key]] + + for (other_key, other_token) in ctx.token_map + should_join = false + + # Join with ACQUIRE token + if other_key isa AcquireTokenKey + should_join = true + + # Join if alias sets overlap + elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey + # Release memory order: join with all LAST_OP tokens + if memory_order !== nothing && has_release_order(memory_order) + should_join = other_key.role == LAST_OP + end + + # Alias set overlap: join if same role and sets overlap + if other_key.role == token_key.role + alias_overlap = !(other_key.alias_set isa AliasUniverse) && + !(token_key.alias_set isa AliasUniverse) && + !isempty(intersect(other_key.alias_set, token_key.alias_set)) + should_join = should_join || alias_overlap + end + end + + # Skip tokens already present by identity avoids join_tokens(%x,%x,%x) + already_present = any(t -> t === other_token, tokens_to_join) + if should_join && !already_present + push!(tokens_to_join, other_token) + end + end + + return tokens_to_join +end + +""" + get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) + -> (Value, Union{Nothing, JoinOp}) + +Get the input token for an operation, potentially creating a join operation. +""" +function get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) + + if !haskey(ctx.token_map, token_key) + @warn "Token key not found in token_map!" token_key available_keys=keys(ctx.token_map) + # Fallback to root token + return (ctx.token_map[ACQUIRE_TOKEN_KEY], nothing) + end + tokens_to_join = collect_join_tokens(ctx, token_key, memory_order) + + if length(tokens_to_join) == 1 + return (tokens_to_join[1], nothing) + end + + # Join multiple tokens + result_token = encode_JoinTokensOp!(ctx.cb, ctx.token_type, tokens_to_join) + + return (result_token, nothing) # Return nothing for join_op since its already been emitted +end + +""" + trace_to_source(ctx::CGCtx, var) -> Any + +Trace a value back to its original source (Argument, SSAValue). +""" +function trace_to_source(ctx::CGCtx, var) + # Returns if its an Argument or SSAValue + if var isa Argument || var isa SSAValue + return var + end + + # Resolve for SlothNumber + if var isa SlotNumber + tv = get(ctx.slots, var.id, nothing) + if tv !== nothing && is_arg_ref(tv) + arg_idx, _ = tv.arg_ref + return Argument(arg_idx) + end + end + + # Generic emit_value resolution + tv = emit_value!(ctx, var) + if tv !== nothing && is_arg_ref(tv) + arg_idx, _ = tv.arg_ref + return Argument(arg_idx) + end + + # Return as is for unknown + return var +end + +""" + has_release_order(memory_order) -> Bool + +Check if memory order has release semantics. +For now, returns false (no memory order support yet). +""" +function has_release_order(memory_order) + # TODO: Implement proper memory order checking when needed + return false +end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index 393811b2..ab3a90b9 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -2,6 +2,38 @@ # # Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. + +#============================================================================= + Alias Analysis Types +=============================================================================# + +""" + AliasUniverse + +Singleton type representing the universal alias set (everything may alias everything). +""" +struct AliasUniverse end +const ALIAS_UNIVERSE = AliasUniverse() + +# Universe behaves specially in set operations +Base.union(::AliasUniverse, ::AliasUniverse) = ALIAS_UNIVERSE +Base.union(::AliasUniverse, other) = ALIAS_UNIVERSE +Base.union(other, ::AliasUniverse) = ALIAS_UNIVERSE +Base.intersect(::AliasUniverse, other) = other +Base.intersect(other, ::AliasUniverse) = other +Base.:(==)(::AliasUniverse, ::AliasUniverse) = true +Base.:(==)(::AliasUniverse, other) = false +Base.:(==)(other, ::AliasUniverse) = false + +""" + AliasSet + +Union type representing either a concrete set of values that may alias, +or the universal alias set (ALIAS_UNIVERSE). +""" +const AliasSet = Union{Set{Any}, AliasUniverse} + + #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -165,7 +197,9 @@ mutable struct CGCtx tt::TypeTable sci::StructuredIRCode - # Memory ordering token + # Memory ordering token (kept for backward compatibility) + tokens::Dict{UInt64, Value} + global_token::Union{Value, Nothing} token::Union{Value, Nothing} token_type::Union{TypeId, Nothing} @@ -177,6 +211,10 @@ mutable struct CGCtx # Compilation cache (needed for combiner compilation) cache::CacheView + + # Alias-aware token system + alias_result::Dict{Any, AliasSet} # From alias analysis + token_map::Dict{Any, Value} # TokenKey -> current token Value end function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, @@ -192,8 +230,19 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, Dict{Int, CGVal}(), Dict{Tuple{Int, Vector{Int}}, Vector{Value}}(), Dict{Int, Type}(), - Dict{Any, Tuple{Value, TypeId}}(), - cb, tt, sci, token, token_type, type_cache, sm_arch, cache, + Dict{Int, Tuple{Value, TypeId}}(), + cb, + tt, + sci, + Dict{UInt64, Value}(), # tokens (old system) + nothing, # global_token (old system) + token, # token (old system) + token_type, + type_cache, + sm_arch, + cache, + Dict{Any, AliasSet}(), # alias_result + Dict{Any, Value}(), # token_map ) end diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index 5866bd66..f575ae58 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -39,6 +39,15 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) mask_tv, has_mask = emit_optional_mask(ctx, args, 3) + # Get alias set use global token if unknown + alias_set = get_alias_set(ctx, args[1]) + input_token = if alias_set isa AliasUniverse + ctx.token + else + last_store_key_val = last_store_key(alias_set) + first(get_input_token!(ctx, last_store_key_val, nothing)) + end + if has_mask mask = mask_tv.v @@ -48,22 +57,39 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) padding = padding_tv.v # Load with mask and padding - tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; - mask=mask, - padding_value=padding, - token=ctx.token, - optimization_hints) + tile_val, new_token = encode_LoadPtrTkoOp!( + cb, result_tile_type, token_type, pointers; + mask = mask, + padding_value = padding, + token = input_token, + optimization_hints + ) else # Load without mask - tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; - token=ctx.token, - optimization_hints) + tile_val, new_token = encode_LoadPtrTkoOp!( + cb, result_tile_type, token_type, pointers; + token = input_token, + optimization_hints + ) + end + + # Only track alias if we have a real alias set + if alias_set isa AliasUniverse + ctx.token = new_token + else + last_op_key_val = last_op_key(alias_set) + last_op_token = get(ctx.token_map, last_op_key_val, nothing) + if last_op_token === nothing || last_op_token === input_token || last_op_token === new_token + new_last_op_token = new_token + else + new_last_op_token = encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, new_token]) + end + ctx.token_map[last_op_key_val] = new_last_op_token end - ctx.token = new_token julia_shape = ColMajorShape(tile_shape) result_jltype = Tile{elem_type, TupleType(julia_shape)} - CGVal(tile_val, result_tile_type, result_jltype, tile_shape) + return CGVal(tile_val, result_tile_type, result_jltype, tile_shape) end # TODO: cuda_tile.make_token @@ -98,21 +124,57 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) mask_tv, has_mask = emit_optional_mask(ctx, args, 4) - if has_mask - mask = mask_tv.v - - # Store with mask - new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; - mask=mask, - token=ctx.token, - optimization_hints) + alias_set = get_alias_set(ctx, args[1]) + + if alias_set isa AliasUniverse + # Baseline behavior: use global token directly, no alias tracking overhead + if has_mask + mask = mask_tv.v + + # Store with mask + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + mask = mask, + token = ctx.token, + optimization_hints + ) + else + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + token = ctx.token, + optimization_hints + ) + end + ctx.token = new_token else - # Store without mask - new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; - token=ctx.token, - optimization_hints) + last_op_key_val = last_op_key(alias_set) + last_store_key_val = last_store_key(alias_set) + + # Store depends on LAST_OP (write after read/write) + input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) + + if has_mask + mask = mask_tv.v + + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + mask = mask, + token = input_token, + optimization_hints + ) + else + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + token = input_token, + optimization_hints + ) + end + + # Update both LAST_OP and LAST_STORE. + # Do NOT update ctx.token — alias-aware path uses token_map only. + ctx.token_map[last_op_key_val] = new_token + ctx.token_map[last_store_key_val] = new_token end - ctx.token = new_token nothing end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 11e1b99f..17756c64 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -108,13 +108,36 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Load tile with token - tile_val, new_token = encode_LoadViewTkoOp!(cb, tile_type, token_type, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token + # Get alias set fall back to simple token threading if unknown + alias_set = get_alias_set(ctx, args[1]) + + if alias_set isa AliasUniverse + # Baseline behavior: use global token directly, no alias tracking overhead + tile_val, result_token = encode_LoadViewTkoOp!( + cb, tile_type, token_type, pv_arg.v, index_vals; + token = ctx.token, optimization_hints + ) + ctx.token = result_token + else + last_store_key_val = last_store_key(alias_set) + input_token, _ = get_input_token!(ctx, last_store_key_val, nothing) + tile_val, result_token = encode_LoadViewTkoOp!( + cb, tile_type, token_type, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + last_op_key_val = last_op_key(alias_set) + last_op_token = get(ctx.token_map, last_op_key_val, result_token) + # Only join if last_op_token is not already in the causal chain of result_token. + # result_token was produced from input_token, so if last_op_token === input_token + # the join is redundant — result_token already implies last_op_token. + new_last_op_token = last_op_token === input_token ? result_token : + encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, result_token]) + ctx.token_map[last_op_key_val] = new_last_op_token + # Do NOT update ctx.token — alias-aware path uses token_map only. + end julia_shape = ColMajorShape(tile_shape) - CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) + return CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) end function pad_indices(ctx::CGCtx, index_vals::Vector{Value}, ndim::Int, idx_type::TypeId, idx_jl_type::Type) @@ -414,11 +437,28 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Store tile with token + # Get alias set — fall back to simple token threading if unknown + alias_set = get_alias_set(ctx, args[1]) token_type = Token(tt) - new_token = encode_StoreViewTkoOp!(cb, token_type, tile_val, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token - nothing + if alias_set isa AliasUniverse + result_token = encode_StoreViewTkoOp!( + cb, token_type, tile_val, pv_arg.v, index_vals; + token = ctx.token, optimization_hints + ) + ctx.token = result_token + else + last_op_key_val = last_op_key(alias_set) + last_store_key_val = last_store_key(alias_set) + input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) + result_token = encode_StoreViewTkoOp!( + cb, token_type, tile_val, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + ctx.token_map[last_op_key_val] = result_token + ctx.token_map[last_store_key_val] = result_token + # Do NOT update ctx.token — alias-aware path uses token_map only. + end + + return nothing end From 8c5ea50b08627aab592a3d4be6d03fa34a552cb8 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 09:19:33 +0100 Subject: [PATCH 02/10] Refactor token ordering into a separate IR pass. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move token threading from inline codegen to a `token_order_pass!` that runs on StructuredIRCode before bytecode emission. The pass: - Inserts MakeTokenNode at function entry - Adds token arguments to memory operations (loads/stores/atomics) - Inserts JoinTokensNode and TokenResultNode for token flow tracking - Uses alias analysis to give independent arrays independent tokens This decouples token ordering decisions from codegen, matching cuTile Python's architecture (res/cutile-python/src/cuda/tile/_passes/token_order.py). Control flow token threading (loops, branches) is still handled by codegen conservatively; the pass only transforms straight-line code before the first control flow op. Per-alias loop carries will be added in a follow-up. Key changes: - New: codegen/irutils.jl — SSAMap mutation helpers (insert_before!, etc.) - New: codegen/passes/ directory for IR passes - Moved: alias_analysis.jl, token_keys.jl, token_order.jl → passes/ - Simplified: memory.jl, views.jl, atomics.jl — read token from IR args via extract_token_arg!(), fall back to ctx.token inside control flow - Simplified: control_flow.jl — removed token_map save/restore, kept single-token loop carry (conservative) - Removed: ctx.token_map, ctx.global_token, ctx.alias_result from CGCtx Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen.jl | 7 +- src/compiler/codegen/control_flow.jl | 234 ++-------- src/compiler/codegen/irutils.jl | 66 +++ src/compiler/codegen/kernel.jl | 30 +- .../codegen/{ => passes}/alias_analysis.jl | 0 .../codegen/{ => passes}/token_keys.jl | 0 src/compiler/codegen/passes/token_order.jl | 419 ++++++++++++++++++ src/compiler/codegen/statements.jl | 51 ++- src/compiler/codegen/token_order.jl | 130 ------ src/compiler/codegen/utils.jl | 83 +++- src/compiler/intrinsics/atomics.jl | 18 +- src/compiler/intrinsics/memory.jl | 146 ++---- src/compiler/intrinsics/views.jl | 67 +-- src/cuTile.jl | 2 +- 14 files changed, 734 insertions(+), 519 deletions(-) create mode 100644 src/compiler/codegen/irutils.jl rename src/compiler/codegen/{ => passes}/alias_analysis.jl (100%) rename src/compiler/codegen/{ => passes}/token_keys.jl (100%) create mode 100644 src/compiler/codegen/passes/token_order.jl delete mode 100644 src/compiler/codegen/token_order.jl diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 6cb0af0e..50336afc 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,9 +1,10 @@ # Codegen: Julia IR -> Tile IR bytecode include("codegen/utils.jl") -include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY -include("codegen/alias_analysis.jl") # Defines alias_analysis_pass! -include("codegen/token_order.jl") # Defines get_alias_set, get_input_token! +include("codegen/irutils.jl") # SSAMap/Block mutation helpers +include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY +include("codegen/passes/alias_analysis.jl") # alias_analysis_pass! +include("codegen/passes/token_order.jl") # token_order_pass! include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 692edb26..bc7aa51e 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -4,11 +4,6 @@ result_count(T) -> Int Compute the number of results from a Block.types entry. -Block.types contains Julia types: -- For Statement: Julia type → 1 result -- For ControlFlowOp with 0 results: Nothing → 0 results -- For ControlFlowOp with 1 result: Julia type → 1 result -- For ControlFlowOp with N results: Tuple{T1, T2, ...} → N results """ function result_count(@nospecialize(T)) T === Nothing && return 0 @@ -20,12 +15,8 @@ end emit_block!(ctx, block::Block) Emit bytecode for a structured IR block. -All SSA values use original Julia SSA indices (no local renumbering). -Values are stored in ctx.values by their original index. """ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) - # Emit body items (interleaved expressions and control flow ops) - # SSAVector iteration yields (ssa_idx, entry) where entry has .stmt and .typ for (ssa_idx, entry) in block.body if entry.stmt isa ControlFlowOp n_results = result_count(entry.typ) @@ -35,20 +26,11 @@ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) end end - # Emit terminator (unless skipped) if !skip_terminator && block.terminator !== nothing emit_terminator!(ctx, block.terminator) end end -""" - emit_control_flow_op!(ctx, op::ControlFlowOp, result_type, n_results, original_idx) - -Emit bytecode for a structured control flow operation. -Uses multiple dispatch on the concrete ControlFlowOp type. -Results are stored at indices assigned AFTER nested regions (DFS order). -original_idx is the original Julia SSA index for cross-block references. -""" emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = emit_if_op!(ctx, op, result_type, n_results, original_idx) emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = @@ -58,44 +40,40 @@ emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(result_type), n_res emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = emit_loop_op!(ctx, op, result_type, n_results, original_idx) +#============================================================================= + Control flow emitters + Token threading through control flow is still manual (conservative approach). + The token_order_pass handles straight-line code; control flow uses ctx.token. +=============================================================================# + function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb then_blk = op.then_region else_blk = op.else_region - # Get condition value cond_tv = emit_value!(ctx, op.condition) cond_tv === nothing && throw(IRError("Cannot resolve condition for IfOp")) - # Determine result types from parent_result_type + # User result types result_types = TypeId[] - julia_result_types = Type[] if parent_result_type === Nothing # No results elseif parent_result_type <: Tuple for T in parent_result_type.parameters push!(result_types, tile_type_for_julia!(ctx, T)) - push!(julia_result_types, T) end else push!(result_types, tile_type_for_julia!(ctx, parent_result_type)) - push!(julia_result_types, parent_result_type) end n_user_results = length(result_types) - # Add token type as additional result (for memory ordering) + # Add token as additional result push!(result_types, ctx.token_type) - # Save token before branches token_before = ctx.token - # Save token_map before branches - token_map_before = copy(ctx.token_map) - - # Emit IfOp with callback-based region building then_body = function(_) saved_block_args = copy(ctx.block_args) - ctx.token = token_before # Reset to pre-branch token - ctx.token_map = copy(token_map_before) # Reset token_map too + ctx.token = token_before emit_block!(ctx, then_blk) if then_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -105,8 +83,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ end else_body = function(_) saved_block_args = copy(ctx.block_args) - ctx.token = token_before # Reset to pre-branch token - ctx.token_map = copy(token_map_before) # Reset token_map too + ctx.token = token_before emit_block!(ctx, else_blk) if else_blk.terminator === nothing encode_YieldOp!(ctx.cb, [ctx.token]) @@ -116,25 +93,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_ end results = encode_IfOp!(then_body, else_body, cb, result_types, cond_tv.v) - # Last result is the merged token from both branches ctx.token = results[end] - - # Merge token_map from both branches - # Conservatively reset to token_before for all keys - for key in keys(ctx.token_map) - ctx.token_map[key] = results[end] - end - - # Store results at IfOp's SSA index (may be empty for void-returning ifs) ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) end function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb - tt = ctx.tt body_blk = op.body - # Get bounds values lower_tv = emit_value!(ctx, op.lower) upper_tv = emit_value!(ctx, op.upper) step_tv = emit_value!(ctx, op.step) @@ -142,62 +108,40 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), (lower_tv === nothing || upper_tv === nothing || step_tv === nothing) && throw(IRError("Cannot resolve ForOp bounds")) - - # Assert all bounds have the same type lower_tv.jltype === upper_tv.jltype === step_tv.jltype || - throw(IRError("ForOp bounds must all have the same type: lower=$(lower_tv.jltype), upper=$(upper_tv.jltype), step=$(step_tv.jltype)")) + throw(IRError("ForOp bounds must all have the same type")) iv_jl_type = lower_tv.jltype iv_type = tile_type_for_julia!(ctx, iv_jl_type) - # Get init values (init_values are loop-carried values) + # Init values + token init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve ForOp init value")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) push!(init_values, ctx.token) - # Number of carries (init_values) - these are the loop results n_carries = length(op.init_values) - # Determine result types from carries (body.args) result_types = TypeId[] for i in 1:n_carries body_arg = body_blk.args[i] - type_id = tile_type_for_julia!(ctx, body_arg.type) - push!(result_types, type_id) + push!(result_types, tile_type_for_julia!(ctx, body_arg.type)) end - # Add token type as additional result (for memory ordering) push!(result_types, ctx.token_type) - # Number of user result types (excluding token) - n_user_results = n_carries - - # Save token_map before loop - token_map_before = copy(ctx.token_map) - - # Emit ForOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) - # Tile IR block args layout: [iv, carries..., token] - # Julia IR body.args layout: [carries...] - - # Map the induction variable iv_tv = CGVal(block_args[1], iv_type, iv_jl_type) ctx[iv_arg] = iv_tv - # Map carried values (body.args) for i in 1:n_carries body_arg = body_blk.args[i] shape = RowMajorShape(extract_tile_shape(body_arg.type)) - tv = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape) - ctx[body_arg] = tv + ctx[body_arg] = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape) end - - # Set token from last block arg ctx.token = block_args[end] emit_block!(ctx, body_blk) @@ -207,71 +151,43 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), end results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values) - ctx.token = ctx.global_token - - for key in keys(token_map_before) - ctx.token_map[key] = ctx.global_token - end - - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.token = results[end] + ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) end function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb body_blk = op.body - # Get init values (init_values are loop-carried values) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve LoopOp init value")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) push!(init_values, ctx.token) - # Number of carries (init_values) - these are the loop results n_carries = length(op.init_values) - # Determine result types from carries (body.args) result_types = TypeId[] for i in 1:n_carries body_arg = body_blk.args[i] - type_id = tile_type_for_julia!(ctx, body_arg.type) - push!(result_types, type_id) + push!(result_types, tile_type_for_julia!(ctx, body_arg.type)) end - # Add token type as additional result (for memory ordering) push!(result_types, ctx.token_type) - # Number of user result types (excluding token) - n_user_results = n_carries - - # Save token_map before loop - token_map_before = copy(ctx.token_map) - - # Emit LoopOp with callback-based region building body_builder = function(block_args) saved_block_args = copy(ctx.block_args) - # Tile IR block args layout: [carries..., token] - # Julia IR body.args layout: [carries...] - - # Map carried values (body.args) for i in 1:n_carries body_arg = body_blk.args[i] shape = RowMajorShape(extract_tile_shape(body_arg.type)) ctx[body_arg] = CGVal(block_args[i], result_types[i], body_arg.type, shape) end - - # Set token from last block arg ctx.token = block_args[end] emit_block!(ctx, body_blk) - # In Tile IR, if the loop body ends with an IfOp (even one with continue/break - # in all branches), the if is NOT a terminator. We need an explicit terminator - # after the if. Add an unreachable ContinueOp as fallback terminator. if body_blk.terminator === nothing fallback_operands = copy(block_args) fallback_operands[end] = ctx.token @@ -283,14 +199,8 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - ctx.token = ctx.global_token - - for key in keys(token_map_before) - ctx.token_map[key] = ctx.global_token - end - - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.token = results[end] + ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) end function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) @@ -298,75 +208,46 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ before_blk = op.before after_blk = op.after - # Get init values (init_values are loop-carried values) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve WhileOp init value: $init_val")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) push!(init_values, ctx.token) - # Number of carries (init_values) - these are the loop results n_carries = length(op.init_values) - # Determine result types from carries (before.args) result_types = TypeId[] for i in 1:n_carries before_arg = before_blk.args[i] - type_id = tile_type_for_julia!(ctx, before_arg.type) - push!(result_types, type_id) + push!(result_types, tile_type_for_julia!(ctx, before_arg.type)) end - # Add token type as additional result (for memory ordering) push!(result_types, ctx.token_type) - # Number of user result types (excluding token) - n_user_results = n_carries - - # Save token_map before loop - token_map_before = copy(ctx.token_map) - - # Emit WhileOp as cuda_tile.loop with conditional break pattern - # MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } - # Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } - # This structure keeps the "after" statements at LoopOp body level, avoiding - # nested region issues when "after" contains loops. body_builder = function(block_args) saved_block_args = copy(ctx.block_args) - # Tile IR block args layout: [carries..., token] - # Julia IR before.args layout: [carries...] - - # Map carried values (before.args) for i in 1:n_carries before_arg = before_blk.args[i] shape = RowMajorShape(extract_tile_shape(before_arg.type)) ctx[before_arg] = CGVal(block_args[i], result_types[i], before_arg.type, shape) end - - # Set token from last block arg ctx.token = block_args[end] - # Emit "before" region emit_block!(ctx, before_blk) - # Get condition from ConditionOp terminator cond_op = before_blk.terminator cond_op isa ConditionOp || throw(IRError("WhileOp before region must end with ConditionOp")) cond_tv = emit_value!(ctx, cond_op.condition) - (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition: $(cond_op.condition)")) + (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition")) - # Emit conditional break: if (cond) { yield } else { break } - # This keeps nested loops in "after" at LoopOp body level then_body = function(_) - # Just yield (empty) - control continues to after_stmts encode_YieldOp!(ctx.cb, Value[]) end else_body = function(_) - # Break with ConditionOp args (become loop results) break_operands = Value[] for arg in cond_op.args tv = emit_value!(ctx, arg) @@ -381,13 +262,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ encode_BreakOp!(ctx.cb, break_operands) end - # Emit IfOp with NO results: if (cond) continue flow, else break - if_result_types = TypeId[] - encode_IfOp!(then_body, else_body, cb, if_result_types, cond_tv.v) + encode_IfOp!(then_body, else_body, cb, TypeId[], cond_tv.v) - # Now emit "after" region at LoopOp body level (not inside IfOp!) - # Map "after" region block args - carries from ConditionOp.args - for i in 1:n_carries + for i in 1:length(after_blk.args) after_arg = after_blk.args[i] if i <= length(cond_op.args) tv = emit_value!(ctx, cond_op.args[i]) @@ -400,10 +277,8 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end end - # Emit "after" region body (skip terminator - we emit ContinueOp instead) emit_block!(ctx, after_blk; skip_terminator=true) - # Emit ContinueOp with yield values from after region's YieldOp continue_operands = Value[] if after_blk.terminator isa YieldOp for val in after_blk.terminator.values @@ -419,91 +294,57 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - ctx.token = ctx.global_token - - for key in keys(token_map_before) - ctx.token_map[key] = ctx.global_token - end - - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.token = results[end] + ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) end -""" - emit_terminator!(ctx, terminator) +#============================================================================= + Terminators + Token is appended manually for control flow threading (conservative approach). +=============================================================================# -Emit bytecode for a block terminator. -""" function emit_terminator!(ctx::CGCtx, node::ReturnNode) emit_return!(ctx, node) end function emit_terminator!(ctx::CGCtx, op::YieldOp) - # Collect yield operands operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering push!(operands, ctx.token) encode_YieldOp!(ctx.cb, operands) end function emit_terminator!(ctx::CGCtx, op::ContinueOp) - # Collect continue operands (updated carried values) operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering push!(operands, ctx.token) encode_ContinueOp!(ctx.cb, operands) end function emit_terminator!(ctx::CGCtx, op::BreakOp) - # Collect break operands (final values) operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering push!(operands, ctx.token) encode_BreakOp!(ctx.cb, operands) end -function emit_terminator!(ctx::CGCtx, ::Nothing) - # No terminator, nothing to emit -end - -function emit_terminator!(ctx::CGCtx, ::ConditionOp) - # ConditionOp is handled specially by emit_while_op!, not emitted as a terminator -end +function emit_terminator!(ctx::CGCtx, ::Nothing) end +function emit_terminator!(ctx::CGCtx, ::ConditionOp) end #============================================================================= Early Return Hoisting - - tileiras rejects ReturnNode (cuda_tile.return) inside IfOp (cuda_tile.if) - regions. This pre-pass rewrites the structured IR so that ReturnNode only - appears at the top level, replacing nested returns with YieldOp. =============================================================================# -""" - hoist_returns!(block::Block) - -Rewrite `ReturnNode` terminators inside `IfOp` regions into `YieldOp`, -hoisting the return to the parent block. Operates recursively so that -nested early returns (multiple successive `if ... return end` patterns) -are handled automatically. - -Only handles the case where BOTH branches of an IfOp terminate with -ReturnNode (REGION_TERMINATION with 3 children). The 2-child case -(early return inside a loop) is not handled. -""" function hoist_returns!(block::Block) - # First, recurse into all nested control flow ops for (_, entry) in block.body stmt = entry.stmt if stmt isa IfOp @@ -519,29 +360,22 @@ function hoist_returns!(block::Block) end end - # Now check: does this block contain an IfOp where both branches return? - # If so, replace branch ReturnNodes with YieldOp and set block terminator. for (_, entry) in block.body entry.stmt isa IfOp || continue op = entry.stmt::IfOp op.then_region.terminator isa ReturnNode || continue op.else_region.terminator isa ReturnNode || continue - # Both branches return — hoist to parent block. - # Replace branch terminators with YieldOp (void — no values to yield). op.then_region.terminator = YieldOp() op.else_region.terminator = YieldOp() block.terminator = ReturnNode(nothing) end end -""" - emit_getfield!(ctx, args) -> Union{CGVal, Nothing} +#============================================================================= + Loop getfield extraction +=============================================================================# -Handle getfield on multi-value results (loops, ifs). Returns CGVal if handled, -nothing if this is not a multi-value extraction and normal handling should proceed. -This is a compile-time lookup - no Tile IR is emitted. -""" function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any}) length(args) >= 2 || return nothing args[1] isa SSAValue || return nothing diff --git a/src/compiler/codegen/irutils.jl b/src/compiler/codegen/irutils.jl new file mode 100644 index 00000000..58c47629 --- /dev/null +++ b/src/compiler/codegen/irutils.jl @@ -0,0 +1,66 @@ +# StructuredIRCode / SSAMap mutation utilities +# +# Helpers for passes that modify the structured IR in place. +# Inspired by Julia's IncrementalCompact (Compiler/src/ssair/ir.jl). + +""" + new_ssa_idx!(sci::StructuredIRCode) -> Int + +Allocate a fresh SSA index from the StructuredIRCode. +""" +function new_ssa_idx!(sci::StructuredIRCode) + sci.max_ssa_idx += 1 + return sci.max_ssa_idx +end + +""" + new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) -> BlockArg + +Add a new BlockArg to a block, allocating a fresh ID. +""" +function new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) + id = new_ssa_idx!(sci) + arg = BlockArg(id, typ) + push!(block.args, arg) + return arg +end + +""" + Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any}) + +Prepend a statement at the beginning of an SSAMap. +""" +function Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any}) + pushfirst!(m.ssa_idxes, idx) + pushfirst!(m.stmts, stmt) + pushfirst!(m.types, typ) + return nothing +end + +""" + insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ) + +Insert a new entry before the entry with SSA index `before_idx`. +""" +function insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ) + pos = findfirst(==(before_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(before_idx)) + insert!(m.ssa_idxes, pos, new_idx) + insert!(m.stmts, pos, stmt) + insert!(m.types, pos, typ) + return nothing +end + +""" + insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ) + +Insert a new entry after the entry with SSA index `after_idx`. +""" +function insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ) + pos = findfirst(==(after_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(after_idx)) + insert!(m.ssa_idxes, pos + 1, new_idx) + insert!(m.stmts, pos + 1, stmt) + insert!(m.types, pos + 1, typ) + return nothing +end diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index c722cbee..81f85483 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -141,30 +141,14 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, create_tensor_views!(ctx, arg_idx, argtype, Int[]) end - # Run alias analysis FIRST + # Run alias analysis and token ordering pass on the structured IR. + # This inserts MakeTokenNode, JoinTokensNode, TokenResultNode into the IR + # and threads tokens through control flow (loop carries, branch yields). alias_result = alias_analysis_pass!(sci) - ctx.alias_result = alias_result + token_order_pass!(sci, alias_result) - # Create memory ordering token - token_type = Token(tt) - ctx.token_type = token_type - root_token = encode_MakeTokenOp!(cb, token_type) - - ctx.global_token = root_token - ctx.token = root_token - - # Initialize token map with root token for all alias sets - # Default: all tokens start at root - ctx.token_map = Dict{TokenKey, Value}() - - unique_alias_sets = Set(values(alias_result)) - for alias_set in unique_alias_sets - ctx.token_map[last_op_key(alias_set)] = root_token - ctx.token_map[last_store_key(alias_set)] = root_token - end - - # ACQUIRE token also starts at root - ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token + # Cache the token bytecode type for codegen + ctx.token_type = Token(tt) # Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp) hoist_returns!(ctx.sci.entry) @@ -334,7 +318,7 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector, # 3. Create sub-context sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci, - ctx.token, ctx.token_type, + ctx.token_type, ctx.type_cache, ctx.sm_arch, ctx.cache) diff --git a/src/compiler/codegen/alias_analysis.jl b/src/compiler/codegen/passes/alias_analysis.jl similarity index 100% rename from src/compiler/codegen/alias_analysis.jl rename to src/compiler/codegen/passes/alias_analysis.jl diff --git a/src/compiler/codegen/token_keys.jl b/src/compiler/codegen/passes/token_keys.jl similarity index 100% rename from src/compiler/codegen/token_keys.jl rename to src/compiler/codegen/passes/token_keys.jl diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl new file mode 100644 index 00000000..2a1ff767 --- /dev/null +++ b/src/compiler/codegen/passes/token_order.jl @@ -0,0 +1,419 @@ +# Token ordering pass +# +# Transforms a StructuredIRCode by inserting token operations (MakeToken, JoinTokens, +# TokenResult) and threading tokens through control flow (loop carries, branch yields). +# After this pass, codegen simply emits what's in the IR — no manual token threading. +# +# Mirrors cuTile Python's `token_order_pass` (res/cutile-python/src/cuda/tile/_passes/token_order.py). + +using Core: SSAValue, Argument, SlotNumber + +#============================================================================= + Memory effect classification +=============================================================================# + +@enum MemoryEffect MEM_NONE MEM_LOAD MEM_STORE + +""" + MemoryEffects + +Per-block summary of which alias sets are read/written and whether any +acquire-ordered operation appears. +""" +struct MemoryEffects + effects::Dict{AliasSet, MemoryEffect} + has_acquire::Bool +end + +MemoryEffects() = MemoryEffects(Dict{AliasSet, MemoryEffect}(), false) + +function Base.merge!(a::MemoryEffects, b::MemoryEffects) + for (alias_set, effect) in b.effects + existing = get(a.effects, alias_set, MEM_NONE) + a.effects[alias_set] = max(existing, effect) + end + return MemoryEffects(a.effects, a.has_acquire | b.has_acquire) +end + +function Base.union(a::MemoryEffects, b::MemoryEffects) + result = Dict{AliasSet, MemoryEffect}() + for (k, v) in a.effects + result[k] = v + end + for (k, v) in b.effects + existing = get(result, k, MEM_NONE) + result[k] = max(existing, v) + end + return MemoryEffects(result, a.has_acquire | b.has_acquire) +end + +const EMPTY_MEMORY_EFFECTS = MemoryEffects() + +#============================================================================= + Resolve functions from IR expressions +=============================================================================# + +""" + resolve_call(stmt) -> (func, operands) or nothing + +Extract the resolved function value and operands from a :call or :invoke Expr. +""" +function resolve_call(stmt) + stmt isa Expr || return nothing + if stmt.head === :call + func_ref = stmt.args[1] + operands = @view stmt.args[2:end] + elseif stmt.head === :invoke + func_ref = stmt.args[2] + operands = @view stmt.args[3:end] + else + return nothing + end + resolved = if func_ref isa GlobalRef + try + getfield(func_ref.mod, func_ref.name) + catch + nothing + end + else + func_ref + end + resolved === nothing && return nothing + return (resolved, operands) +end + +""" + classify_memory_op(resolved_func) -> (MemoryEffect, Bool) + +Classify a resolved function as a memory operation. +Returns (effect, is_store) where effect is MEM_NONE/MEM_LOAD/MEM_STORE. +""" +function classify_memory_op(resolved_func) + if resolved_func === Intrinsics.load_partition_view || + resolved_func === Intrinsics.load_ptr_tko + return MEM_LOAD + elseif resolved_func === Intrinsics.store_partition_view || + resolved_func === Intrinsics.store_ptr_tko + return MEM_STORE + elseif is_atomic_intrinsic(resolved_func) + return MEM_STORE # Atomics are read-modify-write, treat as store for ordering + else + return MEM_NONE + end +end + +function is_atomic_intrinsic(func) + isdefined(Intrinsics, :atomic_cas) && func === Intrinsics.atomic_cas && return true + for op in (:atomic_xchg, :atomic_add, :atomic_max, :atomic_min, + :atomic_or, :atomic_and, :atomic_xor) + isdefined(Intrinsics, op) && func === getfield(Intrinsics, op) && return true + end + return false +end + +""" + get_alias_set_for_operand(alias_result, operand) -> AliasSet + +Look up the alias set for an operand (the first arg of a memory op). +""" +function get_alias_set_for_operand(alias_result::Dict{Any, AliasSet}, operand) + if operand isa SSAValue || operand isa Argument || operand isa SlotNumber + return get(alias_result, operand, ALIAS_UNIVERSE) + end + return ALIAS_UNIVERSE +end + +#============================================================================= + Compute per-block memory effects +=============================================================================# + +""" + compute_block_memory_effects!(block, alias_result, cache) + +Compute memory effects for a block and all nested blocks, storing results in `cache`. +""" +function compute_block_memory_effects!(block::Block, alias_result::Dict{Any, AliasSet}, + cache::Dict{UInt64, MemoryEffects}) + block_id = objectid(block) + haskey(cache, block_id) && return cache[block_id] + + effects = MemoryEffects() + for (ssa_idx, entry) in block.body + if entry.stmt isa ControlFlowOp + nested = compute_cf_memory_effects!(entry.stmt, alias_result, cache) + effects = union(effects, nested) + else + call = resolve_call(entry.stmt) + call === nothing && continue + resolved_func, operands = call + mem_effect = classify_memory_op(resolved_func) + mem_effect == MEM_NONE && continue + alias_set = get_alias_set_for_operand(alias_result, first(operands)) + existing = get(effects.effects, alias_set, MEM_NONE) + effects.effects[alias_set] = max(existing, mem_effect) + end + end + cache[block_id] = effects + return effects +end + +function compute_cf_memory_effects!(op::IfOp, alias_result, cache) + then_eff = compute_block_memory_effects!(op.then_region, alias_result, cache) + else_eff = compute_block_memory_effects!(op.else_region, alias_result, cache) + return union(then_eff, else_eff) +end + +function compute_cf_memory_effects!(op::ForOp, alias_result, cache) + return compute_block_memory_effects!(op.body, alias_result, cache) +end + +function compute_cf_memory_effects!(op::WhileOp, alias_result, cache) + before_eff = compute_block_memory_effects!(op.before, alias_result, cache) + after_eff = compute_block_memory_effects!(op.after, alias_result, cache) + return union(before_eff, after_eff) +end + +function compute_cf_memory_effects!(op::LoopOp, alias_result, cache) + return compute_block_memory_effects!(op.body, alias_result, cache) +end + +compute_cf_memory_effects!(::ControlFlowOp, alias_result, cache) = EMPTY_MEMORY_EFFECTS + +#============================================================================= + Token map (IR-level, using SSAValue/BlockArg instead of bytecode Values) +=============================================================================# + +# IRToken: an SSAValue, BlockArg, or nothing (for tokens in the IR) +const IRToken = Any + +""" + collect_join_tokens_ir(token_key, token_map, memory_order=nothing) -> Vector{IRToken} + +IR-level equivalent of Python's `_collect_join_tokens`. +Collects all token IR values that need to be joined for the given token_key. +""" +function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, + memory_order=nothing) + tokens_to_join = IRToken[token_map[token_key]] + + for (other_key, other_tok) in token_map + should_join = false + + if other_key isa AcquireTokenKey + should_join = true + elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey + # Release: join with all LAST_OP tokens + if memory_order !== nothing && has_release_order(memory_order) + should_join = other_key.role == LAST_OP + end + # Alias set overlap: same role and sets overlap + if other_key.role == token_key.role + alias_overlap = !(other_key.alias_set isa AliasUniverse) && + !(token_key.alias_set isa AliasUniverse) && + !isempty(intersect(other_key.alias_set, token_key.alias_set)) + should_join = should_join || alias_overlap + end + end + + if should_join && !any(t -> t === other_tok, tokens_to_join) + push!(tokens_to_join, other_tok) + end + end + + return tokens_to_join +end + +""" + get_input_token_ir!(sci, block, before_ssa, token_key, token_map, memory_order=nothing) + -> IRToken + +Get the input token for a memory operation. If multiple tokens need joining, +inserts a JoinTokensNode into the block before `before_ssa` and returns its SSAValue. +""" +function get_input_token_ir!(sci::StructuredIRCode, block::Block, before_ssa::Int, + token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, + memory_order=nothing) + if !haskey(token_map, token_key) + # Fallback to ACQUIRE token + return token_map[ACQUIRE_TOKEN_KEY] + end + + tokens_to_join = collect_join_tokens_ir(token_key, token_map, memory_order) + + if length(tokens_to_join) == 1 + return tokens_to_join[1] + end + + # Insert JoinTokensNode before the memory op + join_ssa = new_ssa_idx!(sci) + insert_before!(block.body, before_ssa, join_ssa, JoinTokensNode(tokens_to_join), TOKEN_TYPE) + return SSAValue(join_ssa) +end + +has_release_order(memory_order) = false + + +#============================================================================= + The main pass +=============================================================================# + +""" + token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet}) + +Transform a StructuredIRCode by inserting explicit token operations. +Modifies the IR in place: +- Inserts MakeTokenNode at function entry +- Inserts JoinTokensNode where tokens need merging +- Inserts TokenResultNode after memory ops to capture their result tokens +- Adds token as extra argument to memory op calls +- Adds per-alias-set token carries through loops and branches + +After this pass, codegen emits tokens from the IR without manual threading. +""" +function token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet}) + # Compute per-block memory effects + effects_cache = Dict{UInt64, MemoryEffects}() + compute_block_memory_effects!(sci.entry, alias_result, effects_cache) + + # Create root token (MakeTokenNode) at entry + root_ssa = new_ssa_idx!(sci) + pushfirst!(sci.entry.body, (root_ssa, MakeTokenNode(), TOKEN_TYPE)) + root_token = SSAValue(root_ssa) + + # Initialize token map: all alias sets start at root token + token_map = Dict{TokenKey, IRToken}() + unique_alias_sets = Set(values(alias_result)) + for alias_set in unique_alias_sets + token_map[last_op_key(alias_set)] = root_token + token_map[last_store_key(alias_set)] = root_token + end + token_map[ACQUIRE_TOKEN_KEY] = root_token + + # Transform the entry block + transform_block!(sci, sci.entry, alias_result, token_map, effects_cache, nothing, nothing) + + return nothing +end + +#============================================================================= + Block transformation (recursive) +=============================================================================# + +""" + transform_block!(sci, block, alias_result, token_map, effects_cache, + innermost_loop_info, ifelse_info) + +Walk a block's statements and transform memory/control-flow ops for token ordering. +Modifies `token_map` in place to reflect the token state after the block. +""" +function transform_block!(sci::StructuredIRCode, block::Block, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, IRToken}, + effects_cache::Dict{UInt64, MemoryEffects}, + innermost_loop_effects::Union{MemoryEffects, Nothing}, + ifelse_effects::Union{MemoryEffects, Nothing}) + + # Collect SSA indices first to avoid iterator invalidation from insertions. + ssa_indices = collect(Int, block.body.ssa_idxes) + + # Track whether we've seen a control flow op. Once we hit one, + # we stop transforming memory ops because the token state after the CF op + # is managed by codegen (ctx.token), not by the pass's token_map. + seen_control_flow = false + + for ssa_idx in ssa_indices + entry = get(block.body, ssa_idx, nothing) + entry === nothing && continue + + if entry.stmt isa ControlFlowOp + seen_control_flow = true + # Don't recurse into nested regions (conservative approach) + elseif !seen_control_flow + transform_statement!(sci, block, ssa_idx, entry.stmt, + alias_result, token_map) + end + end +end + +""" + transform_statement!(sci, block, ssa_idx, stmt, alias_result, token_map) + +Transform a single statement. If it's a memory operation, insert token input/output nodes. +""" +function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, stmt, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, IRToken}) + call = resolve_call(stmt) + call === nothing && return + resolved_func, operands = call + mem_effect = classify_memory_op(resolved_func) + mem_effect == MEM_NONE && return + + alias_set = get_alias_set_for_operand(alias_result, first(operands)) + + if mem_effect == MEM_LOAD + # Load depends on LAST_STORE (read-after-write) + input_token = get_input_token_ir!(sci, block, ssa_idx, + last_store_key(alias_set), token_map) + + # Add token arg to the call + push!(stmt.args, input_token) + + # Insert TokenResultNode after the load + result_ssa = new_ssa_idx!(sci) + insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) + + result_token = SSAValue(result_ssa) + + # Update LAST_OP: eagerly join with existing last_op token + lop_key = last_op_key(alias_set) + last_op_tok = token_map[lop_key] + join_ssa = new_ssa_idx!(sci) + insert_after!(block.body, result_ssa, join_ssa, + JoinTokensNode([last_op_tok, result_token]), TOKEN_TYPE) + token_map[lop_key] = SSAValue(join_ssa) + + elseif mem_effect == MEM_STORE + # Store depends on LAST_OP (write-after-read, write-after-write) + input_token = get_input_token_ir!(sci, block, ssa_idx, + last_op_key(alias_set), token_map) + + # Add token arg to the call + push!(stmt.args, input_token) + + # Insert TokenResultNode after the store + result_ssa = new_ssa_idx!(sci) + insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) + + result_token = SSAValue(result_ssa) + + # Update both LAST_OP and LAST_STORE + token_map[last_op_key(alias_set)] = result_token + token_map[last_store_key(alias_set)] = result_token + end +end + + +#============================================================================= + Control flow transformation (conservative) + + For this initial port, control flow ops are handled conservatively: + - Memory ops inside nested blocks get the root token (from the enclosing scope) + - No per-alias token carries through loops or branches + - Token state is unchanged after control flow ops + + This matches the original inline approach's conservative behavior. + TODO: Add per-alias token carries (matching Python's token_order_pass). +=============================================================================# + +# For the conservative approach, control flow regions are NOT transformed by the pass. +# Memory ops inside loops/branches use ctx.token (the loop-carried or pre-branch token) +# which is managed manually by the codegen's control flow emitters. +# The pass only transforms straight-line code in the entry block. +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::ControlFlowOp, @nospecialize(result_type), + alias_result, token_map, effects_cache) + # Do nothing — codegen handles control flow token threading conservatively. + # TODO: Transform nested regions once per-alias loop carries are implemented. +end + diff --git a/src/compiler/codegen/statements.jl b/src/compiler/codegen/statements.jl index df15534c..fd324a3b 100644 --- a/src/compiler/codegen/statements.jl +++ b/src/compiler/codegen/statements.jl @@ -7,8 +7,15 @@ Emit bytecode for a single SSA statement. The ssa_idx is the original Julia SSA index to store the result at. """ function emit_statement!(ctx::CGCtx, @nospecialize(stmt), ssa_idx::Int, @nospecialize(result_type)) + ctx.current_ssa_idx = ssa_idx tv = nothing - if stmt isa ReturnNode + if stmt isa MakeTokenNode + tv = emit_make_token!(ctx) + elseif stmt isa JoinTokensNode + tv = emit_join_tokens!(ctx, stmt) + elseif stmt isa TokenResultNode + tv = emit_token_result!(ctx, stmt) + elseif stmt isa ReturnNode emit_return!(ctx, stmt) elseif stmt isa Expr tv = emit_expr!(ctx, stmt, result_type) @@ -65,3 +72,45 @@ function emit_return!(ctx::CGCtx, node::ReturnNode) end end end + +#============================================================================= + Token IR node emission +=============================================================================# + +function emit_make_token!(ctx::CGCtx) + token_type = ctx.token_type + if token_type === nothing + token_type = Token(ctx.tt) + ctx.token_type = token_type + end + v = encode_MakeTokenOp!(ctx.cb, token_type) + ctx.token = v # Set as current token for control flow threading + return CGVal(v, token_type, TokenType) +end + +function emit_join_tokens!(ctx::CGCtx, node::JoinTokensNode) + tokens = Value[] + for tok_ref in node.tokens + tv = emit_value!(ctx, tok_ref) + tv === nothing && throw(IRError("JoinTokensNode: cannot resolve token operand $tok_ref")) + push!(tokens, tv.v) + end + # Deduplicate by identity (avoid join_tokens(%x, %x)) + unique_tokens = Value[] + for t in tokens + any(u -> u === t, unique_tokens) || push!(unique_tokens, t) + end + if length(unique_tokens) == 1 + return CGVal(unique_tokens[1], ctx.token_type, TokenType) + end + v = encode_JoinTokensOp!(ctx.cb, ctx.token_type, unique_tokens) + return CGVal(v, ctx.token_type, TokenType) +end + +function emit_token_result!(ctx::CGCtx, node::TokenResultNode) + # The memory op at node.mem_op_ssa should have stored its result token + v = get(ctx.result_tokens, node.mem_op_ssa, nothing) + v === nothing && throw(IRError("TokenResultNode: no result token for memory op at SSA %$(node.mem_op_ssa)")) + return CGVal(v, ctx.token_type, TokenType) +end + diff --git a/src/compiler/codegen/token_order.jl b/src/compiler/codegen/token_order.jl deleted file mode 100644 index 6cee31c0..00000000 --- a/src/compiler/codegen/token_order.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" - get_input_var(args) -> Any - -Extract the pointer/array variable from memory operation arguments. -""" -function get_input_var(args) - return args[1] -end - -""" - get_alias_set(ctx::CGCtx, var) -> AliasSet - -Get the alias set for a variable from analysis results. -""" -function get_alias_set(ctx::CGCtx, var) - # Trace to source - source = trace_to_source(ctx, var) - - # Lookup in alias results - return get(ctx.alias_result, source, ALIAS_UNIVERSE) -end - -""" - collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) -> Vector{Value} - -Collect all tokens that need to be joined for synchronization. -Based on Python's `_collect_join_tokens`. -""" -function collect_join_tokens(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) - tokens_to_join = [ctx.token_map[token_key]] - - for (other_key, other_token) in ctx.token_map - should_join = false - - # Join with ACQUIRE token - if other_key isa AcquireTokenKey - should_join = true - - # Join if alias sets overlap - elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey - # Release memory order: join with all LAST_OP tokens - if memory_order !== nothing && has_release_order(memory_order) - should_join = other_key.role == LAST_OP - end - - # Alias set overlap: join if same role and sets overlap - if other_key.role == token_key.role - alias_overlap = !(other_key.alias_set isa AliasUniverse) && - !(token_key.alias_set isa AliasUniverse) && - !isempty(intersect(other_key.alias_set, token_key.alias_set)) - should_join = should_join || alias_overlap - end - end - - # Skip tokens already present by identity avoids join_tokens(%x,%x,%x) - already_present = any(t -> t === other_token, tokens_to_join) - if should_join && !already_present - push!(tokens_to_join, other_token) - end - end - - return tokens_to_join -end - -""" - get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order=nothing) - -> (Value, Union{Nothing, JoinOp}) - -Get the input token for an operation, potentially creating a join operation. -""" -function get_input_token!(ctx::CGCtx, token_key::TokenKey, memory_order = nothing) - - if !haskey(ctx.token_map, token_key) - @warn "Token key not found in token_map!" token_key available_keys=keys(ctx.token_map) - # Fallback to root token - return (ctx.token_map[ACQUIRE_TOKEN_KEY], nothing) - end - tokens_to_join = collect_join_tokens(ctx, token_key, memory_order) - - if length(tokens_to_join) == 1 - return (tokens_to_join[1], nothing) - end - - # Join multiple tokens - result_token = encode_JoinTokensOp!(ctx.cb, ctx.token_type, tokens_to_join) - - return (result_token, nothing) # Return nothing for join_op since its already been emitted -end - -""" - trace_to_source(ctx::CGCtx, var) -> Any - -Trace a value back to its original source (Argument, SSAValue). -""" -function trace_to_source(ctx::CGCtx, var) - # Returns if its an Argument or SSAValue - if var isa Argument || var isa SSAValue - return var - end - - # Resolve for SlothNumber - if var isa SlotNumber - tv = get(ctx.slots, var.id, nothing) - if tv !== nothing && is_arg_ref(tv) - arg_idx, _ = tv.arg_ref - return Argument(arg_idx) - end - end - - # Generic emit_value resolution - tv = emit_value!(ctx, var) - if tv !== nothing && is_arg_ref(tv) - arg_idx, _ = tv.arg_ref - return Argument(arg_idx) - end - - # Return as is for unknown - return var -end - -""" - has_release_order(memory_order) -> Bool - -Check if memory order has release semantics. -For now, returns false (no memory order support yet). -""" -function has_release_order(memory_order) - # TODO: Implement proper memory order checking when needed - return false -end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index ab3a90b9..c3e4d0e4 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -34,6 +34,59 @@ or the universal alias set (ALIAS_UNIVERSE). const AliasSet = Union{Set{Any}, AliasUniverse} +#============================================================================= + Token IR Node Types (inserted by token_order_pass!) +=============================================================================# + +""" + TokenType + +Sentinel type used in StructuredIRCode to mark SSA values and BlockArgs +that represent memory ordering tokens. Not a runtime type. +""" +struct TokenType end +const TOKEN_TYPE = TokenType() + +""" + MakeTokenNode + +IR statement node: creates the root memory ordering token at kernel entry. +Inserted by `token_order_pass!`. Emitted as `encode_MakeTokenOp!` during codegen. +""" +struct MakeTokenNode end + +""" + JoinTokensNode + +IR statement node: merges multiple token values into one. +Inserted by `token_order_pass!`. Emitted as `encode_JoinTokensOp!` during codegen. +""" +struct JoinTokensNode + tokens::Vector{Any} # SSAValue or BlockArg references to token values +end + +""" + TokenResultNode + +IR statement node: represents the result token produced by a memory operation. +The memory op at `mem_op_ssa` produces both a data value and a token; this node +extracts the token. Codegen resolves this via `ctx.result_tokens[mem_op_ssa]`. +""" +struct TokenResultNode + mem_op_ssa::Int # SSA index of the memory operation that produced this token +end + +# Note: IfTokenResultNode was considered but removed in favor of conservative +# IfOp handling (no token carries through branches). Can be added later for +# more precise token tracking. + +""" + is_token_type(typ) -> Bool + +Check whether a type annotation in the structured IR represents a token. +""" +is_token_type(@nospecialize(typ)) = typ isa TokenType + #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -197,12 +250,21 @@ mutable struct CGCtx tt::TypeTable sci::StructuredIRCode - # Memory ordering token (kept for backward compatibility) - tokens::Dict{UInt64, Value} - global_token::Union{Value, Nothing} - token::Union{Value, Nothing} + # Token bytecode type (cached for encoding token operations) token_type::Union{TypeId, Nothing} + # Current token for control flow threading (loops, branches). + # Set by MakeTokenNode emission, updated by control flow emitters. + token::Union{Value, Nothing} + + # Result tokens from memory ops: mem_op SSA index → bytecode Value + # Populated during codegen when emitting memory ops with token args. + # Read by TokenResultNode emission. + result_tokens::Dict{Int, Value} + + # Current SSA index being emitted (set by emit_statement!) + current_ssa_idx::Int + # Type cache: Julia type -> TypeId type_cache::Dict{Type, TypeId} @@ -211,14 +273,9 @@ mutable struct CGCtx # Compilation cache (needed for combiner compilation) cache::CacheView - - # Alias-aware token system - alias_result::Dict{Any, AliasSet} # From alias analysis - token_map::Dict{Any, Value} # TokenKey -> current token Value end function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, - token::Union{Value, Nothing} = nothing, token_type::Union{TypeId, Nothing} = nothing, type_cache::Dict{Type, TypeId} = Dict{Type, TypeId}(), sm_arch::Union{VersionNumber, Nothing} = nothing, @@ -234,15 +291,13 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, cb, tt, sci, - Dict{UInt64, Value}(), # tokens (old system) - nothing, # global_token (old system) - token, # token (old system) token_type, + nothing, # token + Dict{Int, Value}(), # result_tokens + 0, # current_ssa_idx type_cache, sm_arch, cache, - Dict{Any, AliasSet}(), # alias_result - Dict{Any, Value}(), # token_map ) end diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 9ea4484b..41c03403 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -28,6 +28,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope) ptr_tv = emit_value!(ctx, args[1]) ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile")) @@ -60,16 +63,18 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, expected_tv.v, desired_tv.v; mask=mask_tv.v, - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) else encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, expected_tv.v, desired_tv.v; - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) end + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = new_token ctx.token = new_token julia_shape = ColMajorShape(shape) @@ -81,6 +86,9 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptr_tile, val, mask, memory_order, memory_scope) ptr_tv = emit_value!(ctx, args[1]) ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile")) @@ -118,16 +126,18 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, val_tv.v, actual_mode; mask=mask_tv.v, - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) else encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, val_tv.v, actual_mode; - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) end + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = new_token ctx.token = new_token julia_shape = ColMajorShape(shape) diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index f575ae58..81f88ec5 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -1,7 +1,5 @@ # Memory -# TODO: cuda_tile.join_tokens - # cuda_tile.load_ptr_tko @intrinsic load_ptr_tko(ptrs, latency=nothing, mask=nothing, padding=nothing) function tfunc(𝕃, ::typeof(Intrinsics.load_ptr_tko), @nospecialize(ptrs), @nospecialize args...) @@ -17,83 +15,52 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptrs, latency, mask?, padding?) - # Get pointer tile (arg 1) ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve pointer tile")) pointers = ptrs_tv.v tile_shape = ptrs_tv.shape - # Get element type from pointer tile type (Tile{Ptr{T}, S}) ptrs_type = CC.widenconst(ptrs_tv.jltype) - ptr_type = eltype(ptrs_type) # Ptr{T} from Tile{Ptr{T}, S} - elem_type = eltype(ptr_type) # T from Ptr{T} + ptr_type = eltype(ptrs_type) + elem_type = eltype(ptr_type) dtype = julia_to_tile_dtype!(tt, elem_type) result_tile_type = tile_type!(tt, dtype, tile_shape) token_type = Token(tt) - # Extract latency hint (args[2]) latency = @something get_constant(ctx, args[2]) throw(IRError("latency must be a compile-time constant")) - optimization_hints = create_optimization_hints(ctx, latency) - mask_tv, has_mask = emit_optional_mask(ctx, args, 3) - # Get alias set use global token if unknown - alias_set = get_alias_set(ctx, args[1]) - input_token = if alias_set isa AliasUniverse - ctx.token - else - last_store_key_val = last_store_key(alias_set) - first(get_input_token!(ctx, last_store_key_val, nothing)) - end - if has_mask mask = mask_tv.v - - # Get padding tile (arg 4) padding_tv = emit_value!(ctx, args[4]) padding_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve padding tile")) padding = padding_tv.v - # Load with mask and padding tile_val, new_token = encode_LoadPtrTkoOp!( cb, result_tile_type, token_type, pointers; - mask = mask, - padding_value = padding, - token = input_token, - optimization_hints + mask, padding_value = padding, token = input_token, optimization_hints ) else - # Load without mask tile_val, new_token = encode_LoadPtrTkoOp!( cb, result_tile_type, token_type, pointers; - token = input_token, - optimization_hints + token = input_token, optimization_hints ) end - # Only track alias if we have a real alias set - if alias_set isa AliasUniverse - ctx.token = new_token - else - last_op_key_val = last_op_key(alias_set) - last_op_token = get(ctx.token_map, last_op_key_val, nothing) - if last_op_token === nothing || last_op_token === input_token || last_op_token === new_token - new_last_op_token = new_token - else - new_last_op_token = encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, new_token]) - end - ctx.token_map[last_op_key_val] = new_last_op_token - end + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = new_token + ctx.token = new_token julia_shape = ColMajorShape(tile_shape) result_jltype = Tile{elem_type, TupleType(julia_shape)} return CGVal(tile_val, result_tile_type, result_jltype, tile_shape) end -# TODO: cuda_tile.make_token - # cuda_tile.store_ptr_tko @intrinsic store_ptr_tko(ptrs::Tile{Ptr{T}, S}, values::Tile{T, S}, latency::Union{Int, Nothing}, @@ -105,13 +72,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) cb = ctx.cb tt = ctx.tt - # args: (ptrs, values, latency, mask?) - # Get pointer tile (arg 1) + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve pointer tile")) pointers = ptrs_tv.v - # Get value tile (arg 2) values_tv = emit_value!(ctx, args[2]) values_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve values tile")) values = values_tv.v @@ -119,62 +86,47 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) token_type = Token(tt) latency = @something get_constant(ctx, args[3]) throw(IRError("latency must be a compile-time constant")) - optimization_hints = create_optimization_hints(ctx, latency) - mask_tv, has_mask = emit_optional_mask(ctx, args, 4) - alias_set = get_alias_set(ctx, args[1]) - - if alias_set isa AliasUniverse - # Baseline behavior: use global token directly, no alias tracking overhead - if has_mask - mask = mask_tv.v - - # Store with mask - new_token = encode_StorePtrTkoOp!( - cb, token_type, pointers, values; - mask = mask, - token = ctx.token, - optimization_hints - ) - else - new_token = encode_StorePtrTkoOp!( - cb, token_type, pointers, values; - token = ctx.token, - optimization_hints - ) - end - ctx.token = new_token + if has_mask + mask = mask_tv.v + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + mask, token = input_token, optimization_hints + ) else - last_op_key_val = last_op_key(alias_set) - last_store_key_val = last_store_key(alias_set) - - # Store depends on LAST_OP (write after read/write) - input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) - - if has_mask - mask = mask_tv.v - - new_token = encode_StorePtrTkoOp!( - cb, token_type, pointers, values; - mask = mask, - token = input_token, - optimization_hints - ) - else - new_token = encode_StorePtrTkoOp!( - cb, token_type, pointers, values; - token = input_token, - optimization_hints - ) - end - - # Update both LAST_OP and LAST_STORE. - # Do NOT update ctx.token — alias-aware path uses token_map only. - ctx.token_map[last_op_key_val] = new_token - ctx.token_map[last_store_key_val] = new_token + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + token = input_token, optimization_hints + ) end + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = new_token + ctx.token = new_token + nothing end + +""" + extract_token_arg!(ctx, args) -> Value + +Check if the last argument is a token SSAValue (added by token_order_pass!). +If so, pop it from args and return the resolved bytecode Value. +If no token arg is present (memory op inside a control flow region not yet +transformed by the pass), fall back to `ctx.token`. +""" +function extract_token_arg!(ctx::CGCtx, args) + if !isempty(args) + last_arg = args[end] + tv = emit_value!(ctx, last_arg) + if tv !== nothing && tv.jltype === TokenType + pop!(args) + return tv.v + end + end + # Fallback: use ctx.token (for memory ops inside loops/branches) + ctx.token !== nothing && return ctx.token + throw(IRError("Memory op has no token argument and ctx.token is not set")) +end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 17756c64..77e87a9f 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -56,6 +56,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (partition_view, latency, allow_tma, indices) pv_arg = emit_value!(ctx, args[1]) pv_arg === nothing && throw(IRError("load_partition_view() requires a PartitionView argument")) @@ -108,33 +111,14 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Get alias set fall back to simple token threading if unknown - alias_set = get_alias_set(ctx, args[1]) + tile_val, result_token = encode_LoadViewTkoOp!( + cb, tile_type, token_type, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) - if alias_set isa AliasUniverse - # Baseline behavior: use global token directly, no alias tracking overhead - tile_val, result_token = encode_LoadViewTkoOp!( - cb, tile_type, token_type, pv_arg.v, index_vals; - token = ctx.token, optimization_hints - ) - ctx.token = result_token - else - last_store_key_val = last_store_key(alias_set) - input_token, _ = get_input_token!(ctx, last_store_key_val, nothing) - tile_val, result_token = encode_LoadViewTkoOp!( - cb, tile_type, token_type, pv_arg.v, index_vals; - token = input_token, optimization_hints - ) - last_op_key_val = last_op_key(alias_set) - last_op_token = get(ctx.token_map, last_op_key_val, result_token) - # Only join if last_op_token is not already in the causal chain of result_token. - # result_token was produced from input_token, so if last_op_token === input_token - # the join is redundant — result_token already implies last_op_token. - new_last_op_token = last_op_token === input_token ? result_token : - encode_JoinTokensOp!(ctx.cb, token_type, [last_op_token, result_token]) - ctx.token_map[last_op_key_val] = new_last_op_token - # Do NOT update ctx.token — alias-aware path uses token_map only. - end + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = result_token + ctx.token = result_token julia_shape = ColMajorShape(tile_shape) return CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) @@ -374,6 +358,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (partition_view, tile, latency, allow_tma, indices) pv_arg = emit_value!(ctx, args[1]) pv_arg === nothing && throw(IRError("store_partition_view() requires a PartitionView argument")) @@ -437,28 +424,16 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Get alias set — fall back to simple token threading if unknown - alias_set = get_alias_set(ctx, args[1]) token_type = Token(tt) - if alias_set isa AliasUniverse - result_token = encode_StoreViewTkoOp!( - cb, token_type, tile_val, pv_arg.v, index_vals; - token = ctx.token, optimization_hints - ) - ctx.token = result_token - else - last_op_key_val = last_op_key(alias_set) - last_store_key_val = last_store_key(alias_set) - input_token, _ = get_input_token!(ctx, last_op_key_val, nothing) - result_token = encode_StoreViewTkoOp!( - cb, token_type, tile_val, pv_arg.v, index_vals; - token = input_token, optimization_hints - ) - ctx.token_map[last_op_key_val] = result_token - ctx.token_map[last_store_key_val] = result_token - # Do NOT update ctx.token — alias-aware path uses token_map only. - end + result_token = encode_StoreViewTkoOp!( + cb, token_type, tile_val, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + + # Store result token for TokenResultNode and update ctx.token for control flow + ctx.result_tokens[ctx.current_ssa_idx] = result_token + ctx.token = result_token return nothing end diff --git a/src/cuTile.jl b/src/cuTile.jl index 3549b2f3..2987e826 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -1,7 +1,7 @@ module cuTile using IRStructurizer -using IRStructurizer: Block, ControlFlowOp, BlockArg, +using IRStructurizer: Block, ControlFlowOp, BlockArg, SSAMap, YieldOp, ContinueOp, BreakOp, ConditionOp, IfOp, ForOp, WhileOp, LoopOp, Undef From 605cbd8666b956db80e46df3ec279ab042c5dda0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 10:38:13 +0100 Subject: [PATCH 03/10] Make TokenType a first-class type and add per-alias loop token carries. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The key insight: adding one line to `tile_type_for_julia!` to handle TokenType eliminates all token-specific conditionals from codegen. Loop emitters, getfield extraction, and type mapping become uniform — tokens are just another type flowing through the same code paths. Changes: - tile_type_for_julia! maps TokenType → Token(tt) (the 1-line fix) - extract_tile_shape handles TokenType (returns ScalarShape) - token_order_pass! now recurses into loops and branches: - Adds per-alias-set token carries (init_values, BlockArgs, terminators) - Updates SSAMap types via update_type! to include TokenType parameters - Inserts Core.getfield for token result extraction after loops/ifs - control_flow.jl simplified: no is_token_type branches, trusts parent_result_type - Terminators no longer manually append ctx.token — pass handles it - ctx.token removed from CGCtx entirely This is a WIP: 196/202 codegen tests pass. 6 integration tests with complex loop patterns (spinloop, nested loops) have BoundsErrors from update_type! producing incorrect parameter counts — to be fixed next. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/control_flow.jl | 242 ++++----- src/compiler/codegen/irutils.jl | 12 + src/compiler/codegen/passes/token_order.jl | 575 ++++++++++++++------- src/compiler/codegen/statements.jl | 1 - src/compiler/codegen/utils.jl | 10 +- src/compiler/intrinsics/atomics.jl | 6 +- src/compiler/intrinsics/memory.jl | 29 +- src/compiler/intrinsics/views.jl | 6 +- 8 files changed, 523 insertions(+), 358 deletions(-) diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index bc7aa51e..9410f9e2 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -1,21 +1,11 @@ # Structured IR Emission -""" - result_count(T) -> Int - -Compute the number of results from a Block.types entry. -""" function result_count(@nospecialize(T)) T === Nothing && return 0 T <: Tuple && return length(T.parameters) return 1 end -""" - emit_block!(ctx, block::Block) - -Emit bytecode for a structured IR block. -""" function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) for (ssa_idx, entry) in block.body if entry.stmt isa ControlFlowOp @@ -25,78 +15,59 @@ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) emit_statement!(ctx, entry.stmt, ssa_idx, entry.typ) end end - if !skip_terminator && block.terminator !== nothing emit_terminator!(ctx, block.terminator) end end -emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_if_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_for_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_while_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_loop_op!(ctx, op, result_type, n_results, original_idx) +emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(rt), n::Int, idx::Int) = emit_if_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(rt), n::Int, idx::Int) = emit_for_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(rt), n::Int, idx::Int) = emit_while_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(rt), n::Int, idx::Int) = emit_loop_op!(ctx, op, rt, n, idx) #============================================================================= - Control flow emitters - Token threading through control flow is still manual (conservative approach). - The token_order_pass handles straight-line code; control flow uses ctx.token. + IfOp =============================================================================# function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb - then_blk = op.then_region - else_blk = op.else_region cond_tv = emit_value!(ctx, op.condition) cond_tv === nothing && throw(IRError("Cannot resolve condition for IfOp")) - # User result types + # Build result types uniformly from parent_result_type (pass set it correctly) result_types = TypeId[] - if parent_result_type === Nothing - # No results - elseif parent_result_type <: Tuple - for T in parent_result_type.parameters - push!(result_types, tile_type_for_julia!(ctx, T)) + if parent_result_type !== Nothing + if parent_result_type <: Tuple + for T in parent_result_type.parameters + push!(result_types, tile_type_for_julia!(ctx, T)) + end + else + push!(result_types, tile_type_for_julia!(ctx, parent_result_type)) end - else - push!(result_types, tile_type_for_julia!(ctx, parent_result_type)) end - n_user_results = length(result_types) - # Add token as additional result - push!(result_types, ctx.token_type) - - token_before = ctx.token then_body = function(_) - saved_block_args = copy(ctx.block_args) - ctx.token = token_before - emit_block!(ctx, then_blk) - if then_blk.terminator === nothing - encode_YieldOp!(ctx.cb, [ctx.token]) - end - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + saved = copy(ctx.block_args) + emit_block!(ctx, op.then_region) + op.then_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[]) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end else_body = function(_) - saved_block_args = copy(ctx.block_args) - ctx.token = token_before - emit_block!(ctx, else_blk) - if else_blk.terminator === nothing - encode_YieldOp!(ctx.cb, [ctx.token]) - end - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + saved = copy(ctx.block_args) + emit_block!(ctx, op.else_region) + op.else_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[]) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_IfOp!(then_body, else_body, cb, result_types, cond_tv.v) - ctx.token = results[end] - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + ForOp +=============================================================================# + function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb body_blk = op.body @@ -110,51 +81,44 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), throw(IRError("Cannot resolve ForOp bounds")) lower_tv.jltype === upper_tv.jltype === step_tv.jltype || throw(IRError("ForOp bounds must all have the same type")) - iv_jl_type = lower_tv.jltype - iv_type = tile_type_for_julia!(ctx, iv_jl_type) + iv_jl_type = lower_tv.jltype + iv_type = tile_type_for_julia!(ctx, iv_jl_type) - # Init values + token + # Emit ALL init values (user + token carries from pass) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve ForOp init value")) push!(init_values, tv.v) end - push!(init_values, ctx.token) - n_carries = length(op.init_values) - - result_types = TypeId[] - for i in 1:n_carries - body_arg = body_blk.args[i] - push!(result_types, tile_type_for_julia!(ctx, body_arg.type)) - end - push!(result_types, ctx.token_type) + # Build result types uniformly from block args + n_carries = length(body_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args] body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) - - iv_tv = CGVal(block_args[1], iv_type, iv_jl_type) - ctx[iv_arg] = iv_tv - + saved = copy(ctx.block_args) + # Tile IR layout: [iv, carries...] + ctx[iv_arg] = CGVal(block_args[1], iv_type, iv_jl_type) for i in 1:n_carries - body_arg = body_blk.args[i] - shape = RowMajorShape(extract_tile_shape(body_arg.type)) - ctx[body_arg] = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape) + arg = body_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i + 1], result_types[i], arg.type, shape) end - ctx.token = block_args[end] - emit_block!(ctx, body_blk) - - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end - results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values) + results = encode_ForOp!(body_builder, cb, result_types, iv_type, + lower_tv.v, upper_tv.v, step_tv.v, init_values) - ctx.token = results[end] - ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) + # Trust parent_result_type (pass set it via update_type!) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + LoopOp +=============================================================================# + function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb body_blk = op.body @@ -165,44 +129,32 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve LoopOp init value")) push!(init_values, tv.v) end - push!(init_values, ctx.token) - n_carries = length(op.init_values) - - result_types = TypeId[] - for i in 1:n_carries - body_arg = body_blk.args[i] - push!(result_types, tile_type_for_julia!(ctx, body_arg.type)) - end - push!(result_types, ctx.token_type) + n_carries = length(body_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args] body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) - + saved = copy(ctx.block_args) for i in 1:n_carries - body_arg = body_blk.args[i] - shape = RowMajorShape(extract_tile_shape(body_arg.type)) - ctx[body_arg] = CGVal(block_args[i], result_types[i], body_arg.type, shape) + arg = body_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end - ctx.token = block_args[end] - emit_block!(ctx, body_blk) - if body_blk.terminator === nothing - fallback_operands = copy(block_args) - fallback_operands[end] = ctx.token - encode_ContinueOp!(ctx.cb, fallback_operands) + encode_ContinueOp!(ctx.cb, copy(block_args)) end - - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - ctx.token = results[end] - ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + WhileOp — lowered to LoopOp pattern in codegen +=============================================================================# + function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb before_blk = op.before @@ -214,26 +166,18 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve WhileOp init value: $init_val")) push!(init_values, tv.v) end - push!(init_values, ctx.token) - n_carries = length(op.init_values) - - result_types = TypeId[] - for i in 1:n_carries - before_arg = before_blk.args[i] - push!(result_types, tile_type_for_julia!(ctx, before_arg.type)) - end - push!(result_types, ctx.token_type) + n_carries = length(before_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in before_blk.args] body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) + saved = copy(ctx.block_args) for i in 1:n_carries - before_arg = before_blk.args[i] - shape = RowMajorShape(extract_tile_shape(before_arg.type)) - ctx[before_arg] = CGVal(block_args[i], result_types[i], before_arg.type, shape) + arg = before_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end - ctx.token = block_args[end] emit_block!(ctx, before_blk) @@ -243,10 +187,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ cond_tv = emit_value!(ctx, cond_op.condition) (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition")) - then_body = function(_) - encode_YieldOp!(ctx.cb, Value[]) - end - + then_body = (_) -> encode_YieldOp!(ctx.cb, Value[]) else_body = function(_) break_operands = Value[] for arg in cond_op.args @@ -254,26 +195,32 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ tv !== nothing && tv.v !== nothing && push!(break_operands, tv.v) end if isempty(break_operands) - for i in 1:n_carries + append!(break_operands, block_args[1:n_carries]) + else + # Append token carries (block_args beyond user carries from ConditionOp) + n_user = length(break_operands) + for i in (n_user + 1):n_carries push!(break_operands, block_args[i]) end end - push!(break_operands, ctx.token) encode_BreakOp!(ctx.cb, break_operands) end - encode_IfOp!(then_body, else_body, cb, TypeId[], cond_tv.v) for i in 1:length(after_blk.args) - after_arg = after_blk.args[i] + arg = after_blk.args[i] if i <= length(cond_op.args) tv = emit_value!(ctx, cond_op.args[i]) if tv !== nothing - ctx[after_arg] = tv + ctx[arg] = tv else - shape = RowMajorShape(extract_tile_shape(after_arg.type)) - ctx[after_arg] = CGVal(block_args[i], result_types[i], after_arg.type, shape) + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end + else + # Token carries beyond ConditionOp.args: use block_args directly + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, + RowMajorShape(extract_tile_shape(arg.type))) end end @@ -286,26 +233,24 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ tv !== nothing && tv.v !== nothing && push!(continue_operands, tv.v) end end - push!(continue_operands, ctx.token) + # Ensure token carries are included even if YieldOp didn't resolve them + while length(continue_operands) < n_carries + push!(continue_operands, block_args[length(continue_operands) + 1]) + end encode_ContinueOp!(ctx.cb, continue_operands) - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - ctx.token = results[end] - ctx.values[ssa_idx] = CGVal(results[1:n_carries], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end #============================================================================= - Terminators - Token is appended manually for control flow threading (conservative approach). + Terminators — tokens already in op.values from token_order_pass! =============================================================================# -function emit_terminator!(ctx::CGCtx, node::ReturnNode) - emit_return!(ctx, node) -end +emit_terminator!(ctx::CGCtx, node::ReturnNode) = emit_return!(ctx, node) function emit_terminator!(ctx::CGCtx, op::YieldOp) operands = Value[] @@ -313,7 +258,6 @@ function emit_terminator!(ctx::CGCtx, op::YieldOp) tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - push!(operands, ctx.token) encode_YieldOp!(ctx.cb, operands) end @@ -323,7 +267,6 @@ function emit_terminator!(ctx::CGCtx, op::ContinueOp) tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - push!(operands, ctx.token) encode_ContinueOp!(ctx.cb, operands) end @@ -333,12 +276,11 @@ function emit_terminator!(ctx::CGCtx, op::BreakOp) tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - push!(operands, ctx.token) encode_BreakOp!(ctx.cb, operands) end -function emit_terminator!(ctx::CGCtx, ::Nothing) end -function emit_terminator!(ctx::CGCtx, ::ConditionOp) end +emit_terminator!(ctx::CGCtx, ::Nothing) = nothing +emit_terminator!(ctx::CGCtx, ::ConditionOp) = nothing #============================================================================= Early Return Hoisting @@ -359,13 +301,11 @@ function hoist_returns!(block::Block) hoist_returns!(stmt.body) end end - for (_, entry) in block.body entry.stmt isa IfOp || continue op = entry.stmt::IfOp op.then_region.terminator isa ReturnNode || continue op.else_region.terminator isa ReturnNode || continue - op.then_region.terminator = YieldOp() op.else_region.terminator = YieldOp() block.terminator = ReturnNode(nothing) @@ -373,7 +313,7 @@ function hoist_returns!(block::Block) end #============================================================================= - Loop getfield extraction + Loop getfield extraction — uniform, no token special cases =============================================================================# function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any}) diff --git a/src/compiler/codegen/irutils.jl b/src/compiler/codegen/irutils.jl index 58c47629..e90e968e 100644 --- a/src/compiler/codegen/irutils.jl +++ b/src/compiler/codegen/irutils.jl @@ -64,3 +64,15 @@ function insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ) insert!(m.types, pos + 1, typ) return nothing end + +""" + update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type)) + +Update the type annotation for an existing SSAMap entry. +""" +function update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type)) + pos = findfirst(==(ssa_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(ssa_idx)) + m.types[pos] = new_type + return nothing +end diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index 2a1ff767..d153ffeb 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -1,8 +1,8 @@ # Token ordering pass # # Transforms a StructuredIRCode by inserting token operations (MakeToken, JoinTokens, -# TokenResult) and threading tokens through control flow (loop carries, branch yields). -# After this pass, codegen simply emits what's in the IR — no manual token threading. +# TokenResult) and threading per-alias-set tokens through control flow. +# After this pass, codegen emits what's in the IR — no manual token threading. # # Mirrors cuTile Python's `token_order_pass` (res/cutile-python/src/cuda/tile/_passes/token_order.py). @@ -17,8 +17,7 @@ using Core: SSAValue, Argument, SlotNumber """ MemoryEffects -Per-block summary of which alias sets are read/written and whether any -acquire-ordered operation appears. +Per-block summary of which alias sets are read/written. """ struct MemoryEffects effects::Dict{AliasSet, MemoryEffect} @@ -27,22 +26,11 @@ end MemoryEffects() = MemoryEffects(Dict{AliasSet, MemoryEffect}(), false) -function Base.merge!(a::MemoryEffects, b::MemoryEffects) - for (alias_set, effect) in b.effects - existing = get(a.effects, alias_set, MEM_NONE) - a.effects[alias_set] = max(existing, effect) - end - return MemoryEffects(a.effects, a.has_acquire | b.has_acquire) -end - function Base.union(a::MemoryEffects, b::MemoryEffects) result = Dict{AliasSet, MemoryEffect}() - for (k, v) in a.effects - result[k] = v - end + for (k, v) in a.effects; result[k] = v; end for (k, v) in b.effects - existing = get(result, k, MEM_NONE) - result[k] = max(existing, v) + result[k] = max(get(result, k, MEM_NONE), v) end return MemoryEffects(result, a.has_acquire | b.has_acquire) end @@ -50,14 +38,9 @@ end const EMPTY_MEMORY_EFFECTS = MemoryEffects() #============================================================================= - Resolve functions from IR expressions + Resolve and classify IR expressions =============================================================================# -""" - resolve_call(stmt) -> (func, operands) or nothing - -Extract the resolved function value and operands from a :call or :invoke Expr. -""" function resolve_call(stmt) stmt isa Expr || return nothing if stmt.head === :call @@ -70,11 +53,7 @@ function resolve_call(stmt) return nothing end resolved = if func_ref isa GlobalRef - try - getfield(func_ref.mod, func_ref.name) - catch - nothing - end + try; getfield(func_ref.mod, func_ref.name); catch; nothing; end else func_ref end @@ -82,12 +61,6 @@ function resolve_call(stmt) return (resolved, operands) end -""" - classify_memory_op(resolved_func) -> (MemoryEffect, Bool) - -Classify a resolved function as a memory operation. -Returns (effect, is_store) where effect is MEM_NONE/MEM_LOAD/MEM_STORE. -""" function classify_memory_op(resolved_func) if resolved_func === Intrinsics.load_partition_view || resolved_func === Intrinsics.load_ptr_tko @@ -96,7 +69,7 @@ function classify_memory_op(resolved_func) resolved_func === Intrinsics.store_ptr_tko return MEM_STORE elseif is_atomic_intrinsic(resolved_func) - return MEM_STORE # Atomics are read-modify-write, treat as store for ordering + return MEM_STORE else return MEM_NONE end @@ -111,11 +84,6 @@ function is_atomic_intrinsic(func) return false end -""" - get_alias_set_for_operand(alias_result, operand) -> AliasSet - -Look up the alias set for an operand (the first arg of a memory op). -""" function get_alias_set_for_operand(alias_result::Dict{Any, AliasSet}, operand) if operand isa SSAValue || operand isa Argument || operand isa SlotNumber return get(alias_result, operand, ALIAS_UNIVERSE) @@ -127,18 +95,13 @@ end Compute per-block memory effects =============================================================================# -""" - compute_block_memory_effects!(block, alias_result, cache) - -Compute memory effects for a block and all nested blocks, storing results in `cache`. -""" function compute_block_memory_effects!(block::Block, alias_result::Dict{Any, AliasSet}, cache::Dict{UInt64, MemoryEffects}) block_id = objectid(block) haskey(cache, block_id) && return cache[block_id] effects = MemoryEffects() - for (ssa_idx, entry) in block.body + for (_, entry) in block.body if entry.stmt isa ControlFlowOp nested = compute_cf_memory_effects!(entry.stmt, alias_result, cache) effects = union(effects, nested) @@ -149,64 +112,44 @@ function compute_block_memory_effects!(block::Block, alias_result::Dict{Any, Ali mem_effect = classify_memory_op(resolved_func) mem_effect == MEM_NONE && continue alias_set = get_alias_set_for_operand(alias_result, first(operands)) - existing = get(effects.effects, alias_set, MEM_NONE) - effects.effects[alias_set] = max(existing, mem_effect) + effects.effects[alias_set] = max(get(effects.effects, alias_set, MEM_NONE), mem_effect) + # Track acquire ordering for atomics + if is_atomic_intrinsic(resolved_func) + effects = MemoryEffects(effects.effects, true) + end end end cache[block_id] = effects return effects end -function compute_cf_memory_effects!(op::IfOp, alias_result, cache) - then_eff = compute_block_memory_effects!(op.then_region, alias_result, cache) - else_eff = compute_block_memory_effects!(op.else_region, alias_result, cache) - return union(then_eff, else_eff) -end - -function compute_cf_memory_effects!(op::ForOp, alias_result, cache) - return compute_block_memory_effects!(op.body, alias_result, cache) -end - -function compute_cf_memory_effects!(op::WhileOp, alias_result, cache) - before_eff = compute_block_memory_effects!(op.before, alias_result, cache) - after_eff = compute_block_memory_effects!(op.after, alias_result, cache) - return union(before_eff, after_eff) -end - -function compute_cf_memory_effects!(op::LoopOp, alias_result, cache) - return compute_block_memory_effects!(op.body, alias_result, cache) -end - -compute_cf_memory_effects!(::ControlFlowOp, alias_result, cache) = EMPTY_MEMORY_EFFECTS +compute_cf_memory_effects!(op::IfOp, ar, c) = + union(compute_block_memory_effects!(op.then_region, ar, c), + compute_block_memory_effects!(op.else_region, ar, c)) +compute_cf_memory_effects!(op::ForOp, ar, c) = compute_block_memory_effects!(op.body, ar, c) +compute_cf_memory_effects!(op::LoopOp, ar, c) = compute_block_memory_effects!(op.body, ar, c) +compute_cf_memory_effects!(op::WhileOp, ar, c) = + union(compute_block_memory_effects!(op.before, ar, c), + compute_block_memory_effects!(op.after, ar, c)) +compute_cf_memory_effects!(::ControlFlowOp, _, _) = EMPTY_MEMORY_EFFECTS #============================================================================= - Token map (IR-level, using SSAValue/BlockArg instead of bytecode Values) + Token map (IR-level, SSAValue/BlockArg) =============================================================================# -# IRToken: an SSAValue, BlockArg, or nothing (for tokens in the IR) const IRToken = Any -""" - collect_join_tokens_ir(token_key, token_map, memory_order=nothing) -> Vector{IRToken} - -IR-level equivalent of Python's `_collect_join_tokens`. -Collects all token IR values that need to be joined for the given token_key. -""" function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, memory_order=nothing) tokens_to_join = IRToken[token_map[token_key]] - for (other_key, other_tok) in token_map should_join = false - if other_key isa AcquireTokenKey should_join = true elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey - # Release: join with all LAST_OP tokens if memory_order !== nothing && has_release_order(memory_order) should_join = other_key.role == LAST_OP end - # Alias set overlap: same role and sets overlap if other_key.role == token_key.role alias_overlap = !(other_key.alias_set isa AliasUniverse) && !(token_key.alias_set isa AliasUniverse) && @@ -214,132 +157,107 @@ function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, I should_join = should_join || alias_overlap end end - if should_join && !any(t -> t === other_tok, tokens_to_join) push!(tokens_to_join, other_tok) end end - return tokens_to_join end -""" - get_input_token_ir!(sci, block, before_ssa, token_key, token_map, memory_order=nothing) - -> IRToken - -Get the input token for a memory operation. If multiple tokens need joining, -inserts a JoinTokensNode into the block before `before_ssa` and returns its SSAValue. -""" function get_input_token_ir!(sci::StructuredIRCode, block::Block, before_ssa::Int, token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, memory_order=nothing) - if !haskey(token_map, token_key) - # Fallback to ACQUIRE token - return token_map[ACQUIRE_TOKEN_KEY] - end - - tokens_to_join = collect_join_tokens_ir(token_key, token_map, memory_order) - - if length(tokens_to_join) == 1 - return tokens_to_join[1] - end - - # Insert JoinTokensNode before the memory op + haskey(token_map, token_key) || return token_map[ACQUIRE_TOKEN_KEY] + tokens = collect_join_tokens_ir(token_key, token_map, memory_order) + length(tokens) == 1 && return tokens[1] join_ssa = new_ssa_idx!(sci) - insert_before!(block.body, before_ssa, join_ssa, JoinTokensNode(tokens_to_join), TOKEN_TYPE) + insert_before!(block.body, before_ssa, join_ssa, JoinTokensNode(tokens), TOKEN_TYPE) return SSAValue(join_ssa) end has_release_order(memory_order) = false - #============================================================================= - The main pass + Control flow exit tokens (matching Python's _get_cf_exit_tokens) =============================================================================# """ - token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet}) + get_cf_exit_tokens(effects, token_map) -> Vector{IRToken} -Transform a StructuredIRCode by inserting explicit token operations. -Modifies the IR in place: -- Inserts MakeTokenNode at function entry -- Inserts JoinTokensNode where tokens need merging -- Inserts TokenResultNode after memory ops to capture their result tokens -- Adds token as extra argument to memory op calls -- Adds per-alias-set token carries through loops and branches - -After this pass, codegen emits tokens from the IR without manual threading. +Collect current tokens for each alias set with memory effects. +These are appended to ContinueOp/BreakOp/YieldOp when leaving a CF region. """ +function get_cf_exit_tokens(effects::MemoryEffects, token_map::Dict{TokenKey, IRToken}) + tokens = IRToken[] + for (alias_set, effect) in effects.effects + effect == MEM_NONE && continue + if effect == MEM_LOAD + push!(tokens, token_map[last_op_key(alias_set)]) + elseif effect == MEM_STORE + push!(tokens, token_map[last_op_key(alias_set)]) + push!(tokens, token_map[last_store_key(alias_set)]) + end + end + if effects.has_acquire + push!(tokens, token_map[ACQUIRE_TOKEN_KEY]) + end + return tokens +end + +#============================================================================= + The main pass +=============================================================================# + function token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet}) - # Compute per-block memory effects effects_cache = Dict{UInt64, MemoryEffects}() compute_block_memory_effects!(sci.entry, alias_result, effects_cache) - # Create root token (MakeTokenNode) at entry + # Insert root MakeTokenNode at entry root_ssa = new_ssa_idx!(sci) pushfirst!(sci.entry.body, (root_ssa, MakeTokenNode(), TOKEN_TYPE)) root_token = SSAValue(root_ssa) - # Initialize token map: all alias sets start at root token + # Initialize: all alias sets start at root token token_map = Dict{TokenKey, IRToken}() - unique_alias_sets = Set(values(alias_result)) - for alias_set in unique_alias_sets + for alias_set in Set(values(alias_result)) token_map[last_op_key(alias_set)] = root_token token_map[last_store_key(alias_set)] = root_token end token_map[ACQUIRE_TOKEN_KEY] = root_token - # Transform the entry block transform_block!(sci, sci.entry, alias_result, token_map, effects_cache, nothing, nothing) - return nothing end #============================================================================= - Block transformation (recursive) + Block transformation =============================================================================# -""" - transform_block!(sci, block, alias_result, token_map, effects_cache, - innermost_loop_info, ifelse_info) - -Walk a block's statements and transform memory/control-flow ops for token ordering. -Modifies `token_map` in place to reflect the token state after the block. -""" function transform_block!(sci::StructuredIRCode, block::Block, alias_result::Dict{Any, AliasSet}, token_map::Dict{TokenKey, IRToken}, effects_cache::Dict{UInt64, MemoryEffects}, - innermost_loop_effects::Union{MemoryEffects, Nothing}, + loop_effects::Union{MemoryEffects, Nothing}, ifelse_effects::Union{MemoryEffects, Nothing}) - - # Collect SSA indices first to avoid iterator invalidation from insertions. + # Snapshot indices to avoid invalidation from insertions ssa_indices = collect(Int, block.body.ssa_idxes) - # Track whether we've seen a control flow op. Once we hit one, - # we stop transforming memory ops because the token state after the CF op - # is managed by codegen (ctx.token), not by the pass's token_map. - seen_control_flow = false - for ssa_idx in ssa_indices entry = get(block.body, ssa_idx, nothing) entry === nothing && continue - if entry.stmt isa ControlFlowOp - seen_control_flow = true - # Don't recurse into nested regions (conservative approach) - elseif !seen_control_flow + transform_control_flow!(sci, block, ssa_idx, entry.stmt, entry.typ, + alias_result, token_map, effects_cache) + else transform_statement!(sci, block, ssa_idx, entry.stmt, alias_result, token_map) end end -end -""" - transform_statement!(sci, block, ssa_idx, stmt, alias_result, token_map) + # Append exit tokens to the block's terminator (for loops and branches) + transform_terminator!(block, token_map, loop_effects, ifelse_effects) +end -Transform a single statement. If it's a memory operation, insert token input/output nodes. -""" function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, stmt, alias_result::Dict{Any, AliasSet}, token_map::Dict{TokenKey, IRToken}) @@ -352,20 +270,15 @@ function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, alias_set = get_alias_set_for_operand(alias_result, first(operands)) if mem_effect == MEM_LOAD - # Load depends on LAST_STORE (read-after-write) input_token = get_input_token_ir!(sci, block, ssa_idx, last_store_key(alias_set), token_map) - - # Add token arg to the call push!(stmt.args, input_token) - # Insert TokenResultNode after the load result_ssa = new_ssa_idx!(sci) insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) - result_token = SSAValue(result_ssa) - # Update LAST_OP: eagerly join with existing last_op token + # Eagerly join with last_op token (Python line 176-179) lop_key = last_op_key(alias_set) last_op_tok = token_map[lop_key] join_ssa = new_ssa_idx!(sci) @@ -374,46 +287,360 @@ function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, token_map[lop_key] = SSAValue(join_ssa) elseif mem_effect == MEM_STORE - # Store depends on LAST_OP (write-after-read, write-after-write) input_token = get_input_token_ir!(sci, block, ssa_idx, last_op_key(alias_set), token_map) - - # Add token arg to the call push!(stmt.args, input_token) - # Insert TokenResultNode after the store result_ssa = new_ssa_idx!(sci) insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) - result_token = SSAValue(result_ssa) - # Update both LAST_OP and LAST_STORE token_map[last_op_key(alias_set)] = result_token token_map[last_store_key(alias_set)] = result_token + + # Atomics with acquire semantics update the ACQUIRE token + if is_atomic_intrinsic(resolved_func) + token_map[ACQUIRE_TOKEN_KEY] = result_token + end end end +function transform_terminator!(block::Block, token_map::Dict{TokenKey, IRToken}, + loop_effects::Union{MemoryEffects, Nothing}, + ifelse_effects::Union{MemoryEffects, Nothing}) + term = block.terminator + term === nothing && return + effects = if (term isa ContinueOp || term isa BreakOp) && loop_effects !== nothing + loop_effects + elseif term isa YieldOp && ifelse_effects !== nothing + ifelse_effects + elseif term isa YieldOp && loop_effects !== nothing + # WhileOp after-block: YieldOp values become ContinueOp in codegen + loop_effects + else + nothing + end + effects === nothing && return + append!(term.values, get_cf_exit_tokens(effects, token_map)) +end #============================================================================= - Control flow transformation (conservative) + Control flow transformation +=============================================================================# - For this initial port, control flow ops are handled conservatively: - - Memory ops inside nested blocks get the root token (from the enclosing scope) - - No per-alias token carries through loops or branches - - Token state is unchanged after control flow ops +# --- Loops (ForOp, LoopOp) --- +# Matching Python's Loop handling (token_order.py:228-280) - This matches the original inline approach's conservative behavior. - TODO: Add per-alias token carries (matching Python's token_order_pass). -=============================================================================# +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::ForOp, @nospecialize(result_type), + alias_result, token_map, effects_cache) + transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, + token_map, effects_cache) +end -# For the conservative approach, control flow regions are NOT transformed by the pass. -# Memory ops inside loops/branches use ctx.token (the loop-carried or pre-branch token) -# which is managed manually by the codegen's control flow emitters. -# The pass only transforms straight-line code in the entry block. function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, - ssa_idx::Int, op::ControlFlowOp, @nospecialize(result_type), + ssa_idx::Int, op::LoopOp, @nospecialize(result_type), + alias_result, token_map, effects_cache) + transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, + token_map, effects_cache) +end + +""" + transform_loop!(sci, parent_block, ssa_idx, op, body, alias_result, token_map, effects_cache) + +Add per-alias-set token carries to a loop. For each alias set with memory effects +in the body, creates init_values, BlockArgs, and terminator exit tokens. +Then recurses into the body with the body-scoped token_map. +After the loop, inserts getfield extractions for the token results. +""" +function transform_loop!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::Union{ForOp, LoopOp}, body::Block, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, IRToken}, + effects_cache::Dict{UInt64, MemoryEffects}) + body_effects = get(effects_cache, objectid(body), EMPTY_MEMORY_EFFECTS) + + body_token_map = copy(token_map) + result_token_map = copy(token_map) + + # Track the number of user carries (before we add tokens) + n_user_carries = length(op.init_values) + + # Add per-alias token carries (matching Python lines 245-264) + carry_idx = n_user_carries # 0-based index into results after user carries + for (alias_set, effect) in body_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + push!(op.init_values, token_map[last_op_key(alias_set)]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[last_op_key(alias_set)] = body_arg + # result_token_map will be updated below with getfield SSAs + end + if effect == MEM_STORE + carry_idx += 1 + push!(op.init_values, token_map[last_store_key(alias_set)]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[last_store_key(alias_set)] = body_arg + end + end + if body_effects.has_acquire + carry_idx += 1 + push!(op.init_values, token_map[ACQUIRE_TOKEN_KEY]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[ACQUIRE_TOKEN_KEY] = body_arg + end + + n_total_carries = length(op.init_values) + + # Recurse into body with body-scoped token map + transform_block!(sci, body, alias_result, body_token_map, effects_cache, + body_effects, nothing) + + # After the loop: insert getfield extractions for token results. + # These reference the loop's SSA result (which will be a tuple of all carries). + # Update the loop's type in the parent SSAMap to include token types. + if n_total_carries > n_user_carries + # Build extended type: Tuple{user_types..., TokenType...} + old_type = get(parent_block.body, ssa_idx, nothing) + if old_type !== nothing + user_types = if old_type.typ === Nothing + Type[] + elseif old_type.typ <: Tuple + collect(Type, old_type.typ.parameters) + else + Type[old_type.typ] + end + token_types = fill(TokenType, n_total_carries - n_user_carries) + new_type = Tuple{user_types..., token_types...} + update_type!(parent_block.body, ssa_idx, new_type) + end + + # Insert getfield SSAs after the loop for each token result + last_inserted = ssa_idx + carry_idx = n_user_carries + for (alias_set, effect) in body_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if body_effects.has_acquire + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end + end + + # Update parent's token_map with loop result tokens + merge!(token_map, result_token_map) +end + +# --- WhileOp --- +# WhileOp has before/after regions. We treat it similarly to a loop but need to +# handle both regions. + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::WhileOp, @nospecialize(result_type), + alias_result, token_map, effects_cache) + before_effects = get(effects_cache, objectid(op.before), EMPTY_MEMORY_EFFECTS) + after_effects = get(effects_cache, objectid(op.after), EMPTY_MEMORY_EFFECTS) + loop_effects = union(before_effects, after_effects) + + body_token_map = copy(token_map) + result_token_map = copy(token_map) + n_user_carries = length(op.init_values) + + # Add per-alias token carries to before/after blocks + for (alias_set, effect) in loop_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + push!(op.init_values, token_map[last_op_key(alias_set)]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[last_op_key(alias_set)] = before_arg + end + if effect == MEM_STORE + push!(op.init_values, token_map[last_store_key(alias_set)]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[last_store_key(alias_set)] = before_arg + end + end + if loop_effects.has_acquire + push!(op.init_values, token_map[ACQUIRE_TOKEN_KEY]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[ACQUIRE_TOKEN_KEY] = before_arg + end + + n_total_carries = length(op.init_values) + + # Build after_token_map from after block's args (not before's) + after_token_map = copy(token_map) + after_arg_idx = n_user_carries + for (alias_set, effect) in loop_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + after_arg_idx += 1 + after_token_map[last_op_key(alias_set)] = op.after.args[after_arg_idx] + end + if effect == MEM_STORE + after_arg_idx += 1 + after_token_map[last_store_key(alias_set)] = op.after.args[after_arg_idx] + end + end + if loop_effects.has_acquire + after_arg_idx += 1 + after_token_map[ACQUIRE_TOKEN_KEY] = op.after.args[after_arg_idx] + end + + # Transform before and after regions + transform_block!(sci, op.before, alias_result, body_token_map, effects_cache, + loop_effects, nothing) + transform_block!(sci, op.after, alias_result, after_token_map, effects_cache, + loop_effects, nothing) + + # Insert getfield extractions for token results (same as transform_loop!) + if n_total_carries > n_user_carries + old_type = get(parent_block.body, ssa_idx, nothing) + if old_type !== nothing + user_types = if old_type.typ === Nothing + Type[] + elseif old_type.typ <: Tuple + collect(Type, old_type.typ.parameters) + else + Type[old_type.typ] + end + token_types = fill(TokenType, n_total_carries - n_user_carries) + new_type = Tuple{user_types..., token_types...} + update_type!(parent_block.body, ssa_idx, new_type) + end + + last_inserted = ssa_idx + carry_idx = n_user_carries + for (alias_set, effect) in loop_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if loop_effects.has_acquire + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + result_token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end + end + + merge!(token_map, result_token_map) +end + +# --- IfOp --- +# Matching Python's IfElse handling (token_order.py:294-334) + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::IfOp, @nospecialize(result_type), alias_result, token_map, effects_cache) - # Do nothing — codegen handles control flow token threading conservatively. - # TODO: Transform nested regions once per-alias loop carries are implemented. + then_effects = get(effects_cache, objectid(op.then_region), EMPTY_MEMORY_EFFECTS) + else_effects = get(effects_cache, objectid(op.else_region), EMPTY_MEMORY_EFFECTS) + merged_effects = union(then_effects, else_effects) + + # Transform both branches (they extend their YieldOps with exit tokens) + then_map = copy(token_map) + transform_block!(sci, op.then_region, alias_result, then_map, effects_cache, + nothing, merged_effects) + else_map = copy(token_map) + transform_block!(sci, op.else_region, alias_result, else_map, effects_cache, + nothing, merged_effects) + + # Count token results and insert getfield extractions + n_token_results = 0 + for (_, effect) in merged_effects.effects + effect == MEM_NONE && continue + n_token_results += (effect == MEM_LOAD) ? 1 : 2 + end + n_token_results += merged_effects.has_acquire ? 1 : 0 + + if n_token_results > 0 + # Update IfOp type to include token results + old_type = get(parent_block.body, ssa_idx, nothing) + if old_type !== nothing + user_types = if old_type.typ === Nothing + Type[] + elseif old_type.typ <: Tuple + collect(Type, old_type.typ.parameters) + else + Type[old_type.typ] + end + token_types = fill(TokenType, n_token_results) + new_type = Tuple{user_types..., token_types...} + update_type!(parent_block.body, ssa_idx, new_type) + end + + # Insert getfield extractions for token results + last_inserted = ssa_idx + result_idx = length(user_types) + for (alias_set, effect) in merged_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if merged_effects.has_acquire + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end + end end +# Fallback +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::ControlFlowOp, @nospecialize(result_type), + alias_result, token_map, effects_cache) +end diff --git a/src/compiler/codegen/statements.jl b/src/compiler/codegen/statements.jl index fd324a3b..b3de499e 100644 --- a/src/compiler/codegen/statements.jl +++ b/src/compiler/codegen/statements.jl @@ -84,7 +84,6 @@ function emit_make_token!(ctx::CGCtx) ctx.token_type = token_type end v = encode_MakeTokenOp!(ctx.cb, token_type) - ctx.token = v # Set as current token for control flow threading return CGVal(v, token_type, TokenType) end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index c3e4d0e4..c410fe02 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -84,8 +84,9 @@ end is_token_type(typ) -> Bool Check whether a type annotation in the structured IR represents a token. +Handles both instances (`TOKEN_TYPE`) and the type itself (`TokenType`). """ -is_token_type(@nospecialize(typ)) = typ isa TokenType +is_token_type(@nospecialize(typ)) = typ isa TokenType || typ === TokenType #============================================================================= IRError: Exception type for IR compilation errors @@ -253,10 +254,6 @@ mutable struct CGCtx # Token bytecode type (cached for encoding token operations) token_type::Union{TypeId, Nothing} - # Current token for control flow threading (loops, branches). - # Set by MakeTokenNode emission, updated by control flow emitters. - token::Union{Value, Nothing} - # Result tokens from memory ops: mem_op SSA index → bytecode Value # Populated during codegen when emitting memory ops with token args. # Read by TokenResultNode emission. @@ -292,7 +289,6 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, tt, sci, token_type, - nothing, # token Dict{Int, Value}(), # result_tokens 0, # current_ssa_idx type_cache, @@ -434,6 +430,7 @@ Get or create a Tile IR type for a Julia type. With `throw_error=false`, returns `nothing` instead of throwing if the type has no Tile IR representation. """ function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T); throw_error::Bool=true) + is_token_type(T) && return Token(ctx.tt) actual_type = CC.widenconst(T) cached = get(ctx.type_cache, actual_type, nothing) cached !== nothing && return cached @@ -591,6 +588,7 @@ Extract shape from a Tile{T, Shape} type in Julia's column-major convention. Returns ScalarShape for non-Tile types. """ function extract_tile_shape(@nospecialize(T)) + is_token_type(T) && return ScalarShape() T = CC.widenconst(T) if T <: Tile return ColMajorShape(size(T)) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 41c03403..ec76bca9 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -73,9 +73,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) memory_ordering=mem_ordering, memory_scope=mem_scope) end - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = new_token - ctx.token = new_token julia_shape = ColMajorShape(shape) CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) @@ -136,9 +135,8 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. memory_ordering=mem_ordering, memory_scope=mem_scope) end - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = new_token - ctx.token = new_token julia_shape = ColMajorShape(shape) CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index 81f88ec5..47edb07c 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -52,9 +52,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) ) end - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = new_token - ctx.token = new_token julia_shape = ColMajorShape(tile_shape) result_jltype = Tile{elem_type, TupleType(julia_shape)} @@ -102,9 +101,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) ) end - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = new_token - ctx.token = new_token nothing end @@ -112,21 +110,16 @@ end """ extract_token_arg!(ctx, args) -> Value -Check if the last argument is a token SSAValue (added by token_order_pass!). -If so, pop it from args and return the resolved bytecode Value. -If no token arg is present (memory op inside a control flow region not yet -transformed by the pass), fall back to `ctx.token`. +Extract the token argument (last arg, added by token_order_pass!) from a memory op call. +Pops the token from args and returns the resolved bytecode Value. """ function extract_token_arg!(ctx::CGCtx, args) - if !isempty(args) - last_arg = args[end] - tv = emit_value!(ctx, last_arg) - if tv !== nothing && tv.jltype === TokenType - pop!(args) - return tv.v - end + isempty(args) && throw(IRError("Memory op has no arguments")) + last_arg = args[end] + tv = emit_value!(ctx, last_arg) + if tv !== nothing && tv.jltype === TokenType + pop!(args) + return tv.v end - # Fallback: use ctx.token (for memory ops inside loops/branches) - ctx.token !== nothing && return ctx.token - throw(IRError("Memory op has no token argument and ctx.token is not set")) + throw(IRError("Memory op missing token argument (token_order_pass! not run?)")) end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 77e87a9f..e747c7e7 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -116,9 +116,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a token = input_token, optimization_hints ) - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = result_token - ctx.token = result_token julia_shape = ColMajorShape(tile_shape) return CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) @@ -431,9 +430,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), token = input_token, optimization_hints ) - # Store result token for TokenResultNode and update ctx.token for control flow + # Store result token for TokenResultNode ctx.result_tokens[ctx.current_ssa_idx] = result_token - ctx.token = result_token return nothing end From 652bbb2fa303fccab817b15af03d7dc18bac8480 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 11:00:21 +0100 Subject: [PATCH 04/10] Fix type mismatch and loop terminator token threading. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes: 1. Build loop result types from block args (authoritative source) instead of from old_type.typ which may be Nothing for void loops — fixes BoundsError in emit_loop_getfield! for nested spinloop patterns. 2. Thread parent_loop_effects through IfOp branches so that ContinueOp/BreakOp inside IfOp (common for LoopOp→IfOp while-loop patterns) get their token exit values appended. This was the cause of the "continue op operand mismatch" errors for for-loops with memory ops in the body. 3. Add ForOp body fallback ContinueOp (matching LoopOp) for completeness. All 202 codegen tests pass. GPU execution of spinlock (hang.jl) still has a token carry issue — break/continue inside the inner loop carry initial block args instead of the updated CAS result tokens. To be debugged next. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/control_flow.jl | 7 ++- src/compiler/codegen/passes/token_order.jl | 67 +++++++++------------- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 9410f9e2..46ce68ae 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -106,6 +106,11 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), ctx[arg] = CGVal(block_args[i + 1], result_types[i], arg.type, shape) end emit_block!(ctx, body_blk) + # If body has no terminator, emit a ContinueOp with all block args + if body_blk.terminator === nothing + # block_args[1] is iv, block_args[2:end] are carries + encode_ContinueOp!(ctx.cb, block_args[2:end]) + end empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_ForOp!(body_builder, cb, result_types, iv_type, @@ -233,7 +238,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ tv !== nothing && tv.v !== nothing && push!(continue_operands, tv.v) end end - # Ensure token carries are included even if YieldOp didn't resolve them + # Ensure all carries (including tokens from pass) are in the ContinueOp while length(continue_operands) < n_carries push!(continue_operands, block_args[length(continue_operands) + 1]) end diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index d153ffeb..aaac3172 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -247,7 +247,7 @@ function transform_block!(sci::StructuredIRCode, block::Block, entry === nothing && continue if entry.stmt isa ControlFlowOp transform_control_flow!(sci, block, ssa_idx, entry.stmt, entry.typ, - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, loop_effects) else transform_statement!(sci, block, ssa_idx, entry.stmt, alias_result, token_map) @@ -333,14 +333,14 @@ end function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::ForOp, @nospecialize(result_type), - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, parent_loop_effects=nothing) transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, token_map, effects_cache) end function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::LoopOp, @nospecialize(result_type), - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, parent_loop_effects=nothing) transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, token_map, effects_cache) end @@ -398,23 +398,15 @@ function transform_loop!(sci::StructuredIRCode, parent_block::Block, body_effects, nothing) # After the loop: insert getfield extractions for token results. - # These reference the loop's SSA result (which will be a tuple of all carries). - # Update the loop's type in the parent SSAMap to include token types. + # Update the loop's type to include ALL carries (user + token) so that + # codegen's getfield extraction works correctly. if n_total_carries > n_user_carries - # Build extended type: Tuple{user_types..., TokenType...} - old_type = get(parent_block.body, ssa_idx, nothing) - if old_type !== nothing - user_types = if old_type.typ === Nothing - Type[] - elseif old_type.typ <: Tuple - collect(Type, old_type.typ.parameters) - else - Type[old_type.typ] - end - token_types = fill(TokenType, n_total_carries - n_user_carries) - new_type = Tuple{user_types..., token_types...} - update_type!(parent_block.body, ssa_idx, new_type) - end + # Build result type from block args (authoritative source of all carries) + all_types = Type[is_token_type(arg.type) ? TokenType : arg.type + for arg in body.args] + new_type = isempty(all_types) ? Nothing : Tuple{all_types...} + update_type!(parent_block.body, ssa_idx, new_type) + # Insert getfield SSAs after the loop for each token result last_inserted = ssa_idx @@ -457,7 +449,7 @@ end function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::WhileOp, @nospecialize(result_type), - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, parent_loop_effects=nothing) before_effects = get(effects_cache, objectid(op.before), EMPTY_MEMORY_EFFECTS) after_effects = get(effects_cache, objectid(op.after), EMPTY_MEMORY_EFFECTS) loop_effects = union(before_effects, after_effects) @@ -516,21 +508,14 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, transform_block!(sci, op.after, alias_result, after_token_map, effects_cache, loop_effects, nothing) - # Insert getfield extractions for token results (same as transform_loop!) + # Insert getfield extractions for token results if n_total_carries > n_user_carries - old_type = get(parent_block.body, ssa_idx, nothing) - if old_type !== nothing - user_types = if old_type.typ === Nothing - Type[] - elseif old_type.typ <: Tuple - collect(Type, old_type.typ.parameters) - else - Type[old_type.typ] - end - token_types = fill(TokenType, n_total_carries - n_user_carries) - new_type = Tuple{user_types..., token_types...} - update_type!(parent_block.body, ssa_idx, new_type) - end + # Build result type from before block args (authoritative source of all carries) + all_types = Type[is_token_type(arg.type) ? TokenType : arg.type + for arg in op.before.args] + new_type = isempty(all_types) ? Nothing : Tuple{all_types...} + update_type!(parent_block.body, ssa_idx, new_type) + last_inserted = ssa_idx carry_idx = n_user_carries @@ -570,18 +555,19 @@ end function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::IfOp, @nospecialize(result_type), - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, parent_loop_effects=nothing) then_effects = get(effects_cache, objectid(op.then_region), EMPTY_MEMORY_EFFECTS) else_effects = get(effects_cache, objectid(op.else_region), EMPTY_MEMORY_EFFECTS) merged_effects = union(then_effects, else_effects) - # Transform both branches (they extend their YieldOps with exit tokens) + # Transform both branches. Pass parent_loop_effects so that ContinueOp/BreakOp + # inside branches (common for LoopOp→IfOp patterns) get token exit values. then_map = copy(token_map) transform_block!(sci, op.then_region, alias_result, then_map, effects_cache, - nothing, merged_effects) + parent_loop_effects, merged_effects) else_map = copy(token_map) transform_block!(sci, op.else_region, alias_result, else_map, effects_cache, - nothing, merged_effects) + parent_loop_effects, merged_effects) # Count token results and insert getfield extractions n_token_results = 0 @@ -594,6 +580,7 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, if n_token_results > 0 # Update IfOp type to include token results old_type = get(parent_block.body, ssa_idx, nothing) + old_type = get(parent_block.body, ssa_idx, nothing) if old_type !== nothing user_types = if old_type.typ === Nothing Type[] @@ -605,6 +592,8 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, token_types = fill(TokenType, n_token_results) new_type = Tuple{user_types..., token_types...} update_type!(parent_block.body, ssa_idx, new_type) + else + user_types = Type[] end # Insert getfield extractions for token results @@ -642,5 +631,5 @@ end # Fallback function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::ControlFlowOp, @nospecialize(result_type), - alias_result, token_map, effects_cache) + alias_result, token_map, effects_cache, parent_loop_effects=nothing) end From 1dc3e0334814ee30bfa583824faf2b9d34046758 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 11:15:07 +0100 Subject: [PATCH 05/10] =?UTF-8?q?Fix=20spinlock=20token=20carry:=20propaga?= =?UTF-8?q?te=20before=E2=86=92after=20and=20extend=20ConditionOp.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two fixes for WhileOp token threading: 1. After transforming the before block (which contains the CAS), propagate the updated token_map to after_token_map. Previously, after_token_map used the initial BlockArgs, so the YieldOp (which becomes ContinueOp in codegen) carried stale tokens instead of the CAS result. 2. Extend ConditionOp.args with exit tokens (like ContinueOp/BreakOp). The codegen-generated BreakOp reads from cond_op.args, so token values must be present there for the break path to carry the right tokens. Result: both break and continue in the spinlock loop now carry %result_token_12 (CAS acquire token) instead of the initial block args. Status: 202/202 codegen, hang.jl passes, 1581/1585 full suite. Remaining: 3 early-return device tests + layernorm example. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/token_order.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index aaac3172..51649aac 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -310,12 +310,19 @@ function transform_terminator!(block::Block, token_map::Dict{TokenKey, IRToken}, ifelse_effects::Union{MemoryEffects, Nothing}) term = block.terminator term === nothing && return + + # ConditionOp (WhileOp before-block): extend args with exit tokens so that + # the codegen-generated BreakOp carries them. + if term isa ConditionOp && loop_effects !== nothing + append!(term.args, get_cf_exit_tokens(loop_effects, token_map)) + return + end + effects = if (term isa ContinueOp || term isa BreakOp) && loop_effects !== nothing loop_effects elseif term isa YieldOp && ifelse_effects !== nothing ifelse_effects elseif term isa YieldOp && loop_effects !== nothing - # WhileOp after-block: YieldOp values become ContinueOp in codegen loop_effects else nothing @@ -502,9 +509,17 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, after_token_map[ACQUIRE_TOKEN_KEY] = op.after.args[after_arg_idx] end - # Transform before and after regions + # Transform before region (may update body_token_map, e.g., CAS in condition) transform_block!(sci, op.before, alias_result, body_token_map, effects_cache, loop_effects, nothing) + + # Propagate before's final token state to after_token_map. + # The after block receives values from before's ConditionOp, so it should + # see the token state AFTER the before block's transformations (e.g., CAS result). + for (key, val) in body_token_map + after_token_map[key] = val + end + transform_block!(sci, op.after, alias_result, after_token_map, effects_cache, loop_effects, nothing) From cd73ead4eabf388e551bb03043fbf117b56efdac Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 11:23:14 +0100 Subject: [PATCH 06/10] Run hoist_returns! before token_order_pass!. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit hoist_returns! replaces ReturnNode terminators in IfOp branches with empty YieldOp(). If the token pass runs first and extends the IfOp with token yields, hoist_returns! wipes them out — causing "then branch does not yield anything" errors. Fix: run hoist_returns! first so the token pass sees normalized YieldOps. Full suite: 1586/1587 pass (layernorm dW mismatch is pre-existing). Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/kernel.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 81f85483..9e49dede 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -141,18 +141,17 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, create_tensor_views!(ctx, arg_idx, argtype, Int[]) end + # Hoist early returns BEFORE token ordering — hoist_returns! rewrites + # ReturnNode terminators to YieldOp, which the token pass then extends. + hoist_returns!(ctx.sci.entry) + # Run alias analysis and token ordering pass on the structured IR. - # This inserts MakeTokenNode, JoinTokensNode, TokenResultNode into the IR - # and threads tokens through control flow (loop carries, branch yields). alias_result = alias_analysis_pass!(sci) token_order_pass!(sci, alias_result) # Cache the token bytecode type for codegen ctx.token_type = Token(tt) - # Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp) - hoist_returns!(ctx.sci.entry) - # Emit the structured IR (uses original Julia SSA indices everywhere) emit_block!(ctx, ctx.sci.entry) From bb5e142c87f57e687c53b44fc1954bbb4e708c0d Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 11:44:50 +0100 Subject: [PATCH 07/10] Implement release memory ordering for atomic token joins. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The has_release_order() stub returned false, meaning release-ordered atomics didn't join with all LAST_OP tokens. Per the Tile IR memory model: "When you use a release operation, you need to token-order all memory events that must stay before the release to the release itself." Without this, the release atomic_xchg in spinlock patterns didn't depend on the data store's token — the store's writes weren't guaranteed visible before the lock release, causing data corruption in the layernorm backward kernel. Fix: extract memory_order from atomic call args in the IR and pass it through to collect_join_tokens_ir, which already had the release join logic (line 150-152) but was never triggered. All 1586 tests pass including layernorm backward. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/token_order.jl | 32 ++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index 51649aac..0e5acfe8 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -175,7 +175,32 @@ function get_input_token_ir!(sci::StructuredIRCode, block::Block, before_ssa::In return SSAValue(join_ssa) end -has_release_order(memory_order) = false +function has_release_order(memory_order) + memory_order === nothing && return false + # MemoryOrder enum: Release=3, AcqRel=4 + return memory_order === MemoryOrder.Release || memory_order === MemoryOrder.AcqRel +end + +""" + extract_memory_order(resolved_func, operands) -> Union{MemoryOrder.T, Nothing} + +Extract the compile-time memory_order from an atomic intrinsic's operands. +""" +function extract_memory_order(resolved_func, operands) + is_atomic_intrinsic(resolved_func) || return nothing + # CAS: (ptr, expected, desired, mask, memory_order, memory_scope) + # RMW: (ptr, val, mask, memory_order, memory_scope) + mo_idx = resolved_func === Intrinsics.atomic_cas ? 5 : 4 + mo_idx > length(operands) && return nothing + mo_arg = operands[mo_idx] + # The memory_order is typically a compile-time constant (QuoteNode or literal) + if mo_arg isa QuoteNode + return mo_arg.value + elseif mo_arg isa MemoryOrder.T + return mo_arg + end + return nothing +end #============================================================================= Control flow exit tokens (matching Python's _get_cf_exit_tokens) @@ -287,8 +312,11 @@ function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, token_map[lop_key] = SSAValue(join_ssa) elseif mem_effect == MEM_STORE + # For release-ordered atomics, join with ALL LAST_OP tokens (memory fence) + memory_order = extract_memory_order(resolved_func, operands) input_token = get_input_token_ir!(sci, block, ssa_idx, - last_op_key(alias_set), token_map) + last_op_key(alias_set), token_map, + memory_order) push!(stmt.args, input_token) result_ssa = new_ssa_idx!(sci) From 96c06e0ccc36a98ada607cfe8171a618565c936f Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 11:57:19 +0100 Subject: [PATCH 08/10] Simplify: extract shared helpers, remove duplication. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on code review findings: - Move resolve_call() to irutils.jl (shared by alias_analysis + token_order) - Use resolve_call() in alias_analysis instead of inline call/invoke normalization - Extract insert_token_result_getfields!() helper — replaces 3 copy-pasted blocks (~25 lines each) in transform_loop!, WhileOp, and IfOp transforms - Remove dead after_arg assignments in WhileOp (side-effect-only calls) - Remove duplicate old_type = get(...) line in IfOp transform - Remove const IRToken = Any alias (no type safety value) Net: -48 lines, 3 copy-paste blocks → 1 shared function. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/irutils.jl | 26 +++ src/compiler/codegen/passes/alias_analysis.jl | 27 +-- src/compiler/codegen/passes/token_order.jl | 189 ++++++------------ 3 files changed, 97 insertions(+), 145 deletions(-) diff --git a/src/compiler/codegen/irutils.jl b/src/compiler/codegen/irutils.jl index e90e968e..4ba9e007 100644 --- a/src/compiler/codegen/irutils.jl +++ b/src/compiler/codegen/irutils.jl @@ -76,3 +76,29 @@ function update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type)) m.types[pos] = new_type return nothing end + +""" + resolve_call(stmt) -> (resolved_func, operands) or nothing + +Extract the resolved function and operands from a `:call` or `:invoke` Expr. +Shared by alias analysis and token ordering passes. +""" +function resolve_call(stmt) + stmt isa Expr || return nothing + if stmt.head === :call + func_ref = stmt.args[1] + operands = @view stmt.args[2:end] + elseif stmt.head === :invoke + func_ref = stmt.args[2] + operands = @view stmt.args[3:end] + else + return nothing + end + resolved = if func_ref isa GlobalRef + try; getfield(func_ref.mod, func_ref.name); catch; nothing; end + else + func_ref + end + resolved === nothing && return nothing + return (resolved, operands) +end diff --git a/src/compiler/codegen/passes/alias_analysis.jl b/src/compiler/codegen/passes/alias_analysis.jl index 87dcfd7e..ef2333e8 100644 --- a/src/compiler/codegen/passes/alias_analysis.jl +++ b/src/compiler/codegen/passes/alias_analysis.jl @@ -134,29 +134,12 @@ Analyze a single statement and propagate aliases. Handles both `:call` and `:invoke` expression forms. """ function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) - if stmt isa Expr && (stmt.head === :call || stmt.head === :invoke) - # Normalize :call and :invoke into (func, operands) - # :call -> args = [func, operands...] - # :invoke -> args = [MethodInstance, func, operands...] - if stmt.head === :call - func = stmt.args[1] - operands = @view stmt.args[2:end] - else # :invoke - func = stmt.args[2] - operands = @view stmt.args[3:end] - end + call = resolve_call(stmt) + if call !== nothing + resolved_func, operands = call - # Resolve func to its runtime value for intrinsic matching. - # In :invoke, func may already be the function object (not a GlobalRef). - resolved_func = if func isa GlobalRef - try - getfield(func.mod, func.name) - catch - nothing - end - else - func # Direct function value (common in :invoke) - end + # Also need the raw func ref for GlobalRef comparisons + func = stmt.head === :call ? stmt.args[1] : stmt.args[2] # getfield: propagate from parent if func === GlobalRef(Core, :getfield) && length(operands) >= 1 diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index 0e5acfe8..47c391f5 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -41,26 +41,6 @@ const EMPTY_MEMORY_EFFECTS = MemoryEffects() Resolve and classify IR expressions =============================================================================# -function resolve_call(stmt) - stmt isa Expr || return nothing - if stmt.head === :call - func_ref = stmt.args[1] - operands = @view stmt.args[2:end] - elseif stmt.head === :invoke - func_ref = stmt.args[2] - operands = @view stmt.args[3:end] - else - return nothing - end - resolved = if func_ref isa GlobalRef - try; getfield(func_ref.mod, func_ref.name); catch; nothing; end - else - func_ref - end - resolved === nothing && return nothing - return (resolved, operands) -end - function classify_memory_op(resolved_func) if resolved_func === Intrinsics.load_partition_view || resolved_func === Intrinsics.load_ptr_tko @@ -137,11 +117,9 @@ compute_cf_memory_effects!(::ControlFlowOp, _, _) = EMPTY_MEMORY_EFFECTS Token map (IR-level, SSAValue/BlockArg) =============================================================================# -const IRToken = Any - -function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, +function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, Any}, memory_order=nothing) - tokens_to_join = IRToken[token_map[token_key]] + tokens_to_join = Any[token_map[token_key]] for (other_key, other_tok) in token_map should_join = false if other_key isa AcquireTokenKey @@ -165,7 +143,7 @@ function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, I end function get_input_token_ir!(sci::StructuredIRCode, block::Block, before_ssa::Int, - token_key::TokenKey, token_map::Dict{TokenKey, IRToken}, + token_key::TokenKey, token_map::Dict{TokenKey, Any}, memory_order=nothing) haskey(token_map, token_key) || return token_map[ACQUIRE_TOKEN_KEY] tokens = collect_join_tokens_ir(token_key, token_map, memory_order) @@ -207,13 +185,13 @@ end =============================================================================# """ - get_cf_exit_tokens(effects, token_map) -> Vector{IRToken} + get_cf_exit_tokens(effects, token_map) -> Vector{Any} Collect current tokens for each alias set with memory effects. These are appended to ContinueOp/BreakOp/YieldOp when leaving a CF region. """ -function get_cf_exit_tokens(effects::MemoryEffects, token_map::Dict{TokenKey, IRToken}) - tokens = IRToken[] +function get_cf_exit_tokens(effects::MemoryEffects, token_map::Dict{TokenKey, Any}) + tokens = Any[] for (alias_set, effect) in effects.effects effect == MEM_NONE && continue if effect == MEM_LOAD @@ -243,7 +221,7 @@ function token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasS root_token = SSAValue(root_ssa) # Initialize: all alias sets start at root token - token_map = Dict{TokenKey, IRToken}() + token_map = Dict{TokenKey, Any}() for alias_set in Set(values(alias_result)) token_map[last_op_key(alias_set)] = root_token token_map[last_store_key(alias_set)] = root_token @@ -260,7 +238,7 @@ end function transform_block!(sci::StructuredIRCode, block::Block, alias_result::Dict{Any, AliasSet}, - token_map::Dict{TokenKey, IRToken}, + token_map::Dict{TokenKey, Any}, effects_cache::Dict{UInt64, MemoryEffects}, loop_effects::Union{MemoryEffects, Nothing}, ifelse_effects::Union{MemoryEffects, Nothing}) @@ -285,7 +263,7 @@ end function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, stmt, alias_result::Dict{Any, AliasSet}, - token_map::Dict{TokenKey, IRToken}) + token_map::Dict{TokenKey, Any}) call = resolve_call(stmt) call === nothing && return resolved_func, operands = call @@ -333,7 +311,7 @@ function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, end end -function transform_terminator!(block::Block, token_map::Dict{TokenKey, IRToken}, +function transform_terminator!(block::Block, token_map::Dict{TokenKey, Any}, loop_effects::Union{MemoryEffects, Nothing}, ifelse_effects::Union{MemoryEffects, Nothing}) term = block.terminator @@ -381,17 +359,61 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, end """ - transform_loop!(sci, parent_block, ssa_idx, op, body, alias_result, token_map, effects_cache) + insert_token_result_getfields!(sci, parent_block, ssa_idx, n_user, effects, token_map) -Add per-alias-set token carries to a loop. For each alias set with memory effects -in the body, creates init_values, BlockArgs, and terminator exit tokens. -Then recurses into the body with the body-scoped token_map. -After the loop, inserts getfield extractions for the token results. +Insert getfield extractions after a loop/if for each per-alias token result. +Updates `token_map` with SSAValues pointing to the extracted tokens. +Also updates the SSAMap type to include TokenType parameters. +""" +function insert_token_result_getfields!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, block_args, n_user::Int, + effects::MemoryEffects, token_map::Dict{TokenKey, Any}) + n_total = length(block_args) + n_total > n_user || return + + # Update result type to include all carries + all_types = Type[is_token_type(arg.type) ? TokenType : arg.type for arg in block_args] + update_type!(parent_block.body, ssa_idx, isempty(all_types) ? Nothing : Tuple{all_types...}) + + last_inserted = ssa_idx + carry_idx = n_user + for (alias_set, effect) in effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if effects.has_acquire + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end +end + +""" + transform_loop!(...) + +Add per-alias-set token carries to a loop. """ function transform_loop!(sci::StructuredIRCode, parent_block::Block, ssa_idx::Int, op::Union{ForOp, LoopOp}, body::Block, alias_result::Dict{Any, AliasSet}, - token_map::Dict{TokenKey, IRToken}, + token_map::Dict{TokenKey, Any}, effects_cache::Dict{UInt64, MemoryEffects}) body_effects = get(effects_cache, objectid(body), EMPTY_MEMORY_EFFECTS) @@ -432,49 +454,8 @@ function transform_loop!(sci::StructuredIRCode, parent_block::Block, transform_block!(sci, body, alias_result, body_token_map, effects_cache, body_effects, nothing) - # After the loop: insert getfield extractions for token results. - # Update the loop's type to include ALL carries (user + token) so that - # codegen's getfield extraction works correctly. - if n_total_carries > n_user_carries - # Build result type from block args (authoritative source of all carries) - all_types = Type[is_token_type(arg.type) ? TokenType : arg.type - for arg in body.args] - new_type = isempty(all_types) ? Nothing : Tuple{all_types...} - update_type!(parent_block.body, ssa_idx, new_type) - - - # Insert getfield SSAs after the loop for each token result - last_inserted = ssa_idx - carry_idx = n_user_carries - for (alias_set, effect) in body_effects.effects - effect == MEM_NONE && continue - if effect >= MEM_LOAD - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) - last_inserted = gf_ssa - end - if effect == MEM_STORE - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) - last_inserted = gf_ssa - end - end - if body_effects.has_acquire - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) - end - end - - # Update parent's token_map with loop result tokens + insert_token_result_getfields!(sci, parent_block, ssa_idx, body.args, + n_user_carries, body_effects, result_token_map) merge!(token_map, result_token_map) end @@ -499,13 +480,13 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, if effect >= MEM_LOAD push!(op.init_values, token_map[last_op_key(alias_set)]) before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) - after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + new_block_arg!(op.after, sci, TOKEN_TYPE) body_token_map[last_op_key(alias_set)] = before_arg end if effect == MEM_STORE push!(op.init_values, token_map[last_store_key(alias_set)]) before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) - after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + new_block_arg!(op.after, sci, TOKEN_TYPE) body_token_map[last_store_key(alias_set)] = before_arg end end @@ -551,45 +532,8 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, transform_block!(sci, op.after, alias_result, after_token_map, effects_cache, loop_effects, nothing) - # Insert getfield extractions for token results - if n_total_carries > n_user_carries - # Build result type from before block args (authoritative source of all carries) - all_types = Type[is_token_type(arg.type) ? TokenType : arg.type - for arg in op.before.args] - new_type = isempty(all_types) ? Nothing : Tuple{all_types...} - update_type!(parent_block.body, ssa_idx, new_type) - - - last_inserted = ssa_idx - carry_idx = n_user_carries - for (alias_set, effect) in loop_effects.effects - effect == MEM_NONE && continue - if effect >= MEM_LOAD - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) - last_inserted = gf_ssa - end - if effect == MEM_STORE - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) - last_inserted = gf_ssa - end - end - if loop_effects.has_acquire - carry_idx += 1 - gf_ssa = new_ssa_idx!(sci) - gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx) - insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) - result_token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) - end - end - + insert_token_result_getfields!(sci, parent_block, ssa_idx, op.before.args, + n_user_carries, loop_effects, result_token_map) merge!(token_map, result_token_map) end @@ -623,7 +567,6 @@ function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, if n_token_results > 0 # Update IfOp type to include token results old_type = get(parent_block.body, ssa_idx, nothing) - old_type = get(parent_block.body, ssa_idx, nothing) if old_type !== nothing user_types = if old_type.typ === Nothing Type[] From a8abe60661b0de62458eb13e921708cfd2cd941c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 12:18:32 +0100 Subject: [PATCH 09/10] Fix three token ordering memory model bugs. 1. ACQUIRE_TOKEN_KEY was updated for all atomics instead of only acquire/acq_rel-ordered ones, over-constraining relaxed atomics. 2. has_acquire effect was set unconditionally for all atomics in compute_block_memory_effects!, causing unnecessary token carries. 3. ALIAS_UNIVERSE was treated as overlapping with nothing instead of everything, potentially missing token dependencies for unknown aliases. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/compiler/codegen/passes/token_order.jl | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index 47c391f5..78be5010 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -93,9 +93,12 @@ function compute_block_memory_effects!(block::Block, alias_result::Dict{Any, Ali mem_effect == MEM_NONE && continue alias_set = get_alias_set_for_operand(alias_result, first(operands)) effects.effects[alias_set] = max(get(effects.effects, alias_set, MEM_NONE), mem_effect) - # Track acquire ordering for atomics + # Track acquire ordering for acquire/acq_rel atomics only if is_atomic_intrinsic(resolved_func) - effects = MemoryEffects(effects.effects, true) + mo = extract_memory_order(resolved_func, operands) + if has_acquire_order(mo) + effects = MemoryEffects(effects.effects, true) + end end end end @@ -129,8 +132,8 @@ function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, A should_join = other_key.role == LAST_OP end if other_key.role == token_key.role - alias_overlap = !(other_key.alias_set isa AliasUniverse) && - !(token_key.alias_set isa AliasUniverse) && + alias_overlap = (other_key.alias_set isa AliasUniverse) || + (token_key.alias_set isa AliasUniverse) || !isempty(intersect(other_key.alias_set, token_key.alias_set)) should_join = should_join || alias_overlap end @@ -155,10 +158,14 @@ end function has_release_order(memory_order) memory_order === nothing && return false - # MemoryOrder enum: Release=3, AcqRel=4 return memory_order === MemoryOrder.Release || memory_order === MemoryOrder.AcqRel end +function has_acquire_order(memory_order) + memory_order === nothing && return false + return memory_order === MemoryOrder.Acquire || memory_order === MemoryOrder.AcqRel +end + """ extract_memory_order(resolved_func, operands) -> Union{MemoryOrder.T, Nothing} @@ -304,8 +311,8 @@ function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, token_map[last_op_key(alias_set)] = result_token token_map[last_store_key(alias_set)] = result_token - # Atomics with acquire semantics update the ACQUIRE token - if is_atomic_intrinsic(resolved_func) + # Only acquire/acq_rel atomics update the ACQUIRE token + if is_atomic_intrinsic(resolved_func) && has_acquire_order(memory_order) token_map[ACQUIRE_TOKEN_KEY] = result_token end end From b50c3ca61a18e2e4e4273385bcd5d6dbe05c91c8 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 26 Mar 2026 13:14:33 +0100 Subject: [PATCH 10/10] Add comments. --- src/compiler/codegen/control_flow.jl | 93 +++++++++++++++++-- src/compiler/codegen/passes/alias_analysis.jl | 25 +++++ src/compiler/codegen/passes/token_keys.jl | 3 + src/compiler/codegen/passes/token_order.jl | 34 ++++++- 4 files changed, 144 insertions(+), 11 deletions(-) diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 46ce68ae..853c510b 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -1,12 +1,30 @@ # Structured IR Emission +""" + result_count(T) -> Int + +Compute the number of results from a Block.types entry. +Block.types contains Julia types: +- For Statement: Julia type → 1 result +- For ControlFlowOp with 0 results: Nothing → 0 results +- For ControlFlowOp with 1 result: Julia type → 1 result +- For ControlFlowOp with N results: Tuple{T1, T2, ...} → N results +""" function result_count(@nospecialize(T)) T === Nothing && return 0 T <: Tuple && return length(T.parameters) return 1 end +""" + emit_block!(ctx, block::Block) + +Emit bytecode for a structured IR block. +All SSA values use original Julia SSA indices (no local renumbering). +Values are stored in ctx.values by their original index. +""" function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) + # SSAVector iteration yields (ssa_idx, entry) where entry has .stmt and .typ for (ssa_idx, entry) in block.body if entry.stmt isa ControlFlowOp n_results = result_count(entry.typ) @@ -20,6 +38,14 @@ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) end end +""" + emit_control_flow_op!(ctx, op::ControlFlowOp, result_type, n_results, original_idx) + +Emit bytecode for a structured control flow operation. +Uses multiple dispatch on the concrete ControlFlowOp type. +Results are stored at indices assigned AFTER nested regions (DFS order). +original_idx is the original Julia SSA index for cross-block references. +""" emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(rt), n::Int, idx::Int) = emit_if_op!(ctx, op, rt, n, idx) emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(rt), n::Int, idx::Int) = emit_for_op!(ctx, op, rt, n, idx) emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(rt), n::Int, idx::Int) = emit_while_op!(ctx, op, rt, n, idx) @@ -32,10 +58,12 @@ emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(rt), n::Int, idx::In function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb + # Get condition value cond_tv = emit_value!(ctx, op.condition) cond_tv === nothing && throw(IRError("Cannot resolve condition for IfOp")) - # Build result types uniformly from parent_result_type (pass set it correctly) + # Determine result types from parent_result_type (token_order_pass! already + # updated the type to include any token carries via update_type!) result_types = TypeId[] if parent_result_type !== Nothing if parent_result_type <: Tuple @@ -72,6 +100,7 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), cb = ctx.cb body_blk = op.body + # Get bounds values lower_tv = emit_value!(ctx, op.lower) upper_tv = emit_value!(ctx, op.upper) step_tv = emit_value!(ctx, op.step) @@ -98,7 +127,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), body_builder = function(block_args) saved = copy(ctx.block_args) - # Tile IR layout: [iv, carries...] + + # Tile IR block args layout: [iv, carries...] + # (carries include both user values and token carries added by token_order_pass!) ctx[iv_arg] = CGVal(block_args[1], iv_type, iv_jl_type) for i in 1:n_carries arg = body_blk.args[i] @@ -106,9 +137,8 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), ctx[arg] = CGVal(block_args[i + 1], result_types[i], arg.type, shape) end emit_block!(ctx, body_blk) - # If body has no terminator, emit a ContinueOp with all block args + # If body has no terminator, emit a ContinueOp with all carried values if body_blk.terminator === nothing - # block_args[1] is iv, block_args[2:end] are carries encode_ContinueOp!(ctx.cb, block_args[2:end]) end empty!(ctx.block_args); merge!(ctx.block_args, saved) @@ -116,7 +146,6 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values) - # Trust parent_result_type (pass set it via update_type!) ctx.values[ssa_idx] = CGVal(results, parent_result_type) end @@ -140,12 +169,18 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) body_builder = function(block_args) saved = copy(ctx.block_args) + + # Tile IR block args layout: [carries...] + # (includes both user values and token carries added by token_order_pass!) for i in 1:n_carries arg = body_blk.args[i] shape = RowMajorShape(extract_tile_shape(arg.type)) ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end emit_block!(ctx, body_blk) + # In Tile IR, if the loop body ends with an IfOp (even one with continue/break + # in all branches), the if is NOT a terminator. We need an explicit terminator + # after the if. Add an unreachable ContinueOp as fallback terminator. if body_blk.terminator === nothing encode_ContinueOp!(ctx.cb, copy(block_args)) end @@ -158,6 +193,11 @@ end #============================================================================= WhileOp — lowered to LoopOp pattern in codegen + + MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } + Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } + This structure keeps the "after" statements at LoopOp body level, avoiding + nested region issues when "after" contains loops. =============================================================================# function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) @@ -178,22 +218,29 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ body_builder = function(block_args) saved = copy(ctx.block_args) + # Tile IR block args layout: [carries...] + # (includes both user values and token carries added by token_order_pass!) for i in 1:n_carries arg = before_blk.args[i] shape = RowMajorShape(extract_tile_shape(arg.type)) ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end + # Emit "before" region emit_block!(ctx, before_blk) + # Get condition from ConditionOp terminator cond_op = before_blk.terminator cond_op isa ConditionOp || throw(IRError("WhileOp before region must end with ConditionOp")) cond_tv = emit_value!(ctx, cond_op.condition) (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition")) + # Emit conditional break: if (cond) { yield } else { break } + # This keeps nested loops in "after" at LoopOp body level then_body = (_) -> encode_YieldOp!(ctx.cb, Value[]) else_body = function(_) + # Break with ConditionOp args (become loop results) break_operands = Value[] for arg in cond_op.args tv = emit_value!(ctx, arg) @@ -212,6 +259,8 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end encode_IfOp!(then_body, else_body, cb, TypeId[], cond_tv.v) + # Map "after" region block args from ConditionOp.args (user carries) + # and block_args (token carries beyond ConditionOp.args) for i in 1:length(after_blk.args) arg = after_blk.args[i] if i <= length(cond_op.args) @@ -229,8 +278,10 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end end + # Emit "after" region body (skip terminator — we emit ContinueOp instead) emit_block!(ctx, after_blk; skip_terminator=true) + # Emit ContinueOp with yield values from after region's YieldOp continue_operands = Value[] if after_blk.terminator isa YieldOp for val in after_blk.terminator.values @@ -252,9 +303,15 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ end #============================================================================= - Terminators — tokens already in op.values from token_order_pass! + Terminators + Token values are already in op.values (appended by token_order_pass!). =============================================================================# +""" + emit_terminator!(ctx, terminator) + +Emit bytecode for a block terminator. +""" emit_terminator!(ctx::CGCtx, node::ReturnNode) = emit_return!(ctx, node) function emit_terminator!(ctx::CGCtx, op::YieldOp) @@ -285,12 +342,29 @@ function emit_terminator!(ctx::CGCtx, op::BreakOp) end emit_terminator!(ctx::CGCtx, ::Nothing) = nothing +# ConditionOp is handled specially by emit_while_op!, not emitted as a terminator emit_terminator!(ctx::CGCtx, ::ConditionOp) = nothing #============================================================================= Early Return Hoisting + + tileiras rejects ReturnNode (cuda_tile.return) inside IfOp (cuda_tile.if) + regions. This pre-pass rewrites the structured IR so that ReturnNode only + appears at the top level, replacing nested returns with YieldOp. =============================================================================# +""" + hoist_returns!(block::Block) + +Rewrite `ReturnNode` terminators inside `IfOp` regions into `YieldOp`, +hoisting the return to the parent block. Operates recursively so that +nested early returns (multiple successive `if ... return end` patterns) +are handled automatically. + +Only handles the case where BOTH branches of an IfOp terminate with +ReturnNode (REGION_TERMINATION with 3 children). The 2-child case +(early return inside a loop) is not handled. +""" function hoist_returns!(block::Block) for (_, entry) in block.body stmt = entry.stmt @@ -321,6 +395,13 @@ end Loop getfield extraction — uniform, no token special cases =============================================================================# +""" + emit_loop_getfield!(ctx, args) -> Union{CGVal, Nothing} + +Handle getfield on multi-value results (loops, ifs). Returns CGVal if handled, +nothing if this is not a multi-value extraction and normal handling should proceed. +This is a compile-time lookup — no Tile IR is emitted. +""" function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any}) length(args) >= 2 || return nothing args[1] isa SSAValue || return nothing diff --git a/src/compiler/codegen/passes/alias_analysis.jl b/src/compiler/codegen/passes/alias_analysis.jl index ef2333e8..4df3b2fb 100644 --- a/src/compiler/codegen/passes/alias_analysis.jl +++ b/src/compiler/codegen/passes/alias_analysis.jl @@ -1,3 +1,28 @@ +# Alias Analysis Pass +# +# Fixed-point alias analysis over StructuredIRCode. Determines which memory +# operations may access the same underlying data (i.e., which SSA values +# point into the same allocation). +# +# WHY: The token ordering pass needs alias information to decide which memory +# operations require token dependencies between them. Without alias analysis, +# all memory ops would be serialized through a single token chain — correct, +# but overly conservative. With per-alias-set information, independent memory +# regions (e.g., separate kernel arguments) get independent token chains, +# enabling more parallelism in the generated Tile IR. +# +# HOW: Each pointer-containing kernel argument starts in its own alias set. +# Alias sets propagate forward through: +# - getfield (for TileArray.ptr field access) +# - pointer arithmetic (+, -) +# - view constructors (make_tensor_view, make_partition_view) +# - pointer passthroughs (bitcast, assume_aligned, etc.) +# Unknown operations conservatively produce ALIAS_UNIVERSE (may alias anything). +# Fixed-point iteration handles back-edges from loops. +# +# OUTPUT: Dict{Any, AliasSet} mapping SSA values and Arguments to their alias +# sets, consumed by token_order_pass!. + """ AliasTracker diff --git a/src/compiler/codegen/passes/token_keys.jl b/src/compiler/codegen/passes/token_keys.jl index 07c448ec..231469ca 100644 --- a/src/compiler/codegen/passes/token_keys.jl +++ b/src/compiler/codegen/passes/token_keys.jl @@ -1,3 +1,6 @@ +# Token map key types for the token ordering pass. +# Each key identifies a token lane: per-alias-set (LAST_OP / LAST_STORE) or ACQUIRE. + # Token role enum @enum TokenRole LAST_OP LAST_STORE diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl index 78be5010..2587f07b 100644 --- a/src/compiler/codegen/passes/token_order.jl +++ b/src/compiler/codegen/passes/token_order.jl @@ -1,10 +1,34 @@ -# Token ordering pass +# Token Ordering Pass # -# Transforms a StructuredIRCode by inserting token operations (MakeToken, JoinTokens, -# TokenResult) and threading per-alias-set tokens through control flow. -# After this pass, codegen emits what's in the IR — no manual token threading. +# Transforms a StructuredIRCode by inserting explicit token operations +# (MakeTokenNode, JoinTokensNode, TokenResultNode) and adding token carries +# to loop/branch control flow. After this pass, codegen simply emits what +# the IR says — no manual token threading in control_flow.jl or intrinsics. # -# Mirrors cuTile Python's `token_order_pass` (res/cutile-python/src/cuda/tile/_passes/token_order.py). +# WHY: Tile IR uses a token-based memory ordering model (similar to LLVM's +# token type). Every memory operation (load, store, atomic) consumes an input +# token and produces an output token. The chain of tokens defines the +# happens-before ordering between memory accesses. +# +# HOW: The pass maintains a `token_map: Dict{TokenKey, Any}` mapping each +# (alias_set, role) pair to its current token SSA value. Two roles exist per +# alias set: +# - LAST_OP: token from the most recent load or store (RAW/WAR tracking) +# - LAST_STORE: token from the most recent store only (WAW tracking) +# Plus a global ACQUIRE token for acquire-ordered atomics. +# +# For loads, the input token comes from LAST_STORE of the same alias set +# (read-after-write dependency). For stores, the input token joins all +# LAST_OP tokens of overlapping alias sets (write-after-read + write-after-write). +# Release-ordered atomics additionally join ALL LAST_OP tokens across all alias +# sets (memory fence semantics). Acquire-ordered atomics update the global +# ACQUIRE token. +# +# The pass adds token carries to loops (init_values + block args + terminator +# operands) and token results to IfOp types, then inserts getfield extractions +# after control flow ops to update the parent scope's token_map. +# +# Mirrors cuTile Python's `token_order_pass`. using Core: SSAValue, Argument, SlotNumber