Skip to content

Commit

Permalink
Add backtracking line search
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerhanks committed Dec 19, 2024
1 parent 710cfa0 commit 2a045ff
Showing 1 changed file with 193 additions and 6 deletions.
199 changes: 193 additions & 6 deletions src/ThreadedSheaves.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,61 @@ using Base.Threads
using SparseArrays
using LinearAlgebra

#=function node_consensus_objective(node::ThreadedSheafNode, x, neighbor_xs::Dict{Int, Vector})
loss = 0.0
for (n, rm) in node.neighbors
outgoing_edge_val = rm*x
incoming_edge_val = neighbor_xs[n]
loss += LinearAlgebra.norm_sqr(outgoing_edge_val - incoming_edge_val)
end
return loss
end
function local_line_search(node::ThreadedSheafNode, x, delta_x, neighbor_xs::Dict{Int, Vector})
c = 0.5
τ = 0.5
m = (-delta_x)'*delta_x
t = c*m
a = .5
while node_consensus_objective(node, x, neighbor_xs) - node_consensus_objective(node, x + a*delta_x, neighbor_xs) < a*t
a = τ*a
end
println("Step size: $a")
return a
end
function local_laplacian_step!(node)
x_old = node.x
delta_x = zeros(node.dimension)
neighbor_xs = Dict{Int, Vector}()
for (n, rm) in node.neighbors
outgoing_edge_val = rm*x_old
incoming_edge_val = take!(node.in_channels[n])
neighbor_xs[n] = incoming_edge_val
delta_x -= rm'*(outgoing_edge_val - incoming_edge_val)
end
step_size = local_line_search(node, x_old, delta_x, neighbor_xs)
x_new = x_old + step_size*delta_x
for (n, rm) in node.neighbors
put!(node.out_channels[n], rm*x_new)
end
node.x = x_new
end=#


function local_laplacian_step!(node, step_size)
x_old = node.x
delta_x = zeros(node.dimension)

for (n, rm) in node.neighbors
outgoing_edge_val = rm*x_old
incoming_edge_val = take!(node.in_channels[n])
delta_x += rm'*(outgoing_edge_val - incoming_edge_val)
delta_x -= rm'*(outgoing_edge_val - incoming_edge_val)
end
x_new = x_old - step_size*delta_x
x_new = x_old + step_size*delta_x

for (n, rm) in node.neighbors
put!(node.out_channels[n], rm*x_new)
Expand All @@ -28,6 +73,107 @@ function laplacian_step!(nodes, step_size::Float32)
end
end

#=function laplacian_step!(nodes)
Threads.@threads for node in nodes
local_laplacian_step!(node)
end
end=#

# Compute the local update direction for a given node
function local_descent_direction(node)
x_old = node.x
delta_x = zeros(node.dimension)

for (n, rm) in node.neighbors
outgoing_edge_val = rm*x_old
incoming_edge_val = fetch(node.in_channels[n])
delta_x -= rm'*(outgoing_edge_val - incoming_edge_val)
end
return delta_x
end

function descent_direction!(nodes, results::Vector{Vector{Float32}})
# Allocate a shared memory array for results
#dimensions = (n -> n.dimension).nodes
#results = [Vector{Float32}(undef, d) for d in dimensions] # make a version that passes this in as an argument to override

Threads.@threads for i in eachindex(nodes)
results[i] = local_descent_direction(nodes[i])
end
end

function local_consensus_objective(node::ThreadedSheafNode, x::Vector{Float32})
loss = 0.0
for (n, rm) in node.neighbors
outgoing_edge_val = rm*x
incoming_edge_val = fetch(node.in_channels[n])
loss += LinearAlgebra.norm_sqr(outgoing_edge_val - incoming_edge_val)
end
return loss
end

function consensus_objective(nodes::Vector{ThreadedSheafNode}, xs::Vector{Vector{Float32}})
losses = Vector{Float64}(undef, length(nodes))
Threads.@threads for i in eachindex(nodes)
losses[i] = local_consensus_objective(nodes[i], xs[i])
end
return sum(losses)
end

# Assumes length(x) == length(y) and all inner dimensions also match
function threaded_sum(x, y)
n = length(x)
m = length(x[1])
res = repeat([Vector{Float32}(undef, m)], n)
Threads.@threads for i in eachindex(x)
res[i] = x[i] + y[i]
end
return res
end

function line_search(nodes, delta_x::Vector{Vector{Float32}})#, prev_a)
@assert length(nodes) == length(delta_x)
c = 0.75
τ = 0.25
ms = Vector{Float32}(undef, length(nodes))
Threads.@threads for i in eachindex(delta_x)
ms[i] = (-delta_x[i])'*delta_x[i]
end
m = sum(ms)
#println(m)
t = -c*m
#a = prev_a
a=Float32(.1)
x = [node.x for node in nodes]
while consensus_objective(nodes, x) - consensus_objective(nodes, threaded_sum(x, a .* delta_x)) < a*t
a = τ*a
end
#println("Step size: $a")
return a
end

# Each node takes a step in the direction of a*delta_x.
# Also updates communication channels.
function update_nodes!(nodes, a, delta_x)
Threads.@threads for i in eachindex(nodes)
nodes[i].x += a*delta_x[i]
for (n, rm) in nodes[i].neighbors
# Consume from buffers
take!(nodes[i].in_channels[n])
put!(nodes[i].out_channels[n], rm*nodes[i].x)
end
end
end

# Overwrites delta_x
function laplacian_step!(nodes, delta_x::Vector{Vector{Float32}})#, prev_ss)
descent_direction!(nodes, delta_x)
#step_size = line_search(nodes, delta_x, prev_ss)
step_size = line_search(nodes, delta_x)
update_nodes!(nodes, step_size, delta_x)
#return step_size
end

function random_threaded_sheaf(num_nodes, edge_probability, restriction_map_dimension, restriction_map_density)
nodes = ThreadedSheafNode[]
coin()::Bool = rand() < edge_probability
Expand Down Expand Up @@ -85,16 +231,57 @@ end
function iterate_laplacian!(nodes, step_size, num_iters)
distances = Float64[]

for _ in 1:num_iters+1
for _ in 1:num_iters
push!(distances, distance_from_consensus(nodes))
laplacian_step!(nodes, step_size)
end
push!(distances, distance_from_consensus(nodes))
return distances
end

function iterate_laplacian!(nodes, num_iters)
distances = Float64[]
dimensions = (n -> n.dimension).(nodes)
delta_x = [Vector{Float32}(undef, d) for d in dimensions]
#ss = 0.5
for _ in 1:num_iters
push!(distances, distance_from_consensus(nodes))
laplacian_step!(nodes, delta_x)#, ss)
end
push!(distances, distance_from_consensus(nodes))
return distances
end

#=function iterate_laplacian!(nodes, num_iters)
distances = Float64[]
for _ in 1:num_iters+1
laplacian_step!(nodes)
push!(distances, distance_from_consensus(nodes))
end
return distances
end=#

# Randomly reinitialize the nodes states
function random_initialization(nodes)
function random_initialization(nodes::Vector{ThreadedSheafNode})
for node in nodes
node.x = rand(node.dimension)
x = rand(node.dimension)
node.x = x
for (n, rm) in node.neighbors
take!(node.out_channels[n])
put!(node.out_channels[n], rm*x)
end
end
end

function initialize!(nodes, xs)
for (x, node) in zip(xs, nodes)
node.x = x

for (n, rm) in node.neighbors
take!(node.out_channels[n])
put!(node.out_channels[n], rm*x)
end
end
end
end

0 comments on commit 2a045ff

Please sign in to comment.