Skip to content

Commit

Permalink
Added DECQuantity types
Browse files Browse the repository at this point in the history
Also switched to using SymbolicsUtils' `substitute`. Still needs tests and code needs to be cleaned up.
  • Loading branch information
GeorgeR227 committed Sep 16, 2024
1 parent 69619ce commit 378d6a1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 56 deletions.
97 changes: 42 additions & 55 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DiagrammaticEquations
using ACSets
using SymbolicUtils
using SymbolicUtils.Rewriters
using SymbolicUtils.Code
Expand All @@ -8,30 +9,37 @@ export extract_symexprs, apply_rewrites, merge_equations, to_acset

const DECA_EQUALITY_SYMBOL = (==)

to_symbolics(d::SummationDecapode, node::TraversalNode) = to_symbolics(d, node.index, Val(node.name))
to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, node::TraversalNode) = to_symbolics(d, symvar_lookup, node.index, Val(node.name))

Check warning on line 12 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L12

Added line #L12 was not covered by tests

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
input_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :src], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])
function symbolics_lookup(d::SummationDecapode)
lookup = Dict{Symbol, SymbolicUtils.BasicSymbolic}()
for i in parts(d, :Var)
push!(lookup, d[i, :name] => decavar_to_symbolics(d, i))
end
lookup

Check warning on line 19 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L14-L19

Added lines #L14 - L19 were not covered by tests
end

function decavar_to_symbolics(d::SummationDecapode, index::Int; space = :I)
var = d[index, :name]
new_type = symtype(Deca.DECQuantity, d[index, :type], space)
SymbolicUtils.Sym{new_type}(var)

Check warning on line 25 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L22-L25

Added lines #L22 - L25 were not covered by tests
end

# input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type]))
# output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type]))
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op1})
input_sym = symvar_lookup[d[d[op_index, :src], :name]]
output_sym = symvar_lookup[d[d[op_index, :tgt], :name]]
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

Check warning on line 31 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L28-L31

Added lines #L28 - L31 were not covered by tests

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 34 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2})
input1_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj1], :name])
input2_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :proj2], :name])
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name])
function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, op_index::Int, ::Val{:Op2})
input1_sym = symvar_lookup[d[d[op_index, :proj1], :name]]
input2_sym = symvar_lookup[d[d[op_index, :proj2], :name]]
output_sym = symvar_lookup[d[d[op_index, :res], :name]]
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

Check warning on line 41 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L37-L41

Added lines #L37 - L41 were not covered by tests

# input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type]))
# input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type]))
# output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type]))

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 44 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
end
Expand All @@ -41,27 +49,23 @@ end
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end

# function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort
# @match old begin
# :Form0 => PrimalForm(0, space)
# :Form1 => PrimalForm(1, space)
# :Form2 => PrimalForm(2, space)
function symbolic_rewriting(old_d::SummationDecapode)
d = deepcopy(old_d)

Check warning on line 53 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L52-L53

Added lines #L52 - L53 were not covered by tests

infer_types!(d)
resolve_overloads!(d)

Check warning on line 56 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L55-L56

Added lines #L55 - L56 were not covered by tests

# :DualForm0 => DualForm(0, space)
# :DualForm1 => DualForm(1, space)
# :DualForm2 => DualForm(2, space)
symvar_lookup = symbolics_lookup(d)

Check warning on line 58 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L58

Added line #L58 was not covered by tests

# :Constant => Scalar()
# :Parameter => Scalar()
# end
# end
symexprs = extract_symexprs(d, symvar_lookup)

Check warning on line 60 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L60

Added line #L60 was not covered by tests
end

function extract_symexprs(d::SummationDecapode)
function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic})
topo_list = topological_sort_edges(d)
sym_list = []
for node in topo_list
retrieve_name(d, node) != DerivOp || continue
push!(sym_list, to_symbolics(d, node))
push!(sym_list, to_symbolics(d, symvar_lookup, node))
end
sym_list

Check warning on line 70 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L63-L70

Added lines #L63 - L70 were not covered by tests
end
Expand All @@ -78,52 +82,35 @@ function apply_rewrites(d::SummationDecapode, rewriter)
rewritten_list

Check warning on line 82 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L82

Added line #L82 was not covered by tests
end

function merge_equations(d::SummationDecapode, rewritten_syms)
function merge_equations(d::SummationDecapode, symvar_lookup::Dict{Symbol, SymbolicUtils.BasicSymbolic}, rewritten_syms)

Check warning on line 85 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L85

Added line #L85 was not covered by tests

lookup = Dict()
eqn_lookup = Dict()

Check warning on line 87 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L87

Added line #L87 was not covered by tests

final_list = []

Check warning on line 89 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L89

Added line #L89 was not covered by tests

for node in start_nodes(d)
sym = SymbolicUtils.Sym{Number}(d[node, :name])
# sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type]))
push!(lookup, (sym => sym))
sym = symvar_lookup[d[node, :name]]
push!(eqn_lookup, (sym => sym))
end

Check warning on line 94 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L91-L94

Added lines #L91 - L94 were not covered by tests

final_nodes = infer_terminal_names(d)

Check warning on line 96 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L96

Added line #L96 was not covered by tests

for expr in rewritten_syms

Check warning on line 98 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L98

Added line #L98 was not covered by tests
lhs = expr_lhs(expr)
rhs = expr_rhs(expr)

recursive_descent(rhs, lookup)
merged_eqn = SymbolicUtils.substitute(expr, eqn_lookup)
lhs = merged_eqn.arguments[1]
rhs = merged_eqn.arguments[2]

Check warning on line 102 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L100-L102

Added lines #L100 - L102 were not covered by tests

push!(lookup, (lhs => rhs))
push!(eqn_lookup, (lhs => rhs))

Check warning on line 104 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L104

Added line #L104 was not covered by tests

if lhs.name in final_nodes
push!(final_list, expr)
push!(final_list, merged_eqn)

Check warning on line 107 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
end
end

Check warning on line 109 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L109

Added line #L109 was not covered by tests

final_list

Check warning on line 111 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L111

Added line #L111 was not covered by tests
end

expr_lhs(expr) = expr.arguments[1]
expr_rhs(expr) = expr.arguments[2]

function recursive_descent(expr, lookup)
# @show expr
for (i, arg) in enumerate(expr.arguments)
# @show arg
if arg in keys(lookup)
expr.arguments[i] = lookup[arg]
else
recursive_descent(arg, lookup)
end
end
return expr
end

function to_acset(og_d, sym_exprs)
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)
map(x -> x.args[1] = :(==), final_exprs)

Check warning on line 116 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L114-L116

Added lines #L114 - L116 were not covered by tests
Expand Down
2 changes: 1 addition & 1 deletion src/graph_traversal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode)
:Op1 => d[tsr.index, :op1]
:Op2 => d[tsr.index, :op2]
=> :+
_ => error("$(tsr.name) is not a valid table for names")
_ => error("$(tsr.name) is a table without names")

Check warning on line 70 in src/graph_traversal.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_traversal.jl#L70

Added line #L70 was not covered by tests
end
end

0 comments on commit 378d6a1

Please sign in to comment.