Skip to content

Commit

Permalink
Remove metadata usage
Browse files Browse the repository at this point in the history
This needs to switch to use the new type system
  • Loading branch information
GeorgeR227 committed Sep 13, 2024
1 parent 5f78860 commit 69619ce
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 60 deletions.
2 changes: 0 additions & 2 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")
include("ThDEC.jl")
include("decasymbolic.jl")
include("acset2symbolic.jl")

@reexport using .Deca
Expand Down
45 changes: 19 additions & 26 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using SymbolicUtils.Rewriters
using SymbolicUtils.Code
using MLStyle

import DiagrammaticEquations.ThDEC: Space

export extract_symexprs, apply_rewrites, merge_equations
export extract_symexprs, apply_rewrites, merge_equations, to_acset

const DECA_EQUALITY_SYMBOL = (==)

Expand All @@ -17,8 +15,8 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op1})
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :tgt], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Number}}(d[op_index, :op1])

Check warning on line 16 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L13-L16

Added lines #L13 - L16 were not covered by tests

input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type]))
output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type]))
# input_sym = setmetadata(input_sym, Sort, oldtype_to_new(d[d[op_index, :src], :type]))
# output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :tgt], :type]))

rhs = SymbolicUtils.Term{Number}(op_sym, [input_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 22 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
Expand All @@ -30,9 +28,9 @@ function to_symbolics(d::SummationDecapode, op_index::Int, ::Val{:Op2})
output_sym = SymbolicUtils.Sym{Number}(d[d[op_index, :res], :name])
op_sym = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number, Number}, Number}}(d[op_index, :op2])

Check warning on line 29 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L25-L29

Added lines #L25 - L29 were not covered by tests

input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type]))
input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type]))
output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type]))
# input1_sym = setmetadata(input1_sym, Sort, oldtype_to_new(d[d[op_index, :proj1], :type]))
# input2_sym = setmetadata(input2_sym, Sort, oldtype_to_new(d[d[op_index, :proj2], :type]))
# output_sym = setmetadata(output_sym, Sort, oldtype_to_new(d[d[op_index, :res], :type]))

rhs = SymbolicUtils.Term{Number}(op_sym, [input1_sym, input2_sym])
SymbolicUtils.Term{Number}(DECA_EQUALITY_SYMBOL, [output_sym, rhs])

Check warning on line 36 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
Expand All @@ -43,20 +41,20 @@ end
# Expr(EQUALITY_SYMBOL, c.output, Expr(:call, Expr(:., :+), c.inputs...))
# end

function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort
@match old begin
:Form0 => PrimalForm(0, space)
:Form1 => PrimalForm(1, space)
:Form2 => PrimalForm(2, space)
# function oldtype_to_new(old::Symbol, space::Space = Space(:I, 2))::Sort
# @match old begin
# :Form0 => PrimalForm(0, space)
# :Form1 => PrimalForm(1, space)
# :Form2 => PrimalForm(2, space)

:DualForm0 => DualForm(0, space)
:DualForm1 => DualForm(1, space)
:DualForm2 => DualForm(2, space)
# :DualForm0 => DualForm(0, space)
# :DualForm1 => DualForm(1, space)
# :DualForm2 => DualForm(2, space)

:Constant => Scalar()
:Parameter => Scalar()
end
end
# :Constant => Scalar()
# :Parameter => Scalar()
# end
# end

function extract_symexprs(d::SummationDecapode)
topo_list = topological_sort_edges(d)
Expand Down Expand Up @@ -88,7 +86,7 @@ function merge_equations(d::SummationDecapode, rewritten_syms)

for node in start_nodes(d)
sym = SymbolicUtils.Sym{Number}(d[node, :name])

Check warning on line 88 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L87-L88

Added lines #L87 - L88 were not covered by tests
sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type]))
# sym = setmetadata(sym, Sort, oldtype_to_new(d[node, :type]))
push!(lookup, (sym => sym))
end

Check warning on line 91 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L90-L91

Added lines #L90 - L91 were not covered by tests

Expand Down Expand Up @@ -148,8 +146,3 @@ function to_acset(og_d, sym_exprs)

d

Check warning on line 147 in src/acset2symbolic.jl

View check run for this annotation

Codecov / codecov/patch

src/acset2symbolic.jl#L147

Added line #L147 was not covered by tests
end

# TODO: We need a way to get information like the d and ⋆ even when not in the ACSet
# @syms Δ(x) d(x) ⋆(x)
# lap_0_rule = @rule Δ(~x) => ⋆(d(⋆(d(~x))))
# rewriter = Postwalk(RestartedChain([lap_0_rule]))
36 changes: 4 additions & 32 deletions src/sym_rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ using DiagrammaticEquations
using SymbolicUtils
using MLStyle

test_space = Space(:X, 2)

test_type = Form(0, false, test_space)

Heat = @decapode begin
C::Form0
D::Constant
Expand All @@ -15,18 +11,6 @@ end
infer_types!(Heat)
resolve_overloads!(Heat)

function isform_zero(x)
getmetadata(x, Sort).dim == 0
end

function isform_one(x)
getmetadata(x, Sort).dim == 1
end

function isform_two(x)
getmetadata(x, Sort).dim == 2
end

@syms Δ(x) d(x) (x) Δ₀(x) Δ₁(x) Δ₂(x) d₀(x)

lap_0_convert = @rule Δ₀(~x) => Δ(~x)
Expand All @@ -37,9 +21,9 @@ d_0_convert = @rule d₀(~x) => d(~x)

overloaders = [lap_0_convert, lap_1_convert, lap_2_convert, d_0_convert]

lap_0_rule = @rule Δ(~x::(isform_zero)) => (d((d(~x))))
lap_1_rule = @rule Δ(~x::(isform_one)) => d((d((~x)))) + (d((d(~x))))
lap_2_rule = @rule Δ(~x::(isform_two)) => d((d((~x))))
lap_0_rule = @rule Δ(~x) => (d((d(~x))))
lap_1_rule = @rule Δ(~x) => d((d((~x)))) + (d((d(~x))))
lap_2_rule = @rule Δ(~x) => d((d((~x))))

openers = [lap_0_rule, lap_1_rule, lap_2_rule]

Expand All @@ -61,18 +45,6 @@ optm_rewriter = SymbolicUtils.Postwalk(

res_merge_exprs = map(optm_rewriter, merge_exprs)

final_exprs = SymbolicUtils.Code.toexpr.(res_merge_exprs)
map(x -> x.args[1] = :(==), final_exprs)

to_decapode = quote
C::Form0
D::Constant
Ċ::Form0
end

append!(to_decapode.args, final_exprs)

test = parse_decapode(to_decapode)
deca_test = SummationDecapode(test)
deca_test = to_acset(Heat, res_merge_exprs)
infer_types!(deca_test)
resolve_overloads!(deca_test)

0 comments on commit 69619ce

Please sign in to comment.