Skip to content

Commit

Permalink
Clean out-of-order vector constructions
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 committed Sep 27, 2024
1 parent bc9ab00 commit 31ad602
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 93 deletions.
124 changes: 32 additions & 92 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,92 +11,51 @@ export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_a

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name))

function symbolics_lookup(d::SummationDecapode)
lookup = Dict{Symbol, BasicSymbolic}()
for i in parts(d, :Var)
push!(lookup, d[i, :name] => decavar_to_symbolics(d, i))
end
lookup
Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i
(d[i, :name], decavar_to_symbolics(d, i))
end)
end

function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
var = d[index, :name]
new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space)

SymbolicUtils.Sym{new_type}(var)
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1})
input_sym = symvar_lookup[d[d[op_index, :src], :name]]
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]]

op_sym = getfield(@__MODULE__, d[op_index, :op1])

S = promote_symtype(op_sym, input_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input_sym])
SymbolicEquation{Symbolic}(output_sym, rhs)
SymbolicUtils.Sym{new_type}(d[index, :name])
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2})
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
output_sym = symvar_lookup[d[d[op_index, :res], :name]]
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name])

op_sym = getfield(@__MODULE__, d[op_index, :op2])
op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type)))

S = promote_symtype(op_sym, input1_sym, input2_sym)
rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym])
S = promote_symtype(op_sym, input_syms...)
rhs = SymbolicUtils.Term{S}(op_sym, input_syms)
SymbolicEquation{Symbolic}(output_sym, rhs)
end

#XXX: Always converting + -> .+ here since summation doesn't store the style of addition
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ})
syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]]
output_sym = symvar_lookup[d[d[op_index, :sum], :name]]

S = promote_symtype(+, syms_array...)
rhs = SymbolicUtils.Term{S}(+, syms_array)
SymbolicEquation{Symbolic}(output_sym,rhs)
end

function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing)
d = deepcopy(old_d)

infer_types!(d)
d = infer_types!(deepcopy(old_d))

symvar_lookup = symbolics_lookup(d)
eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup))

if !isnothing(rewriter)
eqns = map(rewriter, eqns)
end

to_acset(d, eqns)
to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns))
end

function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic})
sym_list = SymbolicEquation{Symbolic}[]
for node in topological_sort_edges(d)
retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC
push!(sym_list, to_symbolics(d, symvar_lookup, node))
non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d))
map(non_tangents) do node
to_symbolics(d, symvar_lookup, node.index, node.name)
end
sym_list
end

function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}})

eqn_lookup = Dict()

final_list = []

for node in start_nodes(d)
eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node
sym = symvar_lookup[d[node, :name]]
push!(eqn_lookup, (sym => sym))
end

final_nodes = infer_terminal_names(d)
(sym, sym)
end)

for expr in symexpr_list

Expand All @@ -112,7 +71,7 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basic

push!(eqn_lookup, (expr.lhs => merged_rhs))

if expr.lhs.name in final_nodes
if expr.lhs.name in infer_terminal_names(d)
push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs))
end
end
Expand All @@ -123,22 +82,13 @@ end
formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs])

function apply_rewrites(symexprs, rewriter)

rewritten_list = []
for sym in symexprs
map(symexprs) do sym
res_sym = rewriter(sym)
rewritten_sym = isnothing(res_sym) ? sym : res_sym
push!(rewritten_list, rewritten_sym)
isnothing(res_sym) ? sym : res_sym

Check warning on line 87 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L84-L87

Added lines #L84 - L87 were not covered by tests
end

rewritten_list
end

"""
og_d = original reference decapode which provides type information, state and terminal information
"""
function to_acset(og_d, sym_exprs)

function to_acset(d::SummationDecapode, sym_exprs)
#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)

Expand All @@ -151,28 +101,18 @@ function to_acset(og_d, sym_exprs)
end
sym => nothing
end
map(recursive_descent, final_exprs)

deca_block = quote end

states = infer_states(og_d)
terminals = infer_terminals(og_d)
foreach(recursive_descent, final_exprs)

deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type]))

append!(deca_block.args, map(deca_type_gen, vcat(states, terminals)))

for op1 in parts(og_d, :Op1)
if og_d[op1, :op1] == DerivOp
push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name]))))
end
states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx
:($(d[idx, :name])::$(d[idx, :type]))
end

append!(deca_block.args, final_exprs)

d = SummationDecapode(parse_decapode(deca_block))

infer_types!(d)
tangents = map(incident(d, DerivOp, :op1)) do op1
:($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name])))
end

d
deca_block = quote end
deca_block.args = [states_terminals..., tangents..., final_exprs...]
infer_types!(SummationDecapode(parse_decapode(deca_block)))
end

23 changes: 22 additions & 1 deletion src/graph_traversal.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
using DiagrammaticEquations
using ACSets

export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes
export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function

struct TraversalNode{T}
index::Int
name::T
end

edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
[d[idx,:src]]
edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
[d[idx,:proj1], d[idx,:proj2]]
edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
d[incident(d, idx, :summation), :summand]

edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
d[idx,:tgt]
edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
d[idx,:res]
edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
d[idx, :sum]

edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) =
d[idx,:op1]
edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) =
d[idx,:op2]
edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) =
:+

function topological_sort_edges(d::SummationDecapode)
visited_Var = falses(nparts(d, :Var))
visited_Var[start_nodes(d)] .= true
Expand Down

0 comments on commit 31ad602

Please sign in to comment.