Skip to content

Commit

Permalink
Convert to symbolics inside merge_equations
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Sep 27, 2024
1 parent 31ad602 commit d408c26
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 45 deletions.
77 changes: 33 additions & 44 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using DiagrammaticEquations
using ACSets
using MLStyle
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

import SymbolicUtils: BasicSymbolic, Symbolic
using SymbolicUtils: BasicSymbolic, Symbolic

export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup

Expand All @@ -17,16 +15,16 @@ function symbolics_lookup(d::SummationDecapode)
end)
end

function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space)
SymbolicUtils.Sym{new_type}(d[index, :name])
function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space)
SymbolicUtils.Sym{new_type}(d[idx, :name])
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name])
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])

op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type)))
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(op_sym, input_syms...)
rhs = SymbolicUtils.Term{S}(op_sym, input_syms)
Expand All @@ -35,10 +33,7 @@ end

function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = infer_types!(deepcopy(old_d))

symvar_lookup = symbolics_lookup(d)
eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))

eqns = merge_equations(d)
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns))
end

Expand All @@ -49,34 +44,29 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basi
end
end

function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}})
final_list = []
# XXX SymbolicUtils.substitute swaps the order of multiplication.
# example: @decapode begin
# u::Form0
# G::Form0
# κ::Constant
# ∂ₜ(G) == κ*★(d(★(d(u))))
# end
# will have the kappa*var term rewritten to var*kappa
function merge_equations(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
symexpr_list = extract_symexprs(d, symvar_lookup)

eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node
sym = symvar_lookup[d[node, :name]]
(sym, sym)
end)

for expr in symexpr_list

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# example: @decapode begin
# u::Form0
# G::Form0
# κ::Constant
# ∂ₜ(G) == κ*★(d(★(d(u))))
# end
# will have the kappa*var term rewritten to var*kappa
foreach(symexpr_list) do expr
merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup)

push!(eqn_lookup, (expr.lhs => merged_rhs))

if expr.lhs.name in infer_terminal_names(d)
push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs))
end
end

final_list
terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list)
map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals)
end

formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])
Expand All @@ -89,9 +79,16 @@ function apply_rewrites(symexprs, rewriter)
end

function to_acset(d::SummationDecapode, sym_exprs)
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx
:($(d[idx, :name])::$(d[idx, :type]))
end

tangents = map(incident(d, DerivOp, :op1)) do op1
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)

recursive_descent = begin
e::Expr => begin
if e.head == :call
Expand All @@ -103,16 +100,8 @@ function to_acset(d::SummationDecapode, sym_exprs)
end
foreach(recursive_descent, final_exprs)

states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx
:($(d[idx, :name])::$(d[idx, :type]))
end

tangents = map(incident(d, DerivOp, :op1)) do op1
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

deca_block = quote end
deca_block.args = [states_terminals..., tangents..., final_exprs...]
deca_block.args = [outer_types..., tangents..., final_exprs...]
infer_types!(SummationDecapode(parse_decapode(deca_block)))
end

2 changes: 1 addition & 1 deletion src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function
export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function

struct TraversalNode{T}
index::Int
Expand Down

0 comments on commit d408c26

Please sign in to comment.