From ab9e7776626c6df63435bff71b1ddbd3bdcaf255 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 25 Mar 2026 16:33:41 -0400 Subject: [PATCH 1/5] add TruncationUnion implementation --- Project.toml | 2 +- src/factorizations/factorizations.jl | 4 ++-- src/factorizations/truncation.jl | 16 ++++++++++++++++ test/tensors/factorizations.jl | 9 +++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 7e82876c5..b46d82c91 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.4" +MatrixAlgebraKit = "0.6.5" Mooncake = "0.5" OhMyThreads = "0.8.0" Printf = "1" diff --git a/src/factorizations/factorizations.jl b/src/factorizations/factorizations.jl index 0c2272ac8..bac5cbcd9 100644 --- a/src/factorizations/factorizations.jl +++ b/src/factorizations/factorizations.jl @@ -19,8 +19,8 @@ using TensorOperations: Index2Tuple using MatrixAlgebraKit import MatrixAlgebraKit as MAK using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm -using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, - TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder +using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError, + TruncationIntersection, TruncationUnion, TruncationByFilter, TruncationByOrder using MatrixAlgebraKit: diagview include("utility.jl") diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 007565fd0..c2b60c1e3 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -281,6 +281,22 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationInterse ) for c in intersect(map(keys, inds)...) ) end +function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion) + inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) + return SectorDict( + c => reduce( + MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)] + ) for c in union(map(keys, inds)...) + ) +end +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion) + inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) + return SectorDict( + c => reduce( + MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)] + ) for c in union(map(keys, inds)...) + ) +end # Truncation error # ---------------- diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 48dfaf6f1..de326cc13 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -310,6 +310,15 @@ for V in spacelist @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ test_dim_isapprox(domain(S5), nvals) + + trunc = truncrank(nvals) | trunctol(; atol = λ - 10eps(λ)) + U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ5' ≈ U5 * S5 + @test isisometric(U5) + @test isisometric(Vᴴ5; side = :right) + @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) + @test minimum(diagview(S5)) >= λ + test_dim_isapprox(domain(S5), nvals) end end From c53e7627c07585001ff9d322f357f61015dee92e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 25 Mar 2026 16:33:57 -0400 Subject: [PATCH 2/5] add projection mooncake rules --- ext/TensorKitMooncakeExt/linalg.jl | 45 ++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 6af92f7be..a4c448015 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM return Ainv_ΔAinv, inv_pullback end + +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + + # don't need to copy/restore A since projections don't mutate input + argc = copy(arg) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + + function $adj(::NoRData) + $f!(darg) + if dA !== darg + add!(dA, darg) + MatrixAlgebraKit.zero!(darg) + end + copy!(arg, argc) + return ntuple(Returns(NoRData()), 4) + end + + return arg_darg, $adj + end + + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) + function $adj(::NoRData) + # TODO: need accumulating projection to avoid intermediate here + add!(dA, $f(doutput)) + MatrixAlgebraKit.zero!(doutput) + return ntuple(Returns(NoRData()), 3) + end + + return output_doutput, $adj + end + end +end From 18aa158b125e1b2892f23abae2eb23c751c14f35 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 26 Mar 2026 09:07:24 -0400 Subject: [PATCH 3/5] Ensure `blocktype` is correctly inferred for CuArray --- src/tensors/abstracttensor.jl | 3 +-- src/tensors/tensor.jl | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..245085d57 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -313,8 +313,6 @@ end #------------------------------------------------------------ InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t)) -blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) - numout(t::AbstractTensorMap) = numout(typeof(t)) numin(t::AbstractTensorMap) = numin(typeof(t)) numind(t::AbstractTensorMap) = numind(typeof(t)) @@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb Return the type of the matrix blocks of a tensor. """ blocktype +blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) function blocktype(::Type{T}) where {T <: AbstractTensorMap} return Core.Compiler.return_type(block, Tuple{T, sectortype(T)}) end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 615c67775..598356f04 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c] blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) -function blocktype(::Type{TT}) where {TT <: TensorMap} - A = storagetype(TT) - T = eltype(A) +function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}} end From 5b83bf41a3d60f37ad8711ac37fba40e7ed66d1b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 26 Mar 2026 09:30:39 -0400 Subject: [PATCH 4/5] mark tests as no longer broken --- test/cuda/tensors.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 0fad13473..7bdd90f9d 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -98,8 +98,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t, first(blocksectors(t))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t) + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t) @test typeof(c) === sectortype(t) end end @@ -162,8 +162,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t', first(blocksectors(t'))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t') + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t') @test typeof(c) === sectortype(t) # linear algebra @test isa(@constinferred(norm(t)), real(T)) From c6afd0b7315b2dfd75d674b466742598c122505e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 26 Mar 2026 11:06:14 -0400 Subject: [PATCH 5/5] type stability improvements --- src/factorizations/truncation.jl | 48 +++++++++++++++++++------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index c2b60c1e3..1115f4345 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -265,37 +265,45 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace) return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values)) end +# The implementations below assume that the `SectorDict` always contains an entry for every block sector +# for example, if a block gets fully truncated, inds[c] = Int[]. +# This is always the case in the implementations above. + function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) - return SectorDict( - c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds - ) for c in intersect(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) - return SectorDict( - c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds - ) for c in intersect(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion) inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) - return SectorDict( - c => reduce( - MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)] - ) for c in union(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion) inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) - return SectorDict( - c => reduce( - MatrixAlgebraKit._ind_union, [ind[c] for ind in inds if haskey(ind, c)] - ) for c in union(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end # Truncation error