Skip to content

Commit

Permalink
More tests for canon overloading
Browse files Browse the repository at this point in the history
Problem right now is that operators that get overloaded resolve into our unicode operators. However, already typed operators that aren't in the unicode stay untouched. This is intended behavior but means downstream packages need to use canon names to work with all supported names.
  • Loading branch information
GeorgeR227 committed Jun 26, 2024
1 parent 510619a commit 8a845df
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,13 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [
# Rules for Δ.
(src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = NOFORM_LAPLACE),
(src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = NOFORM_LAPLACE),
(src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = NOFORM_LAPLACE)]
(src_type = :Form2, tgt_type = :Form2, resolved_name = :Δ₂, op = NOFORM_LAPLACE),
# (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl),
# (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl),
# (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)]

(src_type = :Form0, tgt_type = :Form1, resolved_name = :avg₀₁, op = NOFORM_AVG)]

# We merge 1D and 2D rules directly here since it seems op2 rules
# are metric-free. If this assumption is false, this needs to change.
op2_res_rules_2D = vcat(op2_res_rules_1D, [
Expand Down
46 changes: 38 additions & 8 deletions test/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ end
vcat(bespoke_op2_res_rules, op2_res_rules_2D))

op1s_hx = HeatXfer[:op1]
op1s_expected_hx = [:d₀, :₁, :dual_d₁, :₀⁻¹, :avg, :R₀, :₀, :₀⁻¹, :neg, :∂ₜ, :avg, :₁, :neg, :avg, :Δ₁, :δ₁, :d₀, :₁, :d₀, :avg, :d₀, :neg, :avg, :∂ₜ, :₀, :₀⁻¹, :neg, :∂ₜ]
op1s_expected_hx = [:d₀, :₁, :dual_d₁, :₀⁻¹, :avg₀₁, :R₀, :₀, :₀⁻¹, :neg, :∂ₜ, :avg₀₁, :₁, :neg, :avg₀₁, :Δ₁, :δ₁, :d₀, :₁, :d₀, :avg₀₁, :d₀, :neg, :avg₀₁, :∂ₜ, :₀, :₀⁻¹, :neg, :∂ₜ]
@test op1s_hx == op1s_expected_hx # Correct but probably by chance, see above
op2s_hx = HeatXfer[:op2]
op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀]
Expand All @@ -1043,11 +1043,20 @@ end
@testset "Op1 Canonicalization" begin
function check_canontyping(control::SummationDecapode, test::SummationDecapode)
infer_test = infer_types!(deepcopy(test))
@test control[:type] == infer_test[:type]
infer_control = infer_types!(deepcopy(control))
@test infer_control[:type] == infer_test[:type]
@test infer_test[:op1] == test[:op1]
@test infer_test[:op2] == test[:op2]
end

function check_canonoverload(control::SummationDecapode, test::SummationDecapode)
over_test = resolve_overloads!(infer_types!(deepcopy(test)))
over_control = resolve_overloads!(infer_types!(deepcopy(control)))
@test get_canon_name(over_test[:op1]) == get_canon_name(over_control[:op1])
end

setup_basecase(d::SummationDecapode) = resolve_overloads!(infer_types!(deepcopy(d)))

# Test exterior derivative and hodge
gen_d1 = @decapode begin
A::Form0
Expand All @@ -1056,10 +1065,12 @@ end
D == hdg(hdg(C))
E == d(d(hdg(D)))
end
infer_types!(gen_d1)
@test gen_d1[:type] ==
[:Form0, :Form1, :DualForm2, :Form0, :DualForm2, :Form2,
:DualForm0, :Form1, :DualForm1, :Form2, :DualForm1, :DualForm0]

let # Check base case explicitly
d = setup_basecase(gen_d1)
@test d[:type] == [:Form0, :Form1, :DualForm2, :Form0, :DualForm2, :Form2, :DualForm0, :Form1, :DualForm1, :Form2, :DualForm1, :DualForm0]
@test get_canon_name(d[:op1]) == [:hdg_0, :invhdg_0, :d_0, :hdg_1, :invhdg_1, :d_1, :hdg_2, :invhdg_2, :hdg_2, :duald_0, :duald_1]
end

let # Ascii and tagged
d = @decapode begin
Expand All @@ -1070,6 +1081,7 @@ end
E == duald_1(duald_0(hdg_2(D)))
end
check_canontyping(gen_d1, d)
check_canonoverload(gen_d1, d)
end

let # Unicode and tagged
Expand All @@ -1081,6 +1093,7 @@ end
E == d̃₁(d̃₀((D)))
end
check_canontyping(gen_d1, d)
check_canonoverload(gen_d1, d)
end

let # Unicode and not tagged
Expand All @@ -1092,6 +1105,7 @@ end
E == (((D)))
end
check_canontyping(gen_d1, d)
check_canonoverload(gen_d1, d)
end

let # Combination of names
Expand All @@ -1103,6 +1117,7 @@ end
E == d̃₁(duald_0(hdg_2(D)))
end
check_canontyping(gen_d1, d)
check_canonoverload(gen_d1, d)
end

# Test laplacian and codifferential
Expand All @@ -1113,12 +1128,19 @@ end
infer_types!(gen_d2)
@test gen_d2[:type] == [:Form0, :Form0, :Form0, :Form1, :Form2, :Form2, :Form1, :Form1]

let # Check base case explicitly
d = setup_basecase(gen_d2)
@test d[:type] == [:Form0, :Form0, :Form0, :Form1, :Form2, :Form2, :Form1, :Form1]
@test get_canon_name(d[:op1]) == [:lapl_0, :d_0, :lapl_1, :d_1, :lapl_2, :codif_2, :codif_1]
end

let # Ascii and tagged
d = @decapode begin
A::Form0
B == codif_1(codif_2(lapl_2(d_1(lapl_1(d_0(lapl_0(A)))))))
end
check_canontyping(gen_d2, d)
check_canonoverload(gen_d2, d)
end

let # Unicode and tagged
Expand All @@ -1127,6 +1149,7 @@ end
B == δ₁(δ₂(Δ₂(d₁(Δ₁(d₀(Δ₀(A)))))))
end
check_canontyping(gen_d2, d)
check_canonoverload(gen_d2, d)
end

let # Unicode and not tagged
Expand All @@ -1135,6 +1158,7 @@ end
B == δ(δ(Δ(d(Δ(d(Δ(A)))))))
end
check_canontyping(gen_d2, d)
check_canonoverload(gen_d2, d)
end

let # Combination of names
Expand All @@ -1143,6 +1167,7 @@ end
B == δ(codif(Δ₂(d₁(lapl(d_0(Δ(A)))))))
end
check_canontyping(gen_d2, d)
check_canonoverload(gen_d2, d)
end

# Test average, neg and mag
Expand All @@ -1152,8 +1177,12 @@ end
B == neg(mag(avg(neg(mag(A)))))
D == neg(mag(C))
end
infer_types!(gen_d3)
@test gen_d3[:type] == [:Form0, :Form2, :Form1, :Form2, :Form1, :Form1, :Form0, :Form0, :Form2]

let # Check base case explicitly
d = setup_basecase(gen_d3)
@test d[:type] == [:Form0, :Form2, :Form1, :Form2, :Form1, :Form1, :Form0, :Form0, :Form2]
@test get_canon_name(d[:op1]) == [:mag, :-, :avg_01, :mag, :-, :mag, :-]
end

let # Alternate names
d = @decapode begin
Expand All @@ -1163,6 +1192,7 @@ end
D == -(norm(C))
end
check_canontyping(gen_d3, d)
check_canonoverload(gen_d3, d)
end

let # Typing respects typed exterior derivative
Expand Down

0 comments on commit 8a845df

Please sign in to comment.