From ed3d20fa3a179955c38ed8b548fa0ed9952b3b7b Mon Sep 17 00:00:00 2001 From: Luke Morris Date: Thu, 22 Aug 2024 14:35:57 -0400 Subject: [PATCH] Add type rules for vectorfields --- src/acset.jl | 40 +++++++++++------ src/deca/deca_acset.jl | 26 +++++++---- test/language.jl | 97 ++++++++++++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 34 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 9665afd..5366459 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -158,13 +158,16 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, - :Literal, :Parameter, :Constant, :infer] + :PVF, :DVF, + :Literal, :Parameter, :Constant, :infer] const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const VECTORFIELD_TYPES = [:PVF, :DVF] + +const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer] const USER_TYPES = [:Constant, :Parameter] const NUMBER_TYPES = [:Literal] const INFER_TYPES = [:infer] @@ -427,12 +430,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) end """ - filterfor_forms(types::AbstractVector{Symbol}) + filterfor_ec_types(types::AbstractVector{Symbol}) -Return any form type symbols. +Return any form or vector-field type symbols. """ -function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in NONFORM_TYPES) +function filterfor_ec_types(types::AbstractVector{Symbol}) + conditions = x -> !(x in NON_EC_TYPES) filter(conditions, types) end @@ -447,16 +450,16 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) types = d[idxs, :type] all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(types)) + ec_types = unique(filterfor_ec_types(types)) - form = @match length(forms) begin + ec_type = @match length(ec_types) begin 0 => return applied # We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") + 1 => only(ec_types) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types") end for idx in idxs - applied |= safe_modifytype!(d, idx, form) + applied |= safe_modifytype!(d, idx, ec_type) end return applied @@ -489,11 +492,24 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if(check_op && (score_proj1 + score_proj2 + score_res == 2)) + if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res + # Special logic for exponentiation: + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) + return mod_proj1 || mod_res + elseif d[op2_id, :op2] == :^ && + (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || + (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) + mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) + return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 55520e3..c5fa67a 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -32,10 +32,17 @@ op1_inf_rules_1D = [ # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), + + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] op2_inf_rules_1D = [ # Rules for ∧₀₀, ∧₁₀, ∧₀₁ @@ -133,13 +140,16 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ diff --git a/test/language.jl b/test/language.jl index 3e3ab4e..991c7dd 100644 --- a/test/language.jl +++ b/test/language.jl @@ -356,13 +356,14 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, + DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, + INFER_TYPES, NONINFERABLE_TYPES @testset "Type Retrival" begin type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] # No repeated types @@ -374,12 +375,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_ no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) - @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) - @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + @test equal_types(ALL_TYPES, FORM_TYPES ∪ VECTORFIELD_TYPES ∪ NON_EC_TYPES) + @test equal_types(FORM_TYPES, PRIMALFORM_TYPES ∪ DUALFORM_TYPES) + @test equal_types(NONINFERABLE_TYPES, USER_TYPES ∪ NUMBER_TYPES) # Proper seperation of types - @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(FORM_TYPES ∪ VECTORFIELD_TYPES, NON_EC_TYPES) @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) @test INFER_TYPES == [:infer] @@ -400,9 +401,9 @@ end import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] bad_sources = [:Literal, :Constant, :Parameter] - good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer] + good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer] for tgt in all_types for src in bad_sources @@ -425,13 +426,13 @@ import DiagrammaticEquations: safe_modifytype end end -import DiagrammaticEquations: filterfor_forms +import DiagrammaticEquations: filterfor_ec_types @testset "Form Type Retrieval" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] - @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] - @test isempty(filterfor_forms(Symbol[])) - @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] + @test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF] + @test isempty(filterfor_ec_types(Symbol[])) + @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin @@ -831,6 +832,38 @@ end @test_throws "Type mismatch in summation" infer_types!(d) end + # Test #25: Infer between flattened and sharpened vector fields. + let + d = @decapode begin + A::Form1 + B::DualForm1 + C::PVF + D::DVF + + A == ♭(E) + B == ♭(F) + C == ♯(G) + D == ♯(H) + + I::Form1 + J::DualForm1 + K::PVF + L::DVF + + M == ♯(I) + N == ♯(J) + O == ♭(K) + P == ♭(L) + end + infer_types!(d) + + # TODO: Update this as more sharps and flats are released. + names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF), + (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), + (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), + (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) + @test test_nametype_equality(d, names_types_expected) + end end @testset "Overloading Resolution" begin @@ -1048,6 +1081,42 @@ end op2s_hx = HeatXfer[:op2] op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] @test op2s_hx == op2s_expected_hx + + # Infer types and resolve overloads for the Halfar equation. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_types!(d) + resolve_overloads!(d) + @test d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + end @testset "Compilation Transformation" begin