generated from AlgebraicJulia/AlgebraicTemplate.jl
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing Operator types #84
Open
GeorgeR227
wants to merge
19
commits into
main
Choose a base branch
from
gr/impv_infer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+584
−492
Open
Changes from 15 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
93494d8
Made inference rules generic
GeorgeR227 5bc8f8c
Added support for overloading using Operator
GeorgeR227 868f520
Cleaned up inf/res rules into single set
GeorgeR227 87cfccf
Added Symbol Operator constructor
GeorgeR227 8f8188d
Removed HeatXfer and added other tests
GeorgeR227 4fbe011
Basic type checking added
GeorgeR227 78823f8
Much improved type checking
GeorgeR227 17dba52
Don't throw error, return false
GeorgeR227 c389f1d
Some cleanup
GeorgeR227 c8c6d92
Complete type checking
GeorgeR227 5fc4c85
Moved exports
GeorgeR227 5e9a12a
Use explicit minimum idiom in op type checking
lukem12345 92a40c9
Fixed test
GeorgeR227 b4ba787
Fix typo
GeorgeR227 94f99cd
Changed check_operator
GeorgeR227 d81ea0b
Fix docstrings in acset.jl
GeorgeR227 ce81442
Change not equals
GeorgeR227 80bdec5
First pass at ruleset ambiguity checking
GeorgeR227 1f9086d
Dual types work for PartialT
GeorgeR227 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ using ACSets.InterTypes | |
|
||
@intertypes "decapodeacset.it" module decapodeacset end | ||
|
||
import Base.show | ||
|
||
using .decapodeacset | ||
|
||
# Transferring pointers | ||
|
@@ -362,7 +364,7 @@ function find_chains(d::SummationDecapode; | |
|
||
filter!(x -> passes_white_list(d[x, :op1]), chain_starts) | ||
filter!(x -> passes_black_list(d[x, :op1]), chain_starts) | ||
|
||
s = Stack{Int64}() | ||
foreach(x -> push!(s, x), chain_starts) | ||
while !isempty(s) | ||
|
@@ -440,6 +442,50 @@ function filterfor_ec_types(types::AbstractVector{Symbol}) | |
filter(conditions, types) | ||
end | ||
|
||
struct Operator{T} | ||
res_type::T | ||
src_types::AbstractVector{T} | ||
op_name::Symbol | ||
aliases::AbstractVector{Symbol} | ||
|
||
function Operator{T}(res_type::T, src_types::AbstractVector{T}, op_name, aliases = Symbol[]) where T | ||
new(res_type, src_types, op_name, aliases) | ||
end | ||
|
||
function Operator{T}(res_type::T, src_type::T, op_name, aliases = Symbol[]) where T | ||
new(res_type, T[src_type], op_name, aliases) | ||
end | ||
|
||
function Operator(res_type::Symbol, src_type::Union{Symbol, AbstractVector{Symbol}}, op_name, aliases = Symbol[]) | ||
Operator{Symbol}(res_type, src_type, op_name, aliases) | ||
end | ||
end | ||
|
||
function same_type_rules_op(op_name::Symbol, types::AbstractVector{Symbol}, arity::Int, g_aliases::AbstractVector{Symbol} = Symbol[], sp_aliases::AbstractVector = Symbol[]) | ||
@assert isempty(sp_aliases) || length(types) == length(sp_aliases) | ||
map(1:length(types)) do i | ||
aliases = isempty(sp_aliases) ? g_aliases : vcat(g_aliases, sp_aliases[i]) | ||
Operator{Symbol}(types[i], repeat([types[i]], arity), op_name, aliases) | ||
end | ||
end | ||
|
||
function arithmetic_operators(op_name::Symbol, broadcasted::Bool, arity::Int = 2) | ||
@match (broadcasted, arity) begin | ||
(true, 2) => bin_broad_arith_ops(op_name) | ||
_ => error("This type of arithmetic operator is not yet supported or may not be valid.") | ||
end | ||
end | ||
|
||
function bin_broad_arith_ops(op_name) | ||
all_ops = map(t -> Operator{Symbol}(t, [t, t], op_name), FORM_TYPES) | ||
for type in vcat(USER_TYPES, NUMBER_TYPES) | ||
append!(all_ops, map(t -> Operator{Symbol}(t, [t, type], op_name), FORM_TYPES)) | ||
append!(all_ops, map(t -> Operator{Symbol}(t, [type, t], op_name), FORM_TYPES)) | ||
end | ||
|
||
all_ops | ||
end | ||
|
||
function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) | ||
# Note that we are not doing any type checking here for users! | ||
# i.e. We are not checking the underlying types of Constant or Parameter | ||
|
@@ -466,36 +512,105 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) | |
return applied | ||
end | ||
|
||
function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) | ||
score_src = (rule.src_type == d[d[op1_id, :src], :type]) | ||
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) | ||
""" | ||
check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring formatting |
||
|
||
check_op = (d[op1_id, :op1] in rule.op_names) | ||
if(check_op && (score_src + score_tgt == 1)) | ||
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) | ||
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) | ||
return mod_src || mod_tgt | ||
Cross references a given operator's name and its input/ouput types with a given rule. It | ||
reutrns the number of differences in the types. If the rule does not apply to this operator, | ||
which is checked by naming matching, the type difference is Inf. | ||
""" | ||
function check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false) | ||
inputs = edge_inputs(d, op_id, edge_val) | ||
output = edge_output(d, op_id, edge_val) | ||
|
||
max_score = length(inputs) + length(output) | ||
|
||
rule_types = vcat(rule.src_types, rule.res_type) | ||
deca_types = vcat(d[inputs, :type], d[output, :type]) | ||
|
||
score = mapreduce(+, rule_types, deca_types; init = 0) do rule_t, deca_t | ||
if ignore_infers && deca_t in INFER_TYPES | ||
return 1 | ||
elseif ignore_usertypes && deca_t in USER_TYPES | ||
return 1 | ||
else | ||
return rule_t == deca_t | ||
end | ||
end | ||
|
||
dop_name = edge_function(d, op_id, edge_val) | ||
|
||
named = check_name && dop_name == rule.op_name | ||
aliased = check_aliases && dop_name in rule.aliases | ||
|
||
return (named || aliased) ? max_score - score : Inf | ||
end | ||
|
||
function apply_inference_rule!(d::SummationDecapode, op_id, rule, edge_val) | ||
|
||
type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true) | ||
|
||
if type_diff == 1 | ||
vars = vcat(edge_inputs(d, op_id, edge_val), edge_output(d, op_id, edge_val)) | ||
types = vcat(rule.src_types, rule.res_type) | ||
return any(map(vars, types) do var, type | ||
safe_modifytype!(d, var, type) | ||
end) | ||
end | ||
|
||
return false | ||
end | ||
|
||
function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) | ||
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) | ||
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) | ||
score_res = (rule.res_type == d[d[op2_id, :res], :type]) | ||
function apply_overloading_rule!(d::SummationDecapode, op_id, rule, edge_val) | ||
|
||
type_diff = check_operator(d, op_id, rule, edge_val; check_aliases = true) | ||
|
||
check_op = (d[op2_id, :op2] in rule.op_names) | ||
if check_op && (score_proj1 + score_proj2 + score_res == 2) | ||
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) | ||
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) | ||
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) | ||
return mod_proj1 || mod_proj2 || mod_res | ||
if type_diff == 0 | ||
set_edge_label!(d, op_id, rule.op_name, edge_val) | ||
return true | ||
end | ||
|
||
return false | ||
end | ||
|
||
struct DecaTypeError{T} | ||
rule::Operator{T} | ||
idx::Int | ||
table::Symbol | ||
end | ||
|
||
Base.show(io::IO, type_error::DecaTypeError{T}) where T = println("Operator at index $(type_error.idx) in table $(type_error.table) is not correctly typed. Perhaps the operator was meant to be $(type_error.rule)?") | ||
|
||
struct DecaTypeExeception{T} <: Exception | ||
type_errors::Vector{DecaTypeError{T}} | ||
end | ||
|
||
function Base.show(io::IO, type_except::DecaTypeExeception{T}) where T | ||
map(x -> Base.show(io, x), type_except.type_errors) | ||
end | ||
|
||
function run_typechecking(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) | ||
|
||
type_errors = DecaTypeError{Symbol}[] | ||
|
||
for table in [:Op1, :Op2] | ||
for op_idx in parts(d, table) | ||
type_error = run_typechecking_for_op(d, op_idx, type_rules, Val(table)) | ||
if type_error !== nothing | ||
push!(type_errors, type_error) | ||
end | ||
end | ||
end | ||
|
||
return type_errors | ||
end | ||
|
||
function run_typechecking_for_op(d::SummationDecapode, op_id, type_rules, edge_val::Val{table}) where table | ||
min_diff, min_rule_idx = findmin(type_rules) do rule | ||
check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true, ignore_infers = true, ignore_usertypes = true) | ||
end | ||
min_diff in [0,Inf] ? nothing : DecaTypeError{Symbol}(type_rules[min_rule_idx], op_id, table) | ||
end | ||
|
||
# TODO: Although the big-O complexity is the same, it might be more efficent on | ||
# average to iterate over edges then rules, instead of rules then edges. This | ||
|
@@ -506,7 +621,7 @@ end | |
|
||
Infer types of Vars given rules wherein one type is known and the other not. | ||
""" | ||
function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :op_names), Tuple{Symbol, Symbol, Vector{Symbol}}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :op_names), Tuple{Symbol, Symbol, Symbol, Vector{Symbol}}}}) | ||
function infer_types!(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) | ||
|
||
# This is an optimization so we do not "visit" a row which has no infer types. | ||
# It could be deleted if found to be not worth maintainability tradeoff. | ||
|
@@ -519,28 +634,23 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t | |
types_known_op2[incident(d, :infer, [:proj2, :type])] .= false | ||
types_known_op2[incident(d, :infer, [:res, :type])] .= false | ||
|
||
types_known = Dict{Symbol, Vector{Bool}}(:Op1 => types_known_op1, :Op2 => types_known_op2) | ||
|
||
while true | ||
applied = false | ||
|
||
for rule in op1_rules | ||
for op1_idx in parts(d, :Op1) | ||
types_known_op1[op1_idx] && continue | ||
|
||
this_applied = apply_inference_rule_op1!(d, op1_idx, rule) | ||
|
||
types_known_op1[op1_idx] = this_applied | ||
applied |= this_applied | ||
end | ||
end | ||
for table in [:Op1, :Op2] | ||
for op_idx in parts(d, table) | ||
types_known[table][op_idx] && continue | ||
|
||
for rule in op2_rules | ||
for op2_idx in parts(d, :Op2) | ||
types_known_op2[op2_idx] && continue | ||
for rule in type_rules | ||
this_applied = apply_inference_rule!(d, op_idx, rule, Val(table)) | ||
|
||
this_applied = apply_inference_rule_op2!(d, op2_idx, rule) | ||
types_known[table][op_idx] = this_applied | ||
applied |= this_applied | ||
end | ||
|
||
types_known_op2[op2_idx] = this_applied | ||
applied |= this_applied | ||
end | ||
end | ||
|
||
|
@@ -554,38 +664,53 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t | |
d | ||
end | ||
|
||
|
||
|
||
""" function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}) | ||
|
||
Resolve function overloads based on types of src and tgt. | ||
""" | ||
function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :resolved_name, :op), NTuple{5, Symbol}}}) | ||
for op1_idx in parts(d, :Op1) | ||
src = d[:src][op1_idx]; tgt = d[:tgt][op1_idx]; op1 = d[:op1][op1_idx] | ||
src_type = d[:type][src]; tgt_type = d[:type][tgt] | ||
for rule in op1_rules | ||
if op1 == rule[:op] && src_type == rule[:src_type] && tgt_type == rule[:tgt_type] | ||
d[op1_idx, :op1] = rule[:resolved_name] | ||
break | ||
end | ||
end | ||
end | ||
|
||
for op2_idx in parts(d, :Op2) | ||
proj1 = d[:proj1][op2_idx]; proj2 = d[:proj2][op2_idx]; res = d[:res][op2_idx]; op2 = d[:op2][op2_idx] | ||
proj1_type = d[:type][proj1]; proj2_type = d[:type][proj2]; res_type = d[:type][res] | ||
for rule in op2_rules | ||
if op2 == rule[:op] && proj1_type == rule[:proj1_type] && proj2_type == rule[:proj2_type] && res_type == rule[:res_type] | ||
d[op2_idx, :op2] = rule[:resolved_name] | ||
break | ||
function resolve_overloads!(d::SummationDecapode, resolve_rules::AbstractVector{Operator{Symbol}}) | ||
for rule in resolve_rules | ||
for table in [:Op1, :Op2] | ||
for op_idx in parts(d, table) | ||
apply_overloading_rule!(d, op_idx, rule, Val(table)) | ||
end | ||
end | ||
end | ||
|
||
d | ||
end | ||
|
||
""" | ||
type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) | ||
|
||
Takes a Decapode and a set of rules and checks to see if the operators that are in the Decapode | ||
contain a valid configuration of input/output types. If an operator in the Decapode does not | ||
contain a rule in the rule set it will be seen as valid. | ||
|
||
In the case of a type error a DecaTypeExeception is thrown. Otherwise true is returned. | ||
""" | ||
function type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) | ||
GeorgeR227 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
type_errors = run_typechecking(d, type_rules) | ||
|
||
isempty(type_errors) && return true | ||
|
||
throw(DecaTypeExeception{Symbol}(type_errors)) | ||
return false | ||
end | ||
|
||
|
||
""" | ||
infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}}) | ||
|
||
Runs type inference, overload resolution and type checking in that order. | ||
""" | ||
function infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}}) | ||
infer_types!(d, operators) | ||
resolve_overloads!(d, operators) | ||
type_check(d, operators) | ||
|
||
d | ||
end | ||
|
||
function replace_names!(d::SummationDecapode, op1_repls::Vector{Pair{Symbol, Any}}, op2_repls::Vector{Pair{Symbol, Symbol}}) | ||
for (orig,repl) in op1_repls | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it is trying to emulate the style of having an abstract type for an "Operator", can this idiom be used explicitly.