Skip to content

Commit

Permalink
Further improvement of acset2symbolics
Browse files Browse the repository at this point in the history
Remove special DerivOp handling, fixed bug where multiple equations with the same variable result were being dropped, more tests to cover these cases and further clean up.
  • Loading branch information
GeorgeR227 committed Sep 28, 2024
1 parent 367414d commit 5b84cc8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
40 changes: 20 additions & 20 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ACSets
using SymbolicUtils
using SymbolicUtils: BasicSymbolic, Symbolic

# TODO: Expose only the symbolic_rewriting function
export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting

const DECA_EQUALITY_SYMBOL = (==)
Expand Down Expand Up @@ -30,13 +31,13 @@ end
function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = infer_types!(deepcopy(old_d))
eqns = merge_equations(d)
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns))
to_acset(d, apply_rewrites(eqns, rewriter))
end

apply_rewrites(eqns, rewriter) = isnothing(rewriter) ? eqns : map(rewriter, eqns)

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
sym_list = SymbolicEquation{Symbolic}[]
non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d))
map(non_tangents) do node
map(topological_sort_edges(d)) do node
to_symbolics(d, symvar_lookup, node.index, node.name)
end
end
Expand All @@ -50,26 +51,22 @@ function merge_equations(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
symexpr_list = extract_symexprs(d, symvar_lookup)

eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i
sym = symvar_lookup[d[i, :name]]
(sym, sym)
end)
eqn_lookup = Dict()

terminal_vars = infer_terminal_names(d)
terminal_eqns = SymbolicEquation{Symbolic}[]

foreach(symexpr_list) do x
push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup)))
if x.lhs.name in terminal_vars
push!(terminal_eqns, SymbolicEquation{Symbolic}(x.lhs, eqn_lookup[x.lhs]))
end
end

terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list)
map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals)
formed_deca_eqn.(terminal_eqns)
end

formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])

function apply_rewrites(symexprs, rewriter)
map(symexprs) do sym
res_sym = rewriter(sym)
isnothing(res_sym) ? sym : res_sym
end
end
formed_deca_eqn(symeqn::SymbolicEquation{Symbolic}) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [symeqn.lhs, symeqn.rhs])

function to_acset(d::SummationDecapode, sym_exprs)
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i
Expand All @@ -90,8 +87,11 @@ function to_acset(d::SummationDecapode, sym_exprs)
end
reify!(final_exprs)

deca_block = quote end
deca_block.args = [outer_types..., tangents..., final_exprs...]
deca_block = quote
$(outer_types...)
$(final_exprs...)
end

(infer_types!, SummationDecapode, parse_decapode)(deca_block)
end

2 changes: 2 additions & 0 deletions src/sym_rewrite.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: Delete this file

using Test
using DiagrammaticEquations
using SymbolicUtils
Expand Down
30 changes: 30 additions & 0 deletions test/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ using Catlab
all_ops_res[5, :name] = :D
all_ops_res[6, :name] = :C
@test all_ops all_ops_res

with_deriv = @decapode begin
A::Form0
::Form0

∂ₜ(A) ==
== Δ(A)
end

@test with_deriv == symbolic_rewriting(with_deriv)

repeated_vars = @decapode begin
A::Form0
B::Form1
C::Form1

C == d(A)
C == Δ(B)
C == d(A)
end

@test repeated_vars == symbolic_rewriting(repeated_vars)

# TODO: This is broken because of the terminals issue in #77
self_changing = @decapode begin
A::Form0
A == ∂ₜ(A)
end

@test_broken repeated_vars == symbolic_rewriting(self_changing)
end

function expr_rewriter(rules::Vector)
Expand Down

0 comments on commit 5b84cc8

Please sign in to comment.