diff --git a/GNNGraphs/src/gnnheterograph/query.jl b/GNNGraphs/src/gnnheterograph/query.jl index 831e654eb..6fdad7735 100644 --- a/GNNGraphs/src/gnnheterograph/query.jl +++ b/GNNGraphs/src/gnnheterograph/query.jl @@ -89,3 +89,31 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol) end return gi end + +""" + edge_features(g::GNNHeteroGraph) + +Return the edge features for a heterogeneous graph with a single edge type. +If the graph has multiple edge types, this will error. +If no edge features are present, returns `nothing`. +""" +function edge_features(g::GNNHeteroGraph) + if isempty(g.edata) + return nothing + elseif length(g.edata) > 1 + @error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`" + else + edata = only(values(g.edata)) + if edata isa AbstractArray + return isempty(edata) ? nothing : edata + else + if isempty(edata) + return nothing + elseif length(edata) > 1 + @error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`" + else + return first(values(edata)) + end + end + end +end \ No newline at end of file diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..fdbefc3f3 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -257,11 +257,12 @@ end ####################### NNConv ###################################### -function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) +function nn_conv(l, g::AbstractGNNGraph, x, e) check_num_nodes(g, x) + xj, xi = expand_srcdst(g, x) message = Fix1(nn_conv_message, l) - m = propagate(message, g, l.aggr, xj = x, e = e) - return l.σ.(l.weight * x .+ m .+ l.bias) + m = propagate(message, g, l.aggr, xj = xj, e = e) + return l.σ.(l.weight * xi .+ m .+ l.bias) end function nn_conv_message(l, xi, xj, e) diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index e3cf30fea..a085ba848 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -718,6 +718,8 @@ end (l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) +(l::NNConv)(g, x) = GNNlib.nn_conv(l, g, x, edge_features(g)) + (l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) function Base.show(io::IO, l::NNConv) diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..5a888698e 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -155,7 +155,22 @@ x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); + y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "NNConv" begin + nA, nB = 2, 3 + nedges = 6 + g = rand_bipartite_heterograph((nA, nB), nedges; + edata = Dict((:A, :to, :B) => (e = rand(Float32, 5, nedges),), + (:B, :to, :A) => (e = rand(Float32, 5, nedges),)) + ) + x = (A = rand(Float32, 4, nA), B = rand(Float32, 4, nB)) + nn = Dense(5 => 4 * 2) + layers = HeteroGraphConv((:A, :to, :B) => NNConv(4 => 2, nn, tanh), + (:B, :to, :A) => NNConv(4 => 2, nn, tanh)) + y = layers(g, x) + @test size(y.A) == (2, nA) && size(y.B) == (2, nB) + end end