Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
# Scalar: emit ConstantOp
bytes = constant_to_bytes(val, T)
v = encode_ConstantOp!(ctx.cb, type_id, bytes)
tv = CGVal(v, type_id, T, Int[], nothing, Some(val), nothing)
tv = CGVal(v, type_id, T, ScalarShape(), nothing, Some(val), nothing)
else
# Non-primitive (tuple etc.): ghost with constant
tv = ghost_value(T, val)
Expand Down
95 changes: 77 additions & 18 deletions src/compiler/codegen/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,64 @@
#
# Core types (CGVal, CGCtx) and helper functions for Tile IR code generation.

#=============================================================================
Type-safe shape wrappers: Julia (column-major) ↔ Tile IR (row-major)
=============================================================================#

# Tile IR is natively row-major: shapes are stored with the slowest-varying dimension first.
# Julia is column-major: shapes are stored with the fastest-varying dimension first.
# Converting between them is a simple reversal. The Shape{O} wrapper ensures we don't
# accidentally mix up conventions — IR operations accept only RowMajorShape, while
# user-facing shapes from Julia are ColMajorShape.

abstract type StorageOrder end
struct RowMajor <: StorageOrder end
struct ColMajor <: StorageOrder end

struct Shape{O<:StorageOrder}
dims::Vector{Int}
end

const RowMajorShape = Shape{RowMajor}
const ColMajorShape = Shape{ColMajor}

# Conversion constructors
RowMajorShape(s::RowMajorShape) = s
RowMajorShape(s::ColMajorShape) = RowMajorShape(reverse(s.dims))
RowMajorShape(t::Tuple) = RowMajorShape(ColMajorShape(collect(Int, t)))

ColMajorShape(s::ColMajorShape) = s
ColMajorShape(s::RowMajorShape) = ColMajorShape(reverse(s.dims))
ColMajorShape(t::Tuple) = ColMajorShape(collect(Int, t))

# Forward common operations to .dims
Base.length(s::Shape) = length(s.dims)
Base.isempty(s::Shape) = isempty(s.dims)
Base.getindex(s::Shape, i) = s.dims[i]
Base.setindex!(s::Shape, v, i) = (s.dims[i] = v; s)
Base.copy(s::Shape{O}) where O = Shape{O}(copy(s.dims))
Base.:(==)(a::Shape{O}, b::Shape{O}) where O = a.dims == b.dims
Base.iterate(s::Shape, state...) = iterate(s.dims, state...)
Base.eachindex(s::Shape) = eachindex(s.dims)
Base.collect(s::Shape) = s.dims
TupleType(s::Shape) = Tuple{s.dims...}

# Scalar (0-D) shape — storage order is irrelevant for zero dimensions
struct ScalarShape end
Base.length(::ScalarShape) = 0
Base.isempty(::ScalarShape) = true
Base.collect(::ScalarShape) = Int[]
TupleType(::ScalarShape) = Tuple{}

Base.:(==)(::ScalarShape, ::ScalarShape) = true
Base.:(==)(::ScalarShape, ::Shape) = false
Base.:(==)(::Shape, ::ScalarShape) = false

# Cross-type conversions (must be after ScalarShape definition)
ColMajorShape(::ScalarShape) = ColMajorShape(Int[])

const TileShape = Union{RowMajorShape, ScalarShape}

#=============================================================================
IRError: Exception type for IR compilation errors
=============================================================================#
Expand Down Expand Up @@ -69,7 +127,7 @@ struct CGVal
v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing
type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value)
jltype::Any # Original Julia type
shape::Vector{Int} # Tile shape (empty for scalars)
shape::TileShape # Tile shape (ScalarShape for scalars)
# Lazy argument reference: (arg_idx, [field_indices...])
# e.g., (1, [2, 1]) means "argument 1, field 2, sub-field 1"
arg_ref::Union{Tuple{Int, Vector{Int}}, Nothing}
Expand All @@ -79,18 +137,18 @@ end

# Convenience constructors for concrete values
CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) =
CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing)
CGVal(v, type_id, jltype, ScalarShape(), nothing, nothing, nothing)

CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) =
CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::TileShape) =
CGVal(v, type_id, jltype, shape, nothing, nothing, nothing)

# Constructor for multi-value results (from loops, ifs)
CGVal(v::Vector{Value}, @nospecialize(jltype)) =
CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing)
CGVal(v, nothing, jltype, ScalarShape(), nothing, nothing, nothing)

# Constructor for lazy argument references
function arg_ref_value(arg_idx::Int, chain::Vector{Int}, @nospecialize(jltype))
CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing)
CGVal(nothing, nothing, jltype, ScalarShape(), (arg_idx, chain), nothing, nothing)
end

"""
Expand All @@ -99,8 +157,8 @@ end
Create a ghost value (zero-size singleton with no runtime representation).
Optionally stores a compile-time constant value.
"""
ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing)
ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing)
ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, nothing, nothing)
ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, Some(constant), nothing)

"""
tuple_value(jltype, component_refs, component_constants) -> CGVal
Expand All @@ -115,7 +173,7 @@ function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, compone
else
nothing
end
CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs)
CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, constant, component_refs)
end

"""
Expand Down Expand Up @@ -382,27 +440,27 @@ function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type))
throw(IRError("Tile shape must be a tuple, got: $shape_param"))
end
elem_dtype = julia_to_tile_dtype!(tt, eltype(T))
shape = collect(Int, shape_param)
return tile_type!(tt, elem_dtype, shape)
shape = RowMajorShape(shape_param)
return tile_type!(tt, elem_dtype, collect(shape))
end

return nothing
end

"""
tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int})
tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, RowMajorShape)

Get the Tile IR type and shape for a Julia type.
"""
function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T))
actual_type = CC.widenconst(T)
type_id = tile_type_for_julia!(ctx, actual_type)

# Extract shape from Tile types
# Extract shape from Tile types (in Tile IR row-major order)
shape = if actual_type <: Tile
collect(Int, size(actual_type))
RowMajorShape(size(actual_type))
else
Int[]
ScalarShape()
end

return (type_id, shape)
Expand Down Expand Up @@ -489,14 +547,15 @@ end
#-----------------------------------------------------------------------------

"""
extract_tile_shape(T) -> Vector{Int}
extract_tile_shape(T) -> RowMajorShape

Extract shape from a Tile{T, Shape} type, returning Int[] if not a Tile type.
Extract shape from a Tile{T, Shape} type in Tile IR (row-major) order.
Returns empty shape if not a Tile type.
"""
function extract_tile_shape(@nospecialize(T))
T = CC.widenconst(T)
if T <: Tile
return collect(Int, size(T))
return RowMajorShape(size(T))
end
Int[]
ScalarShape()
end
6 changes: 3 additions & 3 deletions src/compiler/codegen/values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ function emit_value!(ctx::CGCtx, val::Integer)
type_id = tile_type_for_julia!(ctx, jltype)
bytes = reinterpret(UInt8, [jltype(val)])
v = encode_ConstantOp!(ctx.cb, type_id, collect(bytes))
CGVal(v, type_id, jltype, Int[], nothing, Some(val), nothing)
CGVal(v, type_id, jltype, ScalarShape(), nothing, Some(val), nothing)
end

function emit_value!(ctx::CGCtx, val::AbstractFloat)
jltype = typeof(val)
type_id = tile_type_for_julia!(ctx, jltype)
bytes = reinterpret(UInt8, [jltype(val)])
v = encode_ConstantOp!(ctx.cb, type_id, collect(bytes))
CGVal(v, type_id, jltype, Int[], nothing, Some(val), nothing)
CGVal(v, type_id, jltype, ScalarShape(), nothing, Some(val), nothing)
end

function emit_value!(ctx::CGCtx, node::QuoteNode)
Expand Down Expand Up @@ -67,7 +67,7 @@ function emit_value!(ctx::CGCtx, ref::GlobalRef)
if type_id !== nothing
bytes = constant_to_bytes(val, T)
v = encode_ConstantOp!(ctx.cb, type_id, bytes)
return CGVal(v, type_id, T, Int[], nothing, Some(val), nothing)
return CGVal(v, type_id, T, ScalarShape(), nothing, Some(val), nothing)
end
end
ghost_value(T, val)
Expand Down
14 changes: 7 additions & 7 deletions src/compiler/intrinsics/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,23 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...)
dtype = julia_to_tile_dtype!(tt, elem_type)
if isempty(lhs_tv.shape)
bv = broadcast_tile_to_shape!(cb, tt, lhs_tv, result_shape, dtype)
lhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type,
lhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type,
result_shape, nothing, lhs_tv.constant, nothing)
elseif isempty(rhs_tv.shape)
bv = broadcast_tile_to_shape!(cb, tt, rhs_tv, result_shape, dtype)
rhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type,
rhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type,
result_shape, nothing, rhs_tv.constant, nothing)
end
else
result_shape = Int[]
result_shape = ScalarShape()
end
result_jltype = lhs_tv.jltype
else
throw(IRError("Mixed tile/scalar operations should be handled at intrinsic level via Tile() and broadcast_to()"))
end

dtype = julia_to_tile_dtype!(tt, elem_type)
result_type_id = tile_type!(tt, dtype, result_shape)
result_type_id = tile_type!(tt, dtype, collect(result_shape))

result_v = encoder(cb, result_type_id, lhs_tv.v, rhs_tv.v; kwargs...)

Expand All @@ -88,7 +88,7 @@ function emit_unop!(ctx::CGCtx, args, encoder::Function; kwargs...)
result_jltype = source.jltype

dtype = julia_to_tile_dtype!(tt, elem_type)
result_type_id = tile_type!(tt, dtype, result_shape)
result_type_id = tile_type!(tt, dtype, collect(result_shape))

result_v = encoder(cb, result_type_id, source.v; kwargs...)

Expand Down Expand Up @@ -149,7 +149,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpi), args)
result_shape = lhs.shape

bool_dtype = I1(tt)
result_type_id = tile_type!(tt, bool_dtype, result_shape)
result_type_id = tile_type!(tt, bool_dtype, collect(result_shape))

result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate, signedness)
lhs_type = CC.widenconst(lhs.jltype)
Expand Down Expand Up @@ -295,7 +295,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpf), args)
result_shape = lhs.shape

bool_dtype = I1(tt)
result_type_id = tile_type!(tt, bool_dtype, result_shape)
result_type_id = tile_type!(tt, bool_dtype, collect(result_shape))

result_v = encode_CmpFOp!(cb, result_type_id, lhs.v, rhs.v; predicate)
lhs_type = CC.widenconst(lhs.jltype)
Expand Down
6 changes: 4 additions & 2 deletions src/compiler/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
end
ctx.token = new_token

CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
julia_shape = ColMajorShape(shape)
CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape)
end

# cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations)
Expand Down Expand Up @@ -129,7 +130,8 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode.
end
ctx.token = new_token

CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
julia_shape = ColMajorShape(shape)
CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape)
end

# cuda_tile.atomic_rmw_tko variants
Expand Down
10 changes: 5 additions & 5 deletions src/compiler/intrinsics/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.exti), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("exti: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, source.shape)
result_type_id = tile_type!(tt, dtype, collect(source.shape))

result_v = encode_ExtIOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -43,7 +43,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftof), args)
target_type = @something get_constant(ctx, args[2]) throw(IRError("ftof: requires compile-time target type"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, source.shape)
result_type_id = tile_type!(tt, dtype, collect(source.shape))

result_v = encode_FToFOp!(cb, result_type_id, source.v)
src_type = CC.widenconst(source.jltype)
Expand All @@ -68,7 +68,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftoi), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("ftoi: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, source.shape)
result_type_id = tile_type!(tt, dtype, collect(source.shape))

result_v = encode_FToIOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -93,7 +93,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.itof), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("itof: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, source.shape)
result_type_id = tile_type!(tt, dtype, collect(source.shape))

result_v = encode_IToFOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -117,7 +117,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.trunci), args)
target_type = @something get_constant(ctx, args[2]) throw(IRError("trunci: requires compile-time target type"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, source.shape)
result_type_id = tile_type!(tt, dtype, collect(source.shape))

result_v = encode_TruncIOp!(cb, result_type_id, source.v)
src_type = CC.widenconst(source.jltype)
Expand Down
Loading
Loading