Skip to content

Commit

Permalink
Define Floyd-Warshall algorithm on Decapodes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Aug 24, 2024
1 parent 2a6269c commit 2512181
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 43 deletions.
93 changes: 50 additions & 43 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DiagrammaticEquations
using ACSets
using DataStructures

export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name

Expand All @@ -8,58 +9,63 @@ struct TraversalNode{T}
name::T
end

function topological_sort_edges(d::SummationDecapode)
visited_Var = falses(nparts(d, :Var))
visited_Var[start_nodes(d)] .= true

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = TraversalNode{Symbol}[]
number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )

for _ in 1:number_of_ops(d)
for op in parts(d, :Op1)
if !visited_1[op] && visited_Var[d[op, :src]]
start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type))

visited_1[op] = true
visited_Var[d[op, :tgt]] = true

push!(op_order, TraversalNode(op, :Op1))
end
end

for op in parts(d, :Op2)
if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]]
visited_2[op] = true
visited_Var[d[op, :res]] = true
push!(op_order, TraversalNode(op, :Op2))
end
end

for op in parts(d, )
args = subpart(d, incident(d, op, :summation), :summand)
if !visited_Σ[op] && all(visited_Var[args])
visited_Σ[op] = true
visited_Var[d[op, :sum]] = true
push!(op_order, TraversalNode(op, ))
#https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm#Pseudocode
function floyd_warshall(d::SummationDecapode)
# Init dists.
V = nparts(d, :Var)
dist = fill(Inf, (V, V))
foreach(parts(d,:Op1)) do e
dist[d[e,:src], d[e,:tgt]] = 1
end
foreach(parts(d,:Op2)) do e
dist[d[e,:proj1], d[e,:res]] = 1
dist[d[e,:proj2], d[e,:res]] = 1
end
foreach(parts(d,:Summand)) do e
dist[d[e,:summand], d[e,[:summation, :sum]]] = 1
end
for v in 1:V
dist[v,v] = 0
end
# Floyd-Warshall
for k in 1:V
for i in 1:V
for j in 1:V
if dist[i,j] > dist[i,k] + dist[k,j]
dist[i,j] = dist[i,k] + dist[k,j]
end
end
end
end

@assert length(op_order) == number_of_ops(d)

op_order
dist
end

function number_of_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
function topological_sort_verts(d::SummationDecapode)
m = floyd_warshall(d)
map(parts(d,:Var)) do v
maximum(filter(!isinf, m[start_nodes(d),v]))
end
# Call sortperm for the vertex ordering.

This comment has been minimized.

Copy link
@jpfairbanks

jpfairbanks Aug 24, 2024

Member

This comment is ambiguous. Can you add a docstring that says something like:

The vector returned by this function maps each vertex to the order that it would be traversed in a topological sort traversal. If you want a list of vertices in the order of traversal call sortperm on the output.

This comment has been minimized.

Copy link
@lukem12345

lukem12345 Aug 24, 2024

Author Member

Good idea. I used this almost verbatim

end

function start_nodes(d::SummationDecapode)
return vcat(infer_states(d), incident(d, :Literal, :type))
function topological_sort_edges(d::SummationDecapode)
tsv = topological_sort_verts(d)
op_order = [TraversalNode.(parts(d,:Op1), :Op1)...,
TraversalNode.(parts(d,:Op2), :Op2)...,
TraversalNode.(parts(d,), )...]
function by(x)
@match x.name begin
:Op1 => tsv[d[x.index,:src]]
:Op2 => max(tsv[d[x.index,:proj1]], tsv[d[x.index,:proj1]])
=> maximum(tsv[d[incident(d,x.index,:summation),:summand]])
_ => error("Unknown function type")
end
end
sort(op_order, by = by)
end

function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
Expand All @@ -70,3 +76,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
_ => error("$(tsr.name) is not a valid table for names")
end
end

3 changes: 3 additions & 0 deletions test/graph_traversal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ end
@test retrieve_name(multi_op1_deca, edge) == test_name
end

# XXX Do cycle-detection with FW by using ∞ on the diagonal.
#=
cyclic = @decapode begin
B == g(A)
A == f(B)
end
@test_throws AssertionError topological_sort_edges(cyclic)
=#

just_op2 = @decapode begin
C == A * B
Expand Down

0 comments on commit 2512181

Please sign in to comment.