From f61db64020015e8029c52cf61ecd91203b36f147 Mon Sep 17 00:00:00 2001 From: Amogh Date: Sun, 22 Mar 2026 07:43:02 +0530 Subject: [PATCH 1/2] feat: Add GatedGraphConv support for HeteroGraphConv --- GNNlib/src/layers/conv.jl | 25 +++++++++++++------ GraphNeuralNetworks/test/layers/heteroconv.jl | 8 ++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..b661d1157 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -215,19 +215,28 @@ end ####################### GatedGraphConv ###################################### -function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) +function gated_graph_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) - m, n = size(x) + xj, xi = expand_srcdst(g, x) + + h = xi + m, n = size(h) @assert m <= l.dims "number of input features must be less or equal to output features." if m < l.dims - xpad = zeros_like(x, (l.dims - m, n)) - x = vcat(x, xpad) + xpad = zeros_like(h, (l.dims - m, n)) + h = vcat(h, xpad) end - h = x + + mj, nj = size(xj) + if mj < l.dims + xpad = zeros_like(xj, (l.dims - mj, nj)) + xj = vcat(xj, xpad) + end + for i in 1:(l.num_layers) - m = view(l.weight, :, :, i) * h - m = propagate(copy_xj, g, l.aggr; xj = m) - _, h = l.gru(m, h) + msg = view(l.weight, :, :, i) * xj + msg = propagate(copy_xj, g, l.aggr; xj = msg) + _, h = l.gru(msg, h) end return h end diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..0d3c1b4a2 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -158,4 +158,12 @@ y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "GatedGraphConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => GatedGraphConv(4, 2), + (:B, :to, :A) => GatedGraphConv(4, 2)); + y = layers(hg, x); + @test size(y.A) == (4, 2) && size(y.B) == (4, 3) + end end From 63107b73967f2a1dfc25177cfe891923e763fc69 Mon Sep 17 00:00:00 2001 From: Amogh Date: Sun, 22 Mar 2026 22:05:25 +0530 Subject: [PATCH 2/2] Optimizations --- GNNlib/src/layers/conv.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index b661d1157..d31dbd3af 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -227,10 +227,14 @@ function gated_graph_conv(l, g::AbstractGNNGraph, x) h = vcat(h, xpad) end - mj, nj = size(xj) - if mj < l.dims - xpad = zeros_like(xj, (l.dims - mj, nj)) - xj = vcat(xj, xpad) + if xj !== xi + mj, nj = size(xj) + if mj < l.dims + xpad = zeros_like(xj, (l.dims - mj, nj)) + xj = vcat(xj, xpad) + end + else + xj = h end for i in 1:(l.num_layers)