Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.4"
MatrixAlgebraKit = "0.6.5"
Mooncake = "0.5"
OhMyThreads = "0.8.0"
Printf = "1"
Expand Down
45 changes: 45 additions & 0 deletions ext/TensorKitMooncakeExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,48 @@ function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorM

return Ainv_ΔAinv, inv_pullback
end

# single-output projections: project_hermitian!, project_antihermitian!
for (f!, f, adj) in (
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
)
@eval begin
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)

# don't need to copy/restore A since projections don't mutate input
argc = copy(arg)
arg = $f!(A, arg, Mooncake.primal(alg_dalg))

function $adj(::NoRData)
$f!(darg)
if dA !== darg
add!(dA, darg)
MatrixAlgebraKit.zero!(darg)
end
copy!(arg, argc)
return ntuple(Returns(NoRData()), 4)
end

return arg_darg, $adj
end

function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
output_doutput = Mooncake.zero_fcodual(output)

doutput = last(arrayify(output_doutput))
function $adj(::NoRData)
# TODO: need accumulating projection to avoid intermediate here
add!(dA, $f(doutput))
MatrixAlgebraKit.zero!(doutput)
return ntuple(Returns(NoRData()), 3)
end

return output_doutput, $adj
end
end
end
4 changes: 2 additions & 2 deletions src/factorizations/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ using TensorOperations: Index2Tuple
using MatrixAlgebraKit
import MatrixAlgebraKit as MAK
using MatrixAlgebraKit: AbstractAlgorithm, TruncatedAlgorithm, DiagonalAlgorithm
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue,
TruncationByError, TruncationIntersection, TruncationByFilter, TruncationByOrder
using MatrixAlgebraKit: TruncationStrategy, NoTruncation, TruncationByValue, TruncationByError,
TruncationIntersection, TruncationUnion, TruncationByFilter, TruncationByOrder
using MatrixAlgebraKit: diagview

include("utility.jl")
Expand Down
44 changes: 34 additions & 10 deletions src/factorizations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,45 @@ function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationSpace)
return SectorDict(c => MAK.findtruncated_svd(d, blockstrategy(c)) for (c, d) in pairs(values))
end

# The implementations below assume that the `SectorDict` always contains an entry for every block sector
# for example, if a block gets fully truncated, inds[c] = Int[].
# This is always the case in the implementations above.

function MAK.findtruncated(values::SectorVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
return SectorDict(
c => mapreduce(
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
) for c in intersect(map(keys, inds)...)
)
@assert _allequal(keys, inds) "missing blocks are not supported right now"
sectors = keys(first(inds))
vals = map(keys(first(inds))) do c
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
end
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
end
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationIntersection)
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
return SectorDict(
c => mapreduce(
Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds
) for c in intersect(map(keys, inds)...)
)
@assert _allequal(keys, inds) "missing blocks are not supported right now"
sectors = keys(first(inds))
vals = map(keys(first(inds))) do c
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_intersect, inds)
end
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
end
function MAK.findtruncated(values::SectorVector, strategy::TruncationUnion)
inds = map(Base.Fix1(MAK.findtruncated, values), strategy.components)
@assert _allequal(keys, inds) "missing blocks are not supported right now"
sectors = keys(first(inds))
vals = map(keys(first(inds))) do c
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
end
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
end
function MAK.findtruncated_svd(values::SectorVector, strategy::TruncationUnion)
inds = map(Base.Fix1(MAK.findtruncated_svd, values), strategy.components)
@assert _allequal(keys, inds) "missing blocks are not supported right now"
sectors = keys(first(inds))
vals = map(keys(first(inds))) do c
mapreduce(Base.Fix2(getindex, c), MatrixAlgebraKit._ind_union, inds)
end
return SectorDict{eltype(sectors), eltype(vals)}(sectors, vals)
end

# Truncation error
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,6 @@ end
#------------------------------------------------------------
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))

blocktype(t::AbstractTensorMap) = blocktype(typeof(t))

numout(t::AbstractTensorMap) = numout(typeof(t))
numin(t::AbstractTensorMap) = numin(typeof(t))
numind(t::AbstractTensorMap) = numind(typeof(t))
Expand Down Expand Up @@ -441,6 +439,7 @@ See also [`blocks`](@ref), [`blocksectors`](@ref), [`blockdim`](@ref) and [`hasb

Return the type of the matrix blocks of a tensor.
""" blocktype
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
function blocktype(::Type{T}) where {T <: AbstractTensorMap}
return Core.Compiler.return_type(block, Tuple{T, sectortype(T)})
end
Expand Down
4 changes: 1 addition & 3 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,7 @@ block(t::TensorMap, c::Sector) = blocks(t)[c]

blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure)

function blocktype(::Type{TT}) where {TT <: TensorMap}
A = storagetype(TT)
T = eltype(A)
function blocktype(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A <: Vector{T}}
return Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}
end

Expand Down
8 changes: 4 additions & 4 deletions test/cuda/tensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t, first(blocksectors(t)))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t)
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t)
@test typeof(c) === sectortype(t)
end
end
Expand Down Expand Up @@ -162,8 +162,8 @@ for V in spacelist
next = @constinferred Nothing iterate(bs, state)
b2 = @constinferred block(t', first(blocksectors(t')))
@test b1 == b2
@test_broken eltype(bs) === Pair{typeof(c), typeof(b1)}
@test_broken typeof(b1) === TensorKit.blocktype(t')
@test eltype(bs) === Pair{typeof(c), typeof(b1)}
@test typeof(b1) === TensorKit.blocktype(t')
@test typeof(c) === sectortype(t)
# linear algebra
@test isa(@constinferred(norm(t)), real(T))
Expand Down
9 changes: 9 additions & 0 deletions test/tensors/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ for V in spacelist
@test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5)
@test minimum(diagview(S5)) >= λ
test_dim_isapprox(domain(S5), nvals)

trunc = truncrank(nvals) | trunctol(; atol = λ - 10eps(λ))
U5, S5, Vᴴ5, ϵ5 = @constinferred svd_trunc(t; trunc)
@test t * Vᴴ5' ≈ U5 * S5
@test isisometric(U5)
@test isisometric(Vᴴ5; side = :right)
@test norm(t - U5 * S5 * Vᴴ5) ≈ ϵ5 atol = eps(real(T))^(4 / 5)
@test minimum(diagview(S5)) >= λ
test_dim_isapprox(domain(S5), nvals)
end
end

Expand Down
Loading