diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 558a839..7b3b87a 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -7,8 +7,7 @@ using Catlab export DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, # Deca -op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, -op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, +op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators, recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, ## collages Collage, collate, @@ -18,8 +17,7 @@ apex, @relation, # Re-exported from Catlab ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, -resolve_overloads!, replace_names!, -apply_inference_rule_op1!, apply_inference_rule_op2!, +resolve_overloads!, replace_names!, type_check, check_rule_ambiguity, transfer_parents!, transfer_children!, unique_lits!, ## language @@ -32,7 +30,9 @@ to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators -transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! +transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!, +Operator, infer_resolve!, type_check, DecaTypeExeception + using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom diff --git a/src/acset.jl b/src/acset.jl index 4cfdaac..c793e27 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -5,6 +5,8 @@ using ACSets.InterTypes @intertypes "decapodeacset.it" module decapodeacset end +import Base.show + using .decapodeacset # Transferring pointers @@ -362,7 +364,7 @@ function find_chains(d::SummationDecapode; filter!(x -> passes_white_list(d[x, :op1]), chain_starts) filter!(x -> passes_black_list(d[x, :op1]), chain_starts) - + s = Stack{Int64}() foreach(x -> push!(s, x), chain_starts) while !isempty(s) @@ -409,8 +411,7 @@ function add_parameter(d::AbstractNamedDecapode, k::Symbol) end -""" - safe_modifytype(org_type::Symbol, new_type::Symbol) +""" safe_modifytype(org_type::Symbol, new_type::Symbol) This function accepts an original type and a new type and determines if the original type can be safely overwritten by the new type. @@ -420,8 +421,7 @@ function safe_modifytype(org_type::Symbol, new_type::Symbol) return (modify, modify ? new_type : org_type) end -""" - safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) +""" safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) This function calls `safe_modifytype` to safely modify a Decapode's variable type. """ @@ -430,8 +430,7 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) return modify end -""" - filterfor_ec_types(types::AbstractVector{Symbol}) +""" filterfor_ec_types(types::AbstractVector{Symbol}) Return any form or vector-field type symbols. """ @@ -440,6 +439,88 @@ function filterfor_ec_types(types::AbstractVector{Symbol}) filter(conditions, types) end +struct Operator{T} + res_type::T + src_types::AbstractVector{T} + op_name::Symbol + aliases::AbstractVector{Symbol} + + function Operator{T}(res_type::T, src_types::AbstractVector{T}, op_name, aliases = Symbol[]) where T + new(res_type, src_types, op_name, aliases) + end + + function Operator{T}(res_type::T, src_type::T, op_name, aliases = Symbol[]) where T + new(res_type, T[src_type], op_name, aliases) + end + + function Operator(res_type::Symbol, src_type::Union{Symbol, AbstractVector{Symbol}}, op_name, aliases = Symbol[]) + Operator{Symbol}(res_type, src_type, op_name, aliases) + end +end + +function same_type_rules_op(op_name::Symbol, types::AbstractVector{Symbol}, arity::Int, g_aliases::AbstractVector{Symbol} = Symbol[], sp_aliases::AbstractVector = Symbol[]) + @assert isempty(sp_aliases) || length(types) == length(sp_aliases) + map(1:length(types)) do i + aliases = isempty(sp_aliases) ? g_aliases : vcat(g_aliases, sp_aliases[i]) + Operator{Symbol}(types[i], repeat([types[i]], arity), op_name, aliases) + end +end + +function arithmetic_operators(op_name::Symbol, broadcasted::Bool, arity::Int = 2) + @match (broadcasted, arity) begin + (true, 2) => bin_broad_arith_ops(op_name) + _ => error("This type of arithmetic operator is not yet supported or may not be valid.") + end +end + +function bin_broad_arith_ops(op_name) + all_ops = map(t -> Operator{Symbol}(t, [t, t], op_name), FORM_TYPES) + for type in vcat(USER_TYPES, NUMBER_TYPES) + append!(all_ops, map(t -> Operator{Symbol}(t, [t, type], op_name), FORM_TYPES)) + append!(all_ops, map(t -> Operator{Symbol}(t, [type, t], op_name), FORM_TYPES)) + end + + all_ops +end + +# TODO: This could probably be implemented using a better version of `check_operator` +# TODO: Add printing of rules which are ambigious with each other +function check_rule_ambiguity(type_rules::AbstractVector{Operator{Symbol}}) + ntype_rules = length(type_rules) + for idx1 in 1:ntype_rules + for idx2 in idx1+1:ntype_rules + + rule1 = type_rules[idx1] + rule2 = type_rules[idx2] + + if rule1.op_name == rule2.op_name || !isempty(intersect(rule1.aliases, rule2.aliases)) + types1 = vcat(rule1.res_type, rule1.src_types) + types2 = vcat(rule2.res_type, rule2.src_types) + + if length(types1) != length(types2) + continue + end + + score = mapreduce(+, types1, types2; init = 0) do type1, type2 + if type1 == type2 + return 0 + elseif type1 in NONINFERABLE_TYPES || type2 in NONINFERABLE_TYPES + return Inf + else + return 1 + end + end + + if score == 1 # Criteria for inferring a type + return false + end + end + + end + end + return true +end + function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) # Note that we are not doing any type checking here for users! # i.e. We are not checking the underlying types of Constant or Parameter @@ -466,36 +547,104 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) return applied end -function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - score_src = (rule.src_type == d[d[op1_id, :src], :type]) - score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) +""" check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false) + +Cross references a given operator's name and its input/ouput types with a given rule. It +reutrns the number of differences in the types. If the rule does not apply to this operator, +which is checked by naming matching, the type difference is Inf. +""" +function check_operator(d::SummationDecapode, op_id, rule, edge_val; check_name::Bool = false, check_aliases::Bool = false, ignore_infers::Bool = false, ignore_usertypes::Bool = false) + inputs = edge_inputs(d, op_id, edge_val) + output = edge_output(d, op_id, edge_val) + + max_score = length(inputs) + length(output) + + rule_types = vcat(rule.src_types, rule.res_type) + deca_types = vcat(d[inputs, :type], d[output, :type]) - check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) - mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) - mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) - return mod_src || mod_tgt + score = mapreduce(+, rule_types, deca_types; init = 0) do rule_t, deca_t + if ignore_infers && deca_t in INFER_TYPES + return 1 + elseif ignore_usertypes && deca_t in USER_TYPES + return 1 + else + return rule_t == deca_t + end + end + + dop_name = edge_function(d, op_id, edge_val) + + named = check_name && dop_name == rule.op_name + aliased = check_aliases && dop_name in rule.aliases + + return (named || aliased) ? max_score - score : Inf +end + +function apply_inference_rule!(d::SummationDecapode, op_id, rule, edge_val) + + type_diff = check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true) + + if type_diff == 1 + vars = vcat(edge_inputs(d, op_id, edge_val), edge_output(d, op_id, edge_val)) + types = vcat(rule.src_types, rule.res_type) + return any(map(vars, types) do var, type + safe_modifytype!(d, var, type) + end) end return false end -function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) - score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) - score_res = (rule.res_type == d[d[op2_id, :res], :type]) +function apply_overloading_rule!(d::SummationDecapode, op_id, rule, edge_val) - check_op = (d[op2_id, :op2] in rule.op_names) - if check_op && (score_proj1 + score_proj2 + score_res == 2) - mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) - mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) - return mod_proj1 || mod_proj2 || mod_res + type_diff = check_operator(d, op_id, rule, edge_val; check_aliases = true) + + if type_diff == 0 + set_edge_label!(d, op_id, rule.op_name, edge_val) + return true end return false end +struct DecaTypeError{T} + rule::Operator{T} + idx::Int + table::Symbol +end + +Base.show(io::IO, type_error::DecaTypeError{T}) where T = println("Operator at index $(type_error.idx) in table $(type_error.table) is not correctly typed. Perhaps the operator was meant to be $(type_error.rule)?") + +struct DecaTypeExeception{T} <: Exception + type_errors::Vector{DecaTypeError{T}} +end + +function Base.show(io::IO, type_except::DecaTypeExeception{T}) where T + map(x -> Base.show(io, x), type_except.type_errors) +end + +function run_typechecking(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) + + type_errors = DecaTypeError{Symbol}[] + + for table in [:Op1, :Op2] + for op_idx in parts(d, table) + type_error = run_typechecking_for_op(d, op_idx, type_rules, Val(table)) + if type_error !== nothing + push!(type_errors, type_error) + end + end + end + + return type_errors +end + +function run_typechecking_for_op(d::SummationDecapode, op_id, type_rules, edge_val::Val{table}) where table + min_diff, min_rule_idx = findmin(type_rules) do rule + check_operator(d, op_id, rule, edge_val; check_name = true, check_aliases = true, ignore_infers = true, ignore_usertypes = true) + end + min_diff in [0,Inf] ? nothing : DecaTypeError{Symbol}(type_rules[min_rule_idx], op_id, table) +end # TODO: Although the big-O complexity is the same, it might be more efficent on # average to iterate over edges then rules, instead of rules then edges. This @@ -506,7 +655,7 @@ end Infer types of Vars given rules wherein one type is known and the other not. """ -function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :op_names), Tuple{Symbol, Symbol, Vector{Symbol}}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :op_names), Tuple{Symbol, Symbol, Symbol, Vector{Symbol}}}}) +function infer_types!(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) # This is an optimization so we do not "visit" a row which has no infer types. # It could be deleted if found to be not worth maintainability tradeoff. @@ -519,28 +668,23 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t types_known_op2[incident(d, :infer, [:proj2, :type])] .= false types_known_op2[incident(d, :infer, [:res, :type])] .= false + types_known = Dict{Symbol, Vector{Bool}}(:Op1 => types_known_op1, :Op2 => types_known_op2) + while true applied = false - for rule in op1_rules - for op1_idx in parts(d, :Op1) - types_known_op1[op1_idx] && continue - - this_applied = apply_inference_rule_op1!(d, op1_idx, rule) - types_known_op1[op1_idx] = this_applied - applied |= this_applied - end - end + for table in [:Op1, :Op2] + for op_idx in parts(d, table) + types_known[table][op_idx] && continue - for rule in op2_rules - for op2_idx in parts(d, :Op2) - types_known_op2[op2_idx] && continue + for rule in type_rules + this_applied = apply_inference_rule!(d, op_idx, rule, Val(table)) - this_applied = apply_inference_rule_op2!(d, op2_idx, rule) + types_known[table][op_idx] = this_applied + applied |= this_applied + end - types_known_op2[op2_idx] = this_applied - applied |= this_applied end end @@ -554,31 +698,15 @@ function infer_types!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_t d end - - """ function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}) Resolve function overloads based on types of src and tgt. """ -function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{(:src_type, :tgt_type, :resolved_name, :op), NTuple{4, Symbol}}}, op2_rules::Vector{NamedTuple{(:proj1_type, :proj2_type, :res_type, :resolved_name, :op), NTuple{5, Symbol}}}) - for op1_idx in parts(d, :Op1) - src = d[:src][op1_idx]; tgt = d[:tgt][op1_idx]; op1 = d[:op1][op1_idx] - src_type = d[:type][src]; tgt_type = d[:type][tgt] - for rule in op1_rules - if op1 == rule[:op] && src_type == rule[:src_type] && tgt_type == rule[:tgt_type] - d[op1_idx, :op1] = rule[:resolved_name] - break - end - end - end - - for op2_idx in parts(d, :Op2) - proj1 = d[:proj1][op2_idx]; proj2 = d[:proj2][op2_idx]; res = d[:res][op2_idx]; op2 = d[:op2][op2_idx] - proj1_type = d[:type][proj1]; proj2_type = d[:type][proj2]; res_type = d[:type][res] - for rule in op2_rules - if op2 == rule[:op] && proj1_type == rule[:proj1_type] && proj2_type == rule[:proj2_type] && res_type == rule[:res_type] - d[op2_idx, :op2] = rule[:resolved_name] - break +function resolve_overloads!(d::SummationDecapode, resolve_rules::AbstractVector{Operator{Symbol}}) + for rule in resolve_rules + for table in [:Op1, :Op2] + for op_idx in parts(d, table) + apply_overloading_rule!(d, op_idx, rule, Val(table)) end end end @@ -586,6 +714,35 @@ function resolve_overloads!(d::SummationDecapode, op1_rules::Vector{NamedTuple{( d end +""" type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) + +Takes a Decapode and a set of rules and checks to see if the operators that are in the Decapode +contain a valid configuration of input/output types. If an operator in the Decapode does not +contain a rule in the rule set it will be seen as valid. + +In the case of a type error a DecaTypeExeception is thrown. Otherwise true is returned. +""" +function type_check(d::SummationDecapode, type_rules::AbstractVector{Operator{Symbol}}) + type_errors = run_typechecking(d, type_rules) + + isempty(type_errors) && return true + + throw(DecaTypeExeception{Symbol}(type_errors)) + return false +end + + +""" infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}}) + +Runs type inference, overload resolution and type checking in that order. +""" +function infer_resolve!(d::SummationDecapode, operators::AbstractVector{Operator{Symbol}}) + infer_types!(d, operators) + resolve_overloads!(d, operators) + type_check(d, operators) + + d +end function replace_names!(d::SummationDecapode, op1_repls::Vector{Pair{Symbol, Any}}, op2_repls::Vector{Pair{Symbol, Symbol}}) for (orig,repl) in op1_repls diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 3187189..c5d0d9e 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -6,9 +6,12 @@ using Catlab using Reexport -import ..infer_types!, ..resolve_overloads! +import ..infer_types!, ..resolve_overloads!, ..type_check, ..infer_resolve! +import ..arithmetic_operators, ..same_type_rules_op -export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec! +export normalize_unicode, varname, infer_types!, resolve_overloads!, type_check, infer_resolve!, +typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, vec_to_dec!, +op1_operators, op1_1D_bound_operators, op1_2D_bound_operators, op2_operators, default_operators include("deca_acset.jl") include("deca_visualization.jl") diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index e0f5580..67bbf1b 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -2,300 +2,145 @@ using ..DiagrammaticEquations # TODO: You could write a method which auto-generates these rules given degree N. """ -These are the default rules used to do type inference in the 1D exterior calculus. +These are the default rules used to do type inference/function resolution in the 1D/2D exterior calculus. """ -op1_inf_rules_1D = [ +op1_operators = [ # Rules for ∂ₜ - (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ,:dt]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ,:dt]), - - # Rules for d - (src_type = :Form0, tgt_type = :Form1, op_names = [:d, :d₀]), - (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), - - # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₁, :star]), - (src_type = :DualForm1, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), - - # Rules for Δ - (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:Δ, :Δ₁, :lapl]), + Operator(:Form0, :Form0, :∂ₜ, [:dt]), + Operator(:Form1, :Form1, :∂ₜ, [:dt]), + Operator(:Form2, :Form2, :∂ₜ, [:dt]), + Operator(:DualForm0, :DualForm0, :∂ₜ, [:dt]), + Operator(:DualForm1, :DualForm1, :∂ₜ, [:dt]), + Operator(:DualForm2, :DualForm2, :∂ₜ, [:dt]), + + # Rules for d. + Operator(:Form1, :Form0, :d₀, [:d]), + Operator(:Form2, :Form1, :d₁, [:d]), + Operator(:DualForm1, :DualForm0, :dual_d₀, [:d, :d̃₀]), + Operator(:DualForm2, :DualForm1, :dual_d₁, [:d, :d̃₁]), # Rules for δ - (src_type = :Form1, tgt_type = :Form0, op_names = [:δ, :δ₁, :codif]), - - # Rules for negation - (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - - # Rules for the averaging operator - (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), - - # Rules for ♯. - (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), - (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), - - # Rules for ♭. - (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + Operator(:Form0, :Form1, :δ₁, [:δ, :codif]), + Operator(:Form1, :Form2, :δ₂, [:δ, :codif]), - # Rules for magnitude/ norm - (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] + # Rules for ♯ + Operator(:PVF, :Form1, :♯ᵖᵖ, [:♯]), + Operator(:DVF, :DualForm1, :♯ᵈᵈ, [:♯]), -op2_inf_rules_1D = [ - # Rules for ∧₀₀, ∧₁₀, ∧₀₁ - (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:∧, :∧₀₀, :wedge]), - (proj1_type = :Form1, proj2_type = :Form0, res_type = :Form1, op_names = [:∧, :∧₁₀, :wedge]), - (proj1_type = :Form0, proj2_type = :Form1, res_type = :Form1, op_names = [:∧, :∧₀₁, :wedge]), + # Rules for ♭ + Operator(:Form1, :DVF, :♭ᵈᵖ, [:♭]), - # Rules for L₀, L₁ - (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:L, :L₀]), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), - - # Rules for i₁ - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:i, :i₁]), - - # Rules for divison and multiplication - (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - - # WARNING: This parameter type inference might be wrong, depending on what the user gives as a parameter - #= (proj1_type = :Parameter, proj2_type = :Form0, res_type = :Form0, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Parameter, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Parameter, proj2_type = :Form2, res_type = :Form2, op_names = [:/, :./, :*, :.*]), - - (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*]),=# - - (proj1_type = :Form0, proj2_type = :Literal, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Form1, proj2_type = :Literal, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - - (proj1_type = :DualForm0, proj2_type = :Literal, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Literal, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + # Rules for Δ + Operator(:Form0, :Form0, :Δ₀, [:Δ, :∇², :lapl]), + Operator(:Form1, :Form1, :Δ₁, [:Δ, :∇², :lapl]), + Operator(:Form2, :Form2, :Δ₂, [:Δ, :∇², :lapl]), - (proj1_type = :Literal, proj2_type = :Form0, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Literal, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), + # Rules for Δᵈ + Operator(:DualForm0, :DualForm0, :Δᵈ₀, [:Δ, :∇², :lapl]), + Operator(:DualForm1, :DualForm1, :Δᵈ₁, [:Δ, :∇², :lapl]), - (proj1_type = :Literal, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Literal, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + Operator(:Form0, :Form0, :(-), [:neg]), + Operator(:Form1, :Form1, :(-), [:neg]), + Operator(:Form2, :Form2, :(-), [:neg]), + Operator(:DualForm0, :DualForm0, :(-), [:neg]), + Operator(:DualForm1, :DualForm1, :(-), [:neg]), + Operator(:DualForm2, :DualForm2, :(-), [:neg]), - (proj1_type = :Constant, proj2_type = :Form0, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Constant, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Form0, proj2_type = :Constant, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Form1, proj2_type = :Constant, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), + # Rules for the averaging operator + Operator(:Form1, :Form0, :avg₀₁, [:avg_01]), - (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + Operator(:Form0, :PVF, :norm, [:mag]), + Operator(:DualForm0, :DVF, :norm, [:mag]) +] - # These rules contain infer: - (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), - (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] +op1_1D_bound_operators = [ -""" -These are the default rules used to do type inference in the 2D exterior calculus. -""" -op1_inf_rules_2D = [ - # Rules for ∂ₜ - (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ, :dt]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ, :dt]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:∂ₜ, :dt]), + # Rules for ⋆ + Operator(:DualForm1, :Form0, :⋆₀, [:★, :⋆, :star]), + Operator(:DualForm0, :Form1, :⋆₁, [:★, :⋆, :star]), + Operator(:Form0, :DualForm1, :⋆₀⁻¹, [:★, :⋆, :star, :star_inv]), + Operator(:Form1, :DualForm0, :⋆₁⁻¹, [:★, :⋆, :star, :star_inv]) +] - # Rules for d - (src_type = :Form0, tgt_type = :Form1, op_names = [:d, :d₀]), - (src_type = :Form1, tgt_type = :Form2, op_names = [:d, :d₁]), - (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), - (src_type = :DualForm1, tgt_type = :DualForm2, op_names = [:d, :dual_d₁, :d̃₁]), +op1_2D_bound_operators = [ # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm2, op_names = [:★, :⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₁, :star]), - (src_type = :Form2, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₂, :star]), + Operator(:DualForm2, :Form0, :⋆₀, [:★, :⋆, :star]), + Operator(:DualForm1, :Form1, :⋆₁, [:★, :⋆, :star]), + Operator(:DualForm0, :Form2, :⋆₂, [:★, :⋆, :star]), + Operator(:Form0, :DualForm2, :⋆₀⁻¹, [:★, :⋆, :star, :star_inv]), + Operator(:Form1, :DualForm1, :⋆₁⁻¹, [:★, :⋆, :star, :star_inv]), + Operator(:Form2, :DualForm0, :⋆₂⁻¹, [:★, :⋆, :star, :star_inv]) +] + +op2_operators = [ + # Rules for ∧. + Operator(:Form0, [:Form0, :Form0], :∧₀₀, [:∧, :wedge]), + Operator(:Form1, [:Form1, :Form0], :∧₁₀, [:∧, :wedge]), + Operator(:Form1, [:Form0, :Form1], :∧₀₁, [:∧, :wedge]), + Operator(:Form2, [:Form1, :Form1], :∧₁₁, [:∧, :wedge]), + Operator(:Form2, [:Form2, :Form0], :∧₂₀, [:∧, :wedge]), + Operator(:Form2, [:Form0, :Form2], :∧₀₂, [:∧, :wedge]), + + # Rules for L. + Operator(:DualForm0, [:Form1, :DualForm0], :L₀, [:L]), + Operator(:DualForm1, [:Form1, :DualForm1], :L₁, [:L]), + Operator(:DualForm2, [:Form1, :DualForm2], :L₂, [:L]), + + # TODO: Make consistent with other Lie's + Operator(:DualForm1, [:DualForm1, :DualForm1], :ℒ₁), + + # Rules for i. + Operator(:DualForm0, [:Form1, :DualForm1], :i₁, [:i]), + Operator(:DualForm1, [:Form1, :DualForm2], :i₂, [:i]), + Operator(:DualForm0, [:DualForm1, :DualForm1], :ι₁₁), + Operator(:DualForm1, [:DualForm1, :DualForm2], :ι₁₂), + + # Arthimetic rules + arithmetic_operators(:.-, true)..., + arithmetic_operators(:./, true)..., + arithmetic_operators(:.*, true)..., + arithmetic_operators(:.^, true)..., + + # TODO: Only labelled as broadcasted since Decapodes converts all these + # to their broadcasted forms. They really should have different rules. + arithmetic_operators(:-, true)..., + arithmetic_operators(:/, true)..., + arithmetic_operators(:*, true)..., + arithmetic_operators(:^, true)..., + + # TODO: Add some intermediate result type to avoid having infers + # Operator(:Form0, [:Form0, :infer], :^), + # Operator(:Form0, [:Form0, :infer], :.^), + # Operator(:Form1, [:Form1, :infer], :^), + # Operator(:Form1, [:Form1, :infer], :.^) +] - (src_type = :DualForm2, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm1, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form2, op_names = [:★, :⋆, :⋆₂⁻¹, :star_inv]), - - # Rules for Δ - (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:Δ, :Δ₁, :lapl]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:Δ, :Δ₂, :lapl]), +# TODO: When SummationDecapodes are annotated with the degree of their space, +# use dispatch to choose the correct set of rules. +function default_operators(dim) + @assert 1 <= dim <= 2 + metric_free = vcat(op1_operators, op2_operators) + return vcat(metric_free, dim == 1 ? op1_1D_bound_operators : op1_2D_bound_operators) +end - # Rules for Δᵈ - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:Δᵈ₀]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:Δᵈ₁]), +infer_types!(d::SummationDecapode; dim = 2) = + infer_types!(d, default_operators(dim)) - # Rules for δ - (src_type = :Form1, tgt_type = :Form0, op_names = [:δ, :δ₁, :codif]), - (src_type = :Form2, tgt_type = :Form1, op_names = [:δ, :δ₂, :codif]), +resolve_overloads!(d::SummationDecapode; dim = 2) = + resolve_overloads!(d, default_operators(dim)) - # Rules for the averaging operator - (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), - - # Rules for negation - (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:neg, :(-)]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:neg, :(-)]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), - - # Rules for ♯. - (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), - (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), - - # Rules for ♭. - (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), - - # Rules for magnitude/ norm - (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] - -op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ - # Rules for ∧₁₁, ∧₂₀, ∧₀₂ - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, op_names = [:∧, :∧₁₁, :wedge]), - (proj1_type = :Form2, proj2_type = :Form0, res_type = :Form2, op_names = [:∧, :∧₂₀, :wedge]), - (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, op_names = [:∧, :∧₀₂, :wedge]), - - # Rules for L₂ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), - - # Rules for i₁ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:i, :i₂]), - - # Rules for ℒ - (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:ℒ₁]), - - # Rules for ι - (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:ι₁₁]), - (proj1_type = :DualForm1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:ι₁₂]), - - # Rules for subtraction - (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:-, :.-]), - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:-, :.-]), - (proj1_type = :Form2, proj2_type = :Form2, res_type = :Form2, op_names = [:-, :.-]), - (proj1_type = :DualForm0, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:-, :.-]), - (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:-, :.-]), - (proj1_type = :DualForm2, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:-, :.-]), - - # Rules for divison, multiplication, and exponentiation. - (proj1_type = :Form2, proj2_type = :Form2, res_type = :Form2, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Literal, proj2_type = :Form2, res_type = :Form2, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Form2, proj2_type = :Literal, res_type = :Form2, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm2, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :Literal, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm2, proj2_type = :Literal, res_type = :DualForm2, op_names = [:/, :./, :*, :.*, :^, :.^])]) - - """ - These are the default rules used to do function resolution in the 1D exterior calculus. - """ - op1_res_rules_1D = [ - # Rules for d. - (src_type = :Form0, tgt_type = :Form1, resolved_name = :d₀, op = :d), - (src_type = :DualForm0, tgt_type = :DualForm1, resolved_name = :dual_d₀, op = :d), - # Rules for ⋆. - (src_type = :Form0, tgt_type = :DualForm1, resolved_name = :⋆₀, op = :⋆), - (src_type = :Form1, tgt_type = :DualForm0, resolved_name = :⋆₁, op = :⋆), - (src_type = :DualForm1, tgt_type = :Form0, resolved_name = :⋆₀⁻¹, op = :⋆), - (src_type = :DualForm0, tgt_type = :Form1, resolved_name = :⋆₁⁻¹, op = :⋆), - (src_type = :Form0, tgt_type = :DualForm1, resolved_name = :⋆₀, op = :star), - (src_type = :Form1, tgt_type = :DualForm0, resolved_name = :⋆₁, op = :star), - (src_type = :DualForm1, tgt_type = :Form0, resolved_name = :⋆₀⁻¹, op = :star), - (src_type = :DualForm0, tgt_type = :Form1, resolved_name = :⋆₁⁻¹, op = :star), - # Rules for δ. - (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), - (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), - # Rules for Δ - (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ)] - - # We merge 1D and 2D rules since it seems op2 rules are metric-free. If - # this assumption is false, this needs to change. - op2_res_rules_1D = [ - # Rules for ∧. - (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, resolved_name = :∧₀₀, op = :∧), - (proj1_type = :Form1, proj2_type = :Form0, res_type = :Form1, resolved_name = :∧₁₀, op = :∧), - (proj1_type = :Form0, proj2_type = :Form1, res_type = :Form1, resolved_name = :∧₀₁, op = :∧), - (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, resolved_name = :∧₀₀, op = :wedge), - (proj1_type = :Form1, proj2_type = :Form0, res_type = :Form1, resolved_name = :∧₁₀, op = :wedge), - (proj1_type = :Form0, proj2_type = :Form1, res_type = :Form1, resolved_name = :∧₀₁, op = :wedge), - # Rules for L. - (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, resolved_name = :L₀, op = :L), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, resolved_name = :L₁, op = :L), - # Rules for i. - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, resolved_name = :i₁, op = :i)] - - - """ - These are the default rules used to do function resolution in the 2D exterior calculus. - """ - op1_res_rules_2D = [ - # Rules for d. - (src_type = :Form0, tgt_type = :Form1, resolved_name = :d₀, op = :d), - (src_type = :Form1, tgt_type = :Form2, resolved_name = :d₁, op = :d), - (src_type = :DualForm0, tgt_type = :DualForm1, resolved_name = :dual_d₀, op = :d), - (src_type = :DualForm1, tgt_type = :DualForm2, resolved_name = :dual_d₁, op = :d), - # Rules for ⋆. - (src_type = :Form0, tgt_type = :DualForm2, resolved_name = :⋆₀, op = :⋆), - (src_type = :Form1, tgt_type = :DualForm1, resolved_name = :⋆₁, op = :⋆), - (src_type = :Form2, tgt_type = :DualForm0, resolved_name = :⋆₂, op = :⋆), - (src_type = :DualForm2, tgt_type = :Form0, resolved_name = :⋆₀⁻¹, op = :⋆), - (src_type = :DualForm1, tgt_type = :Form1, resolved_name = :⋆₁⁻¹, op = :⋆), - (src_type = :DualForm0, tgt_type = :Form2, resolved_name = :⋆₂⁻¹, op = :⋆), - (src_type = :Form0, tgt_type = :DualForm2, resolved_name = :⋆₀, op = :star), - (src_type = :Form1, tgt_type = :DualForm1, resolved_name = :⋆₁, op = :star), - (src_type = :Form2, tgt_type = :DualForm0, resolved_name = :⋆₂, op = :star), - (src_type = :DualForm2, tgt_type = :Form0, resolved_name = :⋆₀⁻¹, op = :star), - (src_type = :DualForm1, tgt_type = :Form1, resolved_name = :⋆₁⁻¹, op = :star), - (src_type = :DualForm0, tgt_type = :Form2, resolved_name = :⋆₂⁻¹, op = :star), - # Rules for δ. - (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :δ), - (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), - (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif), - (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), - # Rules for ♯. - (src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯), - (src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯), - # Rules for ♭. - (src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭), - # Rules for ∇². - # TODO: Call this :nabla2 in ASCII? - (src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :∇²₁, op = :∇²), - (src_type = :Form2, tgt_type = :Form2, resolved_name = :∇²₂, op = :∇²), - # Rules for Δ. - (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :Δ), - (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl), - (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)] - - # We merge 1D and 2D rules directly here since it seems op2 rules - # are metric-free. If this assumption is false, this needs to change. - op2_res_rules_2D = vcat(op2_res_rules_1D, [ - # Rules for ∧. - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, resolved_name = :∧₁₁, op = :∧), - (proj1_type = :Form2, proj2_type = :Form0, res_type = :Form2, resolved_name = :∧₂₀, op = :∧), - (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, resolved_name = :∧₀₂, op = :∧), - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, resolved_name = :∧₁₁, op = :wedge), - (proj1_type = :Form2, proj2_type = :Form0, res_type = :Form2, resolved_name = :∧₂₀, op = :wedge), - (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, resolved_name = :∧₀₂, op = :wedge), - # Rules for L. - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂, op = :L), - # (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂ᵈ, op = :L), - # Rules for i. - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, resolved_name = :i₂, op = :i)]) +type_check(d::SummationDecapode; dim = 2) = + type_check(d, default_operators(dim)) -# TODO: When SummationDecapodes are annotated with the degree of their space, -# use dispatch to choose the correct set of rules. -infer_types!(d::SummationDecapode) = - infer_types!(d, op1_inf_rules_2D, op2_inf_rules_2D) +function infer_resolve!(d::SummationDecapode; dim = 2) + operators = default_operators(dim) + infer_types!(d, operators) + resolve_overloads!(d, operators) + type_check(d, operators) + d +end ascii_to_unicode_op1 = Pair{Symbol, Any}[ (:dt => :∂ₜ), @@ -349,12 +194,3 @@ function vec_to_dec!(d::SummationDecapode) d end - -# TODO: When SummationDecapodes are annotated with the degree of their space, -# use dispatch to choose the correct set of rules. -""" function resolve_overloads!(d::SummationDecapode) - -Resolve function overloads based on types of src and tgt. -""" -resolve_overloads!(d::SummationDecapode) = - resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl index e41048e..bf0367e 100644 --- a/src/graph_traversal.jl +++ b/src/graph_traversal.jl @@ -1,7 +1,7 @@ using DiagrammaticEquations using ACSets -export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function, set_edge_label! struct TraversalNode{T} index::Int @@ -29,6 +29,15 @@ edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = :+ +set_edge_label!(d::SummationDecapode, idx::Int, new_label, ::Val{:Op1}) = + (d[idx,:op1] = new_label) + +set_edge_label!(d::SummationDecapode, idx::Int, new_label, ::Val{:Op2}) = +(d[idx,:op2] = new_label) + +set_edge_label!(d::SummationDecapode, idx::Int, new_label, ::Val{:Σ}) = nothing + + #XXX: This topological sort is O(n^2). function topological_sort_edges(d::SummationDecapode) visited_Var = falses(nparts(d, :Var)) @@ -70,4 +79,3 @@ function retrieve_name(d::SummationDecapode, tsr::TraversalNode) _ => error("$(tsr.name) is a table without names") end end - diff --git a/test/collages.jl b/test/collages.jl index 4c9cde9..381f2dc 100644 --- a/test/collages.jl +++ b/test/collages.jl @@ -211,7 +211,7 @@ infer_types!(ParamDiffusionCollage) incl = [3] op1 = Any[:∂ₜ, [:d, :⋆, :d, :⋆]] op2 = [:*, :rb1_leftwall, :rb2_rightwall, :rb3, :r0] - type = [:Form0, :Parameter, :Form0, :infer, :Form0, :Form0, :Form0, :Form0, :Form0, :Form0, :Parameter, :Parameter] + type = [:Form0, :Parameter, :Form0, :Form0, :Form0, :Form0, :Form0, :Form0, :Form0, :Form0, :Parameter, :Parameter] name = [:K, :A, :K̇, Symbol("•2"), :r1_K, :Kb1, :r2_K, :Kb2, :r3_K̇, :Null, :r4_A, :Ab] end diff --git a/test/language.jl b/test/language.jl index 25f817e..99eb34a 100644 --- a/test/language.jl +++ b/test/language.jl @@ -5,26 +5,11 @@ using MLStyle using Base.Iterators using DiagrammaticEquations +using DiagrammaticEquations.Deca -@testset "Parsing" begin - - # @present DiffusionSpace2D(FreeExtCalc2D) begin - # X::Space - # k::Hom(Form1(X), Form1(X)) # diffusivity of space, usually constant (scalar multiplication) - # proj₁_⁰⁰₀::Hom(Form0(X) ⊗ Form0(X), Form0(X)) - # proj₂_⁰⁰₀::Hom(Form0(X) ⊗ Form0(X), Form0(X)) - # sum₀::Hom(Form0(X) ⊗ Form0(X), Form0(X)) - # prod₀::Hom(Form0(X) ⊗ Form0(X), Form0(X)) - # end - - - # Diffusion = @decapode DiffusionSpace2D begin - # (C, Ċ₁, Ċ₂)::Form0{X} - # Ċ₁ == ⋆₀⁻¹{X}(dual_d₁{X}(⋆₁{X}(k(d₀{X}(C))))) - # Ċ₂ == ⋆₀⁻¹{X}(dual_d₁{X}(⋆₁{X}(d₀{X}(C)))) - # ∂ₜ{Form0{X}}(C) == Ċ₁ + Ċ₂ - # end +import DiagrammaticEquations: Judgement, filterfor_ec_types +@testset "Parsing" begin # Tests ####### @@ -429,16 +414,55 @@ import DiagrammaticEquations: filterfor_ec_types @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end -@testset "Type Inference" begin - # Warning, this testing depends on the fact that varname, form information is - # unique within a decapode even though this is not enforced - function get_name_type_pair(d::SummationDecapode) - Set(zip(d[:name], d[:type])) - end +# Warning, this testing depends on the fact that varname, form information is +# unique within a decapode even though this is not enforced +function get_name_type_pair(d::SummationDecapode) + Set(zip(d[:name], d[:type])) +end - function test_nametype_equality(d::SummationDecapode, names_types_expected) - issetequal(get_name_type_pair(d), names_types_expected) - end +function test_nametype_equality(d::SummationDecapode, names_types_expected) + @test issetequal(get_name_type_pair(d), names_types_expected) +end + +@testset "Ruleset ambiguity" begin + amb_forward_rules = [Operator(:Form0, [:Form0], :test), Operator(:Form1, [:Form0], :test)] + @test !check_rule_ambiguity(amb_forward_rules) + + amb_back_rules = [Operator(:Form1, [:Form0], :test), Operator(:Form1, [:Form1], :test)] + @test !check_rule_ambiguity(amb_back_rules) + + amb_large_rules = [Operator(:Form0, [:Form0, :Form1, :Form2], :test), Operator(:Form1, [:Form0, :Form1, :Form2], :test)] + @test !check_rule_ambiguity(amb_large_rules) + + amb_large_back_rules = [Operator(:Form0, [:Form0, :Form1, :Form1], :test), Operator(:Form0, [:Form0, :Form1, :Form2], :test)] + @test !check_rule_ambiguity(amb_large_back_rules) + + usertype_amb_rules = [Operator(:Form0, [:Constant], :test), Operator(:Form1, [:Constant], :test)] + @test !check_rule_ambiguity(usertype_amb_rules) + + usertype_good_rules = [Operator(:Form0, [:Constant], :test), Operator(:Form0, [:Parameter], :test)] + @test check_rule_ambiguity(usertype_good_rules) + + usertype_large_rules = [Operator(:Form0, [:Form0, :Constant], :test), Operator(:Form0, [:Form0, :Literal], :test)] + @test check_rule_ambiguity(usertype_large_rules) + + different_rules = [Operator(:Form0, [:Form0], :test1), Operator(:Form1, [:Form0], :test2)] + @test check_rule_ambiguity(different_rules) + + aliases_amb = [Operator(:Form0, [:Form0], :test1, [:test]), Operator(:Form1, [:Form0], :test2, [:test])] + @test !check_rule_ambiguity(aliases_amb) + + diff_size_rules = [Operator(:Form0, [:Form0], :test), Operator(:Form1, [:Form0, :Form0], :test)] + @test check_rule_ambiguity(diff_size_rules) + + same_rules = [Operator(:Form1, [:Form0], :test), Operator(:Form1, [:Form0], :test)] + @test check_rule_ambiguity(same_rules) + + @test check_rule_ambiguity(default_operators(1)) + @test check_rule_ambiguity(default_operators(2)) +end + +@testset "Type Inference" begin # The type of the tgt of ∂ₜ is inferred. Test1 = quote @@ -450,7 +474,7 @@ end # We use set equality because we do not care about the order of the Var table. names_types_expected_1 = Set([(:C, :Form0)]) - @test test_nametype_equality(t1, names_types_expected_1) + test_nametype_equality(t1, names_types_expected_1) # The type of the src of ∂ₜ is inferred. Test2 = quote @@ -462,7 +486,7 @@ end infer_types!(t2) names_types_expected_2 = Set([(:C, :Form0)]) - @test test_nametype_equality(t2, names_types_expected_2) + test_nametype_equality(t2, names_types_expected_2) # The type of the tgt of d is inferred. Test3 = quote @@ -480,7 +504,7 @@ end infer_types!(t3) names_types_expected_3 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test test_nametype_equality(t3, names_types_expected_3) + test_nametype_equality(t3, names_types_expected_3) # The type of the src and tgt of d is inferred. Test4 = quote @@ -495,7 +519,7 @@ end infer_types!(t4) names_types_expected_4 = Set([(:C, :Form0), (:D, :Form1), (:E, :Form2)]) - @test test_nametype_equality(t4, names_types_expected_4) + test_nametype_equality(t4, names_types_expected_4) # The type of the src of d is inferred. Test5 = quote @@ -507,7 +531,7 @@ end infer_types!(t5) names_types_expected_5 = Set([(:C, :Form0), (:D, :Form1)]) - @test test_nametype_equality(t5, names_types_expected_5) + test_nametype_equality(t5, names_types_expected_5) # The type of the src of d is inferred. Test6 = quote @@ -525,7 +549,7 @@ end names_types_expected_6 = Set([ (:A, :Form0), (:B, :Form1), (:C, :Form2), (:F, :DualForm2), (:E, :DualForm1), (:D, :DualForm0)]) - @test test_nametype_equality(t6, names_types_expected_6) + test_nametype_equality(t6, names_types_expected_6) # The type of a summand is inferred. Test7 = quote @@ -568,7 +592,7 @@ end (:Γ, :Form0), (:A, :Form1), (:B, :Form1), (:C, :Form1), (:D, :Form1), (:E, :Form1), (:F, :Form1), (:Θ, :Form2)]) - @test test_nametype_equality(t8, names_types_expected_8) + test_nametype_equality(t8, names_types_expected_8) function makeInferPathDeca(log_cycles; infer_path = false) cycles = 2 ^ log_cycles @@ -613,7 +637,7 @@ end names_types_expected_9 = Set([ (:A, :Form0), (Symbol("•_1_", 1), :Form1), (Symbol("•_1_", 2), :Form2), (:B, :DualForm2), (Symbol("•_1_", 4), :DualForm1), (Symbol("•_1_", 3), :DualForm0)]) - @test test_nametype_equality(t9, names_types_expected_9) + test_nametype_equality(t9, names_types_expected_9) # Basic op2 inference using ∧ Test10 = quote @@ -633,7 +657,7 @@ end infer_types!(t10) names_types_expected_10 = Set([(:F, :Form1), (:B, :Form0), (:C, :Form0), (:H, :Form1), (:A, :Form0), (:E, :Form1), (:D, :Form0)]) - @test test_nametype_equality(t10, names_types_expected_10) + test_nametype_equality(t10, names_types_expected_10) # Basic op2 inference using L Test11 = quote @@ -650,7 +674,7 @@ end names_types_expected_11 = Set([(:A, :Form1), (:B, :DualForm0), (:C, :DualForm0), (:E, :DualForm1), (:D, :DualForm1)]) - @test test_nametype_equality(t11, names_types_expected_11) + test_nametype_equality(t11, names_types_expected_11) # Basic op2 inference using i Test12 = quote @@ -664,7 +688,7 @@ end infer_types!(t12) names_types_expected_12 = Set([(:A, :Form1), (:B, :DualForm1), (:C, :DualForm0)]) - @test test_nametype_equality(t12, names_types_expected_12) + test_nametype_equality(t12, names_types_expected_12) #2D op2 inference using ∧ Test13 = quote @@ -684,7 +708,7 @@ end infer_types!(t13) names_types_expected_13 = Set([(:E, :Form2), (:B, :Form1), (:C, :Form2), (:A, :Form1), (:F, :Form2), (:H, :Form2), (:D, :Form0)]) - @test test_nametype_equality(t13, names_types_expected_13) + test_nametype_equality(t13, names_types_expected_13) # 2D op2 inference using L Test14 = quote @@ -698,7 +722,7 @@ end infer_types!(t14) names_types_expected_14 = Set([(:C, :DualForm2), (:A, :Form1), (:B, :DualForm2)]) - @test test_nametype_equality(t14, names_types_expected_14) + test_nametype_equality(t14, names_types_expected_14) # 2D op2 inference using i Test15 = quote @@ -712,7 +736,7 @@ end infer_types!(t15) names_types_expected_15 = Set([(:A, :Form1), (:B, :DualForm2), (:C, :DualForm1)]) - @test test_nametype_equality(t15, names_types_expected_15) + test_nametype_equality(t15, names_types_expected_15) # Special case of a summation with a mix of infer and Literal. t16 = @decapode begin @@ -721,7 +745,7 @@ end infer_types!(t16) names_types_expected_16 = Set([(:A, :infer), (:C, :infer), (:D, :infer), (Symbol("2"), :Literal)]) - @test test_nametype_equality(t16, names_types_expected_16) + test_nametype_equality(t16, names_types_expected_16) # Special case of a summation with a mix of Form, Constant, and infer. t17 = @decapode begin @@ -732,7 +756,7 @@ end infer_types!(t17) names_types_expected_17 = Set([(:A, :Form0), (:C, :Constant), (:D, :Form0)]) - @test test_nametype_equality(t17, names_types_expected_17) + test_nametype_equality(t17, names_types_expected_17) # Special case of a summation with a mix of infer and Constant. t18 = @decapode begin @@ -742,7 +766,7 @@ end infer_types!(t18) names_types_expected_18 = Set([(:A, :infer), (:C, :Constant), (:D, :infer)]) - @test test_nametype_equality(t18, names_types_expected_18) + test_nametype_equality(t18, names_types_expected_18) # Test #19: Prevent intermediates from converting to Constant let @@ -755,7 +779,7 @@ end ḣ == ∘(⋆, d, ⋆)(Γ * d(h) * avg₀₁(mag(♯(d(h)))^(n-1)) * avg₀₁(h^(n+2))) end - infer_types!(d, op1_inf_rules_1D, op2_inf_rules_1D) + infer_types!(d, dim = 1) @test d[18, :type] != :Constant end @@ -771,7 +795,7 @@ end end d = expand_operators(d) - infer_types!(d, op1_inf_rules_1D, op2_inf_rules_1D) + infer_types!(d, dim = 1) @test d[8, :type] != :Literal end @@ -785,7 +809,7 @@ end infer_types!(d) names_types_expected = Set([(:A, :infer), (:C, :Parameter), (:D, :infer)]) - @test test_nametype_equality(d, names_types_expected) + test_nametype_equality(d, names_types_expected) end # Test #22: Have source form information flow over Parameters and Constants @@ -799,7 +823,7 @@ end infer_types!(d) names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) - @test test_nametype_equality(d, names_types_expected) + test_nametype_equality(d, names_types_expected) end # Test #23: Have target form information flow over Parameters and Constants @@ -813,7 +837,7 @@ end infer_types!(d) names_types_expected = Set([(:A, :Form0), (:B, :Form0), (:C, :Parameter), (:D, :Constant)]) - @test test_nametype_equality(d, names_types_expected) + test_nametype_equality(d, names_types_expected) end # Test #24: Summing mismatched forms throws an error @@ -856,13 +880,37 @@ end (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) - @test test_nametype_equality(d, names_types_expected) + test_nametype_equality(d, names_types_expected) + end + + op_not_in_rules = @decapode begin + A::Form0 + C::Form0 + + B == fake(A) + D == fake2(A, C) end + op_not_in_rules_res = deepcopy(op_not_in_rules) + infer_types!(op_not_in_rules) + + @test op_not_in_rules == op_not_in_rules_res + @test !(op_not_in_rules === op_not_in_rules_res) + + only_type_dec = @decapode begin + A::Form0 + + B == fake(A) + C == d(A) + end + infer_types!(only_type_dec) + + only_type_dec_res = Set([(:A, :Form0), (:B, :infer), (:C, :Form1)]) + test_nametype_equality(only_type_dec, only_type_dec_res) end -@testset "Overloading Resolution" begin +@testset "Overloading Resolution and Type Checking" begin # d overloading is resolved. - Test1 = quote + t1 = @decapode begin A::Form0{X} B::Form1{X} C::Form2{X} @@ -874,15 +922,15 @@ end E == d(D) F == d(E) end - t1 = SummationDecapode(parse_decapode(Test1)) resolve_overloads!(t1) op1s_1 = t1[:op1] op1s_expected_1 = [:d₀ , :d₁, :dual_d₀, :dual_d₁] @test op1s_1 == op1s_expected_1 + @test type_check(t1) # ⋆ overloading is resolved. - Test2 = quote + t2 = @decapode begin C::Form0{X} D::DualForm2{X} E::Form1{X} @@ -896,8 +944,8 @@ end G == ⋆(H) H == ⋆(G) end - t2 = SummationDecapode(parse_decapode(Test2)) resolve_overloads!(t2) + @test type_check(t2) op1s_2 = t2[:op1] # Note: The Op1 table of the decapode created does not have the functions @@ -906,7 +954,7 @@ end @test op1s_2 == op1s_expected_2 # All overloading on the de Rahm complex is resolved. - Test3 = quote + t3 = @decapode begin A::Form0{X} B::Form1{X} C::Form2{X} @@ -924,7 +972,6 @@ end D == ⋆(C) C == ⋆(D) end - t3 = SummationDecapode(parse_decapode(Test3)) resolve_overloads!(t3) op1s_3 = t3[:op1] @@ -932,8 +979,9 @@ end # listed in the order in which they appear in Test2. op1s_expected_3 = [:d₀ , :d₁, :dual_d₀, :dual_d₁, :⋆₀ , :⋆₀⁻¹ , :⋆₁ , :⋆₁⁻¹ , :⋆₂ , :⋆₂⁻¹] @test op1s_3 == op1s_expected_3 + @test type_check(t3) - Test4 = quote + t4 = @decapode begin (A, C)::Form0{X} (B, D, E)::Form1{X} @@ -947,13 +995,13 @@ end G == L(B, J) H == i(B, J) end - t4 = SummationDecapode(parse_decapode(Test4)) resolve_overloads!(t4) op2s_4 = t4[:op2] op2s_expected_4 = [:∧₀₀ , :∧₀₁, :∧₁₀, :L₀, :L₁, :i₁] @test op2s_4 == op2s_expected_4 + @test type_check(t4) - Test5 = quote + t5 = @decapode begin A::Form0{X} B::Form1{X} (C, D, E, F)::Form2{X} @@ -967,14 +1015,125 @@ end G == L(B, I) H == i(B, I) end - t5 = SummationDecapode(parse_decapode(Test5)) resolve_overloads!(t5) op2s_5 = t5[:op2] op2s_expected_5 = [:∧₁₁, :∧₀₂, :∧₂₀, :L₂, :i₂] @test op2s_5 == op2s_expected_5 + @test type_check(t5) + + op_not_in_rules = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + B == fake(A) + D == fake2(A, C) + end + op_not_in_rules_res = deepcopy(op_not_in_rules) + resolve_overloads!(op_not_in_rules) + + @test op_not_in_rules == op_not_in_rules_res + @test op_not_in_rules !== op_not_in_rules_res + @test type_check(op_not_in_rules) + + only_type_dec = @decapode begin + A::Form0 + B::Form0 + C::Form1 + + B == fake(A) + C == d(A) + end + resolve_overloads!(only_type_dec) + + only_type_dec_res = Set([(:A, :Form0), (:B, :Form0), (:C, :Form1)]) + test_nametype_equality(only_type_dec, only_type_dec_res) + @test type_check(only_type_dec) + + poorly_type_deca = @decapode begin + (A,B)::Form0 + + B == d(A) + end + resolve_overloads!(poorly_type_deca) + @test_throws DecaTypeExeception type_check(poorly_type_deca) end @testset "Type Inference and Overloading Resolution Integration" begin + + start_just_op1s = @decapode begin + B::Form0 + A == d(d(B)) + end + + infer_resolve!(start_just_op1s) + test_nametype_equality(start_just_op1s, Set([(Symbol("•1"), :Form1), (:B, :Form0), (:A, :Form2)])) + @test start_just_op1s[:op1] == [:d₀, :d₁] + + end_just_op1s = @decapode begin + A::Form2 + A == d(d(B)) + end + infer_resolve!(end_just_op1s) + test_nametype_equality(end_just_op1s, Set([(Symbol("•1"), :Form1), (:B, :Form0), (:A, :Form2)])) + @test end_just_op1s[:op1] == [:d₀, :d₁] + + all_op_types = @decapode begin + (I1, I2)::Form0 + A == Δ(d(∧(I1, I2))) + δ(I3) + end + infer_resolve!(all_op_types) + @test all_op_types[3, :name] == :A && all_op_types[3, :type] == :Form1 + @test all_op_types[8, :name] == :I3 && all_op_types[8, :type] == :Form2 + + @test all_op_types[:op1] == [:d₀, :Δ₁, :δ₂] + @test all_op_types[:op2] == [:∧₀₀] + + arithmetic = @decapode begin + B::Form0 + A == (n+1) .* B + end + infer_resolve!(arithmetic) + + @test_broken arithmetic[2, :name] == :A && arithmetic[2, :type] == :Form0 + + # Infer types and resolve overloads for the Halfar equation. + # TODO: This test isn't passing right now because we can't ignore type of the power + # in the exponent. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_resolve!(d) + @test_broken d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :norm, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + # Momentum-formulation of Navier Stokes on sphere DiffusionExprBody = quote (T, Ṫ)::Form0{X} @@ -1034,83 +1193,12 @@ end [Open(continuity, [:M, :ρ, :P, :T]), Open(NavierStokes, [:M, :ρ, :p, :T])]) HeatXfer = apex(heatXfer_cospan) + infer_types!(HeatXfer) + resolve_overloads!(HeatXfer) - bespoke_op1_inf_rules = [ - # Rules for avg. - (src_type = :Form0, tgt_type = :Form1, op_names = [:avg]), - # Rules for R₀. - (src_type = :Form0, tgt_type = :Form0, op_names = [:R₀])] - - #= bespoke_op2_inf_rules = [ - (proj1_type = :Parameter, proj2_type = :Form0, res_type = :Form0, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Parameter, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Parameter, proj2_type = :Form2, res_type = :Form2, op_names = [:/, :./, :*, :.*]), - - (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), - (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*])] =# - - infer_types!(HeatXfer, vcat(bespoke_op1_inf_rules, op1_inf_rules_2D), - op2_inf_rules_2D) - - names_types_hx = zip(HeatXfer[:name], HeatXfer[:type]) - - names_types_expected_hx = [ - (:T, :Form0), (:continuity_Ṫ₁, :Form0), (:continuity_diffusion_ϕ, :DualForm1), (:continuity_diffusion_k, :Parameter), (Symbol("continuity_diffusion_•1"), :DualForm2), (Symbol("continuity_diffusion_•2"), :Form1), (Symbol("continuity_diffusion_•3"), :Form1), (:M, :Form1), (:continuity_advection_V, :Form1), (:ρ, :Form0), (:P, :Form0), (:continuity_Ṫₐ, :Form0), (Symbol("continuity_advection_•1"), :Form0), (Symbol("continuity_advection_•2"), :Form1), (Symbol("continuity_advection_•3"), :DualForm2), (Symbol("continuity_advection_•4"), :Form0), (Symbol("continuity_advection_•5"), :DualForm2), (:continuity_Ṫ, :Form0), (:navierstokes_Ṁ, :Form1), (:navierstokes_G, :Form1), (:navierstokes_V, :Form1), (:navierstokes_ṗ, :Form0), (:navierstokes_two, :Parameter), (:navierstokes_three, :Parameter), (:navierstokes_kᵥ, :Parameter), (Symbol("navierstokes_•1"), :DualForm2), (Symbol("navierstokes_•2"), :Form1), (Symbol("navierstokes_•3"), :Form1), (Symbol("navierstokes_•4"), :DualForm1), (Symbol("navierstokes_•5"), :DualForm1), (Symbol("navierstokes_•6"), :DualForm1), (Symbol("navierstokes_•7"), :Form1), (Symbol("navierstokes_•8"), :Form1), (Symbol("navierstokes_•9"), :Form1), (Symbol("navierstokes_•10"), :Form1), (Symbol("navierstokes_•11"), :Form1), (:navierstokes_sum_1, :Form1), (Symbol("navierstokes_•12"), :Form0), (Symbol("navierstokes_•13"), :Form1), (Symbol("navierstokes_•14"), :Form1), (Symbol("navierstokes_•15"), :Form0), (Symbol("navierstokes_•16"), :DualForm0), (Symbol("navierstokes_•17"), :DualForm1), (Symbol("navierstokes_•18"), :Form1), (Symbol("navierstokes_•19"), :Form1), (Symbol("navierstokes_•20"), :Form1), (:navierstokes_sum_2, :Form0), (Symbol("navierstokes_•21"), :Form1), (Symbol("navierstokes_•22"), :Form1), (Symbol("navierstokes_•23"), :DualForm2)] - - @test issetequal(names_types_hx, names_types_expected_hx) - - bespoke_op2_res_rules = [ - # Rules for L. - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :Form2, resolved_name = :L₀, op = :L), - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₀, op = :L), - (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, resolved_name = :L₁′, op = :L)] - - resolve_overloads!(HeatXfer, op1_res_rules_2D, - vcat(bespoke_op2_res_rules, op2_res_rules_2D)) - - op1s_hx = HeatXfer[:op1] - op1s_expected_hx = [:d₀, :⋆₁, :dual_d₁, :⋆₀⁻¹, :avg, :R₀, :⋆₀, :⋆₀⁻¹, :neg, :∂ₜ, :avg, :⋆₁, :neg, :avg, :Δ₁, :δ₁, :d₀, :⋆₁, :d₀, :avg, :d₀, :neg, :avg, :∂ₜ, :⋆₀, :⋆₀⁻¹, :neg, :∂ₜ] - @test op1s_hx == op1s_expected_hx - op2s_hx = HeatXfer[:op2] - op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] - @test op2s_hx == op2s_expected_hx - - # Infer types and resolve overloads for the Halfar equation. - let - d = @decapode begin - h::Form0 - Γ::Form1 - n::Constant - - ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) - end - d = expand_operators(d) - infer_types!(d) - resolve_overloads!(d) - @test d == @acset SummationDecapode{Any, Any, Symbol} begin - Var = 19 - TVar = 1 - Op1 = 8 - Op2 = 6 - Σ = 1 - Summand = 2 - src = [1, 1, 1, 13, 12, 6, 18, 19] - tgt = [4, 9, 13, 12, 11, 18, 19, 4] - proj1 = [2, 3, 11, 8, 1, 7] - proj2 = [9, 15, 14, 10, 5, 16] - res = [8, 14, 10, 7, 16, 6] - incl = [4] - summand = [3, 17] - summation = [1, 1] - sum = [5] - op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] - op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] - type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] - name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] - end - end - + @test_throws DecaTypeExeception type_check(HeatXfer) + @test HeatXfer[12, :op2] == :* + @test HeatXfer[40, :type] == :DualForm1 && HeatXfer[39, :type] == :Form1 end @testset "Compilation Transformation" begin @@ -1125,7 +1213,7 @@ end t1_contracted = contract_operators(t1_expanded) @test t1_orig == t1_contracted # contract_operators does not mutate its argument. - @test t1_contracted !== t1_expanded + @test !(t1_contracted === t1_expanded) # contract_operators works on multiple chains. Test2 = quote @@ -1181,7 +1269,7 @@ end t6_rec_del = recursive_delete_parents(t6_orig, Vector{Int64}()) @test t6_orig == SummationDecapode{Any, Any, Symbol}() # recursive_delete_parents does not mutate its argument. - @test t6_orig !== t6_rec_del + @test !(t6_orig === t6_rec_del) # recursive_delete_parents deletes a chain of single-child parents. t7_orig = SummationDecapode(parse_decapode(quote