Skip to content

Commit

Permalink
Merge branch 'gr/acset2sym' of github.com:AlgebraicJulia/Diagrammatic…
Browse files Browse the repository at this point in the history
…Equations.jl into gr/acset2sym
  • Loading branch information
quffaro committed Sep 27, 2024
2 parents 2b3198f + 67079cb commit 367414d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 189 deletions.
191 changes: 55 additions & 136 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -1,178 +1,97 @@
using DiagrammaticEquations
using ACSets
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle
using SymbolicUtils: BasicSymbolic, Symbolic

import SymbolicUtils: BasicSymbolic, Symbolic

export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup
export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name))

function symbolics_lookup(d::SummationDecapode)
lookup = Dict{Symbol, BasicSymbolic}()
for i in parts(d, :Var)
push!(lookup, d[i, :name] => decavar_to_symbolics(d, i))
end
lookup
end

function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
var = d[index, :name]
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space)

SymbolicUtils.Sym{new_type}(var)
Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i
(d[i, :name], decavar_to_symbolics(d, i))
end)
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1})
input_sym = symvar_lookup[d[d[op_index, :src], :name]]
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]]

op_sym = getfield(@__MODULE__, d[op_index, :op1])

S = promote_symtype(op_sym, input_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input_sym])
SymbolicEquation{Symbolic}(output_sym, rhs)
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2})
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
output_sym = symvar_lookup[d[d[op_index, :res], :name]]

op_sym = getfield(@__MODULE__, d[op_index, :op2])

S = promote_symtype(op_sym, input1_sym, input2_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym])
SymbolicEquation{Symbolic}(output_sym, rhs)
function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space)
SymbolicUtils.Sym{new_type}(d[idx, :name])
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ})
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]]
output_sym = symvar_lookup[d[d[op_index, :sum], :name]]
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(+, syms_array...)
rhs = SymbolicUtils.Term{S}(+, syms_array)
SymbolicEquation{Symbolic}(output_sym,rhs)
S = promote_symtype(op_sym, input_syms...)
SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms))
end

function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = deepcopy(old_d)

infer_types!(d)

symvar_lookup = symbolics_lookup(d)
eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))

if !isnothing(rewriter)
eqns = map(rewriter, eqns)
end

to_acset(d, eqns)
d = infer_types!(deepcopy(old_d))
eqns = merge_equations(d)
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns))
end

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
sym_list = SymbolicEquation{Symbolic}[]
for node in topological_sort_edges(d)
# retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC
push!(sym_list, to_symbolics(d, symvar_lookup, node))
non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d))
map(non_tangents) do node
to_symbolics(d, symvar_lookup, node.index, node.name)
end
sym_list
end

function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}})

eqn_lookup = Dict()

final_list = []

for node in start_nodes(d)
sym = symvar_lookup[d[node, :name]]
push!(eqn_lookup, (sym => sym))
end

final_nodes = infer_terminal_names(d)

for expr in symexpr_list

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# example: @decapode begin
# u::Form0
# G::Form0
# κ::Constant
# ∂ₜ(G) == κ*★(d(★(d(u))))
# end
# will have the kappa*var term rewritten to var*kappa
merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup)

push!(eqn_lookup, (expr.lhs => merged_rhs))

if expr.lhs.name in final_nodes
push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs))
end
# XXX SymbolicUtils.substitute swaps the order of multiplication.
# e.g. @decapode begin
# ∂ₜ(G) == κ*u
# end
# will have the κ*u term rewritten to u*κ
function merge_equations(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
symexpr_list = extract_symexprs(d, symvar_lookup)

eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i
sym = symvar_lookup[d[i, :name]]
(sym, sym)
end)
foreach(symexpr_list) do x
push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup)))
end

final_list
terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list)
map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals)
end

formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])

function apply_rewrites(symexprs, rewriter)

rewritten_list = []
for sym in symexprs
map(symexprs) do sym
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
isnothing(res_sym) ? sym : res_sym
end

rewritten_list
end

"""
og_d = original reference decapode which provides type information, state and terminal information
"""
function to_acset(og_d, sym_exprs)
function to_acset(d::SummationDecapode, sym_exprs)
outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i
:($(d[i, :name])::$(d[i, :type]))
end

tangents = map(incident(d, DerivOp, :op1)) do op1
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)

recursive_descent = begin
e::Expr => begin
if e.head == :call
e.args[1] = nameof(e.args[1])
map(recursive_descent, e.args[2:end])
end
reify!(exprs) = foreach(exprs) do x
if typeof(x)==Expr && x.head == :call
x.args[1] = nameof(x.args[1])
reify!(x.args[2:end])
end
sym => nothing
end
map(recursive_descent, final_exprs)
reify!(final_exprs)

deca_block = quote end

states = infer_states(og_d)
terminals = infer_terminals(og_d)

deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type]))

append!(deca_block.args, map(deca_type_gen, vcat(states, terminals)))

for op1 in parts(og_d, :Op1)
if og_d[op1, :op1] == DerivOp
push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name]))))
end
end

append!(deca_block.args, final_exprs)

d = SummationDecapode(parse_decapode(deca_block))

infer_types!(d)

d
deca_block.args = [outer_types..., tangents..., final_exprs...]
(infer_types!, SummationDecapode, parse_decapode)(deca_block)
end

79 changes: 40 additions & 39 deletions src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -1,66 +1,66 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes
export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function

struct TraversalNode{T}
index::Int
name::T
end

edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
[d[idx,:src]]
edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
[d[idx,:proj1], d[idx,:proj2]]
edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
d[incident(d, idx, :summation), :summand]

edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
d[idx,:tgt]
edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
d[idx,:res]
edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
d[idx, :sum]

edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
d[idx,:op1]
edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
d[idx,:op2]
edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
:+

#XXX: This topological sort is O(n^2).
function topological_sort_edges(d::SummationDecapode)
visited_Var = falses(nparts(d, :Var))
visited_Var[start_nodes(d)] .= true
visited = Dict(:Op1 => falses(nparts(d, :Op1)),
:Op2 => falses(nparts(d, :Op2)), => falses(nparts(d, )))

# TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ
visited_1 = falses(nparts(d, :Op1))
visited_2 = falses(nparts(d, :Op2))
visited_Σ = falses(nparts(d, ))

# FIXME: this is a quadratic implementation of topological_sort inlined in here.
op_order = TraversalNode{Symbol}[]

for _ in 1:n_ops(d)
for op in parts(d, :Op1)
if !visited_1[op] && visited_Var[d[op, :src]]

visited_1[op] = true
visited_Var[d[op, :tgt]] = true

push!(op_order, TraversalNode(op, :Op1))
end
end

for op in parts(d, :Op2)
if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]]
visited_2[op] = true
visited_Var[d[op, :res]] = true
push!(op_order, TraversalNode(op, :Op2))
end
function visit(op, op_type)
if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))])
visited[op_type][op] = true
visited_Var[edge_output(d,op,Val(op_type))] = true
push!(op_order, TraversalNode(op, op_type))
end
end

for op in parts(d, )
args = subpart(d, incident(d, op, :summation), :summand)
if !visited_Σ[op] && all(visited_Var[args])
visited_Σ[op] = true
visited_Var[d[op, :sum]] = true
push!(op_order, TraversalNode(op, ))
end
end
for _ in 1:n_ops(d)
visit.(parts(d,:Op1), :Op1)
visit.(parts(d,:Op2), :Op2)
visit.(parts(d,), )
end

@assert length(op_order) == n_ops(d)

op_order
end

function n_ops(d::SummationDecapode)
return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )
end
n_ops(d::SummationDecapode) =
nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, )

function start_nodes(d::SummationDecapode)
return vcat(infer_states(d), incident(d, :Literal, :type))
end
start_nodes(d::SummationDecapode) =
vcat(infer_states(d), incident(d, :Literal, :type))

function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
@match tsr.name begin
Expand All @@ -70,3 +70,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
_ => error("$(tsr.name) is a table without names")
end
end

14 changes: 0 additions & 14 deletions test/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,3 @@ end
@test Heat_open z

end

x=@decapode begin
u::Form0
∂ₜ(u) == u
end
symbolic_rewriting(x)
# if the `for op1 in parts(og_d, :Op1)...` block is removed, this is annihilated because x has no terminals

x=@decapode begin
u::Form0
v::Form0
∂ₜ(v) == u
end
symbolic_rewriting(x) # fine

0 comments on commit 367414d

Please sign in to comment.