From 51e1b80443b0205c7c88840885900bcecd956cea Mon Sep 17 00:00:00 2001 From: Amogh Date: Tue, 24 Mar 2026 08:43:10 +0530 Subject: [PATCH 1/3] Adding NNConv support for HeteroGraphConv --- GNNGraphs/src/gnnheterograph/query.jl | 10 ++++++++++ GNNlib/src/layers/conv.jl | 7 ++++--- GraphNeuralNetworks/src/layers/conv.jl | 2 ++ GraphNeuralNetworks/test/layers/heteroconv.jl | 17 ++++++++++++++++- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/GNNGraphs/src/gnnheterograph/query.jl b/GNNGraphs/src/gnnheterograph/query.jl index 831e654eb..b5a953ced 100644 --- a/GNNGraphs/src/gnnheterograph/query.jl +++ b/GNNGraphs/src/gnnheterograph/query.jl @@ -89,3 +89,13 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) end return gi end + +edge_features(g::GNNHeteroGraph, edge_t::EType) = begin + ds == g.edata[edge_t] + isempty(ds) ? nothing : first(values(ds)) +end + +edge_features(g::GNNHeteroGraph) = begin + ds = only(values(g.edata)) + isempty(ds) ? nothing : first(values(ds)) +end \ No newline at end of file diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..fdbefc3f3 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -257,11 +257,12 @@ end ####################### NNConv ###################################### -function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) +function nn_conv(l, g::AbstractGNNGraph, x, e) check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) message = Fix1(nn_conv_message, l) - m = propagate(message, g, l.aggr, xj = x, e = e) - return l.σ.(l.weight * x .+ m .+ l.bias) + m = propagate(message, g, l.aggr, xj = xj, e = e) + return l.σ.(l.weight * xi .+ m .+ l.bias) end function nn_conv_message(l, xi, xj, e) diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index e3cf30fea..a085ba848 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -718,6 +718,8 @@ end (l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) +(l::NNConv)(g, x) = GNNlib.nn_conv(l, g, x, edge_features(g)) + (l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) function Base.show(io::IO, l::NNConv) diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..5a888698e 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -155,7 +155,22 @@ x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); + y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "NNConv" begin + nA, nB = 2, 3 + nedges = 6 + g = rand_bipartite_heterograph((nA, nB), nedges; + edata = Dict((:A, :to, :B) => (e = rand(Float32, 5, nedges),), + (:B, :to, :A) => (e = rand(Float32, 5, nedges),)) + ) + x = (A = rand(Float32, 4, nA), B = rand(Float32, 4, nB)) + nn = Dense(5 => 4 * 2) + layers = HeteroGraphConv((:A, :to, :B) => NNConv(4 => 2, nn, tanh), + (:B, :to, :A) => NNConv(4 => 2, nn, tanh)) + y = layers(g, x) + @test size(y.A) == (2, nA) && size(y.B) == (2, nB) + end end From 0d09d2a4862e21ea22a8c43e2a3aed98108ecc39 Mon Sep 17 00:00:00 2001 From: Amogh Date: Tue, 24 Mar 2026 21:24:13 +0530 Subject: [PATCH 2/3] Refactoring --- GNNGraphs/src/gnnheterograph/query.jl | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/GNNGraphs/src/gnnheterograph/query.jl b/GNNGraphs/src/gnnheterograph/query.jl index b5a953ced..92deae9d8 100644 --- a/GNNGraphs/src/gnnheterograph/query.jl +++ b/GNNGraphs/src/gnnheterograph/query.jl @@ -90,12 +90,19 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) return gi end -edge_features(g::GNNHeteroGraph, edge_t::EType) = begin - ds == g.edata[edge_t] - isempty(ds) ? nothing : first(values(ds)) -end - edge_features(g::GNNHeteroGraph) = begin - ds = only(values(g.edata)) - isempty(ds) ? nothing : first(values(ds)) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple edge feature arrays, access directly through `g.edata`" + else + ds = only(values(g.edata)) + if isempty(ds) + return nothing + elseif length(ds) > 1 + @error "Multiple edge feature arrays, access directly through `g.edata`" + else + return first(values(ds)) + end + end end \ No newline at end of file From 032a220c0b09a17c1be058bc5754374f528ea043 Mon Sep 17 00:00:00 2001 From: Amogh Date: Thu, 26 Mar 2026 04:58:20 +0530 Subject: [PATCH 3/3] Aligning with changes in GMMConv --- GNNGraphs/src/gnnheterograph/query.jl | 29 ++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/GNNGraphs/src/gnnheterograph/query.jl b/GNNGraphs/src/gnnheterograph/query.jl index 92deae9d8..6fdad7735 100644 --- a/GNNGraphs/src/gnnheterograph/query.jl +++ b/GNNGraphs/src/gnnheterograph/query.jl @@ -90,19 +90,30 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) return gi end -edge_features(g::GNNHeteroGraph) = begin +""" + edge_features(g::GNNHeteroGraph) + +Return the edge features for a heterogeneous graph with a single edge type. +If the graph has multiple edge types, this will error. +If no edge features are present, returns `nothing`. +""" +function edge_features(g::GNNHeteroGraph) if isempty(g.edata) return nothing elseif length(g.edata) > 1 - @error "Multiple edge feature arrays, access directly through `g.edata`" + @error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`" else - ds = only(values(g.edata)) - if isempty(ds) - return nothing - elseif length(ds) > 1 - @error "Multiple edge feature arrays, access directly through `g.edata`" - else - return first(values(ds)) + edata = only(values(g.edata)) + if edata isa AbstractArray + return isempty(edata) ? nothing : edata + else + if isempty(edata) + return nothing + elseif length(edata) > 1 + @error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`" + else + return first(values(edata)) + end end end end \ No newline at end of file