Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work on summation inference #49

Merged
merged 11 commits into from
Jun 24, 2024
114 changes: 69 additions & 45 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,32 @@ function make_sum_mult_unique!(d::AbstractNamedDecapode)
end
end

# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]

# Types that can not ever be inferred
const NONINFERABLE_TYPES = [:Constant, :Parameter, :Literal]

function get_unsupportedtypes(types)
lukem12345 marked this conversation as resolved.
Show resolved Hide resolved
setdiff(types, ALL_TYPES)
end

# Note: This hard-bakes in Form0 through Form2, and higher Forms are not
# allowed.
function recognize_types(d::AbstractNamedDecapode)
types = d[:type]
unrecognized_types = setdiff(d[:type], [:Form0, :Form1, :Form2, :DualForm0,
:DualForm1, :DualForm2, :Literal, :Parameter,
:Constant, :infer])
unrecognized_types = get_unsupportedtypes(types)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
Expand Down Expand Up @@ -349,51 +368,53 @@ This function accepts an original type and a new type and determines if the orig
can be safely overwritten by the new type.
"""
function safe_modifytype(org_type::Symbol, new_type::Symbol)
modify = (org_type == :infer && !(new_type == :Literal || new_type == :Constant || new_type == :Parameter))
modify = (org_type in INFER_TYPES && !(new_type in NONINFERABLE_TYPES))
return (modify, modify ? new_type : org_type)
end

"""
safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol)
safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)

This function calls `safe_modifytype` to safely modify a Decapode's variable type.
"""
function safe_modifytype!(d::SummationDecapode, var_idx::Int, org_type::Symbol, new_type::Symbol)
modify, d[var_idx, :type] = safe_modifytype(org_type, new_type)
function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
modify, d[var_idx, :type] = safe_modifytype(d[var_idx, :type], new_type)
return modify
end

# ! Warning: This is changing types to :Constant when they weren't originally.
# ! This should be refactored to only change types into Forms
function infer_summands_and_summations!(d::SummationDecapode)
# Note that we are not doing any type checking here!
# i.e. We are not checking for this: [Form0, Form1, Form0].
"""
filterfor_forms(types::AbstractVector{Symbol})

Return any form type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
filter(conditions, types)
end

function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
# Note that we are not doing any type checking here for users!
# i.e. We are not checking the underlying types of Constant or Parameter
applied = false
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
for Σ_idx in parts(d, :Σ)
summands = d[:summand][incident(d, Σ_idx, :summation)]
sum = d[:sum][Σ_idx]
idxs = [summands; sum]
types = d[:type][idxs]
all(t != :infer for t in types) && continue # We need not infer
all(t == :infer for t in types) && continue # We can not infer

known_types = types[findall(!=(:infer), types)]
if :Literal ∈ known_types
# If anything is a Literal, then anything not inferred is a Constant.
inferred_type = :Constant
elseif !isnothing(findfirst(!=(:Constant), known_types))
# If anything is a Form, then any term in this sum is the same kind of Form.
# Note that we are not explicitly changing Constants to Forms here,
# although we should consider doing so.
inferred_type = known_types[findfirst(!=(:Constant), known_types)]
else
# All terms are now a mix of Constant or infer. Set them all to Constant.
inferred_type = :Constant
end
to_infer_idxs = filter(i -> d[:type][i] == :infer, idxs)
d[to_infer_idxs, :type] = inferred_type
applied = true

summands = d[incident(d, Σ_idx, :summation), :summand]
sum = d[Σ_idx, :sum]
idxs = [summands; sum]
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))

form = @match length(forms) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
end

for idx in idxs
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
applied |= safe_modifytype!(d, idx, form)
end

return applied
end

Expand All @@ -406,8 +427,8 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], type_src, rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], type_tgt, rule.tgt_type)
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
return mod_src || mod_tgt
end

Expand All @@ -425,9 +446,9 @@ function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], type_proj1, rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], type_proj2, rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], type_res, rule.res_type)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand All @@ -452,7 +473,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
types_known_op1[incident(d, :infer, [:src, :type])] .= false
types_known_op1[incident(d, :infer, [:tgt, :type])] .= false

types_known_op2 = zeros(Bool, nparts(d, :Op2))
types_known_op2 = ones(Bool, nparts(d, :Op2))
types_known_op2[incident(d, :infer, [:proj1, :type])] .= false
types_known_op2[incident(d, :infer, [:proj2, :type])] .= false
types_known_op2[incident(d, :infer, [:res, :type])] .= false
Expand All @@ -467,7 +488,7 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
this_applied = apply_inference_rule_op1!(d, op1_idx, rule)

types_known_op1[op1_idx] = this_applied
applied = applied || this_applied
applied |= this_applied
end
end

Expand All @@ -478,11 +499,14 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
this_applied = apply_inference_rule_op2!(d, op2_idx, rule)

types_known_op2[op2_idx] = this_applied
applied = applied || this_applied
applied |= this_applied
end
end

applied = applied || infer_summands_and_summations!(d)
for Σ_idx in parts(d, :Σ)
applied |= infer_sum_types!(d, Σ_idx)
end

applied || break # Break if no rules were applied.
end

Expand Down
Loading
Loading