diff --git a/src/acset.jl b/src/acset.jl index 607c308..cb633e7 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -155,13 +155,32 @@ function make_sum_mult_unique!(d::AbstractNamedDecapode) end end +# A collection of DecaType getters +# TODO: This should be replaced by using a type hierarchy +const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, + :Literal, :Parameter, :Constant, :infer] + +const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] +const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] +const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] + +const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const USER_TYPES = [:Constant, :Parameter] +const NUMBER_TYPES = [:Literal] +const INFER_TYPES = [:infer] + +# Types that can not ever be inferred +const NONINFERABLE_TYPES = [:Constant, :Parameter, :Literal] + +function get_unsupportedtypes(types) + setdiff(types, ALL_TYPES) +end + # Note: This hard-bakes in Form0 through Form2, and higher Forms are not # allowed. function recognize_types(d::AbstractNamedDecapode) types = d[:type] - unrecognized_types = setdiff(d[:type], [:Form0, :Form1, :Form2, :DualForm0, - :DualForm1, :DualForm2, :Literal, :Parameter, - :Constant, :infer]) + unrecognized_types = get_unsupportedtypes(types) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end @@ -349,51 +368,53 @@ This function accepts an original type and a new type and determines if the orig can be safely overwritten by the new type. """ function safe_modifytype(org_type::Symbol, new_type::Symbol) - modify = (org_type == :infer && !(new_type == :Literal || new_type == :Constant || new_type == :Parameter)) + modify = (org_type in INFER_TYPES && !(new_type in NONINFERABLE_TYPES)) return (modify, modify ? new_type : org_type) end """ - safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol) + safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) This function calls `safe_modifytype` to safely modify a Decapode's variable type. """ -function safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol) - modify, d[var_idx, :type] = safe_modifytype(org_type, new_type) +function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) + modify, d[var_idx, :type] = safe_modifytype(d[var_idx, :type], new_type) return modify end -# ! Warning: This is changing types to :Constant when they weren't originally. -# ! This should be refactored to only change types into Forms -function infer_summands_and_summations!(d::SummationDecapode) - # Note that we are not doing any type checking here! - # i.e. We are not checking for this: [Form0, Form1, Form0]. +""" + filterfor_forms(types::AbstractVector{Symbol}) + +Return any form type symbols. +""" +function filterfor_forms(types::AbstractVector{Symbol}) + conditions = x -> !(x in NONFORM_TYPES) + filter(conditions, types) +end + +function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) + # Note that we are not doing any type checking here for users! + # i.e. We are not checking the underlying types of Constant or Parameter applied = false - for Σ_idx in parts(d, :Σ) - summands = d[:summand][incident(d, Σ_idx, :summation)] - sum = d[:sum][Σ_idx] - idxs = [summands; sum] - types = d[:type][idxs] - all(t != :infer for t in types) && continue # We need not infer - all(t == :infer for t in types) && continue # We can not infer - - known_types = types[findall(!=(:infer), types)] - if :Literal ∈ known_types - # If anything is a Literal, then anything not inferred is a Constant. - inferred_type = :Constant - elseif !isnothing(findfirst(!=(:Constant), known_types)) - # If anything is a Form, then any term in this sum is the same kind of Form. - # Note that we are not explicitly changing Constants to Forms here, - # although we should consider doing so. - inferred_type = known_types[findfirst(!=(:Constant), known_types)] - else - # All terms are now a mix of Constant or infer. Set them all to Constant. - inferred_type = :Constant - end - to_infer_idxs = filter(i -> d[:type][i] == :infer, idxs) - d[to_infer_idxs, :type] = inferred_type - applied = true + + summands = d[incident(d, Σ_idx, :summation), :summand] + sum = d[Σ_idx, :sum] + idxs = [summands; sum] + types = d[idxs, :type] + all(t != :infer for t in types) && return applied # We need not infer + + forms = unique(filterfor_forms(types)) + + form = @match length(forms) begin + 0 => return applied # We can not infer + 1 => only(forms) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") end + + for idx in idxs + applied |= safe_modifytype!(d, idx, form) + end + return applied end @@ -406,8 +427,8 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) check_op = (d[op1_id, :op1] in rule.op_names) if(check_op && (score_src + score_tgt == 1)) - mod_src = safe_modifytype!(d, d[op1_id, :src], type_src, rule.src_type) - mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], type_tgt, rule.tgt_type) + mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) + mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) return mod_src || mod_tgt end @@ -425,9 +446,9 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) check_op = (d[op2_id, :op2] in rule.op_names) if(check_op && (score_proj1 + score_proj2 + score_res == 2)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], type_proj1, rule.proj1_type) - mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], type_proj2, rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], type_res, rule.res_type) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) + mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res end @@ -452,7 +473,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t types_known_op1[incident(d, :infer, [:src, :type])] .= false types_known_op1[incident(d, :infer, [:tgt, :type])] .= false - types_known_op2 = zeros(Bool, nparts(d, :Op2)) + types_known_op2 = ones(Bool, nparts(d, :Op2)) types_known_op2[incident(d, :infer, [:proj1, :type])] .= false types_known_op2[incident(d, :infer, [:proj2, :type])] .= false types_known_op2[incident(d, :infer, [:res, :type])] .= false @@ -467,7 +488,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t this_applied = apply_inference_rule_op1!(d, op1_idx, rule) types_known_op1[op1_idx] = this_applied - applied = applied || this_applied + applied |= this_applied end end @@ -478,11 +499,14 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t this_applied = apply_inference_rule_op2!(d, op2_idx, rule) types_known_op2[op2_idx] = this_applied - applied = applied || this_applied + applied |= this_applied end end - applied = applied || infer_summands_and_summations!(d) + for Σ_idx in parts(d, :Σ) + applied |= infer_sum_types!(d, Σ_idx) + end + applied || break # Break if no rules were applied. end diff --git a/test/language.jl b/test/language.jl index a3352a1..a90b088 100644 --- a/test/language.jl +++ b/test/language.jl @@ -356,6 +356,47 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, + NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES + +@testset "Type Retrival" begin + + type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, + NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + + + # No repeated types + for type_group in type_groups + @test allunique(type_group) + end + + equal_types(types_1, types_2) = issetequal(Set(types_1), Set(types_2)) + no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) + + # Collections of these types should be the same + @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) + @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) + @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + + # Proper seperation of types + @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) + @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) + @test INFER_TYPES == [:infer] + + @test no_overlaps(FORM_TYPES, NUMBER_TYPES) + @test no_overlaps(FORM_TYPES, USER_TYPES) + @test no_overlaps(USER_TYPES, NUMBER_TYPES) +end + +import DiagrammaticEquations: get_unsupportedtypes +@testset "Type Validation" begin + @test isempty(get_unsupportedtypes(ALL_TYPES)) + @test [:A] == get_unsupportedtypes([:A]) + @test [:A] == get_unsupportedtypes([:Form1, :A]) + @test isempty(get_unsupportedtypes(Symbol[])) +end + import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin @@ -384,6 +425,15 @@ import DiagrammaticEquations: safe_modifytype end end +import DiagrammaticEquations: filterfor_forms + +@testset "Form Type Retrieval" begin + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] + @test isempty(filterfor_forms(Symbol[])) + @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) +end + @testset "Type Inference" begin # Warning, this testing depends on the fact that varname, form information is # unique within a decapode even though this is not enforced @@ -391,6 +441,10 @@ end Set(zip(d[:name], d[:type])) end + function test_nametype_equality(d::SummationDecapode, names_types_expected) + issetequal(get_name_type_pair(d), names_types_expected) + end + # The type of the tgt of ∂ₜ is inferred. Test1 = quote C::Form0{X} @@ -400,9 +454,8 @@ end infer_types!(t1) # We use set equality because we do not care about the order of the Var table. - names_types_1 = Set(zip(t1[:name], t1[:type])) names_types_expected_1 = Set([(:C, :Form0)]) - @test issetequal(names_types_1, names_types_expected_1) + @test test_nametype_equality(t1, names_types_expected_1) # The type of the src of ∂ₜ is inferred. Test2 = quote @@ -413,9 +466,8 @@ end t2[only(incident(t2, :C, :name)), :type] = :Form0 infer_types!(t2) - names_types_2 = Set(zip(t2[:name], t2[:type])) names_types_expected_2 = Set([(:C, :Form0)]) - @test issetequal(names_types_2, names_types_expected_2) + @test test_nametype_equality(t2, names_types_expected_2) # The type of the tgt of d is inferred. Test3 = quote @@ -432,9 +484,8 @@ end #t3_inferred = infer_types!(t3) infer_types!(t3) - names_types_3 = Set(zip(t3[:name], t3[:type])) names_types_expected_3 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test issetequal(names_types_3, names_types_expected_3) + @test test_nametype_equality(t3, names_types_expected_3) # The type of the src and tgt of d is inferred. Test4 = quote @@ -448,9 +499,8 @@ end #t4_inferred = infer_types!(t4) infer_types!(t4) - names_types_4 = Set(zip(t4[:name], t4[:type])) names_types_expected_4 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test issetequal(names_types_4, names_types_expected_4) + @test test_nametype_equality(t4, names_types_expected_4) # The type of the src of d is inferred. Test5 = quote @@ -461,9 +511,8 @@ end t5 = SummationDecapode(parse_decapode(Test5)) infer_types!(t5) - names_types_5 = Set(zip(t5[:name], t5[:type])) names_types_expected_5 = Set([(:C, :Form0), (:D, :Form1)]) - @test issetequal(names_types_5, names_types_expected_5) + @test test_nametype_equality(t5, names_types_expected_5) # The type of the src of d is inferred. Test6 = quote @@ -478,11 +527,10 @@ end t6 = SummationDecapode(parse_decapode(Test6)) infer_types!(t6) - names_types_6 = Set(zip(t6[:name], t6[:type])) names_types_expected_6 = Set([ (:A, :Form0), (:B, :Form1), (:C, :Form2), (:F, :DualForm2), (:E, :DualForm1), (:D, :DualForm0)]) - @test issetequal(names_types_6, names_types_expected_6) + @test test_nametype_equality(t6, names_types_expected_6) # The type of a summand is inferred. Test7 = quote @@ -521,12 +569,11 @@ end t8 = SummationDecapode(parse_decapode(Test8)) infer_types!(t8) - names_types_8 = Set(zip(t8[:name], t8[:type])) names_types_expected_8 = Set([ (:Γ, :Form0), (:A, :Form1), (:B, :Form1), (:C, :Form1), (:D, :Form1), (:E, :Form1), (:F, :Form1), (:Θ, :Form2)]) - @test issetequal(names_types_8, names_types_expected_8) + @test test_nametype_equality(t8, names_types_expected_8) function makeInferPathDeca(log_cycles; infer_path = false) cycles = 2 ^ log_cycles @@ -568,11 +615,10 @@ end t9 = expand_operators(t9) infer_types!(t9) - names_types_9 = Set(zip(t9[:name], t9[:type])) names_types_expected_9 = Set([ (:A, :Form0), (Symbol("•_1_", 1), :Form1), (Symbol("•_1_", 2), :Form2), (:B, :DualForm2), (Symbol("•_1_", 4), :DualForm1), (Symbol("•_1_", 3), :DualForm0)]) - @test issetequal(names_types_9, names_types_expected_9) + @test test_nametype_equality(t9, names_types_expected_9) # Basic op2 inference using ∧ Test10 = quote @@ -591,9 +637,8 @@ end t10 = SummationDecapode(parse_decapode(Test10)) infer_types!(t10) - names_types_10 = get_name_type_pair(t10) names_types_expected_10 = Set([(:F, :Form1), (:B, :Form0), (:C, :Form0), (:H, :Form1), (:A, :Form0), (:E, :Form1), (:D, :Form0)]) - @test issetequal(names_types_10, names_types_expected_10) + @test test_nametype_equality(t10, names_types_expected_10) # Basic op2 inference using L Test11 = quote @@ -608,9 +653,9 @@ end t11 = SummationDecapode(parse_decapode(Test11)) infer_types!(t11) - names_types_11 = get_name_type_pair(t11) + names_types_expected_11 = Set([(:A, :Form1), (:B, :DualForm0), (:C, :DualForm0), (:E, :DualForm1), (:D, :DualForm1)]) - @test issetequal(names_types_11, names_types_expected_11) + @test test_nametype_equality(t11, names_types_expected_11) # Basic op2 inference using i Test12 = quote @@ -623,9 +668,8 @@ end t12 = SummationDecapode(parse_decapode(Test12)) infer_types!(t12) - names_types_12 = get_name_type_pair(t12) names_types_expected_12 = Set([(:A, :Form1), (:B, :DualForm1), (:C, :DualForm0)]) - @test issetequal(names_types_12, names_types_expected_12) + @test test_nametype_equality(t12, names_types_expected_12) #2D op2 inference using ∧ Test13 = quote @@ -644,9 +688,8 @@ end t13 = SummationDecapode(parse_decapode(Test13)) infer_types!(t13) - names_types_13 = get_name_type_pair(t13) names_types_expected_13 = Set([(:E, :Form2), (:B, :Form1), (:C, :Form2), (:A, :Form1), (:F, :Form2), (:H, :Form2), (:D, :Form0)]) - @test issetequal(names_types_13, names_types_expected_13) + @test test_nametype_equality(t13, names_types_expected_13) # 2D op2 inference using L Test14 = quote @@ -659,9 +702,8 @@ end t14 = SummationDecapode(parse_decapode(Test14)) infer_types!(t14) - names_types_14 = get_name_type_pair(t14) names_types_expected_14 = Set([(:C, :DualForm2), (:A, :Form1), (:B, :DualForm2)]) - @test issetequal(names_types_14, names_types_expected_14) + @test test_nametype_equality(t14, names_types_expected_14) # 2D op2 inference using i Test15 = quote @@ -674,9 +716,8 @@ end t15 = SummationDecapode(parse_decapode(Test15)) infer_types!(t15) - names_types_15 = get_name_type_pair(t15) names_types_expected_15 = Set([(:A, :Form1), (:B, :DualForm2), (:C, :DualForm1)]) - @test issetequal(names_types_15, names_types_expected_15) + @test test_nametype_equality(t15, names_types_expected_15) # Special case of a summation with a mix of infer and Literal. t16 = @decapode begin @@ -684,9 +725,8 @@ end end infer_types!(t16) - names_types_16 = get_name_type_pair(t16) - names_types_expected_16 = Set([(:A, :Constant), (:C, :Constant), (:D, :Constant), (Symbol("2"), :Literal)]) - @test issetequal(names_types_16, names_types_expected_16) + names_types_expected_16 = Set([(:A, :infer), (:C, :infer), (:D, :infer), (Symbol("2"), :Literal)]) + @test test_nametype_equality(t16, names_types_expected_16) # Special case of a summation with a mix of Form, Constant, and infer. t17 = @decapode begin @@ -696,9 +736,8 @@ end end infer_types!(t17) - names_types_17 = get_name_type_pair(t17) names_types_expected_17 = Set([(:A, :Form0), (:C, :Constant), (:D, :Form0)]) - @test issetequal(names_types_17, names_types_expected_17) + @test test_nametype_equality(t17, names_types_expected_17) # Special case of a summation with a mix of infer and Constant. t18 = @decapode begin @@ -707,11 +746,12 @@ end end infer_types!(t18) - names_types_18 = get_name_type_pair(t18) - names_types_expected_18 = Set([(:A, :Constant), (:C, :Constant), (:D, :Constant)]) - @test issetequal(names_types_18, names_types_expected_18) + names_types_expected_18 = Set([(:A, :infer), (:C, :Constant), (:D, :infer)]) + @test test_nametype_equality(t18, names_types_expected_18) - let d = @decapode begin + # Test #19: Prevent intermediates from converting to Constant + let + d = @decapode begin h::Form0 Γ::Form1 n::Constant @@ -721,19 +761,18 @@ end end infer_types!(d, op1_inf_rules_1D, op2_inf_rules_1D) - - # TODO: This is modifying an intermediate var to be :Constant, which is user-defined - @test_broken d[18, :type] != :Constant + @test d[18, :type] != :Constant end - let d = @decapode begin + # Test #20: Prevent intermediates from converting to Literal + let + d = @decapode begin h::Form0 Γ::Form1 n::Constant ḣ == ∂ₜ(h) ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2))) - end d = expand_operators(d) @@ -742,6 +781,56 @@ end @test d[8, :type] != :Literal end + # Test #21: Prevent intermediates from converting to Parameter + let + d = @decapode begin + C::Parameter + A == C + D + end + infer_types!(d) + + names_types_expected = Set([(:A, :infer), (:C, :Parameter), (:D, :infer)]) + @test test_nametype_equality(d, names_types_expected) + end + + # Test #22: Have source form information flow over Parameters and Constants + let + d = @decapode begin + B::Form0 + C::Parameter + D::Constant + A == C + D + B + end + infer_types!(d) + + names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) + @test test_nametype_equality(d, names_types_expected) + end + + # Test #23: Have target form information flow over Parameters and Constants + let + d = @decapode begin + A::Form0 + C::Parameter + D::Constant + A == C + D + B + end + infer_types!(d) + + names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) + @test test_nametype_equality(d, names_types_expected) + end + + # Test #24: Summing mismatched forms throws an error + let + d = @decapode begin + B::Form0 + C::Form1 + A == C + D + B + end + @test_throws "Type mismatch in summation" infer_types!(d) + end + end @testset "Overloading Resolution" begin