diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index ee1898f..8df2d91 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -11,92 +11,51 @@ export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_a const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) - function symbolics_lookup(d::SummationDecapode) - lookup = Dict{Symbol, BasicSymbolic}() - for i in parts(d, :Var) - push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) - end - lookup + Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i + (d[i, :name], decavar_to_symbolics(d, i)) + end) end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - var = d[index, :name] new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - - SymbolicUtils.Sym{new_type}(var) -end - -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) - input_sym = symvar_lookup[d[d[op_index, :src], :name]] - output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - - op_sym = getfield(@__MODULE__, d[op_index, :op1]) - - S = promote_symtype(op_sym, input_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) - SymbolicEquation{Symbolic}(output_sym, rhs) + SymbolicUtils.Sym{new_type}(d[index, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) - input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] - input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] - output_sym = symvar_lookup[d[d[op_index, :res], :name]] +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, d[op_index, :op2]) + op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) - S = promote_symtype(op_sym, input1_sym, input2_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) + S = promote_symtype(op_sym, input_syms...) + rhs = SymbolicUtils.Term{S}(op_sym, input_syms) SymbolicEquation{Symbolic}(output_sym, rhs) end -#XXX: Always converting + -> .+ here since summation doesn't store the style of addition -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) - syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] - output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - - S = promote_symtype(+, syms_array...) - rhs = SymbolicUtils.Term{S}(+, syms_array) - SymbolicEquation{Symbolic}(output_sym,rhs) -end - function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) - d = deepcopy(old_d) - - infer_types!(d) + d = infer_types!(deepcopy(old_d)) symvar_lookup = symbolics_lookup(d) eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - if !isnothing(rewriter) - eqns = map(rewriter, eqns) - end - - to_acset(d, eqns) + to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - for node in topological_sort_edges(d) - retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC - push!(sym_list, to_symbolics(d, symvar_lookup, node)) + non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) + map(non_tangents) do node + to_symbolics(d, symvar_lookup, node.index, node.name) end - sym_list end function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - - eqn_lookup = Dict() - final_list = [] - for node in start_nodes(d) + eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] - push!(eqn_lookup, (sym => sym)) - end - - final_nodes = infer_terminal_names(d) + (sym, sym) + end) for expr in symexpr_list @@ -112,7 +71,7 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basic push!(eqn_lookup, (expr.lhs => merged_rhs)) - if expr.lhs.name in final_nodes + if expr.lhs.name in infer_terminal_names(d) push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) end end @@ -123,22 +82,13 @@ end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) function apply_rewrites(symexprs, rewriter) - - rewritten_list = [] - for sym in symexprs + map(symexprs) do sym res_sym = rewriter(sym) - rewritten_sym = isnothing(res_sym) ? sym : res_sym - push!(rewritten_list, rewritten_sym) + isnothing(res_sym) ? sym : res_sym end - - rewritten_list end -""" -og_d = original reference decapode which provides type information, state and terminal information -""" -function to_acset(og_d, sym_exprs) - +function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) @@ -151,28 +101,18 @@ function to_acset(og_d, sym_exprs) end sym => nothing end - map(recursive_descent, final_exprs) - - deca_block = quote end - - states = infer_states(og_d) - terminals = infer_terminals(og_d) + foreach(recursive_descent, final_exprs) - deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type])) - - append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) - - for op1 in parts(og_d, :Op1) - if og_d[op1, :op1] == DerivOp - push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) - end + states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) end - append!(deca_block.args, final_exprs) - - d = SummationDecapode(parse_decapode(deca_block)) - - infer_types!(d) + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end - d + deca_block = quote end + deca_block.args = [states_terminals..., tangents..., final_exprs...] + infer_types!(SummationDecapode(parse_decapode(deca_block))) end + diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index f2875b2..cc47435 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,13 +1,34 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function struct TraversalNode{T} index::Int name::T end +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) visited_Var[start_nodes(d)] .= true