diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index db1e6344..79e9a61d 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -119,7 +119,7 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Scalar: emit ConstantOp bytes = constant_to_bytes(val, T) v = encode_ConstantOp!(ctx.cb, type_id, bytes) - tv = CGVal(v, type_id, T, Int[], nothing, Some(val), nothing) + tv = CGVal(v, type_id, T, ScalarShape(), nothing, Some(val), nothing) else # Non-primitive (tuple etc.): ghost with constant tv = ghost_value(T, val) diff --git a/src/compiler/codegen/utils.jl b/src/compiler/codegen/utils.jl index 540d3c13..0d9c22b9 100644 --- a/src/compiler/codegen/utils.jl +++ b/src/compiler/codegen/utils.jl @@ -2,6 +2,64 @@ # # Core types (CGVal, CGCtx) and helper functions for Tile IR code generation. +#============================================================================= + Type-safe shape wrappers: Julia (column-major) ↔ Tile IR (row-major) +=============================================================================# + +# Tile IR is natively row-major: shapes are stored with the slowest-varying dimension first. +# Julia is column-major: shapes are stored with the fastest-varying dimension first. +# Converting between them is a simple reversal. The Shape{O} wrapper ensures we don't +# accidentally mix up conventions — IR operations accept only RowMajorShape, while +# user-facing shapes from Julia are ColMajorShape. + +abstract type StorageOrder end +struct RowMajor <: StorageOrder end +struct ColMajor <: StorageOrder end + +struct Shape{O<:StorageOrder} + dims::Vector{Int} +end + +const RowMajorShape = Shape{RowMajor} +const ColMajorShape = Shape{ColMajor} + +# Conversion constructors +RowMajorShape(s::RowMajorShape) = s +RowMajorShape(s::ColMajorShape) = RowMajorShape(reverse(s.dims)) +RowMajorShape(t::Tuple) = RowMajorShape(ColMajorShape(collect(Int, t))) + +ColMajorShape(s::ColMajorShape) = s +ColMajorShape(s::RowMajorShape) = ColMajorShape(reverse(s.dims)) +ColMajorShape(t::Tuple) = ColMajorShape(collect(Int, t)) + +# Forward common operations to .dims +Base.length(s::Shape) = length(s.dims) +Base.isempty(s::Shape) = isempty(s.dims) +Base.getindex(s::Shape, i) = s.dims[i] +Base.setindex!(s::Shape, v, i) = (s.dims[i] = v; s) +Base.copy(s::Shape{O}) where O = Shape{O}(copy(s.dims)) +Base.:(==)(a::Shape{O}, b::Shape{O}) where O = a.dims == b.dims +Base.iterate(s::Shape, state...) = iterate(s.dims, state...) +Base.eachindex(s::Shape) = eachindex(s.dims) +Base.collect(s::Shape) = s.dims +TupleType(s::Shape) = Tuple{s.dims...} + +# Scalar (0-D) shape — storage order is irrelevant for zero dimensions +struct ScalarShape end +Base.length(::ScalarShape) = 0 +Base.isempty(::ScalarShape) = true +Base.collect(::ScalarShape) = Int[] +TupleType(::ScalarShape) = Tuple{} + +Base.:(==)(::ScalarShape, ::ScalarShape) = true +Base.:(==)(::ScalarShape, ::Shape) = false +Base.:(==)(::Shape, ::ScalarShape) = false + +# Cross-type conversions (must be after ScalarShape definition) +ColMajorShape(::ScalarShape) = ColMajorShape(Int[]) + +const TileShape = Union{RowMajorShape, ScalarShape} + #============================================================================= IRError: Exception type for IR compilation errors =============================================================================# @@ -69,7 +127,7 @@ struct CGVal v::Union{Value, Vector{Value}, Nothing} # Single value, multi-value, or nothing type_id::Union{TypeId, Nothing} # Tile IR type (nothing for lazy refs or multi-value) jltype::Any # Original Julia type - shape::Vector{Int} # Tile shape (empty for scalars) + shape::TileShape # Tile shape (ScalarShape for scalars) # Lazy argument reference: (arg_idx, [field_indices...]) # e.g., (1, [2, 1]) means "argument 1, field 2, sub-field 1" arg_ref::Union{Tuple{Int, Vector{Int}}, Nothing} @@ -79,18 +137,18 @@ end # Convenience constructors for concrete values CGVal(v::Value, type_id::TypeId, @nospecialize(jltype)) = - CGVal(v, type_id, jltype, Int[], nothing, nothing, nothing) + CGVal(v, type_id, jltype, ScalarShape(), nothing, nothing, nothing) -CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::Vector{Int}) = +CGVal(v::Value, type_id::TypeId, @nospecialize(jltype), shape::TileShape) = CGVal(v, type_id, jltype, shape, nothing, nothing, nothing) # Constructor for multi-value results (from loops, ifs) CGVal(v::Vector{Value}, @nospecialize(jltype)) = - CGVal(v, nothing, jltype, Int[], nothing, nothing, nothing) + CGVal(v, nothing, jltype, ScalarShape(), nothing, nothing, nothing) # Constructor for lazy argument references function arg_ref_value(arg_idx::Int, chain::Vector{Int}, @nospecialize(jltype)) - CGVal(nothing, nothing, jltype, Int[], (arg_idx, chain), nothing, nothing) + CGVal(nothing, nothing, jltype, ScalarShape(), (arg_idx, chain), nothing, nothing) end """ @@ -99,8 +157,8 @@ end Create a ghost value (zero-size singleton with no runtime representation). Optionally stores a compile-time constant value. """ -ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, nothing, nothing) -ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, Int[], nothing, Some(constant), nothing) +ghost_value(@nospecialize(jltype)) = CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, nothing, nothing) +ghost_value(@nospecialize(jltype), constant) = CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, Some(constant), nothing) """ tuple_value(jltype, component_refs, component_constants) -> CGVal @@ -115,7 +173,7 @@ function tuple_value(@nospecialize(jltype), component_refs::Vector{Any}, compone else nothing end - CGVal(nothing, TypeId(-1), jltype, Int[], nothing, constant, component_refs) + CGVal(nothing, TypeId(-1), jltype, ScalarShape(), nothing, constant, component_refs) end """ @@ -382,15 +440,15 @@ function _tile_type_for_julia!(tt::TypeTable, @nospecialize(T::Type)) throw(IRError("Tile shape must be a tuple, got: $shape_param")) end elem_dtype = julia_to_tile_dtype!(tt, eltype(T)) - shape = collect(Int, shape_param) - return tile_type!(tt, elem_dtype, shape) + shape = RowMajorShape(shape_param) + return tile_type!(tt, elem_dtype, collect(shape)) end return nothing end """ - tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, Vector{Int}) + tile_type_and_shape_for_julia!(ctx, T) -> (TypeId, RowMajorShape) Get the Tile IR type and shape for a Julia type. """ @@ -398,11 +456,11 @@ function tile_type_and_shape_for_julia!(ctx::CGCtx, @nospecialize(T)) actual_type = CC.widenconst(T) type_id = tile_type_for_julia!(ctx, actual_type) - # Extract shape from Tile types + # Extract shape from Tile types (in Tile IR row-major order) shape = if actual_type <: Tile - collect(Int, size(actual_type)) + RowMajorShape(size(actual_type)) else - Int[] + ScalarShape() end return (type_id, shape) @@ -489,14 +547,15 @@ end #----------------------------------------------------------------------------- """ - extract_tile_shape(T) -> Vector{Int} + extract_tile_shape(T) -> RowMajorShape -Extract shape from a Tile{T, Shape} type, returning Int[] if not a Tile type. +Extract shape from a Tile{T, Shape} type in Tile IR (row-major) order. +Returns empty shape if not a Tile type. """ function extract_tile_shape(@nospecialize(T)) T = CC.widenconst(T) if T <: Tile - return collect(Int, size(T)) + return RowMajorShape(size(T)) end - Int[] + ScalarShape() end diff --git a/src/compiler/codegen/values.jl b/src/compiler/codegen/values.jl index 95b2406b..41b13f2a 100644 --- a/src/compiler/codegen/values.jl +++ b/src/compiler/codegen/values.jl @@ -19,7 +19,7 @@ function emit_value!(ctx::CGCtx, val::Integer) type_id = tile_type_for_julia!(ctx, jltype) bytes = reinterpret(UInt8, [jltype(val)]) v = encode_ConstantOp!(ctx.cb, type_id, collect(bytes)) - CGVal(v, type_id, jltype, Int[], nothing, Some(val), nothing) + CGVal(v, type_id, jltype, ScalarShape(), nothing, Some(val), nothing) end function emit_value!(ctx::CGCtx, val::AbstractFloat) @@ -27,7 +27,7 @@ function emit_value!(ctx::CGCtx, val::AbstractFloat) type_id = tile_type_for_julia!(ctx, jltype) bytes = reinterpret(UInt8, [jltype(val)]) v = encode_ConstantOp!(ctx.cb, type_id, collect(bytes)) - CGVal(v, type_id, jltype, Int[], nothing, Some(val), nothing) + CGVal(v, type_id, jltype, ScalarShape(), nothing, Some(val), nothing) end function emit_value!(ctx::CGCtx, node::QuoteNode) @@ -67,7 +67,7 @@ function emit_value!(ctx::CGCtx, ref::GlobalRef) if type_id !== nothing bytes = constant_to_bytes(val, T) v = encode_ConstantOp!(ctx.cb, type_id, bytes) - return CGVal(v, type_id, T, Int[], nothing, Some(val), nothing) + return CGVal(v, type_id, T, ScalarShape(), nothing, Some(val), nothing) end end ghost_value(T, val) diff --git a/src/compiler/intrinsics/arithmetic.jl b/src/compiler/intrinsics/arithmetic.jl index 43989479..e4583292 100644 --- a/src/compiler/intrinsics/arithmetic.jl +++ b/src/compiler/intrinsics/arithmetic.jl @@ -52,15 +52,15 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...) dtype = julia_to_tile_dtype!(tt, elem_type) if isempty(lhs_tv.shape) bv = broadcast_tile_to_shape!(cb, tt, lhs_tv, result_shape, dtype) - lhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type, + lhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type, result_shape, nothing, lhs_tv.constant, nothing) elseif isempty(rhs_tv.shape) bv = broadcast_tile_to_shape!(cb, tt, rhs_tv, result_shape, dtype) - rhs_tv = CGVal(bv, tile_type!(tt, dtype, result_shape), elem_type, + rhs_tv = CGVal(bv, tile_type!(tt, dtype, collect(result_shape)), elem_type, result_shape, nothing, rhs_tv.constant, nothing) end else - result_shape = Int[] + result_shape = ScalarShape() end result_jltype = lhs_tv.jltype else @@ -68,7 +68,7 @@ function emit_binop!(ctx::CGCtx, args, encoder::Function; kwargs...) end dtype = julia_to_tile_dtype!(tt, elem_type) - result_type_id = tile_type!(tt, dtype, result_shape) + result_type_id = tile_type!(tt, dtype, collect(result_shape)) result_v = encoder(cb, result_type_id, lhs_tv.v, rhs_tv.v; kwargs...) @@ -88,7 +88,7 @@ function emit_unop!(ctx::CGCtx, args, encoder::Function; kwargs...) result_jltype = source.jltype dtype = julia_to_tile_dtype!(tt, elem_type) - result_type_id = tile_type!(tt, dtype, result_shape) + result_type_id = tile_type!(tt, dtype, collect(result_shape)) result_v = encoder(cb, result_type_id, source.v; kwargs...) @@ -149,7 +149,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpi), args) result_shape = lhs.shape bool_dtype = I1(tt) - result_type_id = tile_type!(tt, bool_dtype, result_shape) + result_type_id = tile_type!(tt, bool_dtype, collect(result_shape)) result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate, signedness) lhs_type = CC.widenconst(lhs.jltype) @@ -295,7 +295,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cmpf), args) result_shape = lhs.shape bool_dtype = I1(tt) - result_type_id = tile_type!(tt, bool_dtype, result_shape) + result_type_id = tile_type!(tt, bool_dtype, collect(result_shape)) result_v = encode_CmpFOp!(cb, result_type_id, lhs.v, rhs.v; predicate) lhs_type = CC.widenconst(lhs.jltype) diff --git a/src/compiler/intrinsics/atomics.jl b/src/compiler/intrinsics/atomics.jl index 1cf6487d..f994d310 100644 --- a/src/compiler/intrinsics/atomics.jl +++ b/src/compiler/intrinsics/atomics.jl @@ -72,7 +72,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas), args) end ctx.token = new_token - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) + julia_shape = ColMajorShape(shape) + CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) end # cuda_tile.atomic_rmw_tko (shared helper for atomic RMW operations) @@ -129,7 +130,8 @@ function emit_atomic_rmw!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode. end ctx.token = new_token - CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape)) + julia_shape = ColMajorShape(shape) + CGVal(old_val, result_tile_type, Tile{elem_type, TupleType(julia_shape)}, shape) end # cuda_tile.atomic_rmw_tko variants diff --git a/src/compiler/intrinsics/conversions.jl b/src/compiler/intrinsics/conversions.jl index ba9f7844..53420b08 100644 --- a/src/compiler/intrinsics/conversions.jl +++ b/src/compiler/intrinsics/conversions.jl @@ -19,7 +19,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.exti), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("exti: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, source.shape) + result_type_id = tile_type!(tt, dtype, collect(source.shape)) result_v = encode_ExtIOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -43,7 +43,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftof), args) target_type = @something get_constant(ctx, args[2]) throw(IRError("ftof: requires compile-time target type")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, source.shape) + result_type_id = tile_type!(tt, dtype, collect(source.shape)) result_v = encode_FToFOp!(cb, result_type_id, source.v) src_type = CC.widenconst(source.jltype) @@ -68,7 +68,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.ftoi), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("ftoi: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, source.shape) + result_type_id = tile_type!(tt, dtype, collect(source.shape)) result_v = encode_FToIOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -93,7 +93,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.itof), args) signedness = @something get_constant(ctx, args[3]) throw(IRError("itof: requires compile-time signedness")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, source.shape) + result_type_id = tile_type!(tt, dtype, collect(source.shape)) result_v = encode_IToFOp!(cb, result_type_id, source.v; signedness) src_type = CC.widenconst(source.jltype) @@ -117,7 +117,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.trunci), args) target_type = @something get_constant(ctx, args[2]) throw(IRError("trunci: requires compile-time target type")) dtype = julia_to_tile_dtype!(tt, target_type) - result_type_id = tile_type!(tt, dtype, source.shape) + result_type_id = tile_type!(tt, dtype, collect(source.shape)) result_v = encode_TruncIOp!(cb, result_type_id, source.v) src_type = CC.widenconst(source.jltype) diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index 7faa9937..ddce599a 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -41,11 +41,11 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.broadcast), args) source_type = CC.widenconst(source.jltype) source_elem = eltype(source_type) - # Extract target shape + # Extract target shape (from Julia user code) and reverse to Tile IR order target_shape_tuple = @something get_constant(ctx, args[2]) throw(IRError("broadcast() shape must be a compile-time constant")) target_shape_tuple isa Tuple || throw(IRError("broadcast() shape must be a tuple, got $(typeof(target_shape_tuple))")) - target_shape = collect(Int, target_shape_tuple) - validate_tile_shape(target_shape, "broadcast") + validate_tile_shape(collect(Int, target_shape_tuple), "broadcast") + target_shape = RowMajorShape(target_shape_tuple) # If already the right shape, return unchanged if source.shape == target_shape @@ -55,19 +55,19 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.broadcast), args) # Use the existing broadcast helper dtype = julia_to_tile_dtype!(tt, source_elem) result_v = broadcast_tile_to_shape!(cb, tt, source, target_shape, dtype) - result_type_id = tile_type!(tt, dtype, target_shape) + result_type_id = tile_type!(tt, dtype, collect(target_shape)) - CGVal(result_v, result_type_id, Tile{source_elem, Tuple{target_shape...}}, target_shape) + CGVal(result_v, result_type_id, Tile{source_elem, Tuple{target_shape_tuple...}}, target_shape) end """ - broadcast_tile_to_shape!(cb, tt, tv::CGVal, target_shape::Vector{Int}, dtype::TypeId) -> Value + broadcast_tile_to_shape!(cb, tt, tv::CGVal, target_shape::RowMajorShape, dtype::TypeId) -> Value Broadcast a tile to a target shape by inserting ReshapeOp (for trailing 1s) and BroadcastOp. Returns the value after broadcasting, or the original value if shapes already match. """ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal, - target_shape::Vector{Int}, dtype::TypeId) + target_shape::RowMajorShape, dtype::TypeId) src_shape = tv.shape # Already the right shape? @@ -78,19 +78,18 @@ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal, current_val = tv.v current_shape = src_shape - # Step 1: Add trailing 1s via ReshapeOp if needed (dimension mismatch) - # Follows Julia convention: (n,) pads to (n, 1) — first dimension aligns. + # Step 1: Add leading 1s via ReshapeOp if needed (dimension mismatch) + # In Tile IR row-major order, Julia's trailing singleton padding becomes leading 1s. if length(current_shape) < length(target_shape) n_extra = length(target_shape) - length(current_shape) - new_shape = vcat(current_shape, fill(1, n_extra)) - reshaped_type = tile_type!(tt, dtype, new_shape) + current_shape = RowMajorShape(vcat(fill(1, n_extra), collect(current_shape))) + reshaped_type = tile_type!(tt, dtype, collect(current_shape)) current_val = encode_ReshapeOp!(cb, reshaped_type, current_val) - current_shape = new_shape end # Step 2: Broadcast dimensions that are 1 to target size if current_shape != target_shape - broadcast_type = tile_type!(tt, dtype, target_shape) + broadcast_type = tile_type!(tt, dtype, collect(target_shape)) current_val = encode_BroadcastOp!(cb, broadcast_type, current_val) end @@ -137,16 +136,17 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cat), args) axis_val = @something get_constant(ctx, args[2]) throw(IRError("cat() axis must be a compile-time constant")) axis_val isa Integer || throw(IRError("cat() axis must be an integer, got $(typeof(axis_val))")) - # Handle negative axis + # Handle negative axis and flip to Tile IR order lhs_shape = lhs.shape ndims = length(lhs_shape) - axis = axis_val < 0 ? ndims + axis_val : axis_val + julia_axis = axis_val < 0 ? ndims + axis_val : axis_val + tileir_axis = ndims - 1 - julia_axis - # Compute output shape - concatenate along the axis + # Compute output shape - concatenate along the axis (in Tile IR order) rhs_shape = rhs.shape - output_shape = collect(Int, lhs_shape) - output_shape[axis + 1] += rhs_shape[axis + 1] # 1-based indexing - validate_tile_shape(output_shape, "cat") + output_shape = copy(lhs_shape) + output_shape[tileir_axis + 1] += rhs_shape[tileir_axis + 1] + validate_tile_shape(collect(output_shape), "cat") # Get element type lhs_type = CC.widenconst(lhs.jltype) @@ -154,12 +154,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.cat), args) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, output_shape) + output_tile_type = tile_type!(tt, dtype, collect(output_shape)) - # Emit CatOp (axis is 0-indexed for bytecode) - result = encode_CatOp!(cb, output_tile_type, lhs.v, rhs.v, axis) + # Emit CatOp (Tile IR axis) + result = encode_CatOp!(cb, output_tile_type, lhs.v, rhs.v, tileir_axis) - CGVal(result, output_tile_type, Tile{elem_type, Tuple{output_shape...}}, output_shape) + julia_output = ColMajorShape(output_shape) + CGVal(result, output_tile_type, Tile{elem_type, TupleType(julia_output)}, output_shape) end # cuda_tile.constant @@ -175,17 +176,17 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.constant), args) cb = ctx.cb tt = ctx.tt - # Extract shape + # Extract shape (from Julia user code) and reverse to Tile IR order shape = @something get_constant(ctx, args[1]) throw(IRError("fill() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("fill() shape must be a tuple, got $(typeof(shape))")) - tile_shape = collect(Int, shape) - validate_tile_shape(tile_shape, "fill") + validate_tile_shape(collect(Int, shape), "fill") + tile_shape = RowMajorShape(shape) # Extract dtype from Type{T} argument elem_type = @something get_constant(ctx, args[3]) throw(IRError("constant() requires a compile-time element type")) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, tile_shape) + tile_type = tile_type!(tt, dtype, collect(tile_shape)) tv = emit_value!(ctx, args[2]) tv === nothing && throw(IRError("fill() value must be a constant or a runtime scalar")) @@ -198,7 +199,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.constant), args) result = broadcast_tile_to_shape!(cb, tt, tv, tile_shape, dtype) end - CGVal(result, tile_type, Tile{elem_type, Tuple{tile_shape...}}, tile_shape) + CGVal(result, tile_type, Tile{elem_type, Tuple{shape...}}, tile_shape) end # TODO: cuda_tile.entry @@ -221,26 +222,27 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.extract), args) source = emit_value!(ctx, args[1]) source === nothing && throw(IRError("Cannot resolve source operand for extract()")) - # Extract index + # Extract index (reverse for Tile IR order) index_tuple = @something get_constant(ctx, args[2]) throw(IRError("extract() index must be a compile-time constant")) index_tuple isa Tuple || throw(IRError("extract() index must be a tuple, got $(typeof(index_tuple))")) + # Extract shape (reverse for Tile IR order) shape_tuple = @something get_constant(ctx, args[3]) throw(IRError("extract() shape must be a compile-time constant")) shape_tuple isa Tuple || throw(IRError("extract() shape must be a tuple, got $(typeof(shape_tuple))")) - output_shape = collect(Int, shape_tuple) - validate_tile_shape(output_shape, "extract") + validate_tile_shape(collect(Int, shape_tuple), "extract") + output_shape = RowMajorShape(shape_tuple) # Get element type elem_type = eltype(CC.widenconst(source.jltype)) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, output_shape) + output_tile_type = tile_type!(tt, dtype, collect(output_shape)) - # Create constant index values (0D i32 tiles) + # Create constant index values (0D i32 tiles), reversed for Tile IR order scalar_i32 = tile_type!(tt, I32(tt), Int[]) index_vals = Value[] - for idx in index_tuple + for idx in reverse(index_tuple) idx_bytes = collect(reinterpret(UInt8, [Int32(idx)])) idx_val = encode_ConstantOp!(cb, scalar_i32, idx_bytes) push!(index_vals, idx_val) @@ -249,7 +251,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.extract), args) # Emit ExtractOp result = encode_ExtractOp!(cb, output_tile_type, source.v, index_vals) - CGVal(result, output_tile_type, Tile{elem_type, Tuple{output_shape...}}, output_shape) + CGVal(result, output_tile_type, Tile{elem_type, Tuple{shape_tuple...}}, output_shape) end # TODO: cuda_tile.get_global @@ -296,22 +298,22 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.iota), args) cb = ctx.cb tt = ctx.tt - # Extract shape + # Extract shape (from Julia) and reverse to Tile IR order shape = @something get_constant(ctx, args[1]) throw(IRError("iota() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("iota() shape must be a tuple, got $(typeof(shape))")) - tile_shape = collect(Int, shape) - validate_tile_shape(tile_shape, "arange") + validate_tile_shape(collect(Int, shape), "arange") + tile_shape = RowMajorShape(shape) # Extract dtype from Type{T} argument elem_type = @something get_constant(ctx, args[2]) throw(IRError("iota() requires a compile-time element type")) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, tile_shape) + tile_type = tile_type!(tt, dtype, collect(tile_shape)) # Emit IotaOp result = encode_IotaOp!(cb, tile_type) - CGVal(result, tile_type, Tile{elem_type, Tuple{tile_shape...}}, tile_shape) + CGVal(result, tile_type, Tile{elem_type, Tuple{shape...}}, tile_shape) end # cuda_tile.mmaf, cuda_tile.mmai @@ -364,13 +366,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.offset), args) ptr_elem_type = eltype(base_ptr_type) # T from Ptr{T} elem_dtype = julia_to_tile_dtype!(tt, ptr_elem_type) ptr_dtype = pointer_type!(tt, elem_dtype) - ptr_tile_type = tile_type!(tt, ptr_dtype, tile_shape) + ptr_tile_type = tile_type!(tt, ptr_dtype, collect(tile_shape)) # Broadcast base pointer to tile shape ndims = length(tile_shape) if ndims > 0 - ones_shape = fill(1, ndims) - reshaped_ptr_type = tile_type!(tt, ptr_dtype, ones_shape) + ones_shape = RowMajorShape(fill(1, ndims)) + reshaped_ptr_type = tile_type!(tt, ptr_dtype, collect(ones_shape)) base_ptr_reshaped = encode_ReshapeOp!(cb, reshaped_ptr_type, base_ptr) base_ptr_tile = encode_BroadcastOp!(cb, ptr_tile_type, base_ptr_reshaped) else @@ -380,7 +382,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.offset), args) # Compute offset pointers: base_ptr + offsets (element offset) pointers = encode_OffsetOp!(cb, ptr_tile_type, base_ptr_tile, offsets) - result_jltype = Tile{Ptr{ptr_elem_type}, Tuple{tile_shape...}} + julia_shape = ColMajorShape(tile_shape) + result_jltype = Tile{Ptr{ptr_elem_type}, TupleType(julia_shape)} CGVal(pointers, ptr_tile_type, result_jltype, tile_shape) end @@ -410,28 +413,30 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.permute), args) input_shape = source.shape isempty(input_shape) && throw(IRError("Cannot determine tile shape for permute()")) - # Extract permutation + # Extract permutation (0-indexed Julia order) and transform to Tile IR order perm_tuple = @something get_constant(ctx, args[2]) throw(IRError("permute() permutation must be a compile-time constant")) perm_tuple isa Tuple || throw(IRError("permute() permutation must be a tuple, got $(typeof(perm_tuple))")) - # Convert to 0-indexed vector for bytecode - permutation = collect(Int, perm_tuple) + julia_perm = collect(Int, perm_tuple) + n = length(julia_perm) + # Transform: q[i'] = n-1 - p[n-1-i'] (maps Julia perm to Tile IR perm) + tileir_perm = [n - 1 - julia_perm[n - i] for i in 0:n-1] - # Compute output shape based on permutation - # permutation[i] tells us which input dimension goes to output position i - output_shape = [input_shape[p + 1] for p in permutation] + # Compute output shape based on Tile IR permutation (input_shape is already Tile IR order) + output_shape = RowMajorShape([input_shape[q + 1] for q in tileir_perm]) # Get element type elem_type = eltype(CC.widenconst(source.jltype)) # Create output tile type dtype = julia_to_tile_dtype!(tt, elem_type) - output_tile_type = tile_type!(tt, dtype, output_shape) + output_tile_type = tile_type!(tt, dtype, collect(output_shape)) - # Emit PermuteOp - result = encode_PermuteOp!(cb, output_tile_type, source.v, permutation) + # Emit PermuteOp with Tile IR permutation + result = encode_PermuteOp!(cb, output_tile_type, source.v, tileir_perm) - CGVal(result, output_tile_type, Tile{elem_type, Tuple{output_shape...}}, output_shape) + julia_output = ColMajorShape(output_shape) + CGVal(result, output_tile_type, Tile{elem_type, TupleType(julia_output)}, output_shape) end @@ -470,7 +475,7 @@ function emit_reduce!(ctx::CGCtx, args) end for ref in first_tv.tuple] N = length(tile_tvs) - axis = @something get_constant(ctx, args[2]) throw(IRError("reduce() axis must be a compile-time constant")) + julia_axis = @something get_constant(ctx, args[2]) throw(IRError("reduce() axis must be a compile-time constant")) func = @something get_constant(ctx, args[3]) throw(IRError("reduce() combiner function must be a compile-time constant")) id_tv = emit_value!(ctx, args[4]) @@ -480,12 +485,16 @@ function emit_reduce!(ctx::CGCtx, args) throw(IRError("reduce() identity values must be compile-time constants"))) for ref in id_tv.tuple] - # Get shapes from the first tile + # Get shapes from the first tile (already in Tile IR order) input_shape = tile_tvs[1].shape isempty(input_shape) && throw(IRError("Cannot reduce scalar tile")) + # Flip axis from Julia 0-indexed to Tile IR order + ndim = length(input_shape) + axis = ndim - 1 - julia_axis + # ReduceOp removes the dimension; we'll reshape after to reintroduce it as size 1 - reduced_shape = Int[input_shape[i] for i in eachindex(input_shape) if i != axis + 1] + reduced_shape = RowMajorShape([input_shape[i] for i in eachindex(input_shape) if i != axis + 1]) # Build per-operand types and values elem_types = Type[] @@ -500,7 +509,7 @@ function emit_reduce!(ctx::CGCtx, args) push!(elem_types, etype) dtype = julia_to_tile_dtype!(tt, etype) push!(dtypes, dtype) - push!(reduced_tile_types, tile_type!(tt, dtype, reduced_shape)) + push!(reduced_tile_types, tile_type!(tt, dtype, collect(reduced_shape))) push!(scalar_tile_types, tile_type!(tt, dtype, Int[])) push!(operand_values, tv.v::Value) push!(identities, make_identity_val(identity_vals[k], dtype, etype)) @@ -526,13 +535,14 @@ function emit_reduce!(ctx::CGCtx, args) output_shape = copy(input_shape) output_shape[axis + 1] = 1 + julia_output = ColMajorShape(output_shape) reshaped_values = Value[] component_types = Type[] for (k, res) in enumerate(results) - out_type = tile_type!(tt, dtypes[k], output_shape) + out_type = tile_type!(tt, dtypes[k], collect(output_shape)) reshaped_val = encode_ReshapeOp!(cb, out_type, res) push!(reshaped_values, reshaped_val) - push!(component_types, Tile{elem_types[k], Tuple{output_shape...}}) + push!(component_types, Tile{elem_types[k], TupleType(julia_output)}) end # Return multi-value CGVal (tuple) @@ -578,53 +588,22 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.reshape), args) source = emit_value!(ctx, args[1]) source === nothing && throw(IRError("Cannot resolve source operand for reshape()")) - # Extract target shape + # Extract target shape (from Julia) and reverse to Tile IR order target_shape_tuple = @something get_constant(ctx, args[2]) throw(IRError("reshape() shape must be a compile-time constant")) target_shape_tuple isa Tuple || throw(IRError("reshape() shape must be a tuple, got $(typeof(target_shape_tuple))")) - target_shape = collect(Int, target_shape_tuple) - validate_tile_shape(target_shape, "reshape") - - # Get element type and source shape - source_type = CC.widenconst(source.jltype) - elem_type = eltype(source_type) - source_shape = collect(Int, size(source_type)) + validate_tile_shape(collect(Int, target_shape_tuple), "reshape") + target_shape = RowMajorShape(target_shape_tuple) + # Get element type + elem_type = eltype(CC.widenconst(source.jltype)) dtype = julia_to_tile_dtype!(tt, elem_type) - # Tile IR's ReshapeOp uses row-major element ordering, but Julia uses column-major. - # To achieve Julia's column-major reshape semantics, we need to: - # 1. Permute source to row-major order (reverse dims) if ndim > 1 - # 2. Reshape with reversed target shape - # 3. Permute result back to column-major order (reverse dims) if ndim > 1 - - current_val = source.v - current_shape = source_shape - - # Step 1: Permute source if >1 dimension (column-major → row-major) - if length(current_shape) > 1 - perm = collect(length(current_shape)-1:-1:0) # 0-indexed reverse - permuted_shape = reverse(current_shape) - perm_type_id = tile_type!(tt, dtype, permuted_shape) - current_val = encode_PermuteOp!(cb, perm_type_id, current_val, perm) - current_shape = permuted_shape - end + # Tile IR shapes are already in row-major order, so ReshapeOp's row-major element + # ordering matches directly. No permutes needed! + result_type_id = tile_type!(tt, dtype, collect(target_shape)) + result = encode_ReshapeOp!(cb, result_type_id, source.v) - # Step 2: ReshapeOp with reversed target shape - reversed_target = reverse(target_shape) - reshape_type_id = tile_type!(tt, dtype, reversed_target) - current_val = encode_ReshapeOp!(cb, reshape_type_id, current_val) - current_shape = reversed_target - - # Step 3: Permute result back if >1 dimension (row-major → column-major) - if length(target_shape) > 1 - perm = collect(length(target_shape)-1:-1:0) # 0-indexed reverse - result_type_id = tile_type!(tt, dtype, target_shape) - current_val = encode_PermuteOp!(cb, result_type_id, current_val, perm) - else - result_type_id = tile_type!(tt, dtype, target_shape) - end - - CGVal(current_val, result_type_id, Tile{elem_type, Tuple{target_shape...}}, target_shape) + CGVal(result, result_type_id, Tile{elem_type, Tuple{target_shape_tuple...}}, target_shape) end # cuda_tile.scan @@ -653,7 +632,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) end for ref in first_tv.tuple] N = length(tile_tvs) - axis = @something get_constant(ctx, args[2]) throw(IRError("scan() axis must be a compile-time constant")) + julia_axis = @something get_constant(ctx, args[2]) throw(IRError("scan() axis must be a compile-time constant")) func = @something get_constant(ctx, args[3]) throw(IRError("scan() combiner function must be a compile-time constant")) id_tv = emit_value!(ctx, args[4]) @@ -669,10 +648,14 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) reverse = reverse_val === true end - # Get shapes from the first tile + # Get shapes from the first tile (already in Tile IR order) input_shape = tile_tvs[1].shape isempty(input_shape) && throw(IRError("Cannot scan scalar tile")) + # Flip axis from Julia 0-indexed to Tile IR order + ndim = length(input_shape) + axis = ndim - 1 - julia_axis + # For scan, output shape is same as input shape output_shape = copy(input_shape) @@ -689,7 +672,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) push!(elem_types, etype) dtype = julia_to_tile_dtype!(tt, etype) push!(dtypes, dtype) - push!(output_tile_types, tile_type!(tt, dtype, output_shape)) + push!(output_tile_types, tile_type!(tt, dtype, collect(output_shape))) push!(scalar_tile_types, tile_type!(tt, dtype, Int[])) push!(operand_values, tv.v::Value) push!(identities, make_identity_val(identity_vals[k], dtype, etype)) @@ -712,9 +695,10 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.scan), args) end # Return multi-value CGVal (tuple) + julia_output = ColMajorShape(output_shape) component_types = Type[] for k in 1:N - push!(component_types, Tile{elem_types[k], Tuple{output_shape...}}) + push!(component_types, Tile{elem_types[k], TupleType(julia_output)}) end jltype = Tuple{component_types...} return CGVal(results, jltype) @@ -779,7 +763,7 @@ end function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.from_scalar), args) tv = emit_value!(ctx, args[1]) tv === nothing && throw(IRError("Cannot resolve scalar for from_scalar")) - shape_type = @something get_constant(ctx, args[2]) throw(IRError("from_scalar() shape must be a compile-time constant")) + shape_type = @something get_constant(ctx, args[2]) throw(IRError("from_ScalarShape() shape must be a compile-time constant")) T = CC.widenconst(tv.jltype) CGVal(tv.v, tv.type_id, Tile{T, shape_type}, tv.shape, nothing, nothing, nothing) end diff --git a/src/compiler/intrinsics/julia.jl b/src/compiler/intrinsics/julia.jl index c7dc562d..41888cea 100644 --- a/src/compiler/intrinsics/julia.jl +++ b/src/compiler/intrinsics/julia.jl @@ -56,7 +56,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(===), args) result_v = encode_CmpIOp!(cb, result_type_id, lhs.v, rhs.v; predicate=ComparisonPredicate.Equal, signedness=Signedness.Signed) - CGVal(result_v, result_type_id, Bool, Int[]) + CGVal(result_v, result_type_id, Bool, ScalarShape()) end # built-in: tuple diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index d4c52eac..2bfd34fc 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -29,7 +29,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) ptr_type = eltype(ptrs_type) # Ptr{T} from Tile{Ptr{T}, S} elem_type = eltype(ptr_type) # T from Ptr{T} dtype = julia_to_tile_dtype!(tt, elem_type) - result_tile_type = tile_type!(tt, dtype, tile_shape) + result_tile_type = tile_type!(tt, dtype, collect(tile_shape)) token_type = Token(tt) # Extract latency hint (args[2]) @@ -61,7 +61,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) end ctx.token = new_token - result_jltype = Tile{elem_type, Tuple{tile_shape...}} + julia_shape = ColMajorShape(tile_shape) + result_jltype = Tile{elem_type, TupleType(julia_shape)} CGVal(tile_val, result_tile_type, result_jltype, tile_shape) end diff --git a/src/compiler/intrinsics/misc.jl b/src/compiler/intrinsics/misc.jl index fa1c4ba6..32c47ef4 100644 --- a/src/compiler/intrinsics/misc.jl +++ b/src/compiler/intrinsics/misc.jl @@ -40,18 +40,22 @@ function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value} end # Divisibility assumes for sizes - for (i, div_by) in enumerate(array_spec.shape_div_by) - if div_by > 0 && i <= length(size_vals) - size_vals[i] = encode_AssumeOp!(cb, scalar_type, size_vals[i], DivBy(div_by)) + # ArraySpec fields are in Julia order; size_vals are in Tile IR order (reversed) + ndim = length(size_vals) + for (julia_i, div_by) in enumerate(array_spec.shape_div_by) + tileir_i = ndim + 1 - julia_i # Reverse index mapping + if div_by > 0 && tileir_i <= length(size_vals) + size_vals[tileir_i] = encode_AssumeOp!(cb, scalar_type, size_vals[tileir_i], DivBy(div_by)) end end # Divisibility assumes for strides - only for dynamic strides - for (i, div_by) in enumerate(array_spec.stride_div_by) - if div_by > 0 && i <= length(stride_vals) + for (julia_i, div_by) in enumerate(array_spec.stride_div_by) + tileir_i = ndim + 1 - julia_i # Reverse index mapping + if div_by > 0 && tileir_i <= length(stride_vals) # Skip if this stride is static (not DYNAMIC_SHAPE) - if tv_strides === nothing || tv_strides[i] == DYNAMIC_SHAPE - stride_vals[i] = encode_AssumeOp!(cb, scalar_type, stride_vals[i], DivBy(div_by)) + if tv_strides === nothing || tv_strides[tileir_i] == DYNAMIC_SHAPE + stride_vals[tileir_i] = encode_AssumeOp!(cb, scalar_type, stride_vals[tileir_i], DivBy(div_by)) end end end diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 4b05391b..560a7795 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -14,7 +14,7 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_index_space_shape), pv_arg === nothing && throw(IRError("get_index_space_shape() requires a PartitionView argument")) pv_arg.v === nothing && throw(IRError("get_index_space_shape() requires a materialized PartitionView")) - # Get axis (0-indexed) + # Get axis (0-indexed Julia) and flip to Tile IR order axis = @something get_constant(ctx, args[2]) throw(IRError("get_index_space_shape() axis must be a compile-time constant")) axis = Int(axis) @@ -22,6 +22,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_index_space_shape), pv_arg.constant === nothing && throw(IRError("get_index_space_shape(): PartitionView missing ndim info")) ndim = something(pv_arg.constant) + # Flip axis for row-major Tile IR: Julia dim 0 → Tile IR dim ndim-1 + tileir_axis = ndim - 1 - axis + # Create result types for all dimensions scalar_i32 = tile_type!(tt, I32(tt), Int[]) result_types = fill(scalar_i32, ndim) @@ -29,9 +32,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.get_index_space_shape), # Emit GetIndexSpaceShapeOp shape_vals = encode_GetIndexSpaceShapeOp!(cb, result_types, pv_arg.v) - # Return the value for the requested axis + # Return the value for the requested axis (in Tile IR order) # shape_vals is a single Value when ndim == 1, otherwise a Tuple - result_val = ndim == 1 ? shape_vals : shape_vals[axis + 1] + result_val = ndim == 1 ? shape_vals : shape_vals[tileir_axis + 1] CGVal(result_val, scalar_i32, Int32) end @@ -63,12 +66,13 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a ndim = something(pv_arg.constant) # Extract tile shape from PartitionView type (PartitionView{T, N, Shape}) + # Reverse to Tile IR row-major order pv_type = CC.widenconst(pv_arg.jltype) elem_type = eltype(pv_type) - tile_shape = collect(Int, size(pv_type)) + tile_shape = RowMajorShape(size(pv_type)) dtype = julia_to_tile_dtype!(tt, elem_type) - tile_type = tile_type!(tt, dtype, tile_shape) + tile_type = tile_type!(tt, dtype, collect(tile_shape)) token_type = Token(tt) latency = @something get_constant(ctx, args[2]) throw(IRError("load_partition_view(): latency must be a compile-time constant")) @@ -97,8 +101,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a index_jl_type = isempty(unique_types) ? Int32 : unique_types[1] # Int32 only for 0D case index_type = tile_type_for_julia!(ctx, index_jl_type) - # Pad indices if needed + # Pad indices if needed, then reverse for Tile IR row-major order index_vals = pad_indices(ctx, index_vals, ndim, index_type, index_jl_type) + reverse!(index_vals) # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) @@ -108,7 +113,8 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_partition_view), a token=ctx.token, optimization_hints) ctx.token = new_token - CGVal(tile_val, tile_type, Tile{elem_type, Tuple{tile_shape...}}, tile_shape) + julia_shape = ColMajorShape(tile_shape) + CGVal(tile_val, tile_type, Tile{elem_type, TupleType(julia_shape)}, tile_shape) end function pad_indices(ctx::CGCtx, index_vals::Vector{Value}, ndim::Int, idx_type::TypeId, idx_jl_type::Type) @@ -135,10 +141,11 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_partition_view), a tv === nothing && throw(IRError("make_partition_view() requires a TensorView argument")) # Shape from user call (e.g., load(arr, idx, (16,))) + # Reverse to Tile IR row-major order shape = @something get_constant(ctx, args[2]) throw(IRError("make_partition_view() shape must be a compile-time constant")) shape isa Tuple || throw(IRError("make_partition_view() shape must be a tuple, got $(typeof(shape))")) - tile_shape = collect(Int, shape) - validate_tile_shape(tile_shape, "load") + validate_tile_shape(collect(Int, shape), "load") + tile_shape = RowMajorShape(shape) padding_value = if length(args) >= 3 convert_enum(PaddingValue, @something get_constant(ctx, args[3]) throw(IRError("padding_mode must be a compile-time constant"))) @@ -151,19 +158,21 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_partition_view), a elem_type = eltype(tv.jltype) ndim = length(tile_shape) - # Extract order (arg 4) + # Extract order (arg 4) and reverse for Tile IR row-major order # nothing → identity dim_map, (2,1) → [1, 0] (1-indexed → 0-indexed) order_val = @something get_constant(ctx, args[4]) throw(IRError("make_partition_view() order must be a compile-time constant")) if order_val === nothing dim_map = collect(0:ndim-1) else - dim_map = collect(Int, map(p -> p - 1, order_val)) + # Convert Julia dim_map to Tile IR: reverse and remap indices + julia_dim_map = collect(Int, map(p -> p - 1, order_val)) + dim_map = [ndim - 1 - julia_dim_map[ndim - i] for i in 0:ndim-1] end - pv_type = partition_view_type!(ctx.tt, tile_shape, tv_type, dim_map, padding_value) + pv_type = partition_view_type!(ctx.tt, collect(tile_shape), tv_type, dim_map, padding_value) partition = encode_MakePartitionViewOp!(ctx.cb, pv_type, tensor_view) - CGVal(partition, pv_type, PartitionView{elem_type, ndim, Tuple{tile_shape...}}, Int[], nothing, Some(ndim), nothing) + CGVal(partition, pv_type, PartitionView{elem_type, ndim, Tuple{shape...}}, ScalarShape(), nothing, Some(ndim), nothing) end """ @@ -214,30 +223,33 @@ function cache_tensor_view!(ctx::CGCtx, arg_idx::Int, size_elem_type = eltype(fieldtype(tilearray_type, :sizes)) scalar_size_type = tile_type_for_julia!(ctx, size_elem_type) - # Sizes are passed through directly from parameters - size_vals = Value[sizes_from_arg[i] for i in 1:ndim] + # Sizes in Julia column-major order from parameters + julia_size_vals = Value[sizes_from_arg[i] for i in 1:ndim] - # Strides from parameters or compute for contiguous arrays + # Strides from parameters or compute for contiguous arrays (Julia column-major order) if strides_from_arg !== nothing && length(strides_from_arg) >= ndim - stride_vals = Value[strides_from_arg[i] for i in 1:ndim] + julia_stride_vals = Value[strides_from_arg[i] for i in 1:ndim] else # Compute column-major strides: stride[1]=1, stride[i]=stride[i-1]*size[i-1] - # This matches Julia's memory layout where the first dimension is contiguous - stride_vals = Value[] + julia_stride_vals = Value[] for dim in 1:ndim if dim == 1 stride_bytes = reinterpret(UInt8, [size_elem_type(1)]) - push!(stride_vals, encode_ConstantOp!(cb, scalar_size_type, collect(stride_bytes))) + push!(julia_stride_vals, encode_ConstantOp!(cb, scalar_size_type, collect(stride_bytes))) else - push!(stride_vals, encode_MulIOp!(cb, scalar_size_type, stride_vals[end], size_vals[dim-1])) + push!(julia_stride_vals, encode_MulIOp!(cb, scalar_size_type, julia_stride_vals[end], julia_size_vals[dim-1])) end end end - # TensorView type - tv_shape = fill(DYNAMIC_SHAPE, ndim) + # Reverse sizes and strides for Tile IR row-major order + size_vals = reverse(julia_size_vals) + stride_vals = reverse(julia_stride_vals) + + # TensorView type (strides also in Tile IR order: last dim = contiguous) + tv_shape = RowMajorShape(fill(DYNAMIC_SHAPE, ndim)) tv_strides = compute_tensor_view_strides(spec, ndim) - tv_type = tensor_view_type!(tt, dtype, tv_shape, tv_strides) + tv_type = tensor_view_type!(tt, dtype, collect(tv_shape), tv_strides) # Emit AssumeOps for optimization hints if spec !== nothing @@ -268,8 +280,8 @@ function compute_tensor_view_strides(array_spec::Union{ArraySpec, Nothing}, ndim strides = fill(DYNAMIC_SHAPE, ndim) if array_spec !== nothing && array_spec.contiguous && ndim >= 1 - # Contiguous column-major array: first stride is statically known to be 1 - strides[1] = 1 + # Contiguous column-major array: Julia stride[1]=1 becomes Tile IR stride[ndim]=1 + strides[ndim] = 1 end return strides @@ -395,8 +407,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_partition_view), index_jl_type = isempty(unique_types) ? Int32 : unique_types[1] # Int32 only for 0D case index_type = tile_type_for_julia!(ctx, index_jl_type) - # Pad indices if needed + # Pad indices if needed, then reverse for Tile IR row-major order index_vals = pad_indices(ctx, index_vals, actual_ndim, index_type, index_jl_type) + reverse!(index_vals) # Create optimization hints if provided optimization_hints = create_optimization_hints(ctx, latency, allow_tma_val) diff --git a/src/language/operations.jl b/src/language/operations.jl index 0d83d14f..1d8edce3 100644 --- a/src/language/operations.jl +++ b/src/language/operations.jl @@ -928,9 +928,10 @@ end _muladd(a, b, acc, Val(ndims(a)), Val(ndims(b))) end -# 2D × 2D: direct MmaFOp with type promotion +# 2D × 2D: MmaFOp with swapped operands for row-major Tile IR +# Julia (M,K)*(K,N) → TileIR (K,M)*(N,K) → mmaf(b,a,acc) → TileIR (N,M) → Julia (M,N) @inline function _muladd(a::Tile, b::Tile, acc::Tile, ::Val{2}, ::Val{2}) - Intrinsics.mma(a, b, acc) + Intrinsics.mma(b, a, acc) end # Vec-mat (1D × 2D): reshape (M,) → (M, 1), MmaFOp, acc is already (M, N) @@ -965,11 +966,11 @@ end # Batched matmul (≥3D × ≥3D): trailing batch dims with broadcast # Julia convention: first two dims are matrix (M,K)/(K,N), trailing dims are batch. -# MmaFOp expects exactly 3D tiles (B, M, K), so we: +# With row-major Tile IR shapes, Julia (M,K,B) → TileIR (B,K,M), so: # 1. Broadcast batch dims to a common shape -# 2. Permute trailing batch → leading -# 3. Flatten multiple batch dims into one for MmaFOp -# 4. Unflatten + permute back after +# 2. Flatten batch dims into one via reshape (no permute needed!) +# 3. MmaFOp with swapped operands: mmaf(b, a, acc) +# 4. Unflatten batch dims via reshape @generated function _muladd(a::Tile{T1, SA}, b::Tile{T2, SB}, acc::Tile{T3, SC}, ::Val{NA}, ::Val{NB}) where {T1, T2, T3, SA, SB, SC, NA, NB} sa = Tuple(SA.parameters) @@ -992,14 +993,15 @@ end a_bc = broadcast_to(reshape(a, $((M, K, a_batch_padded...))), $((M, K, batch_shape...))) b_bc = broadcast_to(reshape(b, $((K, N, b_batch_padded...))), $((K, N, batch_shape...))) acc_bc = broadcast_to(acc, $((M, N, batch_shape...))) - # Flatten batch dims to one (still trailing), then permute to leading - a_3d = permutedims(reshape(a_bc, $((M, K, B_flat))), (3, 1, 2)) - b_3d = permutedims(reshape(b_bc, $((K, N, B_flat))), (3, 1, 2)) - acc_3d = permutedims(reshape(acc_bc, $((M, N, B_flat))), (3, 1, 2)) - # MmaFOp - result_3d = Intrinsics.mma(a_3d, b_3d, acc_3d) - # Permute back to trailing, unflatten batch dims - reshape(permutedims(result_3d, (2, 3, 1)), $((M, N, batch_shape...))) + # Flatten batch dims to one — no permute needed since row-major Tile IR + # already has batch as the leading (slowest) dimension + a_3d = reshape(a_bc, $((M, K, B_flat))) + b_3d = reshape(b_bc, $((K, N, B_flat))) + acc_3d = reshape(acc_bc, $((M, N, B_flat))) + # MmaFOp with swapped operands for row-major convention + result_3d = Intrinsics.mma(b_3d, a_3d, acc_3d) + # Unflatten batch dims + reshape(result_3d, $((M, N, batch_shape...))) end end diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index d8d923c2..d64bca75 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -232,8 +232,9 @@ end @check "loop iter_values" # The store MUST use a column index derived from loopIdx, not the spinloop result # After 1→0 index conversion, the store uses (loopIdx - 1) + # With row-major Tile IR, indices are reversed: Julia (row, col) → TileIR [col, row] @check "[[IDX:%.+]] = subi %loopIdx" - @check "store_view_tko{{.*}}[%{{[^,]+}}, [[IDX]]]" + @check "store_view_tko{{.*}}[[[IDX]], %{{[^,]+}}]" code_tiled(Tuple{ct.TileArray{Float32,2,spec}, ct.TileArray{Int32,1,spec1d}, Int32, ct.Constant{Int,4}, ct.Constant{Int,4}}) do DB, Locks, num_iters, GROUP_SIZE_M, TILE_N bid_m = ct.bid(1) diff --git a/test/codegen/operations.jl b/test/codegen/operations.jl index 0a39f8b3..691b0785 100644 --- a/test/codegen/operations.jl +++ b/test/codegen/operations.jl @@ -88,13 +88,13 @@ spec4d = ct.ArraySpec{4}(16, true) # TODO: unpack - unpack tiles @testset "reshape" begin - # 2D -> 1D reshape (emits pre-permute for column-major conversion) + # 2D -> 1D reshape (direct, no permutes with row-major Tile IR) @test @filecheck begin @check_label "entry" + @check_not "permute" code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, ct.TileArray{Float32,1,spec1d}}) do a, b pid = ct.bid(1) tile = ct.load(a, pid, (4, 8)) - @check "permute" # pre-permute for 2D source @check "reshape" reshaped = reshape(tile, (32,)) ct.store(b, pid, reshaped) @@ -102,29 +102,28 @@ spec4d = ct.ArraySpec{4}(16, true) end end - # 1D -> 2D reshape (emits post-permute for column-major conversion) + # 1D -> 2D reshape (direct, no permutes) @test @filecheck begin @check_label "entry" + @check_not "permute" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, ct.TileArray{Float32,2,spec2d}}) do a, b pid = ct.bid(1) tile = ct.load(a, pid, (64,)) @check "reshape" - @check "permute" # post-permute for 2D result reshaped = reshape(tile, (8, 8)) ct.store(b, pid, reshaped) return end end - # 3D -> 2D reshape (emits pre-permute and post-permute) + # 3D -> 2D reshape (direct, no permutes) @test @filecheck begin @check_label "entry" + @check_not "permute" code_tiled(Tuple{ct.TileArray{Float32,3,spec3d}, ct.TileArray{Float32,2,spec2d}}) do a, b pid = ct.bid(1) tile = ct.load(a, pid, (2, 4, 8)) - @check "permute" # pre-permute for 3D source @check "reshape" - @check "permute" # post-permute for 2D result reshaped = reshape(tile, (2, 32)) ct.store(b, pid, reshaped) return @@ -145,15 +144,14 @@ spec4d = ct.ArraySpec{4}(16, true) end end - # 2D -> 2D reshape (different shape, emits both permutes) + # 2D -> 2D reshape (different shape, direct, no permutes) @test @filecheck begin @check_label "entry" + @check_not "permute" code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a pid = ct.bid(1) tile = ct.load(a, pid, (4, 8)) - @check "permute" # pre-permute @check "reshape" - @check "permute" # post-permute reshaped = reshape(tile, (8, 4)) ct.store(a, pid, reshaped) return @@ -534,9 +532,8 @@ spec4d = ct.ArraySpec{4}(16, true) bidx = ct.bid(1) tile_a = ct.load(a, bidx, (32, 16, 1)) tile_b = ct.load(b, bidx, (16, 32, 4)) - # batched: broadcast + permute + mma + permute + # batched: broadcast + reshape + mma (no permutes with row-major Tile IR) @check "broadcast" - @check "permute" @check "mma" result = tile_a * tile_b ct.store(c, bidx, result) @@ -564,7 +561,6 @@ spec4d = ct.ArraySpec{4}(16, true) bidx = ct.bid(1) tile_a = ct.load(a, bidx, (16, 8, 2, 4)) tile_b = ct.load(b, bidx, (8, 16, 2, 4)) - @check "permute" @check "mma" result = tile_a * tile_b ct.store(c, bidx, result)