From 5b84cc8a93097832e2a0fa492bba913f17904bba Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:49:14 -0400 Subject: [PATCH] Further improvement of acset2symbolics Remove special DerivOp handling, fixed bug where multiple equations with the same variable result were being dropped, more tests to cover these cases and further clean up. --- src/acset2symbolic.jl | 40 ++++++++++++++++++++-------------------- src/sym_rewrite.jl | 2 ++ test/acset2symbolic.jl | 30 ++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index c383ca2..b2da836 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -3,6 +3,7 @@ using ACSets using SymbolicUtils using SymbolicUtils: BasicSymbolic, Symbolic +# TODO: Expose only the symbolic_rewriting function export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting const DECA_EQUALITY_SYMBOL = (==) @@ -30,13 +31,13 @@ end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = infer_types!(deepcopy(old_d)) eqns = merge_equations(d) - to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) + to_acset(d, apply_rewrites(eqns, rewriter)) end +apply_rewrites(eqns, rewriter) = isnothing(rewriter) ? eqns : map(rewriter, eqns) + function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) - map(non_tangents) do node + map(topological_sort_edges(d)) do node to_symbolics(d, symvar_lookup, node.index, node.name) end end @@ -50,26 +51,22 @@ function merge_equations(d::SummationDecapode) symvar_lookup = symbolics_lookup(d) symexpr_list = extract_symexprs(d, symvar_lookup) - eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i - sym = symvar_lookup[d[i, :name]] - (sym, sym) - end) + eqn_lookup = Dict() + + terminal_vars = infer_terminal_names(d) + terminal_eqns = SymbolicEquation{Symbolic}[] + foreach(symexpr_list) do x push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) + if x.lhs.name in terminal_vars + push!(terminal_eqns, SymbolicEquation{Symbolic}(x.lhs, eqn_lookup[x.lhs])) + end end - terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) - map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) + formed_deca_eqn.(terminal_eqns) end -formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) - -function apply_rewrites(symexprs, rewriter) - map(symexprs) do sym - res_sym = rewriter(sym) - isnothing(res_sym) ? sym : res_sym - end -end +formed_deca_eqn(symeqn::SymbolicEquation{Symbolic}) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [symeqn.lhs, symeqn.rhs]) function to_acset(d::SummationDecapode, sym_exprs) outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i @@ -90,8 +87,11 @@ function to_acset(d::SummationDecapode, sym_exprs) end reify!(final_exprs) - deca_block = quote end - deca_block.args = [outer_types..., tangents..., final_exprs...] + deca_block = quote + $(outer_types...) + $(final_exprs...) + end + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end diff --git a/src/sym_rewrite.jl b/src/sym_rewrite.jl index e63903c..cc6e4b4 100644 --- a/src/sym_rewrite.jl +++ b/src/sym_rewrite.jl @@ -1,3 +1,5 @@ +# TODO: Delete this file + using Test using DiagrammaticEquations using SymbolicUtils diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 89de387..f31efbb 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -68,6 +68,36 @@ using Catlab all_ops_res[5, :name] = :D all_ops_res[6, :name] = :C @test all_ops ≃ all_ops_res + + with_deriv = @decapode begin + A::Form0 + Ȧ::Form0 + + ∂ₜ(A) == Ȧ + Ȧ == Δ(A) + end + + @test with_deriv == symbolic_rewriting(with_deriv) + + repeated_vars = @decapode begin + A::Form0 + B::Form1 + C::Form1 + + C == d(A) + C == Δ(B) + C == d(A) + end + + @test repeated_vars == symbolic_rewriting(repeated_vars) + + # TODO: This is broken because of the terminals issue in #77 + self_changing = @decapode begin + A::Form0 + A == ∂ₜ(A) + end + + @test_broken repeated_vars == symbolic_rewriting(self_changing) end function expr_rewriter(rules::Vector)