Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/bytecode/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ F8E4M3FN(table::TypeTable) = simple_type!(table, SimpleType.F8E4M3FN)
F8E5M2(table::TypeTable) = simple_type!(table, SimpleType.F8E5M2)
Token(table::TypeTable) = simple_type!(table, SimpleType.Token)

function tile_type!(table::TypeTable, dtype::TypeId, shape::AbstractVector{<:Integer})
function tile_type!(table::TypeTable, dtype::TypeId, shape::TileShape)
buf = UInt8[CompositeType.Tile]
encode_varint!(buf, dtype.id)
encode_int_list!(buf, shape, 8) # 8-byte integers
encode_int_list!(buf, collect(shape), 8) # 8-byte integers
_get_or_create!(table, buf)
end

Expand All @@ -145,22 +145,22 @@ function pointer_type!(table::TypeTable, pointee::TypeId)
end

function tensor_view_type!(table::TypeTable, dtype::TypeId,
shape::AbstractVector{<:Integer},
shape::TileShape,
strides::AbstractVector{<:Integer})
buf = UInt8[CompositeType.TensorView]
encode_varint!(buf, dtype.id)
encode_int_list!(buf, shape, 8)
encode_int_list!(buf, collect(shape), 8)
encode_int_list!(buf, strides, 8)
_get_or_create!(table, buf)
end

function partition_view_type!(table::TypeTable,
tile_shape::AbstractVector{<:Integer},
tile_shape::TileShape,
tensor_view::TypeId,
dim_map::AbstractVector{<:Integer},
padding_value::PaddingValue.T)
buf = UInt8[CompositeType.PartitionView]
encode_int_list!(buf, tile_shape, 4) # 4-byte integers
encode_int_list!(buf, collect(tile_shape), 4) # 4-byte integers
encode_varint!(buf, tensor_view.id)
encode_int_list!(buf, dim_map, 4)
encode_padding_value!(buf, padding_value)
Expand Down
10 changes: 5 additions & 5 deletions src/compiler/codegen/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
# Map carried values (body.args)
for i in 1:n_carries
body_arg = body_blk.args[i]
shape = extract_tile_shape(body_arg.type)
shape = RowMajorShape(extract_tile_shape(body_arg.type))
tv = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape)
ctx[body_arg] = tv
end
Expand Down Expand Up @@ -240,7 +240,7 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
# Map carried values (body.args)
for i in 1:n_carries
body_arg = body_blk.args[i]
shape = extract_tile_shape(body_arg.type)
shape = RowMajorShape(extract_tile_shape(body_arg.type))
ctx[body_arg] = CGVal(block_args[i], result_types[i], body_arg.type, shape)
end

Expand Down Expand Up @@ -315,7 +315,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
# Map carried values (before.args)
for i in 1:n_carries
before_arg = before_blk.args[i]
shape = extract_tile_shape(before_arg.type)
shape = RowMajorShape(extract_tile_shape(before_arg.type))
ctx[before_arg] = CGVal(block_args[i], result_types[i], before_arg.type, shape)
end

Expand Down Expand Up @@ -368,7 +368,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
if tv !== nothing
ctx[after_arg] = tv
else
shape = extract_tile_shape(after_arg.type)
shape = RowMajorShape(extract_tile_shape(after_arg.type))
ctx[after_arg] = CGVal(block_args[i], result_types[i], after_arg.type, shape)
end
end
Expand Down Expand Up @@ -525,6 +525,6 @@ function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any})
v = ref_cgval.v[field_idx]
elem_type = ref_cgval.jltype.parameters[field_idx]
type_id = tile_type_for_julia!(ctx, elem_type)
shape = extract_tile_shape(elem_type)
shape = RowMajorShape(extract_tile_shape(elem_type))
CGVal(v, type_id, elem_type, shape)
end
102 changes: 18 additions & 84 deletions src/compiler/codegen/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,6 @@
#
# Core types (CGVal, CGCtx) and helper functions for Tile IR code generation.

#=============================================================================
Type-safe shape wrappers: Julia (column-major) ↔ Tile IR (row-major)
=============================================================================#

# Tile IR is natively row-major: shapes are stored with the slowest-varying dimension first.
# Julia is column-major: shapes are stored with the fastest-varying dimension first.
# Converting between them is a simple reversal. The Shape{O} wrapper ensures we don't
# accidentally mix up conventions — IR operations accept only RowMajorShape, while
# user-facing shapes from Julia are ColMajorShape.

abstract type StorageOrder end
struct RowMajor <: StorageOrder end
struct ColMajor <: StorageOrder end

struct Shape{O<:StorageOrder}
dims::Vector{Int}
end

const RowMajorShape = Shape{RowMajor}
const ColMajorShape = Shape{ColMajor}

# Conversion constructors
RowMajorShape(s::RowMajorShape) = s
RowMajorShape(s::ColMajorShape) = RowMajorShape(reverse(s.dims))
RowMajorShape(t::Tuple) = RowMajorShape(ColMajorShape(collect(Int, t)))

ColMajorShape(s::ColMajorShape) = s
ColMajorShape(s::RowMajorShape) = ColMajorShape(reverse(s.dims))
ColMajorShape(t::Tuple) = ColMajorShape(collect(Int, t))

# Forward common operations to .dims
Base.length(s::Shape) = length(s.dims)
Base.isempty(s::Shape) = isempty(s.dims)
Base.getindex(s::Shape, i) = s.dims[i]
Base.setindex!(s::Shape, v, i) = (s.dims[i] = v; s)
Base.copy(s::Shape{O}) where O = Shape{O}(copy(s.dims))
Base.:(==)(a::Shape{O}, b::Shape{O}) where O = a.dims == b.dims
Base.iterate(s::Shape, state...) = iterate(s.dims, state...)
Base.eachindex(s::Shape) = eachindex(s.dims)
Base.collect(s::Shape) = s.dims
TupleType(s::Shape) = Tuple{s.dims...}

# Scalar (0-D) shape — storage order is irrelevant for zero dimensions
struct ScalarShape end
Base.length(::ScalarShape) = 0
Base.isempty(::ScalarShape) = true
Base.collect(::ScalarShape) = Int[]
TupleType(::ScalarShape) = Tuple{}

Base.:(==)(::ScalarShape, ::ScalarShape) = true
Base.:(==)(::ScalarShape, ::Shape) = false
Base.:(==)(::Shape, ::ScalarShape) = false

# Cross-type conversions (must be after ScalarShape definition)
ColMajorShape(::ScalarShape) = ColMajorShape(Int[])

const TileShape = Union{RowMajorShape, ScalarShape}

#=============================================================================
IRError: Exception type for IR compilation errors
=============================================================================#
Expand Down Expand Up @@ -403,30 +345,30 @@ end
function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type))
# Scalar types -> 0-D tile
if T === Bool
return tile_type!(tt, I1(tt), Int[])
return tile_type!(tt, I1(tt), ScalarShape())
elseif T === Int8 || T === UInt8
return tile_type!(tt, I8(tt), Int[])
return tile_type!(tt, I8(tt), ScalarShape())
elseif T === Int16 || T === UInt16
return tile_type!(tt, I16(tt), Int[])
return tile_type!(tt, I16(tt), ScalarShape())
elseif T === Int32 || T === UInt32
return tile_type!(tt, I32(tt), Int[])
return tile_type!(tt, I32(tt), ScalarShape())
elseif T === Int64 || T === UInt64
return tile_type!(tt, I64(tt), Int[])
return tile_type!(tt, I64(tt), ScalarShape())
elseif T === Float16
return tile_type!(tt, F16(tt), Int[])
return tile_type!(tt, F16(tt), ScalarShape())
elseif T === BFloat16
return tile_type!(tt, BF16(tt), Int[])
return tile_type!(tt, BF16(tt), ScalarShape())
elseif T === Float32
return tile_type!(tt, F32(tt), Int[])
return tile_type!(tt, F32(tt), ScalarShape())
elseif T === Float64
return tile_type!(tt, F64(tt), Int[])
return tile_type!(tt, F64(tt), ScalarShape())
end

# Pointers -> 0-D tile of pointer type
if T <: Ptr
elem_dtype = julia_to_tile_dtype!(tt, eltype(T))
ptr_type = pointer_type!(tt, elem_dtype)
return tile_type!(tt, ptr_type, Int[])
return tile_type!(tt, ptr_type, ScalarShape())
end

# Tile{T, Shape} -> tile type with shape
Expand All @@ -440,30 +382,22 @@ function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type))
throw(IRError("Tile shape must be a tuple, got: $shape_param"))
end
elem_dtype = julia_to_tile_dtype!(tt, eltype(T))
shape = RowMajorShape(shape_param)
return tile_type!(tt, elem_dtype, collect(shape))
shape = RowMajorShape(ColMajorShape(shape_param))
return tile_type!(tt, elem_dtype, shape)
end

return nothing
end

"""
tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, RowMajorShape)
tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, TileShape)

Get the Tile IR type and shape for a Julia type.
"""
function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T))
actual_type = CC.widenconst(T)
type_id = tile_type_for_julia!(ctx, actual_type)

# Extract shape from Tile types (in Tile IR row-major order)
shape = if actual_type <: Tile
RowMajorShape(size(actual_type))
else
ScalarShape()
end

return (type_id, shape)
return (type_id, RowMajorShape(extract_tile_shape(actual_type)))
end

#=============================================================================
Expand Down Expand Up @@ -547,15 +481,15 @@ end
#-----------------------------------------------------------------------------

"""
extract_tile_shape(T) -> RowMajorShape
extract_tile_shape(T) -> Union{ColMajorShape, ScalarShape}

Extract shape from a Tile{T, Shape} type in Tile IR (row-major) order.
Returns empty shape if not a Tile type.
Extract shape from a Tile{T, Shape} type in Julia's column-major convention.
Returns ScalarShape for non-Tile types.
"""
function extract_tile_shape(@nospecialize(T))
T = CC.widenconst(T)
if T <: Tile
return RowMajorShape(size(T))
return ColMajorShape(size(T))
end
ScalarShape()
end
12 changes: 6 additions & 6 deletions src/compiler/intrinsics/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...)
dtype = julia_to_tile_dtype!(tt, elem_type)
if isempty(lhs_tv.shape)
bv = broadcast_tile_to_shape!(cb, tt, lhs_tv, result_shape, dtype)
lhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type,
lhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type,
result_shape, nothing, lhs_tv.constant, nothing)
elseif isempty(rhs_tv.shape)
bv = broadcast_tile_to_shape!(cb, tt, rhs_tv, result_shape, dtype)
rhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type,
rhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type,
result_shape, nothing, rhs_tv.constant, nothing)
end
else
Expand All @@ -68,7 +68,7 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...)
end

dtype = julia_to_tile_dtype!(tt, elem_type)
result_type_id = tile_type!(tt, dtype, collect(result_shape))
result_type_id = tile_type!(tt, dtype, result_shape)

result_v = encoder(cb, result_type_id, lhs_tv.v, rhs_tv.v; kwargs...)

Expand All @@ -88,7 +88,7 @@ function emit_unop!(ctx::CGCtx, args, encoder::Function; kwargs...)
result_jltype = source.jltype

dtype = julia_to_tile_dtype!(tt, elem_type)
result_type_id = tile_type!(tt, dtype, collect(result_shape))
result_type_id = tile_type!(tt, dtype, result_shape)

result_v = encoder(cb, result_type_id, source.v; kwargs...)

Expand Down Expand Up @@ -149,7 +149,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpi), args)
result_shape = lhs.shape

bool_dtype = I1(tt)
result_type_id = tile_type!(tt, bool_dtype, collect(result_shape))
result_type_id = tile_type!(tt, bool_dtype, result_shape)

result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate, signedness)
lhs_type = CC.widenconst(lhs.jltype)
Expand Down Expand Up @@ -295,7 +295,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpf), args)
result_shape = lhs.shape

bool_dtype = I1(tt)
result_type_id = tile_type!(tt, bool_dtype, collect(result_shape))
result_type_id = tile_type!(tt, bool_dtype, result_shape)

result_v = encode_CmpFOp!(cb, result_type_id, lhs.v, rhs.v; predicate)
lhs_type = CC.widenconst(lhs.jltype)
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args)
elem_type = eltype(ptr_type)

dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
result_tile_type = tile_type!(tt, dtype, shape)
token_type = Token(tt)

# Emit atomic CAS
Expand Down Expand Up @@ -101,7 +101,7 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode.

# Create result type
dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
result_tile_type = tile_type!(tt, dtype, shape)
token_type = Token(tt)

# Use float add mode for floating point types
Expand Down
10 changes: 5 additions & 5 deletions src/compiler/intrinsics/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.exti), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("exti: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, collect(source.shape))
result_type_id = tile_type!(tt, dtype, source.shape)

result_v = encode_ExtIOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -43,7 +43,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftof), args)
target_type = @something get_constant(ctx, args[2]) throw(IRError("ftof: requires compile-time target type"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, collect(source.shape))
result_type_id = tile_type!(tt, dtype, source.shape)

result_v = encode_FToFOp!(cb, result_type_id, source.v)
src_type = CC.widenconst(source.jltype)
Expand All @@ -68,7 +68,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftoi), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("ftoi: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, collect(source.shape))
result_type_id = tile_type!(tt, dtype, source.shape)

result_v = encode_FToIOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -93,7 +93,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.itof), args)
signedness = @something get_constant(ctx, args[3]) throw(IRError("itof: requires compile-time signedness"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, collect(source.shape))
result_type_id = tile_type!(tt, dtype, source.shape)

result_v = encode_IToFOp!(cb, result_type_id, source.v; signedness)
src_type = CC.widenconst(source.jltype)
Expand All @@ -117,7 +117,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.trunci), args)
target_type = @something get_constant(ctx, args[2]) throw(IRError("trunci: requires compile-time target type"))

dtype = julia_to_tile_dtype!(tt, target_type)
result_type_id = tile_type!(tt, dtype, collect(source.shape))
result_type_id = tile_type!(tt, dtype, source.shape)

result_v = encode_TruncIOp!(cb, result_type_id, source.v)
src_type = CC.widenconst(source.jltype)
Expand Down
Loading
Loading