Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions src/FastAlmostBandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,79 @@ import MatrixFactorizations: QR, QRPackedQ, getQ, getR, QRPackedQLayout, AdjQRPa

@reexport using BandedMatrices

# ------------------
# DisjointRange - for zero-allocation colsupport
# ------------------

"""
DisjointRange{T}

A lazy representation of the union of two ranges, supporting iteration and indexing
without heap allocation.
"""
struct DisjointRange{T<:Integer, R1<:AbstractUnitRange{T}, R2<:AbstractUnitRange{T}} <:
AbstractVector{T}
r1::R1
r2::R2
end

Base.size(d::DisjointRange) = (length(d.r1) + length(d.r2),)
Base.length(d::DisjointRange) = length(d.r1) + length(d.r2)

@inline function Base.getindex(d::DisjointRange, i::Integer)
@boundscheck checkbounds(d, i)
n1 = length(d.r1)
if i <= n1
return @inbounds d.r1[i]
else
return @inbounds d.r2[i - n1]
end
end

Base.IndexStyle(::Type{<:DisjointRange}) = IndexLinear()

@inline function Base.iterate(d::DisjointRange)
if !isempty(d.r1)
val, state = iterate(d.r1)
return val, (1, state)
elseif !isempty(d.r2)
val, state = iterate(d.r2)
return val, (2, state)
else
return nothing
end
end

@inline function Base.iterate(d::DisjointRange, state)
which, inner_state = state
if which == 1
next = iterate(d.r1, inner_state)
if next !== nothing
return next[1], (1, next[2])
else
# Switch to r2
if !isempty(d.r2)
val, new_state = iterate(d.r2)
return val, (2, new_state)
else
return nothing
end
end
else
next = iterate(d.r2, inner_state)
if next !== nothing
return next[1], (2, next[2])
else
return nothing
end
end
end

Base.first(d::DisjointRange) = isempty(d.r1) ? first(d.r2) : first(d.r1)
Base.last(d::DisjointRange) = isempty(d.r2) ? last(d.r1) : last(d.r2)
Base.minimum(d::DisjointRange) = min(minimum(d.r1), minimum(d.r2))
Base.maximum(d::DisjointRange) = max(maximum(d.r1), maximum(d.r2))

# ------------------
# AlmostBandedMatrix
# ------------------
Expand Down Expand Up @@ -138,7 +211,7 @@ end
function Base.fill!(A::AlmostBandedMatrix, v)
fill!(bandpart(A), v)
fill!(fillpart(A), v)
return nothing
return A
end

@inline function colsupport(::AbstractAlmostBandedLayout, A, j)
Expand All @@ -151,7 +224,8 @@ end
if isempty(sup)
return Base.OneTo(r)
else
return vcat(Base.OneTo(min(r, minimum(sup) - 1)), sup)
# Use DisjointRange to avoid heap allocation from vcat
return DisjointRange(Base.OneTo(min(r, minimum(sup) - 1)), sup)
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixFactorizations = "a3b82374-2e81-5b9e-98ce-41277c0e4c87"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
134 changes: 134 additions & 0 deletions test/alloc_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using AllocCheck
using BenchmarkTools
using FastAlmostBandedMatrices
using FastAlmostBandedMatrices: DisjointRange
using ArrayLayouts: colsupport, rowsupport
using Test

@testset "Allocation Tests" begin
@testset "DisjointRange - Zero Allocations" begin
# Test that DisjointRange operations don't allocate
r1 = Base.OneTo(5)
r2 = 10:15
dr = DisjointRange(r1, r2)

# Test length
allocs = @allocated length(dr)
@test allocs == 0

# Test getindex
allocs = @allocated dr[3]
@test allocs == 0

allocs = @allocated dr[8]
@test allocs == 0

# Test first/last
allocs = @allocated first(dr)
@test allocs == 0

allocs = @allocated last(dr)
@test allocs == 0

# Test iteration (after warmup)
sum_test = 0
for x in dr
sum_test += x
end
allocs = @allocated begin
s = 0
for x in dr
s += x
end
s
end
@test allocs == 0
end

@testset "colsupport - Zero Allocations" begin
n = 100
m = 2
B = brand(Float64, n, n, m + 1, m)
F = rand(Float64, m, n)
A = AlmostBandedMatrix(B, F)

# Warmup
colsupport(A, 5)
colsupport(A, 50)

# Test colsupport for j <= l+u (should return OneTo, no allocation)
allocs = @allocated colsupport(A, 5)
@test allocs == 0

# Test colsupport for j > l+u (now returns DisjointRange instead of vcat)
allocs = @allocated colsupport(A, 50)
@test allocs == 0
end

@testset "rowsupport - Zero Allocations" begin
n = 100
m = 2
B = brand(Float64, n, n, m + 1, m)
F = rand(Float64, m, n)
A = AlmostBandedMatrix(B, F)

# Warmup
rowsupport(A, 1)
rowsupport(A, 50)

# Test rowsupport (always returns UnitRange, no allocation)
allocs = @allocated rowsupport(A, 1)
@test allocs == 0

allocs = @allocated rowsupport(A, 50)
@test allocs == 0
end

@testset "getindex/setindex! - Zero Allocations" begin
n = 100
m = 2
B = brand(Float64, n, n, m + 1, m)
F = rand(Float64, m, n)
A = AlmostBandedMatrix(B, F)

# Warmup
_ = A[50, 50]
A[50, 50] = 1.0

# Test getindex
allocs = @allocated A[50, 50]
@test allocs == 0

# Test setindex! in band part
allocs = @allocated A[50, 50] = 2.0
@test allocs == 0

# Test setindex! in fill part
allocs = @allocated A[1, 50] = 3.0
@test allocs == 0

# Test setindex! in overlapping part
allocs = @allocated A[1, 1] = 4.0
@test allocs == 0
end

@testset "bandpart/fillpart - Zero Allocations" begin
n = 100
m = 2
B = brand(Float64, n, n, m + 1, m)
F = rand(Float64, m, n)
A = AlmostBandedMatrix(B, F)

# Warmup
bandpart(A)
fillpart(A)

# Test bandpart
allocs = @allocated bandpart(A)
@test allocs == 0

# Test fillpart
allocs = @allocated fillpart(A)
@test allocs == 0
end
end
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,8 @@ using SafeTestsets, Test
@test length(A1.fill.nzval) == 2
end
end

# Allocation tests run separately to avoid precompilation interference
if get(ENV, "GROUP", "all") == "all" || get(ENV, "GROUP", "all") == "nopre"
include("alloc_tests.jl")
end
Loading