From 2a3b186070b3a5a9d79e0e719e688f5426478435 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 14 Mar 2023 11:09:20 +0100 Subject: [PATCH 1/8] Add weightandsum pool --- src/layers/pool.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index ef2627534..023ee55db 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -143,3 +143,14 @@ function topk_index(y::AbstractVector, k::Int) end topk_index(y::Adjoint, k::Int) = topk_index(y', k) + +struct WeigthAndSumPool + in_feats::Int +end + +function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray) + atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) + return reduce_nodes(+, g, atom_weighting(x) .* x ) +end + +(ws::WeigthAndSumPool)(g::GNNGraph) = GNNGraph(g, gdata = ws(g, node_features(g))) \ No newline at end of file From 5a91eef9f68dbfed426ca17fe477e0c6bb797a10 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 14 Mar 2023 11:09:33 +0100 Subject: [PATCH 2/8] Export layer --- src/GraphNeuralNetworks.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 06dfa178a..ea97f8558 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -71,6 +71,7 @@ export GlobalAttentionPool, TopKPool, topk_index, + WeigthAndSumPool, # mldatasets mldataset2gnngraph From fd41210b1d45590f8be0e2cfbb7385263565f839 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 14 Mar 2023 13:21:42 +0100 Subject: [PATCH 3/8] Add CUDA support and docstring --- src/layers/pool.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 023ee55db..94bea3377 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -144,13 +144,23 @@ end topk_index(y::Adjoint, k::Int) = topk_index(y', k) +""" + WeigthAndSumPool(in_feats) + +WeigthAndSum sum pooling layer. Compute the weights for each node and perform a weighted sum. +""" struct WeigthAndSumPool in_feats::Int end function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray) atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) - return reduce_nodes(+, g, atom_weighting(x) .* x ) + return reduce_nodes(+, g, atom_weighting(x) .* x) +end + +function (ws::WeigthAndSumPool)(g::GNNGraph, x::CuArray) + atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) |> gpu + return reduce_nodes(+, g, atom_weighting(x) .* x) end (ws::WeigthAndSumPool)(g::GNNGraph) = GNNGraph(g, gdata = ws(g, node_features(g))) \ No newline at end of file From 85752abf00e8f489fde42d9056a34aba2f5df444 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 14 Mar 2023 13:22:06 +0100 Subject: [PATCH 4/8] Add test `WeightAndSum` --- test/layers/pool.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/layers/pool.jl b/test/layers/pool.jl index 69d848ced..9300904c8 100644 --- a/test/layers/pool.jl +++ b/test/layers/pool.jl @@ -61,4 +61,20 @@ @test topk_index(X, 4) == [1, 2, 3, 4] @test topk_index(X', 4) == [1, 2, 3, 4] end + + @testset "WeigthAndSumPool" begin + n = 3 + chin = 5 + ng = 3 + + ws = WeigthAndSumPool(chin) + g = GNNGraph(rand_graph(n, 4), ndata = rand(Float32, chin, n), graph_type = GRAPH_T) + + test_layer(ws, g, rtol = 1e-5, outtype = :graph,outsize = (chin, 1)) + g_batch = Flux.batch([GNNGraph(rand_graph(n, 4), + ndata = rand(Float32, chin, n), + graph_type = GRAPH_T) + for i in 1:ng]) + test_layer(ws, g_batch, rtol = 1e-5,outtype = :graph, outsize = (chin, ng)) + end end From d9f3171e1ed77dbce581c8e9dc1f3c2dc4c67c9f Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Sun, 26 Mar 2023 22:21:06 +0200 Subject: [PATCH 5/8] Move Dense layer --- src/layers/pool.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 94bea3377..1e73f0950 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -151,15 +151,25 @@ WeigthAndSum sum pooling layer. Compute the weights for each node and perform a """ struct WeigthAndSumPool in_feats::Int + dense_layer::Dense +end + +@functor WeigthAndSumPool + +Flux.trainable(ws::WeigthAndSumPool) = (ws.dense_layer) + +function WeigthAndSumPool(in_feats::Int) + dense_layer = Dense(in_feats, 1, sigmoid; bias = true) + WeigthAndSumPool(in_feats, dense_layer) end function (ws::WeigthAndSumPool)(g::GNNGraph, x::AbstractArray) - atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) + atom_weighting = ws.dense_layer return reduce_nodes(+, g, atom_weighting(x) .* x) end function (ws::WeigthAndSumPool)(g::GNNGraph, x::CuArray) - atom_weighting = Dense(ws.in_feats, 1, sigmoid; bias = true) |> gpu + atom_weighting = ws.dense_layer |> gpu return reduce_nodes(+, g, atom_weighting(x) .* x) end From 4921e9f3c0a8917514955fb14a6632ff01b2f65e Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Sun, 26 Mar 2023 22:30:39 +0200 Subject: [PATCH 6/8] Improve docstring --- src/layers/pool.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 1e73f0950..c3b7e8a18 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -147,7 +147,8 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k) """ WeigthAndSumPool(in_feats) -WeigthAndSum sum pooling layer. Compute the weights for each node and perform a weighted sum. +WeigthAndSum sum pooling layer. +Takes a graph and the node features as inputs, computes the weights for each node and perform a weighted sum. """ struct WeigthAndSumPool in_feats::Int From 191463018ac067a5de30b87df32060a11567f5d6 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 11:21:11 +0200 Subject: [PATCH 7/8] Fix and add example --- src/layers/pool.jl | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index c3b7e8a18..b7b581cf6 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -149,6 +149,18 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k) WeigthAndSum sum pooling layer. Takes a graph and the node features as inputs, computes the weights for each node and perform a weighted sum. + +# Example + +```julia +n = 3 +chin = 5 + +ws = WeigthAndSumPool(chin) +g = GNNGraph(rand_graph(3, 4), ndata = rand(Float32, chin, 3), graph_type = GRAPH_T) + +u = ws(g, g.ndata.x) +``` """ struct WeigthAndSumPool in_feats::Int @@ -157,8 +169,6 @@ end @functor WeigthAndSumPool -Flux.trainable(ws::WeigthAndSumPool) = (ws.dense_layer) - function WeigthAndSumPool(in_feats::Int) dense_layer = Dense(in_feats, 1, sigmoid; bias = true) WeigthAndSumPool(in_feats, dense_layer) From 9d0e12d968f6f811758f51f8a9db1d20abec3e5e Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Fri, 18 Aug 2023 21:51:50 +0200 Subject: [PATCH 8/8] Fix bigger graph Co-authored-by: Carlo Lucibello --- src/layers/pool.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/pool.jl b/src/layers/pool.jl index 16ae5d49c..b03fbf59f 100644 --- a/src/layers/pool.jl +++ b/src/layers/pool.jl @@ -157,7 +157,7 @@ n = 3 chin = 5 ws = WeigthAndSumPool(chin) -g = GNNGraph(rand_graph(3, 4), ndata = rand(Float32, chin, 3), graph_type = GRAPH_T) +g = GNNGraph(rand_graph(30, 50), ndata = rand(Float32, chin, 30)) u = ws(g, g.ndata.x) ```