From a92b40e921ebef0fd8013311a293d7447e0fbed6 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:04:02 -0400 Subject: [PATCH 01/11] Clean up summation infer This will now only propagate form types. If a single form is present then this will be inferred. If there are no forms then we skip and if there are multiple forms we error. --- src/acset.jl | 61 ++++++++++++++--------------- test/language.jl | 100 +++++++++++++++++++++++++++-------------------- 2 files changed, 87 insertions(+), 74 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 607c308..3fc0c61 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -354,45 +354,42 @@ function safe_modifytype(org_type::Symbol, new_type::Symbol) end """ - safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol) + safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) This function calls `safe_modifytype` to safely modify a Decapode's variable type. """ -function safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol) - modify, d[var_idx, :type] = safe_modifytype(org_type, new_type) +function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) + modify, d[var_idx, :type] = safe_modifytype(d[var_idx, :type], new_type) return modify end +function filterfor_forms(types::AbstractVector{Symbol}) + conditions = x -> x != :Literal && x != :Constant && x != :Parameter && x != :infer + filter(conditions, types) +end + # ! Warning: This is changing types to :Constant when they weren't originally. -# ! This should be refactored to only change types into Forms function infer_summands_and_summations!(d::SummationDecapode) - # Note that we are not doing any type checking here! - # i.e. We are not checking for this: [Form0, Form1, Form0]. + # Note that we are not doing any type checking here for users! + # i.e. We are not checking the underlying types of Constant or Parameter applied = false + for Σ_idx in parts(d, :Σ) - summands = d[:summand][incident(d, Σ_idx, :summation)] - sum = d[:sum][Σ_idx] + summands = d[incident(d, Σ_idx, :summation), :summand] + sum = d[Σ_idx, :sum] idxs = [summands; sum] - types = d[:type][idxs] - all(t != :infer for t in types) && continue # We need not infer - all(t == :infer for t in types) && continue # We can not infer - - known_types = types[findall(!=(:infer), types)] - if :Literal ∈ known_types - # If anything is a Literal, then anything not inferred is a Constant. - inferred_type = :Constant - elseif !isnothing(findfirst(!=(:Constant), known_types)) - # If anything is a Form, then any term in this sum is the same kind of Form. - # Note that we are not explicitly changing Constants to Forms here, - # although we should consider doing so. - inferred_type = known_types[findfirst(!=(:Constant), known_types)] - else - # All terms are now a mix of Constant or infer. Set them all to Constant. - inferred_type = :Constant + + forms = unique(filterfor_forms(d[idxs, :type])) + + form = @match length(forms) begin + 0 => continue # We need not infer, We can not infer + 1 => only(forms) + _ => error("Type mismatch in summation") + end + + for idx in idxs + applied |= safe_modifytype!(d, idx, form) end - to_infer_idxs = filter(i -> d[:type][i] == :infer, idxs) - d[to_infer_idxs, :type] = inferred_type - applied = true end return applied end @@ -406,8 +403,8 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) check_op = (d[op1_id, :op1] in rule.op_names) if(check_op && (score_src + score_tgt == 1)) - mod_src = safe_modifytype!(d, d[op1_id, :src], type_src, rule.src_type) - mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], type_tgt, rule.tgt_type) + mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) + mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) return mod_src || mod_tgt end @@ -425,9 +422,9 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) check_op = (d[op2_id, :op2] in rule.op_names) if(check_op && (score_proj1 + score_proj2 + score_res == 2)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], type_proj1, rule.proj1_type) - mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], type_proj2, rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], type_res, rule.res_type) + mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) + mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res end diff --git a/test/language.jl b/test/language.jl index a3352a1..ea68460 100644 --- a/test/language.jl +++ b/test/language.jl @@ -391,6 +391,10 @@ end Set(zip(d[:name], d[:type])) end + function test_nametype_equality(d::SummationDecapode, names_types_expected) + issetequal(get_name_type_pair(d), names_types_expected) + end + # The type of the tgt of ∂ₜ is inferred. Test1 = quote C::Form0{X} @@ -400,9 +404,8 @@ end infer_types!(t1) # We use set equality because we do not care about the order of the Var table. - names_types_1 = Set(zip(t1[:name], t1[:type])) names_types_expected_1 = Set([(:C, :Form0)]) - @test issetequal(names_types_1, names_types_expected_1) + @test test_nametype_equality(t1, names_types_expected_1) # The type of the src of ∂ₜ is inferred. Test2 = quote @@ -413,9 +416,8 @@ end t2[only(incident(t2, :C, :name)), :type] = :Form0 infer_types!(t2) - names_types_2 = Set(zip(t2[:name], t2[:type])) names_types_expected_2 = Set([(:C, :Form0)]) - @test issetequal(names_types_2, names_types_expected_2) + @test test_nametype_equality(t2, names_types_expected_2) # The type of the tgt of d is inferred. Test3 = quote @@ -432,9 +434,8 @@ end #t3_inferred = infer_types!(t3) infer_types!(t3) - names_types_3 = Set(zip(t3[:name], t3[:type])) names_types_expected_3 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test issetequal(names_types_3, names_types_expected_3) + @test test_nametype_equality(t3, names_types_expected_3) # The type of the src and tgt of d is inferred. Test4 = quote @@ -448,9 +449,8 @@ end #t4_inferred = infer_types!(t4) infer_types!(t4) - names_types_4 = Set(zip(t4[:name], t4[:type])) names_types_expected_4 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test issetequal(names_types_4, names_types_expected_4) + @test test_nametype_equality(t4, names_types_expected_4) # The type of the src of d is inferred. Test5 = quote @@ -461,9 +461,8 @@ end t5 = SummationDecapode(parse_decapode(Test5)) infer_types!(t5) - names_types_5 = Set(zip(t5[:name], t5[:type])) names_types_expected_5 = Set([(:C, :Form0), (:D, :Form1)]) - @test issetequal(names_types_5, names_types_expected_5) + @test test_nametype_equality(t5, names_types_expected_5) # The type of the src of d is inferred. Test6 = quote @@ -478,11 +477,10 @@ end t6 = SummationDecapode(parse_decapode(Test6)) infer_types!(t6) - names_types_6 = Set(zip(t6[:name], t6[:type])) names_types_expected_6 = Set([ (:A, :Form0), (:B, :Form1), (:C, :Form2), (:F, :DualForm2), (:E, :DualForm1), (:D, :DualForm0)]) - @test issetequal(names_types_6, names_types_expected_6) + @test test_nametype_equality(t6, names_types_expected_6) # The type of a summand is inferred. Test7 = quote @@ -521,12 +519,11 @@ end t8 = SummationDecapode(parse_decapode(Test8)) infer_types!(t8) - names_types_8 = Set(zip(t8[:name], t8[:type])) names_types_expected_8 = Set([ (:Γ, :Form0), (:A, :Form1), (:B, :Form1), (:C, :Form1), (:D, :Form1), (:E, :Form1), (:F, :Form1), (:Θ, :Form2)]) - @test issetequal(names_types_8, names_types_expected_8) + @test test_nametype_equality(t8, names_types_expected_8) function makeInferPathDeca(log_cycles; infer_path = false) cycles = 2 ^ log_cycles @@ -568,11 +565,10 @@ end t9 = expand_operators(t9) infer_types!(t9) - names_types_9 = Set(zip(t9[:name], t9[:type])) names_types_expected_9 = Set([ (:A, :Form0), (Symbol("•_1_", 1), :Form1), (Symbol("•_1_", 2), :Form2), (:B, :DualForm2), (Symbol("•_1_", 4), :DualForm1), (Symbol("•_1_", 3), :DualForm0)]) - @test issetequal(names_types_9, names_types_expected_9) + @test test_nametype_equality(t9, names_types_expected_9) # Basic op2 inference using ∧ Test10 = quote @@ -591,9 +587,8 @@ end t10 = SummationDecapode(parse_decapode(Test10)) infer_types!(t10) - names_types_10 = get_name_type_pair(t10) names_types_expected_10 = Set([(:F, :Form1), (:B, :Form0), (:C, :Form0), (:H, :Form1), (:A, :Form0), (:E, :Form1), (:D, :Form0)]) - @test issetequal(names_types_10, names_types_expected_10) + @test test_nametype_equality(t10, names_types_expected_10) # Basic op2 inference using L Test11 = quote @@ -608,9 +603,9 @@ end t11 = SummationDecapode(parse_decapode(Test11)) infer_types!(t11) - names_types_11 = get_name_type_pair(t11) + names_types_expected_11 = Set([(:A, :Form1), (:B, :DualForm0), (:C, :DualForm0), (:E, :DualForm1), (:D, :DualForm1)]) - @test issetequal(names_types_11, names_types_expected_11) + @test test_nametype_equality(t11, names_types_expected_11) # Basic op2 inference using i Test12 = quote @@ -623,9 +618,8 @@ end t12 = SummationDecapode(parse_decapode(Test12)) infer_types!(t12) - names_types_12 = get_name_type_pair(t12) names_types_expected_12 = Set([(:A, :Form1), (:B, :DualForm1), (:C, :DualForm0)]) - @test issetequal(names_types_12, names_types_expected_12) + @test test_nametype_equality(t12, names_types_expected_12) #2D op2 inference using ∧ Test13 = quote @@ -644,9 +638,8 @@ end t13 = SummationDecapode(parse_decapode(Test13)) infer_types!(t13) - names_types_13 = get_name_type_pair(t13) names_types_expected_13 = Set([(:E, :Form2), (:B, :Form1), (:C, :Form2), (:A, :Form1), (:F, :Form2), (:H, :Form2), (:D, :Form0)]) - @test issetequal(names_types_13, names_types_expected_13) + @test test_nametype_equality(t13, names_types_expected_13) # 2D op2 inference using L Test14 = quote @@ -659,9 +652,8 @@ end t14 = SummationDecapode(parse_decapode(Test14)) infer_types!(t14) - names_types_14 = get_name_type_pair(t14) names_types_expected_14 = Set([(:C, :DualForm2), (:A, :Form1), (:B, :DualForm2)]) - @test issetequal(names_types_14, names_types_expected_14) + @test test_nametype_equality(t14, names_types_expected_14) # 2D op2 inference using i Test15 = quote @@ -674,9 +666,8 @@ end t15 = SummationDecapode(parse_decapode(Test15)) infer_types!(t15) - names_types_15 = get_name_type_pair(t15) names_types_expected_15 = Set([(:A, :Form1), (:B, :DualForm2), (:C, :DualForm1)]) - @test issetequal(names_types_15, names_types_expected_15) + @test test_nametype_equality(t15, names_types_expected_15) # Special case of a summation with a mix of infer and Literal. t16 = @decapode begin @@ -684,9 +675,8 @@ end end infer_types!(t16) - names_types_16 = get_name_type_pair(t16) - names_types_expected_16 = Set([(:A, :Constant), (:C, :Constant), (:D, :Constant), (Symbol("2"), :Literal)]) - @test issetequal(names_types_16, names_types_expected_16) + names_types_expected_16 = Set([(:A, :infer), (:C, :infer), (:D, :infer), (Symbol("2"), :Literal)]) + @test test_nametype_equality(t16, names_types_expected_16) # Special case of a summation with a mix of Form, Constant, and infer. t17 = @decapode begin @@ -696,9 +686,8 @@ end end infer_types!(t17) - names_types_17 = get_name_type_pair(t17) names_types_expected_17 = Set([(:A, :Form0), (:C, :Constant), (:D, :Form0)]) - @test issetequal(names_types_17, names_types_expected_17) + @test test_nametype_equality(t17, names_types_expected_17) # Special case of a summation with a mix of infer and Constant. t18 = @decapode begin @@ -707,11 +696,13 @@ end end infer_types!(t18) - names_types_18 = get_name_type_pair(t18) - names_types_expected_18 = Set([(:A, :Constant), (:C, :Constant), (:D, :Constant)]) - @test issetequal(names_types_18, names_types_expected_18) - let d = @decapode begin + names_types_expected_18 = Set([(:A, :infer), (:C, :Constant), (:D, :infer)]) + @test test_nametype_equality(t18, names_types_expected_18) + + # Test #19: Prevent intermediates from converting to Constant + let + d = @decapode begin h::Form0 Γ::Form1 n::Constant @@ -721,19 +712,18 @@ end end infer_types!(d, op1_inf_rules_1D, op2_inf_rules_1D) - - # TODO: This is modifying an intermediate var to be :Constant, which is user-defined - @test_broken d[18, :type] != :Constant + @test d[18, :type] != :Constant end - let d = @decapode begin + # Test #20: Prevent intermediates from converting to Literal + let + d = @decapode begin h::Form0 Γ::Form1 n::Constant ḣ == ∂ₜ(h) ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2))) - end d = expand_operators(d) @@ -742,6 +732,32 @@ end @test d[8, :type] != :Literal end + # Test #21: Prevent intermediates from converting to Parameter + let + d = @decapode begin + C::Parameter + A == C + D + end + infer_types!(d) + + names_types_expected = Set([(:A, :infer), (:C, :Parameter), (:D, :infer)]) + @test test_nametype_equality(d, names_types_expected) + end + + # Test #22: Have source form information flow over Parameters and Constants + let + d = @decapode begin + B::Form0 + C::Parameter + D::Constant + A == C + D + B + end + infer_types!(d) + + names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) + @test test_nametype_equality(d, names_types_expected) + end + end @testset "Overloading Resolution" begin From 5fa05a6429368bf5a60e448ac43c56198ed26bd4 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:12:23 -0400 Subject: [PATCH 02/11] Loop only through vars to infer --- src/acset.jl | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 3fc0c61..a4ea4d1 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -363,12 +363,31 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) return modify end +""" + filterfor_forms(types::AbstractVector{Symbol}) + +Return any form type symbols. +""" function filterfor_forms(types::AbstractVector{Symbol}) conditions = x -> x != :Literal && x != :Constant && x != :Parameter && x != :infer filter(conditions, types) end -# ! Warning: This is changing types to :Constant when they weren't originally. +""" + filterfor_forms(types::AbstractVector{Symbol}) + +Return the indices of any variables with form types. +""" +function filterfor_forms(d::SummationDecapode, type_idxs::AbstractVector{Int}) + conditions = x -> d[x, :type] != :Literal && + d[x, :type] != :Constant && + d[x, :type] != :Parameter && + d[x, :type] != :infer + + filter(conditions, type_idxs) +end + + function infer_summands_and_summations!(d::SummationDecapode) # Note that we are not doing any type checking here for users! # i.e. We are not checking the underlying types of Constant or Parameter @@ -377,9 +396,9 @@ function infer_summands_and_summations!(d::SummationDecapode) for Σ_idx in parts(d, :Σ) summands = d[incident(d, Σ_idx, :summation), :summand] sum = d[Σ_idx, :sum] - idxs = [summands; sum] + idxs = filterfor_forms(d, [summands; sum]) - forms = unique(filterfor_forms(d[idxs, :type])) + forms = unique(d[idxs, :type]) form = @match length(forms) begin 0 => continue # We need not infer, We can not infer From fe904501772bae6f276d4caeaac6ba918204c24c Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:35:49 -0400 Subject: [PATCH 03/11] Fixed error introduced --- src/acset.jl | 19 ++----------------- test/language.jl | 8 ++++++++ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index a4ea4d1..aa2cebc 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -373,21 +373,6 @@ function filterfor_forms(types::AbstractVector{Symbol}) filter(conditions, types) end -""" - filterfor_forms(types::AbstractVector{Symbol}) - -Return the indices of any variables with form types. -""" -function filterfor_forms(d::SummationDecapode, type_idxs::AbstractVector{Int}) - conditions = x -> d[x, :type] != :Literal && - d[x, :type] != :Constant && - d[x, :type] != :Parameter && - d[x, :type] != :infer - - filter(conditions, type_idxs) -end - - function infer_summands_and_summations!(d::SummationDecapode) # Note that we are not doing any type checking here for users! # i.e. We are not checking the underlying types of Constant or Parameter @@ -396,9 +381,9 @@ function infer_summands_and_summations!(d::SummationDecapode) for Σ_idx in parts(d, :Σ) summands = d[incident(d, Σ_idx, :summation), :summand] sum = d[Σ_idx, :sum] - idxs = filterfor_forms(d, [summands; sum]) + idxs = [summands; sum] - forms = unique(d[idxs, :type]) + forms = unique(filterfor_forms(d[idxs, :type])) form = @match length(forms) begin 0 => continue # We need not infer, We can not infer diff --git a/test/language.jl b/test/language.jl index ea68460..aa73f3d 100644 --- a/test/language.jl +++ b/test/language.jl @@ -384,6 +384,14 @@ import DiagrammaticEquations: safe_modifytype end end +import DiagrammaticEquations: filterfor_forms + +@testset "Form Type Retrieval" begin + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] + @test isempty(filterfor_forms(Symbol[])) +end + @testset "Type Inference" begin # Warning, this testing depends on the fact that varname, form information is # unique within a decapode even though this is not enforced From ba0a3ae8ba50ee088a202a3c0311eaa5087bd080 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:45:45 -0400 Subject: [PATCH 04/11] Moved infer sums loop This brings its function signature more in line with the other inference applications. --- src/acset.jl | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index aa2cebc..c0ef781 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -373,28 +373,29 @@ function filterfor_forms(types::AbstractVector{Symbol}) filter(conditions, types) end -function infer_summands_and_summations!(d::SummationDecapode) +function infer_summands_and_summations!(d::SummationDecapode, Σ_idx::Int) # Note that we are not doing any type checking here for users! # i.e. We are not checking the underlying types of Constant or Parameter applied = false - for Σ_idx in parts(d, :Σ) - summands = d[incident(d, Σ_idx, :summation), :summand] - sum = d[Σ_idx, :sum] - idxs = [summands; sum] + summands = d[incident(d, Σ_idx, :summation), :summand] + sum = d[Σ_idx, :sum] + idxs = [summands; sum] + types = d[idxs, :type] + all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(d[idxs, :type])) + forms = unique(filterfor_forms(types)) - form = @match length(forms) begin - 0 => continue # We need not infer, We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation") - end + form = @match length(forms) begin + 0 => return applied # We can not infer + 1 => only(forms) + _ => error("Type mismatch in summation") + end - for idx in idxs - applied |= safe_modifytype!(d, idx, form) - end + for idx in idxs + applied |= safe_modifytype!(d, idx, form) end + return applied end @@ -468,7 +469,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t this_applied = apply_inference_rule_op1!(d, op1_idx, rule) types_known_op1[op1_idx] = this_applied - applied = applied || this_applied + applied |= this_applied end end @@ -479,11 +480,14 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t this_applied = apply_inference_rule_op2!(d, op2_idx, rule) types_known_op2[op2_idx] = this_applied - applied = applied || this_applied + applied |= this_applied end end - applied = applied || infer_summands_and_summations!(d) + for Σ_idx in parts(d, :Σ) + applied |= infer_summands_and_summations!(d, Σ_idx) + end + applied || break # Break if no rules were applied. end From e673649898011c2f5881a20411be2c90c4026495 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:50:41 -0400 Subject: [PATCH 05/11] Better error for sum type mismatch --- src/acset.jl | 2 +- test/language.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/acset.jl b/src/acset.jl index c0ef781..bcce46d 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -389,7 +389,7 @@ function infer_summands_and_summations!(d::SummationDecapode, Σ_idx::Int) form = @match length(forms) begin 0 => return applied # We can not infer 1 => only(forms) - _ => error("Type mismatch in summation") + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear $forms") end for idx in idxs diff --git a/test/language.jl b/test/language.jl index aa73f3d..1bfc571 100644 --- a/test/language.jl +++ b/test/language.jl @@ -766,6 +766,16 @@ end @test test_nametype_equality(d, names_types_expected) end + # Test #23: Check that different forms summed up error out + let + d = @decapode begin + B::Form0 + C::Form1 + A == C + D + B + end + @test_throws "Type mismatch in summation" infer_types!(d) + end + end @testset "Overloading Resolution" begin From 5ca691311d0914743f438e9feb62e9d8ae57ea67 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:56:16 -0400 Subject: [PATCH 06/11] Turn on op2 lookup --- src/acset.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/acset.jl b/src/acset.jl index bcce46d..047ecfa 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -454,7 +454,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t types_known_op1[incident(d, :infer, [:src, :type])] .= false types_known_op1[incident(d, :infer, [:tgt, :type])] .= false - types_known_op2 = zeros(Bool, nparts(d, :Op2)) + types_known_op2 = ones(Bool, nparts(d, :Op2)) types_known_op2[incident(d, :infer, [:proj1, :type])] .= false types_known_op2[incident(d, :infer, [:proj2, :type])] .= false types_known_op2[incident(d, :infer, [:res, :type])] .= false From 91ed7a41f6904a15438511f9966aaaa57c9f7150 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:14:12 -0400 Subject: [PATCH 07/11] Touching up --- src/acset.jl | 2 +- test/language.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 047ecfa..377856a 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -389,7 +389,7 @@ function infer_summands_and_summations!(d::SummationDecapode, Σ_idx::Int) form = @match length(forms) begin 0 => return applied # We can not infer 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear $forms") + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") end for idx in idxs diff --git a/test/language.jl b/test/language.jl index 1bfc571..8f47588 100644 --- a/test/language.jl +++ b/test/language.jl @@ -704,7 +704,6 @@ end end infer_types!(t18) - names_types_expected_18 = Set([(:A, :infer), (:C, :Constant), (:D, :infer)]) @test test_nametype_equality(t18, names_types_expected_18) From cff0047db44fd9b7a3538d15f3e6dbea3593c105 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:21:48 -0400 Subject: [PATCH 08/11] A final test --- test/language.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/language.jl b/test/language.jl index 8f47588..2b65aea 100644 --- a/test/language.jl +++ b/test/language.jl @@ -765,7 +765,21 @@ end @test test_nametype_equality(d, names_types_expected) end - # Test #23: Check that different forms summed up error out + # Test #23: Have target form information flow over Parameters and Constants + let + d = @decapode begin + A::Form0 + C::Parameter + D::Constant + A == C + D + B + end + infer_types!(d) + + names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) + @test test_nametype_equality(d, names_types_expected) + end + + # Test #24: Check that different forms summed up error out let d = @decapode begin B::Form0 From 0fe6987b000ac684a3fa75d6c2fabd33569cb989 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 24 Jun 2024 12:18:56 -0400 Subject: [PATCH 09/11] Clarifying filterfor_forms --- src/acset.jl | 3 ++- test/language.jl | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/acset.jl b/src/acset.jl index 377856a..c3dddff 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -369,7 +369,8 @@ end Return any form type symbols. """ function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> x != :Literal && x != :Constant && x != :Parameter && x != :infer + nonform_symbols = [:Literal, :Constant, :Parameter, :infer] + conditions = x -> !(x in nonform_symbols) filter(conditions, types) end diff --git a/test/language.jl b/test/language.jl index 2b65aea..da81399 100644 --- a/test/language.jl +++ b/test/language.jl @@ -390,6 +390,7 @@ import DiagrammaticEquations: filterfor_forms all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] @test isempty(filterfor_forms(Symbol[])) + @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin From 3dc34beecfb9125805510932a0935b2b5f367b4b Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:46:20 -0400 Subject: [PATCH 10/11] Better type group handling I'm testing the use of getter functions to return us type groups to avoid us having to type them out. We can add more groups if needed. --- src/acset.jl | 31 +++++++++++++++++++++++++------ test/language.jl | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index c3dddff..887edd0 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -155,13 +155,33 @@ function make_sum_mult_unique!(d::AbstractNamedDecapode) end end +# A collection of DecaType getters +# TODO: This should be replaced by using a type hierarchy +macro get_alltypes() + [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Parameter, :Constant, :infer] +end + +macro get_formtypes() return [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] end +macro get_primalformtypes() return [:Form0, :Form1, :Form2] end +macro get_dualformtypes() return [:DualForm0, :DualForm1, :DualForm2] end + +macro get_nonformtypes() return [:Constant, :Parameter, :Literal, :infer] end +macro get_usertypes() return [:Constant, :Parameter] end +macro get_numbertypes() return [:Literal] end +macro get_infertypes() return [:infer] end + +# Types that can not ever be inferred +macro get_noninferabletypes() return [:Constant, :Parameter, :Literal] end + +function get_unsupportedtypes(types) + setdiff(types, @get_alltypes()) +end + # Note: This hard-bakes in Form0 through Form2, and higher Forms are not # allowed. function recognize_types(d::AbstractNamedDecapode) types = d[:type] - unrecognized_types = setdiff(d[:type], [:Form0, :Form1, :Form2, :DualForm0, - :DualForm1, :DualForm2, :Literal, :Parameter, - :Constant, :infer]) + unrecognized_types = get_unsupportedtypes(types) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end @@ -349,7 +369,7 @@ This function accepts an original type and a new type and determines if the orig can be safely overwritten by the new type. """ function safe_modifytype(org_type::Symbol, new_type::Symbol) - modify = (org_type == :infer && !(new_type == :Literal || new_type == :Constant || new_type == :Parameter)) + modify = (org_type in @get_infertypes() && !(new_type in @get_noninferabletypes())) return (modify, modify ? new_type : org_type) end @@ -369,8 +389,7 @@ end Return any form type symbols. """ function filterfor_forms(types::AbstractVector{Symbol}) - nonform_symbols = [:Literal, :Constant, :Parameter, :infer] - conditions = x -> !(x in nonform_symbols) + conditions = x -> !(x in @get_nonformtypes()) filter(conditions, types) end diff --git a/test/language.jl b/test/language.jl index da81399..3f944a8 100644 --- a/test/language.jl +++ b/test/language.jl @@ -356,6 +356,45 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end +import DiagrammaticEquations: @get_alltypes, @get_formtypes, @get_primalformtypes, @get_dualformtypes, +@get_nonformtypes, @get_usertypes, @get_numbertypes, @get_noninferabletypes, @get_infertypes +@testset "Type Retrival" begin + calls = [:@get_alltypes, :@get_formtypes, :@get_primalformtypes, :@get_dualformtypes, + :@get_nonformtypes, :@get_usertypes, :@get_numbertypes, :@get_noninferabletypes, :@get_infertypes] + + # No repeated types + for call in calls + res = @eval $(call) + @test unique(res) == res + end + + equal_types(types_1, types_2) = issetequal(Set(types_1), Set(types_2)) + no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) + + # Collections of these types should be the same + @test equal_types(@get_alltypes(), vcat(@get_formtypes(), @get_nonformtypes())) + @test equal_types(@get_formtypes(), vcat(@get_primalformtypes(), @get_dualformtypes())) + @test equal_types(@get_noninferabletypes(), vcat(@get_usertypes(), @get_numbertypes())) + + # Proper seperation of types + @test no_overlaps(@get_formtypes(), @get_nonformtypes()) + @test no_overlaps(@get_primalformtypes(), @get_dualformtypes()) + @test no_overlaps(@get_noninferabletypes(), @get_formtypes()) + @test @get_infertypes() == [:infer] + + @test no_overlaps(@get_formtypes(), @get_numbertypes()) + @test no_overlaps(@get_formtypes(), @get_usertypes()) + @test no_overlaps(@get_usertypes(), @get_numbertypes()) +end + +import DiagrammaticEquations: get_unsupportedtypes +@testset "Type Validation" begin + @test isempty(get_unsupportedtypes(@get_alltypes())) + @test [:A] == get_unsupportedtypes([:A]) + @test [:A] == get_unsupportedtypes([:Form1, :A]) + @test isempty(get_unsupportedtypes(Symbol[])) +end + import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin From ecfe809063e522c3e804ae59fbe2b141c10ab350 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Mon, 24 Jun 2024 15:34:00 -0400 Subject: [PATCH 11/11] Changed macros to const --- src/acset.jl | 31 +++++++++++++++---------------- test/language.jl | 42 ++++++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/src/acset.jl b/src/acset.jl index 887edd0..cb633e7 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -157,24 +157,23 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy -macro get_alltypes() - [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Parameter, :Constant, :infer] -end +const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, + :Literal, :Parameter, :Constant, :infer] -macro get_formtypes() return [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] end -macro get_primalformtypes() return [:Form0, :Form1, :Form2] end -macro get_dualformtypes() return [:DualForm0, :DualForm1, :DualForm2] end +const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] +const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] +const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -macro get_nonformtypes() return [:Constant, :Parameter, :Literal, :infer] end -macro get_usertypes() return [:Constant, :Parameter] end -macro get_numbertypes() return [:Literal] end -macro get_infertypes() return [:infer] end +const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const USER_TYPES = [:Constant, :Parameter] +const NUMBER_TYPES = [:Literal] +const INFER_TYPES = [:infer] # Types that can not ever be inferred -macro get_noninferabletypes() return [:Constant, :Parameter, :Literal] end +const NONINFERABLE_TYPES = [:Constant, :Parameter, :Literal] function get_unsupportedtypes(types) - setdiff(types, @get_alltypes()) + setdiff(types, ALL_TYPES) end # Note: This hard-bakes in Form0 through Form2, and higher Forms are not @@ -369,7 +368,7 @@ This function accepts an original type and a new type and determines if the orig can be safely overwritten by the new type. """ function safe_modifytype(org_type::Symbol, new_type::Symbol) - modify = (org_type in @get_infertypes() && !(new_type in @get_noninferabletypes())) + modify = (org_type in INFER_TYPES && !(new_type in NONINFERABLE_TYPES)) return (modify, modify ? new_type : org_type) end @@ -389,11 +388,11 @@ end Return any form type symbols. """ function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in @get_nonformtypes()) + conditions = x -> !(x in NONFORM_TYPES) filter(conditions, types) end -function infer_summands_and_summations!(d::SummationDecapode, Σ_idx::Int) +function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) # Note that we are not doing any type checking here for users! # i.e. We are not checking the underlying types of Constant or Parameter applied = false @@ -505,7 +504,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t end for Σ_idx in parts(d, :Σ) - applied |= infer_summands_and_summations!(d, Σ_idx) + applied |= infer_sum_types!(d, Σ_idx) end applied || break # Break if no rules were applied. diff --git a/test/language.jl b/test/language.jl index 3f944a8..a90b088 100644 --- a/test/language.jl +++ b/test/language.jl @@ -356,40 +356,42 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: @get_alltypes, @get_formtypes, @get_primalformtypes, @get_dualformtypes, -@get_nonformtypes, @get_usertypes, @get_numbertypes, @get_noninferabletypes, @get_infertypes +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, + NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES + @testset "Type Retrival" begin - calls = [:@get_alltypes, :@get_formtypes, :@get_primalformtypes, :@get_dualformtypes, - :@get_nonformtypes, :@get_usertypes, :@get_numbertypes, :@get_noninferabletypes, :@get_infertypes] + + type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, + NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + # No repeated types - for call in calls - res = @eval $(call) - @test unique(res) == res + for type_group in type_groups + @test allunique(type_group) end equal_types(types_1, types_2) = issetequal(Set(types_1), Set(types_2)) no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(@get_alltypes(), vcat(@get_formtypes(), @get_nonformtypes())) - @test equal_types(@get_formtypes(), vcat(@get_primalformtypes(), @get_dualformtypes())) - @test equal_types(@get_noninferabletypes(), vcat(@get_usertypes(), @get_numbertypes())) + @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) + @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) + @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) # Proper seperation of types - @test no_overlaps(@get_formtypes(), @get_nonformtypes()) - @test no_overlaps(@get_primalformtypes(), @get_dualformtypes()) - @test no_overlaps(@get_noninferabletypes(), @get_formtypes()) - @test @get_infertypes() == [:infer] - - @test no_overlaps(@get_formtypes(), @get_numbertypes()) - @test no_overlaps(@get_formtypes(), @get_usertypes()) - @test no_overlaps(@get_usertypes(), @get_numbertypes()) + @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) + @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) + @test INFER_TYPES == [:infer] + + @test no_overlaps(FORM_TYPES, NUMBER_TYPES) + @test no_overlaps(FORM_TYPES, USER_TYPES) + @test no_overlaps(USER_TYPES, NUMBER_TYPES) end import DiagrammaticEquations: get_unsupportedtypes @testset "Type Validation" begin - @test isempty(get_unsupportedtypes(@get_alltypes())) + @test isempty(get_unsupportedtypes(ALL_TYPES)) @test [:A] == get_unsupportedtypes([:A]) @test [:A] == get_unsupportedtypes([:Form1, :A]) @test isempty(get_unsupportedtypes(Symbol[])) @@ -819,7 +821,7 @@ end @test test_nametype_equality(d, names_types_expected) end - # Test #24: Check that different forms summed up error out + # Test #24: Summing mismatched forms throws an error let d = @decapode begin B::Form0