Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions GNNGraphs/src/gnnheterograph/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,31 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
end
return gi
end

"""
edge_features(g::GNNHeteroGraph)

Return the edge features for a heterogeneous graph with a single edge type.
If the graph has multiple edge types, this will error.
If no edge features are present, returns `nothing`.
"""
function edge_features(g::GNNHeteroGraph)
if isempty(g.edata)
return nothing
elseif length(g.edata) > 1
@error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`"
else
edata = only(values(g.edata))
if edata isa AbstractArray
return isempty(edata) ? nothing : edata
else
if isempty(edata)
return nothing
elseif length(edata) > 1
@error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`"
else
return first(values(edata))
end
end
end
end
4 changes: 3 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,9 @@ function LuxCore.parameterlength(l::GMMConv)
return n
end

function (l::GMMConv)(g::GNNGraph, x, e, ps, st)
(l::GMMConv)(g, x, ps, st) = l(g, x, edge_features(g), ps, st)

function (l::GMMConv)(g::AbstractGNNGraph, x, e, ps, st)
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps))
return GNNlib.gmm_conv(m, g, x, e), st
Expand Down
20 changes: 12 additions & 8 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,29 +369,33 @@ end

####################### GMMConv ######################################

function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
function gmm_conv(l, g::AbstractGNNGraph, x, e)
(nin, ein), out = l.ch #Notational Simplicity

@assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)"
num_edges = g.num_edges isa Integer ? g.num_edges : only(values(g.num_edges))
@assert size(e, 2) == num_edges "Expected $num_edges edges but got $(size(e, 2))"
@assert size(e, 1) == ein "Edge feature dimension mismatch: expected $ein, got $(size(e, 1))"

num_edges = g.num_edges
xj, xi = expand_srcdst(g, x)

num_edges = size(e, 2)
w = reshape(e, (ein, 1, num_edges))
mu = reshape(l.mu, (ein, l.K, 1))

w = @. -((w - mu)^2) / 2
w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1))
w = exp.(sum(w, dims = 1)) # (1, K, num_edge)
w = exp.(sum(w, dims = 1)) # (1, K, num_edge)

xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes)
xj = reshape(l.dense_x(xj), (out, l.K, :)) # (out, K, num_nodes)

m = propagate(e_mul_xj, g, mean, xj = xj, e = w)
m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes)
m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes)

m = l.σ.(m .+ l.bias)

if l.residual
if size(x, 1) == size(m, 1)
m += x
if size(xi, 1) == size(m, 1)
m += xi
else
@warn "Residual not applied : output feature is not equal to input_feature"
end
Expand Down
2 changes: 1 addition & 1 deletion GraphNeuralNetworks/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ function GMMConv(ch::Pair{NTuple{2, Int}, Int},
GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual)
end

(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e)
(l::GMMConv)(g::AbstractGNNGraph, x, e = edge_features(g)) = GNNlib.gmm_conv(l, g, x, e)

(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))

Expand Down
16 changes: 14 additions & 2 deletions GraphNeuralNetworks/test/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,25 @@
y = layers(hg, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end

@testset "GCNConv" begin
g = rand_bipartite_heterograph((2,3), 6)
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh),
(:B, :to, :A) => GCNConv(4 => 2, tanh));
y = layers(g, x);
y = layers(g, x);
@test size(y.A) == (2,2) && size(y.B) == (2,3)
end

@testset "GMMConv" begin
ein = 4
g = rand_bipartite_heterograph((2, 3), 6;
edata=Dict((:A, :to, :B) => (x=rand(Float32, ein, 6),),
(:B, :to, :A) => (x=rand(Float32, ein, 6),)))
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
layers = HeteroGraphConv((:A, :to, :B) => GMMConv((4, ein) => 2, K=2),
(:B, :to, :A) => GMMConv((4, ein) => 2, K=2));
y = layers(g, x);
@test size(y.A) == (2, 2) && size(y.B) == (2, 3)
end
end
Loading