Skip to content

Commit

Permalink
Initial attempt at rewriting
Browse files Browse the repository at this point in the history
Converts ACSet to a series of Symbolic terms that can be rewritten with a provided rewriter
  • Loading branch information
GeorgeR227 committed Aug 23, 2024
1 parent 23fbf3f commit 2a6269c
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
SymbolicUtils = "3.4"
Unicode = "1.6"
julia = "1.6"
2 changes: 2 additions & 0 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("graph_traversal.jl")
include("acset2symbolic.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")

Expand Down
60 changes: 60 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using DiagrammaticEquations
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name))

Check warning on line 9 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L9

Added line #L9 was not covered by tests

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

Check warning on line 14 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L11-L14

Added lines #L11 - L14 were not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 17 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
end

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2})
input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name])
input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

Check warning on line 24 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L20-L24

Added lines #L20 - L24 were not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 27 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L26-L27

Added lines #L26 - L27 were not covered by tests
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ})
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end

function extract_symexprs(d::SummationDecapode)
topo_list = topological_sort_edges(d)
sym_list = []
for node in topo_list
retrieve_name(d, node) != DerivOp || continue
push!(sym_list, to_symbolics(d, node))
end
sym_list

Check warning on line 42 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L35-L42

Added lines #L35 - L42 were not covered by tests
end

function apply_rewrites(d::SummationDecapode, rewriter)

Check warning on line 45 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L45

Added line #L45 was not covered by tests

rewritten_list = []
for sym in extract_symexprs(d)
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
end

Check warning on line 52 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L47-L52

Added lines #L47 - L52 were not covered by tests

rewritten_list

Check warning on line 54 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L54

Added line #L54 was not covered by tests
end

# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet
# @syms Δ(x) d(x) ⋆(x)
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# rewriter = Postwalk(RestartedChain([lap_0_rule]))
72 changes: 72 additions & 0 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name

struct TraversalNode{T}
index::Int
name::T
end

function topological_sort_edges(d::SummationDecapode)
visited_Var = falses(nparts(d, :Var))
visited_Var[start_nodes(d)] .= true

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = TraversalNode{Symbol}[]

for _ in 1:number_of_ops(d)
for op in parts(d, :Op1)
if !visited_1[op] && visited_Var[d[op, :src]]

visited_1[op] = true
visited_Var[d[op, :tgt]] = true

push!(op_order, TraversalNode(op, :Op1))
end
end

for op in parts(d, :Op2)
if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]]
visited_2[op] = true
visited_Var[d[op, :res]] = true
push!(op_order, TraversalNode(op, :Op2))
end
end

for op in parts(d, )
args = subpart(d, incident(d, op, :summation), :summand)
if !visited_Σ[op] && all(visited_Var[args])
visited_Σ[op] = true
visited_Var[d[op, :sum]] = true
push!(op_order, TraversalNode(op, ))
end
end
end

@assert length(op_order) == number_of_ops(d)

op_order

Check warning on line 54 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L54

Added line #L54 was not covered by tests
end

function number_of_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
end

function start_nodes(d::SummationDecapode)
return vcat(infer_states(d), incident(d, :Literal, :type))
end

function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
@match tsr.name begin
:Op1 => d[tsr.index, :op1]
:Op2 => d[tsr.index, :op2]
=> :+
_ => error("$(tsr.name) is not a valid table for names")

Check warning on line 70 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L70

Added line #L70 was not covered by tests
end
end
64 changes: 64 additions & 0 deletions test/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using DiagrammaticEquations
using ACSets
using MLStyle
using Test

function is_correct_length(d::SummationDecapode, result)
return length(result) == number_of_ops(d)
end

@testset "Topological Sort on Edges" begin
no_edge = @decapode begin
F == S
end
@test isempty(topological_sort_edges(no_edge))

one_op1_deca = @decapode begin
F == f(S)
end
result = topological_sort_edges(one_op1_deca)
@test is_correct_length(one_op1_deca, result)
@test retrieve_name(one_op1_deca, only(result)) == :f

multi_op1_deca = @decapode begin
F == c(b(a(S)))
end
result = topological_sort_edges(multi_op1_deca)
@test is_correct_length(multi_op1_deca, result)
for (edge, test_name) in zip(result, [:a, :b, :c])
@test retrieve_name(multi_op1_deca, edge) == test_name
end

cyclic = @decapode begin
B == g(A)
A == f(B)
end
@test_throws AssertionError topological_sort_edges(cyclic)

just_op2 = @decapode begin
C == A * B
end
result = topological_sort_edges(just_op2)
@test is_correct_length(just_op2, result)
@test retrieve_name(just_op2, only(result)) == :*

just_simple_sum = @decapode begin
C == A + B
end
result = topological_sort_edges(just_simple_sum)
@test is_correct_length(just_simple_sum, result)
@test retrieve_name(just_simple_sum, only(result)) == :+

just_multi_sum = @decapode begin
F == A + B + C + D + E
end
result = topological_sort_edges(just_multi_sum)
@test is_correct_length(just_multi_sum, result)
@test retrieve_name(just_multi_sum, only(result)) == :+

op_combo = @decapode begin
F == h(d(A) + f(g(B) * C) + D)
end
result = topological_sort_edges(op_combo)
@test is_correct_length(op_combo, result)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ end
@testset "Open Operators" begin
include("openoperators.jl")
end

@testset "Symbolic Rewriting" begin
include("graph_traversal.jl")
end

0 comments on commit 2a6269c

Please sign in to comment.