From cc17f4f16ee22f17f201119ee2cfff355280fa4f Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sun, 15 Feb 2026 19:12:50 +0100 Subject: [PATCH 1/9] add autotuning to Experimental submodule --- Project.toml | 1 + ext/CUDAExt.jl | 4 +- ext/autotune/autotune.jl | 266 +++++++++++++++++++++++++++++++++++++ src/Experimental.jl | 62 +++++++++ src/cuTile.jl | 2 + test/execution/autotune.jl | 245 ++++++++++++++++++++++++++++++++++ 6 files changed, 579 insertions(+), 1 deletion(-) create mode 100644 ext/autotune/autotune.jl create mode 100644 src/Experimental.jl create mode 100644 test/execution/autotune.jl diff --git a/Project.toml b/Project.toml index c554faf0..1cbaea16 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ CUDA_Compiler_jll = "d1e2174e-dfdc-576e-b43e-73b79eb1aca8" CUDA_Tile_jll = "2068806d-a867-5dbd-af0e-42c2eb5d895d" CompilerCaching = "9db33cc3-5358-4881-8759-fa4194144afd" IRStructurizer = "93e32bba-5bb8-402b-805d-ffb066edee93" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index d81003d9..6ddf303f 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -8,7 +8,7 @@ using CompilerCaching: CacheView, method_instance, results import Core.Compiler as CC -using CUDA: CuModule, CuFunction, cudacall, device, capability +using CUDA: CUDA, CuModule, CuFunction, cudacall, device, capability using CUDA_Compiler_jll public launch @@ -201,4 +201,6 @@ Other values pass through unchanged. to_tile_arg(x) = x to_tile_arg(arr::AbstractArray) = TileArray(arr) +include("autotune/autotune.jl") + end diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl new file mode 100644 index 00000000..a73747eb --- /dev/null +++ b/ext/autotune/autotune.jl @@ -0,0 +1,266 @@ +import cuTile.Experimental: autotune_launch, clear_autotune_cache +using cuTile.Experimental: AbstractSearchSpace, CartesianSpace, FixedSpace + +using Random + +const AUTOTUNE_LOCK = ReentrantLock() +const AUTOTUNE_CACHE = Dict{Any, Dict{Any, Any}}() + +struct VerificationError <: Exception + msg::String +end + +const TUNING_PRESETS = ( + fast = (warmup=1, reps=3, refine_topk=0, refine_reps=0), + default = (warmup=2, reps=6, refine_topk=2, refine_reps=8), + thorough = (warmup=2, reps=9, refine_topk=4, refine_reps=16), +) + +function normalize_tuning(tuning::NamedTuple) + preset = get(tuning, :preset, :default) + preset isa Symbol || throw(ArgumentError("tuning.preset must be a Symbol")) + hasproperty(TUNING_PRESETS, preset) || + throw(ArgumentError("Unknown preset `$preset`; use :fast, :default, or :thorough")) + + base = merge(getproperty(TUNING_PRESETS, preset), + (seed=nothing, force=false, precompile_workers=Threads.nthreads())) + + # Apply user overrides (excluding :preset) + overrides = NamedTuple(k => v for (k, v) in pairs(tuning) if k !== :preset) + return merge(base, overrides) +end + +# Extract hint fields (occupancy, num_ctas) from a config for launch kwargs +function hints_from_cfg(cfg) + n = hasproperty(cfg, :num_ctas) ? cfg.num_ctas : nothing + o = hasproperty(cfg, :occupancy) ? cfg.occupancy : nothing + return (num_ctas=n, occupancy=o) +end + +function time_ms(run_once::Function, get_args::Function; + warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing) + CUDA.synchronize() + for _ in 1:max(warmup, verify !== nothing ? 1 : 0) + run_once(get_args()) + end + + if verify !== nothing + CUDA.synchronize() + verify() || throw(VerificationError("config produced incorrect output")) + end + + best_ms = Inf32 + for _ in 1:reps + args = get_args() + CUDA.synchronize() + elapsed_s = CUDA.@elapsed run_once(args) + CUDA.synchronize() + best_ms = min(best_ms, Float32(elapsed_s * 1000)) + end + return best_ms +end + +function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; + sm_arch::String, opt_level::Int, warmup::Int, reps::Int, + verify::Union{Nothing, Function}=nothing) + run_once = args -> cuTile.launch(f, grid_fn(cfg), args...; + sm_arch, opt_level, hints_from_cfg(cfg)...) + return time_ms(run_once, () -> args_fn(cfg); warmup, reps, verify) +end + +function precompile_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; + sm_arch::String, opt_level::Int) + grid_fn(cfg) + args = args_fn(cfg) + tile_args = map(to_tile_arg, args) + + # Mirror launch's Constant handling + unwrapped_types = map(tile_args) do arg + arg isa Constant ? constant_eltype(typeof(arg)) : typeof(arg) + end + argtypes = Tuple{unwrapped_types...} + + world = Base.get_world_counter() + mi = method_instance(f, argtypes; world) + mi === nothing && throw(MethodError(f, argtypes)) + + has_consts = any(x -> x isa Constant, tile_args) + const_argtypes = if has_consts + cats = Any[CC.Const(f)] + for arg in tile_args + push!(cats, arg isa Constant ? CC.Const(arg[]) : typeof(arg)) + end + cats + else + nothing + end + + hints = hints_from_cfg(cfg) + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=hints.num_ctas, occupancy=hints.occupancy) + cache = CacheView{CuTileResults}((:cuTile, opts), world) + emit_function(cache, mi; const_argtypes) +end + +function precompile_candidates(@nospecialize(f), configs::Vector{Any}, + grid_fn::Function, args_fn::Function; + sm_arch::String, opt_level::Int, workers::Int) + isempty(configs) && return configs, nothing + iszero(workers) && return configs, nothing + + workers = min(workers, Threads.nthreads(), length(configs)) + compiled = fill(true, length(configs)) + errors = Vector{Any}(nothing, length(configs)) + sem = Base.Semaphore(workers) + + @sync for (i, cfg) in enumerate(configs) + Threads.@spawn Base.acquire(sem) do + try + precompile_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level) + catch err + compiled[i] = false + errors[i] = (cfg, err) + end + end + end + + first_err = nothing + for e in errors + if e !== nothing + first_err = e + break + end + end + + return configs[compiled], first_err +end + +function measure_candidates(@nospecialize(f), configs::Vector{Any}, + grid_fn::Function, args_fn::Function; + sm_arch::String, opt_level::Int, warmup::Int, reps::Int, + verify::Union{Nothing, Function}=nothing) + record = Tuple{Any, Float32}[] + first_error = nothing + for cfg in configs + ms = try + eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify) + catch err + err isa VerificationError && @warn "Config $cfg failed verification, skipping" + first_error === nothing && (first_error = (cfg, err)) + continue + end + push!(record, (cfg, ms)) + end + return record, first_error +end + +function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, + grid_fn::Function, args_fn::Function, tuning; + sm_arch::String, opt_level::Int, kernel_key, arg_key, + verify::Union{Nothing, Function}=nothing) + if !tuning.force + entry = lock(AUTOTUNE_LOCK) do + per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing) + per_kernel !== nothing ? get(per_kernel, arg_key, nothing) : nothing + end + entry !== nothing && return entry, true + end + + checker = verify !== nothing ? verify() : nothing + + trials = collect(space) + + trials = Any[trials...] + trials, precompile_error = precompile_candidates(f, trials, grid_fn, args_fn; + sm_arch, opt_level, workers=tuning.precompile_workers) + + record, first_error = measure_candidates(f, trials, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker) + + if isempty(record) + # Prefer showing the precompile error (more informative) over the benchmark error + err_info = first_error !== nothing ? first_error : precompile_error + if err_info === nothing + throw(ArgumentError("No valid config found in search space")) + else + cfg, err = err_info + throw(ArgumentError( + "No valid config found. First failure for cfg=$cfg: $(sprint(showerror, err))")) + end + end + + # Refinement: re-benchmark top K with more reps to stabilize the winner + if tuning.refine_topk > 0 && length(record) > 1 + sort!(record, by=last) + top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] + refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn; + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps) + if !isempty(refined) + record = refined + end + end + + _, best_idx = findmin(last, record) + candidate = (; best_config=record[best_idx][1], tuning_record=record) + + return lock(AUTOTUNE_LOCK) do + per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key) + if !tuning.force && haskey(per_kernel, arg_key) + per_kernel[arg_key], true + else + per_kernel[arg_key] = candidate + candidate, false + end + end +end + +function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, + grid_fn::Function, args_fn::Function; + key=nothing, + key_fn::Union{Nothing, Function}=nothing, + launch_args_fn::Union{Nothing, Function}=nothing, + verify::Union{Nothing, Function}=nothing, + tuning::NamedTuple=NamedTuple(), + sm_arch::String=default_sm_arch(), + opt_level::Int=3) + tuning = normalize_tuning(tuning) + rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng() + + kernel_key = (f, sm_arch, opt_level) + arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing) + + entry, cache_hit = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; + sm_arch, opt_level, kernel_key, arg_key, verify) + + cfg = entry.best_config + grid = grid_fn(cfg) + args = launch_args_fn !== nothing ? launch_args_fn(cfg) : args_fn(cfg) + + cuTile.launch(f, grid, args...; sm_arch, opt_level, hints_from_cfg(cfg)...) + + return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) +end + +# Convenience: accept plain Vector (→ FixedSpace) or NamedTuple (→ CartesianSpace) +function autotune_launch(@nospecialize(f), configs, grid_fn::Function, args_fn::Function; kwargs...) + space = configs isa NamedTuple ? CartesianSpace(configs) : FixedSpace(configs) + return autotune_launch(f, space, grid_fn, args_fn; kwargs...) +end + +function clear_autotune_cache(; kernel=nothing, key=nothing) + lock(AUTOTUNE_LOCK) do + if kernel === nothing + key === nothing || throw(ArgumentError("`key` requires `kernel`")) + empty!(AUTOTUNE_CACHE) + return nothing + end + + for kernel_key in collect(keys(AUTOTUNE_CACHE)) + kernel_key isa Tuple || continue + kernel_key[1] === kernel || continue + per_kernel = AUTOTUNE_CACHE[kernel_key] + key === nothing ? empty!(per_kernel) : pop!(per_kernel, key, nothing) + isempty(per_kernel) && delete!(AUTOTUNE_CACHE, kernel_key) + end + end + return nothing +end diff --git a/src/Experimental.jl b/src/Experimental.jl new file mode 100644 index 00000000..219b9612 --- /dev/null +++ b/src/Experimental.jl @@ -0,0 +1,62 @@ +module Experimental + +autotune_launch(args...; kwargs...) = + error("Please import CUDA.jl before using `cuTile.autotune_launch`.") +clear_autotune_cache(args...; kwargs...) = + error("Please import CUDA.jl before using `cuTile.clear_autotune_cache`.") + +abstract type AbstractSearchSpace end + +Base.length(s::AbstractSearchSpace) = count(_ -> true, s) + +struct FixedSpace{names,NT<:NamedTuple{names}} <: AbstractSearchSpace + elements::Vector{NT} +end + +Base.iterate(space::FixedSpace, args...) = iterate(space.elements, args...) + +struct CartesianSpace{names,NT<:NamedTuple{names,<:Tuple{Vararg{Tuple}}}} <: AbstractSearchSpace + constraint::Function + axes::NT +end + +CartesianSpace(axes::NamedTuple) = CartesianSpace(Returns(true), axes) +CartesianSpace(; axes...) = CartesianSpace(NamedTuple(axes)) +CartesianSpace(constraint::Function; axes...) = CartesianSpace(constraint, NamedTuple(axes)) + +function Base.iterate(space::CartesianSpace{names}, state=nothing) where names + to_cfg = vals -> NamedTuple{names}(vals) + inner = state === nothing ? + Iterators.filter(space.constraint ∘ to_cfg, + Iterators.product(map(Tuple, values(space.axes))...)) : + state.inner + result = isnothing(state) ? iterate(inner) : iterate(inner, state.cursor) + isnothing(result) && return nothing + vals, cursor = result + cfg = to_cfg(vals) + return cfg, (; inner, cursor) +end + + +function _withconfig_rewrite(ex, cfg) + if ex isa Expr + if ex.head === :$ + key = ex.args[1] + key isa Symbol || throw(ArgumentError( + "@withconfig placeholders must be symbols like `\$tile`")) + return :($(cfg).$(key)) + elseif ex.head === :quote + return ex + end + return Expr(ex.head, map(arg -> _withconfig_rewrite(arg, cfg), ex.args)...) + end + return ex +end + +macro withconfig(ex) + cfg = gensym(:cfg) + rewritten = _withconfig_rewrite(ex, cfg) + return esc(:($cfg -> $rewritten)) +end + +end diff --git a/src/cuTile.jl b/src/cuTile.jl index d4ae69ec..e1f1958b 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -41,4 +41,6 @@ include("language/atomics.jl") public launch launch(args...) = error("Please import CUDA.jl before using `cuTile.launch`.") +include("Experimental.jl") + end # module cuTile diff --git a/test/execution/autotune.jl b/test/execution/autotune.jl new file mode 100644 index 00000000..7450215c --- /dev/null +++ b/test/execution/autotune.jl @@ -0,0 +1,245 @@ +using CUDA + +const Exp = ct.Experimental + +@testset "Autotune" begin + @testset "@withconfig" begin + grid_fn = Exp.@withconfig (cld(1024, $tile), 17) + @test grid_fn((; tile=64)) == (16, 17) + + args_fn = Exp.@withconfig (ct.Constant($tile), $occ, 42) + args = args_fn((; tile=32, occ=2)) + @test args[1] isa ct.Constant + @test args[1][] == 32 + @test args[2] == 2 + @test args[3] == 42 + end + + function vadd_kernel(a::ct.TileArray{Float32,1}, + b::ct.TileArray{Float32,1}, + c::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + ta = ct.load(a, pid, (tile[],)) + tb = ct.load(b, pid, (tile[],)) + ct.store(c, pid, ta + tb) + return nothing + end + + function inplace_add_kernel(x::ct.TileArray{Float32,1}, + tile::Int) + pid = ct.bid(1) + tx = ct.load(x, pid, (tile[],)) + ct.store(x, pid, tx .+ 1f0) + return nothing + end + + n = 512 + a = CUDA.fill(1f0, n) + b = CUDA.fill(2f0, n) + c = CUDA.zeros(Float32, n) + + configs = [ + (; tile=16, occupancy=nothing, num_ctas=nothing), + (; tile=32, occupancy=2, num_ctas=nothing), + (; tile=64, occupancy=4, num_ctas=2), + ] + args_fn = cfg -> (a, b, c, ct.Constant(cfg.tile)) + grid_fn = cfg -> cld(n, cfg.tile) + + @testset "basic tuning" begin + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test result.tuned_config in configs + @test !isempty(result.tuning_record) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "cache hit" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "force retune" begin + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:basic, n), + tuning=(preset=:fast, refine_topk=0, force=true), + ) + @test !result.cache_hit + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + space = Exp.CartesianSpace(; + tile=(16, 32), occupancy=(nothing, 2), num_ctas=(nothing,)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:cartesian, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test hasproperty(result.tuned_config, :tile) + @test hasproperty(result.tuned_config, :occupancy) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "CartesianSpace with constraint" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + space = Exp.CartesianSpace( + cfg -> cfg.tile == 16; + tile=(16, 32, 64), occupancy=(nothing,), num_ctas=(nothing,)) + result = Exp.autotune_launch( + vadd_kernel, space, grid_fn, args_fn; + key=(:constrained, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile == 16 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "NamedTuple convenience → CartesianSpace" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, + (tile=(16, 32), occupancy=(nothing,), num_ctas=(nothing,)), + grid_fn, args_fn; + key=(:nt_convenience, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test result.tuned_config.tile in (16, 32) + @test Array(c) ≈ fill(3f0, n) + end + + @testset "launch_args_fn (inplace kernel)" begin + x = CUDA.zeros(Float32, n) + original_x = Array(x) + Exp.clear_autotune_cache() + result = Exp.autotune_launch( + inplace_add_kernel, + [(; tile=16), (; tile=32)], + grid_fn, + cfg -> (copy(x), ct.Constant(cfg.tile)); + launch_args_fn=cfg -> (x, ct.Constant(cfg.tile)), + key=(:inplace, n), + tuning=(preset=:fast, refine_topk=0), + ) + @test !result.cache_hit + @test Array(x) == original_x .+ 1f0 + end + + @testset "refinement" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:refine, n), + tuning=(warmup=1, reps=2, refine_topk=2, refine_reps=4), + ) + @test !result.cache_hit + # Refinement record replaces initial — has at most refine_topk entries + @test length(result.tuning_record) <= 2 + @test Array(c) ≈ fill(3f0, n) + end + + @testset "verify" begin + Exp.clear_autotune_cache() + fill!(c, 0f0) + verify_called = Ref(false) + result = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:verify, n), + tuning=(preset=:fast, refine_topk=0), + verify=() -> let + ref = Array(a) .+ Array(b) + verify_called[] = true + () -> (CUDA.@allowscalar all(isapprox.(Array(c), ref, atol=1f-5))) + end, + ) + @test verify_called[] + @test Array(c) ≈ fill(3f0, n) + end + + @testset "clear cache per-kernel per-key" begin + Exp.clear_autotune_cache() + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + + # Clear only one key + Exp.clear_autotune_cache(kernel=vadd_kernel, key=(:k1, n)) + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k1, n), tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit # was cleared + + fill!(c, 0f0) + r2 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=(:k2, n), tuning=(preset=:fast, refine_topk=0)) + @test r2.cache_hit # still cached + end + + @testset "shared key across shapes" begin + Exp.clear_autotune_cache() + n2 = 1024 + a2 = CUDA.fill(1f0, n2) + b2 = CUDA.fill(2f0, n2) + c2 = CUDA.zeros(Float32, n2) + shared_key = (:shape_agnostic, eltype(a)) + + Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + + fill!(c2, 0f0) + result = Exp.autotune_launch( + vadd_kernel, configs, + cfg -> cld(n2, cfg.tile), + cfg -> (a2, b2, c2, ct.Constant(cfg.tile)); + key=shared_key, tuning=(preset=:fast, refine_topk=0)) + @test result.cache_hit + @test result.grid == cld(n2, result.tuned_config.tile) + @test Array(c2) ≈ fill(3f0, n2) + end + + @testset "key_fn" begin + Exp.clear_autotune_cache() + call_count = Ref(0) + my_key_fn = () -> begin + call_count[] += 1 + return (:dynamic, Float32) + end + + fill!(c, 0f0) + r1 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) + r2 = Exp.autotune_launch( + vadd_kernel, configs, grid_fn, args_fn; + key_fn=my_key_fn, tuning=(preset=:fast, refine_topk=0)) + @test !r1.cache_hit + @test r2.cache_hit + @test call_count[] == 2 + @test Array(c) ≈ fill(3f0, n) + end +end From 71afeb6297bd9dd187d70057798ccf7ad4485c9f Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 20 Feb 2026 23:31:40 +0100 Subject: [PATCH 2/9] adjust tuning presets --- ext/autotune/autotune.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index a73747eb..f52e297f 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -11,9 +11,9 @@ struct VerificationError <: Exception end const TUNING_PRESETS = ( - fast = (warmup=1, reps=3, refine_topk=0, refine_reps=0), - default = (warmup=2, reps=6, refine_topk=2, refine_reps=8), - thorough = (warmup=2, reps=9, refine_topk=4, refine_reps=16), + fast = (warmup=1, reps=3, refine_topk=0, refine_reps=2), + default = (warmup=2, reps=5, refine_topk=2, refine_reps=4), + thorough = (warmup=2, reps=7, refine_topk=4, refine_reps=6), ) function normalize_tuning(tuning::NamedTuple) From c22e19f0a2d864a822dc7a63ed6ea619c5288f4d Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 20 Feb 2026 23:31:57 +0100 Subject: [PATCH 3/9] add EMIT_CODE_LOCK --- ext/CUDAExt.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/CUDAExt.jl b/ext/CUDAExt.jl index 6ddf303f..629b63a9 100644 --- a/ext/CUDAExt.jl +++ b/ext/CUDAExt.jl @@ -13,6 +13,8 @@ using CUDA_Compiler_jll public launch +const EMIT_CODE_LOCK = ReentrantLock() + """ emit_binary(cache, mi; const_argtypes=nothing) -> Vector{UInt8} @@ -20,7 +22,9 @@ Binary phase: compile Tile IR bytecode to CUBIN using tileiras. """ function emit_binary(cache::CacheView, mi::Core.MethodInstance; const_argtypes::Union{Vector{Any}, Nothing}=nothing) - bytecode = emit_code(cache, mi; const_argtypes) + bytecode = lock(EMIT_CODE_LOCK) do + emit_code(cache, mi; const_argtypes) + end ci = get(cache, mi) res = const_argtypes !== nothing ? results(cache, ci, const_argtypes) : results(cache, ci) From 3f335dfe05f9b684664c211f205cace999725eb0 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 24 Feb 2026 11:42:47 +0100 Subject: [PATCH 4/9] shared inference cache --- ext/autotune/autotune.jl | 4 +++- src/compiler/interface.jl | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index f52e297f..f864cc13 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -170,8 +170,10 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac trials = collect(space) trials = Any[trials...] - trials, precompile_error = precompile_candidates(f, trials, grid_fn, args_fn; + trials, precompile_error = Base.ScopedValues.with(cuTile._SCOPED_INF_CACHE => CC.InferenceResult[]) do + precompile_candidates(f, trials, grid_fn, args_fn; sm_arch, opt_level, workers=tuning.precompile_workers) + end record, first_error = measure_candidates(f, trials, grid_fn, args_fn; sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 1f0dee93..d88f0fc7 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -35,9 +35,15 @@ struct cuTileInterpreter <: CC.AbstractInterpreter opt_params::CC.OptimizationParams end +# Scoped inference cache: reuse callee inference results (e.g. kwarg sorters) +# across interpreter instances within a compilation scope. Set by autotuning +# to share callee inference across configs; unset falls back to fresh cache. +using Base.ScopedValues: ScopedValue, with +const _SCOPED_INF_CACHE = ScopedValue{Vector{CC.InferenceResult}}() + function cuTileInterpreter(cache::CacheView; always_inline::Bool=true) method_table = get_method_table_view(cache.world) - inf_cache = Vector{CC.InferenceResult}() + inf_cache = isassigned(_SCOPED_INF_CACHE) ? _SCOPED_INF_CACHE[] : Vector{CC.InferenceResult}() inf_params = CC.InferenceParams() opt_params = if always_inline CC.OptimizationParams(; inline_cost_threshold=typemax(Int)) From 655b59340e856983a563d103d3b668fa0fa3957d Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 24 Feb 2026 11:43:11 +0100 Subject: [PATCH 5/9] handle interrupts while autotuning --- ext/autotune/autotune.jl | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index f864cc13..b6696a29 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -111,16 +111,30 @@ function precompile_candidates(@nospecialize(f), configs::Vector{Any}, compiled = fill(true, length(configs)) errors = Vector{Any}(nothing, length(configs)) sem = Base.Semaphore(workers) - - @sync for (i, cfg) in enumerate(configs) - Threads.@spawn Base.acquire(sem) do - try - precompile_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level) - catch err - compiled[i] = false - errors[i] = (cfg, err) + cancelled = Threads.Atomic{Bool}(false) + + try + @sync for (i, cfg) in enumerate(configs) + Threads.@spawn begin + cancelled[] && return + Base.acquire(sem) do + cancelled[] && return + try + precompile_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level) + catch err + compiled[i] = false + errors[i] = (cfg, err) + end + end end end + catch e + cancelled[] = true + e isa InterruptException || rethrow() + @warn "Precompilation interrupted, waiting for in-flight workers…" + # @sync already waits for spawned tasks before propagating, + # but the atomic flag ensures queued ones exit early. + rethrow() end first_err = nothing @@ -144,6 +158,10 @@ function measure_candidates(@nospecialize(f), configs::Vector{Any}, ms = try eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify) catch err + if err isa InterruptException + @warn "Benchmarking interrupted after $(length(record)) configs" + break + end err isa VerificationError && @warn "Config $cfg failed verification, skipping" first_error === nothing && (first_error = (cfg, err)) continue @@ -172,7 +190,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac trials = Any[trials...] trials, precompile_error = Base.ScopedValues.with(cuTile._SCOPED_INF_CACHE => CC.InferenceResult[]) do precompile_candidates(f, trials, grid_fn, args_fn; - sm_arch, opt_level, workers=tuning.precompile_workers) + sm_arch, opt_level, workers=tuning.precompile_workers) end record, first_error = measure_candidates(f, trials, grid_fn, args_fn; From 96edb5c763dcc970f8b1e0b1ab78769c52f5c4a6 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Tue, 24 Feb 2026 11:44:11 +0100 Subject: [PATCH 6/9] remove withconfig macro --- src/Experimental.jl | 22 ---------------------- test/execution/autotune.jl | 11 ----------- 2 files changed, 33 deletions(-) diff --git a/src/Experimental.jl b/src/Experimental.jl index 219b9612..f6d53546 100644 --- a/src/Experimental.jl +++ b/src/Experimental.jl @@ -37,26 +37,4 @@ function Base.iterate(space::CartesianSpace{names}, state=nothing) where names return cfg, (; inner, cursor) end - -function _withconfig_rewrite(ex, cfg) - if ex isa Expr - if ex.head === :$ - key = ex.args[1] - key isa Symbol || throw(ArgumentError( - "@withconfig placeholders must be symbols like `\$tile`")) - return :($(cfg).$(key)) - elseif ex.head === :quote - return ex - end - return Expr(ex.head, map(arg -> _withconfig_rewrite(arg, cfg), ex.args)...) - end - return ex -end - -macro withconfig(ex) - cfg = gensym(:cfg) - rewritten = _withconfig_rewrite(ex, cfg) - return esc(:($cfg -> $rewritten)) -end - end diff --git a/test/execution/autotune.jl b/test/execution/autotune.jl index 7450215c..d6e705e9 100644 --- a/test/execution/autotune.jl +++ b/test/execution/autotune.jl @@ -3,17 +3,6 @@ using CUDA const Exp = ct.Experimental @testset "Autotune" begin - @testset "@withconfig" begin - grid_fn = Exp.@withconfig (cld(1024, $tile), 17) - @test grid_fn((; tile=64)) == (16, 17) - - args_fn = Exp.@withconfig (ct.Constant($tile), $occ, 42) - args = args_fn((; tile=32, occ=2)) - @test args[1] isa ct.Constant - @test args[1][] == 32 - @test args[2] == 2 - @test args[3] == 42 - end function vadd_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,1}, From cab015ffb26544a21293ce58a499cdef069257d1 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 13 Mar 2026 21:41:55 +0000 Subject: [PATCH 7/9] fix --- ext/autotune/autotune.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index b6696a29..63e79a79 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -96,7 +96,9 @@ function precompile_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Funct end hints = hints_from_cfg(cfg) - opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=hints.num_ctas, occupancy=hints.occupancy) + bytecode_version = check_tile_ir_support() + opts = (sm_arch=sm_arch, opt_level=opt_level, num_ctas=hints.num_ctas, occupancy=hints.occupancy, + bytecode_version=bytecode_version) cache = CacheView{CuTileResults}((:cuTile, opts), world) emit_function(cache, mi; const_argtypes) end From d651181f0db45b8400fc3163eaabcb35a54f037a Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 14 Mar 2026 22:20:25 +0000 Subject: [PATCH 8/9] add setup argument --- ext/autotune/autotune.jl | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index 63e79a79..653a7d2f 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -38,9 +38,11 @@ function hints_from_cfg(cfg) end function time_ms(run_once::Function, get_args::Function; - warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing) + warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) CUDA.synchronize() for _ in 1:max(warmup, verify !== nothing ? 1 : 0) + reset !== nothing && reset() run_once(get_args()) end @@ -51,6 +53,7 @@ function time_ms(run_once::Function, get_args::Function; best_ms = Inf32 for _ in 1:reps + reset !== nothing && reset() args = get_args() CUDA.synchronize() elapsed_s = CUDA.@elapsed run_once(args) @@ -62,10 +65,11 @@ end function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; sm_arch::String, opt_level::Int, warmup::Int, reps::Int, - verify::Union{Nothing, Function}=nothing) + verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) run_once = args -> cuTile.launch(f, grid_fn(cfg), args...; sm_arch, opt_level, hints_from_cfg(cfg)...) - return time_ms(run_once, () -> args_fn(cfg); warmup, reps, verify) + return time_ms(run_once, () -> args_fn(cfg); warmup, reps, verify, reset) end function precompile_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; @@ -153,12 +157,13 @@ end function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn::Function, args_fn::Function; sm_arch::String, opt_level::Int, warmup::Int, reps::Int, - verify::Union{Nothing, Function}=nothing) + verify::Union{Nothing, Function}=nothing, + reset::Union{Nothing, Function}=nothing) record = Tuple{Any, Float32}[] first_error = nothing for cfg in configs ms = try - eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify) + eval_cfg(f, cfg, grid_fn, args_fn; sm_arch, opt_level, warmup, reps, verify, reset) catch err if err isa InterruptException @warn "Benchmarking interrupted after $(length(record)) configs" @@ -176,16 +181,18 @@ end function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, grid_fn::Function, args_fn::Function, tuning; sm_arch::String, opt_level::Int, kernel_key, arg_key, - verify::Union{Nothing, Function}=nothing) + verify::Union{Nothing, Function}=nothing, + setup::Union{Nothing, Function}=nothing) if !tuning.force entry = lock(AUTOTUNE_LOCK) do per_kernel = get(AUTOTUNE_CACHE, kernel_key, nothing) per_kernel !== nothing ? get(per_kernel, arg_key, nothing) : nothing end - entry !== nothing && return entry, true + entry !== nothing && return entry, true, nothing end checker = verify !== nothing ? verify() : nothing + reset = setup !== nothing ? setup() : nothing trials = collect(space) @@ -196,7 +203,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac end record, first_error = measure_candidates(f, trials, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker) + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.reps, verify=checker, reset) if isempty(record) # Prefer showing the precompile error (more informative) over the benchmark error @@ -215,7 +222,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac sort!(record, by=last) top_configs = Any[first(r) for r in record[1:min(tuning.refine_topk, length(record))]] refined, _ = measure_candidates(f, top_configs, grid_fn, args_fn; - sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps) + sm_arch, opt_level, warmup=tuning.warmup, reps=tuning.refine_reps, reset) if !isempty(refined) record = refined end @@ -224,7 +231,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac _, best_idx = findmin(last, record) candidate = (; best_config=record[best_idx][1], tuning_record=record) - return lock(AUTOTUNE_LOCK) do + entry, cache_hit = lock(AUTOTUNE_LOCK) do per_kernel = get!(Dict{Any,Any}, AUTOTUNE_CACHE, kernel_key) if !tuning.force && haskey(per_kernel, arg_key) per_kernel[arg_key], true @@ -233,6 +240,7 @@ function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::Abstrac candidate, false end end + return entry, cache_hit, reset end function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, @@ -241,6 +249,7 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, key_fn::Union{Nothing, Function}=nothing, launch_args_fn::Union{Nothing, Function}=nothing, verify::Union{Nothing, Function}=nothing, + setup::Union{Nothing, Function}=nothing, tuning::NamedTuple=NamedTuple(), sm_arch::String=default_sm_arch(), opt_level::Int=3) @@ -250,13 +259,16 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, kernel_key = (f, sm_arch, opt_level) arg_key = key !== nothing ? key : (key_fn !== nothing ? key_fn() : nothing) - entry, cache_hit = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; - sm_arch, opt_level, kernel_key, arg_key, verify) + entry, cache_hit, reset = find_or_tune(f, space, rng, grid_fn, args_fn, tuning; + sm_arch, opt_level, kernel_key, arg_key, verify, setup) cfg = entry.best_config grid = grid_fn(cfg) args = launch_args_fn !== nothing ? launch_args_fn(cfg) : args_fn(cfg) + # Reset state before the final "real" launch + reset !== nothing && reset() + cuTile.launch(f, grid, args...; sm_arch, opt_level, hints_from_cfg(cfg)...) return (; tuned_config=cfg, grid, tuning_record=copy(entry.tuning_record), cache_hit) From a1b02adc86a49747c564dd6b953a21ddb90fc2e6 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 21 Mar 2026 11:35:32 +0000 Subject: [PATCH 9/9] fixes --- ext/autotune/autotune.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ext/autotune/autotune.jl b/ext/autotune/autotune.jl index 653a7d2f..6e0d171c 100644 --- a/ext/autotune/autotune.jl +++ b/ext/autotune/autotune.jl @@ -1,6 +1,8 @@ import cuTile.Experimental: autotune_launch, clear_autotune_cache using cuTile.Experimental: AbstractSearchSpace, CartesianSpace, FixedSpace +using CUDA: CUDA + using Random const AUTOTUNE_LOCK = ReentrantLock() @@ -64,7 +66,7 @@ function time_ms(run_once::Function, get_args::Function; end function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; - sm_arch::String, opt_level::Int, warmup::Int, reps::Int, + sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, reset::Union{Nothing, Function}=nothing) run_once = args -> cuTile.launch(f, grid_fn(cfg), args...; @@ -73,7 +75,7 @@ function eval_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; end function precompile_cfg(@nospecialize(f), cfg, grid_fn::Function, args_fn::Function; - sm_arch::String, opt_level::Int) + sm_arch::VersionNumber, opt_level::Int) grid_fn(cfg) args = args_fn(cfg) tile_args = map(to_tile_arg, args) @@ -109,7 +111,7 @@ end function precompile_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn::Function, args_fn::Function; - sm_arch::String, opt_level::Int, workers::Int) + sm_arch::VersionNumber, opt_level::Int, workers::Int) isempty(configs) && return configs, nothing iszero(workers) && return configs, nothing @@ -156,7 +158,7 @@ end function measure_candidates(@nospecialize(f), configs::Vector{Any}, grid_fn::Function, args_fn::Function; - sm_arch::String, opt_level::Int, warmup::Int, reps::Int, + sm_arch::VersionNumber, opt_level::Int, warmup::Int, reps::Int, verify::Union{Nothing, Function}=nothing, reset::Union{Nothing, Function}=nothing) record = Tuple{Any, Float32}[] @@ -180,7 +182,7 @@ end function find_or_tune(@nospecialize(f), space::AbstractSearchSpace, rng::AbstractRNG, grid_fn::Function, args_fn::Function, tuning; - sm_arch::String, opt_level::Int, kernel_key, arg_key, + sm_arch::VersionNumber, opt_level::Int, kernel_key, arg_key, verify::Union{Nothing, Function}=nothing, setup::Union{Nothing, Function}=nothing) if !tuning.force @@ -251,7 +253,7 @@ function autotune_launch(@nospecialize(f), space::AbstractSearchSpace, verify::Union{Nothing, Function}=nothing, setup::Union{Nothing, Function}=nothing, tuning::NamedTuple=NamedTuple(), - sm_arch::String=default_sm_arch(), + sm_arch::VersionNumber=default_sm_arch(), opt_level::Int=3) tuning = normalize_tuning(tuning) rng = tuning.seed !== nothing ? MersenneTwister(tuning.seed) : Random.default_rng()