diff --git a/GNNGraphs/src/gnnheterograph/query.jl b/GNNGraphs/src/gnnheterograph/query.jl index 831e654eb..6fdad7735 100644 --- a/GNNGraphs/src/gnnheterograph/query.jl +++ b/GNNGraphs/src/gnnheterograph/query.jl @@ -89,3 +89,31 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) end return gi end + +""" + edge_features(g::GNNHeteroGraph) + +Return the edge features for a heterogeneous graph with a single edge type. +If the graph has multiple edge types, this will error. +If no edge features are present, returns `nothing`. +""" +function edge_features(g::GNNHeteroGraph) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`" + else + edata = only(values(g.edata)) + if edata isa AbstractArray + return isempty(edata) ? nothing : edata + else + if isempty(edata) + return nothing + elseif length(edata) > 1 + @error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`" + else + return first(values(edata)) + end + end + end +end \ No newline at end of file diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 3b70747c8..a4d531638 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1446,7 +1446,9 @@ function LuxCore.parameterlength(l::GMMConv) return n end -function (l::GMMConv)(g::GNNGraph, x, e, ps, st) +(l::GMMConv)(g, x, ps, st) = l(g, x, edge_features(g), ps, st) + +function (l::GMMConv)(g::AbstractGNNGraph, x, e, ps, st) dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x)) m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps)) return GNNlib.gmm_conv(m, g, x, e), st diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..7bcdcdd11 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -369,29 +369,33 @@ end ####################### GMMConv ###################################### -function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) +function gmm_conv(l, g::AbstractGNNGraph, x, e) (nin, ein), out = l.ch #Notational Simplicity - @assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)" + num_edges = g.num_edges isa Integer ? g.num_edges : only(values(g.num_edges)) + @assert size(e, 2) == num_edges "Expected $num_edges edges but got $(size(e, 2))" + @assert size(e, 1) == ein "Edge feature dimension mismatch: expected $ein, got $(size(e, 1))" - num_edges = g.num_edges + xj, xi = expand_srcdst(g, x) + + num_edges = size(e, 2) w = reshape(e, (ein, 1, num_edges)) mu = reshape(l.mu, (ein, l.K, 1)) w = @. -((w - mu)^2) / 2 w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1)) - w = exp.(sum(w, dims = 1)) # (1, K, num_edge) + w = exp.(sum(w, dims = 1)) # (1, K, num_edge) - xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes) + xj = reshape(l.dense_x(xj), (out, l.K, :)) # (out, K, num_nodes) m = propagate(e_mul_xj, g, mean, xj = xj, e = w) - m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) + m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) m = l.σ.(m .+ l.bias) if l.residual - if size(x, 1) == size(m, 1) - m += x + if size(xi, 1) == size(m, 1) + m += xi else @warn "Residual not applied : output feature is not equal to input_feature" end diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index e3cf30fea..873ebe3c8 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -1135,7 +1135,7 @@ function GMMConv(ch::Pair{NTuple{2, Int}, Int}, GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual) end -(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e) +(l::GMMConv)(g::AbstractGNNGraph, x, e = edge_features(g)) = GNNlib.gmm_conv(l, g, x, e) (l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..3be610ec9 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -149,13 +149,25 @@ y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end - + @testset "GCNConv" begin g = rand_bipartite_heterograph((2,3), 6) x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); + y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "GMMConv" begin + ein = 4 + g = rand_bipartite_heterograph((2, 3), 6; + edata=Dict((:A, :to, :B) => (x=rand(Float32, ein, 6),), + (:B, :to, :A) => (x=rand(Float32, ein, 6),))) + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GMMConv((4, ein) => 2, K=2), + (:B, :to, :A) => GMMConv((4, ein) => 2, K=2)); + y = layers(g, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end