From 378d6a1ce3399caf8d50b6e26cd08364a37be6be Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 16 Sep 2024 14:03:16 -0400 Subject: [PATCH] Added DECQuantity types Also switched to using SymbolicsUtils' `substitute`. Still needs tests and code needs to be cleaned up. --- src/acset2symbolic.jl | 97 ++++++++++++++++++------------------------ src/graph_traversal.jl | 2 +- 2 files changed, 43 insertions(+), 56 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index f84b3f2..802d182 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -1,4 +1,5 @@ using DiagrammaticEquations +using ACSets using SymbolicUtils using SymbolicUtils.Rewriters using SymbolicUtils.Code @@ -8,30 +9,37 @@ export extract_symexprs, apply_rewrites, merge_equations, to_acset const DECA_EQUALITY_SYMBOL = (==) -to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name)) +to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name)) -function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1}) - input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name]) - output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name]) - op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) +function symbolics_lookup(d::SummationDecapode) + lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}() + for i in parts(d, :Var) + push!(lookup, d[i, :name] => decavar_to_symbolics(d, i)) + end + lookup +end + +function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I) + var = d[index, :name] + new_type = symtype(Deca.DECQuantity, d[index, :type], space) + SymbolicUtils.Sym{new_type}(var) +end - # input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type])) - # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type])) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1}) + input_sym = symvar_lookup[d[d[op_index, :src], :name]] + output_sym = symvar_lookup[d[d[op_index, :tgt], :name]] + op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1]) rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end -function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2}) - input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name]) - input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name]) - output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name]) +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2}) + input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]] + input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]] + output_sym = symvar_lookup[d[d[op_index, :res], :name]] op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2]) - # input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type])) - # input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type])) - # output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type])) - rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym]) SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs]) end @@ -41,27 +49,23 @@ end # Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...)) # end -# function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort -# @match old begin -# :Form0 => PrimalForm(0, space) -# :Form1 => PrimalForm(1, space) -# :Form2 => PrimalForm(2, space) +function symbolic_rewriting(old_d::SummationDecapode) + d = deepcopy(old_d) + + infer_types!(d) + resolve_overloads!(d) -# :DualForm0 => DualForm(0, space) -# :DualForm1 => DualForm(1, space) -# :DualForm2 => DualForm(2, space) + symvar_lookup = symbolics_lookup(d) -# :Constant => Scalar() -# :Parameter => Scalar() -# end -# end + symexprs = extract_symexprs(d, symvar_lookup) +end -function extract_symexprs(d::SummationDecapode) +function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}) topo_list = topological_sort_edges(d) sym_list = [] for node in topo_list retrieve_name(d, node) != DerivOp || continue - push!(sym_list, to_symbolics(d, node)) + push!(sym_list, to_symbolics(d, symvar_lookup, node)) end sym_list end @@ -78,52 +82,35 @@ function apply_rewrites(d::SummationDecapode, rewriter) rewritten_list end -function merge_equations(d::SummationDecapode, rewritten_syms) +function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms) - lookup = Dict() + eqn_lookup = Dict() final_list = [] for node in start_nodes(d) - sym = SymbolicUtils.Sym{Number}(d[node, :name]) - # sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type])) - push!(lookup, (sym => sym)) + sym = symvar_lookup[d[node, :name]] + push!(eqn_lookup, (sym => sym)) end final_nodes = infer_terminal_names(d) for expr in rewritten_syms - lhs = expr_lhs(expr) - rhs = expr_rhs(expr) - recursive_descent(rhs, lookup) + merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup) + lhs = merged_eqn.arguments[1] + rhs = merged_eqn.arguments[2] - push!(lookup, (lhs => rhs)) + push!(eqn_lookup, (lhs => rhs)) if lhs.name in final_nodes - push!(final_list, expr) + push!(final_list, merged_eqn) end end final_list end -expr_lhs(expr) = expr.arguments[1] -expr_rhs(expr) = expr.arguments[2] - -function recursive_descent(expr, lookup) - # @show expr - for (i, arg) in enumerate(expr.arguments) - # @show arg - if arg in keys(lookup) - expr.arguments[i] = lookup[arg] - else - recursive_descent(arg, lookup) - end - end - return expr -end - function to_acset(og_d, sym_exprs) final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) map(x -> x.args[1] = :(==), final_exprs) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index 1fd78e2..43b8860 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -67,6 +67,6 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) :Op1 => d[tsr.index, :op1] :Op2 => d[tsr.index, :op2] :Σ => :+ - _ => error("$(tsr.name) is not a valid table for names") + _ => error("$(tsr.name) is a table without names") end end