Skip to content

Commit

Permalink
Implement shared memory multithreaded sheaf laplacians
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhanks committed Dec 18, 2024
1 parent 499b183 commit 9ef27d8
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/DistributedSheaves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ function random_distributed_sheaf(num_nodes, edge_probability, restriction_map_d
# Spawn sheaf nodes on every worker
node_refs = Future[]
for w in workers
push!(node_refs, @spawnat w SheafNode(w, n,
push!(node_refs, @spawnat w DistributedSheafNode(w, n,
Dict{Int32, SparseMatrixCSC{Float32, Int32}}(), # restriction maps
#Dict{Int32, Matrix{Float32}}(),
Dict{Int32, RemoteChannel}(), # inbound channels
Dict{Int32, RemoteChannel}(), rand(n))) # outbound channels, initial state
end
Expand All @@ -59,6 +60,8 @@ function random_distributed_sheaf(num_nodes, edge_probability, restriction_map_d
# A should live on proc i and B should live on proc j
Aref = @spawnat i sprand(n,n,p)
Bref = @spawnat j sprand(n,n,p)
#Aref = @spawnat i rand(n,n)
#Bref = @spawnat j rand(n,n)

remote_do(node_ref -> fetch(node_ref).neighbors[j] = fetch(Aref), i, node_refs[i-1])
remote_do(node_ref -> fetch(node_ref).neighbors[i] = fetch(Bref), j, node_refs[j-1])
Expand Down
21 changes: 16 additions & 5 deletions src/SheafNodes.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
module SheafNodes

export SheafNode, add_neighbor!
export DistributedSheafNode, ThreadedSheafNode, add_neighbor!, neighbors

using Distributed
using SparseArrays

mutable struct SheafNode
abstract type AbstractSheafNode end

mutable struct ThreadedSheafNode <: AbstractSheafNode
id::Int32
dimension::Int32
neighbors::Dict{Int32, AbstractMatrix}
in_channels::Dict{Int32, Channel}
out_channels::Dict{Int32, Channel}
x::Vector{Float32}
end

mutable struct DistributedSheafNode <: AbstractSheafNode
id::Int32
dimension::Int32
neighbors::Dict{Int32, SparseMatrixCSC{Float32, Int32}}
neighbors::Dict{Int32, AbstractMatrix}
in_channels::Dict{Int32, RemoteChannel}
out_channels::Dict{Int32, RemoteChannel}
x::Vector{Float32}
end

function add_neighbor!(s::SheafNode, n_id::Int32, restriction_map)
function add_neighbor!(s::AbstractSheafNode, n_id::Int32, restriction_map)
s.neighbors[n_id] = restriction_map
end

function neighbors(s::SheafNode)
function neighbors(s::AbstractSheafNode)
return collect(keys(s.neighbors))
end

Expand Down
97 changes: 97 additions & 0 deletions src/ThreadedSheaves.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
include("SheafNodes.jl")
using .SheafNodes
using Base.Threads
using SparseArrays
using LinearAlgebra

function local_laplacian_step!(node, step_size)
x_old = node.x
delta_x = zeros(node.dimension)

for (n, rm) in node.neighbors
outgoing_edge_val = rm*x_old
incoming_edge_val = take!(node.in_channels[n])
delta_x += rm'*(outgoing_edge_val - incoming_edge_val)
end
x_new = x_old - step_size*delta_x

for (n, rm) in node.neighbors
put!(node.out_channels[n], rm*x_new)
end

node.x = x_new
end

function laplacian_step!(nodes, step_size::Float32)
Threads.@threads for node in nodes
local_laplacian_step!(node, step_size)
end
end

function random_threaded_sheaf(num_nodes, edge_probability, restriction_map_dimension, restriction_map_density)
nodes = ThreadedSheafNode[]
coin()::Bool = rand() < edge_probability
n, p = restriction_map_dimension, restriction_map_density
for i in 1:num_nodes
push!(nodes, ThreadedSheafNode(i, n,
Dict{Int32, SparseMatrixCSC{Float32, Int32}}(),
Dict{Int32, Channel}(),
Dict{Int32, Channel}(), rand(n)))
end

for i in 1:num_nodes
for j in i+1:num_nodes
if coin()
A = sprand(n,n,p)
B = sprand(n,n,p)

nodes[i].neighbors[j] = A
nodes[j].neighbors[i] = B

i_to_j_channel = Channel{Vector{Float32}}(1)
j_to_i_channel = Channel{Vector{Float32}}(1)

nodes[i].in_channels[j] = j_to_i_channel
nodes[i].out_channels[j] = i_to_j_channel
put!(i_to_j_channel, A*nodes[i].x)

nodes[j].in_channels[i] = i_to_j_channel
nodes[j].out_channels[i] = j_to_i_channel
put!(j_to_i_channel, B*nodes[j].x)
end
end
end
return nodes
end


function distance_from_consensus(nodes)
total_distance = 0.0
for node in nodes
node_distance = 0.0
# There is some double counting happening in here but idrc
for ((_, in_channel), (_, out_channel)) in zip(node.in_channels, node.out_channels)
node_distance += norm(fetch(in_channel) - fetch(out_channel))
end
total_distance += node_distance
end
return total_distance
end

# Returns a list of distances from consensus over the iterations
function iterate_laplacian!(nodes, step_size, num_iters)
distances = Float64[]

for _ in 1:num_iters+1
laplacian_step!(nodes, step_size)
push!(distances, distance_from_consensus(nodes))
end
return distances
end

# Randomly reinitialize the nodes states
function random_initialization(nodes)
for node in nodes
node.x = rand(node.dimension)
end
end

0 comments on commit 9ef27d8

Please sign in to comment.