Skip to content

Commit

Permalink
Take advantage of :infer in type rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lukem12345 authored and jpfairbanks committed Aug 23, 2024
1 parent 79f27e7 commit 23fbf3f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 27 deletions.
32 changes: 6 additions & 26 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,10 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -483,33 +480,16 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
# Special logic for exponentiation:
elseif d[op2_id, :op2] == :^ &&
(type_proj1 == :Form0 && (type_proj2 == :infer || type_res == :infer)) ||
(type_res == :Form0 && (type_proj1 == :infer || type_proj2 == :infer))
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :Form0)
mod_res = safe_modifytype!(d, d[op2_id, :res], :Form0)
return mod_proj1 || mod_res
elseif d[op2_id, :op2] == :^ &&
(type_proj1 == :DualForm0 && (type_proj2 == :infer || type_res == :infer)) ||
(type_res == :DualForm0 && (type_proj1 == :infer || type_proj2 == :infer))
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], :DualForm0)
mod_res = safe_modifytype!(d, d[op2_id, :res], :DualForm0)
return mod_proj1 || mod_res
end

return false
Expand Down
6 changes: 5 additions & 1 deletion src/deca/deca_acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ op2_inf_rules_1D = [
(proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]),
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])]
(proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]),

# These rules contain infer:
(proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]),
(proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])]

"""
These are the default rules used to do type inference in the 2D exterior calculus.
Expand Down

0 comments on commit 23fbf3f

Please sign in to comment.