From 48427d051850fed6fa8fce17c9cad872ace975d1 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 22 Mar 2026 14:10:50 -0400 Subject: [PATCH 01/12] add utility to customize hashing/equality --- src/auxiliary/dicts.jl | 21 ++++++++++++++++++ test/other/hashed.jl | 48 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 test/other/hashed.jl diff --git a/src/auxiliary/dicts.jl b/src/auxiliary/dicts.jl index f43838010..bfdf79b88 100644 --- a/src/auxiliary/dicts.jl +++ b/src/auxiliary/dicts.jl @@ -263,3 +263,24 @@ function Base.:(==)(d1::SortedVectorDict, d2::SortedVectorDict) end return true end + +""" + Hashed(value, hashfunction = Base.hash, isequal = Base.isequal) + +Wrapper struct to alter the `hash` and `isequal` implementations of a given value. +This is useful in the contexts of dictionaries, where you either want to customize the hashfunction, +or consider various values as equal with a different notion of equality. +""" +struct Hashed{T, Hash, Eq} + val::T +end + +Hashed(val, hash = Base.hash, eq = Base.isequal) = Hashed{typeof(val), hash, eq}(val) + +Base.parent(h::Hashed) = h.val + +# hash overload +Base.hash(h::Hashed{T, Hash}, seed::UInt) where {T, Hash} = Hash(parent(h), seed) + +# isequal overload +Base.isequal(h1::H, h2::H) where {Eq, H <: Hashed{<:Any, <:Any, Eq}} = Eq(parent(h1), parent(h2)) diff --git a/test/other/hashed.jl b/test/other/hashed.jl new file mode 100644 index 000000000..ae0c18915 --- /dev/null +++ b/test/other/hashed.jl @@ -0,0 +1,48 @@ +using Test, TestExtras +using TensorKit +using TensorKit: Hashed + +@testset "Hashed" begin + @testset "default constructor" begin + h1 = @constinferred Hashed(42) + h2 = Hashed(42) + @test isequal(h1, h2) + @test hash(h1) == hash(h2) + @test parent(h1) == 42 + end + + @testset "custom hash function" begin + # hash only the length, ignoring contents + lenhash = (v, seed) -> hash(length(v), seed) + h1 = Hashed([1, 2, 3], lenhash) + h2 = Hashed([4, 5, 6], lenhash) + @test hash(h1) == hash(h2) + h3 = Hashed([1, 2], lenhash) + @test hash(h1) != hash(h3) + end + + @testset "custom isequal" begin + # consider vectors equal if they have the same length + lenequal = (a, b) -> length(a) == length(b) + h1 = Hashed([1, 2, 3], Base.hash, lenequal) + h2 = Hashed([4, 5, 6], Base.hash, lenequal) + h3 = Hashed([1, 2], Base.hash, lenequal) + @test isequal(h1, h2) + @test !isequal(h1, h3) + end + + @testset "Dict key usage" begin + d = Dict(Hashed(1) => "one", Hashed(2) => "two") + @test d[Hashed(1)] == "one" + @test d[Hashed(2)] == "two" + @test length(d) == 2 + end + + @testset "Dict with custom hash and isequal" begin + lenhash = (v, seed) -> hash(length(v), seed) + lenequal = (a, b) -> length(a) == length(b) + d = Dict(Hashed([1, 2, 3], lenhash, lenequal) => "length3") + # lookup with different contents but same length should succeed + @test d[Hashed([7, 8, 9], lenhash, lenequal)] == "length3" + end +end From d038f96443b71d8c99de3c8f9766e80f64ba491b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 11:57:36 -0400 Subject: [PATCH 02/12] split out fusionblockstructure and fusiontreelist --- src/spaces/homspace.jl | 179 ++++++++++++++++++++++++++++++----------- 1 file changed, 133 insertions(+), 46 deletions(-) diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index d032a1586..ac9c7076a 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -136,7 +136,7 @@ dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`. """ -fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist +fusiontrees(W::HomSpace) = fusiontreelist(W).fusiontreelist # Operations on HomSpaces # ----------------------- @@ -297,12 +297,45 @@ end # sizes, strides, offset const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} +function sectorequal(W₁::HomSpace, W₂::HomSpace) + check_spacetype(W₁, W₂) + (numout(W₁) == numout(W₂) && numin(W₁) == numin(W₂)) || return false + for (w₁, w₂) in zip(codomain(W₁), codomain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + for (w₁, w₂) in zip(domain(W₁), domain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + return true +end +function sectorhash(W::HomSpace, h::UInt) + for w in codomain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + for w in domain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + return h +end + +""" + FusionTreeList{F₁, F₂} + +Charge-only structure encoding a bijection between the fusion tree pairs and a linear index. +This encodes the symmetry structure of a `HomSpace`, shared across all `HomSpace`s with the same `sectors` but varying degeneracies. +""" +struct FusionTreeList{F₁, F₂} + fusiontreelist::Vector{Tuple{F₁, F₂}} + fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} +end + struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreelist::Vector{Tuple{F₁, F₂}} fusiontreestructure::Vector{StridedStructure{N}} - fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} + treelist::FusionTreeList{F₁, F₂} end function fusionblockstructuretype(W::HomSpace) @@ -315,73 +348,129 @@ function fusionblockstructuretype(W::HomSpace) return FusionBlockStructure{I, N, F₁, F₂} end +Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where {S <: HomSpace} + I = sectortype(S) + F₁ = fusiontreetype(I, numout(S)) + F₂ = fusiontreetype(I, numin(S)) + return FusionTreeList{F₁, F₂} +end + +fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) + +@cached function fusiontreelist(key::Hashed{S})::fusiontreelisttype(key) where {S <: HomSpace} + W = parent(key) + codom, dom = codomain(W), domain(W) + I = sectortype(S) + N₁, N₂ = numout(S), numin(S) + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + + trees = Vector{Tuple{F₁, F₂}}() + + for c in blocksectors(W) + codom_start = length(trees) + 1 + n₁ = 0 + for f₂ in fusiontrees(dom, c) + if n₁ == 0 + # First f₂ for this sector: enumerate codomain trees and record how many there are. + for f₁ in fusiontrees(codom, c) + push!(trees, (f₁, f₂)) + end + n₁ = length(trees) - codom_start + 1 + else + # Subsequent f₂s: the codomain trees are already in the list at + # codom_start:codom_start+n₁-1, so read them back instead of recomputing. + for j in codom_start:(codom_start + n₁ - 1) + push!(trees, (trees[j][1], f₂)) + end + end + end + end + + treeindices = sizehint!( + FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(trees) + ) + for (i, f₁₂) in enumerate(trees) + treeindices[f₁₂] = i + end + + return FusionTreeList{F₁, F₂}(trees, treeindices) +end + +CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() + @cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) codom = codomain(W) dom = domain(W) - N₁ = length(codom) - N₂ = length(dom) + N = length(codom) + length(dom) I = sectortype(W) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - # output structure - blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range - fusiontreelist = Vector{Tuple{F₁, F₂}}() - fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset + treelist = fusiontreelist(W) + trees = treelist.fusiontreelist + L = length(trees) + fusiontreestructure = sizehint!(Vector{StridedStructure{N}}(), L) + blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() + # temporary data structures - splittingtrees = Vector{F₁}() - splittingstructure = Vector{Tuple{Int, Int}}() + splittingstructure = Vector{NTuple{numout(W), Int}}() - # main computational routine blockoffset = 0 - for c in blocksectors(W) - empty!(splittingtrees) - empty!(splittingstructure) + tree_index = 1 + while tree_index <= L + f₁, f₂ = trees[tree_index] + c = f₁.coupled + # compute subblock structure + # splitting tree data + empty!(splittingstructure) offset₁ = 0 - for f₁ in fusiontrees(codom, c) - push!(splittingtrees, f₁) - d₁ = dim(codom, f₁.uncoupled) - push!(splittingstructure, (offset₁, d₁)) + for (f₁′, f₂′) in view(trees, tree_index:L) + f₂′ == f₂ || break + s₁ = f₁′.uncoupled + d₁s = dims(codom, s₁) + d₁ = prod(d₁s) offset₁ += d₁ + push!(splittingstructure, d₁s) end blockdim₁ = offset₁ + n₁ = length(splittingstructure) strides = (1, blockdim₁) + # fusion tree data and combine offset₂ = 0 - for f₂ in fusiontrees(dom, c) - s₂ = f₂.uncoupled - d₂ = dim(dom, s₂) - for (f₁, (offset₁, d₁)) in zip(splittingtrees, splittingstructure) - push!(fusiontreelist, (f₁, f₂)) + n₂ = 0 + for (f₁′, f₂′) in view(trees, tree_index:n₁:L) + f₂′.coupled == c || break + n₂ += 1 + s₂ = f₂′.uncoupled + d₂s = dims(dom, s₂) + d₂ = prod(d₂s) + offset₁ = 0 + for d₁s in splittingstructure + d₁ = prod(d₁s) totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁ - subsz = (dims(codom, f₁.uncoupled)..., dims(dom, f₂.uncoupled)...) - @assert !any(isequal(0), subsz) + subsz = (d₁s..., d₂s...) + @assert !any(==(0), subsz) substr = _subblock_strides(subsz, (d₁, d₂), strides) push!(fusiontreestructure, (subsz, substr, totaloffset)) + offset₁ += d₁ end offset₂ += d₂ end + + # compute block structure blockdim₂ = offset₂ - blocksize = (blockdim₁, blockdim₂) - blocklength = blockdim₁ * blockdim₂ - blockrange = (blockoffset + 1):(blockoffset + blocklength) + blockrange = (blockoffset + 1):(blockoffset + blockdim₁ * blockdim₂) + blockstructure[c] = ((blockdim₁, blockdim₂), blockrange) + + # reset blockoffset = last(blockrange) - blockstructure[c] = (blocksize, blockrange) + tree_index += n₁ * n₂ end + @assert length(fusiontreestructure) == L - fusiontreeindices = sizehint!( - FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(fusiontreelist) - ) - for (i, f₁₂) in enumerate(fusiontreelist) - fusiontreeindices[f₁₂] = i - end - totaldim = blockoffset - structure = FusionBlockStructure( - totaldim, blockstructure, fusiontreelist, fusiontreestructure, fusiontreeindices - ) - return structure + return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure, treelist) end function _subblock_strides(subsz, sz, str) @@ -392,9 +481,7 @@ function _subblock_strides(subsz, sz, str) return strides end -function CacheStyle(::typeof(fusionblockstructure), W::HomSpace) - return GlobalLRUCache() -end +CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() # Diagonal ranges #---------------- From 42d6a0c2ab58afd51cafad582c35d9fafd070949 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 11:57:56 -0400 Subject: [PATCH 03/12] use new split throughout the code --- src/tensors/abstracttensor.jl | 2 +- src/tensors/braidingtensor.jl | 2 +- src/tensors/tensor.jl | 4 ++-- src/tensors/treetransformers.jl | 14 +++++++------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2d7239460..ecb6c11c8 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -392,7 +392,7 @@ hasblock(t::AbstractTensorMap, c::Sector) = c ∈ blocksectors(t) Return an iterator over all splitting - fusion tree pairs of a tensor. """ -fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist +fusiontrees(t::AbstractTensorMap) = fusiontrees(space(t)) fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t)) function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap} diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 0070bc2d4..d50342434 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -147,7 +147,7 @@ function block(b::BraidingTensor, s::Sector) base_offset = first(structure.blockstructure[s][2]) - 1 for ((f1, f2), (sz, str, off)) in - zip(structure.fusiontreelist, structure.fusiontreestructure) + zip(structure.treelist.fusiontreelist, structure.fusiontreestructure) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 615c67775..4a4641c1c 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -489,10 +489,10 @@ function subblock( ) where {T, S, N₁, N₂, I <: Sector} structure = fusionblockstructure(t) @boundscheck begin - haskey(structure.fusiontreeindices, (f₁, f₂)) || throw(SectorMismatch()) + haskey(structure.treelist.fusiontreeindices, (f₁, f₂)) || throw(SectorMismatch()) end @inbounds begin - i = structure.fusiontreeindices[(f₁, f₂)] + i = structure.treelist.fusiontreeindices[(f₁, f₂)] sz, str, offset = structure.fusiontreestructure[i] return StridedView(t.data, sz, str, offset) end diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 36cd3926d..e5e919e3d 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -19,15 +19,15 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) structure_dst = fusionblockstructure(Vdst) structure_src = fusionblockstructure(Vsrc) - L = length(structure_src.fusiontreelist) + L = length(structure_src.treelist.fusiontreelist) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) data = Vector{Tuple{T, StridedStructure{N}, StridedStructure{N}}}(undef, L) for i in 1:L - f₁, f₂ = structure_src.fusiontreelist[i] + f₁, f₂ = structure_src.treelist.fusiontreelist[i] (f₃, f₄), coeff = only(transform(f₁, f₂)) - j = structure_dst.fusiontreeindices[(f₃, f₄)] + j = structure_dst.treelist.fusiontreeindices[(f₃, f₄)] stridestructure_dst = structure_dst.fusiontreestructure[j] stridestructure_src = structure_src.fusiontreestructure[i] data[i] = (coeff, stridestructure_dst, stridestructure_src) @@ -64,12 +64,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) fusionstructure_src = structure_src.fusiontreestructure I = sectortype(Vsrc) - uncoupleds_src = map(structure_src.fusiontreelist) do (f₁, f₂) + uncoupleds_src = map(structure_src.treelist.fusiontreelist) do (f₁, f₂) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end uncoupleds_src_unique = unique(uncoupleds_src) - uncoupleds_dst = map(structure_dst.fusiontreelist) do (f₁, f₂) + uncoupleds_dst = map(structure_dst.treelist.fusiontreelist) do (f₁, f₂) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end @@ -81,12 +81,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) # TODO: this can be multithreaded for (i, uncoupled) in enumerate(uncoupleds_src_unique) inds_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = structure_src.fusiontreelist[inds_src] + fusiontrees_outer_src = structure_src.treelist.fusiontreelist[inds_src] uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) - fusiontrees_outer_dst = structure_dst.fusiontreelist[inds_dst] + fusiontrees_outer_dst = structure_dst.treelist.fusiontreelist[inds_dst] matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) From 180ae2b09cb9ae9e4848fcd15450d938ed2b8f28 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 12:42:33 -0400 Subject: [PATCH 04/12] reorganize tensor structure computations --- src/TensorKit.jl | 1 + src/spaces/homspace.jl | 229 +-------------------------------- src/tensors/tensorstructure.jl | 206 +++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 222 deletions(-) create mode 100644 src/tensors/tensorstructure.jl diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 688e7363c..aafd64562 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -218,6 +218,7 @@ end # Definitions and methods for tensors #------------------------------------- # general definitions +include("tensors/tensorstructure.jl") include("tensors/abstracttensor.jl") include("tensors/backends.jl") include("tensors/blockiterator.jl") diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index ac9c7076a..1b4b9fb70 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -41,8 +41,7 @@ spacetype(::Type{<:HomSpace{S}}) where {S} = S const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}} const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{ - S, ProductSpace{S, N₁}, - ProductSpace{S, N₂}, + S, ProductSpace{S, N₁}, ProductSpace{S, N₂}, } numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁ @@ -62,17 +61,12 @@ end →(dom::VectorSpace, codom::VectorSpace) = ←(codom, dom) function Base.show(io::IO, W::HomSpace) - if length(W.codomain) == 1 - print(io, W.codomain[1]) - else - print(io, W.codomain) - end - print(io, " ← ") - return if length(W.domain) == 1 - print(io, W.domain[1]) - else - print(io, W.domain) - end + return print( + io, + numout(W) == 1 ? codomain(W)[1] : codomain(W), + " ← ", + numin(W) == 1 ? domain(W)[1] : domain(W) + ) end """ @@ -290,212 +284,3 @@ function removeunit(P::HomSpace, ::Val{i}) where {i} return codomain(P) ← removeunit(domain(P), Val(i - numout(P))) end end - -# Block and fusion tree ranges: structure information for building tensors -#-------------------------------------------------------------------------- - -# sizes, strides, offset -const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} - -function sectorequal(W₁::HomSpace, W₂::HomSpace) - check_spacetype(W₁, W₂) - (numout(W₁) == numout(W₂) && numin(W₁) == numin(W₂)) || return false - for (w₁, w₂) in zip(codomain(W₁), codomain(W₂)) - isdual(w₁) == isdual(w₂) || return false - isequal(sectors(w₁), sectors(w₂)) || return false - end - for (w₁, w₂) in zip(domain(W₁), domain(W₂)) - isdual(w₁) == isdual(w₂) || return false - isequal(sectors(w₁), sectors(w₂)) || return false - end - return true -end -function sectorhash(W::HomSpace, h::UInt) - for w in codomain(W) - h = hash(sectors(w), hash(isdual(w), h)) - end - for w in domain(W) - h = hash(sectors(w), hash(isdual(w), h)) - end - return h -end - -""" - FusionTreeList{F₁, F₂} - -Charge-only structure encoding a bijection between the fusion tree pairs and a linear index. -This encodes the symmetry structure of a `HomSpace`, shared across all `HomSpace`s with the same `sectors` but varying degeneracies. -""" -struct FusionTreeList{F₁, F₂} - fusiontreelist::Vector{Tuple{F₁, F₂}} - fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} -end - -struct FusionBlockStructure{I, N, F₁, F₂} - totaldim::Int - blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreestructure::Vector{StridedStructure{N}} - treelist::FusionTreeList{F₁, F₂} -end - -function fusionblockstructuretype(W::HomSpace) - N₁ = length(codomain(W)) - N₂ = length(domain(W)) - N = N₁ + N₂ - I = sectortype(W) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - return FusionBlockStructure{I, N, F₁, F₂} -end - -Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where {S <: HomSpace} - I = sectortype(S) - F₁ = fusiontreetype(I, numout(S)) - F₂ = fusiontreetype(I, numin(S)) - return FusionTreeList{F₁, F₂} -end - -fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) - -@cached function fusiontreelist(key::Hashed{S})::fusiontreelisttype(key) where {S <: HomSpace} - W = parent(key) - codom, dom = codomain(W), domain(W) - I = sectortype(S) - N₁, N₂ = numout(S), numin(S) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - - trees = Vector{Tuple{F₁, F₂}}() - - for c in blocksectors(W) - codom_start = length(trees) + 1 - n₁ = 0 - for f₂ in fusiontrees(dom, c) - if n₁ == 0 - # First f₂ for this sector: enumerate codomain trees and record how many there are. - for f₁ in fusiontrees(codom, c) - push!(trees, (f₁, f₂)) - end - n₁ = length(trees) - codom_start + 1 - else - # Subsequent f₂s: the codomain trees are already in the list at - # codom_start:codom_start+n₁-1, so read them back instead of recomputing. - for j in codom_start:(codom_start + n₁ - 1) - push!(trees, (trees[j][1], f₂)) - end - end - end - end - - treeindices = sizehint!( - FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(trees) - ) - for (i, f₁₂) in enumerate(trees) - treeindices[f₁₂] = i - end - - return FusionTreeList{F₁, F₂}(trees, treeindices) -end - -CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() - -@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) - codom = codomain(W) - dom = domain(W) - N = length(codom) + length(dom) - I = sectortype(W) - - treelist = fusiontreelist(W) - trees = treelist.fusiontreelist - L = length(trees) - fusiontreestructure = sizehint!(Vector{StridedStructure{N}}(), L) - blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() - - - # temporary data structures - splittingstructure = Vector{NTuple{numout(W), Int}}() - - blockoffset = 0 - tree_index = 1 - while tree_index <= L - f₁, f₂ = trees[tree_index] - c = f₁.coupled - - # compute subblock structure - # splitting tree data - empty!(splittingstructure) - offset₁ = 0 - for (f₁′, f₂′) in view(trees, tree_index:L) - f₂′ == f₂ || break - s₁ = f₁′.uncoupled - d₁s = dims(codom, s₁) - d₁ = prod(d₁s) - offset₁ += d₁ - push!(splittingstructure, d₁s) - end - blockdim₁ = offset₁ - n₁ = length(splittingstructure) - strides = (1, blockdim₁) - - # fusion tree data and combine - offset₂ = 0 - n₂ = 0 - for (f₁′, f₂′) in view(trees, tree_index:n₁:L) - f₂′.coupled == c || break - n₂ += 1 - s₂ = f₂′.uncoupled - d₂s = dims(dom, s₂) - d₂ = prod(d₂s) - offset₁ = 0 - for d₁s in splittingstructure - d₁ = prod(d₁s) - totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁ - subsz = (d₁s..., d₂s...) - @assert !any(==(0), subsz) - substr = _subblock_strides(subsz, (d₁, d₂), strides) - push!(fusiontreestructure, (subsz, substr, totaloffset)) - offset₁ += d₁ - end - offset₂ += d₂ - end - - # compute block structure - blockdim₂ = offset₂ - blockrange = (blockoffset + 1):(blockoffset + blockdim₁ * blockdim₂) - blockstructure[c] = ((blockdim₁, blockdim₂), blockrange) - - # reset - blockoffset = last(blockrange) - tree_index += n₁ * n₂ - end - @assert length(fusiontreestructure) == L - - return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure, treelist) -end - -function _subblock_strides(subsz, sz, str) - sz_simplify = Strided.StridedViews._simplifydims(sz, str) - strides = Strided.StridedViews._computereshapestrides(subsz, sz_simplify...) - isnothing(strides) && - throw(ArgumentError("unexpected error in computing subblock strides")) - return strides -end - -CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() - -# Diagonal ranges -#---------------- -# TODO: is this something we want to cache? -function diagonalblockstructure(W::HomSpace) - ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || - throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) - structure = SectorDict{sectortype(W), UnitRange{Int}}() # range - offset = 0 - dom = domain(W)[1] - for c in blocksectors(W) - d = dim(dom, c) - structure[c] = offset .+ (1:d) - offset += d - end - return structure -end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl new file mode 100644 index 000000000..0063aa012 --- /dev/null +++ b/src/tensors/tensorstructure.jl @@ -0,0 +1,206 @@ +# Block and fusion tree ranges: structure information for building tensors +#-------------------------------------------------------------------------- + +# sizes, strides, offset +const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int} + +function sectorequal(W₁::HomSpace, W₂::HomSpace) + check_spacetype(W₁, W₂) + (numout(W₁) == numout(W₂) && numin(W₁) == numin(W₂)) || return false + for (w₁, w₂) in zip(codomain(W₁), codomain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + for (w₁, w₂) in zip(domain(W₁), domain(W₂)) + isdual(w₁) == isdual(w₂) || return false + isequal(sectors(w₁), sectors(w₂)) || return false + end + return true +end +function sectorhash(W::HomSpace, h::UInt) + for w in codomain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + for w in domain(W) + h = hash(sectors(w), hash(isdual(w), h)) + end + return h +end + +""" + FusionTreeList{F₁, F₂} + +Charge-only structure encoding a bijection between the fusion tree pairs and a linear index. +This encodes the symmetry structure of a `HomSpace`, shared across all `HomSpace`s with the same `sectors` but varying degeneracies. +""" +struct FusionTreeList{F₁, F₂} + fusiontreelist::Vector{Tuple{F₁, F₂}} + fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} +end + +struct FusionBlockStructure{I, N, F₁, F₂} + totaldim::Int + blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} + fusiontreestructure::Vector{StridedStructure{N}} + treelist::FusionTreeList{F₁, F₂} +end + +function fusionblockstructuretype(W::HomSpace) + N₁ = length(codomain(W)) + N₂ = length(domain(W)) + N = N₁ + N₂ + I = sectortype(W) + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + return FusionBlockStructure{I, N, F₁, F₂} +end + +Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where {S <: HomSpace} + I = sectortype(S) + F₁ = fusiontreetype(I, numout(S)) + F₂ = fusiontreetype(I, numin(S)) + return FusionTreeList{F₁, F₂} +end + +fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) + +@cached function fusiontreelist(key::Hashed{S})::fusiontreelisttype(key) where {S <: HomSpace} + W = parent(key) + codom, dom = codomain(W), domain(W) + I = sectortype(S) + N₁, N₂ = numout(S), numin(S) + F₁ = fusiontreetype(I, N₁) + F₂ = fusiontreetype(I, N₂) + + trees = Vector{Tuple{F₁, F₂}}() + + for c in blocksectors(W) + codom_start = length(trees) + 1 + n₁ = 0 + for f₂ in fusiontrees(dom, c) + if n₁ == 0 + # First f₂ for this sector: enumerate codomain trees and record how many there are. + for f₁ in fusiontrees(codom, c) + push!(trees, (f₁, f₂)) + end + n₁ = length(trees) - codom_start + 1 + else + # Subsequent f₂s: the codomain trees are already in the list at + # codom_start:codom_start+n₁-1, so read them back instead of recomputing. + for j in codom_start:(codom_start + n₁ - 1) + push!(trees, (trees[j][1], f₂)) + end + end + end + end + + treeindices = sizehint!(FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(trees)) + for (i, f₁₂) in enumerate(trees) + treeindices[f₁₂] = i + end + + return FusionTreeList{F₁, F₂}(trees, treeindices) +end + +CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() + +@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) + codom = codomain(W) + dom = domain(W) + N = length(codom) + length(dom) + I = sectortype(W) + + treelist = fusiontreelist(W) + trees = treelist.fusiontreelist + L = length(trees) + fusiontreestructure = sizehint!(Vector{StridedStructure{N}}(), L) + blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() + + + # temporary data structures + splittingstructure = Vector{NTuple{numout(W), Int}}() + + blockoffset = 0 + tree_index = 1 + while tree_index <= L + f₁, f₂ = trees[tree_index] + c = f₁.coupled + + # compute subblock structure + # splitting tree data + empty!(splittingstructure) + offset₁ = 0 + for (f₁′, f₂′) in view(trees, tree_index:L) + f₂′ == f₂ || break + s₁ = f₁′.uncoupled + d₁s = dims(codom, s₁) + d₁ = prod(d₁s) + offset₁ += d₁ + push!(splittingstructure, d₁s) + end + blockdim₁ = offset₁ + n₁ = length(splittingstructure) + strides = (1, blockdim₁) + + # fusion tree data and combine + offset₂ = 0 + n₂ = 0 + for (f₁′, f₂′) in view(trees, tree_index:n₁:L) + f₂′.coupled == c || break + n₂ += 1 + s₂ = f₂′.uncoupled + d₂s = dims(dom, s₂) + d₂ = prod(d₂s) + offset₁ = 0 + for d₁s in splittingstructure + d₁ = prod(d₁s) + totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁ + subsz = (d₁s..., d₂s...) + @assert !any(==(0), subsz) + substr = _subblock_strides(subsz, (d₁, d₂), strides) + push!(fusiontreestructure, (subsz, substr, totaloffset)) + offset₁ += d₁ + end + offset₂ += d₂ + end + + # compute block structure + blockdim₂ = offset₂ + blockrange = (blockoffset + 1):(blockoffset + blockdim₁ * blockdim₂) + blockstructure[c] = ((blockdim₁, blockdim₂), blockrange) + + # reset + blockoffset = last(blockrange) + tree_index += n₁ * n₂ + end + @assert length(fusiontreestructure) == L + + return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure, treelist) +end + +function _subblock_strides(subsz, sz, str) + sz_simplify = Strided.StridedViews._simplifydims(sz, str) + strides = Strided.StridedViews._computereshapestrides(subsz, sz_simplify...) + isnothing(strides) && + throw(ArgumentError("unexpected error in computing subblock strides")) + return strides +end + +CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() + +# Diagonal ranges +#---------------- +# TODO: is this something we want to cache? +function diagonalblockstructure(W::HomSpace) + ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || + throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) + structure = SectorDict{sectortype(W), UnitRange{Int}}() # range + offset = 0 + dom = domain(W)[1] + for c in blocksectors(W) + d = dim(dom, c) + structure[c] = offset .+ (1:d) + offset += d + end + return structure +end From 406db6ea1772f33ac2748127355118c1bdf7306a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 13:40:46 -0400 Subject: [PATCH 05/12] update docstrings --- src/tensors/tensorstructure.jl | 43 +++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index 0063aa012..83105b9b7 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -32,12 +32,33 @@ end Charge-only structure encoding a bijection between the fusion tree pairs and a linear index. This encodes the symmetry structure of a `HomSpace`, shared across all `HomSpace`s with the same `sectors` but varying degeneracies. + +See also [`fusiontreelist`](@ref). """ struct FusionTreeList{F₁, F₂} fusiontreelist::Vector{Tuple{F₁, F₂}} fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} end +""" + FusionBlockStructure{I, N, F₁, F₂} + +Full block structure of a `HomSpace`, encoding how a tensor's flat data vector is +partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. + +## Fields +- `totaldim`: total number of elements in the flat data vector. +- `blockstructure`: maps each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, where + `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, and + `r` is the corresponding index range in the flat data vector. +- `fusiontreestructure`: for each fusion tree pair `(f₁, f₂)` (in the same order as + `treelist`), a [`StridedStructure`](@ref) `(sizes, strides, offset)` describing the + sub-block as a strided view into the flat data vector. +- `treelist`: the underlying [`FusionTreeList`](@ref) providing the bijection between + fusion tree pairs and linear indices. + +See also [`fusionblockstructure`](@ref), [`FusionTreeList`](@ref). +""" struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} @@ -62,6 +83,16 @@ Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where return FusionTreeList{F₁, F₂} end +""" + fusiontreelist(W::HomSpace) -> FusionTreeList + +Return the [`FusionTreeList`](@ref) for `W`, enumerating all valid fusion tree pairs +`(f₁, f₂)` and providing a bijection to linear indices. The result is cached based on +the sector structure of `W` (ignoring degeneracy dimensions), so `HomSpace`s that share +the same sectors, dualities, and index count will reuse the same object. + +See also [`FusionTreeList`](@ref), [`fusionblockstructure`](@ref). +""" fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) @cached function fusiontreelist(key::Hashed{S})::fusiontreelisttype(key) where {S <: HomSpace} @@ -104,6 +135,17 @@ end CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() +@doc """ + fusionblockstructure(W::HomSpace) -> FusionBlockStructure + +Compute the full [`FusionBlockStructure`](@ref) for `W`, describing how a tensor's flat +data vector is laid out in terms of symmetry blocks and fusion-tree sub-blocks. The result +is cached per `HomSpace` instance (keyed by object identity, not sector structure, since +degeneracy dimensions affect the block sizes and offsets). + +See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). +""" fusionblockstructure(::HomSpace) + @cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) codom = codomain(W) dom = domain(W) @@ -190,7 +232,6 @@ CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() # Diagonal ranges #---------------- -# TODO: is this something we want to cache? function diagonalblockstructure(W::HomSpace) ((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) || throw(SpaceMismatch("Diagonal only support on V←V with a single space V")) From f4b2a5644cb9cf5da8612cc6a2968b83ede34808 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 14:14:01 -0400 Subject: [PATCH 06/12] switch to Dictionaries.Indices --- Project.toml | 4 ++- src/TensorKit.jl | 1 + src/spaces/homspace.jl | 2 +- src/tensors/braidingtensor.jl | 2 +- src/tensors/tensor.jl | 6 ++-- src/tensors/tensorstructure.jl | 54 ++++++++++++--------------------- src/tensors/treetransformers.jl | 16 +++++----- 7 files changed, 36 insertions(+), 49 deletions(-) diff --git a/Project.toml b/Project.toml index 7e82876c5..a3bf2ea84 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.3" +authors = ["Jutho Haegeman, Lukas Devos"] [deps] +Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" @@ -41,6 +42,7 @@ CUDA = "5.9" ChainRulesCore = "1" ChainRulesTestUtils = "1" Combinatorics = "1" +Dictionaries = "0.4" FiniteDifferences = "0.12" GPUArrays = "11.3.1" JET = "0.9, 0.10, 0.11" diff --git a/src/TensorKit.jl b/src/TensorKit.jl index aafd64562..df57da83b 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -118,6 +118,7 @@ const TO = TensorOperations using MatrixAlgebraKit +using Dictionaries: Dictionaries, Indices, gettoken, gettokenvalue using LRUCache using OhMyThreads using ScopedValues diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 1b4b9fb70..1d1410cce 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -130,7 +130,7 @@ dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`. """ -fusiontrees(W::HomSpace) = fusiontreelist(W).fusiontreelist +fusiontrees(W::HomSpace) = fusiontreelist(W) # Operations on HomSpaces # ----------------------- diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index d50342434..d028bf271 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -147,7 +147,7 @@ function block(b::BraidingTensor, s::Sector) base_offset = first(structure.blockstructure[s][2]) - 1 for ((f1, f2), (sz, str, off)) in - zip(structure.treelist.fusiontreelist, structure.fusiontreestructure) + zip(structure.treelist, structure.fusiontreestructure) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 4a4641c1c..d8699ebf6 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -488,11 +488,9 @@ function subblock( t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} structure = fusionblockstructure(t) - @boundscheck begin - haskey(structure.treelist.fusiontreeindices, (f₁, f₂)) || throw(SectorMismatch()) - end + found, i = gettoken(structure.treelist, (f₁, f₂)) + @boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present")) @inbounds begin - i = structure.treelist.fusiontreeindices[(f₁, f₂)] sz, str, offset = structure.fusiontreestructure[i] return StridedView(t.data, sz, str, offset) end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index 83105b9b7..d6c5bd882 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -27,19 +27,6 @@ function sectorhash(W::HomSpace, h::UInt) return h end -""" - FusionTreeList{F₁, F₂} - -Charge-only structure encoding a bijection between the fusion tree pairs and a linear index. -This encodes the symmetry structure of a `HomSpace`, shared across all `HomSpace`s with the same `sectors` but varying degeneracies. - -See also [`fusiontreelist`](@ref). -""" -struct FusionTreeList{F₁, F₂} - fusiontreelist::Vector{Tuple{F₁, F₂}} - fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int} -end - """ FusionBlockStructure{I, N, F₁, F₂} @@ -54,16 +41,16 @@ partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. - `fusiontreestructure`: for each fusion tree pair `(f₁, f₂)` (in the same order as `treelist`), a [`StridedStructure`](@ref) `(sizes, strides, offset)` describing the sub-block as a strided view into the flat data vector. -- `treelist`: the underlying [`FusionTreeList`](@ref) providing the bijection between - fusion tree pairs and linear indices. +- `treelist`: an `Indices{Tuple{F₁,F₂}}` providing a bijection between fusion tree pairs + and sequential integer positions. -See also [`fusionblockstructure`](@ref), [`FusionTreeList`](@ref). +See also [`fusionblockstructure`](@ref), [`fusiontreelist`](@ref). """ struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} fusiontreestructure::Vector{StridedStructure{N}} - treelist::FusionTreeList{F₁, F₂} + treelist::Indices{Tuple{F₁, F₂}} end function fusionblockstructuretype(W::HomSpace) @@ -80,18 +67,19 @@ Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where I = sectortype(S) F₁ = fusiontreetype(I, numout(S)) F₂ = fusiontreetype(I, numin(S)) - return FusionTreeList{F₁, F₂} + return Indices{Tuple{F₁, F₂}} end """ - fusiontreelist(W::HomSpace) -> FusionTreeList + fusiontreelist(W::HomSpace) -> Indices{Tuple{F₁,F₂}} -Return the [`FusionTreeList`](@ref) for `W`, enumerating all valid fusion tree pairs -`(f₁, f₂)` and providing a bijection to linear indices. The result is cached based on -the sector structure of `W` (ignoring degeneracy dimensions), so `HomSpace`s that share -the same sectors, dualities, and index count will reuse the same object. +Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a +bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is +cached based on the sector structure of `W` (ignoring degeneracy dimensions), so +`HomSpace`s that share the same sectors, dualities, and index count will reuse the same +object. -See also [`FusionTreeList`](@ref), [`fusionblockstructure`](@ref). +See also [`fusionblockstructure`](@ref). """ fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) @@ -125,12 +113,7 @@ fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) end end - treeindices = sizehint!(FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(trees)) - for (i, f₁₂) in enumerate(trees) - treeindices[f₁₂] = i - end - - return FusionTreeList{F₁, F₂}(trees, treeindices) + return Indices(trees) end CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() @@ -153,8 +136,7 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). I = sectortype(W) treelist = fusiontreelist(W) - trees = treelist.fusiontreelist - L = length(trees) + L = length(treelist) fusiontreestructure = sizehint!(Vector{StridedStructure{N}}(), L) blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() @@ -165,14 +147,15 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). blockoffset = 0 tree_index = 1 while tree_index <= L - f₁, f₂ = trees[tree_index] + f₁, f₂ = gettokenvalue(treelist, tree_index) c = f₁.coupled # compute subblock structure # splitting tree data empty!(splittingstructure) offset₁ = 0 - for (f₁′, f₂′) in view(trees, tree_index:L) + for i in tree_index:L + f₁′, f₂′ = gettokenvalue(treelist, i) f₂′ == f₂ || break s₁ = f₁′.uncoupled d₁s = dims(codom, s₁) @@ -187,7 +170,8 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). # fusion tree data and combine offset₂ = 0 n₂ = 0 - for (f₁′, f₂′) in view(trees, tree_index:n₁:L) + for i in tree_index:n₁:L + f₁′, f₂′ = gettokenvalue(treelist, i) f₂′.coupled == c || break n₂ += 1 s₂ = f₂′.uncoupled diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index e5e919e3d..37a361931 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -19,15 +19,15 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) structure_dst = fusionblockstructure(Vdst) structure_src = fusionblockstructure(Vsrc) - L = length(structure_src.treelist.fusiontreelist) + L = length(structure_src.treelist) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) data = Vector{Tuple{T, StridedStructure{N}, StridedStructure{N}}}(undef, L) for i in 1:L - f₁, f₂ = structure_src.treelist.fusiontreelist[i] + f₁, f₂ = gettokenvalue(structure_src.treelist, i) (f₃, f₄), coeff = only(transform(f₁, f₂)) - j = structure_dst.treelist.fusiontreeindices[(f₃, f₄)] + _, j = gettoken(structure_dst.treelist, (f₃, f₄)) stridestructure_dst = structure_dst.fusiontreestructure[j] stridestructure_src = structure_src.fusiontreestructure[i] data[i] = (coeff, stridestructure_dst, stridestructure_src) @@ -64,12 +64,14 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) fusionstructure_src = structure_src.fusiontreestructure I = sectortype(Vsrc) - uncoupleds_src = map(structure_src.treelist.fusiontreelist) do (f₁, f₂) + uncoupleds_src = map(1:length(structure_src.treelist)) do i + f₁, f₂ = gettokenvalue(structure_src.treelist, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end uncoupleds_src_unique = unique(uncoupleds_src) - uncoupleds_dst = map(structure_dst.treelist.fusiontreelist) do (f₁, f₂) + uncoupleds_dst = map(1:length(structure_dst.treelist)) do i + f₁, f₂ = gettokenvalue(structure_dst.treelist, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end @@ -81,12 +83,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) # TODO: this can be multithreaded for (i, uncoupled) in enumerate(uncoupleds_src_unique) inds_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = structure_src.treelist.fusiontreelist[inds_src] + fusiontrees_outer_src = map(i -> gettokenvalue(structure_src.treelist, i), inds_src) uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) - fusiontrees_outer_dst = structure_dst.treelist.fusiontreelist[inds_dst] + fusiontrees_outer_dst = map(i -> gettokenvalue(structure_dst.treelist, i), inds_dst) matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) From 80506dfee4889a056855d15b411f4d3e63bd0e59 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 14:41:31 -0400 Subject: [PATCH 07/12] switch even more to `Dictionaries.Dictionary` --- src/TensorKit.jl | 2 +- src/tensors/braidingtensor.jl | 3 +-- src/tensors/tensor.jl | 4 ++-- src/tensors/tensorstructure.jl | 22 ++++++++++----------- src/tensors/treetransformers.jl | 34 ++++++++++++++++++--------------- 5 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index df57da83b..4de6159c9 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -118,7 +118,7 @@ const TO = TensorOperations using MatrixAlgebraKit -using Dictionaries: Dictionaries, Indices, gettoken, gettokenvalue +using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue using LRUCache using OhMyThreads using ScopedValues diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index d028bf271..0ca317026 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -146,8 +146,7 @@ function block(b::BraidingTensor, s::Sector) structure = fusionblockstructure(b) base_offset = first(structure.blockstructure[s][2]) - 1 - for ((f1, f2), (sz, str, off)) in - zip(structure.treelist, structure.fusiontreestructure) + for ((f1, f2), (sz, str, off)) in pairs(structure.fusiontreestructure) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index d8699ebf6..f81ee53fe 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -488,10 +488,10 @@ function subblock( t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} structure = fusionblockstructure(t) - found, i = gettoken(structure.treelist, (f₁, f₂)) + found, token = gettoken(structure.fusiontreestructure, (f₁, f₂)) @boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present")) @inbounds begin - sz, str, offset = structure.fusiontreestructure[i] + sz, str, offset = gettokenvalue(structure.fusiontreestructure, token) return StridedView(t.data, sz, str, offset) end end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index d6c5bd882..731de57e3 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -38,19 +38,17 @@ partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. - `blockstructure`: maps each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, where `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, and `r` is the corresponding index range in the flat data vector. -- `fusiontreestructure`: for each fusion tree pair `(f₁, f₂)` (in the same order as - `treelist`), a [`StridedStructure`](@ref) `(sizes, strides, offset)` describing the - sub-block as a strided view into the flat data vector. -- `treelist`: an `Indices{Tuple{F₁,F₂}}` providing a bijection between fusion tree pairs - and sequential integer positions. +- `fusiontreestructure`: a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to a + [`StridedStructure`](@ref) `(sizes, strides, offset)` describing the sub-block as a + strided view into the flat data vector. The insertion order of the dictionary matches + the canonical enumeration order from [`fusiontreelist`](@ref). See also [`fusionblockstructure`](@ref), [`fusiontreelist`](@ref). """ struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreestructure::Vector{StridedStructure{N}} - treelist::Indices{Tuple{F₁, F₂}} + fusiontreestructure::Dictionary{Tuple{F₁, F₂}, StridedStructure{N}} end function fusionblockstructuretype(W::HomSpace) @@ -137,10 +135,9 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). treelist = fusiontreelist(W) L = length(treelist) - fusiontreestructure = sizehint!(Vector{StridedStructure{N}}(), L) + structurevalues = sizehint!(Vector{StridedStructure{N}}(), L) blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() - # temporary data structures splittingstructure = Vector{NTuple{numout(W), Int}}() @@ -184,7 +181,7 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). subsz = (d₁s..., d₂s...) @assert !any(==(0), subsz) substr = _subblock_strides(subsz, (d₁, d₂), strides) - push!(fusiontreestructure, (subsz, substr, totaloffset)) + push!(structurevalues, (subsz, substr, totaloffset)) offset₁ += d₁ end offset₂ += d₂ @@ -199,9 +196,10 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). blockoffset = last(blockrange) tree_index += n₁ * n₂ end - @assert length(fusiontreestructure) == L + @assert length(structurevalues) == L - return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure, treelist) + fusiontreestructure = Dictionary(treelist, structurevalues) + return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure) end function _subblock_strides(subsz, sz, str) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 37a361931..ef211c4ac 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -19,17 +19,15 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) structure_dst = fusionblockstructure(Vdst) structure_src = fusionblockstructure(Vsrc) - L = length(structure_src.treelist) + L = length(structure_src.fusiontreestructure) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) data = Vector{Tuple{T, StridedStructure{N}, StridedStructure{N}}}(undef, L) - for i in 1:L - f₁, f₂ = gettokenvalue(structure_src.treelist, i) + for (i, ((f₁, f₂), stridestructure_src)) in enumerate(pairs(structure_src.fusiontreestructure)) (f₃, f₄), coeff = only(transform(f₁, f₂)) - _, j = gettoken(structure_dst.treelist, (f₃, f₄)) - stridestructure_dst = structure_dst.fusiontreestructure[j] - stridestructure_src = structure_src.fusiontreestructure[i] + _, token = gettoken(structure_dst.fusiontreestructure, (f₃, f₄)) + stridestructure_dst = gettokenvalue(structure_dst.fusiontreestructure, token) data[i] = (coeff, stridestructure_dst, stridestructure_src) end @@ -64,14 +62,17 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) fusionstructure_src = structure_src.fusiontreestructure I = sectortype(Vsrc) - uncoupleds_src = map(1:length(structure_src.treelist)) do i - f₁, f₂ = gettokenvalue(structure_src.treelist, i) + treelist_src = keys(structure_src.fusiontreestructure) + treelist_dst = keys(structure_dst.fusiontreestructure) + + uncoupleds_src = map(1:length(treelist_src)) do i + f₁, f₂ = gettokenvalue(treelist_src, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end uncoupleds_src_unique = unique(uncoupleds_src) - uncoupleds_dst = map(1:length(structure_dst.treelist)) do i - f₁, f₂ = gettokenvalue(structure_dst.treelist, i) + uncoupleds_dst = map(1:length(treelist_dst)) do i + f₁, f₂ = gettokenvalue(treelist_dst, i) return TupleTools.vcat(f₁.uncoupled, dual.(f₂.uncoupled)) end @@ -83,12 +84,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) # TODO: this can be multithreaded for (i, uncoupled) in enumerate(uncoupleds_src_unique) inds_src = findall(==(uncoupled), uncoupleds_src) - fusiontrees_outer_src = map(i -> gettokenvalue(structure_src.treelist, i), inds_src) + fusiontrees_outer_src = map(i -> gettokenvalue(treelist_src, i), inds_src) uncoupled_dst = TupleTools.getindices(uncoupled, (p[1]..., p[2]...)) inds_dst = findall(==(uncoupled_dst), uncoupleds_dst) - fusiontrees_outer_dst = map(i -> gettokenvalue(structure_dst.treelist, i), inds_dst) + fusiontrees_outer_dst = map(i -> gettokenvalue(treelist_dst, i), inds_dst) matrix = zeros(sectorscalartype(I), length(inds_dst), length(inds_src)) for (row, (f₁, f₂)) in enumerate(fusiontrees_outer_src) @@ -129,9 +130,12 @@ function GenericTreeTransformer(transform, p, Vdst, Vsrc) return transformer end -function repack_transformer_structure(structures, ids) - sz = structures[first(ids)][1] - strides_offsets = map(i -> (structures[i][2], structures[i][3]), ids) +function repack_transformer_structure(structures::Dictionary, ids) + sz = gettokenvalue(structures, first(ids))[1] + strides_offsets = map(ids) do i + s = gettokenvalue(structures, i) + return (s[2], s[3]) + end return sz, strides_offsets end From dfbcb87244370e9d74e3f689d07044635462c4d7 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 17:02:51 -0400 Subject: [PATCH 08/12] remove fusiontreelist --- src/spaces/homspace.jl | 6 ------ src/tensors/tensorstructure.jl | 16 ++++++++-------- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 1d1410cce..347e5551e 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -125,12 +125,6 @@ end dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) -""" - fusiontrees(W::HomSpace) - -Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`. -""" -fusiontrees(W::HomSpace) = fusiontreelist(W) # Operations on HomSpaces # ----------------------- diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index 731de57e3..902872255 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -43,7 +43,7 @@ partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. strided view into the flat data vector. The insertion order of the dictionary matches the canonical enumeration order from [`fusiontreelist`](@ref). -See also [`fusionblockstructure`](@ref), [`fusiontreelist`](@ref). +See also [`fusionblockstructure`](@ref), [`fusiontrees`](@ref). """ struct FusionBlockStructure{I, N, F₁, F₂} totaldim::Int @@ -61,7 +61,7 @@ function fusionblockstructuretype(W::HomSpace) return FusionBlockStructure{I, N, F₁, F₂} end -Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where {S <: HomSpace} +Base.@assume_effects :foldable function fusiontreestype(key::Hashed{S}) where {S <: HomSpace} I = sectortype(S) F₁ = fusiontreetype(I, numout(S)) F₂ = fusiontreetype(I, numin(S)) @@ -69,7 +69,7 @@ Base.@assume_effects :foldable function fusiontreelisttype(key::Hashed{S}) where end """ - fusiontreelist(W::HomSpace) -> Indices{Tuple{F₁,F₂}} + fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is @@ -79,9 +79,9 @@ object. See also [`fusionblockstructure`](@ref). """ -fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) +fusiontrees(W::HomSpace) = fusiontrees(Hashed(W, sectorhash, sectorequal)) -@cached function fusiontreelist(key::Hashed{S})::fusiontreelisttype(key) where {S <: HomSpace} +@cached function fusiontrees(key::Hashed{S})::fusiontreestype(key) where {S <: HomSpace} W = parent(key) codom, dom = codomain(W), domain(W) I = sectortype(S) @@ -114,7 +114,7 @@ fusiontreelist(W::HomSpace) = fusiontreelist(Hashed(W, sectorhash, sectorequal)) return Indices(trees) end -CacheStyle(::typeof(fusiontreelist), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() +CacheStyle(::typeof(fusiontrees), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() @doc """ fusionblockstructure(W::HomSpace) -> FusionBlockStructure @@ -124,7 +124,7 @@ data vector is laid out in terms of symmetry blocks and fusion-tree sub-blocks. is cached per `HomSpace` instance (keyed by object identity, not sector structure, since degeneracy dimensions affect the block sizes and offsets). -See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). +See also [`FusionBlockStructure`](@ref), [`fusiontrees`](@ref). """ fusionblockstructure(::HomSpace) @cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) @@ -133,7 +133,7 @@ See also [`FusionBlockStructure`](@ref), [`fusiontreelist`](@ref). N = length(codom) + length(dom) I = sectortype(W) - treelist = fusiontreelist(W) + treelist = fusiontrees(W) L = length(treelist) structurevalues = sizehint!(Vector{StridedStructure{N}}(), L) blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() From 0fac8ace8956c42520d4679c5c62eac194443a4a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 23 Mar 2026 17:04:25 -0400 Subject: [PATCH 09/12] avoid storing fusiontrees in fusionblockstructure --- src/tensors/braidingtensor.jl | 2 +- src/tensors/tensor.jl | 6 +++--- src/tensors/tensorstructure.jl | 37 ++++++++++++++++++++------------- src/tensors/treetransformers.jl | 20 +++++++++--------- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 0ca317026..6b2544bff 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -146,7 +146,7 @@ function block(b::BraidingTensor, s::Sector) structure = fusionblockstructure(b) base_offset = first(structure.blockstructure[s][2]) - 1 - for ((f1, f2), (sz, str, off)) in pairs(structure.fusiontreestructure) + for ((f1, f2), (sz, str, off)) in pairs(fusiontreestructure(space(b))) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index f81ee53fe..bacdc30d1 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -487,11 +487,11 @@ end function subblock( t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} - structure = fusionblockstructure(t) - found, token = gettoken(structure.fusiontreestructure, (f₁, f₂)) + fts = fusiontreestructure(space(t)) + found, token = gettoken(fts, (f₁, f₂)) @boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present")) @inbounds begin - sz, str, offset = gettokenvalue(structure.fusiontreestructure, token) + sz, str, offset = gettokenvalue(fts, token) return StridedView(t.data, sz, str, offset) end end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index 902872255..281d3388c 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -28,7 +28,7 @@ function sectorhash(W::HomSpace, h::UInt) end """ - FusionBlockStructure{I, N, F₁, F₂} + FusionBlockStructure{I, N} Full block structure of a `HomSpace`, encoding how a tensor's flat data vector is partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. @@ -38,27 +38,23 @@ partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. - `blockstructure`: maps each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, where `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, and `r` is the corresponding index range in the flat data vector. -- `fusiontreestructure`: a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to a - [`StridedStructure`](@ref) `(sizes, strides, offset)` describing the sub-block as a - strided view into the flat data vector. The insertion order of the dictionary matches - the canonical enumeration order from [`fusiontreelist`](@ref). +- `fusiontreestructure`: a `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` + values, one per fusion tree pair, in the canonical enumeration order from + [`fusiontrees`](@ref). Use `fusiontrees` to obtain the corresponding `Indices` of + fusion tree pairs. See also [`fusionblockstructure`](@ref), [`fusiontrees`](@ref). """ -struct FusionBlockStructure{I, N, F₁, F₂} +struct FusionBlockStructure{I, N} totaldim::Int blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreestructure::Dictionary{Tuple{F₁, F₂}, StridedStructure{N}} + fusiontreestructure::Vector{StridedStructure{N}} end function fusionblockstructuretype(W::HomSpace) - N₁ = length(codomain(W)) - N₂ = length(domain(W)) - N = N₁ + N₂ + N = length(codomain(W)) + length(domain(W)) I = sectortype(W) - F₁ = fusiontreetype(I, N₁) - F₂ = fusiontreetype(I, N₂) - return FusionBlockStructure{I, N, F₁, F₂} + return FusionBlockStructure{I, N} end Base.@assume_effects :foldable function fusiontreestype(key::Hashed{S}) where {S <: HomSpace} @@ -198,8 +194,7 @@ See also [`FusionBlockStructure`](@ref), [`fusiontrees`](@ref). end @assert length(structurevalues) == L - fusiontreestructure = Dictionary(treelist, structurevalues) - return FusionBlockStructure(blockoffset, blockstructure, fusiontreestructure) + return FusionBlockStructure(blockoffset, blockstructure, structurevalues) end function _subblock_strides(subsz, sz, str) @@ -212,6 +207,18 @@ end CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() +""" + fusiontreestructure(W::HomSpace) -> Dictionary + +Return a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to its +[`StridedStructure`](@ref) `(sizes, strides, offset)`. This wraps the cached +[`fusiontrees`](@ref) `Indices` together with the values stored in +[`fusionblockstructure`](@ref), with no data copying. +""" +function fusiontreestructure(W::HomSpace) + return Dictionary(fusiontrees(W), fusionblockstructure(W).fusiontreestructure) +end + # Diagonal ranges #---------------- function diagonalblockstructure(W::HomSpace) diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index ef211c4ac..191147a9c 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -19,15 +19,17 @@ function AbelianTreeTransformer(transform, p, Vdst, Vsrc) structure_dst = fusionblockstructure(Vdst) structure_src = fusionblockstructure(Vsrc) - L = length(structure_src.fusiontreestructure) + fts_src = fusiontreestructure(Vsrc) + fts_dst = fusiontreestructure(Vdst) + L = length(fts_src) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) data = Vector{Tuple{T, StridedStructure{N}, StridedStructure{N}}}(undef, L) - for (i, ((f₁, f₂), stridestructure_src)) in enumerate(pairs(structure_src.fusiontreestructure)) + for (i, ((f₁, f₂), stridestructure_src)) in enumerate(pairs(fts_src)) (f₃, f₄), coeff = only(transform(f₁, f₂)) - _, token = gettoken(structure_dst.fusiontreestructure, (f₃, f₄)) - stridestructure_dst = gettokenvalue(structure_dst.fusiontreestructure, token) + _, token = gettoken(fts_dst, (f₃, f₄)) + stridestructure_dst = gettokenvalue(fts_dst, token) data[i] = (coeff, stridestructure_dst, stridestructure_src) end @@ -56,14 +58,12 @@ end function GenericTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - structure_dst = fusionblockstructure(Vdst) - fusionstructure_dst = structure_dst.fusiontreestructure - structure_src = fusionblockstructure(Vsrc) - fusionstructure_src = structure_src.fusiontreestructure + fusionstructure_dst = fusiontreestructure(Vdst) + fusionstructure_src = fusiontreestructure(Vsrc) I = sectortype(Vsrc) - treelist_src = keys(structure_src.fusiontreestructure) - treelist_dst = keys(structure_dst.fusiontreestructure) + treelist_src = keys(fusionstructure_src) + treelist_dst = keys(fusionstructure_dst) uncoupleds_src = map(1:length(treelist_src)) do i f₁, f₂ = gettokenvalue(treelist_src, i) From 543953dbaf034c08237617cccae6db293b9a9701 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 24 Mar 2026 11:05:39 -0400 Subject: [PATCH 10/12] clean cache separation --- docs/src/lib/tensors.md | 1 - ext/TensorKitMooncakeExt/utility.jl | 3 +- src/spaces/homspace.jl | 30 ++---- src/tensors/abstracttensor.jl | 12 +-- src/tensors/braidingtensor.jl | 3 +- src/tensors/tensor.jl | 26 ++--- src/tensors/tensoroperations.jl | 3 +- src/tensors/tensorstructure.jl | 161 ++++++++++++++++++---------- src/tensors/treetransformers.jl | 3 - 9 files changed, 135 insertions(+), 107 deletions(-) diff --git a/docs/src/lib/tensors.md b/docs/src/lib/tensors.md index ea491a843..7b12cdf18 100644 --- a/docs/src/lib/tensors.md +++ b/docs/src/lib/tensors.md @@ -97,7 +97,6 @@ In `TensorMap` instances, all data is gathered in a single `AbstractVector`, whi To obtain information about the structure of the data, you can use: ```@docs -fusionblockstructure(::AbstractTensorMap) dim(::AbstractTensorMap) blocksectors(::AbstractTensorMap) hasblock(::AbstractTensorMap, ::Sector) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index 779a4e017..ceb32d867 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -62,7 +62,8 @@ end Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent -@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 347e5551e..9f5ba920c 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -69,16 +69,16 @@ function Base.show(io::IO, W::HomSpace) ) end -""" - blocksectors(W::HomSpace) +@doc """ + blocksectors(W::HomSpace) -> Indices{I} -Return an iterator over the different unique coupled sector labels, i.e. the intersection -of the different fusion outputs that can be obtained by fusing the sectors present in the -domain, as well as from the codomain. +Return an `Indices` of all coupled sectors for `W`. The result is cached based on the +sector structure of `W` (ignoring degeneracy dimensions). -See also [`hasblock`](@ref). -""" -function blocksectors(W::HomSpace) +See also [`hasblock`](@ref), [`blockstructure`](@ref). +""" blocksectors(::HomSpace) + +function _blocksectors(W::HomSpace) sectortype(W) === Trivial && return OneOrNoneIterator(dim(domain(W)) != 0 && dim(codomain(W)) != 0, Trivial()) @@ -109,23 +109,15 @@ See also [`blocksectors`](@ref). """ hasblock(W::HomSpace, c::Sector) = hasblock(codomain(W), c) && hasblock(domain(W), c) -""" - dim(W::HomSpace) +@doc """ + dim(W::HomSpace) -> Int Return the total dimension of a `HomSpace`, i.e. the number of linearly independent morphisms that can be constructed within this space. -""" -function dim(W::HomSpace) - d = 0 - for c in blocksectors(W) - d += blockdim(codomain(W), c) * blockdim(domain(W), c) - end - return d -end +""" dim(::HomSpace) dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) - # Operations on HomSpaces # ----------------------- """ diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index ecb6c11c8..2ed4129e2 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -321,21 +321,13 @@ numind(t::AbstractTensorMap) = numind(typeof(t)) # tensor characteristics: data structure and properties #------------------------------------------------------ -""" - fusionblockstructure(t::AbstractTensorMap) -> TensorStructure - -Return the necessary structure information to decompose a tensor in blocks labeled by -coupled sectors and in subblocks labeled by a splitting-fusion tree couple. -""" -fusionblockstructure(t::AbstractTensorMap) = fusionblockstructure(space(t)) - """ dim(t::AbstractTensorMap) -> Int The total number of free parameters of a tensor, discounting the entries that are fixed by symmetry. This is also the dimension of the `HomSpace` on which the `TensorMap` is defined. """ -dim(t::AbstractTensorMap) = fusionblockstructure(t).totaldim +dim(t::AbstractTensorMap) = dim(space(t)) dims(t::AbstractTensorMap) = dims(space(t)) @@ -344,7 +336,7 @@ dims(t::AbstractTensorMap) = dims(space(t)) Return an iterator over all coupled sectors of a tensor. """ -blocksectors(t::AbstractTensorMap) = keys(fusionblockstructure(t).blockstructure) +blocksectors(t::AbstractTensorMap) = blocksectors(space(t)) """ hasblock(t::AbstractTensorMap, c::Sector) -> Bool diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 6b2544bff..1a86e30b0 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -143,8 +143,7 @@ function block(b::BraidingTensor, s::Sector) return data end - structure = fusionblockstructure(b) - base_offset = first(structure.blockstructure[s][2]) - 1 + base_offset = first(blockstructure(b)[s][2]) - 1 for ((f1, f2), (sz, str, off)) in pairs(fusiontreestructure(space(b))) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index bacdc30d1..26a9d127a 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -15,8 +15,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac function TensorMap{T, S, N₁, N₂, A}( ::UndefInitializer, space::TensorMapSpace{S, N₁, N₂} ) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} - d = fusionblockstructure(space).totaldim - data = A(undef, d) + data = A(undef, dim(space)) if !isbitstype(T) zerovector!(data) end @@ -31,8 +30,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac I = sectortype(S) T <: Real && !(sectorscalartype(I) <: Real) && @warn("Tensors with real data might be incompatible with sector type $I", maxlog = 1) - d = fusionblockstructure(space).totaldim - length(data) == d || throw(DimensionMismatch("invalid length of data")) + length(data) == dim(space) || throw(DimensionMismatch("invalid length of data")) return new{T, S, N₁, N₂, A}(data, space) end end @@ -453,7 +451,7 @@ end #------------------------------------------------- block(t::TensorMap, c::Sector) = blocks(t)[c] -blocks(t::TensorMap) = BlockIterator(t, fusionblockstructure(t).blockstructure) +blocks(t::TensorMap) = BlockIterator(t, blockstructure(space(t))) function blocktype(::Type{TT}) where {TT <: TensorMap} A = storagetype(TT) @@ -462,7 +460,7 @@ function blocktype(::Type{TT}) where {TT <: TensorMap} end function Base.iterate(iter::BlockIterator{<:TensorMap}, state...) - next = iterate(iter.structure, state...) + next = iterate(pairs(iter.structure), state...) isnothing(next) && return next (c, (sz, r)), newstate = next return c => reshape(view(iter.t.data, r), sz), newstate @@ -470,16 +468,18 @@ end function Base.getindex(iter::BlockIterator{<:TensorMap}, c::Sector) sectortype(iter.t) === typeof(c) || throw(SectorMismatch()) - (d₁, d₂), r = get(iter.structure, c) do - # is s is not a key, at least one of the two dimensions will be zero: + found, token = gettoken(iter.structure, c) + if found + (d₁, d₂), r = gettokenvalue(iter.structure, token) + return reshape(view(iter.t.data, r), (d₁, d₂)) + else + # if c is not a key, at least one of the two dimensions will be zero: # it then does not matter where exactly we construct a view in `t.data`, # as it will have length zero anyway - d₁′ = blockdim(codomain(iter.t), c) - d₂′ = blockdim(domain(iter.t), c) - l = d₁′ * d₂′ - return (d₁′, d₂′), 1:l + d₁ = blockdim(codomain(iter.t), c) + d₂ = blockdim(domain(iter.t), c) + return reshape(view(iter.t.data, 1:(d₁ * d₂)), (d₁, d₂)) end - return reshape(view(iter.t.data, r), (d₁, d₂)) end # Getting and setting the data at the subblock level diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 0820fe1af..2d6b5a9a5 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -9,8 +9,7 @@ function TO.tensoralloc( ::Type{TT}, structure::TensorMapSpace, istemp::Val, allocator = TO.DefaultAllocator() ) where {TT <: AbstractTensorMap} A = storagetype(TT) - dim = fusionblockstructure(structure).totaldim - data = TO.tensoralloc(A, dim, istemp, allocator) + data = TO.tensoralloc(A, dim(structure), istemp, allocator) TT′ = tensormaptype(spacetype(structure), numout(structure), numin(structure), typeof(data)) return TT′(data, structure) end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index 281d3388c..af40aed50 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -28,56 +28,70 @@ function sectorhash(W::HomSpace, h::UInt) end """ - FusionBlockStructure{I, N} + SectorStructure{I, F₁, F₂} -Full block structure of a `HomSpace`, encoding how a tensor's flat data vector is -partitioned into symmetry blocks and sub-blocks indexed by fusion tree pairs. +Sector-only structure of a `HomSpace`: the coupled sectors and all valid fusion tree pairs, +depending only on which sectors appear (not their degeneracy dimensions). Shared across +`HomSpace`s with the same sector structure. ## Fields -- `totaldim`: total number of elements in the flat data vector. -- `blockstructure`: maps each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, where - `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, and - `r` is the corresponding index range in the flat data vector. -- `fusiontreestructure`: a `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` - values, one per fusion tree pair, in the canonical enumeration order from - [`fusiontrees`](@ref). Use `fusiontrees` to obtain the corresponding `Indices` of - fusion tree pairs. - -See also [`fusionblockstructure`](@ref), [`fusiontrees`](@ref). -""" -struct FusionBlockStructure{I, N} - totaldim::Int - blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreestructure::Vector{StridedStructure{N}} -end +- `blocksectors`: `Indices` of all coupled sectors `c::I`. +- `fusiontrees`: `Indices` of all valid fusion tree pairs `(f₁, f₂)`, in canonical order. -function fusionblockstructuretype(W::HomSpace) - N = length(codomain(W)) + length(domain(W)) - I = sectortype(W) - return FusionBlockStructure{I, N} +See also [`sectorstructure`](@ref), [`DegeneracyStructure`](@ref). +""" +struct SectorStructure{I, F₁, F₂} + blocksectors::Indices{I} + fusiontrees::Indices{Tuple{F₁, F₂}} end -Base.@assume_effects :foldable function fusiontreestype(key::Hashed{S}) where {S <: HomSpace} +Base.@assume_effects :foldable function sectorstructuretype(key::Hashed{S}) where {S <: HomSpace} I = sectortype(S) F₁ = fusiontreetype(I, numout(S)) F₂ = fusiontreetype(I, numin(S)) - return Indices{Tuple{F₁, F₂}} + return SectorStructure{I, F₁, F₂} end """ - fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} + DegeneracyStructure{N} -Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a -bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is -cached based on the sector structure of `W` (ignoring degeneracy dimensions), so -`HomSpace`s that share the same sectors, dualities, and index count will reuse the same -object. +Degeneracy-dependent structure of a `HomSpace`: the block sizes, ranges, and sub-block +strides that depend on the degeneracy (multiplicity) dimensions. Specific to a given +`HomSpace` instance. -See also [`fusionblockstructure`](@ref). +## Fields +- `totaldim`: total number of elements in the flat data vector. +- `blockstructure`: `Vector` of `((d₁, d₂), range)` values, one per coupled sector, in the + same order as [`sectorstructure`](@ref)`.blocksectors`. +- `fusiontreestructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` + values, one per fusion tree pair, in the same order as + [`sectorstructure`](@ref)`.fusiontrees`. + +See also [`degeneracystructure`](@ref), [`SectorStructure`](@ref). """ -fusiontrees(W::HomSpace) = fusiontrees(Hashed(W, sectorhash, sectorequal)) +struct DegeneracyStructure{N} + totaldim::Int + blockstructure::Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}} + fusiontreestructure::Vector{StridedStructure{N}} +end + +function degeneracystructuretype(W::HomSpace) + N = length(codomain(W)) + length(domain(W)) + return DegeneracyStructure{N} +end + +@doc """ + sectorstructure(W::HomSpace) -> SectorStructure + +Return the [`SectorStructure`](@ref) for `W`, containing the coupled sectors and fusion tree +pairs as `Indices`. The result is cached based on the sector structure of `W` (ignoring +degeneracy dimensions). -@cached function fusiontrees(key::Hashed{S})::fusiontreestype(key) where {S <: HomSpace} +See also [`degeneracystructure`](@ref), [`fusiontrees`](@ref), [`blocksectors`](@ref). +""" sectorstructure(::HomSpace) +sectorstructure(W::HomSpace) = sectorstructure(Hashed(W, sectorhash, sectorequal)) + +@cached function sectorstructure(key::Hashed{S})::sectorstructuretype(key) where {S <: HomSpace} W = parent(key) codom, dom = codomain(W), domain(W) I = sectortype(S) @@ -85,9 +99,11 @@ fusiontrees(W::HomSpace) = fusiontrees(Hashed(W, sectorhash, sectorequal)) F₁ = fusiontreetype(I, N₁) F₂ = fusiontreetype(I, N₂) + bs = Vector{I}() trees = Vector{Tuple{F₁, F₂}}() - for c in blocksectors(W) + for c in _blocksectors(W) + push!(bs, c) codom_start = length(trees) + 1 n₁ = 0 for f₂ in fusiontrees(dom, c) @@ -107,38 +123,37 @@ fusiontrees(W::HomSpace) = fusiontrees(Hashed(W, sectorhash, sectorequal)) end end - return Indices(trees) + return SectorStructure{I, F₁, F₂}(Indices(bs), Indices(trees)) end -CacheStyle(::typeof(fusiontrees), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() +CacheStyle(::typeof(sectorstructure), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() @doc """ - fusionblockstructure(W::HomSpace) -> FusionBlockStructure - -Compute the full [`FusionBlockStructure`](@ref) for `W`, describing how a tensor's flat -data vector is laid out in terms of symmetry blocks and fusion-tree sub-blocks. The result -is cached per `HomSpace` instance (keyed by object identity, not sector structure, since -degeneracy dimensions affect the block sizes and offsets). + degeneracystructure(W::HomSpace) -> DegeneracyStructure -See also [`FusionBlockStructure`](@ref), [`fusiontrees`](@ref). -""" fusionblockstructure(::HomSpace) +Compute the [`DegeneracyStructure`](@ref) for `W`, describing block sizes, data ranges, and +sub-block strides. The result is cached per `HomSpace` instance (keyed by object identity, +since degeneracy dimensions affect the block sizes and offsets). -@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W) +See also [`sectorstructure`](@ref), [`blockstructure`](@ref), [`fusiontreestructure`](@ref). +""" degeneracystructure(::HomSpace) +@cached function degeneracystructure(W::HomSpace)::degeneracystructuretype(W) codom = codomain(W) dom = domain(W) N = length(codom) + length(dom) - I = sectortype(W) - treelist = fusiontrees(W) + ss = sectorstructure(W) + treelist = ss.fusiontrees L = length(treelist) structurevalues = sizehint!(Vector{StridedStructure{N}}(), L) - blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() + blockvalues = Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}}(undef, length(ss.blocksectors)) # temporary data structures splittingstructure = Vector{NTuple{numout(W), Int}}() blockoffset = 0 tree_index = 1 + block_index = 1 while tree_index <= L f₁, f₂ = gettokenvalue(treelist, tree_index) c = f₁.coupled @@ -186,15 +201,16 @@ See also [`FusionBlockStructure`](@ref), [`fusiontrees`](@ref). # compute block structure blockdim₂ = offset₂ blockrange = (blockoffset + 1):(blockoffset + blockdim₁ * blockdim₂) - blockstructure[c] = ((blockdim₁, blockdim₂), blockrange) + blockvalues[block_index] = ((blockdim₁, blockdim₂), blockrange) # reset blockoffset = last(blockrange) tree_index += n₁ * n₂ + block_index += 1 end @assert length(structurevalues) == L - return FusionBlockStructure(blockoffset, blockstructure, structurevalues) + return DegeneracyStructure(blockoffset, blockvalues, structurevalues) end function _subblock_strides(subsz, sz, str) @@ -205,18 +221,51 @@ function _subblock_strides(subsz, sz, str) return strides end -CacheStyle(::typeof(fusionblockstructure), W::HomSpace) = GlobalLRUCache() +CacheStyle(::typeof(degeneracystructure), ::HomSpace) = GlobalLRUCache() + +# Public API: combining the two caches +#-------------------------------------- + +""" + fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} + +Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a +bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is +cached based on the sector structure of `W` (ignoring degeneracy dimensions), so +`HomSpace`s that share the same sectors, dualities, and index count will reuse the same +object. + +See also [`sectorstructure`](@ref), [`fusiontreestructure`](@ref). +""" +fusiontrees(W::HomSpace) = sectorstructure(W).fusiontrees + +blocksectors(W::HomSpace) = sectorstructure(W).blocksectors + +dim(W::HomSpace) = degeneracystructure(W).totaldim + +""" + blockstructure(W::HomSpace) -> Dictionary + +Return a `Dictionary` mapping each coupled sector `c::I` to a tuple `((d₁, d₂), r)`, +where `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, +and `r` is the corresponding index range in the flat data vector. + +See also [`degeneracystructure`](@ref), [`fusiontreestructure`](@ref). +""" +function blockstructure(W::HomSpace) + return Dictionary(sectorstructure(W).blocksectors, degeneracystructure(W).blockstructure) +end """ fusiontreestructure(W::HomSpace) -> Dictionary Return a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to its -[`StridedStructure`](@ref) `(sizes, strides, offset)`. This wraps the cached -[`fusiontrees`](@ref) `Indices` together with the values stored in -[`fusionblockstructure`](@ref), with no data copying. +[`StridedStructure`](@ref) `(sizes, strides, offset)`. + +See also [`degeneracystructure`](@ref), [`blockstructure`](@ref). """ function fusiontreestructure(W::HomSpace) - return Dictionary(fusiontrees(W), fusionblockstructure(W).fusiontreestructure) + return Dictionary(sectorstructure(W).fusiontrees, degeneracystructure(W).fusiontreestructure) end # Diagonal ranges diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index 191147a9c..c58748f66 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -16,9 +16,6 @@ end function AbelianTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - structure_dst = fusionblockstructure(Vdst) - structure_src = fusionblockstructure(Vsrc) - fts_src = fusiontreestructure(Vsrc) fts_dst = fusiontreestructure(Vdst) L = length(fts_src) From 1a3893dd930ed81886aaa87d54b175ddb6f34eb1 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 24 Mar 2026 13:00:42 -0400 Subject: [PATCH 11/12] rename to subblockstructure --- src/tensors/abstracttensor.jl | 35 ++------------------------------- src/tensors/braidingtensor.jl | 2 +- src/tensors/tensor.jl | 2 +- src/tensors/tensorstructure.jl | 16 +++++++-------- src/tensors/treetransformers.jl | 8 ++++---- 5 files changed, 16 insertions(+), 47 deletions(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 2ed4129e2..af216d257 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -345,39 +345,8 @@ Verify whether a tensor has a block corresponding to a coupled sector `c`. """ hasblock(t::AbstractTensorMap, c::Sector) = c ∈ blocksectors(t) -# TODO: convenience methods, do we need them? -# """ -# blocksize(t::AbstractTensorMap, c::Sector) -> Tuple{Int,Int} - -# Return the size of the matrix block of a tensor corresponding to a coupled sector `c`. - -# See also [`blockdim`](@ref) and [`blockrange`](@ref). -# """ -# function blocksize(t::AbstractTensorMap, c::Sector) -# return fusionblockstructure(t).blockstructure[c][1] -# end - -# """ -# blockdim(t::AbstractTensorMap, c::Sector) -> Int - -# Return the total dimension (length) of the matrix block of a tensor corresponding to -# a coupled sector `c`. - -# See also [`blocksize`](@ref) and [`blockrange`](@ref). -# """ -# function blockdim(t::AbstractTensorMap, c::Sector) -# return *(blocksize(t, c)...) -# end - -# """ -# blockrange(t::AbstractTensorMap, c::Sector) -> UnitRange{Int} - -# Return the range at which to find the matrix block of a tensor corresponding to a -# coupled sector `c`, within the total data vector of length `dim(t)`. -# """ -# function blockrange(t::AbstractTensorMap, c::Sector) -# return fusionblockstructure(t).blockstructure[c][2] -# end +blockstructure(t::AbstractTensorMap) = blockstructure(space(t)) +subblockstructure(t::AbstractTensorMap) = subblockstructure(space(t)) """ fusiontrees(t::AbstractTensorMap) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index 1a86e30b0..48e1aa1df 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -145,7 +145,7 @@ function block(b::BraidingTensor, s::Sector) base_offset = first(blockstructure(b)[s][2]) - 1 - for ((f1, f2), (sz, str, off)) in pairs(fusiontreestructure(space(b))) + for ((f1, f2), (sz, str, off)) in pairs(subblockstructure(space(b))) if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s) continue end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 26a9d127a..7e4f037cb 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -487,7 +487,7 @@ end function subblock( t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}} ) where {T, S, N₁, N₂, I <: Sector} - fts = fusiontreestructure(space(t)) + fts = subblockstructure(space(t)) found, token = gettoken(fts, (f₁, f₂)) @boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present")) @inbounds begin diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index af40aed50..a460960c3 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -63,7 +63,7 @@ strides that depend on the degeneracy (multiplicity) dimensions. Specific to a g - `totaldim`: total number of elements in the flat data vector. - `blockstructure`: `Vector` of `((d₁, d₂), range)` values, one per coupled sector, in the same order as [`sectorstructure`](@ref)`.blocksectors`. -- `fusiontreestructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` +- `subblockstructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` values, one per fusion tree pair, in the same order as [`sectorstructure`](@ref)`.fusiontrees`. @@ -72,7 +72,7 @@ See also [`degeneracystructure`](@ref), [`SectorStructure`](@ref). struct DegeneracyStructure{N} totaldim::Int blockstructure::Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}} - fusiontreestructure::Vector{StridedStructure{N}} + subblockstructure::Vector{StridedStructure{N}} end function degeneracystructuretype(W::HomSpace) @@ -135,7 +135,7 @@ Compute the [`DegeneracyStructure`](@ref) for `W`, describing block sizes, data sub-block strides. The result is cached per `HomSpace` instance (keyed by object identity, since degeneracy dimensions affect the block sizes and offsets). -See also [`sectorstructure`](@ref), [`blockstructure`](@ref), [`fusiontreestructure`](@ref). +See also [`sectorstructure`](@ref), [`blockstructure`](@ref), [`subblockstructure`](@ref). """ degeneracystructure(::HomSpace) @cached function degeneracystructure(W::HomSpace)::degeneracystructuretype(W) codom = codomain(W) @@ -235,7 +235,7 @@ cached based on the sector structure of `W` (ignoring degeneracy dimensions), so `HomSpace`s that share the same sectors, dualities, and index count will reuse the same object. -See also [`sectorstructure`](@ref), [`fusiontreestructure`](@ref). +See also [`sectorstructure`](@ref), [`subblockstructure`](@ref). """ fusiontrees(W::HomSpace) = sectorstructure(W).fusiontrees @@ -250,22 +250,22 @@ Return a `Dictionary` mapping each coupled sector `c::I` to a tuple `((d₁, d where `d₁` and `d₂` are the block dimensions for the codomain and domain respectively, and `r` is the corresponding index range in the flat data vector. -See also [`degeneracystructure`](@ref), [`fusiontreestructure`](@ref). +See also [`degeneracystructure`](@ref), [`subblockstructure`](@ref). """ function blockstructure(W::HomSpace) return Dictionary(sectorstructure(W).blocksectors, degeneracystructure(W).blockstructure) end """ - fusiontreestructure(W::HomSpace) -> Dictionary + subblockstructure(W::HomSpace) -> Dictionary Return a `Dictionary` mapping each fusion tree pair `(f₁, f₂)` to its [`StridedStructure`](@ref) `(sizes, strides, offset)`. See also [`degeneracystructure`](@ref), [`blockstructure`](@ref). """ -function fusiontreestructure(W::HomSpace) - return Dictionary(sectorstructure(W).fusiontrees, degeneracystructure(W).fusiontreestructure) +function subblockstructure(W::HomSpace) + return Dictionary(sectorstructure(W).fusiontrees, degeneracystructure(W).subblockstructure) end # Diagonal ranges diff --git a/src/tensors/treetransformers.jl b/src/tensors/treetransformers.jl index c58748f66..8577af908 100644 --- a/src/tensors/treetransformers.jl +++ b/src/tensors/treetransformers.jl @@ -16,8 +16,8 @@ end function AbelianTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - fts_src = fusiontreestructure(Vsrc) - fts_dst = fusiontreestructure(Vdst) + fts_src = subblockstructure(Vsrc) + fts_dst = subblockstructure(Vdst) L = length(fts_src) T = sectorscalartype(sectortype(Vdst)) N = numind(Vsrc) @@ -55,8 +55,8 @@ end function GenericTreeTransformer(transform, p, Vdst, Vsrc) t₀ = Base.time() permute(Vsrc, p) == Vdst || throw(SpaceMismatch("Incompatible spaces for permuting.")) - fusionstructure_dst = fusiontreestructure(Vdst) - fusionstructure_src = fusiontreestructure(Vsrc) + fusionstructure_dst = subblockstructure(Vdst) + fusionstructure_src = subblockstructure(Vsrc) I = sectortype(Vsrc) treelist_src = keys(fusionstructure_src) From 07ed7d5b93c2f61c7e9a0b4c5a8a1c1463755979 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 24 Mar 2026 13:56:54 -0400 Subject: [PATCH 12/12] some more simplifications --- src/TensorKit.jl | 2 +- src/spaces/homspace.jl | 45 +++++++++----------- src/spaces/productspace.jl | 23 ++++------- src/tensors/tensorstructure.jl | 75 +++++++++++++++------------------- 4 files changed, 60 insertions(+), 85 deletions(-) diff --git a/src/TensorKit.jl b/src/TensorKit.jl index 4de6159c9..61adedcc2 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -118,7 +118,7 @@ const TO = TensorOperations using MatrixAlgebraKit -using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue +using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue, set! using LRUCache using OhMyThreads using ScopedValues diff --git a/src/spaces/homspace.jl b/src/spaces/homspace.jl index 9f5ba920c..42d0f4c9e 100644 --- a/src/spaces/homspace.jl +++ b/src/spaces/homspace.jl @@ -69,36 +69,15 @@ function Base.show(io::IO, W::HomSpace) ) end -@doc """ +""" blocksectors(W::HomSpace) -> Indices{I} Return an `Indices` of all coupled sectors for `W`. The result is cached based on the sector structure of `W` (ignoring degeneracy dimensions). See also [`hasblock`](@ref), [`blockstructure`](@ref). -""" blocksectors(::HomSpace) - -function _blocksectors(W::HomSpace) - sectortype(W) === Trivial && - return OneOrNoneIterator(dim(domain(W)) != 0 && dim(codomain(W)) != 0, Trivial()) - - codom = codomain(W) - dom = domain(W) - N₁ = length(codom) - N₂ = length(dom) - I = sectortype(W) - if N₁ == N₂ == 0 - return allunits(I) - elseif N₁ == 0 - return filter!(isunit, collect(blocksectors(dom))) # module space cannot end in empty space - elseif N₂ == 0 - return filter!(isunit, collect(blocksectors(codom))) - elseif N₂ <= N₁ - return filter!(c -> hasblock(codom, c), collect(blocksectors(dom))) - else - return filter!(c -> hasblock(dom, c), collect(blocksectors(codom))) - end -end +""" +blocksectors(W::HomSpace) = sectorstructure(W).blocksectors """ hasblock(W::HomSpace, c::Sector) @@ -109,15 +88,29 @@ See also [`blocksectors`](@ref). """ hasblock(W::HomSpace, c::Sector) = hasblock(codomain(W), c) && hasblock(domain(W), c) -@doc """ +""" dim(W::HomSpace) -> Int Return the total dimension of a `HomSpace`, i.e. the number of linearly independent morphisms that can be constructed within this space. -""" dim(::HomSpace) +""" +dim(W::HomSpace) = degeneracystructure(W).totaldim dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...) +""" + fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} + +Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a +bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is +cached based on the sector structure of `W` (ignoring degeneracy dimensions), so +`HomSpace`s that share the same sectors, dualities, and index count will reuse the same +object. + +See also [`sectorstructure`](@ref), [`subblockstructure`](@ref). +""" +fusiontrees(W::HomSpace) = sectorstructure(W).fusiontrees + # Operations on HomSpaces # ----------------------- """ diff --git a/src/spaces/productspace.jl b/src/spaces/productspace.jl index 4b75ce737..a887999f0 100644 --- a/src/spaces/productspace.jl +++ b/src/spaces/productspace.jl @@ -96,6 +96,9 @@ sectors(P::ProductSpace) = _sectors(P, sectortype(P)) function _sectors(P::ProductSpace{<:ElementarySpace, N}, ::Type{Trivial}) where {N} return OneOrNoneIterator(dim(P) != 0, ntuple(n -> Trivial(), N)) end +function _sectors(P::ProductSpace{<:ElementarySpace, 0}, ::Type{I}) where {I <: Sector} + return Iterators.map(u -> (u,), allunits(I)) +end function _sectors(P::ProductSpace{<:ElementarySpace, N}, ::Type{<:Sector}) where {N} return product(map(sectors, P.spaces)...) end @@ -147,23 +150,11 @@ that make up the `ProductSpace` instance. function blocksectors(P::ProductSpace{S, N}) where {S, N} I = sectortype(S) if I == Trivial - return OneOrNoneIterator(dim(P) != 0, Trivial()) + return dim(P) != 0 ? Indices([Trivial()]) : Indices(Trivial[]) end - bs = Vector{I}() - if N == 0 - append!(bs, allunits(I)) - elseif N == 1 - for s in sectors(P) - push!(bs, first(s)) - end - else - for s in sectors(P) - for c in ⊗(s...) - if !(c in bs) - push!(bs, c) - end - end - end + bs = Indices{I}() + for s in sectors(P), c in ⊗(s...) + set!(bs, c) end return sort!(bs) end diff --git a/src/tensors/tensorstructure.jl b/src/tensors/tensorstructure.jl index a460960c3..94e812ddd 100644 --- a/src/tensors/tensorstructure.jl +++ b/src/tensors/tensorstructure.jl @@ -52,32 +52,11 @@ Base.@assume_effects :foldable function sectorstructuretype(key::Hashed{S}) wher return SectorStructure{I, F₁, F₂} end -""" - DegeneracyStructure{N} - -Degeneracy-dependent structure of a `HomSpace`: the block sizes, ranges, and sub-block -strides that depend on the degeneracy (multiplicity) dimensions. Specific to a given -`HomSpace` instance. - -## Fields -- `totaldim`: total number of elements in the flat data vector. -- `blockstructure`: `Vector` of `((d₁, d₂), range)` values, one per coupled sector, in the - same order as [`sectorstructure`](@ref)`.blocksectors`. -- `subblockstructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` - values, one per fusion tree pair, in the same order as - [`sectorstructure`](@ref)`.fusiontrees`. +function _blocksectors(W::HomSpace) + sectortype(W) === Trivial && + return OneOrNoneIterator(dim(domain(W)) != 0 && dim(codomain(W)) != 0, Trivial()) -See also [`degeneracystructure`](@ref), [`SectorStructure`](@ref). -""" -struct DegeneracyStructure{N} - totaldim::Int - blockstructure::Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}} - subblockstructure::Vector{StridedStructure{N}} -end - -function degeneracystructuretype(W::HomSpace) - N = length(codomain(W)) + length(domain(W)) - return DegeneracyStructure{N} + return sort!(intersect(blocksectors(codomain(W)), blocksectors(domain(W)))) end @doc """ @@ -128,6 +107,35 @@ end CacheStyle(::typeof(sectorstructure), ::Hashed{S}) where {S <: HomSpace} = GlobalLRUCache() + +""" + DegeneracyStructure{N} + +Degeneracy-dependent structure of a `HomSpace`: the block sizes, ranges, and sub-block +strides that depend on the degeneracy (multiplicity) dimensions. Specific to a given +`HomSpace` instance. + +## Fields +- `totaldim`: total number of elements in the flat data vector. +- `blockstructure`: `Vector` of `((d₁, d₂), range)` values, one per coupled sector, in the + same order as [`sectorstructure`](@ref)`.blocksectors`. +- `subblockstructure`: `Vector` of [`StridedStructure`](@ref) `(sizes, strides, offset)` + values, one per fusion tree pair, in the same order as + [`sectorstructure`](@ref)`.fusiontrees`. + +See also [`degeneracystructure`](@ref), [`SectorStructure`](@ref). +""" +struct DegeneracyStructure{N} + totaldim::Int + blockstructure::Vector{Tuple{Tuple{Int, Int}, UnitRange{Int}}} + subblockstructure::Vector{StridedStructure{N}} +end + +function degeneracystructuretype(W::HomSpace) + N = length(codomain(W)) + length(domain(W)) + return DegeneracyStructure{N} +end + @doc """ degeneracystructure(W::HomSpace) -> DegeneracyStructure @@ -226,23 +234,6 @@ CacheStyle(::typeof(degeneracystructure), ::HomSpace) = GlobalLRUCache() # Public API: combining the two caches #-------------------------------------- -""" - fusiontrees(W::HomSpace) -> Indices{Tuple{F₁,F₂}} - -Return an `Indices` of all valid fusion tree pairs `(f₁, f₂)` for `W`, providing a -bijection to sequential integer positions via `gettoken`/`gettokenvalue`. The result is -cached based on the sector structure of `W` (ignoring degeneracy dimensions), so -`HomSpace`s that share the same sectors, dualities, and index count will reuse the same -object. - -See also [`sectorstructure`](@ref), [`subblockstructure`](@ref). -""" -fusiontrees(W::HomSpace) = sectorstructure(W).fusiontrees - -blocksectors(W::HomSpace) = sectorstructure(W).blocksectors - -dim(W::HomSpace) = degeneracystructure(W).totaldim - """ blockstructure(W::HomSpace) -> Dictionary