diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index b9551c8..322d3ed 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,5 +1,6 @@ using DiagrammaticEquations using ACSets +using DataStructures export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name @@ -8,58 +9,63 @@ struct TraversalNode{T} name::T end -function topological_sort_edges(d::SummationDecapode) - visited_Var = falses(nparts(d, :Var)) - visited_Var[start_nodes(d)] .= true - - # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ - visited_1 = falses(nparts(d, :Op1)) - visited_2 = falses(nparts(d, :Op2)) - visited_Σ = falses(nparts(d, :Σ)) - - # FIXME: this is a quadratic implementation of topological_sort inlined in here. - op_order = TraversalNode{Symbol}[] +number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) - for _ in 1:number_of_ops(d) - for op in parts(d, :Op1) - if !visited_1[op] && visited_Var[d[op, :src]] +start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) - visited_1[op] = true - visited_Var[d[op, :tgt]] = true - - push!(op_order, TraversalNode(op, :Op1)) - end - end - - for op in parts(d, :Op2) - if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] - visited_2[op] = true - visited_Var[d[op, :res]] = true - push!(op_order, TraversalNode(op, :Op2)) - end - end - - for op in parts(d, :Σ) - args = subpart(d, incident(d, op, :summation), :summand) - if !visited_Σ[op] && all(visited_Var[args]) - visited_Σ[op] = true - visited_Var[d[op, :sum]] = true - push!(op_order, TraversalNode(op, :Σ)) +#https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm#Pseudocode +function floyd_warshall(d::SummationDecapode) + # Init dists. + V = nparts(d, :Var) + dist = fill(Inf, (V, V)) + foreach(parts(d,:Op1)) do e + dist[d[e,:src], d[e,:tgt]] = 1 + end + foreach(parts(d,:Op2)) do e + dist[d[e,:proj1], d[e,:res]] = 1 + dist[d[e,:proj2], d[e,:res]] = 1 + end + foreach(parts(d,:Summand)) do e + dist[d[e,:summand], d[e,[:summation, :sum]]] = 1 + end + for v in 1:V + dist[v,v] = 0 + end + # Floyd-Warshall + for k in 1:V + for i in 1:V + for j in 1:V + if dist[i,j] > dist[i,k] + dist[k,j] + dist[i,j] = dist[i,k] + dist[k,j] + end end end end - - @assert length(op_order) == number_of_ops(d) - - op_order + dist end -function number_of_ops(d::SummationDecapode) - return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) +function topological_sort_verts(d::SummationDecapode) + m = floyd_warshall(d) + map(parts(d,:Var)) do v + maximum(filter(!isinf, m[start_nodes(d),v])) + end + # Call sortperm for the vertex ordering. end -function start_nodes(d::SummationDecapode) - return vcat(infer_states(d), incident(d, :Literal, :type)) +function topological_sort_edges(d::SummationDecapode) + tsv = topological_sort_verts(d) + op_order = [TraversalNode.(parts(d,:Op1), :Op1)..., + TraversalNode.(parts(d,:Op2), :Op2)..., + TraversalNode.(parts(d,:Σ), :Σ)...] + function by(x) + @match x.name begin + :Op1 => tsv[d[x.index,:src]] + :Op2 => max(tsv[d[x.index,:proj1]], tsv[d[x.index,:proj1]]) + :Σ => maximum(tsv[d[incident(d,x.index,:summation),:summand]]) + _ => error("Unknown function type") + end + end + sort(op_order, by = by) end function retrieve_name(d::SummationDecapode, tsr::TraversalNode) @@ -70,3 +76,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) _ => error("$(tsr.name) is not a valid table for names") end end + diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 7a5d6ac..825de65 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -29,11 +29,14 @@ end @test retrieve_name(multi_op1_deca, edge) == test_name end + # XXX Do cycle-detection with FW by using ∞ on the diagonal. + #= cyclic = @decapode begin B == g(A) A == f(B) end @test_throws AssertionError topological_sort_edges(cyclic) + =# just_op2 = @decapode begin C == A * B