diff --git a/src/compiler/codegen.jl b/src/compiler/codegen.jl index 564aa8ea..50336afc 100644 --- a/src/compiler/codegen.jl +++ b/src/compiler/codegen.jl @@ -1,6 +1,10 @@ # Codegen: Julia IR -> Tile IR bytecode include("codegen/utils.jl") +include("codegen/irutils.jl") # SSAMap/Block mutation helpers +include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY +include("codegen/passes/alias_analysis.jl") # alias_analysis_pass! +include("codegen/passes/token_order.jl") # token_order_pass! include("codegen/kernel.jl") include("codegen/control_flow.jl") include("codegen/statements.jl") diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index b5a4db3d..853c510b 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -24,7 +24,6 @@ All SSA values use original Julia SSA indices (no local renumbering). Values are stored in ctx.values by their original index. """ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) - # Emit body items (interleaved expressions and control flow ops) # SSAVector iteration yields (ssa_idx, entry) where entry has .stmt and .typ for (ssa_idx, entry) in block.body if entry.stmt isa ControlFlowOp @@ -34,8 +33,6 @@ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) emit_statement!(ctx, entry.stmt, ssa_idx, entry.typ) end end - - # Emit terminator (unless skipped) if !skip_terminator && block.terminator !== nothing emit_terminator!(ctx, block.terminator) end @@ -49,78 +46,58 @@ Uses multiple dispatch on the concrete ControlFlowOp type. Results are stored at indices assigned AFTER nested regions (DFS order). original_idx is the original Julia SSA index for cross-block references. """ -emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_if_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_for_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_while_op!(ctx, op, result_type, n_results, original_idx) -emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(result_type), n_results::Int, original_idx::Int) = - emit_loop_op!(ctx, op, result_type, n_results, original_idx) +emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(rt), n::Int, idx::Int) = emit_if_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(rt), n::Int, idx::Int) = emit_for_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(rt), n::Int, idx::Int) = emit_while_op!(ctx, op, rt, n, idx) +emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(rt), n::Int, idx::Int) = emit_loop_op!(ctx, op, rt, n, idx) + +#============================================================================= + IfOp +=============================================================================# function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb - then_blk = op.then_region - else_blk = op.else_region # Get condition value cond_tv = emit_value!(ctx, op.condition) cond_tv === nothing && throw(IRError("Cannot resolve condition for IfOp")) - # Determine result types from parent_result_type + # Determine result types from parent_result_type (token_order_pass! already + # updated the type to include any token carries via update_type!) result_types = TypeId[] - julia_result_types = Type[] - if parent_result_type === Nothing - # No results - elseif parent_result_type <: Tuple - for T in parent_result_type.parameters - push!(result_types, tile_type_for_julia!(ctx, T)) - push!(julia_result_types, T) + if parent_result_type !== Nothing + if parent_result_type <: Tuple + for T in parent_result_type.parameters + push!(result_types, tile_type_for_julia!(ctx, T)) + end + else + push!(result_types, tile_type_for_julia!(ctx, parent_result_type)) end - else - push!(result_types, tile_type_for_julia!(ctx, parent_result_type)) - push!(julia_result_types, parent_result_type) end - n_user_results = length(result_types) - # Add token type as additional result (for memory ordering) - push!(result_types, ctx.token_type) - - # Save token before branches - token_before = ctx.token - # Emit IfOp with callback-based region building then_body = function(_) - saved_block_args = copy(ctx.block_args) - ctx.token = token_before # Reset to pre-branch token - emit_block!(ctx, then_blk) - if then_blk.terminator === nothing - encode_YieldOp!(ctx.cb, [ctx.token]) - end - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + saved = copy(ctx.block_args) + emit_block!(ctx, op.then_region) + op.then_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[]) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end else_body = function(_) - saved_block_args = copy(ctx.block_args) - ctx.token = token_before # Reset to pre-branch token - emit_block!(ctx, else_blk) - if else_blk.terminator === nothing - encode_YieldOp!(ctx.cb, [ctx.token]) - end - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + saved = copy(ctx.block_args) + emit_block!(ctx, op.else_region) + op.else_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[]) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_IfOp!(then_body, else_body, cb, result_types, cond_tv.v) - # Last result is the merged token from both branches - ctx.token = results[end] - - # Store results at IfOp's SSA index (may be empty for void-returning ifs) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + ForOp +=============================================================================# + function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb - tt = ctx.tt body_blk = op.body # Get bounds values @@ -131,197 +108,124 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), (lower_tv === nothing || upper_tv === nothing || step_tv === nothing) && throw(IRError("Cannot resolve ForOp bounds")) - - # Assert all bounds have the same type lower_tv.jltype === upper_tv.jltype === step_tv.jltype || - throw(IRError("ForOp bounds must all have the same type: lower=$(lower_tv.jltype), upper=$(upper_tv.jltype), step=$(step_tv.jltype)")) - iv_jl_type = lower_tv.jltype - iv_type = tile_type_for_julia!(ctx, iv_jl_type) + throw(IRError("ForOp bounds must all have the same type")) + iv_jl_type = lower_tv.jltype + iv_type = tile_type_for_julia!(ctx, iv_jl_type) - # Get init values (init_values are loop-carried values) + # Emit ALL init values (user + token carries from pass) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve ForOp init value")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) - push!(init_values, ctx.token) - - # Number of carries (init_values) - these are the loop results - n_carries = length(op.init_values) - - # Determine result types from carries (body.args) - result_types = TypeId[] - for i in 1:n_carries - body_arg = body_blk.args[i] - type_id = tile_type_for_julia!(ctx, body_arg.type) - push!(result_types, type_id) - end - # Add token type as additional result (for memory ordering) - push!(result_types, ctx.token_type) - # Number of user result types (excluding token) - n_user_results = n_carries + # Build result types uniformly from block args + n_carries = length(body_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args] - # Emit ForOp with callback-based region building body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) - - # Tile IR block args layout: [iv, carries..., token] - # Julia IR body.args layout: [carries...] - - # Map the induction variable - iv_tv = CGVal(block_args[1], iv_type, iv_jl_type) - ctx[iv_arg] = iv_tv + saved = copy(ctx.block_args) - # Map carried values (body.args) + # Tile IR block args layout: [iv, carries...] + # (carries include both user values and token carries added by token_order_pass!) + ctx[iv_arg] = CGVal(block_args[1], iv_type, iv_jl_type) for i in 1:n_carries - body_arg = body_blk.args[i] - shape = RowMajorShape(extract_tile_shape(body_arg.type)) - tv = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape) - ctx[body_arg] = tv + arg = body_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i + 1], result_types[i], arg.type, shape) end - - # Set token from last block arg - ctx.token = block_args[end] - emit_block!(ctx, body_blk) - - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + # If body has no terminator, emit a ContinueOp with all carried values + if body_blk.terminator === nothing + encode_ContinueOp!(ctx.cb, block_args[2:end]) + end + empty!(ctx.block_args); merge!(ctx.block_args, saved) end - results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values) - - # Last result is the token - ctx.token = results[end] + results = encode_ForOp!(body_builder, cb, result_types, iv_type, + lower_tv.v, upper_tv.v, step_tv.v, init_values) - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + LoopOp +=============================================================================# + function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb body_blk = op.body - # Get init values (init_values are loop-carried values) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve LoopOp init value")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) - push!(init_values, ctx.token) - - # Number of carries (init_values) - these are the loop results - n_carries = length(op.init_values) - - # Determine result types from carries (body.args) - result_types = TypeId[] - for i in 1:n_carries - body_arg = body_blk.args[i] - type_id = tile_type_for_julia!(ctx, body_arg.type) - push!(result_types, type_id) - end - # Add token type as additional result (for memory ordering) - push!(result_types, ctx.token_type) - # Number of user result types (excluding token) - n_user_results = n_carries + n_carries = length(body_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in body_blk.args] - # Emit LoopOp with callback-based region building body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) - - # Tile IR block args layout: [carries..., token] - # Julia IR body.args layout: [carries...] + saved = copy(ctx.block_args) - # Map carried values (body.args) + # Tile IR block args layout: [carries...] + # (includes both user values and token carries added by token_order_pass!) for i in 1:n_carries - body_arg = body_blk.args[i] - shape = RowMajorShape(extract_tile_shape(body_arg.type)) - ctx[body_arg] = CGVal(block_args[i], result_types[i], body_arg.type, shape) + arg = body_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end - - # Set token from last block arg - ctx.token = block_args[end] - emit_block!(ctx, body_blk) - # In Tile IR, if the loop body ends with an IfOp (even one with continue/break # in all branches), the if is NOT a terminator. We need an explicit terminator # after the if. Add an unreachable ContinueOp as fallback terminator. if body_blk.terminator === nothing - fallback_operands = copy(block_args) - fallback_operands[end] = ctx.token - encode_ContinueOp!(ctx.cb, fallback_operands) + encode_ContinueOp!(ctx.cb, copy(block_args)) end - - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - # Last result is the token - ctx.token = results[end] - - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + WhileOp — lowered to LoopOp pattern in codegen + + MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } + Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } + This structure keeps the "after" statements at LoopOp body level, avoiding + nested region issues when "after" contains loops. +=============================================================================# + function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int) cb = ctx.cb before_blk = op.before after_blk = op.after - # Get init values (init_values are loop-carried values) init_values = Value[] for init_val in op.init_values tv = emit_value!(ctx, init_val) (tv === nothing || tv.v === nothing) && throw(IRError("Cannot resolve WhileOp init value: $init_val")) push!(init_values, tv.v) end - # Add token as additional init value (for memory ordering) - push!(init_values, ctx.token) - # Number of carries (init_values) - these are the loop results - n_carries = length(op.init_values) + n_carries = length(before_blk.args) + result_types = TypeId[tile_type_for_julia!(ctx, arg.type) for arg in before_blk.args] - # Determine result types from carries (before.args) - result_types = TypeId[] - for i in 1:n_carries - before_arg = before_blk.args[i] - type_id = tile_type_for_julia!(ctx, before_arg.type) - push!(result_types, type_id) - end - # Add token type as additional result (for memory ordering) - push!(result_types, ctx.token_type) - - # Number of user result types (excluding token) - n_user_results = n_carries - - # Emit WhileOp as cuda_tile.loop with conditional break pattern - # MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals } - # Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue } - # This structure keeps the "after" statements at LoopOp body level, avoiding - # nested region issues when "after" contains loops. body_builder = function(block_args) - saved_block_args = copy(ctx.block_args) + saved = copy(ctx.block_args) - # Tile IR block args layout: [carries..., token] - # Julia IR before.args layout: [carries...] - - # Map carried values (before.args) + # Tile IR block args layout: [carries...] + # (includes both user values and token carries added by token_order_pass!) for i in 1:n_carries - before_arg = before_blk.args[i] - shape = RowMajorShape(extract_tile_shape(before_arg.type)) - ctx[before_arg] = CGVal(block_args[i], result_types[i], before_arg.type, shape) + arg = before_blk.args[i] + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end - # Set token from last block arg - ctx.token = block_args[end] - # Emit "before" region emit_block!(ctx, before_blk) @@ -330,15 +234,11 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ cond_op isa ConditionOp || throw(IRError("WhileOp before region must end with ConditionOp")) cond_tv = emit_value!(ctx, cond_op.condition) - (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition: $(cond_op.condition)")) + (cond_tv === nothing || cond_tv.v === nothing) && throw(IRError("Cannot resolve WhileOp condition")) # Emit conditional break: if (cond) { yield } else { break } # This keeps nested loops in "after" at LoopOp body level - then_body = function(_) - # Just yield (empty) - control continues to after_stmts - encode_YieldOp!(ctx.cb, Value[]) - end - + then_body = (_) -> encode_YieldOp!(ctx.cb, Value[]) else_body = function(_) # Break with ConditionOp args (become loop results) break_operands = Value[] @@ -347,34 +247,38 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ tv !== nothing && tv.v !== nothing && push!(break_operands, tv.v) end if isempty(break_operands) - for i in 1:n_carries + append!(break_operands, block_args[1:n_carries]) + else + # Append token carries (block_args beyond user carries from ConditionOp) + n_user = length(break_operands) + for i in (n_user + 1):n_carries push!(break_operands, block_args[i]) end end - push!(break_operands, ctx.token) encode_BreakOp!(ctx.cb, break_operands) end + encode_IfOp!(then_body, else_body, cb, TypeId[], cond_tv.v) - # Emit IfOp with NO results: if (cond) continue flow, else break - if_result_types = TypeId[] - encode_IfOp!(then_body, else_body, cb, if_result_types, cond_tv.v) - - # Now emit "after" region at LoopOp body level (not inside IfOp!) - # Map "after" region block args - carries from ConditionOp.args - for i in 1:n_carries - after_arg = after_blk.args[i] + # Map "after" region block args from ConditionOp.args (user carries) + # and block_args (token carries beyond ConditionOp.args) + for i in 1:length(after_blk.args) + arg = after_blk.args[i] if i <= length(cond_op.args) tv = emit_value!(ctx, cond_op.args[i]) if tv !== nothing - ctx[after_arg] = tv + ctx[arg] = tv else - shape = RowMajorShape(extract_tile_shape(after_arg.type)) - ctx[after_arg] = CGVal(block_args[i], result_types[i], after_arg.type, shape) + shape = RowMajorShape(extract_tile_shape(arg.type)) + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, shape) end + else + # Token carries beyond ConditionOp.args: use block_args directly + ctx[arg] = CGVal(block_args[i], result_types[i], arg.type, + RowMajorShape(extract_tile_shape(arg.type))) end end - # Emit "after" region body (skip terminator - we emit ContinueOp instead) + # Emit "after" region body (skip terminator — we emit ContinueOp instead) emit_block!(ctx, after_blk; skip_terminator=true) # Emit ContinueOp with yield values from after region's YieldOp @@ -385,73 +289,61 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ tv !== nothing && tv.v !== nothing && push!(continue_operands, tv.v) end end - push!(continue_operands, ctx.token) + # Ensure all carries (including tokens from pass) are in the ContinueOp + while length(continue_operands) < n_carries + push!(continue_operands, block_args[length(continue_operands) + 1]) + end encode_ContinueOp!(ctx.cb, continue_operands) - empty!(ctx.block_args) - merge!(ctx.block_args, saved_block_args) + empty!(ctx.block_args); merge!(ctx.block_args, saved) end results = encode_LoopOp!(body_builder, cb, result_types, init_values) - # Last result is the token - ctx.token = results[end] - - # Store results at the loop's SSA index (may be empty for void-returning loops) - ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type) + ctx.values[ssa_idx] = CGVal(results, parent_result_type) end +#============================================================================= + Terminators + Token values are already in op.values (appended by token_order_pass!). +=============================================================================# + """ emit_terminator!(ctx, terminator) Emit bytecode for a block terminator. """ -function emit_terminator!(ctx::CGCtx, node::ReturnNode) - emit_return!(ctx, node) -end +emit_terminator!(ctx::CGCtx, node::ReturnNode) = emit_return!(ctx, node) function emit_terminator!(ctx::CGCtx, op::YieldOp) - # Collect yield operands operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering - push!(operands, ctx.token) encode_YieldOp!(ctx.cb, operands) end function emit_terminator!(ctx::CGCtx, op::ContinueOp) - # Collect continue operands (updated carried values) operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering - push!(operands, ctx.token) encode_ContinueOp!(ctx.cb, operands) end function emit_terminator!(ctx::CGCtx, op::BreakOp) - # Collect break operands (final values) operands = Value[] for val in op.values tv = emit_value!(ctx, val) tv !== nothing && tv.v !== nothing && push!(operands, tv.v) end - # Append current token for memory ordering - push!(operands, ctx.token) encode_BreakOp!(ctx.cb, operands) end -function emit_terminator!(ctx::CGCtx, ::Nothing) - # No terminator, nothing to emit -end - -function emit_terminator!(ctx::CGCtx, ::ConditionOp) - # ConditionOp is handled specially by emit_while_op!, not emitted as a terminator -end +emit_terminator!(ctx::CGCtx, ::Nothing) = nothing +# ConditionOp is handled specially by emit_while_op!, not emitted as a terminator +emit_terminator!(ctx::CGCtx, ::ConditionOp) = nothing #============================================================================= Early Return Hoisting @@ -474,7 +366,6 @@ ReturnNode (REGION_TERMINATION with 3 children). The 2-child case (early return inside a loop) is not handled. """ function hoist_returns!(block::Block) - # First, recurse into all nested control flow ops for (_, entry) in block.body stmt = entry.stmt if stmt isa IfOp @@ -489,29 +380,27 @@ function hoist_returns!(block::Block) hoist_returns!(stmt.body) end end - - # Now check: does this block contain an IfOp where both branches return? - # If so, replace branch ReturnNodes with YieldOp and set block terminator. for (_, entry) in block.body entry.stmt isa IfOp || continue op = entry.stmt::IfOp op.then_region.terminator isa ReturnNode || continue op.else_region.terminator isa ReturnNode || continue - - # Both branches return — hoist to parent block. - # Replace branch terminators with YieldOp (void — no values to yield). op.then_region.terminator = YieldOp() op.else_region.terminator = YieldOp() block.terminator = ReturnNode(nothing) end end +#============================================================================= + Loop getfield extraction — uniform, no token special cases +=============================================================================# + """ - emit_getfield!(ctx, args) -> Union{CGVal, Nothing} + emit_loop_getfield!(ctx, args) -> Union{CGVal, Nothing} Handle getfield on multi-value results (loops, ifs). Returns CGVal if handled, nothing if this is not a multi-value extraction and normal handling should proceed. -This is a compile-time lookup - no Tile IR is emitted. +This is a compile-time lookup — no Tile IR is emitted. """ function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any}) length(args) >= 2 || return nothing diff --git a/src/compiler/codegen/irutils.jl b/src/compiler/codegen/irutils.jl new file mode 100644 index 00000000..4ba9e007 --- /dev/null +++ b/src/compiler/codegen/irutils.jl @@ -0,0 +1,104 @@ +# StructuredIRCode / SSAMap mutation utilities +# +# Helpers for passes that modify the structured IR in place. +# Inspired by Julia's IncrementalCompact (Compiler/src/ssair/ir.jl). + +""" + new_ssa_idx!(sci::StructuredIRCode) -> Int + +Allocate a fresh SSA index from the StructuredIRCode. +""" +function new_ssa_idx!(sci::StructuredIRCode) + sci.max_ssa_idx += 1 + return sci.max_ssa_idx +end + +""" + new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) -> BlockArg + +Add a new BlockArg to a block, allocating a fresh ID. +""" +function new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) + id = new_ssa_idx!(sci) + arg = BlockArg(id, typ) + push!(block.args, arg) + return arg +end + +""" + Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any}) + +Prepend a statement at the beginning of an SSAMap. +""" +function Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any}) + pushfirst!(m.ssa_idxes, idx) + pushfirst!(m.stmts, stmt) + pushfirst!(m.types, typ) + return nothing +end + +""" + insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ) + +Insert a new entry before the entry with SSA index `before_idx`. +""" +function insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ) + pos = findfirst(==(before_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(before_idx)) + insert!(m.ssa_idxes, pos, new_idx) + insert!(m.stmts, pos, stmt) + insert!(m.types, pos, typ) + return nothing +end + +""" + insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ) + +Insert a new entry after the entry with SSA index `after_idx`. +""" +function insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ) + pos = findfirst(==(after_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(after_idx)) + insert!(m.ssa_idxes, pos + 1, new_idx) + insert!(m.stmts, pos + 1, stmt) + insert!(m.types, pos + 1, typ) + return nothing +end + +""" + update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type)) + +Update the type annotation for an existing SSAMap entry. +""" +function update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type)) + pos = findfirst(==(ssa_idx), m.ssa_idxes) + pos === nothing && throw(KeyError(ssa_idx)) + m.types[pos] = new_type + return nothing +end + +""" + resolve_call(stmt) -> (resolved_func, operands) or nothing + +Extract the resolved function and operands from a `:call` or `:invoke` Expr. +Shared by alias analysis and token ordering passes. +""" +function resolve_call(stmt) + stmt isa Expr || return nothing + if stmt.head === :call + func_ref = stmt.args[1] + operands = @view stmt.args[2:end] + elseif stmt.head === :invoke + func_ref = stmt.args[2] + operands = @view stmt.args[3:end] + else + return nothing + end + resolved = if func_ref isa GlobalRef + try; getfield(func_ref.mod, func_ref.name); catch; nothing; end + else + func_ref + end + resolved === nothing && return nothing + return (resolved, operands) +end diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index 79e9a61d..9e49dede 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -141,14 +141,17 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, create_tensor_views!(ctx, arg_idx, argtype, Int[]) end - # Create memory ordering token - token_type = Token(tt) - ctx.token_type = token_type - ctx.token = encode_MakeTokenOp!(cb, token_type) - - # Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp) + # Hoist early returns BEFORE token ordering — hoist_returns! rewrites + # ReturnNode terminators to YieldOp, which the token pass then extends. hoist_returns!(ctx.sci.entry) + # Run alias analysis and token ordering pass on the structured IR. + alias_result = alias_analysis_pass!(sci) + token_order_pass!(sci, alias_result) + + # Cache the token bytecode type for codegen + ctx.token_type = Token(tt) + # Emit the structured IR (uses original Julia SSA indices everywhere) emit_block!(ctx, ctx.sci.entry) @@ -314,7 +317,7 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector, # 3. Create sub-context sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci, - ctx.token, ctx.token_type, + ctx.token_type, ctx.type_cache, ctx.sm_arch, ctx.cache) diff --git a/src/compiler/codegen/passes/alias_analysis.jl b/src/compiler/codegen/passes/alias_analysis.jl new file mode 100644 index 00000000..4df3b2fb --- /dev/null +++ b/src/compiler/codegen/passes/alias_analysis.jl @@ -0,0 +1,236 @@ +# Alias Analysis Pass +# +# Fixed-point alias analysis over StructuredIRCode. Determines which memory +# operations may access the same underlying data (i.e., which SSA values +# point into the same allocation). +# +# WHY: The token ordering pass needs alias information to decide which memory +# operations require token dependencies between them. Without alias analysis, +# all memory ops would be serialized through a single token chain — correct, +# but overly conservative. With per-alias-set information, independent memory +# regions (e.g., separate kernel arguments) get independent token chains, +# enabling more parallelism in the generated Tile IR. +# +# HOW: Each pointer-containing kernel argument starts in its own alias set. +# Alias sets propagate forward through: +# - getfield (for TileArray.ptr field access) +# - pointer arithmetic (+, -) +# - view constructors (make_tensor_view, make_partition_view) +# - pointer passthroughs (bitcast, assume_aligned, etc.) +# Unknown operations conservatively produce ALIAS_UNIVERSE (may alias anything). +# Fixed-point iteration handles back-edges from loops. +# +# OUTPUT: Dict{Any, AliasSet} mapping SSA values and Arguments to their alias +# sets, consumed by token_order_pass!. + +""" + AliasTracker + +Tracks alias sets for each SSA value during fixed-point analysis. +""" +mutable struct AliasTracker + dirty::Bool + aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet +end + +AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}()) + +function Base.getindex(tracker::AliasTracker, key) + return get(tracker.aliases, key, ALIAS_UNIVERSE) +end + +function Base.setindex!(tracker::AliasTracker, value::AliasSet, key) + current = get(tracker.aliases, key, nothing) + if current !== value + tracker.dirty = true + tracker.aliases[key] = value + end + return +end + +""" + alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet} + +Perform fixed-point alias analysis on structured IR. +Returns mapping from SSA values to alias sets. +""" +function alias_analysis_pass!(sci::StructuredIRCode) + tracker = AliasTracker() + + # Initialize: each argument gets its own alias set + for (idx, argtype) in enumerate(sci.argtypes) + argtype_unwrapped = CC.widenconst(argtype) + if contains_pointers(argtype_unwrapped) + arg_ref = Argument(idx) + tracker[arg_ref] = Set{Any}([arg_ref]) + end + end + + # Fixed-point iteration + iteration = 0 + max_iterations = 100 + + tracker.dirty = true + while tracker.dirty && iteration < max_iterations + tracker.dirty = false + iteration += 1 + + analyze_block!(tracker, sci.entry) + end + + @debug "Alias analysis converged in $iteration iterations" + + return tracker.aliases +end + +""" + propagate!(tracker::AliasTracker, from, to) + +Propagate alias set from `from` to `to`. +Uses direct assignment when `to` is uninitialized, union otherwise. +""" +function propagate!(tracker::AliasTracker, from, to) + from_aliases = tracker[from] + + if from_aliases === ALIAS_UNIVERSE + # Propagating UNIVERSE is always conservative + tracker[to] = ALIAS_UNIVERSE + return + end + + if haskey(tracker.aliases, to) + # Target already has an alias set union with it + to_aliases = tracker.aliases[to] + new_aliases = union(from_aliases, to_aliases) + if new_aliases != to_aliases + tracker[to] = new_aliases + end + else + # Target not yet analyzed assign directly + tracker[to] = from_aliases + end + return +end + +""" + analyze_block!(tracker::AliasTracker, block) + +Analyze all statements in a block, recursing into nested control flow. +""" +function analyze_block!(tracker::AliasTracker, block) + for (ssa_idx, entry) in block.body + if entry.stmt isa ControlFlowOp + analyze_control_flow!(tracker, entry.stmt) + else + analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt) + end + end + return +end + +# Recurse into nested control flow regions +function analyze_control_flow!(tracker::AliasTracker, op::IfOp) + analyze_block!(tracker, op.then_region) + return analyze_block!(tracker, op.else_region) +end + +function analyze_control_flow!(tracker::AliasTracker, op::ForOp) + return analyze_block!(tracker, op.body) +end + +function analyze_control_flow!(tracker::AliasTracker, op::WhileOp) + analyze_block!(tracker, op.before) + return analyze_block!(tracker, op.after) +end + +function analyze_control_flow!(tracker::AliasTracker, op::LoopOp) + return analyze_block!(tracker, op.body) +end + +# Fallback for unknown control flow ops +function analyze_control_flow!(::AliasTracker, ::ControlFlowOp) + return +end + +""" + analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + +Analyze a single statement and propagate aliases. +Handles both `:call` and `:invoke` expression forms. +""" +function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt) + call = resolve_call(stmt) + if call !== nothing + resolved_func, operands = call + + # Also need the raw func ref for GlobalRef comparisons + func = stmt.head === :call ? stmt.args[1] : stmt.args[2] + + # getfield: propagate from parent + if func === GlobalRef(Core, :getfield) && length(operands) >= 1 + field = length(operands) >= 2 ? operands[2] : nothing + + # For TileArray.ptr field access, propagate pointer alias + if field isa QuoteNode && field.value === :ptr + propagate!(tracker, operands[1], ssa) + else + # Conservatively mark as UNIVERSE for non-pointer fields + tracker[ssa] = ALIAS_UNIVERSE + end + + # Pointer arithmetic: propagate from pointer operand + elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-) + for arg in operands + # Find the pointer argument and propagate + arg_aliases = tracker[arg] + if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set + propagate!(tracker, arg, ssa) + break + end + end + + # View construction: propagate alias from first operand + elseif is_view_constructor(resolved_func) || is_pointer_passthrough(resolved_func) + if length(operands) >= 1 + propagate!(tracker, operands[1], ssa) + end + + # Default: unknown operation -> UNIVERSE + else + tracker[ssa] = ALIAS_UNIVERSE + end + + elseif stmt isa ReturnNode + # No alias propagation needed + + else + # Unknown statement type -> conservative + tracker[ssa] = ALIAS_UNIVERSE + end + return +end + +# Helper functions +contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr) + +""" + is_view_constructor(func) -> Bool + +Check if a resolved function is a tensor/partition view constructor. +These propagate alias identity from their first operand. +""" +function is_view_constructor(func) + return func === Intrinsics.make_tensor_view || + func === Intrinsics.make_partition_view +end + +function is_pointer_passthrough(func) + func === GlobalRef(Core.Intrinsics, :bitcast) && return true + + # Safely check by name to avoid UndefVarError if intrinsics aren't exposed + if func isa Core.IntrinsicFunction || func isa Function + n = nameof(func) + return n === :bitcast || n === :assume_div_by || n === :assume_bounded || n === :assume_aligned + end + return false +end diff --git a/src/compiler/codegen/passes/token_keys.jl b/src/compiler/codegen/passes/token_keys.jl new file mode 100644 index 00000000..231469ca --- /dev/null +++ b/src/compiler/codegen/passes/token_keys.jl @@ -0,0 +1,41 @@ +# Token map key types for the token ordering pass. +# Each key identifies a token lane: per-alias-set (LAST_OP / LAST_STORE) or ACQUIRE. + +# Token role enum +@enum TokenRole LAST_OP LAST_STORE + +# Acquire token key (singleton) +struct AcquireTokenKey end +const ACQUIRE_TOKEN_KEY = AcquireTokenKey() + +# Alias token key (per alias set and role) +struct AliasTokenKey + alias_set::AliasSet + role::TokenRole +end + +# Union type for all token keys +const TokenKey = Union{AliasTokenKey, AcquireTokenKey} + +# Helper constructors +""" + last_op_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last operation (load or store) on an alias set. +""" +last_op_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_OP) + +""" + last_store_key(alias_set::AliasSet) -> AliasTokenKey + +Create a TokenKey for the last store operation on an alias set. +""" +last_store_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_STORE) + +# Make TokenKey hashable for use in Dict +Base.hash(key::AliasTokenKey, h::UInt) = hash((key.alias_set, key.role), h) +Base.:(==)(a::AliasTokenKey, b::AliasTokenKey) = + a.alias_set == b.alias_set && a.role == b.role + +Base.hash(::AcquireTokenKey, h::UInt) = hash(:ACQUIRE_TOKEN_KEY, h) +Base.:(==)(::AcquireTokenKey, ::AcquireTokenKey) = true diff --git a/src/compiler/codegen/passes/token_order.jl b/src/compiler/codegen/passes/token_order.jl new file mode 100644 index 00000000..2587f07b --- /dev/null +++ b/src/compiler/codegen/passes/token_order.jl @@ -0,0 +1,652 @@ +# Token Ordering Pass +# +# Transforms a StructuredIRCode by inserting explicit token operations +# (MakeTokenNode, JoinTokensNode, TokenResultNode) and adding token carries +# to loop/branch control flow. After this pass, codegen simply emits what +# the IR says — no manual token threading in control_flow.jl or intrinsics. +# +# WHY: Tile IR uses a token-based memory ordering model (similar to LLVM's +# token type). Every memory operation (load, store, atomic) consumes an input +# token and produces an output token. The chain of tokens defines the +# happens-before ordering between memory accesses. +# +# HOW: The pass maintains a `token_map: Dict{TokenKey, Any}` mapping each +# (alias_set, role) pair to its current token SSA value. Two roles exist per +# alias set: +# - LAST_OP: token from the most recent load or store (RAW/WAR tracking) +# - LAST_STORE: token from the most recent store only (WAW tracking) +# Plus a global ACQUIRE token for acquire-ordered atomics. +# +# For loads, the input token comes from LAST_STORE of the same alias set +# (read-after-write dependency). For stores, the input token joins all +# LAST_OP tokens of overlapping alias sets (write-after-read + write-after-write). +# Release-ordered atomics additionally join ALL LAST_OP tokens across all alias +# sets (memory fence semantics). Acquire-ordered atomics update the global +# ACQUIRE token. +# +# The pass adds token carries to loops (init_values + block args + terminator +# operands) and token results to IfOp types, then inserts getfield extractions +# after control flow ops to update the parent scope's token_map. +# +# Mirrors cuTile Python's `token_order_pass`. + +using Core: SSAValue, Argument, SlotNumber + +#============================================================================= + Memory effect classification +=============================================================================# + +@enum MemoryEffect MEM_NONE MEM_LOAD MEM_STORE + +""" + MemoryEffects + +Per-block summary of which alias sets are read/written. +""" +struct MemoryEffects + effects::Dict{AliasSet, MemoryEffect} + has_acquire::Bool +end + +MemoryEffects() = MemoryEffects(Dict{AliasSet, MemoryEffect}(), false) + +function Base.union(a::MemoryEffects, b::MemoryEffects) + result = Dict{AliasSet, MemoryEffect}() + for (k, v) in a.effects; result[k] = v; end + for (k, v) in b.effects + result[k] = max(get(result, k, MEM_NONE), v) + end + return MemoryEffects(result, a.has_acquire | b.has_acquire) +end + +const EMPTY_MEMORY_EFFECTS = MemoryEffects() + +#============================================================================= + Resolve and classify IR expressions +=============================================================================# + +function classify_memory_op(resolved_func) + if resolved_func === Intrinsics.load_partition_view || + resolved_func === Intrinsics.load_ptr_tko + return MEM_LOAD + elseif resolved_func === Intrinsics.store_partition_view || + resolved_func === Intrinsics.store_ptr_tko + return MEM_STORE + elseif is_atomic_intrinsic(resolved_func) + return MEM_STORE + else + return MEM_NONE + end +end + +function is_atomic_intrinsic(func) + isdefined(Intrinsics, :atomic_cas) && func === Intrinsics.atomic_cas && return true + for op in (:atomic_xchg, :atomic_add, :atomic_max, :atomic_min, + :atomic_or, :atomic_and, :atomic_xor) + isdefined(Intrinsics, op) && func === getfield(Intrinsics, op) && return true + end + return false +end + +function get_alias_set_for_operand(alias_result::Dict{Any, AliasSet}, operand) + if operand isa SSAValue || operand isa Argument || operand isa SlotNumber + return get(alias_result, operand, ALIAS_UNIVERSE) + end + return ALIAS_UNIVERSE +end + +#============================================================================= + Compute per-block memory effects +=============================================================================# + +function compute_block_memory_effects!(block::Block, alias_result::Dict{Any, AliasSet}, + cache::Dict{UInt64, MemoryEffects}) + block_id = objectid(block) + haskey(cache, block_id) && return cache[block_id] + + effects = MemoryEffects() + for (_, entry) in block.body + if entry.stmt isa ControlFlowOp + nested = compute_cf_memory_effects!(entry.stmt, alias_result, cache) + effects = union(effects, nested) + else + call = resolve_call(entry.stmt) + call === nothing && continue + resolved_func, operands = call + mem_effect = classify_memory_op(resolved_func) + mem_effect == MEM_NONE && continue + alias_set = get_alias_set_for_operand(alias_result, first(operands)) + effects.effects[alias_set] = max(get(effects.effects, alias_set, MEM_NONE), mem_effect) + # Track acquire ordering for acquire/acq_rel atomics only + if is_atomic_intrinsic(resolved_func) + mo = extract_memory_order(resolved_func, operands) + if has_acquire_order(mo) + effects = MemoryEffects(effects.effects, true) + end + end + end + end + cache[block_id] = effects + return effects +end + +compute_cf_memory_effects!(op::IfOp, ar, c) = + union(compute_block_memory_effects!(op.then_region, ar, c), + compute_block_memory_effects!(op.else_region, ar, c)) +compute_cf_memory_effects!(op::ForOp, ar, c) = compute_block_memory_effects!(op.body, ar, c) +compute_cf_memory_effects!(op::LoopOp, ar, c) = compute_block_memory_effects!(op.body, ar, c) +compute_cf_memory_effects!(op::WhileOp, ar, c) = + union(compute_block_memory_effects!(op.before, ar, c), + compute_block_memory_effects!(op.after, ar, c)) +compute_cf_memory_effects!(::ControlFlowOp, _, _) = EMPTY_MEMORY_EFFECTS + +#============================================================================= + Token map (IR-level, SSAValue/BlockArg) +=============================================================================# + +function collect_join_tokens_ir(token_key::TokenKey, token_map::Dict{TokenKey, Any}, + memory_order=nothing) + tokens_to_join = Any[token_map[token_key]] + for (other_key, other_tok) in token_map + should_join = false + if other_key isa AcquireTokenKey + should_join = true + elseif other_key isa AliasTokenKey && token_key isa AliasTokenKey + if memory_order !== nothing && has_release_order(memory_order) + should_join = other_key.role == LAST_OP + end + if other_key.role == token_key.role + alias_overlap = (other_key.alias_set isa AliasUniverse) || + (token_key.alias_set isa AliasUniverse) || + !isempty(intersect(other_key.alias_set, token_key.alias_set)) + should_join = should_join || alias_overlap + end + end + if should_join && !any(t -> t === other_tok, tokens_to_join) + push!(tokens_to_join, other_tok) + end + end + return tokens_to_join +end + +function get_input_token_ir!(sci::StructuredIRCode, block::Block, before_ssa::Int, + token_key::TokenKey, token_map::Dict{TokenKey, Any}, + memory_order=nothing) + haskey(token_map, token_key) || return token_map[ACQUIRE_TOKEN_KEY] + tokens = collect_join_tokens_ir(token_key, token_map, memory_order) + length(tokens) == 1 && return tokens[1] + join_ssa = new_ssa_idx!(sci) + insert_before!(block.body, before_ssa, join_ssa, JoinTokensNode(tokens), TOKEN_TYPE) + return SSAValue(join_ssa) +end + +function has_release_order(memory_order) + memory_order === nothing && return false + return memory_order === MemoryOrder.Release || memory_order === MemoryOrder.AcqRel +end + +function has_acquire_order(memory_order) + memory_order === nothing && return false + return memory_order === MemoryOrder.Acquire || memory_order === MemoryOrder.AcqRel +end + +""" + extract_memory_order(resolved_func, operands) -> Union{MemoryOrder.T, Nothing} + +Extract the compile-time memory_order from an atomic intrinsic's operands. +""" +function extract_memory_order(resolved_func, operands) + is_atomic_intrinsic(resolved_func) || return nothing + # CAS: (ptr, expected, desired, mask, memory_order, memory_scope) + # RMW: (ptr, val, mask, memory_order, memory_scope) + mo_idx = resolved_func === Intrinsics.atomic_cas ? 5 : 4 + mo_idx > length(operands) && return nothing + mo_arg = operands[mo_idx] + # The memory_order is typically a compile-time constant (QuoteNode or literal) + if mo_arg isa QuoteNode + return mo_arg.value + elseif mo_arg isa MemoryOrder.T + return mo_arg + end + return nothing +end + +#============================================================================= + Control flow exit tokens (matching Python's _get_cf_exit_tokens) +=============================================================================# + +""" + get_cf_exit_tokens(effects, token_map) -> Vector{Any} + +Collect current tokens for each alias set with memory effects. +These are appended to ContinueOp/BreakOp/YieldOp when leaving a CF region. +""" +function get_cf_exit_tokens(effects::MemoryEffects, token_map::Dict{TokenKey, Any}) + tokens = Any[] + for (alias_set, effect) in effects.effects + effect == MEM_NONE && continue + if effect == MEM_LOAD + push!(tokens, token_map[last_op_key(alias_set)]) + elseif effect == MEM_STORE + push!(tokens, token_map[last_op_key(alias_set)]) + push!(tokens, token_map[last_store_key(alias_set)]) + end + end + if effects.has_acquire + push!(tokens, token_map[ACQUIRE_TOKEN_KEY]) + end + return tokens +end + +#============================================================================= + The main pass +=============================================================================# + +function token_order_pass!(sci::StructuredIRCode, alias_result::Dict{Any, AliasSet}) + effects_cache = Dict{UInt64, MemoryEffects}() + compute_block_memory_effects!(sci.entry, alias_result, effects_cache) + + # Insert root MakeTokenNode at entry + root_ssa = new_ssa_idx!(sci) + pushfirst!(sci.entry.body, (root_ssa, MakeTokenNode(), TOKEN_TYPE)) + root_token = SSAValue(root_ssa) + + # Initialize: all alias sets start at root token + token_map = Dict{TokenKey, Any}() + for alias_set in Set(values(alias_result)) + token_map[last_op_key(alias_set)] = root_token + token_map[last_store_key(alias_set)] = root_token + end + token_map[ACQUIRE_TOKEN_KEY] = root_token + + transform_block!(sci, sci.entry, alias_result, token_map, effects_cache, nothing, nothing) + return nothing +end + +#============================================================================= + Block transformation +=============================================================================# + +function transform_block!(sci::StructuredIRCode, block::Block, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, Any}, + effects_cache::Dict{UInt64, MemoryEffects}, + loop_effects::Union{MemoryEffects, Nothing}, + ifelse_effects::Union{MemoryEffects, Nothing}) + # Snapshot indices to avoid invalidation from insertions + ssa_indices = collect(Int, block.body.ssa_idxes) + + for ssa_idx in ssa_indices + entry = get(block.body, ssa_idx, nothing) + entry === nothing && continue + if entry.stmt isa ControlFlowOp + transform_control_flow!(sci, block, ssa_idx, entry.stmt, entry.typ, + alias_result, token_map, effects_cache, loop_effects) + else + transform_statement!(sci, block, ssa_idx, entry.stmt, + alias_result, token_map) + end + end + + # Append exit tokens to the block's terminator (for loops and branches) + transform_terminator!(block, token_map, loop_effects, ifelse_effects) +end + +function transform_statement!(sci::StructuredIRCode, block::Block, ssa_idx::Int, stmt, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, Any}) + call = resolve_call(stmt) + call === nothing && return + resolved_func, operands = call + mem_effect = classify_memory_op(resolved_func) + mem_effect == MEM_NONE && return + + alias_set = get_alias_set_for_operand(alias_result, first(operands)) + + if mem_effect == MEM_LOAD + input_token = get_input_token_ir!(sci, block, ssa_idx, + last_store_key(alias_set), token_map) + push!(stmt.args, input_token) + + result_ssa = new_ssa_idx!(sci) + insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) + result_token = SSAValue(result_ssa) + + # Eagerly join with last_op token (Python line 176-179) + lop_key = last_op_key(alias_set) + last_op_tok = token_map[lop_key] + join_ssa = new_ssa_idx!(sci) + insert_after!(block.body, result_ssa, join_ssa, + JoinTokensNode([last_op_tok, result_token]), TOKEN_TYPE) + token_map[lop_key] = SSAValue(join_ssa) + + elseif mem_effect == MEM_STORE + # For release-ordered atomics, join with ALL LAST_OP tokens (memory fence) + memory_order = extract_memory_order(resolved_func, operands) + input_token = get_input_token_ir!(sci, block, ssa_idx, + last_op_key(alias_set), token_map, + memory_order) + push!(stmt.args, input_token) + + result_ssa = new_ssa_idx!(sci) + insert_after!(block.body, ssa_idx, result_ssa, TokenResultNode(ssa_idx), TOKEN_TYPE) + result_token = SSAValue(result_ssa) + + token_map[last_op_key(alias_set)] = result_token + token_map[last_store_key(alias_set)] = result_token + + # Only acquire/acq_rel atomics update the ACQUIRE token + if is_atomic_intrinsic(resolved_func) && has_acquire_order(memory_order) + token_map[ACQUIRE_TOKEN_KEY] = result_token + end + end +end + +function transform_terminator!(block::Block, token_map::Dict{TokenKey, Any}, + loop_effects::Union{MemoryEffects, Nothing}, + ifelse_effects::Union{MemoryEffects, Nothing}) + term = block.terminator + term === nothing && return + + # ConditionOp (WhileOp before-block): extend args with exit tokens so that + # the codegen-generated BreakOp carries them. + if term isa ConditionOp && loop_effects !== nothing + append!(term.args, get_cf_exit_tokens(loop_effects, token_map)) + return + end + + effects = if (term isa ContinueOp || term isa BreakOp) && loop_effects !== nothing + loop_effects + elseif term isa YieldOp && ifelse_effects !== nothing + ifelse_effects + elseif term isa YieldOp && loop_effects !== nothing + loop_effects + else + nothing + end + effects === nothing && return + append!(term.values, get_cf_exit_tokens(effects, token_map)) +end + +#============================================================================= + Control flow transformation +=============================================================================# + +# --- Loops (ForOp, LoopOp) --- +# Matching Python's Loop handling (token_order.py:228-280) + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::ForOp, @nospecialize(result_type), + alias_result, token_map, effects_cache, parent_loop_effects=nothing) + transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, + token_map, effects_cache) +end + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::LoopOp, @nospecialize(result_type), + alias_result, token_map, effects_cache, parent_loop_effects=nothing) + transform_loop!(sci, parent_block, ssa_idx, op, op.body, alias_result, + token_map, effects_cache) +end + +""" + insert_token_result_getfields!(sci, parent_block, ssa_idx, n_user, effects, token_map) + +Insert getfield extractions after a loop/if for each per-alias token result. +Updates `token_map` with SSAValues pointing to the extracted tokens. +Also updates the SSAMap type to include TokenType parameters. +""" +function insert_token_result_getfields!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, block_args, n_user::Int, + effects::MemoryEffects, token_map::Dict{TokenKey, Any}) + n_total = length(block_args) + n_total > n_user || return + + # Update result type to include all carries + all_types = Type[is_token_type(arg.type) ? TokenType : arg.type for arg in block_args] + update_type!(parent_block.body, ssa_idx, isempty(all_types) ? Nothing : Tuple{all_types...}) + + last_inserted = ssa_idx + carry_idx = n_user + for (alias_set, effect) in effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if effects.has_acquire + carry_idx += 1 + gf_ssa = new_ssa_idx!(sci) + insert_after!(parent_block.body, last_inserted, gf_ssa, + Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), carry_idx), TOKEN_TYPE) + token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end +end + +""" + transform_loop!(...) + +Add per-alias-set token carries to a loop. +""" +function transform_loop!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::Union{ForOp, LoopOp}, body::Block, + alias_result::Dict{Any, AliasSet}, + token_map::Dict{TokenKey, Any}, + effects_cache::Dict{UInt64, MemoryEffects}) + body_effects = get(effects_cache, objectid(body), EMPTY_MEMORY_EFFECTS) + + body_token_map = copy(token_map) + result_token_map = copy(token_map) + + # Track the number of user carries (before we add tokens) + n_user_carries = length(op.init_values) + + # Add per-alias token carries (matching Python lines 245-264) + carry_idx = n_user_carries # 0-based index into results after user carries + for (alias_set, effect) in body_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + carry_idx += 1 + push!(op.init_values, token_map[last_op_key(alias_set)]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[last_op_key(alias_set)] = body_arg + # result_token_map will be updated below with getfield SSAs + end + if effect == MEM_STORE + carry_idx += 1 + push!(op.init_values, token_map[last_store_key(alias_set)]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[last_store_key(alias_set)] = body_arg + end + end + if body_effects.has_acquire + carry_idx += 1 + push!(op.init_values, token_map[ACQUIRE_TOKEN_KEY]) + body_arg = new_block_arg!(body, sci, TOKEN_TYPE) + body_token_map[ACQUIRE_TOKEN_KEY] = body_arg + end + + n_total_carries = length(op.init_values) + + # Recurse into body with body-scoped token map + transform_block!(sci, body, alias_result, body_token_map, effects_cache, + body_effects, nothing) + + insert_token_result_getfields!(sci, parent_block, ssa_idx, body.args, + n_user_carries, body_effects, result_token_map) + merge!(token_map, result_token_map) +end + +# --- WhileOp --- +# WhileOp has before/after regions. We treat it similarly to a loop but need to +# handle both regions. + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::WhileOp, @nospecialize(result_type), + alias_result, token_map, effects_cache, parent_loop_effects=nothing) + before_effects = get(effects_cache, objectid(op.before), EMPTY_MEMORY_EFFECTS) + after_effects = get(effects_cache, objectid(op.after), EMPTY_MEMORY_EFFECTS) + loop_effects = union(before_effects, after_effects) + + body_token_map = copy(token_map) + result_token_map = copy(token_map) + n_user_carries = length(op.init_values) + + # Add per-alias token carries to before/after blocks + for (alias_set, effect) in loop_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + push!(op.init_values, token_map[last_op_key(alias_set)]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[last_op_key(alias_set)] = before_arg + end + if effect == MEM_STORE + push!(op.init_values, token_map[last_store_key(alias_set)]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[last_store_key(alias_set)] = before_arg + end + end + if loop_effects.has_acquire + push!(op.init_values, token_map[ACQUIRE_TOKEN_KEY]) + before_arg = new_block_arg!(op.before, sci, TOKEN_TYPE) + after_arg = new_block_arg!(op.after, sci, TOKEN_TYPE) + body_token_map[ACQUIRE_TOKEN_KEY] = before_arg + end + + n_total_carries = length(op.init_values) + + # Build after_token_map from after block's args (not before's) + after_token_map = copy(token_map) + after_arg_idx = n_user_carries + for (alias_set, effect) in loop_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + after_arg_idx += 1 + after_token_map[last_op_key(alias_set)] = op.after.args[after_arg_idx] + end + if effect == MEM_STORE + after_arg_idx += 1 + after_token_map[last_store_key(alias_set)] = op.after.args[after_arg_idx] + end + end + if loop_effects.has_acquire + after_arg_idx += 1 + after_token_map[ACQUIRE_TOKEN_KEY] = op.after.args[after_arg_idx] + end + + # Transform before region (may update body_token_map, e.g., CAS in condition) + transform_block!(sci, op.before, alias_result, body_token_map, effects_cache, + loop_effects, nothing) + + # Propagate before's final token state to after_token_map. + # The after block receives values from before's ConditionOp, so it should + # see the token state AFTER the before block's transformations (e.g., CAS result). + for (key, val) in body_token_map + after_token_map[key] = val + end + + transform_block!(sci, op.after, alias_result, after_token_map, effects_cache, + loop_effects, nothing) + + insert_token_result_getfields!(sci, parent_block, ssa_idx, op.before.args, + n_user_carries, loop_effects, result_token_map) + merge!(token_map, result_token_map) +end + +# --- IfOp --- +# Matching Python's IfElse handling (token_order.py:294-334) + +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::IfOp, @nospecialize(result_type), + alias_result, token_map, effects_cache, parent_loop_effects=nothing) + then_effects = get(effects_cache, objectid(op.then_region), EMPTY_MEMORY_EFFECTS) + else_effects = get(effects_cache, objectid(op.else_region), EMPTY_MEMORY_EFFECTS) + merged_effects = union(then_effects, else_effects) + + # Transform both branches. Pass parent_loop_effects so that ContinueOp/BreakOp + # inside branches (common for LoopOp→IfOp patterns) get token exit values. + then_map = copy(token_map) + transform_block!(sci, op.then_region, alias_result, then_map, effects_cache, + parent_loop_effects, merged_effects) + else_map = copy(token_map) + transform_block!(sci, op.else_region, alias_result, else_map, effects_cache, + parent_loop_effects, merged_effects) + + # Count token results and insert getfield extractions + n_token_results = 0 + for (_, effect) in merged_effects.effects + effect == MEM_NONE && continue + n_token_results += (effect == MEM_LOAD) ? 1 : 2 + end + n_token_results += merged_effects.has_acquire ? 1 : 0 + + if n_token_results > 0 + # Update IfOp type to include token results + old_type = get(parent_block.body, ssa_idx, nothing) + if old_type !== nothing + user_types = if old_type.typ === Nothing + Type[] + elseif old_type.typ <: Tuple + collect(Type, old_type.typ.parameters) + else + Type[old_type.typ] + end + token_types = fill(TokenType, n_token_results) + new_type = Tuple{user_types..., token_types...} + update_type!(parent_block.body, ssa_idx, new_type) + else + user_types = Type[] + end + + # Insert getfield extractions for token results + last_inserted = ssa_idx + result_idx = length(user_types) + for (alias_set, effect) in merged_effects.effects + effect == MEM_NONE && continue + if effect >= MEM_LOAD + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[last_op_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + if effect == MEM_STORE + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[last_store_key(alias_set)] = SSAValue(gf_ssa) + last_inserted = gf_ssa + end + end + if merged_effects.has_acquire + result_idx += 1 + gf_ssa = new_ssa_idx!(sci) + gf_expr = Expr(:call, GlobalRef(Core, :getfield), SSAValue(ssa_idx), result_idx) + insert_after!(parent_block.body, last_inserted, gf_ssa, gf_expr, TOKEN_TYPE) + token_map[ACQUIRE_TOKEN_KEY] = SSAValue(gf_ssa) + end + end +end + +# Fallback +function transform_control_flow!(sci::StructuredIRCode, parent_block::Block, + ssa_idx::Int, op::ControlFlowOp, @nospecialize(result_type), + alias_result, token_map, effects_cache, parent_loop_effects=nothing) +end diff --git a/src/compiler/codegen/statements.jl b/src/compiler/codegen/statements.jl index df15534c..b3de499e 100644 --- a/src/compiler/codegen/statements.jl +++ b/src/compiler/codegen/statements.jl @@ -7,8 +7,15 @@ Emit bytecode for a single SSA statement. The ssa_idx is the original Julia SSA index to store the result at. """ function emit_statement!(ctx::CGCtx, @nospecialize(stmt), ssa_idx::Int, @nospecialize(result_type)) + ctx.current_ssa_idx = ssa_idx tv = nothing - if stmt isa ReturnNode + if stmt isa MakeTokenNode + tv = emit_make_token!(ctx) + elseif stmt isa JoinTokensNode + tv = emit_join_tokens!(ctx, stmt) + elseif stmt isa TokenResultNode + tv = emit_token_result!(ctx, stmt) + elseif stmt isa ReturnNode emit_return!(ctx, stmt) elseif stmt isa Expr tv = emit_expr!(ctx, stmt, result_type) @@ -65,3 +72,44 @@ function emit_return!(ctx::CGCtx, node::ReturnNode) end end end + +#============================================================================= + Token IR node emission +=============================================================================# + +function emit_make_token!(ctx::CGCtx) + token_type = ctx.token_type + if token_type === nothing + token_type = Token(ctx.tt) + ctx.token_type = token_type + end + v = encode_MakeTokenOp!(ctx.cb, token_type) + return CGVal(v, token_type, TokenType) +end + +function emit_join_tokens!(ctx::CGCtx, node::JoinTokensNode) + tokens = Value[] + for tok_ref in node.tokens + tv = emit_value!(ctx, tok_ref) + tv === nothing && throw(IRError("JoinTokensNode: cannot resolve token operand $tok_ref")) + push!(tokens, tv.v) + end + # Deduplicate by identity (avoid join_tokens(%x, %x)) + unique_tokens = Value[] + for t in tokens + any(u -> u === t, unique_tokens) || push!(unique_tokens, t) + end + if length(unique_tokens) == 1 + return CGVal(unique_tokens[1], ctx.token_type, TokenType) + end + v = encode_JoinTokensOp!(ctx.cb, ctx.token_type, unique_tokens) + return CGVal(v, ctx.token_type, TokenType) +end + +function emit_token_result!(ctx::CGCtx, node::TokenResultNode) + # The memory op at node.mem_op_ssa should have stored its result token + v = get(ctx.result_tokens, node.mem_op_ssa, nothing) + v === nothing && throw(IRError("TokenResultNode: no result token for memory op at SSA %$(node.mem_op_ssa)")) + return CGVal(v, ctx.token_type, TokenType) +end + diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index 393811b2..c410fe02 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -2,6 +2,92 @@ # # Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. + +#============================================================================= + Alias Analysis Types +=============================================================================# + +""" + AliasUniverse + +Singleton type representing the universal alias set (everything may alias everything). +""" +struct AliasUniverse end +const ALIAS_UNIVERSE = AliasUniverse() + +# Universe behaves specially in set operations +Base.union(::AliasUniverse, ::AliasUniverse) = ALIAS_UNIVERSE +Base.union(::AliasUniverse, other) = ALIAS_UNIVERSE +Base.union(other, ::AliasUniverse) = ALIAS_UNIVERSE +Base.intersect(::AliasUniverse, other) = other +Base.intersect(other, ::AliasUniverse) = other +Base.:(==)(::AliasUniverse, ::AliasUniverse) = true +Base.:(==)(::AliasUniverse, other) = false +Base.:(==)(other, ::AliasUniverse) = false + +""" + AliasSet + +Union type representing either a concrete set of values that may alias, +or the universal alias set (ALIAS_UNIVERSE). +""" +const AliasSet = Union{Set{Any}, AliasUniverse} + + +#============================================================================= + Token IR Node Types (inserted by token_order_pass!) +=============================================================================# + +""" + TokenType + +Sentinel type used in StructuredIRCode to mark SSA values and BlockArgs +that represent memory ordering tokens. Not a runtime type. +""" +struct TokenType end +const TOKEN_TYPE = TokenType() + +""" + MakeTokenNode + +IR statement node: creates the root memory ordering token at kernel entry. +Inserted by `token_order_pass!`. Emitted as `encode_MakeTokenOp!` during codegen. +""" +struct MakeTokenNode end + +""" + JoinTokensNode + +IR statement node: merges multiple token values into one. +Inserted by `token_order_pass!`. Emitted as `encode_JoinTokensOp!` during codegen. +""" +struct JoinTokensNode + tokens::Vector{Any} # SSAValue or BlockArg references to token values +end + +""" + TokenResultNode + +IR statement node: represents the result token produced by a memory operation. +The memory op at `mem_op_ssa` produces both a data value and a token; this node +extracts the token. Codegen resolves this via `ctx.result_tokens[mem_op_ssa]`. +""" +struct TokenResultNode + mem_op_ssa::Int # SSA index of the memory operation that produced this token +end + +# Note: IfTokenResultNode was considered but removed in favor of conservative +# IfOp handling (no token carries through branches). Can be added later for +# more precise token tracking. + +""" + is_token_type(typ) -> Bool + +Check whether a type annotation in the structured IR represents a token. +Handles both instances (`TOKEN_TYPE`) and the type itself (`TokenType`). +""" +is_token_type(@nospecialize(typ)) = typ isa TokenType || typ === TokenType + #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -165,10 +251,17 @@ mutable struct CGCtx tt::TypeTable sci::StructuredIRCode - # Memory ordering token - token::Union{Value, Nothing} + # Token bytecode type (cached for encoding token operations) token_type::Union{TypeId, Nothing} + # Result tokens from memory ops: mem_op SSA index → bytecode Value + # Populated during codegen when emitting memory ops with token args. + # Read by TokenResultNode emission. + result_tokens::Dict{Int, Value} + + # Current SSA index being emitted (set by emit_statement!) + current_ssa_idx::Int + # Type cache: Julia type -> TypeId type_cache::Dict{Type, TypeId} @@ -180,7 +273,6 @@ mutable struct CGCtx end function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, - token::Union{Value, Nothing} = nothing, token_type::Union{TypeId, Nothing} = nothing, type_cache::Dict{Type, TypeId} = Dict{Type, TypeId}(), sm_arch::Union{VersionNumber, Nothing} = nothing, @@ -192,8 +284,16 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, Dict{Int, CGVal}(), Dict{Tuple{Int, Vector{Int}}, Vector{Value}}(), Dict{Int, Type}(), - Dict{Any, Tuple{Value, TypeId}}(), - cb, tt, sci, token, token_type, type_cache, sm_arch, cache, + Dict{Int, Tuple{Value, TypeId}}(), + cb, + tt, + sci, + token_type, + Dict{Int, Value}(), # result_tokens + 0, # current_ssa_idx + type_cache, + sm_arch, + cache, ) end @@ -330,6 +430,7 @@ Get or create a Tile IR type for a Julia type. With `throw_error=false`, returns `nothing` instead of throwing if the type has no Tile IR representation. """ function tile_type_for_julia!(ctx::CGCtx, @nospecialize(T); throw_error::Bool=true) + is_token_type(T) && return Token(ctx.tt) actual_type = CC.widenconst(T) cached = get(ctx.type_cache, actual_type, nothing) cached !== nothing && return cached @@ -487,6 +588,7 @@ Extract shape from a Tile{T, Shape} type in Julia's column-major convention. Returns ScalarShape for non-Tile types. """ function extract_tile_shape(@nospecialize(T)) + is_token_type(T) && return ScalarShape() T = CC.widenconst(T) if T <: Tile return ColMajorShape(size(T)) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 9ea4484b..ec76bca9 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -28,6 +28,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptr_tile, expected, desired, mask, memory_order, memory_scope) ptr_tv = emit_value!(ctx, args[1]) ptr_tv === nothing && throw(IRError("atomic CAS requires ptr_tile")) @@ -60,17 +63,18 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, expected_tv.v, desired_tv.v; mask=mask_tv.v, - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) else encode_AtomicCASPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, expected_tv.v, desired_tv.v; - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) end - ctx.token = new_token + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = new_token julia_shape = ColMajorShape(shape) CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) @@ -81,6 +85,9 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptr_tile, val, mask, memory_order, memory_scope) ptr_tv = emit_value!(ctx, args[1]) ptr_tv === nothing && throw(IRError("atomic RMW requires ptr_tile")) @@ -118,17 +125,18 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, val_tv.v, actual_mode; mask=mask_tv.v, - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) else encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type, ptr_tv.v, val_tv.v, actual_mode; - token=ctx.token, + token=input_token, memory_ordering=mem_ordering, memory_scope=mem_scope) end - ctx.token = new_token + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = new_token julia_shape = ColMajorShape(shape) CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index 5866bd66..47edb07c 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -1,7 +1,5 @@ # Memory -# TODO: cuda_tile.join_tokens - # cuda_tile.load_ptr_tko @intrinsic load_ptr_tko(ptrs, latency=nothing, mask=nothing, padding=nothing) function tfunc(𝕃, ::typeof(Intrinsics.load_ptr_tko), @nospecialize(ptrs), @nospecialize args...) @@ -17,57 +15,51 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (ptrs, latency, mask?, padding?) - # Get pointer tile (arg 1) ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve pointer tile")) pointers = ptrs_tv.v tile_shape = ptrs_tv.shape - # Get element type from pointer tile type (Tile{Ptr{T}, S}) ptrs_type = CC.widenconst(ptrs_tv.jltype) - ptr_type = eltype(ptrs_type) # Ptr{T} from Tile{Ptr{T}, S} - elem_type = eltype(ptr_type) # T from Ptr{T} + ptr_type = eltype(ptrs_type) + elem_type = eltype(ptr_type) dtype = julia_to_tile_dtype!(tt, elem_type) result_tile_type = tile_type!(tt, dtype, tile_shape) token_type = Token(tt) - # Extract latency hint (args[2]) latency = @something get_constant(ctx, args[2]) throw(IRError("latency must be a compile-time constant")) - optimization_hints = create_optimization_hints(ctx, latency) - mask_tv, has_mask = emit_optional_mask(ctx, args, 3) if has_mask mask = mask_tv.v - - # Get padding tile (arg 4) padding_tv = emit_value!(ctx, args[4]) padding_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve padding tile")) padding = padding_tv.v - # Load with mask and padding - tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; - mask=mask, - padding_value=padding, - token=ctx.token, - optimization_hints) + tile_val, new_token = encode_LoadPtrTkoOp!( + cb, result_tile_type, token_type, pointers; + mask, padding_value = padding, token = input_token, optimization_hints + ) else - # Load without mask - tile_val, new_token = encode_LoadPtrTkoOp!(cb, result_tile_type, token_type, pointers; - token=ctx.token, - optimization_hints) + tile_val, new_token = encode_LoadPtrTkoOp!( + cb, result_tile_type, token_type, pointers; + token = input_token, optimization_hints + ) end - ctx.token = new_token + + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = new_token julia_shape = ColMajorShape(tile_shape) result_jltype = Tile{elem_type, TupleType(julia_shape)} - CGVal(tile_val, result_tile_type, result_jltype, tile_shape) + return CGVal(tile_val, result_tile_type, result_jltype, tile_shape) end -# TODO: cuda_tile.make_token - # cuda_tile.store_ptr_tko @intrinsic store_ptr_tko(ptrs::Tile{Ptr{T}, S}, values::Tile{T, S}, latency::Union{Int, Nothing}, @@ -79,13 +71,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) cb = ctx.cb tt = ctx.tt - # args: (ptrs, values, latency, mask?) - # Get pointer tile (arg 1) + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve pointer tile")) pointers = ptrs_tv.v - # Get value tile (arg 2) values_tv = emit_value!(ctx, args[2]) values_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve values tile")) values = values_tv.v @@ -93,26 +85,41 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) token_type = Token(tt) latency = @something get_constant(ctx, args[3]) throw(IRError("latency must be a compile-time constant")) - optimization_hints = create_optimization_hints(ctx, latency) - mask_tv, has_mask = emit_optional_mask(ctx, args, 4) if has_mask mask = mask_tv.v - - # Store with mask - new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; - mask=mask, - token=ctx.token, - optimization_hints) + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + mask, token = input_token, optimization_hints + ) else - # Store without mask - new_token = encode_StorePtrTkoOp!(cb, token_type, pointers, values; - token=ctx.token, - optimization_hints) + new_token = encode_StorePtrTkoOp!( + cb, token_type, pointers, values; + token = input_token, optimization_hints + ) end - ctx.token = new_token + + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = new_token nothing end + +""" + extract_token_arg!(ctx, args) -> Value + +Extract the token argument (last arg, added by token_order_pass!) from a memory op call. +Pops the token from args and returns the resolved bytecode Value. +""" +function extract_token_arg!(ctx::CGCtx, args) + isempty(args) && throw(IRError("Memory op has no arguments")) + last_arg = args[end] + tv = emit_value!(ctx, last_arg) + if tv !== nothing && tv.jltype === TokenType + pop!(args) + return tv.v + end + throw(IRError("Memory op missing token argument (token_order_pass! not run?)")) +end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 11e1b99f..e747c7e7 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -56,6 +56,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (partition_view, latency, allow_tma, indices) pv_arg = emit_value!(ctx, args[1]) pv_arg === nothing && throw(IRError("load_partition_view() requires a PartitionView argument")) @@ -108,13 +111,16 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Load tile with token - tile_val, new_token = encode_LoadViewTkoOp!(cb, tile_type, token_type, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token + tile_val, result_token = encode_LoadViewTkoOp!( + cb, tile_type, token_type, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = result_token julia_shape = ColMajorShape(tile_shape) - CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) + return CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) end function pad_indices(ctx::CGCtx, index_vals::Vector{Value}, ndim::Int, idx_type::TypeId, idx_jl_type::Type) @@ -351,6 +357,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), cb = ctx.cb tt = ctx.tt + # Extract input token from last arg (added by token_order_pass!) + input_token = extract_token_arg!(ctx, args) + # args: (partition_view, tile, latency, allow_tma, indices) pv_arg = emit_value!(ctx, args[1]) pv_arg === nothing && throw(IRError("store_partition_view() requires a PartitionView argument")) @@ -414,11 +423,15 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) - # Store tile with token token_type = Token(tt) - new_token = encode_StoreViewTkoOp!(cb, token_type, tile_val, pv_arg.v, index_vals; - token=ctx.token, optimization_hints) - ctx.token = new_token - nothing + result_token = encode_StoreViewTkoOp!( + cb, token_type, tile_val, pv_arg.v, index_vals; + token = input_token, optimization_hints + ) + + # Store result token for TokenResultNode + ctx.result_tokens[ctx.current_ssa_idx] = result_token + + return nothing end diff --git a/src/cuTile.jl b/src/cuTile.jl index 3549b2f3..2987e826 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -1,7 +1,7 @@ module cuTile using IRStructurizer -using IRStructurizer: Block, ControlFlowOp, BlockArg, +using IRStructurizer: Block, ControlFlowOp, BlockArg, SSAMap, YieldOp, ContinueOp, BreakOp, ConditionOp, IfOp, ForOp, WhileOp, LoopOp, Undef