diff --git a/Project.toml b/Project.toml index 7e82876c5..b46d82c91 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.4" +MatrixAlgebraKit = "0.6.5" Mooncake = "0.5" OhMyThreads = "0.8.0" Printf = "1" diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 6af92f7be..a4c448015 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM return Ainv_ΔAinv, inv_pullback end + +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + + # don't need to copy/restore A since projections don't mutate input + argc = copy(arg) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + + function $adj(::NoRData) + $f!(darg) + if dA !== darg + add!(dA, darg) + MatrixAlgebraKit.zero!(darg) + end + copy!(arg, argc) + return ntuple(Returns(NoRData()), 4) + end + + return arg_darg, $adj + end + + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) + function $adj(::NoRData) + # TODO: need accumulating projection to avoid intermediate here + add!(dA, $f(doutput)) + MatrixAlgebraKit.zero!(doutput) + return ntuple(Returns(NoRData()), 3) + end + + return output_doutput, $adj + end + end +end diff --git a/src/factorizations/factorizations.jl b/src/factorizations/factorizations.jl index 0c2272ac8..bac5cbcd9 100644 --- a/src/factorizations/factorizations.jl +++ b/src/factorizations/factorizations.jl @@ -19,8 +19,8 @@ using TensorOperations: Index2Tuple using MatrixAlgebraKit import MatrixAlgebraKit as MAK using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm -using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, - TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder +using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError, + TruncationIntersection, TruncationUnion, TruncationByFilter, TruncationByOrder using MatrixAlgebraKit: diagview include("utility.jl") diff --git a/src/factorizations/truncation.jl b/src/factorizations/truncation.jl index 007565fd0..1115f4345 100644 --- a/src/factorizations/truncation.jl +++ b/src/factorizations/truncation.jl @@ -265,21 +265,45 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace) return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values)) end +# The implementations below assume that the `SectorDict` always contains an entry for every block sector +# for example, if a block gets fully truncated, inds[c] = Int[]. +# This is always the case in the implementations above. + function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) - return SectorDict( - c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds - ) for c in intersect(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection) inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) - return SectorDict( - c => mapreduce( - Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds - ) for c in intersect(map(keys, inds)...) - ) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) +end +function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion) + inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) +end +function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion) + inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components) + @assert _allequal(keys, inds) "missing blocks are not supported right now" + sectors = keys(first(inds)) + vals = map(keys(first(inds))) do c + mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds) + end + return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals) end # Truncation error diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..245085d57 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -313,8 +313,6 @@ end #------------------------------------------------------------ InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t)) -blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) - numout(t::AbstractTensorMap) = numout(typeof(t)) numin(t::AbstractTensorMap) = numin(typeof(t)) numind(t::AbstractTensorMap) = numind(typeof(t)) @@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb Return the type of the matrix blocks of a tensor. """ blocktype +blocktype(t::AbstractTensorMap) = blocktype(typeof(t)) function blocktype(::Type{T}) where {T <: AbstractTensorMap} return Core.Compiler.return_type(block, Tuple{T, sectortype(T)}) end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 615c67775..598356f04 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c] blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) -function blocktype(::Type{TT}) where {TT <: TensorMap} - A = storagetype(TT) - T = eltype(A) +function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}} return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}} end diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 0fad13473..7bdd90f9d 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -98,8 +98,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t, first(blocksectors(t))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t) + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t) @test typeof(c) === sectortype(t) end end @@ -162,8 +162,8 @@ for V in spacelist next = @constinferred Nothing iterate(bs, state) b2 = @constinferred block(t', first(blocksectors(t'))) @test b1 == b2 - @test_broken eltype(bs) === Pair{typeof(c), typeof(b1)} - @test_broken typeof(b1) === TensorKit.blocktype(t') + @test eltype(bs) === Pair{typeof(c), typeof(b1)} + @test typeof(b1) === TensorKit.blocktype(t') @test typeof(c) === sectortype(t) # linear algebra @test isa(@constinferred(norm(t)), real(T)) diff --git a/test/tensors/factorizations.jl b/test/tensors/factorizations.jl index 48dfaf6f1..de326cc13 100644 --- a/test/tensors/factorizations.jl +++ b/test/tensors/factorizations.jl @@ -310,6 +310,15 @@ for V in spacelist @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) @test minimum(diagview(S5)) >= λ test_dim_isapprox(domain(S5), nvals) + + trunc = truncrank(nvals) | trunctol(; atol = λ - 10eps(λ)) + U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc) + @test t * Vᴴ5' ≈ U5 * S5 + @test isisometric(U5) + @test isisometric(Vᴴ5; side = :right) + @test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5) + @test minimum(diagview(S5)) >= λ + test_dim_isapprox(domain(S5), nvals) end end