Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions GNNGraphs/src/gnnheterograph/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,31 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
end
return gi
end

"""
edge_features(g::GNNHeteroGraph)

Return the edge features for a heterogeneous graph with a single edge type.
If the graph has multiple edge types, this will error.
If no edge features are present, returns `nothing`.
"""
function edge_features(g::GNNHeteroGraph)
if isempty(g.edata)
return nothing
elseif length(g.edata) > 1
@error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`"
else
edata = only(values(g.edata))
if edata isa AbstractArray
return isempty(edata) ? nothing : edata
else
if isempty(edata)
return nothing
elseif length(edata) > 1
@error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`"
else
return first(values(edata))
end
end
end
end
7 changes: 4 additions & 3 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,12 @@ end

####################### NNConv ######################################

function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e)
function nn_conv(l, g::AbstractGNNGraph, x, e)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
message = Fix1(nn_conv_message, l)
m = propagate(message, g, l.aggr, xj = x, e = e)
return l.σ.(l.weight * x .+ m .+ l.bias)
m = propagate(message, g, l.aggr, xj = xj, e = e)
return l.σ.(l.weight * xi .+ m .+ l.bias)
end

function nn_conv_message(l, xi, xj, e)
Expand Down
2 changes: 2 additions & 0 deletions GraphNeuralNetworks/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,8 @@ end

(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e)

(l::NNConv)(g, x) = GNNlib.nn_conv(l, g, x, edge_features(g))

(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))

function Base.show(io::IO, l::NNConv)
Expand Down
17 changes: 16 additions & 1 deletion GraphNeuralNetworks/test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,22 @@
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh),
(:B, :to, :A) => GCNConv(4 => 2, tanh));
y = layers(g, x);
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end

@testset "NNConv" begin
nA, nB = 2, 3
nedges = 6
g = rand_bipartite_heterograph((nA, nB), nedges;
edata = Dict((:A, :to, :B) => (e = rand(Float32, 5, nedges),),
(:B, :to, :A) => (e = rand(Float32, 5, nedges),))
)
x = (A = rand(Float32, 4, nA), B = rand(Float32, 4, nB))
nn = Dense(5 => 4 * 2)
layers = HeteroGraphConv((:A, :to, :B) => NNConv(4 => 2, nn, tanh),
(:B, :to, :A) => NNConv(4 => 2, nn, tanh))
y = layers(g, x)
@test size(y.A) == (2, nA) && size(y.B) == (2, nB)
end
end
Loading