Skip to content
Merged
4 changes: 4 additions & 0 deletions src/compiler/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Codegen: Julia IR -> Tile IR bytecode

include("codegen/utils.jl")
include("codegen/irutils.jl") # SSAMap/Block mutation helpers
include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
include("codegen/passes/alias_analysis.jl") # alias_analysis_pass!
include("codegen/passes/token_order.jl") # token_order_pass!
include("codegen/kernel.jl")
include("codegen/control_flow.jl")
include("codegen/statements.jl")
Expand Down
361 changes: 125 additions & 236 deletions src/compiler/codegen/control_flow.jl

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions src/compiler/codegen/irutils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# StructuredIRCode / SSAMap mutation utilities
#
# Helpers for passes that modify the structured IR in place.
# Inspired by Julia's IncrementalCompact (Compiler/src/ssair/ir.jl).

"""
new_ssa_idx!(sci::StructuredIRCode) -> Int

Allocate a fresh SSA index from the StructuredIRCode.
"""
function new_ssa_idx!(sci::StructuredIRCode)
sci.max_ssa_idx += 1
return sci.max_ssa_idx
end

"""
new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ)) -> BlockArg

Add a new BlockArg to a block, allocating a fresh ID.
"""
function new_block_arg!(block::Block, sci::StructuredIRCode, @nospecialize(typ))
id = new_ssa_idx!(sci)
arg = BlockArg(id, typ)
push!(block.args, arg)
return arg
end

"""
Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any})

Prepend a statement at the beginning of an SSAMap.
"""
function Base.pushfirst!(m::SSAMap, (idx, stmt, typ)::Tuple{Int,Any,Any})
pushfirst!(m.ssa_idxes, idx)
pushfirst!(m.stmts, stmt)
pushfirst!(m.types, typ)
return nothing
end

"""
insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ)

Insert a new entry before the entry with SSA index `before_idx`.
"""
function insert_before!(m::SSAMap, before_idx::Int, new_idx::Int, stmt, typ)
pos = findfirst(==(before_idx), m.ssa_idxes)
pos === nothing && throw(KeyError(before_idx))
insert!(m.ssa_idxes, pos, new_idx)
insert!(m.stmts, pos, stmt)
insert!(m.types, pos, typ)
return nothing
end

"""
insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ)

Insert a new entry after the entry with SSA index `after_idx`.
"""
function insert_after!(m::SSAMap, after_idx::Int, new_idx::Int, stmt, typ)
pos = findfirst(==(after_idx), m.ssa_idxes)
pos === nothing && throw(KeyError(after_idx))
insert!(m.ssa_idxes, pos + 1, new_idx)
insert!(m.stmts, pos + 1, stmt)
insert!(m.types, pos + 1, typ)
return nothing
end

"""
update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type))

Update the type annotation for an existing SSAMap entry.
"""
function update_type!(m::SSAMap, ssa_idx::Int, @nospecialize(new_type))
pos = findfirst(==(ssa_idx), m.ssa_idxes)
pos === nothing && throw(KeyError(ssa_idx))
m.types[pos] = new_type
return nothing
end

"""
resolve_call(stmt) -> (resolved_func, operands) or nothing

Extract the resolved function and operands from a `:call` or `:invoke` Expr.
Shared by alias analysis and token ordering passes.
"""
function resolve_call(stmt)
stmt isa Expr || return nothing
if stmt.head === :call
func_ref = stmt.args[1]
operands = @view stmt.args[2:end]
elseif stmt.head === :invoke
func_ref = stmt.args[2]
operands = @view stmt.args[3:end]
else
return nothing
end
resolved = if func_ref isa GlobalRef
try; getfield(func_ref.mod, func_ref.name); catch; nothing; end
else
func_ref
end
resolved === nothing && return nothing
return (resolved, operands)
end
17 changes: 10 additions & 7 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,17 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
create_tensor_views!(ctx, arg_idx, argtype, Int[])
end

# Create memory ordering token
token_type = Token(tt)
ctx.token_type = token_type
ctx.token = encode_MakeTokenOp!(cb, token_type)

# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
# Hoist early returns BEFORE token ordering — hoist_returns! rewrites
# ReturnNode terminators to YieldOp, which the token pass then extends.
hoist_returns!(ctx.sci.entry)

# Run alias analysis and token ordering pass on the structured IR.
alias_result = alias_analysis_pass!(sci)
token_order_pass!(sci, alias_result)

# Cache the token bytecode type for codegen
ctx.token_type = Token(tt)

# Emit the structured IR (uses original Julia SSA indices everywhere)
emit_block!(ctx, ctx.sci.entry)

Expand Down Expand Up @@ -314,7 +317,7 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,

# 3. Create sub-context
sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci,
ctx.token, ctx.token_type,
ctx.token_type,
ctx.type_cache, ctx.sm_arch,
ctx.cache)

Expand Down
236 changes: 236 additions & 0 deletions src/compiler/codegen/passes/alias_analysis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Alias Analysis Pass
#
# Fixed-point alias analysis over StructuredIRCode. Determines which memory
# operations may access the same underlying data (i.e., which SSA values
# point into the same allocation).
#
# WHY: The token ordering pass needs alias information to decide which memory
# operations require token dependencies between them. Without alias analysis,
# all memory ops would be serialized through a single token chain — correct,
# but overly conservative. With per-alias-set information, independent memory
# regions (e.g., separate kernel arguments) get independent token chains,
# enabling more parallelism in the generated Tile IR.
#
# HOW: Each pointer-containing kernel argument starts in its own alias set.
# Alias sets propagate forward through:
# - getfield (for TileArray.ptr field access)
# - pointer arithmetic (+, -)
# - view constructors (make_tensor_view, make_partition_view)
# - pointer passthroughs (bitcast, assume_aligned, etc.)
# Unknown operations conservatively produce ALIAS_UNIVERSE (may alias anything).
# Fixed-point iteration handles back-edges from loops.
#
# OUTPUT: Dict{Any, AliasSet} mapping SSA values and Arguments to their alias
# sets, consumed by token_order_pass!.

"""
AliasTracker

Tracks alias sets for each SSA value during fixed-point analysis.
"""
mutable struct AliasTracker
dirty::Bool
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
end

AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())

function Base.getindex(tracker::AliasTracker, key)
return get(tracker.aliases, key, ALIAS_UNIVERSE)
end

function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
current = get(tracker.aliases, key, nothing)
if current !== value
tracker.dirty = true
tracker.aliases[key] = value
end
return
end

"""
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}

Perform fixed-point alias analysis on structured IR.
Returns mapping from SSA values to alias sets.
"""
function alias_analysis_pass!(sci::StructuredIRCode)
tracker = AliasTracker()

# Initialize: each argument gets its own alias set
for (idx, argtype) in enumerate(sci.argtypes)
argtype_unwrapped = CC.widenconst(argtype)
if contains_pointers(argtype_unwrapped)
arg_ref = Argument(idx)
tracker[arg_ref] = Set{Any}([arg_ref])
end
end

# Fixed-point iteration
iteration = 0
max_iterations = 100

tracker.dirty = true
while tracker.dirty && iteration < max_iterations
tracker.dirty = false
iteration += 1

analyze_block!(tracker, sci.entry)
end

@debug "Alias analysis converged in $iteration iterations"

return tracker.aliases
end

"""
propagate!(tracker::AliasTracker, from, to)

Propagate alias set from `from` to `to`.
Uses direct assignment when `to` is uninitialized, union otherwise.
"""
function propagate!(tracker::AliasTracker, from, to)
from_aliases = tracker[from]

if from_aliases === ALIAS_UNIVERSE
# Propagating UNIVERSE is always conservative
tracker[to] = ALIAS_UNIVERSE
return
end

if haskey(tracker.aliases, to)
# Target already has an alias set union with it
to_aliases = tracker.aliases[to]
new_aliases = union(from_aliases, to_aliases)
if new_aliases != to_aliases
tracker[to] = new_aliases
end
else
# Target not yet analyzed assign directly
tracker[to] = from_aliases
end
return
end

"""
analyze_block!(tracker::AliasTracker, block)

Analyze all statements in a block, recursing into nested control flow.
"""
function analyze_block!(tracker::AliasTracker, block)
for (ssa_idx, entry) in block.body
if entry.stmt isa ControlFlowOp
analyze_control_flow!(tracker, entry.stmt)
else
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
end
end
return
end

# Recurse into nested control flow regions
function analyze_control_flow!(tracker::AliasTracker, op::IfOp)
analyze_block!(tracker, op.then_region)
return analyze_block!(tracker, op.else_region)
end

function analyze_control_flow!(tracker::AliasTracker, op::ForOp)
return analyze_block!(tracker, op.body)
end

function analyze_control_flow!(tracker::AliasTracker, op::WhileOp)
analyze_block!(tracker, op.before)
return analyze_block!(tracker, op.after)
end

function analyze_control_flow!(tracker::AliasTracker, op::LoopOp)
return analyze_block!(tracker, op.body)
end

# Fallback for unknown control flow ops
function analyze_control_flow!(::AliasTracker, ::ControlFlowOp)
return
end

"""
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)

Analyze a single statement and propagate aliases.
Handles both `:call` and `:invoke` expression forms.
"""
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
call = resolve_call(stmt)
if call !== nothing
resolved_func, operands = call

# Also need the raw func ref for GlobalRef comparisons
func = stmt.head === :call ? stmt.args[1] : stmt.args[2]

# getfield: propagate from parent
if func === GlobalRef(Core, :getfield) && length(operands) >= 1
field = length(operands) >= 2 ? operands[2] : nothing

# For TileArray.ptr field access, propagate pointer alias
if field isa QuoteNode && field.value === :ptr
propagate!(tracker, operands[1], ssa)
else
# Conservatively mark as UNIVERSE for non-pointer fields
tracker[ssa] = ALIAS_UNIVERSE
end

# Pointer arithmetic: propagate from pointer operand
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
for arg in operands
# Find the pointer argument and propagate
arg_aliases = tracker[arg]
if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set
propagate!(tracker, arg, ssa)
break
end
end

# View construction: propagate alias from first operand
elseif is_view_constructor(resolved_func) || is_pointer_passthrough(resolved_func)
if length(operands) >= 1
propagate!(tracker, operands[1], ssa)
end

# Default: unknown operation -> UNIVERSE
else
tracker[ssa] = ALIAS_UNIVERSE
end

elseif stmt isa ReturnNode
# No alias propagation needed

else
# Unknown statement type -> conservative
tracker[ssa] = ALIAS_UNIVERSE
end
return
end

# Helper functions
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)

"""
is_view_constructor(func) -> Bool

Check if a resolved function is a tensor/partition view constructor.
These propagate alias identity from their first operand.
"""
function is_view_constructor(func)
return func === Intrinsics.make_tensor_view ||
func === Intrinsics.make_partition_view
end

function is_pointer_passthrough(func)
func === GlobalRef(Core.Intrinsics, :bitcast) && return true

# Safely check by name to avoid UndefVarError if intrinsics aren't exposed
if func isa Core.IntrinsicFunction || func isa Function
n = nameof(func)
return n === :bitcast || n === :assume_div_by || n === :assume_bounded || n === :assume_aligned
end
return false
end
Loading
Loading