Skip to content

Commit

Permalink
First pass at ruleset ambiguity checking
Browse files Browse the repository at this point in the history
Meant to check if the ruleset contains a situation where multiple rules may be applied in the same situation.
  • Loading branch information
GeorgeR227 committed Nov 14, 2024
1 parent ce81442 commit 80bdec5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
resolve_overloads!, replace_names!, type_check,
resolve_overloads!, replace_names!, type_check, check_rule_ambiguity,
transfer_parents!, transfer_children!,
unique_lits!,
## language
Expand Down
38 changes: 38 additions & 0 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,44 @@ function bin_broad_arith_ops(op_name)
all_ops
end

# TODO: This could probably be implemented using a better version of `check_operator`
# TODO: Add printing of rules which are ambigious with each other
function check_rule_ambiguity(type_rules::AbstractVector{Operator{Symbol}})
ntype_rules = length(type_rules)
for idx1 in 1:ntype_rules
for idx2 in idx1+1:ntype_rules

rule1 = type_rules[idx1]
rule2 = type_rules[idx2]

if rule1.op_name == rule2.op_name || !isempty(intersect(rule1.aliases, rule2.aliases))
types1 = vcat(rule1.res_type, rule1.src_types)
types2 = vcat(rule2.res_type, rule2.src_types)

if length(types1) != length(types2)
continue
end

score = mapreduce(+, types1, types2; init = 0) do type1, type2
if type1 == type2
return 0
elseif type1 in NONINFERABLE_TYPES || type2 in NONINFERABLE_TYPES
return Inf
else
return 1
end
end

if score == 1 # Criteria for inferring a type
return false
end
end

end
end
return true
end

function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
# Note that we are not doing any type checking here for users!
# i.e. We are not checking the underlying types of Constant or Parameter
Expand Down
2 changes: 1 addition & 1 deletion src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ..arithmetic_operators, ..same_type_rules_op

export normalize_unicode, varname, infer_types!, resolve_overloads!, type_check, infer_resolve!,
typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, vec_to_dec!,
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators
op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators, default_operators

include("deca_acset.jl")
include("deca_visualization.jl")
Expand Down
38 changes: 38 additions & 0 deletions test/language.jl
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,44 @@ function test_nametype_equality(d::SummationDecapode, names_types_expected)
@test issetequal(get_name_type_pair(d), names_types_expected)
end

@testset "Ruleset ambiguity" begin
amb_forward_rules = [Operator(:Form0, [:Form0], :test), Operator(:Form1, [:Form0], :test)]
@test !check_rule_ambiguity(amb_forward_rules)

amb_back_rules = [Operator(:Form1, [:Form0], :test), Operator(:Form1, [:Form1], :test)]
@test !check_rule_ambiguity(amb_back_rules)

amb_large_rules = [Operator(:Form0, [:Form0, :Form1, :Form2], :test), Operator(:Form1, [:Form0, :Form1, :Form2], :test)]
@test !check_rule_ambiguity(amb_large_rules)

amb_large_back_rules = [Operator(:Form0, [:Form0, :Form1, :Form1], :test), Operator(:Form0, [:Form0, :Form1, :Form2], :test)]
@test !check_rule_ambiguity(amb_large_back_rules)

usertype_amb_rules = [Operator(:Form0, [:Constant], :test), Operator(:Form1, [:Constant], :test)]
@test !check_rule_ambiguity(usertype_amb_rules)

usertype_good_rules = [Operator(:Form0, [:Constant], :test), Operator(:Form0, [:Parameter], :test)]
@test check_rule_ambiguity(usertype_good_rules)

usertype_large_rules = [Operator(:Form0, [:Form0, :Constant], :test), Operator(:Form0, [:Form0, :Literal], :test)]
@test check_rule_ambiguity(usertype_large_rules)

different_rules = [Operator(:Form0, [:Form0], :test1), Operator(:Form1, [:Form0], :test2)]
@test check_rule_ambiguity(different_rules)

aliases_amb = [Operator(:Form0, [:Form0], :test1, [:test]), Operator(:Form1, [:Form0], :test2, [:test])]
@test !check_rule_ambiguity(aliases_amb)

diff_size_rules = [Operator(:Form0, [:Form0], :test), Operator(:Form1, [:Form0, :Form0], :test)]
@test check_rule_ambiguity(diff_size_rules)

same_rules = [Operator(:Form1, [:Form0], :test), Operator(:Form1, [:Form0], :test)]
@test check_rule_ambiguity(same_rules)

@test check_rule_ambiguity(default_operators(1))
@test check_rule_ambiguity(default_operators(2))
end

@testset "Type Inference" begin

# The type of the tgt of ∂ₜ is inferred.
Expand Down

0 comments on commit 80bdec5

Please sign in to comment.