Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Operator types #84

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ using Catlab
export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D,
op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D,
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators,
recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
## collages
Collage, collate,
Expand All @@ -18,8 +17,7 @@ apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
resolve_overloads!, replace_names!,
apply_inference_rule_op1!, apply_inference_rule_op2!,
resolve_overloads!, replace_names!, type_check,
transfer_parents!, transfer_children!,
unique_lits!,
## language
Expand All @@ -32,7 +30,9 @@ to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!,
Operator, infer_resolve!, type_check, DecaTypeExeception


using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
Expand Down
237 changes: 181 additions & 56 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using ACSets.InterTypes

@intertypes "decapodeacset.it" module decapodeacset end

import Base.show

using .decapodeacset

# Transferring pointers
Expand Down Expand Up @@ -362,7 +364,7 @@ function find_chains(d::SummationDecapode;

filter!(x -> passes_white_list(d[x, :op1]), chain_starts)
filter!(x -> passes_black_list(d[x, :op1]), chain_starts)

s = Stack{Int64}()
foreach(x -> push!(s, x), chain_starts)
while !isempty(s)
Expand Down Expand Up @@ -440,6 +442,50 @@ function filterfor_ec_types(types::AbstractVector{Symbol})
filter(conditions, types)
end

struct Operator{T}
res_type::T
src_types::AbstractVector{T}
op_name::Symbol
aliases::AbstractVector{Symbol}

function Operator{T}(res_type::T, src_types::AbstractVector{T}, op_name, aliases = Symbol[]) where T
new(res_type, src_types, op_name, aliases)
end

function Operator{T}(res_type::T, src_type::T, op_name, aliases = Symbol[]) where T
new(res_type, T[src_type], op_name, aliases)
end

function Operator(res_type::Symbol, src_type::Union{Symbol, AbstractVector{Symbol}}, op_name, aliases = Symbol[])
Operator{Symbol}(res_type, src_type, op_name, aliases)
end
end

function same_type_rules_op(op_name::Symbol, types::AbstractVector{Symbol}, arity::Int, g_aliases::AbstractVector{Symbol} = Symbol[], sp_aliases::AbstractVector = Symbol[])
@assert isempty(sp_aliases) || length(types) == length(sp_aliases)
map(1:length(types)) do i
aliases = isempty(sp_aliases) ? g_aliases : vcat(g_aliases, sp_aliases[i])
Operator{Symbol}(types[i], repeat([types[i]], arity), op_name, aliases)
end
end

function arithmetic_operators(op_name::Symbol, broadcasted::Bool, arity::Int = 2)
@match (broadcasted, arity) begin
(true, 2) => bin_broad_arith_ops(op_name)
_ => error("This type of arithmetic operator is not yet supported or may not be valid.")
end
end

function bin_broad_arith_ops(op_name)
all_ops = map(t -> Operator{Symbol}(t, [t, t], op_name), FORM_TYPES)
for type in vcat(USER_TYPES, NUMBER_TYPES)
append!(all_ops, map(t -> Operator{Symbol}(t, [t, type], op_name), FORM_TYPES))
append!(all_ops, map(t -> Operator{Symbol}(t, [type, t], op_name), FORM_TYPES))
end

all_ops
end
Comment on lines +442 to +484
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it is trying to emulate the style of having an abstract type for an "Operator", can this idiom be used explicitly.


function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
# Note that we are not doing any type checking here for users!
# i.e. We are not checking the underlying types of Constant or Parameter
Expand All @@ -466,36 +512,105 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])
"""
check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring formatting


check_op = (d[op1_id, :op1] in rule.op_names)
if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
return mod_src || mod_tgt
Cross references a given operator's name and its input/ouput types with a given rule. It
reutrns the number of differences in the types. If the rule does not apply to this operator,
which is checked by naming matching, the type difference is Inf.
"""
function check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false)
inputs = edge_inputs(d, op_id, edge_val)
output = edge_output(d, op_id, edge_val)

max_score = length(inputs) + length(output)

rule_types = vcat(rule.src_types, rule.res_type)
deca_types = vcat(d[inputs, :type], d[output, :type])

score = mapreduce(+, rule_types, deca_types; init = 0) do rule_t, deca_t
if ignore_infers && deca_t in INFER_TYPES
return 1
elseif ignore_usertypes && deca_t in USER_TYPES
return 1
else
return rule_t == deca_t
end
end

dop_name = edge_function(d, op_id, edge_val)

named = check_name && dop_name == rule.op_name
aliased = check_aliases && dop_name in rule.aliases

return (named || aliased) ? max_score - score : Inf
end

function apply_inference_rule!(d::SummationDecapode, op_id, rule, edge_val)

type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true)

if type_diff == 1
vars = vcat(edge_inputs(d, op_id, edge_val), edge_output(d, op_id, edge_val))
types = vcat(rule.src_types, rule.res_type)
return any(map(vars, types) do var, type
safe_modifytype!(d, var, type)
end)
end

return false
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])
function apply_overloading_rule!(d::SummationDecapode, op_id, rule, edge_val)

type_diff = check_operator(d, op_id, rule, edge_val; check_aliases = true)

check_op = (d[op2_id, :op2] in rule.op_names)
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
if type_diff == 0
set_edge_label!(d, op_id, rule.op_name, edge_val)
return true
end

return false
end

struct DecaTypeError{T}
rule::Operator{T}
idx::Int
table::Symbol
end

Base.show(io::IO, type_error::DecaTypeError{T}) where T = println("Operator at index $(type_error.idx) in table $(type_error.table) is not correctly typed. Perhaps the operator was meant to be $(type_error.rule)?")

struct DecaTypeExeception{T} <: Exception
type_errors::Vector{DecaTypeError{T}}
end

function Base.show(io::IO, type_except::DecaTypeExeception{T}) where T
map(x -> Base.show(io, x), type_except.type_errors)
end

function run_typechecking(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

type_errors = DecaTypeError{Symbol}[]

for table in [:Op1, :Op2]
for op_idx in parts(d, table)
type_error = run_typechecking_for_op(d, op_idx, type_rules, Val(table))
if type_error !== nothing
push!(type_errors, type_error)
end
end
end

return type_errors
end

function run_typechecking_for_op(d::SummationDecapode, op_id, type_rules, edge_val::Val{table}) where table
min_diff, min_rule_idx = findmin(type_rules) do rule
check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true, ignore_infers = true, ignore_usertypes = true)
end
min_diff in [0,Inf] ? nothing : DecaTypeError{Symbol}(type_rules[min_rule_idx], op_id, table)
end

# TODO: Although the big-O complexity is the same, it might be more efficent on
# average to iterate over edges then rules, instead of rules then edges. This
Expand All @@ -506,7 +621,7 @@ end

Infer types of Vars given rules wherein one type is known and the other not.
"""
function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :op_names), Tuple{Symbol, Symbol, Vector{Symbol}}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :op_names), Tuple{Symbol, Symbol, Symbol, Vector{Symbol}}}})
function infer_types!(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

# This is an optimization so we do not "visit" a row which has no infer types.
# It could be deleted if found to be not worth maintainability tradeoff.
Expand All @@ -519,28 +634,23 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
types_known_op2[incident(d, :infer, [:proj2, :type])] .= false
types_known_op2[incident(d, :infer, [:res, :type])] .= false

types_known = Dict{Symbol, Vector{Bool}}(:Op1 => types_known_op1, :Op2 => types_known_op2)

while true
applied = false

for rule in op1_rules
for op1_idx in parts(d, :Op1)
types_known_op1[op1_idx] && continue

this_applied = apply_inference_rule_op1!(d, op1_idx, rule)

types_known_op1[op1_idx] = this_applied
applied |= this_applied
end
end
for table in [:Op1, :Op2]
for op_idx in parts(d, table)
types_known[table][op_idx] && continue

for rule in op2_rules
for op2_idx in parts(d, :Op2)
types_known_op2[op2_idx] && continue
for rule in type_rules
this_applied = apply_inference_rule!(d, op_idx, rule, Val(table))

this_applied = apply_inference_rule_op2!(d, op2_idx, rule)
types_known[table][op_idx] = this_applied
applied |= this_applied
end

types_known_op2[op2_idx] = this_applied
applied |= this_applied
end
end

Expand All @@ -554,38 +664,53 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t
d
end



""" function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}})

Resolve function overloads based on types of src and tgt.
"""
function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :resolved_name, :op), NTuple{5, Symbol}}})
for op1_idx in parts(d, :Op1)
src = d[:src][op1_idx]; tgt = d[:tgt][op1_idx]; op1 = d[:op1][op1_idx]
src_type = d[:type][src]; tgt_type = d[:type][tgt]
for rule in op1_rules
if op1 == rule[:op] && src_type == rule[:src_type] && tgt_type == rule[:tgt_type]
d[op1_idx, :op1] = rule[:resolved_name]
break
end
end
end

for op2_idx in parts(d, :Op2)
proj1 = d[:proj1][op2_idx]; proj2 = d[:proj2][op2_idx]; res = d[:res][op2_idx]; op2 = d[:op2][op2_idx]
proj1_type = d[:type][proj1]; proj2_type = d[:type][proj2]; res_type = d[:type][res]
for rule in op2_rules
if op2 == rule[:op] && proj1_type == rule[:proj1_type] && proj2_type == rule[:proj2_type] && res_type == rule[:res_type]
d[op2_idx, :op2] = rule[:resolved_name]
break
function resolve_overloads!(d::SummationDecapode, resolve_rules::AbstractVector{Operator{Symbol}})
for rule in resolve_rules
for table in [:Op1, :Op2]
for op_idx in parts(d, table)
apply_overloading_rule!(d, op_idx, rule, Val(table))
end
end
end

d
end

"""
type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})

Takes a Decapode and a set of rules and checks to see if the operators that are in the Decapode
contain a valid configuration of input/output types. If an operator in the Decapode does not
contain a rule in the rule set it will be seen as valid.

In the case of a type error a DecaTypeExeception is thrown. Otherwise true is returned.
"""
function type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}})
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved
type_errors = run_typechecking(d, type_rules)

isempty(type_errors) && return true

throw(DecaTypeExeception{Symbol}(type_errors))
return false
end


"""
infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}})

Runs type inference, overload resolution and type checking in that order.
"""
function infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}})
infer_types!(d, operators)
resolve_overloads!(d, operators)
type_check(d, operators)

d
end

function replace_names!(d::SummationDecapode, op1_repls::Vector{Pair{Symbol, Any}}, op2_repls::Vector{Pair{Symbol, Symbol}})
for (orig,repl) in op1_repls
Expand Down
7 changes: 5 additions & 2 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ using Catlab

using Reexport

import ..infer_types!, ..resolve_overloads!
import ..infer_types!, ..resolve_overloads!, ..type_check, ..infer_resolve!
import ..arithmetic_operators, ..same_type_rules_op

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!
export normalize_unicode, varname, infer_types!, resolve_overloads!, type_check, infer_resolve!,
typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, vec_to_dec!,
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators

include("deca_acset.jl")
include("deca_visualization.jl")
Expand Down
Loading
Loading