From 31ad602145ba5f9d4a901dc3ac9935c4704317a1 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 22:26:28 -0400 Subject: [PATCH 1/4] Clean out-of-order vector constructions --- src/acset2symbolic.jl | 124 +++++++++++------------------------------ src/graph_traversal.jl | 23 +++++++- 2 files changed, 54 insertions(+), 93 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index ee1898f..8df2d91 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -11,92 +11,51 @@ export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_a const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) - function symbolics_lookup(d::SummationDecapode) - lookup = Dict{Symbol, BasicSymbolic}() - for i in parts(d, :Var) - push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) - end - lookup + Dict{Symbol, BasicSymbolic}(map(parts(d, :Var)) do i + (d[i, :name], decavar_to_symbolics(d, i)) + end) end function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - var = d[index, :name] new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - - SymbolicUtils.Sym{new_type}(var) -end - -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op1}) - input_sym = symvar_lookup[d[d[op_index, :src], :name]] - output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] - - op_sym = getfield(@__MODULE__, d[op_index, :op1]) - - S = promote_symtype(op_sym, input_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input_sym]) - SymbolicEquation{Symbolic}(output_sym, rhs) + SymbolicUtils.Sym{new_type}(d[index, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Op2}) - input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] - input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] - output_sym = symvar_lookup[d[d[op_index, :res], :name]] +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, d[op_index, :op2]) + op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) - S = promote_symtype(op_sym, input1_sym, input2_sym) - rhs = SymbolicUtils.Term{S}(op_sym, [input1_sym, input2_sym]) + S = promote_symtype(op_sym, input_syms...) + rhs = SymbolicUtils.Term{S}(op_sym, input_syms) SymbolicEquation{Symbolic}(output_sym, rhs) end -#XXX: Always converting + -> .+ here since summation doesn't store the style of addition -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, ::Val{:Σ}) - syms_array = [symvar_lookup[var] for var in d[d[incident(d, op_index, :summation), :summand], :name]] - output_sym = symvar_lookup[d[d[op_index, :sum], :name]] - - S = promote_symtype(+, syms_array...) - rhs = SymbolicUtils.Term{S}(+, syms_array) - SymbolicEquation{Symbolic}(output_sym,rhs) -end - function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) - d = deepcopy(old_d) - - infer_types!(d) + d = infer_types!(deepcopy(old_d)) symvar_lookup = symbolics_lookup(d) eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - if !isnothing(rewriter) - eqns = map(rewriter, eqns) - end - - to_acset(d, eqns) + to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - sym_list = SymbolicEquation{Symbolic}[] - for node in topological_sort_edges(d) - retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC - push!(sym_list, to_symbolics(d, symvar_lookup, node)) + non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) + map(non_tangents) do node + to_symbolics(d, symvar_lookup, node.index, node.name) end - sym_list end function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - - eqn_lookup = Dict() - final_list = [] - for node in start_nodes(d) + eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] - push!(eqn_lookup, (sym => sym)) - end - - final_nodes = infer_terminal_names(d) + (sym, sym) + end) for expr in symexpr_list @@ -112,7 +71,7 @@ function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basic push!(eqn_lookup, (expr.lhs => merged_rhs)) - if expr.lhs.name in final_nodes + if expr.lhs.name in infer_terminal_names(d) push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) end end @@ -123,22 +82,13 @@ end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) function apply_rewrites(symexprs, rewriter) - - rewritten_list = [] - for sym in symexprs + map(symexprs) do sym res_sym = rewriter(sym) - rewritten_sym = isnothing(res_sym) ? sym : res_sym - push!(rewritten_list, rewritten_sym) + isnothing(res_sym) ? sym : res_sym end - - rewritten_list end -""" -og_d = original reference decapode which provides type information, state and terminal information -""" -function to_acset(og_d, sym_exprs) - +function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) @@ -151,28 +101,18 @@ function to_acset(og_d, sym_exprs) end sym => nothing end - map(recursive_descent, final_exprs) - - deca_block = quote end - - states = infer_states(og_d) - terminals = infer_terminals(og_d) + foreach(recursive_descent, final_exprs) - deca_type_gen = idx -> :($(og_d[idx, :name])::$(og_d[idx, :type])) - - append!(deca_block.args, map(deca_type_gen, vcat(states, terminals))) - - for op1 in parts(og_d, :Op1) - if og_d[op1, :op1] == DerivOp - push!(deca_block.args, :($(og_d[og_d[op1, :tgt], :name]) == $DerivOp($(og_d[og_d[op1, :src], :name])))) - end + states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) end - append!(deca_block.args, final_exprs) - - d = SummationDecapode(parse_decapode(deca_block)) - - infer_types!(d) + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end - d + deca_block = quote end + deca_block.args = [states_terminals..., tangents..., final_exprs...] + infer_types!(SummationDecapode(parse_decapode(deca_block))) end + diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index f2875b2..cc47435 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,13 +1,34 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function struct TraversalNode{T} index::Int name::T end +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) visited_Var[start_nodes(d)] .= true From d408c26d3747c354545e90b5d53f184603aff1bd Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 23:24:27 -0400 Subject: [PATCH 2/4] Convert to symbolics inside merge_equations --- src/acset2symbolic.jl | 77 ++++++++++++++++++------------------------ src/graph_traversal.jl | 2 +- 2 files changed, 34 insertions(+), 45 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 8df2d91..94afff2 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,11 +1,9 @@ using DiagrammaticEquations using ACSets +using MLStyle using SymbolicUtils using SymbolicUtils.Rewriters -using SymbolicUtils.Code -using MLStyle - -import SymbolicUtils: BasicSymbolic, Symbolic +using SymbolicUtils: BasicSymbolic, Symbolic export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup @@ -17,16 +15,16 @@ function symbolics_lookup(d::SummationDecapode) end) end -function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) - new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[index, :type], space) - SymbolicUtils.Sym{new_type}(d[index, :name]) +function decavar_to_symbolics(d::SummationDecapode, idx::Int; space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, d[idx, :type], space) + SymbolicUtils.Sym{new_type}(d[idx, :name]) end -function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_index::Int, op_type::Symbol) - input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_index,Val(op_type)), :name]) - output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_index,Val(op_type)), :name]) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, edge_function(d,op_index,Val(op_type))) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) rhs = SymbolicUtils.Term{S}(op_sym, input_syms) @@ -35,10 +33,7 @@ end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) d = infer_types!(deepcopy(old_d)) - - symvar_lookup = symbolics_lookup(d) - eqns = merge_equations(d, symvar_lookup, extract_symexprs(d, symvar_lookup)) - + eqns = merge_equations(d) to_acset(d, isnothing(rewriter) ? eqns : map(rewriter, eqns)) end @@ -49,34 +44,29 @@ function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, Basi end end -function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, symexpr_list::Vector{SymbolicEquation{Symbolic}}) - final_list = [] +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# example: @decapode begin +# u::Form0 +# G::Form0 +# κ::Constant +# ∂ₜ(G) == κ*★(d(★(d(u)))) +# end +# will have the kappa*var term rewritten to var*kappa +function merge_equations(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + symexpr_list = extract_symexprs(d, symvar_lookup) eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node sym = symvar_lookup[d[node, :name]] (sym, sym) end) - - for expr in symexpr_list - - # XXX SymbolicUtils.substitute swaps the order of multiplication. - # example: @decapode begin - # u::Form0 - # G::Form0 - # κ::Constant - # ∂ₜ(G) == κ*★(d(★(d(u)))) - # end - # will have the kappa*var term rewritten to var*kappa + foreach(symexpr_list) do expr merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (expr.lhs => merged_rhs)) - - if expr.lhs.name in infer_terminal_names(d) - push!(final_list, formed_deca_eqn(expr.lhs, merged_rhs)) - end end - final_list + terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) + map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) end formed_deca_eqn(lhs, rhs) = SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [lhs, rhs]) @@ -89,9 +79,16 @@ function apply_rewrites(symexprs, rewriter) end function to_acset(d::SummationDecapode, sym_exprs) + outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx + :($(d[idx, :name])::$(d[idx, :type])) + end + + tangents = map(incident(d, DerivOp, :op1)) do op1 + :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) + end + #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - recursive_descent = @λ begin e::Expr => begin if e.head == :call @@ -103,16 +100,8 @@ function to_acset(d::SummationDecapode, sym_exprs) end foreach(recursive_descent, final_exprs) - states_terminals = map([infer_states(d)..., infer_terminals(d)...]) do idx - :($(d[idx, :name])::$(d[idx, :type])) - end - - tangents = map(incident(d, DerivOp, :op1)) do op1 - :($(d[d[op1, :tgt], :name]) == $DerivOp($(d[d[op1, :src], :name]))) - end - deca_block = quote end - deca_block.args = [states_terminals..., tangents..., final_exprs...] + deca_block.args = [outer_types..., tangents..., final_exprs...] infer_types!(SummationDecapode(parse_decapode(deca_block))) end diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index cc47435..6266240 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_outputs, edge_function +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function struct TraversalNode{T} index::Int From 3cd624e4c917a0cf01e4841bdbd9a5ee87388302 Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 26 Sep 2024 23:54:52 -0400 Subject: [PATCH 3/4] Reduce cases of topological sort --- src/graph_traversal.jl | 56 ++++++++++++++---------------------------- 1 file changed, 18 insertions(+), 38 deletions(-) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 6266240..e41048e 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -29,59 +29,38 @@ edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = :+ +#XXX: This topological sort is O(n^2). function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) visited_Var[start_nodes(d)] .= true + visited = Dict(:Op1 => falses(nparts(d, :Op1)), + :Op2 => falses(nparts(d, :Op2)), :Σ => falses(nparts(d, :Σ))) - # TODO: Collect these visited arrays into one structure indexed by :Op1, :Op2, and :Σ - visited_1 = falses(nparts(d, :Op1)) - visited_2 = falses(nparts(d, :Op2)) - visited_Σ = falses(nparts(d, :Σ)) - - # FIXME: this is a quadratic implementation of topological_sort inlined in here. op_order = TraversalNode{Symbol}[] - for _ in 1:n_ops(d) - for op in parts(d, :Op1) - if !visited_1[op] && visited_Var[d[op, :src]] - - visited_1[op] = true - visited_Var[d[op, :tgt]] = true - - push!(op_order, TraversalNode(op, :Op1)) - end - end - - for op in parts(d, :Op2) - if !visited_2[op] && visited_Var[d[op, :proj1]] && visited_Var[d[op, :proj2]] - visited_2[op] = true - visited_Var[d[op, :res]] = true - push!(op_order, TraversalNode(op, :Op2)) - end + function visit(op, op_type) + if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))]) + visited[op_type][op] = true + visited_Var[edge_output(d,op,Val(op_type))] = true + push!(op_order, TraversalNode(op, op_type)) end + end - for op in parts(d, :Σ) - args = subpart(d, incident(d, op, :summation), :summand) - if !visited_Σ[op] && all(visited_Var[args]) - visited_Σ[op] = true - visited_Var[d[op, :sum]] = true - push!(op_order, TraversalNode(op, :Σ)) - end - end + for _ in 1:n_ops(d) + visit.(parts(d,:Op1), :Op1) + visit.(parts(d,:Op2), :Op2) + visit.(parts(d,:Σ), :Σ) end @assert length(op_order) == n_ops(d) - op_order end -function n_ops(d::SummationDecapode) - return nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) -end +n_ops(d::SummationDecapode) = + nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) -function start_nodes(d::SummationDecapode) - return vcat(infer_states(d), incident(d, :Literal, :type)) -end +start_nodes(d::SummationDecapode) = + vcat(infer_states(d), incident(d, :Literal, :type)) function retrieve_name(d::SummationDecapode, tsr::TraversalNode) @match tsr.name begin @@ -91,3 +70,4 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) _ => error("$(tsr.name) is a table without names") end end + From 67079cbfd4c5af0b189d74e20cb82e35864a084a Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Fri, 27 Sep 2024 11:32:49 -0400 Subject: [PATCH 4/4] Reify via recursive function, not lambda case --- src/acset2symbolic.jl | 49 +++++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 94afff2..cfef6fa 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,11 +1,9 @@ using DiagrammaticEquations using ACSets -using MLStyle using SymbolicUtils -using SymbolicUtils.Rewriters using SymbolicUtils: BasicSymbolic, Symbolic -export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting, symbolics_lookup +export symbolics_lookup, extract_symexprs, apply_rewrites, merge_equations, to_acset, symbolic_rewriting const DECA_EQUALITY_SYMBOL = (==) @@ -23,12 +21,10 @@ end function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) - op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) S = promote_symtype(op_sym, input_syms...) - rhs = SymbolicUtils.Term{S}(op_sym, input_syms) - SymbolicEquation{Symbolic}(output_sym, rhs) + SymbolicEquation{Symbolic}(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) end function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) @@ -38,34 +34,30 @@ function symbolic_rewriting(old_d::SummationDecapode, rewriter = nothing) end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) - non_tangents = filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) + non_tangents = Iterators.filter(x -> retrieve_name(d, x) != DerivOp, topological_sort_edges(d)) map(non_tangents) do node to_symbolics(d, symvar_lookup, node.index, node.name) end end # XXX SymbolicUtils.substitute swaps the order of multiplication. -# example: @decapode begin -# u::Form0 -# G::Form0 -# κ::Constant -# ∂ₜ(G) == κ*★(d(★(d(u)))) +# e.g. @decapode begin +# ∂ₜ(G) == κ*u # end -# will have the kappa*var term rewritten to var*kappa +# will have the κ*u term rewritten to u*κ function merge_equations(d::SummationDecapode) symvar_lookup = symbolics_lookup(d) symexpr_list = extract_symexprs(d, symvar_lookup) - eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do node - sym = symvar_lookup[d[node, :name]] + eqn_lookup = Dict{Any,Any}(map(start_nodes(d)) do i + sym = symvar_lookup[d[i, :name]] (sym, sym) end) - foreach(symexpr_list) do expr - merged_rhs = SymbolicUtils.substitute(expr.rhs, eqn_lookup) - push!(eqn_lookup, (expr.lhs => merged_rhs)) + foreach(symexpr_list) do x + push!(eqn_lookup, (x.lhs => SymbolicUtils.substitute(x.rhs, eqn_lookup))) end - terminals = filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) + terminals = Iterators.filter(x -> x.lhs.name in infer_terminal_names(d), symexpr_list) map(x -> formed_deca_eqn(x.lhs, eqn_lookup[x.lhs]), terminals) end @@ -79,8 +71,8 @@ function apply_rewrites(symexprs, rewriter) end function to_acset(d::SummationDecapode, sym_exprs) - outer_types = map([infer_states(d)..., infer_terminals(d)...]) do idx - :($(d[idx, :name])::$(d[idx, :type])) + outer_types = map([infer_states(d)..., infer_terminals(d)...]) do i + :($(d[i, :name])::$(d[i, :type])) end tangents = map(incident(d, DerivOp, :op1)) do op1 @@ -89,19 +81,16 @@ function to_acset(d::SummationDecapode, sym_exprs) #TODO: This step is breaking up summations final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) - recursive_descent = @λ begin - e::Expr => begin - if e.head == :call - e.args[1] = nameof(e.args[1]) - map(recursive_descent, e.args[2:end]) - end + reify!(exprs) = foreach(exprs) do x + if typeof(x)==Expr && x.head == :call + x.args[1] = nameof(x.args[1]) + reify!(x.args[2:end]) end - sym => nothing end - foreach(recursive_descent, final_exprs) + reify!(final_exprs) deca_block = quote end deca_block.args = [outer_types..., tangents..., final_exprs...] - infer_types!(SummationDecapode(parse_decapode(deca_block))) + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) end