generated from AlgebraicJulia/AlgebraicTemplate.jl
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
…Equations.jl into gr/acset2sym
- Loading branch information
Showing
3 changed files
with
95 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,178 +1,97 @@ | ||
using DiagrammaticEquations | ||
using ACSets | ||
using SymbolicUtils | ||
using SymbolicUtils.Rewriters | ||
using SymbolicUtils.Code | ||
using MLStyle | ||
using SymbolicUtils: BasicSymbolic, Symbolic | ||
|
||
import SymbolicUtils: BasicSymbolic, Symbolic | ||
|
||
export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup | ||
export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting | ||
|
||
const DECA_EQUALITY_SYMBOL = (==) | ||
|
||
to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) | ||
|
||
function symbolics_lookup(d::SummationDecapode) | ||
lookup = Dict{Symbol, BasicSymbolic}() | ||
for i in parts(d, :Var) | ||
push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) | ||
end | ||
lookup | ||
end | ||
|
||
function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) | ||
var = d[index, :name] | ||
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) | ||
|
||
SymbolicUtils.Sym{new_type}(var) | ||
Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i | ||
(d[i, :name], decavar_to_symbolics(d, i)) | ||
end) | ||
end | ||
|
||
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) | ||
input_sym = symvar_lookup[d[d[op_index, :src], :name]] | ||
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] | ||
|
||
op_sym = getfield(@__MODULE__, d[op_index, :op1]) | ||
|
||
S = promote_symtype(op_sym, input_sym) | ||
rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) | ||
SymbolicEquation{Symbolic}(output_sym, rhs) | ||
end | ||
|
||
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) | ||
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] | ||
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] | ||
output_sym = symvar_lookup[d[d[op_index, :res], :name]] | ||
|
||
op_sym = getfield(@__MODULE__, d[op_index, :op2]) | ||
|
||
S = promote_symtype(op_sym, input1_sym, input2_sym) | ||
rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) | ||
SymbolicEquation{Symbolic}(output_sym, rhs) | ||
function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I) | ||
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space) | ||
SymbolicUtils.Sym{new_type}(d[idx, :name]) | ||
end | ||
|
||
#XXX: Always converting + -> .+ here since summation doesn't store the style of addition | ||
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) | ||
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] | ||
output_sym = symvar_lookup[d[d[op_index, :sum], :name]] | ||
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) | ||
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) | ||
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) | ||
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) | ||
|
||
S = promote_symtype(+, syms_array...) | ||
rhs = SymbolicUtils.Term{S}(+, syms_array) | ||
SymbolicEquation{Symbolic}(output_sym,rhs) | ||
S = promote_symtype(op_sym, input_syms...) | ||
SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) | ||
end | ||
|
||
function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) | ||
d = deepcopy(old_d) | ||
|
||
infer_types!(d) | ||
|
||
symvar_lookup = symbolics_lookup(d) | ||
eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) | ||
|
||
if !isnothing(rewriter) | ||
eqns = map(rewriter, eqns) | ||
end | ||
|
||
to_acset(d, eqns) | ||
d = infer_types!(deepcopy(old_d)) | ||
eqns = merge_equations(d) | ||
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) | ||
end | ||
|
||
function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) | ||
sym_list = SymbolicEquation{Symbolic}[] | ||
for node in topological_sort_edges(d) | ||
# retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC | ||
push!(sym_list, to_symbolics(d, symvar_lookup, node)) | ||
non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) | ||
map(non_tangents) do node | ||
to_symbolics(d, symvar_lookup, node.index, node.name) | ||
end | ||
sym_list | ||
end | ||
|
||
function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) | ||
|
||
eqn_lookup = Dict() | ||
|
||
final_list = [] | ||
|
||
for node in start_nodes(d) | ||
sym = symvar_lookup[d[node, :name]] | ||
push!(eqn_lookup, (sym => sym)) | ||
end | ||
|
||
final_nodes = infer_terminal_names(d) | ||
|
||
for expr in symexpr_list | ||
|
||
# XXX SymbolicUtils.substitute swaps the order of multiplication. | ||
# example: @decapode begin | ||
# u::Form0 | ||
# G::Form0 | ||
# κ::Constant | ||
# ∂ₜ(G) == κ*★(d(★(d(u)))) | ||
# end | ||
# will have the kappa*var term rewritten to var*kappa | ||
merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) | ||
|
||
push!(eqn_lookup, (expr.lhs => merged_rhs)) | ||
|
||
if expr.lhs.name in final_nodes | ||
push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) | ||
end | ||
# XXX SymbolicUtils.substitute swaps the order of multiplication. | ||
# e.g. @decapode begin | ||
# ∂ₜ(G) == κ*u | ||
# end | ||
# will have the κ*u term rewritten to u*κ | ||
function merge_equations(d::SummationDecapode) | ||
symvar_lookup = symbolics_lookup(d) | ||
symexpr_list = extract_symexprs(d, symvar_lookup) | ||
|
||
eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i | ||
sym = symvar_lookup[d[i, :name]] | ||
(sym, sym) | ||
end) | ||
foreach(symexpr_list) do x | ||
push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) | ||
end | ||
|
||
final_list | ||
terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) | ||
map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) | ||
end | ||
|
||
formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) | ||
|
||
function apply_rewrites(symexprs, rewriter) | ||
|
||
rewritten_list = [] | ||
for sym in symexprs | ||
map(symexprs) do sym | ||
res_sym = rewriter(sym) | ||
rewritten_sym = isnothing(res_sym) ? sym : res_sym | ||
push!(rewritten_list, rewritten_sym) | ||
isnothing(res_sym) ? sym : res_sym | ||
end | ||
|
||
rewritten_list | ||
end | ||
|
||
""" | ||
og_d = original reference decapode which provides type information, state and terminal information | ||
""" | ||
function to_acset(og_d, sym_exprs) | ||
function to_acset(d::SummationDecapode, sym_exprs) | ||
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i | ||
:($(d[i, :name])::$(d[i, :type])) | ||
end | ||
|
||
tangents = map(incident(d, DerivOp, :op1)) do op1 | ||
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) | ||
end | ||
|
||
#TODO: This step is breaking up summations | ||
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) | ||
|
||
recursive_descent = @λ begin | ||
e::Expr => begin | ||
if e.head == :call | ||
e.args[1] = nameof(e.args[1]) | ||
map(recursive_descent, e.args[2:end]) | ||
end | ||
reify!(exprs) = foreach(exprs) do x | ||
if typeof(x)==Expr && x.head == :call | ||
x.args[1] = nameof(x.args[1]) | ||
reify!(x.args[2:end]) | ||
end | ||
sym => nothing | ||
end | ||
map(recursive_descent, final_exprs) | ||
reify!(final_exprs) | ||
|
||
deca_block = quote end | ||
|
||
states = infer_states(og_d) | ||
terminals = infer_terminals(og_d) | ||
|
||
deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type])) | ||
|
||
append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) | ||
|
||
for op1 in parts(og_d, :Op1) | ||
if og_d[op1, :op1] == DerivOp | ||
push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) | ||
end | ||
end | ||
|
||
append!(deca_block.args, final_exprs) | ||
|
||
d = SummationDecapode(parse_decapode(deca_block)) | ||
|
||
infer_types!(d) | ||
|
||
d | ||
deca_block.args = [outer_types..., tangents..., final_exprs...] | ||
∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters