diff --git a/Project.toml b/Project.toml index 7e82876c5..a3bf2ea84 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.3" +authors = ["Jutho Haegeman, Lukas Devos"] [deps] +Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" @@ -41,6 +42,7 @@ CUDA = "5.9" ChainRulesCore = "1" ChainRulesTestUtils = "1" Combinatorics = "1" +Dictionaries = "0.4" FiniteDifferences = "0.12" GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index ea491a843..7b12cdf18 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -97,7 +97,6 @@ In `TensorMap` instances, all data is gathered in a single `AbstractVector`, whi To obtain information about the structure of the data, you can use: ```@docs -fusionblockstructure(::AbstractTensorMap) dim(::AbstractTensorMap) blocksectors(::AbstractTensorMap) hasblock(::AbstractTensorMap, ::Sector) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index 779a4e017..ceb32d867 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -62,7 +62,8 @@ end Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent -@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 688e7363c..61adedcc2 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -118,6 +118,7 @@ const TO = TensorOperations using MatrixAlgebraKit +using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue, set! using LRUCache using OhMyThreads using ScopedValues @@ -218,6 +219,7 @@ end # Definitions and methods for tensors #------------------------------------- # general definitions +include("tensors/tensorstructure.jl") include("tensors/abstracttensor.jl") include("tensors/backends.jl") include("tensors/blockiterator.jl") diff --git a/src/auxiliary/dicts.jl b/src/auxiliary/dicts.jl index f43838010..bfdf79b88 100644 --- a/src/auxiliary/dicts.jl +++ b/src/auxiliary/dicts.jl @@ -263,3 +263,24 @@ function Base.:(==)(d1::SortedVectorDict, d2::SortedVectorDict) end return true end + +""" + Hashed(value, hashfunction = Base.hash, isequal = Base.isequal) + +Wrapper struct to alter the `hash` and `isequal` implementations of a given value. +This is useful in the contexts of dictionaries, where you either want to customize the hashfunction, +or consider various values as equal with a different notion of equality. +""" +struct Hashed{T, Hash, Eq} + val::T +end + +Hashed(val, hash = Base.hash, eq = Base.isequal) = Hashed{typeof(val), hash, eq}(val) + +Base.parent(h::Hashed) = h.val + +# hash overload +Base.hash(h::Hashed{T, Hash}, seed::UInt) where {T, Hash} = Hash(parent(h), seed) + +# isequal overload +Base.isequal(h1::H, h2::H) where {Eq, H <: Hashed{<:Any, <:Any, Eq}} = Eq(parent(h1), parent(h2)) diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index d032a1586..42d0f4c9e 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -41,8 +41,7 @@ spacetype(::Type{<:HomSpace{S}}) where {S} = S const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}} const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{ - S, ProductSpace{S, N₁}, - ProductSpace{S, N₂}, + S, ProductSpace{S, N₁}, ProductSpace{S, N₂}, } numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁ @@ -62,49 +61,23 @@ end →(dom::VectorSpace, codom::VectorSpace) = ←(codom, dom) function Base.show(io::IO, W::HomSpace) - if length(W.codomain) == 1 - print(io, W.codomain[1]) - else - print(io, W.codomain) - end - print(io, " ← ") - return if length(W.domain) == 1 - print(io, W.domain[1]) - else - print(io, W.domain) - end + return print( + io, + numout(W) == 1 ? codomain(W)[1] : codomain(W), + " ← ", + numin(W) == 1 ? domain(W)[1] : domain(W) + ) end """ - blocksectors(W::HomSpace) + blocksectors(W::HomSpace) -> Indices{I} -Return an iterator over the different unique coupled sector labels, i.e. the intersection -of the different fusion outputs that can be obtained by fusing the sectors present in the -domain, as well as from the codomain. +Return an `Indices` of all coupled sectors for `W`. The result is cached based on the +sector structure of `W` (ignoring degeneracy dimensions). -See also [`hasblock`](@ref). +See also [`hasblock`](@ref), [`blockstructure`](@ref). """ -function blocksectors(W::HomSpace) - sectortype(W) === Trivial && - return OneOrNoneIterator(dim(domain(W)) != 0 && dim(codomain(W)) != 0, Trivial()) - - codom = codomain(W) - dom = domain(W) - N₁ = length(codom) - N₂ = length(dom) - I = sectortype(W) - if N₁ == N₂ == 0 - return allunits(I) - elseif N₁ == 0 - return filter!(isunit, collect(blocksectors(dom))) # module space cannot end in empty space - elseif N₂ == 0 - return filter!(isunit, collect(blocksectors(codom))) - elseif N₂ <= N₁ - return filter!(c -> hasblock(codom, c), collect(blocksectors(dom))) - else - return filter!(c -> hasblock(dom, c), collect(blocksectors(codom))) - end -end +blocksectors(W::HomSpace) = sectorstructure(W).blocksectors """ hasblock(W::HomSpace, c::Sector) @@ -116,27 +89,27 @@ See also [`blocksectors`](@ref). hasblock(W::HomSpace, c::Sector) = hasblock(codomain(W), c) && hasblock(domain(W), c) """ - dim(W::HomSpace) + dim(W::HomSpace) -> Int Return the total dimension of a `HomSpace`, i.e. the number of linearly independent morphisms that can be constructed within this space. """ -function dim(W::HomSpace) - d = 0 - for c in blocksectors(W) - d += blockdim(codomain(W), c) * blockdim(domain(W), c) - end - return d -end +dim(W::HomSpace) = degeneracystructure(W).totaldim dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) """ - fusiontrees(W::HomSpace) + fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} -Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`. +Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a +bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is +cached based on the sector structure of `W` (ignoring degeneracy dimensions), so +`HomSpace`s that share the same sectors, dualities, and index count will reuse the same +object. + +See also [`sectorstructure`](@ref), [`subblockstructure`](@ref). """ -fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist +fusiontrees(W::HomSpace) = sectorstructure(W).fusiontrees # Operations on HomSpaces # ----------------------- @@ -290,125 +263,3 @@ function removeunit(P::HomSpace, ::Val{i}) where {i} return codomain(P) ← removeunit(domain(P), Val(i - numout(P))) end end - -# Block and fusion tree ranges: structure information for building tensors -#-------------------------------------------------------------------------- - -# sizes, strides, offset -const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} - -struct FusionBlockStructure{I, N, F₁, F₂} - totaldim::Int - blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreelist::Vector{Tuple{F₁, F₂}} - fusiontreestructure::Vector{StridedStructure{N}} - fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} -end - -function fusionblockstructuretype(W::HomSpace) - N₁ = length(codomain(W)) - N₂ = length(domain(W)) - N = N₁ + N₂ - I = sectortype(W) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - return FusionBlockStructure{I, N, F₁, F₂} -end - -@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) - codom = codomain(W) - dom = domain(W) - N₁ = length(codom) - N₂ = length(dom) - I = sectortype(W) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - - # output structure - blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range - fusiontreelist = Vector{Tuple{F₁, F₂}}() - fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset - - # temporary data structures - splittingtrees = Vector{F₁}() - splittingstructure = Vector{Tuple{Int, Int}}() - - # main computational routine - blockoffset = 0 - for c in blocksectors(W) - empty!(splittingtrees) - empty!(splittingstructure) - - offset₁ = 0 - for f₁ in fusiontrees(codom, c) - push!(splittingtrees, f₁) - d₁ = dim(codom, f₁.uncoupled) - push!(splittingstructure, (offset₁, d₁)) - offset₁ += d₁ - end - blockdim₁ = offset₁ - strides = (1, blockdim₁) - - offset₂ = 0 - for f₂ in fusiontrees(dom, c) - s₂ = f₂.uncoupled - d₂ = dim(dom, s₂) - for (f₁, (offset₁, d₁)) in zip(splittingtrees, splittingstructure) - push!(fusiontreelist, (f₁, f₂)) - totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁ - subsz = (dims(codom, f₁.uncoupled)..., dims(dom, f₂.uncoupled)...) - @assert !any(isequal(0), subsz) - substr = _subblock_strides(subsz, (d₁, d₂), strides) - push!(fusiontreestructure, (subsz, substr, totaloffset)) - end - offset₂ += d₂ - end - blockdim₂ = offset₂ - blocksize = (blockdim₁, blockdim₂) - blocklength = blockdim₁ * blockdim₂ - blockrange = (blockoffset + 1):(blockoffset + blocklength) - blockoffset = last(blockrange) - blockstructure[c] = (blocksize, blockrange) - end - - fusiontreeindices = sizehint!( - FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(fusiontreelist) - ) - for (i, f₁₂) in enumerate(fusiontreelist) - fusiontreeindices[f₁₂] = i - end - totaldim = blockoffset - structure = FusionBlockStructure( - totaldim, blockstructure, fusiontreelist, fusiontreestructure, fusiontreeindices - ) - return structure -end - -function _subblock_strides(subsz, sz, str) - sz_simplify = Strided.StridedViews._simplifydims(sz, str) - strides = Strided.StridedViews._computereshapestrides(subsz, sz_simplify...) - isnothing(strides) && - throw(ArgumentError("unexpected error in computing subblock strides")) - return strides -end - -function CacheStyle(::typeof(fusionblockstructure), W::HomSpace) - return GlobalLRUCache() -end - -# Diagonal ranges -#---------------- -# TODO: is this something we want to cache? -function diagonalblockstructure(W::HomSpace) - ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || - throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) - structure = SectorDict{sectortype(W), UnitRange{Int}}() # range - offset = 0 - dom = domain(W)[1] - for c in blocksectors(W) - d = dim(dom, c) - structure[c] = offset .+ (1:d) - offset += d - end - return structure -end diff --git a/src/spaces/productspace.jl b/src/spaces/productspace.jl index 4b75ce737..a887999f0 100644 --- a/src/spaces/productspace.jl +++ b/src/spaces/productspace.jl @@ -96,6 +96,9 @@ sectors(P::ProductSpace) = _sectors(P, sectortype(P)) function _sectors(P::ProductSpace{<:ElementarySpace, N}, ::Type{Trivial}) where {N} return OneOrNoneIterator(dim(P) != 0, ntuple(n -> Trivial(), N)) end +function _sectors(P::ProductSpace{<:ElementarySpace, 0}, ::Type{I}) where {I <: Sector} + return Iterators.map(u -> (u,), allunits(I)) +end function _sectors(P::ProductSpace{<:ElementarySpace, N}, ::Type{<:Sector}) where {N} return product(map(sectors, P.spaces)...) end @@ -147,23 +150,11 @@ that make up the `ProductSpace` instance. function blocksectors(P::ProductSpace{S, N}) where {S, N} I = sectortype(S) if I == Trivial - return OneOrNoneIterator(dim(P) != 0, Trivial()) + return dim(P) != 0 ? Indices([Trivial()]) : Indices(Trivial[]) end - bs = Vector{I}() - if N == 0 - append!(bs, allunits(I)) - elseif N == 1 - for s in sectors(P) - push!(bs, first(s)) - end - else - for s in sectors(P) - for c in ⊗(s...) - if !(c in bs) - push!(bs, c) - end - end - end + bs = Indices{I}() + for s in sectors(P), c in ⊗(s...) + set!(bs, c) end return sort!(bs) end diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..af216d257 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -321,21 +321,13 @@ numind(t::AbstractTensorMap) = numind(typeof(t)) # tensor characteristics: data structure and properties #------------------------------------------------------ -""" - fusionblockstructure(t::AbstractTensorMap) -> TensorStructure - -Return the necessary structure information to decompose a tensor in blocks labeled by -coupled sectors and in subblocks labeled by a splitting-fusion tree couple. -""" -fusionblockstructure(t::AbstractTensorMap) = fusionblockstructure(space(t)) - """ dim(t::AbstractTensorMap) -> Int The total number of free parameters of a tensor, discounting the entries that are fixed by symmetry. This is also the dimension of the `HomSpace` on which the `TensorMap` is defined. """ -dim(t::AbstractTensorMap) = fusionblockstructure(t).totaldim +dim(t::AbstractTensorMap) = dim(space(t)) dims(t::AbstractTensorMap) = dims(space(t)) @@ -344,7 +336,7 @@ dims(t::AbstractTensorMap) = dims(space(t)) Return an iterator over all coupled sectors of a tensor. """ -blocksectors(t::AbstractTensorMap) = keys(fusionblockstructure(t).blockstructure) +blocksectors(t::AbstractTensorMap) = blocksectors(space(t)) """ hasblock(t::AbstractTensorMap, c::Sector) -> Bool @@ -353,46 +345,15 @@ Verify whether a tensor has a block corresponding to a coupled sector `c`. """ hasblock(t::AbstractTensorMap, c::Sector) = c ∈ blocksectors(t) -# TODO: convenience methods, do we need them? -# """ -# blocksize(t::AbstractTensorMap, c::Sector) -> Tuple{Int,Int} - -# Return the size of the matrix block of a tensor corresponding to a coupled sector `c`. - -# See also [`blockdim`](@ref) and [`blockrange`](@ref). -# """ -# function blocksize(t::AbstractTensorMap, c::Sector) -# return fusionblockstructure(t).blockstructure[c][1] -# end - -# """ -# blockdim(t::AbstractTensorMap, c::Sector) -> Int - -# Return the total dimension (length) of the matrix block of a tensor corresponding to -# a coupled sector `c`. - -# See also [`blocksize`](@ref) and [`blockrange`](@ref). -# """ -# function blockdim(t::AbstractTensorMap, c::Sector) -# return *(blocksize(t, c)...) -# end - -# """ -# blockrange(t::AbstractTensorMap, c::Sector) -> UnitRange{Int} - -# Return the range at which to find the matrix block of a tensor corresponding to a -# coupled sector `c`, within the total data vector of length `dim(t)`. -# """ -# function blockrange(t::AbstractTensorMap, c::Sector) -# return fusionblockstructure(t).blockstructure[c][2] -# end +blockstructure(t::AbstractTensorMap) = blockstructure(space(t)) +subblockstructure(t::AbstractTensorMap) = subblockstructure(space(t)) """ fusiontrees(t::AbstractTensorMap) Return an iterator over all splitting - fusion tree pairs of a tensor. """ -fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist +fusiontrees(t::AbstractTensorMap) = fusiontrees(space(t)) fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t)) function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap} diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 0070bc2d4..48e1aa1df 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -143,11 +143,9 @@ function block(b::BraidingTensor, s::Sector) return data end - structure = fusionblockstructure(b) - base_offset = first(structure.blockstructure[s][2]) - 1 + base_offset = first(blockstructure(b)[s][2]) - 1 - for ((f1, f2), (sz, str, off)) in - zip(structure.fusiontreelist, structure.fusiontreestructure) + for ((f1, f2), (sz, str, off)) in pairs(subblockstructure(space(b))) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 615c67775..7e4f037cb 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -15,8 +15,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac function TensorMap{T, S, N₁, N₂, A}( ::UndefInitializer, space::TensorMapSpace{S, N₁, N₂} ) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} - d = fusionblockstructure(space).totaldim - data = A(undef, d) + data = A(undef, dim(space)) if !isbitstype(T) zerovector!(data) end @@ -31,8 +30,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac I = sectortype(S) T <: Real && !(sectorscalartype(I) <: Real) && @warn("Tensors with real data might be incompatible with sector type $I", maxlog = 1) - d = fusionblockstructure(space).totaldim - length(data) == d || throw(DimensionMismatch("invalid length of data")) + length(data) == dim(space) || throw(DimensionMismatch("invalid length of data")) return new{T, S, N₁, N₂, A}(data, space) end end @@ -453,7 +451,7 @@ end #------------------------------------------------- block(t::TensorMap, c::Sector) = blocks(t)[c] -blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) +blocks(t::TensorMap) = BlockIterator(t, blockstructure(space(t))) function blocktype(::Type{TT}) where {TT <: TensorMap} A = storagetype(TT) @@ -462,7 +460,7 @@ function blocktype(::Type{TT}) where {TT <: TensorMap} end function Base.iterate(iter::BlockIterator{<:TensorMap}, state...) - next = iterate(iter.structure, state...) + next = iterate(pairs(iter.structure), state...) isnothing(next) && return next (c, (sz, r)), newstate = next return c => reshape(view(iter.t.data, r), sz), newstate @@ -470,16 +468,18 @@ end function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector) sectortype(iter.t) === typeof(c) || throw(SectorMismatch()) - (d₁, d₂), r = get(iter.structure, c) do - # is s is not a key, at least one of the two dimensions will be zero: + found, token = gettoken(iter.structure, c) + if found + (d₁, d₂), r = gettokenvalue(iter.structure, token) + return reshape(view(iter.t.data, r), (d₁, d₂)) + else + # if c is not a key, at least one of the two dimensions will be zero: # it then does not matter where exactly we construct a view in `t.data`, # as it will have length zero anyway - d₁′ = blockdim(codomain(iter.t), c) - d₂′ = blockdim(domain(iter.t), c) - l = d₁′ * d₂′ - return (d₁′, d₂′), 1:l + d₁ = blockdim(codomain(iter.t), c) + d₂ = blockdim(domain(iter.t), c) + return reshape(view(iter.t.data, 1:(d₁ * d₂)), (d₁, d₂)) end - return reshape(view(iter.t.data, r), (d₁, d₂)) end # Getting and setting the data at the subblock level @@ -487,13 +487,11 @@ end function subblock( t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} - structure = fusionblockstructure(t) - @boundscheck begin - haskey(structure.fusiontreeindices, (f₁, f₂)) || throw(SectorMismatch()) - end + fts = subblockstructure(space(t)) + found, token = gettoken(fts, (f₁, f₂)) + @boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present")) @inbounds begin - i = structure.fusiontreeindices[(f₁, f₂)] - sz, str, offset = structure.fusiontreestructure[i] + sz, str, offset = gettokenvalue(fts, token) return StridedView(t.data, sz, str, offset) end end diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 0820fe1af..2d6b5a9a5 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -9,8 +9,7 @@ function TO.tensoralloc( ::Type{TT}, structure::TensorMapSpace, istemp::Val, allocator = TO.DefaultAllocator() ) where {TT <: AbstractTensorMap} A = storagetype(TT) - dim = fusionblockstructure(structure).totaldim - data = TO.tensoralloc(A, dim, istemp, allocator) + data = TO.tensoralloc(A, dim(structure), istemp, allocator) TT′ = tensormaptype(spacetype(structure), numout(structure), numin(structure), typeof(data)) return TT′(data, structure) end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl new file mode 100644 index 000000000..94e812ddd --- /dev/null +++ b/src/tensors/tensorstructure.jl @@ -0,0 +1,276 @@ +# Block and fusion tree ranges: structure information for building tensors +#-------------------------------------------------------------------------- + +# sizes, strides, offset +const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} + +function sectorequal(W₁::HomSpace, W₂::HomSpace) + check_spacetype(W₁, W₂) + (numout(W₁) == numout(W₂) && numin(W₁) == numin(W₂)) || return false + for (w₁, w₂) in zip(codomain(W₁), codomain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + for (w₁, w₂) in zip(domain(W₁), domain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + return true +end +function sectorhash(W::HomSpace, h::UInt) + for w in codomain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + for w in domain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + return h +end + +""" + SectorStructure{I, F₁, F₂} + +Sector-only structure of a `HomSpace`: the coupled sectors and all valid fusion tree pairs, +depending only on which sectors appear (not their degeneracy dimensions). Shared across +`HomSpace`s with the same sector structure. + +## Fields +- `blocksectors`: `Indices` of all coupled sectors `c::I`. +- `fusiontrees`: `Indices` of all valid fusion tree pairs `(f₁, f₂)`, in canonical order. + +See also [`sectorstructure`](@ref), [`DegeneracyStructure`](@ref). +""" +struct SectorStructure{I, F₁, F₂} + blocksectors::Indices{I} + fusiontrees::Indices{Tuple{F₁, F₂}} +end + +Base.@assume_effects :foldable function sectorstructuretype(key::Hashed{S}) where {S <: HomSpace} + I = sectortype(S) + F₁ = fusiontreetype(I, numout(S)) + F₂ = fusiontreetype(I, numin(S)) + return SectorStructure{I, F₁, F₂} +end + +function _blocksectors(W::HomSpace) + sectortype(W) === Trivial && + return OneOrNoneIterator(dim(domain(W)) != 0 && dim(codomain(W)) != 0, Trivial()) + + return sort!(intersect(blocksectors(codomain(W)), blocksectors(domain(W)))) +end + +@doc """ + sectorstructure(W::HomSpace) -> SectorStructure + +Return the [`SectorStructure`](@ref) for `W`, containing the coupled sectors and fusion tree +pairs as `Indices`. The result is cached based on the sector structure of `W` (ignoring +degeneracy dimensions). + +See also [`degeneracystructure`](@ref), [`fusiontrees`](@ref), [`blocksectors`](@ref). +""" sectorstructure(::HomSpace) +sectorstructure(W::HomSpace) = sectorstructure(Hashed(W, sectorhash, sectorequal)) + +@cached function sectorstructure(key::Hashed{S})::sectorstructuretype(key) where {S <: HomSpace} + W = parent(key) + codom, dom = codomain(W), domain(W) + I = sectortype(S) + N₁, N₂ = numout(S), numin(S) + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + + bs = Vector{I}() + trees = Vector{Tuple{F₁, F₂}}() + + for c in _blocksectors(W) + push!(bs, c) + codom_start = length(trees) + 1 + n₁ = 0 + for f₂ in fusiontrees(dom, c) + if n₁ == 0 + # First f₂ for this sector: enumerate codomain trees and record how many there are. + for f₁ in fusiontrees(codom, c) + push!(trees, (f₁, f₂)) + end + n₁ = length(trees) - codom_start + 1 + else + # Subsequent f₂s: the codomain trees are already in the list at + # codom_start:codom_start+n₁-1, so read them back instead of recomputing. + for j in codom_start:(codom_start + n₁ - 1) + push!(trees, (trees[j][1], f₂)) + end + end + end + end + + return SectorStructure{I, F₁, F₂}(Indices(bs), Indices(trees)) +end + +CacheStyle(::typeof(sectorstructure), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() + + +""" + DegeneracyStructure{N} + +Degeneracy-dependent structure of a `HomSpace`: the block sizes, ranges, and sub-block +strides that depend on the degeneracy (multiplicity) dimensions. Specific to a given +`HomSpace` instance. + +## Fields +- `totaldim`: total number of elements in the flat data vector. +- `blockstructure`: `Vector` of `((d₁, d₂), range)` values, one per coupled sector, in the + same order as [`sectorstructure`](@ref)`.blocksectors`. +- `subblockstructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` + values, one per fusion tree pair, in the same order as + [`sectorstructure`](@ref)`.fusiontrees`. + +See also [`degeneracystructure`](@ref), [`SectorStructure`](@ref). +""" +struct DegeneracyStructure{N} + totaldim::Int + blockstructure::Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}} + subblockstructure::Vector{StridedStructure{N}} +end + +function degeneracystructuretype(W::HomSpace) + N = length(codomain(W)) + length(domain(W)) + return DegeneracyStructure{N} +end + +@doc """ + degeneracystructure(W::HomSpace) -> DegeneracyStructure + +Compute the [`DegeneracyStructure`](@ref) for `W`, describing block sizes, data ranges, and +sub-block strides. The result is cached per `HomSpace` instance (keyed by object identity, +since degeneracy dimensions affect the block sizes and offsets). + +See also [`sectorstructure`](@ref), [`blockstructure`](@ref), [`subblockstructure`](@ref). +""" degeneracystructure(::HomSpace) +@cached function degeneracystructure(W::HomSpace)::degeneracystructuretype(W) + codom = codomain(W) + dom = domain(W) + N = length(codom) + length(dom) + + ss = sectorstructure(W) + treelist = ss.fusiontrees + L = length(treelist) + structurevalues = sizehint!(Vector{StridedStructure{N}}(), L) + blockvalues = Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}}(undef, length(ss.blocksectors)) + + # temporary data structures + splittingstructure = Vector{NTuple{numout(W), Int}}() + + blockoffset = 0 + tree_index = 1 + block_index = 1 + while tree_index <= L + f₁, f₂ = gettokenvalue(treelist, tree_index) + c = f₁.coupled + + # compute subblock structure + # splitting tree data + empty!(splittingstructure) + offset₁ = 0 + for i in tree_index:L + f₁′, f₂′ = gettokenvalue(treelist, i) + f₂′ == f₂ || break + s₁ = f₁′.uncoupled + d₁s = dims(codom, s₁) + d₁ = prod(d₁s) + offset₁ += d₁ + push!(splittingstructure, d₁s) + end + blockdim₁ = offset₁ + n₁ = length(splittingstructure) + strides = (1, blockdim₁) + + # fusion tree data and combine + offset₂ = 0 + n₂ = 0 + for i in tree_index:n₁:L + f₁′, f₂′ = gettokenvalue(treelist, i) + f₂′.coupled == c || break + n₂ += 1 + s₂ = f₂′.uncoupled + d₂s = dims(dom, s₂) + d₂ = prod(d₂s) + offset₁ = 0 + for d₁s in splittingstructure + d₁ = prod(d₁s) + totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁ + subsz = (d₁s..., d₂s...) + @assert !any(==(0), subsz) + substr = _subblock_strides(subsz, (d₁, d₂), strides) + push!(structurevalues, (subsz, substr, totaloffset)) + offset₁ += d₁ + end + offset₂ += d₂ + end + + # compute block structure + blockdim₂ = offset₂ + blockrange = (blockoffset + 1):(blockoffset + blockdim₁ * blockdim₂) + blockvalues[block_index] = ((blockdim₁, blockdim₂), blockrange) + + # reset + blockoffset = last(blockrange) + tree_index += n₁ * n₂ + block_index += 1 + end + @assert length(structurevalues) == L + + return DegeneracyStructure(blockoffset, blockvalues, structurevalues) +end + +function _subblock_strides(subsz, sz, str) + sz_simplify = Strided.StridedViews._simplifydims(sz, str) + strides = Strided.StridedViews._computereshapestrides(subsz, sz_simplify...) + isnothing(strides) && + throw(ArgumentError("unexpected error in computing subblock strides")) + return strides +end + +CacheStyle(::typeof(degeneracystructure), ::HomSpace) = GlobalLRUCache() + +# Public API: combining the two caches +#-------------------------------------- + +""" + blockstructure(W::HomSpace) -> Dictionary + +Return a `Dictionary` mapping each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, +where `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, +and `r` is the corresponding index range in the flat data vector. + +See also [`degeneracystructure`](@ref), [`subblockstructure`](@ref). +""" +function blockstructure(W::HomSpace) + return Dictionary(sectorstructure(W).blocksectors, degeneracystructure(W).blockstructure) +end + +""" + subblockstructure(W::HomSpace) -> Dictionary + +Return a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to its +[`StridedStructure`](@ref) `(sizes, strides, offset)`. + +See also [`degeneracystructure`](@ref), [`blockstructure`](@ref). +""" +function subblockstructure(W::HomSpace) + return Dictionary(sectorstructure(W).fusiontrees, degeneracystructure(W).subblockstructure) +end + +# Diagonal ranges +#---------------- +function diagonalblockstructure(W::HomSpace) + ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || + throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) + structure = SectorDict{sectortype(W), UnitRange{Int}}() # range + offset = 0 + dom = domain(W)[1] + for c in blocksectors(W) + d = dim(dom, c) + structure[c] = offset .+ (1:d) + offset += d + end + return structure +end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 36cd3926d..8577af908 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -16,20 +16,17 @@ end function AbelianTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - structure_dst = fusionblockstructure(Vdst) - structure_src = fusionblockstructure(Vsrc) - - L = length(structure_src.fusiontreelist) + fts_src = subblockstructure(Vsrc) + fts_dst = subblockstructure(Vdst) + L = length(fts_src) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) data = Vector{Tuple{T, StridedStructure{N}, StridedStructure{N}}}(undef, L) - for i in 1:L - f₁, f₂ = structure_src.fusiontreelist[i] + for (i, ((f₁, f₂), stridestructure_src)) in enumerate(pairs(fts_src)) (f₃, f₄), coeff = only(transform(f₁, f₂)) - j = structure_dst.fusiontreeindices[(f₃, f₄)] - stridestructure_dst = structure_dst.fusiontreestructure[j] - stridestructure_src = structure_src.fusiontreestructure[i] + _, token = gettoken(fts_dst, (f₃, f₄)) + stridestructure_dst = gettokenvalue(fts_dst, token) data[i] = (coeff, stridestructure_dst, stridestructure_src) end @@ -58,18 +55,21 @@ end function GenericTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - structure_dst = fusionblockstructure(Vdst) - fusionstructure_dst = structure_dst.fusiontreestructure - structure_src = fusionblockstructure(Vsrc) - fusionstructure_src = structure_src.fusiontreestructure + fusionstructure_dst = subblockstructure(Vdst) + fusionstructure_src = subblockstructure(Vsrc) I = sectortype(Vsrc) - uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂) + treelist_src = keys(fusionstructure_src) + treelist_dst = keys(fusionstructure_dst) + + uncoupleds_src = map(1:length(treelist_src)) do i + f₁, f₂ = gettokenvalue(treelist_src, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end uncoupleds_src_unique = unique(uncoupleds_src) - uncoupleds_dst = map(structure_dst.fusiontreelist) do (f₁, f₂) + uncoupleds_dst = map(1:length(treelist_dst)) do i + f₁, f₂ = gettokenvalue(treelist_dst, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end @@ -81,12 +81,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) # TODO: this can be multithreaded for (i, uncoupled) in enumerate(uncoupleds_src_unique) inds_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = structure_src.fusiontreelist[inds_src] + fusiontrees_outer_src = map(i -> gettokenvalue(treelist_src, i), inds_src) uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) - fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst] + fusiontrees_outer_dst = map(i -> gettokenvalue(treelist_dst, i), inds_dst) matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) @@ -127,9 +127,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) return transformer end -function repack_transformer_structure(structures, ids) - sz = structures[first(ids)][1] - strides_offsets = map(i -> (structures[i][2], structures[i][3]), ids) +function repack_transformer_structure(structures::Dictionary, ids) + sz = gettokenvalue(structures, first(ids))[1] + strides_offsets = map(ids) do i + s = gettokenvalue(structures, i) + return (s[2], s[3]) + end return sz, strides_offsets end diff --git a/test/other/hashed.jl b/test/other/hashed.jl new file mode 100644 index 000000000..ae0c18915 --- /dev/null +++ b/test/other/hashed.jl @@ -0,0 +1,48 @@ +using Test, TestExtras +using TensorKit +using TensorKit: Hashed + +@testset "Hashed" begin + @testset "default constructor" begin + h1 = @constinferred Hashed(42) + h2 = Hashed(42) + @test isequal(h1, h2) + @test hash(h1) == hash(h2) + @test parent(h1) == 42 + end + + @testset "custom hash function" begin + # hash only the length, ignoring contents + lenhash = (v, seed) -> hash(length(v), seed) + h1 = Hashed([1, 2, 3], lenhash) + h2 = Hashed([4, 5, 6], lenhash) + @test hash(h1) == hash(h2) + h3 = Hashed([1, 2], lenhash) + @test hash(h1) != hash(h3) + end + + @testset "custom isequal" begin + # consider vectors equal if they have the same length + lenequal = (a, b) -> length(a) == length(b) + h1 = Hashed([1, 2, 3], Base.hash, lenequal) + h2 = Hashed([4, 5, 6], Base.hash, lenequal) + h3 = Hashed([1, 2], Base.hash, lenequal) + @test isequal(h1, h2) + @test !isequal(h1, h3) + end + + @testset "Dict key usage" begin + d = Dict(Hashed(1) => "one", Hashed(2) => "two") + @test d[Hashed(1)] == "one" + @test d[Hashed(2)] == "two" + @test length(d) == 2 + end + + @testset "Dict with custom hash and isequal" begin + lenhash = (v, seed) -> hash(length(v), seed) + lenequal = (a, b) -> length(a) == length(b) + d = Dict(Hashed([1, 2, 3], lenhash, lenequal) => "length3") + # lookup with different contents but same length should succeed + @test d[Hashed([7, 8, 9], lenhash, lenequal)] == "length3" + end +end