Skip to content

Commit

Permalink
Addition of generic graph struct
Browse files Browse the repository at this point in the history
This struct organizes the data into a more generic hypergraph that can then be routed through generic graph algorithms, like topo sort or F-W, without relying on the underlying ACSet structure.
  • Loading branch information
GeorgeR227 committed Aug 28, 2024
1 parent dfe0e03 commit db2274f
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("graph_traversal.jl")
include("graph_interface.jl")
include("acset2symbolic.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
Expand Down
56 changes: 55 additions & 1 deletion src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,50 @@ using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

import DiagrammaticEquations: HyperGraph, HyperGraphEdge, HyperGraphVertex, vertex_list, edge_list
import DiagrammaticEquations: topological_sort_edges

export TableData, extract_symexprs, number_of_ops, retrieve_name

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name))
# Decapode graph conversion
struct TableData
table_index::Int
table_name::Symbol
end

HyperGraphVertex(d::SummationDecapode, index::Int) = HyperGraphVertex(index, TableData(index, :Var))

HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op1}) = HyperGraphEdge(d[index, :tgt], [d[index, :src]], TableData(index, :Op1))
HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Op2}) = HyperGraphEdge(d[index, :res], [d[index,:proj1],d[index,:proj2]], TableData(index, :Op2))
HyperGraphEdge(d::SummationDecapode, index::Int, ::Val{:Σ}) = HyperGraphEdge(d[index, :sum], d[incident(d, index, :summation), :summand], TableData(index, ))

HyperGraph(d::SummationDecapode) = HyperGraph(vertex_list(d), edge_list(d), nothing)

vertex_list(d::SummationDecapode) = map(id -> HyperGraphVertex(d, id), parts(d, :Var))

function edge_list(d::SummationDecapode)
edges = HyperGraphEdge[]
for op_table in [:Op1, :Op2, ]
for op in parts(d, op_table)
if op_table == :Op1 && d[op, :op1] == DerivOp
continue
end
push!(edges, HyperGraphEdge(d, op, Val(op_table)))
end
end
edges
end

table_data(v::HyperGraphVertex) = v.metadata
table_data(v::HyperGraphEdge) = v.metadata

topological_sort_edges(d::SummationDecapode) = table_data.(topological_sort_edges(HyperGraph(d)))

# Decapode ACSet symbolics conversion

to_symbolics(d::SummationDecapode, data::TableData) = to_symbolics(d, data.table_index, Val(data.table_name))

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name])
Expand Down Expand Up @@ -54,6 +95,19 @@ function apply_rewrites(d::SummationDecapode, rewriter)
rewritten_list
end

function number_of_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
end

function retrieve_name(d::SummationDecapode, data::TableData)
@match data.table_name begin
:Op1 => d[data.table_index, :op1]
:Op2 => d[data.table_index, :op2]
=> :+
_ => error("$(data.table_name) is a table without names")
end
end

# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet
# @syms Δ(x) d(x) ⋆(x)
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
Expand Down
133 changes: 133 additions & 0 deletions src/graph_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using DiagrammaticEquations
using ACSets

export HyperGraph, HyperGraphVertex, HyperGraphEdge, vertex_list, edge_list
export topological_sort_edges, floyd_warshall

struct HyperGraphVertex
id::Int
metadata
end

# Assuming we only have a single target
struct HyperGraphEdge
tgt::Int
srcs::AbstractVector{Int}
metadata
end

struct HyperGraph
vertices::AbstractVector{HyperGraphVertex}
edges::AbstractVector{HyperGraphEdge}
metadata
end

# Returns a list of all vertices from ACSet as HyperGraphVertex
function vertex_list() end

# Returns a list of all edges from ACSet as HyperGraphEdge
function edge_list() end

num_vertices(g::HyperGraph) = length(g.vertices)
num_edges(g::HyperGraph) = length(g.edges)

# TODO: Clean this up to use better logic
function start_nodes(g::HyperGraph)
indices = HyperGraphVertex[]

for vertex in g.vertices
v_id = vertex.id

is_tgt = true
for edge in g.edges
if v_id == edge.tgt
is_tgt = false
break
end
end

if is_tgt
push!(indices, vertex)
end

end

indices
end

function has_unique_targets(g::HyperGraph)
seen_vertices = Set{Int}()
for edge in g.edges
if edge.tgt in seen_vertices
return false
end
push!(seen_vertices, edge.tgt)
end
return true
end

vertex_id(v::HyperGraphVertex) = return v.id

function topological_sort_edges(g::HyperGraph)
@assert has_unique_targets(g)

visited_vertices = falses(num_vertices(g))
visited_vertices[vertex_id.(start_nodes(g))] .= true

visited_edges = falses(num_edges(g))

edge_order = HyperGraphEdge[]

for _ in 1:num_edges(g)
for (idx, edge) in enumerate(g.edges)
if !visited_edges[idx] && all(visited_vertices[edge.srcs])
visited_edges[idx] = true
visited_vertices[edge.tgt] = true

push!(edge_order, edge)
end
end
end

@assert length(edge_order) == num_edges(g)

edge_order
end

"""
floyd_warshall(g::HyperGraph)
Return a |variable| × |variable| matrix of shortest paths via the Floyd-Warshall algorithm.
Taking the maximum of the non-infinite short paths from state variables induces a topological ordering.
https://en.wikipedia.org/wiki/Floyd–Warshall_algorithm
"""
function floyd_warshall(g::HyperGraph)
# Define weights.
w(e) = -1

# Init dists
V = num_vertices(g)
dist = fill(Inf, (V, V))
foreach(g.edges) do e
dist[(e.srcs), e.tgt] .= w(e)
end

for v in 1:V
dist[v,v] = 0
end

# Floyd-Warshall
for k in 1:V
for i in 1:V
for j in 1:V
if dist[i,j] > dist[i,k] + dist[k,j]
dist[i,j] = dist[i,k] + dist[k,j]
end
end
end
end

dist
end
42 changes: 24 additions & 18 deletions test/graph_traversal.jl → test/graph_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ using ACSets
using MLStyle
using Test

function is_correct_length(d::SummationDecapode, result)
return length(result) == number_of_ops(d)
function is_topo_sort_ordered(result::AbstractVector{TableData})
seen_edges = Dict{Symbol, Int}(:Op1 => 0, :Op2 => 0, => 0)
for entry in result
table = entry.table_name
prev_seen = seen_edges[table]
if !(prev_seen < entry.table_index)
return false
end
seen_edges[table] = entry.table_index
end
return true
end

@testset "Topological Sort on Edges" begin
Expand All @@ -13,61 +22,58 @@ end
end
@test isempty(topological_sort_edges(no_edge))

one_op1_deca = @decapode begin
one_op1 = @decapode begin
F == f(S)
end
result = topological_sort_edges(one_op1_deca)
@test is_correct_length(one_op1_deca, result)
@test retrieve_name(one_op1_deca, only(result)) == :f
result = topological_sort_edges(one_op1)
@test retrieve_name(one_op1, only(result)) == :f
@test is_topo_sort_ordered(result)

multi_op1_deca = @decapode begin
multi_op1 = @decapode begin
F == c(b(a(S)))
end
result = topological_sort_edges(multi_op1_deca)
@test is_correct_length(multi_op1_deca, result)
result = topological_sort_edges(multi_op1)
for (edge, test_name) in zip(result, [:a, :b, :c])
@test retrieve_name(multi_op1_deca, edge) == test_name
@test retrieve_name(multi_op1, edge) == test_name
end
@test is_topo_sort_ordered(result)

# XXX Do cycle-detection with FW by using ∞ on the diagonal.
#=
cyclic = @decapode begin
B == g(A)
A == f(B)
end
@test_throws AssertionError topological_sort_edges(cyclic)
=#

just_op2 = @decapode begin
C == A * B
end
result = topological_sort_edges(just_op2)
@test is_correct_length(just_op2, result)
@test retrieve_name(just_op2, only(result)) == :*
@test is_topo_sort_ordered(result)

just_simple_sum = @decapode begin
C == A + B
end
result = topological_sort_edges(just_simple_sum)
@test is_correct_length(just_simple_sum, result)
@test retrieve_name(just_simple_sum, only(result)) == :+
@test is_topo_sort_ordered(result)

just_multi_sum = @decapode begin
F == A + B + C + D + E
end
result = topological_sort_edges(just_multi_sum)
@test is_correct_length(just_multi_sum, result)
@test retrieve_name(just_multi_sum, only(result)) == :+
@test is_topo_sort_ordered(result)

op_combo = @decapode begin
F == h(d(A) + f(g(B) * C) + D)
end
result = topological_sort_edges(op_combo)
@test is_correct_length(op_combo, result)
@test is_topo_sort_ordered(result)

sum_with_single_dependency = @decapode begin
F == A + f(A) + h(g(A))
end
result = topological_sort_edges(sum_with_single_dependency)
@test is_correct_length(sum_with_single_dependency, result)
@test is_topo_sort_ordered(result)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ end
end

@testset "Symbolic Rewriting" begin
include("graph_traversal.jl")
include("graph_interface.jl")
end

0 comments on commit db2274f

Please sign in to comment.