Skip to content

Commit

Permalink
Added Personalized PageRank Diffusion [ppr_diffusion function] (#427)
Browse files Browse the repository at this point in the history
* add ppr diffusion

* add ppr diffusion

* add function to GNNGraphs.jl

* :coo

* Update transform.jl

* try

* Made function non-mutating

uses SparseArrays

* Update src/GNNGraphs/transform.jl

rename args

Co-authored-by: Carlo Lucibello <[email protected]>

* Update test/GNNGraphs/transform.jl

clean code

Co-authored-by: Carlo Lucibello <[email protected]>

* Update test/GNNGraphs/transform.jl

remove unneeded line

Co-authored-by: Carlo Lucibello <[email protected]>

* Update src/GNNGraphs/transform.jl

args fix

Co-authored-by: Carlo Lucibello <[email protected]>

* Update src/GNNGraphs/transform.jl

rename var

Co-authored-by: Carlo Lucibello <[email protected]>

* empty weights

* indent

* fixes

* Update test/GNNGraphs/transform.jl

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
rbSparky and CarloLucibello authored Jun 29, 2024
1 parent bcce0cf commit 3bcafbe
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export add_nodes,
to_unidirected,
random_walk_pe,
remove_nodes,
ppr_diffusion,
drop_nodes,
# from Flux
batch,
Expand Down
46 changes: 46 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1168,3 +1168,49 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable dense_zeros_like(x...)

"""
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix.
References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422)
The function performs the following steps:
1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix.
2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities.
3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix.
4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix.
# Arguments
- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available.
- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`.
# Returns
- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation.
"""
function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0)
s, t = edge_index(g)
w = get_edge_weight(g)
if isnothing(w)
w = ones(Float32, g.num_edges)
end

N = g.num_nodes

initial_A = sparse(t, s, w, N, N)
scaled_A = (Float32(alpha) - 1) * initial_A

I_sparse = sparse(Diagonal(ones(Float32, N)))
A_sparse = I_sparse + scaled_A

A_dense = Matrix(A_sparse)

PPR = alpha * inv(A_dense)

new_w = [PPR[dst, src] for (src, dst) in zip(s, t)]

return GNNGraph((s, t, new_w),
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, g.edata, g.gdata)
end
20 changes: 20 additions & 0 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -595,4 +595,24 @@ end

@test g.graph[(:A, :to1, :A)][3] == vcat([2, 2, 2], fill(1, n))
end
end

@testset "ppr_diffusion" begin
if GRAPH_T == :coo
s = [1, 1, 2, 3]
t = [2, 3, 4, 5]
eweights = [0.1, 0.2, 0.3, 0.4]

g = GNNGraph(s, t, eweights)

g_new = ppr_diffusion(g)
w_new = get_edge_weight(g_new)

check_ew = Float32[0.012749999
0.025499998
0.038249996
0.050999995]

@test w_new check_ew
end
end

0 comments on commit 3bcafbe

Please sign in to comment.