diff --git a/src/FastAlmostBandedMatrices.jl b/src/FastAlmostBandedMatrices.jl index 29c56e8..b95c4c6 100644 --- a/src/FastAlmostBandedMatrices.jl +++ b/src/FastAlmostBandedMatrices.jl @@ -14,6 +14,79 @@ import MatrixFactorizations: QR, QRPackedQ, getQ, getR, QRPackedQLayout, AdjQRPa @reexport using BandedMatrices +# ------------------ +# DisjointRange - for zero-allocation colsupport +# ------------------ + +""" + DisjointRange{T} + +A lazy representation of the union of two ranges, supporting iteration and indexing +without heap allocation. +""" +struct DisjointRange{T<:Integer, R1<:AbstractUnitRange{T}, R2<:AbstractUnitRange{T}} <: + AbstractVector{T} + r1::R1 + r2::R2 +end + +Base.size(d::DisjointRange) = (length(d.r1) + length(d.r2),) +Base.length(d::DisjointRange) = length(d.r1) + length(d.r2) + +@inline function Base.getindex(d::DisjointRange, i::Integer) + @boundscheck checkbounds(d, i) + n1 = length(d.r1) + if i <= n1 + return @inbounds d.r1[i] + else + return @inbounds d.r2[i - n1] + end +end + +Base.IndexStyle(::Type{<:DisjointRange}) = IndexLinear() + +@inline function Base.iterate(d::DisjointRange) + if !isempty(d.r1) + val, state = iterate(d.r1) + return val, (1, state) + elseif !isempty(d.r2) + val, state = iterate(d.r2) + return val, (2, state) + else + return nothing + end +end + +@inline function Base.iterate(d::DisjointRange, state) + which, inner_state = state + if which == 1 + next = iterate(d.r1, inner_state) + if next !== nothing + return next[1], (1, next[2]) + else + # Switch to r2 + if !isempty(d.r2) + val, new_state = iterate(d.r2) + return val, (2, new_state) + else + return nothing + end + end + else + next = iterate(d.r2, inner_state) + if next !== nothing + return next[1], (2, next[2]) + else + return nothing + end + end +end + +Base.first(d::DisjointRange) = isempty(d.r1) ? first(d.r2) : first(d.r1) +Base.last(d::DisjointRange) = isempty(d.r2) ? last(d.r1) : last(d.r2) +Base.minimum(d::DisjointRange) = min(minimum(d.r1), minimum(d.r2)) +Base.maximum(d::DisjointRange) = max(maximum(d.r1), maximum(d.r2)) + # ------------------ # AlmostBandedMatrix # ------------------ @@ -138,7 +211,7 @@ end function Base.fill!(A::AlmostBandedMatrix, v) fill!(bandpart(A), v) fill!(fillpart(A), v) - return nothing + return A end @inline function colsupport(::AbstractAlmostBandedLayout, A, j) @@ -151,7 +224,8 @@ end if isempty(sup) return Base.OneTo(r) else - return vcat(Base.OneTo(min(r, minimum(sup) - 1)), sup) + # Use DisjointRange to avoid heap allocation from vcat + return DisjointRange(Base.OneTo(min(r, minimum(sup) - 1)), sup) end end end diff --git a/test/Project.toml b/test/Project.toml index bb5b1cf..0c703ae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,8 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixFactorizations = "a3b82374-2e81-5b9e-98ce-41277c0e4c87" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/alloc_tests.jl b/test/alloc_tests.jl new file mode 100644 index 0000000..d826b2c --- /dev/null +++ b/test/alloc_tests.jl @@ -0,0 +1,134 @@ +using AllocCheck +using BenchmarkTools +using FastAlmostBandedMatrices +using FastAlmostBandedMatrices: DisjointRange +using ArrayLayouts: colsupport, rowsupport +using Test + +@testset "Allocation Tests" begin + @testset "DisjointRange - Zero Allocations" begin + # Test that DisjointRange operations don't allocate + r1 = Base.OneTo(5) + r2 = 10:15 + dr = DisjointRange(r1, r2) + + # Test length + allocs = @allocated length(dr) + @test allocs == 0 + + # Test getindex + allocs = @allocated dr[3] + @test allocs == 0 + + allocs = @allocated dr[8] + @test allocs == 0 + + # Test first/last + allocs = @allocated first(dr) + @test allocs == 0 + + allocs = @allocated last(dr) + @test allocs == 0 + + # Test iteration (after warmup) + sum_test = 0 + for x in dr + sum_test += x + end + allocs = @allocated begin + s = 0 + for x in dr + s += x + end + s + end + @test allocs == 0 + end + + @testset "colsupport - Zero Allocations" begin + n = 100 + m = 2 + B = brand(Float64, n, n, m + 1, m) + F = rand(Float64, m, n) + A = AlmostBandedMatrix(B, F) + + # Warmup + colsupport(A, 5) + colsupport(A, 50) + + # Test colsupport for j <= l+u (should return OneTo, no allocation) + allocs = @allocated colsupport(A, 5) + @test allocs == 0 + + # Test colsupport for j > l+u (now returns DisjointRange instead of vcat) + allocs = @allocated colsupport(A, 50) + @test allocs == 0 + end + + @testset "rowsupport - Zero Allocations" begin + n = 100 + m = 2 + B = brand(Float64, n, n, m + 1, m) + F = rand(Float64, m, n) + A = AlmostBandedMatrix(B, F) + + # Warmup + rowsupport(A, 1) + rowsupport(A, 50) + + # Test rowsupport (always returns UnitRange, no allocation) + allocs = @allocated rowsupport(A, 1) + @test allocs == 0 + + allocs = @allocated rowsupport(A, 50) + @test allocs == 0 + end + + @testset "getindex/setindex! - Zero Allocations" begin + n = 100 + m = 2 + B = brand(Float64, n, n, m + 1, m) + F = rand(Float64, m, n) + A = AlmostBandedMatrix(B, F) + + # Warmup + _ = A[50, 50] + A[50, 50] = 1.0 + + # Test getindex + allocs = @allocated A[50, 50] + @test allocs == 0 + + # Test setindex! in band part + allocs = @allocated A[50, 50] = 2.0 + @test allocs == 0 + + # Test setindex! in fill part + allocs = @allocated A[1, 50] = 3.0 + @test allocs == 0 + + # Test setindex! in overlapping part + allocs = @allocated A[1, 1] = 4.0 + @test allocs == 0 + end + + @testset "bandpart/fillpart - Zero Allocations" begin + n = 100 + m = 2 + B = brand(Float64, n, n, m + 1, m) + F = rand(Float64, m, n) + A = AlmostBandedMatrix(B, F) + + # Warmup + bandpart(A) + fillpart(A) + + # Test bandpart + allocs = @allocated bandpart(A) + @test allocs == 0 + + # Test fillpart + allocs = @allocated fillpart(A) + @test allocs == 0 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index f2941be..233ef38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -107,3 +107,8 @@ using SafeTestsets, Test @test length(A1.fill.nzval) == 2 end end + +# Allocation tests run separately to avoid precompilation interference +if get(ENV, "GROUP", "all") == "all" || get(ENV, "GROUP", "all") == "nopre" + include("alloc_tests.jl") +end