From 503b06acc1ee58d74677954dce6138fd4ccfc821 Mon Sep 17 00:00:00 2001 From: Amogh Date: Sat, 21 Mar 2026 01:53:00 +0530 Subject: [PATCH 1/3] Updated self loop logic and tests --- GraphNeuralNetworks/test/layers/conv.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/GraphNeuralNetworks/test/layers/conv.jl b/GraphNeuralNetworks/test/layers/conv.jl index 96137a3ce..28bfc2008 100644 --- a/GraphNeuralNetworks/test/layers/conv.jl +++ b/GraphNeuralNetworks/test/layers/conv.jl @@ -422,7 +422,14 @@ end g.graph isa AbstractSparseMatrix && continue @test size(l(g, g.x)) == (D_IN, g.num_nodes) test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false) - end + end + l_bip = AGNNConv(add_self_loops=false) + s = [1, 1, 2, 3] + t = [1, 2, 1, 2] + g = GNNGraph((s, t)) |> gpu + x = (randn(Float32, D_IN, 3) |> gpu, randn(Float32, D_IN, 2) |> gpu) + y = l_bip(g, x) + @test size(y) == (D_IN, 2) end @testitem "MEGNetConv" setup=[TolSnippet, TestModule] begin From 4b5d41b4010c34aca988a9350361d2f7ab5a2bed Mon Sep 17 00:00:00 2001 From: Amogh Date: Sun, 22 Mar 2026 18:29:52 +0530 Subject: [PATCH 2/3] Restored interface --- GNNLux/src/layers/conv.jl | 2 +- GNNlib/src/layers/conv.jl | 21 ++++++++++++++----- GraphNeuralNetworks/test/layers/heteroconv.jl | 10 ++++++++- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 3b70747c8..a8549e59e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -421,7 +421,7 @@ function Base.show(io::IO, l::AGNNConv) print(io, ")") end -function (l::AGNNConv)(g, x::AbstractMatrix, ps, st) +function (l::AGNNConv)(g, x, ps, st) β = l.trainable ? ps.β : l.init_beta m = (; β, l.add_self_loops) return GNNlib.agnn_conv(m, g, x), st diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..b101ff9b7 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -334,23 +334,34 @@ end ####################### AGNNConv ###################################### -function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) +function agnn_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) if l.add_self_loops g = add_self_loops(g) end - xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) - cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) + xj, xi = expand_srcdst(g, x) + + xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 1)) + xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1)) + cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n) α = softmax_edge_neighbors(g, l.β .* cos_dist) - x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj + x = propagate(g, +; xj, e = α) do xi, xj, α + α .* xj end return x end +""" + _has_same_node_types(g::GNNHeteroGraph) + +Return true if all edge types in the heterogeneous graph have the same source and +target node types (i.e., no bipartite relations). +""" +_has_same_node_types(g::GNNHeteroGraph) = all(et -> et[1] == et[3], g.etypes) + ####################### MegNetConv ###################################### function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..ed3ae826c 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -155,7 +155,15 @@ x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); + y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "AGNNConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => AGNNConv(), + (:B, :to, :A) => AGNNConv()) + y = layers(hg, x) + @test size(y.A) == (4, 2) && size(y.B) == (4, 3) + end end From f43f822576574a7d04c1f7c470d8c6687535b47e Mon Sep 17 00:00:00 2001 From: i-Amogh Date: Sun, 22 Mar 2026 21:33:13 +0530 Subject: [PATCH 3/3] Changes attended --- GNNlib/src/layers/conv.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index b101ff9b7..faad398c0 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -343,7 +343,11 @@ function agnn_conv(l, g::AbstractGNNGraph, x) xj, xi = expand_srcdst(g, x) xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 1)) - xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1)) + if xj !== xi + xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1)) + else + xj_n = xi_n + end cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n) α = softmax_edge_neighbors(g, l.β .* cos_dist) @@ -354,14 +358,6 @@ function agnn_conv(l, g::AbstractGNNGraph, x) return x end -""" - _has_same_node_types(g::GNNHeteroGraph) - -Return true if all edge types in the heterogeneous graph have the same source and -target node types (i.e., no bipartite relations). -""" -_has_same_node_types(g::GNNHeteroGraph) = all(et -> et[1] == et[3], g.etypes) - ####################### MegNetConv ###################################### function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)