diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 3b70747c8..a8549e59e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -421,7 +421,7 @@ function Base.show(io::IO, l::AGNNConv) print(io, ")") end -function (l::AGNNConv)(g, x::AbstractMatrix, ps, st) +function (l::AGNNConv)(g, x, ps, st) β = l.trainable ? ps.β : l.init_beta m = (; β, l.add_self_loops) return GNNlib.agnn_conv(m, g, x), st diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 0827b4592..faad398c0 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -334,18 +334,25 @@ end ####################### AGNNConv ###################################### -function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) +function agnn_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) if l.add_self_loops g = add_self_loops(g) end - xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) - cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) + xj, xi = expand_srcdst(g, x) + + xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 1)) + if xj !== xi + xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1)) + else + xj_n = xi_n + end + cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n) α = softmax_edge_neighbors(g, l.β .* cos_dist) - x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj + x = propagate(g, +; xj, e = α) do xi, xj, α + α .* xj end return x diff --git a/GraphNeuralNetworks/test/layers/conv.jl b/GraphNeuralNetworks/test/layers/conv.jl index 96137a3ce..28bfc2008 100644 --- a/GraphNeuralNetworks/test/layers/conv.jl +++ b/GraphNeuralNetworks/test/layers/conv.jl @@ -422,7 +422,14 @@ end g.graph isa AbstractSparseMatrix && continue @test size(l(g, g.x)) == (D_IN, g.num_nodes) test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false) - end + end + l_bip = AGNNConv(add_self_loops=false) + s = [1, 1, 2, 3] + t = [1, 2, 1, 2] + g = GNNGraph((s, t)) |> gpu + x = (randn(Float32, D_IN, 3) |> gpu, randn(Float32, D_IN, 2) |> gpu) + y = l_bip(g, x) + @test size(y) == (D_IN, 2) end @testitem "MEGNetConv" setup=[TolSnippet, TestModule] begin diff --git a/GraphNeuralNetworks/test/layers/heteroconv.jl b/GraphNeuralNetworks/test/layers/heteroconv.jl index 4a6d40ca3..ed3ae826c 100644 --- a/GraphNeuralNetworks/test/layers/heteroconv.jl +++ b/GraphNeuralNetworks/test/layers/heteroconv.jl @@ -155,7 +155,15 @@ x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), (:B, :to, :A) => GCNConv(4 => 2, tanh)); - y = layers(g, x); + y = layers(g, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "AGNNConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => AGNNConv(), + (:B, :to, :A) => AGNNConv()) + y = layers(hg, x) + @test size(y.A) == (4, 2) && size(y.B) == (4, 3) + end end