From d408c26d3747c354545e90b5d53f184603aff1bd Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 23:24:27 -0400 Subject: [PATCH] Convert to symbolics inside merge_equations --- src/acset2symbolic.jl | 77 ++++++++++++++++++------------------------ src/graph_traversal.jl | 2 +- 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 8df2d91..94afff2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,11 +1,9 @@ using DiagrammaticEquations using ACSets +using MLStyle using SymbolicUtils using SymbolicUtils.Rewriters -using SymbolicUtils.Code -using MLStyle - -import SymbolicUtils: BasicSymbolic, Symbolic +using SymbolicUtils: BasicSymbolic, Symbolic export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup @@ -17,16 +15,16 @@ function symbolics_lookup(d::SummationDecapode) end) end -function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - SymbolicUtils.Sym{new_type}(d[index, :name]) +function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space) + SymbolicUtils.Sym{new_type}(d[idx, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) - input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) - output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) rhs = SymbolicUtils.Term{S}(op_sym, input_syms) @@ -35,10 +33,7 @@ end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = infer_types!(deepcopy(old_d)) - - symvar_lookup = symbolics_lookup(d) - eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - + eqns = merge_equations(d) to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end @@ -49,34 +44,29 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basi end end -function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - final_list = [] +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# example: @decapode begin +# u::Form0 +# G::Form0 +# κ::Constant +# ∂ₜ(G) == κ*★(d(★(d(u)))) +# end +# will have the kappa*var term rewritten to var*kappa +function merge_equations(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + symexpr_list = extract_symexprs(d, symvar_lookup) eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] (sym, sym) end) - - for expr in symexpr_list - - # XXX SymbolicUtils.substitute swaps the order of multiplication. - # example: @decapode begin - # u::Form0 - # G::Form0 - # κ::Constant - # ∂ₜ(G) == κ*★(d(★(d(u)))) - # end - # will have the kappa*var term rewritten to var*kappa + foreach(symexpr_list) do expr merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (expr.lhs => merged_rhs)) - - if expr.lhs.name in infer_terminal_names(d) - push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) - end end - final_list + terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) + map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) @@ -89,9 +79,16 @@ function apply_rewrites(symexprs, rewriter) end function to_acset(d::SummationDecapode, sym_exprs) + outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) + end + + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end + #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - recursive_descent = @λ begin e::Expr => begin if e.head == :call @@ -103,16 +100,8 @@ function to_acset(d::SummationDecapode, sym_exprs) end foreach(recursive_descent, final_exprs) - states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx - :($(d[idx, :name])::$(d[idx, :type])) - end - - tangents = map(incident(d, DerivOp, :op1)) do op1 - :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) - end - deca_block = quote end - deca_block.args = [states_terminals..., tangents..., final_exprs...] + deca_block.args = [outer_types..., tangents..., final_exprs...] infer_types!(SummationDecapode(parse_decapode(deca_block))) end diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index cc47435..6266240 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function struct TraversalNode{T} index::Int