From 303962d5b983891074165786f34d4522b457c8d5 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Sun, 25 Aug 2024 19:40:08 -0400 Subject: [PATCH] Compute longest paths from terminals --- src/DiagrammaticEquations.jl | 2 +- src/graph_traversal.jl | 7 ++++--- test/graph_traversal.jl | 6 ++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 5ab376a..978428c 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -5,7 +5,7 @@ module DiagrammaticEquations using Catlab export -DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, +DerivOp, append_dot, normalize_unicode, infer_states, infer_terminals, infer_types!, # Deca op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 5e8c794..4e61538 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -29,7 +29,6 @@ number_of_ops(d::SummationDecapode) = nparts(d, :Op1) + nparts(d, :Op2) + nparts start_nodes(d::SummationDecapode) = vcat(infer_states(d), incident(d, :Literal, :type)) - # TODO: This could be Catlab'd. Hypergraph category? Migration to a DWD? """ function hyper_edge_list(d::SummationDecapode) @@ -55,11 +54,13 @@ Taking the maximum of the non-infinite short paths from state variables induces https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm """ function floyd_warshall(d::SummationDecapode) + # Define weights. + w(e) = (length(e.dom) == 1 && e.name ∈ [:∂ₜ,:dt]) ? -Inf : -1 # Init dists V = nparts(d, :Var) dist = fill(Inf, (V, V)) foreach(hyper_edge_list(d)) do e - dist[(e.dom), e.cod] .= 1 + dist[(e.dom), e.cod] .= w(e) end for v in 1:V dist[v,v] = 0 @@ -86,7 +87,7 @@ The vector returned by this function maps each vertex to the order that it would function topological_sort_verts(d::SummationDecapode) m = floyd_warshall(d) map(parts(d,:Var)) do v - maximum(filter(!isinf, m[start_nodes(d),v])) + minimum(filter(!isinf, m[v,infer_terminals(d)])) end end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl index 825de65..9160b05 100644 --- a/test/graph_traversal.jl +++ b/test/graph_traversal.jl @@ -64,4 +64,10 @@ end end result = topological_sort_edges(op_combo) @test is_correct_length(op_combo, result) + + sum_with_single_dependency = @decapode begin + F == A + f(A) + h(g(A)) + end + result = topological_sort_edges(sum_with_single_dependency) + @test is_correct_length(sum_with_single_dependency, result) end