Skip to content

Commit

Permalink
fix lstm
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Dec 7, 2024
1 parent de896d5 commit 82c28b0
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 20 deletions.
4 changes: 2 additions & 2 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNLux"
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"
version = "0.2.0-DEV"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand All @@ -18,7 +18,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ConcreteStructs = "0.2.3"
GNNGraphs = "1.3"
GNNlib = "0.2.3"
GNNlib = "1"
Lux = "1"
LuxCore = "1"
NNlib = "0.9.21"
Expand Down
2 changes: 1 addition & 1 deletion GNNlib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GNNlib"
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
authors = ["Carlo Lucibello and contributors"]
version = "0.2.3"
version = "1.0.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ topk_index(y::Adjoint, k::Int) = topk_index(y', k)
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
n_in = size(x, 1)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
for t in 1:l.num_iters
q = l.lstm(qstar) # [n_in, n_graphs]
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
q = h
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]
Expand Down
4 changes: 2 additions & 2 deletions GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
ChainRulesCore = "1"
Flux = "0.15"
GNNGraphs = "1.0"
GNNlib = "0.2"
GNNGraphs = "1"
GNNlib = "1"
LinearAlgebra = "1"
MLUtils = "0.4"
MacroTools = "0.5"
Expand Down
21 changes: 8 additions & 13 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,19 @@ end
Flux.@layer Set2Set

function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
@assert n_layers >= 1
@assert n_layers == 1 "multiple layers not implemented yet" #TODO
n_out = 2 * n_in

if n_layers == 1
lstm = LSTM(n_out => n_in)
else
layers = [LSTM(n_out => n_in)]
for _ in 2:n_layers
push!(layers, LSTM(n_in => n_in))
end
lstm = Chain(layers...)
end

lstm = LSTMCell(n_out => n_in)
return Set2Set(lstm, n_iters)
end

function initialstates(cell::LSTMCell)
h = zeros_like(cell.Wh, size(cell.Wh, 2))
c = zeros_like(cell.Wh, size(cell.Wh, 2))
return h, c
end

function (l::Set2Set)(g, x)
Flux.reset!(l.lstm)
return GNNlib.set2set_pool(l, g, x)
end

Expand Down
2 changes: 1 addition & 1 deletion GraphNeuralNetworks/test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end

n_in = 3
n_iters = 2
n_layers = 1
n_layers = 1 #TODO test with more layers
g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5])
g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes))
l = Set2Set(n_in, n_iters, n_layers)
Expand Down

0 comments on commit 82c28b0

Please sign in to comment.