diff --git a/src/acset.jl b/src/acset.jl index 5366459..34d1d3c 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -466,13 +466,10 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) end function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - type_src = d[d[op1_id, :src], :type] - type_tgt = d[d[op1_id, :tgt], :type] + score_src = (rule.src_type == d[d[op1_id, :src], :type]) + score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) - score_src = (rule.src_type == type_src) - score_tgt = (rule.tgt_type == type_tgt) check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) @@ -483,33 +480,16 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) end function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - type_proj1 = d[d[op2_id, :proj1], :type] - type_proj2 = d[d[op2_id, :proj2], :type] - type_res = d[d[op2_id, :res], :type] + score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) + score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) + score_res = (rule.res_type == d[d[op2_id, :res], :type]) - score_proj1 = (rule.proj1_type == type_proj1) - score_proj2 = (rule.proj2_type == type_proj2) - score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res - # Special logic for exponentiation: - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0) - return mod_proj1 || mod_res - elseif d[op2_id, :op2] == :^ && - (type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) || - (type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer)) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0) - mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0) - return mod_proj1 || mod_res end return false diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index b557553..4334c4b 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -90,7 +90,11 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])] + (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + + # These rules contain infer: + (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), + (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] """ These are the default rules used to do type inference in the 2D exterior calculus.