generated from AlgebraicJulia/AlgebraicTemplate.jl
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Converts ACSet to a series of Symbolic terms that can be rewritten with a provided rewriter
- Loading branch information
1 parent
23fbf3f
commit 2a6269c
Showing
6 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
using DiagrammaticEquations | ||
using SymbolicUtils | ||
using SymbolicUtils.Rewriters | ||
using SymbolicUtils.Code | ||
using MLStyle | ||
|
||
const DECA_EQUALITY_SYMBOL = (==) | ||
|
||
to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) | ||
|
||
function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) | ||
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) | ||
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) | ||
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) | ||
|
||
rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) | ||
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) | ||
end | ||
|
||
function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) | ||
input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name]) | ||
input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name]) | ||
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) | ||
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) | ||
|
||
rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) | ||
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) | ||
end | ||
|
||
#XXX: Always converting + -> .+ here since summation doesn't store the style of addition | ||
# function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Σ}) | ||
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) | ||
# end | ||
|
||
function extract_symexprs(d::SummationDecapode) | ||
topo_list = topological_sort_edges(d) | ||
sym_list = [] | ||
for node in topo_list | ||
retrieve_name(d, node) != DerivOp || continue | ||
push!(sym_list, to_symbolics(d, node)) | ||
end | ||
sym_list | ||
end | ||
|
||
function apply_rewrites(d::SummationDecapode, rewriter) | ||
|
||
rewritten_list = [] | ||
for sym in extract_symexprs(d) | ||
res_sym = rewriter(sym) | ||
rewritten_sym = isnothing(res_sym) ? sym : res_sym | ||
push!(rewritten_list, rewritten_sym) | ||
end | ||
|
||
rewritten_list | ||
end | ||
|
||
# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet | ||
# @syms Δ(x) d(x) ⋆(x) | ||
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x)))) | ||
# rewriter = Postwalk(RestartedChain([lap_0_rule])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using DiagrammaticEquations | ||
using ACSets | ||
|
||
export TraversalNode, topological_sort_edges, number_of_ops, retrieve_name | ||
|
||
struct TraversalNode{T} | ||
index::Int | ||
name::T | ||
end | ||
|
||
function topological_sort_edges(d::SummationDecapode) | ||
visited_Var = falses(nparts(d, :Var)) | ||
visited_Var[start_nodes(d)] .= true | ||
|
||
# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ | ||
visited_1 = falses(nparts(d, :Op1)) | ||
visited_2 = falses(nparts(d, :Op2)) | ||
visited_Σ = falses(nparts(d, :Σ)) | ||
|
||
# FIXME: this is a quadratic implementation of topological_sort inlined in here. | ||
op_order = TraversalNode{Symbol}[] | ||
|
||
for _ in 1:number_of_ops(d) | ||
for op in parts(d, :Op1) | ||
if !visited_1[op] && visited_Var[d[op, :src]] | ||
|
||
visited_1[op] = true | ||
visited_Var[d[op, :tgt]] = true | ||
|
||
push!(op_order, TraversalNode(op, :Op1)) | ||
end | ||
end | ||
|
||
for op in parts(d, :Op2) | ||
if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] | ||
visited_2[op] = true | ||
visited_Var[d[op, :res]] = true | ||
push!(op_order, TraversalNode(op, :Op2)) | ||
end | ||
end | ||
|
||
for op in parts(d, :Σ) | ||
args = subpart(d, incident(d, op, :summation), :summand) | ||
if !visited_Σ[op] && all(visited_Var[args]) | ||
visited_Σ[op] = true | ||
visited_Var[d[op, :sum]] = true | ||
push!(op_order, TraversalNode(op, :Σ)) | ||
end | ||
end | ||
end | ||
|
||
@assert length(op_order) == number_of_ops(d) | ||
|
||
op_order | ||
end | ||
|
||
function number_of_ops(d::SummationDecapode) | ||
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) | ||
end | ||
|
||
function start_nodes(d::SummationDecapode) | ||
return vcat(infer_states(d), incident(d, :Literal, :type)) | ||
end | ||
|
||
function retrieve_name(d::SummationDecapode, tsr::TraversalNode) | ||
@match tsr.name begin | ||
:Op1 => d[tsr.index, :op1] | ||
:Op2 => d[tsr.index, :op2] | ||
:Σ => :+ | ||
_ => error("$(tsr.name) is not a valid table for names") | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
using DiagrammaticEquations | ||
using ACSets | ||
using MLStyle | ||
using Test | ||
|
||
function is_correct_length(d::SummationDecapode, result) | ||
return length(result) == number_of_ops(d) | ||
end | ||
|
||
@testset "Topological Sort on Edges" begin | ||
no_edge = @decapode begin | ||
F == S | ||
end | ||
@test isempty(topological_sort_edges(no_edge)) | ||
|
||
one_op1_deca = @decapode begin | ||
F == f(S) | ||
end | ||
result = topological_sort_edges(one_op1_deca) | ||
@test is_correct_length(one_op1_deca, result) | ||
@test retrieve_name(one_op1_deca, only(result)) == :f | ||
|
||
multi_op1_deca = @decapode begin | ||
F == c(b(a(S))) | ||
end | ||
result = topological_sort_edges(multi_op1_deca) | ||
@test is_correct_length(multi_op1_deca, result) | ||
for (edge, test_name) in zip(result, [:a, :b, :c]) | ||
@test retrieve_name(multi_op1_deca, edge) == test_name | ||
end | ||
|
||
cyclic = @decapode begin | ||
B == g(A) | ||
A == f(B) | ||
end | ||
@test_throws AssertionError topological_sort_edges(cyclic) | ||
|
||
just_op2 = @decapode begin | ||
C == A * B | ||
end | ||
result = topological_sort_edges(just_op2) | ||
@test is_correct_length(just_op2, result) | ||
@test retrieve_name(just_op2, only(result)) == :* | ||
|
||
just_simple_sum = @decapode begin | ||
C == A + B | ||
end | ||
result = topological_sort_edges(just_simple_sum) | ||
@test is_correct_length(just_simple_sum, result) | ||
@test retrieve_name(just_simple_sum, only(result)) == :+ | ||
|
||
just_multi_sum = @decapode begin | ||
F == A + B + C + D + E | ||
end | ||
result = topological_sort_edges(just_multi_sum) | ||
@test is_correct_length(just_multi_sum, result) | ||
@test retrieve_name(just_multi_sum, only(result)) == :+ | ||
|
||
op_combo = @decapode begin | ||
F == h(d(A) + f(g(B) * C) + D) | ||
end | ||
result = topological_sort_edges(op_combo) | ||
@test is_correct_length(op_combo, result) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters