diff --git a/src/bytecode/types.jl b/src/bytecode/types.jl index c38ae07c..2a0de2b8 100644 --- a/src/bytecode/types.jl +++ b/src/bytecode/types.jl @@ -131,10 +131,10 @@ F8E4M3FN(table::TypeTable) = simple_type!(table, SimpleType.F8E4M3FN) F8E5M2(table::TypeTable) = simple_type!(table, SimpleType.F8E5M2) Token(table::TypeTable) = simple_type!(table, SimpleType.Token) -function tile_type!(table::TypeTable, dtype::TypeId, shape::AbstractVector{<:Integer}) +function tile_type!(table::TypeTable, dtype::TypeId, shape::TileShape) buf = UInt8[CompositeType.Tile] encode_varint!(buf, dtype.id) - encode_int_list!(buf, shape, 8) # 8-byte integers + encode_int_list!(buf, collect(shape), 8) # 8-byte integers _get_or_create!(table, buf) end @@ -145,22 +145,22 @@ function pointer_type!(table::TypeTable, pointee::TypeId) end function tensor_view_type!(table::TypeTable, dtype::TypeId, - shape::AbstractVector{<:Integer}, + shape::TileShape, strides::AbstractVector{<:Integer}) buf = UInt8[CompositeType.TensorView] encode_varint!(buf, dtype.id) - encode_int_list!(buf, shape, 8) + encode_int_list!(buf, collect(shape), 8) encode_int_list!(buf, strides, 8) _get_or_create!(table, buf) end function partition_view_type!(table::TypeTable, - tile_shape::AbstractVector{<:Integer}, + tile_shape::TileShape, tensor_view::TypeId, dim_map::AbstractVector{<:Integer}, padding_value::PaddingValue.T) buf = UInt8[CompositeType.PartitionView] - encode_int_list!(buf, tile_shape, 4) # 4-byte integers + encode_int_list!(buf, collect(tile_shape), 4) # 4-byte integers encode_varint!(buf, tensor_view.id) encode_int_list!(buf, dim_map, 4) encode_padding_value!(buf, padding_value) diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 50eed22b..b5a4db3d 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -178,7 +178,7 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), # Map carried values (body.args) for i in 1:n_carries body_arg = body_blk.args[i] - shape = extract_tile_shape(body_arg.type) + shape = RowMajorShape(extract_tile_shape(body_arg.type)) tv = CGVal(block_args[i + 1], result_types[i], body_arg.type, shape) ctx[body_arg] = tv end @@ -240,7 +240,7 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type) # Map carried values (body.args) for i in 1:n_carries body_arg = body_blk.args[i] - shape = extract_tile_shape(body_arg.type) + shape = RowMajorShape(extract_tile_shape(body_arg.type)) ctx[body_arg] = CGVal(block_args[i], result_types[i], body_arg.type, shape) end @@ -315,7 +315,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ # Map carried values (before.args) for i in 1:n_carries before_arg = before_blk.args[i] - shape = extract_tile_shape(before_arg.type) + shape = RowMajorShape(extract_tile_shape(before_arg.type)) ctx[before_arg] = CGVal(block_args[i], result_types[i], before_arg.type, shape) end @@ -368,7 +368,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ if tv !== nothing ctx[after_arg] = tv else - shape = extract_tile_shape(after_arg.type) + shape = RowMajorShape(extract_tile_shape(after_arg.type)) ctx[after_arg] = CGVal(block_args[i], result_types[i], after_arg.type, shape) end end @@ -525,6 +525,6 @@ function emit_loop_getfield!(ctx::CGCtx, args::Vector{Any}) v = ref_cgval.v[field_idx] elem_type = ref_cgval.jltype.parameters[field_idx] type_id = tile_type_for_julia!(ctx, elem_type) - shape = extract_tile_shape(elem_type) + shape = RowMajorShape(extract_tile_shape(elem_type)) CGVal(v, type_id, elem_type, shape) end diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index 0d9c22b9..393811b2 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -2,64 +2,6 @@ # # Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. -#============================================================================= - Type-safe shape wrappers: Julia (column-major) ↔ Tile IR (row-major) -=============================================================================# - -# Tile IR is natively row-major: shapes are stored with the slowest-varying dimension first. -# Julia is column-major: shapes are stored with the fastest-varying dimension first. -# Converting between them is a simple reversal. The Shape{O} wrapper ensures we don't -# accidentally mix up conventions — IR operations accept only RowMajorShape, while -# user-facing shapes from Julia are ColMajorShape. - -abstract type StorageOrder end -struct RowMajor <: StorageOrder end -struct ColMajor <: StorageOrder end - -struct Shape{O<:StorageOrder} - dims::Vector{Int} -end - -const RowMajorShape = Shape{RowMajor} -const ColMajorShape = Shape{ColMajor} - -# Conversion constructors -RowMajorShape(s::RowMajorShape) = s -RowMajorShape(s::ColMajorShape) = RowMajorShape(reverse(s.dims)) -RowMajorShape(t::Tuple) = RowMajorShape(ColMajorShape(collect(Int, t))) - -ColMajorShape(s::ColMajorShape) = s -ColMajorShape(s::RowMajorShape) = ColMajorShape(reverse(s.dims)) -ColMajorShape(t::Tuple) = ColMajorShape(collect(Int, t)) - -# Forward common operations to .dims -Base.length(s::Shape) = length(s.dims) -Base.isempty(s::Shape) = isempty(s.dims) -Base.getindex(s::Shape, i) = s.dims[i] -Base.setindex!(s::Shape, v, i) = (s.dims[i] = v; s) -Base.copy(s::Shape{O}) where O = Shape{O}(copy(s.dims)) -Base.:(==)(a::Shape{O}, b::Shape{O}) where O = a.dims == b.dims -Base.iterate(s::Shape, state...) = iterate(s.dims, state...) -Base.eachindex(s::Shape) = eachindex(s.dims) -Base.collect(s::Shape) = s.dims -TupleType(s::Shape) = Tuple{s.dims...} - -# Scalar (0-D) shape — storage order is irrelevant for zero dimensions -struct ScalarShape end -Base.length(::ScalarShape) = 0 -Base.isempty(::ScalarShape) = true -Base.collect(::ScalarShape) = Int[] -TupleType(::ScalarShape) = Tuple{} - -Base.:(==)(::ScalarShape, ::ScalarShape) = true -Base.:(==)(::ScalarShape, ::Shape) = false -Base.:(==)(::Shape, ::ScalarShape) = false - -# Cross-type conversions (must be after ScalarShape definition) -ColMajorShape(::ScalarShape) = ColMajorShape(Int[]) - -const TileShape = Union{RowMajorShape, ScalarShape} - #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -403,30 +345,30 @@ end function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) # Scalar types -> 0-D tile if T === Bool - return tile_type!(tt, I1(tt), Int[]) + return tile_type!(tt, I1(tt), ScalarShape()) elseif T === Int8 || T === UInt8 - return tile_type!(tt, I8(tt), Int[]) + return tile_type!(tt, I8(tt), ScalarShape()) elseif T === Int16 || T === UInt16 - return tile_type!(tt, I16(tt), Int[]) + return tile_type!(tt, I16(tt), ScalarShape()) elseif T === Int32 || T === UInt32 - return tile_type!(tt, I32(tt), Int[]) + return tile_type!(tt, I32(tt), ScalarShape()) elseif T === Int64 || T === UInt64 - return tile_type!(tt, I64(tt), Int[]) + return tile_type!(tt, I64(tt), ScalarShape()) elseif T === Float16 - return tile_type!(tt, F16(tt), Int[]) + return tile_type!(tt, F16(tt), ScalarShape()) elseif T === BFloat16 - return tile_type!(tt, BF16(tt), Int[]) + return tile_type!(tt, BF16(tt), ScalarShape()) elseif T === Float32 - return tile_type!(tt, F32(tt), Int[]) + return tile_type!(tt, F32(tt), ScalarShape()) elseif T === Float64 - return tile_type!(tt, F64(tt), Int[]) + return tile_type!(tt, F64(tt), ScalarShape()) end # Pointers -> 0-D tile of pointer type if T <: Ptr elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) ptr_type = pointer_type!(tt, elem_dtype) - return tile_type!(tt, ptr_type, Int[]) + return tile_type!(tt, ptr_type, ScalarShape()) end # Tile{T, Shape} -> tile type with shape @@ -440,30 +382,22 @@ function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) throw(IRError("Tile shape must be a tuple, got: $shape_param")) end elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) - shape = RowMajorShape(shape_param) - return tile_type!(tt, elem_dtype, collect(shape)) + shape = RowMajorShape(ColMajorShape(shape_param)) + return tile_type!(tt, elem_dtype, shape) end return nothing end """ - tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, RowMajorShape) + tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, TileShape) Get the Tile IR type and shape for a Julia type. """ function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) actual_type = CC.widenconst(T) type_id = tile_type_for_julia!(ctx, actual_type) - - # Extract shape from Tile types (in Tile IR row-major order) - shape = if actual_type <: Tile - RowMajorShape(size(actual_type)) - else - ScalarShape() - end - - return (type_id, shape) + return (type_id, RowMajorShape(extract_tile_shape(actual_type))) end #============================================================================= @@ -547,15 +481,15 @@ end #----------------------------------------------------------------------------- """ - extract_tile_shape(T) -> RowMajorShape + extract_tile_shape(T) -> Union{ColMajorShape, ScalarShape} -Extract shape from a Tile{T, Shape} type in Tile IR (row-major) order. -Returns empty shape if not a Tile type. +Extract shape from a Tile{T, Shape} type in Julia's column-major convention. +Returns ScalarShape for non-Tile types. """ function extract_tile_shape(@nospecialize(T)) T = CC.widenconst(T) if T <: Tile - return RowMajorShape(size(T)) + return ColMajorShape(size(T)) end ScalarShape() end diff --git a/src/compiler/intrinsics/arithmetic.jl b/src/compiler/intrinsics/arithmetic.jl index e4583292..16af7396 100644 --- a/src/compiler/intrinsics/arithmetic.jl +++ b/src/compiler/intrinsics/arithmetic.jl @@ -52,11 +52,11 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...) dtype = julia_to_tile_dtype!(tt, elem_type) if isempty(lhs_tv.shape) bv = broadcast_tile_to_shape!(cb, tt, lhs_tv, result_shape, dtype) - lhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type, + lhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type, result_shape, nothing, lhs_tv.constant, nothing) elseif isempty(rhs_tv.shape) bv = broadcast_tile_to_shape!(cb, tt, rhs_tv, result_shape, dtype) - rhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type, + rhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type, result_shape, nothing, rhs_tv.constant, nothing) end else @@ -68,7 +68,7 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...) end dtype = julia_to_tile_dtype!(tt, elem_type) - result_type_id = tile_type!(tt, dtype, collect(result_shape)) + result_type_id = tile_type!(tt, dtype, result_shape) result_v = encoder(cb, result_type_id, lhs_tv.v, rhs_tv.v; kwargs...) @@ -88,7 +88,7 @@ function emit_unop!(ctx::CGCtx, args, encoder::Function; kwargs...) result_jltype = source.jltype dtype = julia_to_tile_dtype!(tt, elem_type) - result_type_id = tile_type!(tt, dtype, collect(result_shape)) + result_type_id = tile_type!(tt, dtype, result_shape) result_v = encoder(cb, result_type_id, source.v; kwargs...) @@ -149,7 +149,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpi), args) result_shape = lhs.shape bool_dtype = I1(tt) - result_type_id = tile_type!(tt, bool_dtype, collect(result_shape)) + result_type_id = tile_type!(tt, bool_dtype, result_shape) result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate, signedness) lhs_type = CC.widenconst(lhs.jltype) @@ -295,7 +295,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpf), args) result_shape = lhs.shape bool_dtype = I1(tt) - result_type_id = tile_type!(tt, bool_dtype, collect(result_shape)) + result_type_id = tile_type!(tt, bool_dtype, result_shape) result_v = encode_CmpFOp!(cb, result_type_id, lhs.v, rhs.v; predicate) lhs_type = CC.widenconst(lhs.jltype) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index f994d310..9ea4484b 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -49,7 +49,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) elem_type = eltype(ptr_type) dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, collect(shape)) + result_tile_type = tile_type!(tt, dtype, shape) token_type = Token(tt) # Emit atomic CAS @@ -101,7 +101,7 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. # Create result type dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, collect(shape)) + result_tile_type = tile_type!(tt, dtype, shape) token_type = Token(tt) # Use float add mode for floating point types diff --git a/src/compiler/intrinsics/conversions.jl b/src/compiler/intrinsics/conversions.jl index 53420b08..ba9f7844 100644 --- a/src/compiler/intrinsics/conversions.jl +++ b/src/compiler/intrinsics/conversions.jl @@ -19,7 +19,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.exti), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("exti: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, collect(source.shape)) + result_type_id = tile_type!(tt, dtype, source.shape) result_v = encode_ExtIOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -43,7 +43,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftof), args) target_type = @something get_constant(ctx, args[2]) throw(IRError("ftof: requires compile-time target type")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, collect(source.shape)) + result_type_id = tile_type!(tt, dtype, source.shape) result_v = encode_FToFOp!(cb, result_type_id, source.v) src_type = CC.widenconst(source.jltype) @@ -68,7 +68,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftoi), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("ftoi: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, collect(source.shape)) + result_type_id = tile_type!(tt, dtype, source.shape) result_v = encode_FToIOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -93,7 +93,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.itof), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("itof: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, collect(source.shape)) + result_type_id = tile_type!(tt, dtype, source.shape) result_v = encode_IToFOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -117,7 +117,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.trunci), args) target_type = @something get_constant(ctx, args[2]) throw(IRError("trunci: requires compile-time target type")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, collect(source.shape)) + result_type_id = tile_type!(tt, dtype, source.shape) result_v = encode_TruncIOp!(cb, result_type_id, source.v) src_type = CC.widenconst(source.jltype) diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index ddce599a..103d9573 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -45,7 +45,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.broadcast), args) target_shape_tuple = @something get_constant(ctx, args[2]) throw(IRError("broadcast() shape must be a compile-time constant")) target_shape_tuple isa Tuple || throw(IRError("broadcast() shape must be a tuple, got $(typeof(target_shape_tuple))")) validate_tile_shape(collect(Int, target_shape_tuple), "broadcast") - target_shape = RowMajorShape(target_shape_tuple) + julia_shape = ColMajorShape(target_shape_tuple) + target_shape = RowMajorShape(julia_shape) # If already the right shape, return unchanged if source.shape == target_shape @@ -55,7 +56,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.broadcast), args) # Use the existing broadcast helper dtype = julia_to_tile_dtype!(tt, source_elem) result_v = broadcast_tile_to_shape!(cb, tt, source, target_shape, dtype) - result_type_id = tile_type!(tt, dtype, collect(target_shape)) + result_type_id = tile_type!(tt, dtype, target_shape) CGVal(result_v, result_type_id, Tile{source_elem, Tuple{target_shape_tuple...}}, target_shape) end @@ -83,13 +84,13 @@ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal, if length(current_shape) < length(target_shape) n_extra = length(target_shape) - length(current_shape) current_shape = RowMajorShape(vcat(fill(1, n_extra), collect(current_shape))) - reshaped_type = tile_type!(tt, dtype, collect(current_shape)) + reshaped_type = tile_type!(tt, dtype, current_shape) current_val = encode_ReshapeOp!(cb, reshaped_type, current_val) end # Step 2: Broadcast dimensions that are 1 to target size if current_shape != target_shape - broadcast_type = tile_type!(tt, dtype, collect(target_shape)) + broadcast_type = tile_type!(tt, dtype, target_shape) current_val = encode_BroadcastOp!(cb, broadcast_type, current_val) end @@ -154,7 +155,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cat), args) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, collect(output_shape)) + output_tile_type = tile_type!(tt, dtype, output_shape) # Emit CatOp (Tile IR axis) result = encode_CatOp!(cb, output_tile_type, lhs.v, rhs.v, tileir_axis) @@ -180,13 +181,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.constant), args) shape = @something get_constant(ctx, args[1]) throw(IRError("fill() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("fill() shape must be a tuple, got $(typeof(shape))")) validate_tile_shape(collect(Int, shape), "fill") - tile_shape = RowMajorShape(shape) + tile_shape = RowMajorShape(ColMajorShape(shape)) # Extract dtype from Type{T} argument elem_type = @something get_constant(ctx, args[3]) throw(IRError("constant() requires a compile-time element type")) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, collect(tile_shape)) + tile_type = tile_type!(tt, dtype, tile_shape) tv = emit_value!(ctx, args[2]) tv === nothing && throw(IRError("fill() value must be a constant or a runtime scalar")) @@ -230,17 +231,17 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.extract), args) shape_tuple = @something get_constant(ctx, args[3]) throw(IRError("extract() shape must be a compile-time constant")) shape_tuple isa Tuple || throw(IRError("extract() shape must be a tuple, got $(typeof(shape_tuple))")) validate_tile_shape(collect(Int, shape_tuple), "extract") - output_shape = RowMajorShape(shape_tuple) + output_shape = RowMajorShape(ColMajorShape(shape_tuple)) # Get element type elem_type = eltype(CC.widenconst(source.jltype)) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, collect(output_shape)) + output_tile_type = tile_type!(tt, dtype, output_shape) # Create constant index values (0D i32 tiles), reversed for Tile IR order - scalar_i32 = tile_type!(tt, I32(tt), Int[]) + scalar_i32 = tile_type!(tt, I32(tt), ScalarShape()) index_vals = Value[] for idx in reverse(index_tuple) idx_bytes = collect(reinterpret(UInt8, [Int32(idx)])) @@ -263,7 +264,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_num_tile_blocks), a axis = @something get_constant(ctx, args[1]) throw(IRError("get_num_tile_blocks() axis must be a compile-time constant")) axis in (0, 1, 2) || throw(IRError("get_num_tile_blocks() axis must be 0, 1, or 2, got $axis")) - res_type = tile_type!(ctx.tt, I32(ctx.tt), Int[]) + res_type = tile_type!(ctx.tt, I32(ctx.tt), ScalarShape()) nb_x, nb_y, nb_z = encode_GetNumTileBlocksOp!(ctx.cb, res_type, res_type, res_type) CGVal((nb_x, nb_y, nb_z)[axis + 1], res_type, Int32) @@ -276,7 +277,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_tile_block_id), arg axis = @something get_constant(ctx, args[1]) throw(IRError("get_tile_block_id() axis must be a compile-time constant")) axis in (0, 1, 2) || throw(IRError("get_tile_block_id() axis must be 0, 1, or 2, got $axis")) - res_type = tile_type!(ctx.tt, I32(ctx.tt), Int[]) + res_type = tile_type!(ctx.tt, I32(ctx.tt), ScalarShape()) bid_x, bid_y, bid_z = encode_GetTileBlockIdOp!(ctx.cb, res_type, res_type, res_type) result = (bid_x, bid_y, bid_z)[axis + 1] @@ -302,13 +303,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.iota), args) shape = @something get_constant(ctx, args[1]) throw(IRError("iota() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("iota() shape must be a tuple, got $(typeof(shape))")) validate_tile_shape(collect(Int, shape), "arange") - tile_shape = RowMajorShape(shape) + tile_shape = RowMajorShape(ColMajorShape(shape)) # Extract dtype from Type{T} argument elem_type = @something get_constant(ctx, args[2]) throw(IRError("iota() requires a compile-time element type")) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, collect(tile_shape)) + tile_type = tile_type!(tt, dtype, tile_shape) # Emit IotaOp result = encode_IotaOp!(cb, tile_type) @@ -366,13 +367,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.offset), args) ptr_elem_type = eltype(base_ptr_type) # T from Ptr{T} elem_dtype = julia_to_tile_dtype!(tt, ptr_elem_type) ptr_dtype = pointer_type!(tt, elem_dtype) - ptr_tile_type = tile_type!(tt, ptr_dtype, collect(tile_shape)) + ptr_tile_type = tile_type!(tt, ptr_dtype, tile_shape) # Broadcast base pointer to tile shape ndims = length(tile_shape) if ndims > 0 ones_shape = RowMajorShape(fill(1, ndims)) - reshaped_ptr_type = tile_type!(tt, ptr_dtype, collect(ones_shape)) + reshaped_ptr_type = tile_type!(tt, ptr_dtype, ones_shape) base_ptr_reshaped = encode_ReshapeOp!(cb, reshaped_ptr_type, base_ptr) base_ptr_tile = encode_BroadcastOp!(cb, ptr_tile_type, base_ptr_reshaped) else @@ -430,7 +431,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.permute), args) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, collect(output_shape)) + output_tile_type = tile_type!(tt, dtype, output_shape) # Emit PermuteOp with Tile IR permutation result = encode_PermuteOp!(cb, output_tile_type, source.v, tileir_perm) @@ -509,8 +510,8 @@ function emit_reduce!(ctx::CGCtx, args) push!(elem_types, etype) dtype = julia_to_tile_dtype!(tt, etype) push!(dtypes, dtype) - push!(reduced_tile_types, tile_type!(tt, dtype, collect(reduced_shape))) - push!(scalar_tile_types, tile_type!(tt, dtype, Int[])) + push!(reduced_tile_types, tile_type!(tt, dtype, reduced_shape)) + push!(scalar_tile_types, tile_type!(tt, dtype, ScalarShape())) push!(operand_values, tv.v::Value) push!(identities, make_identity_val(identity_vals[k], dtype, etype)) end @@ -539,7 +540,7 @@ function emit_reduce!(ctx::CGCtx, args) reshaped_values = Value[] component_types = Type[] for (k, res) in enumerate(results) - out_type = tile_type!(tt, dtypes[k], collect(output_shape)) + out_type = tile_type!(tt, dtypes[k], output_shape) reshaped_val = encode_ReshapeOp!(cb, out_type, res) push!(reshaped_values, reshaped_val) push!(component_types, Tile{elem_types[k], TupleType(julia_output)}) @@ -592,7 +593,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args) target_shape_tuple = @something get_constant(ctx, args[2]) throw(IRError("reshape() shape must be a compile-time constant")) target_shape_tuple isa Tuple || throw(IRError("reshape() shape must be a tuple, got $(typeof(target_shape_tuple))")) validate_tile_shape(collect(Int, target_shape_tuple), "reshape") - target_shape = RowMajorShape(target_shape_tuple) + target_shape = RowMajorShape(ColMajorShape(target_shape_tuple)) # Get element type elem_type = eltype(CC.widenconst(source.jltype)) @@ -600,7 +601,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args) # Tile IR shapes are already in row-major order, so ReshapeOp's row-major element # ordering matches directly. No permutes needed! - result_type_id = tile_type!(tt, dtype, collect(target_shape)) + result_type_id = tile_type!(tt, dtype, target_shape) result = encode_ReshapeOp!(cb, result_type_id, source.v) CGVal(result, result_type_id, Tile{elem_type, Tuple{target_shape_tuple...}}, target_shape) @@ -672,8 +673,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) push!(elem_types, etype) dtype = julia_to_tile_dtype!(tt, etype) push!(dtypes, dtype) - push!(output_tile_types, tile_type!(tt, dtype, collect(output_shape))) - push!(scalar_tile_types, tile_type!(tt, dtype, Int[])) + push!(output_tile_types, tile_type!(tt, dtype, output_shape)) + push!(scalar_tile_types, tile_type!(tt, dtype, ScalarShape())) push!(operand_values, tv.v::Value) push!(identities, make_identity_val(identity_vals[k], dtype, etype)) end diff --git a/src/compiler/intrinsics/julia.jl b/src/compiler/intrinsics/julia.jl index 41888cea..94170f97 100644 --- a/src/compiler/intrinsics/julia.jl +++ b/src/compiler/intrinsics/julia.jl @@ -51,7 +51,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(===), args) lhs = @something emit_value!(ctx, args[1]) throw(IRError("===: cannot resolve lhs")) rhs = @something emit_value!(ctx, args[2]) throw(IRError("===: cannot resolve rhs")) - result_type_id = tile_type!(tt, I1(tt), Int[]) + result_type_id = tile_type!(tt, I1(tt), ScalarShape()) result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate=ComparisonPredicate.Equal, signedness=Signedness.Signed) diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index 2bfd34fc..5866bd66 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -29,7 +29,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) ptr_type = eltype(ptrs_type) # Ptr{T} from Tile{Ptr{T}, S} elem_type = eltype(ptr_type) # T from Ptr{T} dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, collect(tile_shape)) + result_tile_type = tile_type!(tt, dtype, tile_shape) token_type = Token(tt) # Extract latency hint (args[2]) diff --git a/src/compiler/intrinsics/misc.jl b/src/compiler/intrinsics/misc.jl index 32c47ef4..78affb92 100644 --- a/src/compiler/intrinsics/misc.jl +++ b/src/compiler/intrinsics/misc.jl @@ -23,7 +23,7 @@ function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value} # Pointer alignment if array_spec.alignment > 0 ptr_dtype = pointer_type!(tt, dtype) - ptr_tile_type = tile_type!(tt, ptr_dtype, Int[]) + ptr_tile_type = tile_type!(tt, ptr_dtype, ScalarShape()) array_val = encode_AssumeOp!(cb, ptr_tile_type, array_val, DivBy(array_spec.alignment)) end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 560a7795..11e1b99f 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -26,7 +26,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_index_space_shape), tileir_axis = ndim - 1 - axis # Create result types for all dimensions - scalar_i32 = tile_type!(tt, I32(tt), Int[]) + scalar_i32 = tile_type!(tt, I32(tt), ScalarShape()) result_types = fill(scalar_i32, ndim) # Emit GetIndexSpaceShapeOp @@ -69,10 +69,10 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a # Reverse to Tile IR row-major order pv_type = CC.widenconst(pv_arg.jltype) elem_type = eltype(pv_type) - tile_shape = RowMajorShape(size(pv_type)) + tile_shape = RowMajorShape(ColMajorShape(size(pv_type))) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, collect(tile_shape)) + tile_type = tile_type!(tt, dtype, tile_shape) token_type = Token(tt) latency = @something get_constant(ctx, args[2]) throw(IRError("load_partition_view(): latency must be a compile-time constant")) @@ -145,7 +145,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_partition_view), a shape = @something get_constant(ctx, args[2]) throw(IRError("make_partition_view() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("make_partition_view() shape must be a tuple, got $(typeof(shape))")) validate_tile_shape(collect(Int, shape), "load") - tile_shape = RowMajorShape(shape) + tile_shape = RowMajorShape(ColMajorShape(shape)) padding_value = if length(args) >= 3 convert_enum(PaddingValue, @something get_constant(ctx, args[3]) throw(IRError("padding_mode must be a compile-time constant"))) @@ -169,7 +169,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_partition_view), a dim_map = [ndim - 1 - julia_dim_map[ndim - i] for i in 0:ndim-1] end - pv_type = partition_view_type!(ctx.tt, collect(tile_shape), tv_type, dim_map, padding_value) + pv_type = partition_view_type!(ctx.tt, tile_shape, tv_type, dim_map, padding_value) partition = encode_MakePartitionViewOp!(ctx.cb, pv_type, tensor_view) CGVal(partition, pv_type, PartitionView{elem_type, ndim, Tuple{shape...}}, ScalarShape(), nothing, Some(ndim), nothing) @@ -249,7 +249,7 @@ function cache_tensor_view!(ctx::CGCtx, arg_idx::Int, # TensorView type (strides also in Tile IR order: last dim = contiguous) tv_shape = RowMajorShape(fill(DYNAMIC_SHAPE, ndim)) tv_strides = compute_tensor_view_strides(spec, ndim) - tv_type = tensor_view_type!(tt, dtype, collect(tv_shape), tv_strides) + tv_type = tensor_view_type!(tt, dtype, tv_shape, tv_strides) # Emit AssumeOps for optimization hints if spec !== nothing @@ -375,7 +375,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), actual_tile_shape = tile_shape if length(tile_shape) == 0 actual_ndim = 1 - actual_tile_shape = [1] + actual_tile_shape = RowMajorShape([1]) tile_1d_type = tile_type!(tt, dtype, actual_tile_shape) tile_val = encode_ReshapeOp!(cb, tile_1d_type, tile_val) end diff --git a/src/cuTile.jl b/src/cuTile.jl index 69543b66..3549b2f3 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -17,6 +17,9 @@ using BFloat16s: BFloat16 using EnumX public BFloat16 +# Shared definitions +include("shapes.jl") + # Bytecode infrastructure include("bytecode/basic.jl") include("bytecode/types.jl") diff --git a/src/shapes.jl b/src/shapes.jl new file mode 100644 index 00000000..d303e772 --- /dev/null +++ b/src/shapes.jl @@ -0,0 +1,45 @@ +# Type-safe shape wrappers: Julia (column-major) ↔ Tile IR (row-major) +# +# Tile IR is natively row-major: shapes are stored with the slowest-varying dimension first. +# Julia is column-major: shapes are stored with the fastest-varying dimension first. +# Converting between them is a simple reversal. The Shape{O} wrapper ensures we don't +# accidentally mix up conventions — IR operations accept only RowMajorShape, while +# user-facing shapes from Julia are ColMajorShape. + +abstract type ShapeKind end +struct RowMajor <: ShapeKind end +struct ColMajor <: ShapeKind end +struct Scalar <: ShapeKind end + +struct Shape{O<:ShapeKind} + dims::Vector{Int} +end + +const ScalarShape = Shape{Scalar} +const RowMajorShape = Shape{RowMajor} +const ColMajorShape = Shape{ColMajor} +const TileShape = Shape{<:Union{RowMajor, Scalar}} + +ScalarShape() = Shape{Scalar}(Int[]) + +RowMajorShape(t::Tuple) = RowMajorShape(collect(Int, t)) +RowMajorShape(s::ScalarShape) = RowMajorShape(s.dims) +RowMajorShape(s::RowMajorShape) = s +RowMajorShape(s::ColMajorShape) = RowMajorShape(reverse(s.dims)) + +ColMajorShape(t::Tuple) = ColMajorShape(collect(Int, t)) +ColMajorShape(s::ScalarShape) = ColMajorShape(s.dims) +ColMajorShape(s::ColMajorShape) = s +ColMajorShape(s::RowMajorShape) = ColMajorShape(reverse(s.dims)) + +# Forward common operations to .dims +Base.length(s::Shape) = length(s.dims) +Base.isempty(s::Shape) = isempty(s.dims) +Base.getindex(s::Shape, i) = s.dims[i] +Base.setindex!(s::Shape, v, i) = (s.dims[i] = v; s) +Base.copy(s::Shape{O}) where O = Shape{O}(copy(s.dims)) +Base.:(==)(a::Shape{O}, b::Shape{O}) where O = a.dims == b.dims +Base.iterate(s::Shape, state...) = iterate(s.dims, state...) +Base.eachindex(s::Shape) = eachindex(s.dims) +Base.collect(s::Shape) = s.dims +TupleType(s::Shape) = Tuple{s.dims...}