diff --git a/Project.toml b/Project.toml index b40b43a..69f41e4 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DiagrammaticEquations" uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317" license = "MIT" authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"] -version = "0.1.6" +version = "0.1.7" [deps] ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index be9cace..558a839 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -2,6 +2,8 @@ """ module DiagrammaticEquations +using Catlab + export DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, # Deca @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, Collage, collate, ## composition oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, +apex, @relation, # Re-exported from Catlab ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, @@ -25,12 +28,12 @@ unique_lits!, Plus, AppCirc1, Var, Tan, App1, App2, ## visualization to_graphviz_property_graph, typename, draw_composition, +to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! -using Catlab using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs @@ -62,6 +65,7 @@ include("pretty.jl") include("colanguage.jl") include("openoperators.jl") include("symbolictheoryutils.jl") +include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") @@ -69,4 +73,6 @@ include("SymbolicUtilsInterop.jl") @reexport using .Deca @reexport using .SymbolicUtilsInterop +include("acset2symbolic.jl") + end diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 61502a8..c95fb54 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -1,6 +1,8 @@ module SymbolicUtilsInterop -using ..DiagrammaticEquations: AbstractDecapode, Quantity +using ACSets +using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp +using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..decapodes using ..Deca @@ -14,6 +16,7 @@ struct SymbolicEquation{E} lhs::E rhs::E end +export SymbolicEquation Base.show(io::IO, e::SymbolicEquation) = begin print(io, e.lhs); print(io, " == "); print(io, e.rhs) @@ -48,7 +51,7 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) decapodes.Plus(termargs) elseif op == * decapodes.Mult(termargs) - elseif op == ∂ₜ + elseif op ∈ [DerivOp, ∂ₜ] decapodes.Tan(only(termargs)) elseif length(args) == 1 decapodes.App1(nameof(op, symtype.(args)...), termargs...) @@ -59,6 +62,9 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) end end end +# TODO subtraction is not parsed as such. e.g., +# a, b = @syms a::Scalar b::Scalar +# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)])) decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) @@ -82,9 +88,9 @@ Example: SymbolicUtils.BasicSymbolic(context, Term(a)) ``` """ -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__) +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term) # user must import symbols into scope - ! = (f -> getfield(__module__, f)) + ! = (f -> getfield(@__MODULE__, f)) @match t begin Var(name) => SymbolicUtils.Sym{context[name]}(name) Lit(v) => Meta.parse(string(v)) @@ -95,17 +101,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapode # see test/language.jl (f, x) -> (!(f))(x), fs; - init=BasicSymbolic(context, arg, __module__) + init=BasicSymbolic(context, arg) ) - App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__)) - App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__)) - Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__)) + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => (!(DerivOp))(BasicSymbolic(context, x)) end end -function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) +function SymbolicContext(d::decapodes.DecaExpr) # associates each var to its sort... context = map(d.context) do j j.var => symtype(Deca.DECQuantity, j.dim, j.space) @@ -116,13 +122,13 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) end context = Dict{Symbol,DataType}(context) eqs = map(d.equations) do eq - SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) end SymbolicContext(vars, eqs) end function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) - eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) + eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) end """ function SummationDecapode(e::SymbolicContext) """ @@ -132,7 +138,7 @@ function SummationDecapode(e::SymbolicContext) foreach(e.vars) do var # convert Sort(var)::PrimalForm0 --> :Form0 - var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) + var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var))) symbol_table[var.name] = var_id end diff --git a/src/acset.jl b/src/acset.jl index 9665afd..4cfdaac 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -158,13 +158,16 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, - :Literal, :Parameter, :Constant, :infer] + :PVF, :DVF, + :Literal, :Parameter, :Constant, :infer] const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const VECTORFIELD_TYPES = [:PVF, :DVF] + +const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer] const USER_TYPES = [:Constant, :Parameter] const NUMBER_TYPES = [:Literal] const INFER_TYPES = [:infer] @@ -184,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end +export recognize_types """ is_expanded(d::AbstractNamedDecapode) @@ -427,12 +431,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) end """ - filterfor_forms(types::AbstractVector{Symbol}) + filterfor_ec_types(types::AbstractVector{Symbol}) -Return any form type symbols. +Return any form or vector-field type symbols. """ -function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in NONFORM_TYPES) +function filterfor_ec_types(types::AbstractVector{Symbol}) + conditions = x -> !(x in NON_EC_TYPES) filter(conditions, types) end @@ -447,29 +451,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) types = d[idxs, :type] all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(types)) + ec_types = unique(filterfor_ec_types(types)) - form = @match length(forms) begin + ec_type = @match length(ec_types) begin 0 => return applied # We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") + 1 => only(ec_types) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types") end for idx in idxs - applied |= safe_modifytype!(d, idx, form) + applied |= safe_modifytype!(d, idx, ec_type) end return applied end function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - type_src = d[d[op1_id, :src], :type] - type_tgt = d[d[op1_id, :tgt], :type] + score_src = (rule.src_type == d[d[op1_id, :src], :type]) + score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) - score_src = (rule.src_type == type_src) - score_tgt = (rule.tgt_type == type_tgt) check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) @@ -480,19 +481,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) end function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - type_proj1 = d[d[op2_id, :proj1], :type] - type_proj2 = d[d[op2_id, :proj2], :type] - type_res = d[d[op2_id, :res], :type] + score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) + score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) + score_res = (rule.res_type == d[d[op2_id, :res], :type]) - score_proj1 = (rule.proj1_type == type_proj1) - score_proj2 = (rule.proj2_type == type_proj2) - score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - - if(check_op && (score_proj1 + score_proj2 + score_res == 2)) + if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res end diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl new file mode 100644 index 0000000..330fa76 --- /dev/null +++ b/src/acset2symbolic.jl @@ -0,0 +1,81 @@ +using DiagrammaticEquations +using ACSets +using SymbolicUtils +using SymbolicUtils: BasicSymbolic, Symbolic + +export symbolic_rewriting + +const EQUALITY = (==) +const SymEqSym = SymbolicEquation{Symbolic} + +function symbolics_lookup(d::SummationDecapode) + Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type + (name, decavar_to_symbolics(name, type)) + end) +end + +function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) + SymbolicUtils.Sym{new_type}(var_name) +end + +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) + + S = promote_symtype(op_sym, input_syms...) + SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) +end + +function to_symbolics(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d)) +end + +function symbolic_rewriting(d::SummationDecapode, rewriter=identity) + d′ = infer_types!(deepcopy(d)) + eqns = merge_equations(d′) + to_acset(d′, map(rewriter, eqns)) +end + +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ +function merge_equations(d::SummationDecapode) + eqn_lookup, terminal_eqns = Dict(), SymEqSym[] + + foreach(to_symbolics(d)) do x + sub = SymbolicUtils.substitute(x.rhs, eqn_lookup) + push!(eqn_lookup, (x.lhs => sub)) + if x.lhs.name in infer_terminal_names(d) + push!(terminal_eqns, SymEqSym(x.lhs, sub)) + end + end + + map(terminal_eqns) do eqn + SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs]) + end +end + +function to_acset(d::SummationDecapode, sym_exprs) + literals = incident(d, :Literal, :type) + + outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i + :($(d[i, :name])::$(d[i, :type])) + end + + #TODO: This step is breaking up summations + final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) + reify!(exprs) = foreach(exprs) do x + if typeof(x) == Expr && x.head == :call + x.args[1] = nameof(x.args[1]) + reify!(x.args[2:end]) + end + end + reify!(final_exprs) + + deca_block = quote end + deca_block.args = [outer_types..., final_exprs...] + + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) +end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 383274e..4fa4695 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -24,8 +24,16 @@ export DECQuantity # this ensures symtype doesn't recurse endlessly SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S -struct Scalar <: DECQuantity end -export Scalar +abstract type AbstractScalar <: DECQuantity end + +struct InferredType <: DECQuantity end +export InferredType + +struct Scalar <: AbstractScalar end +struct Parameter <: AbstractScalar end +struct Const <: AbstractScalar end +struct Literal <: AbstractScalar end +export Scalar, Parameter, Const, Literal struct FormParams dim::Int @@ -85,6 +93,18 @@ Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") # ACTIVE PATTERNS +@active PatInferredType(T) begin + if T <: InferredType + Some(InferredType) + end +end + +@active PatInferredTypes(T) begin + if any(S->S<:InferredType, T) + Some(InferredType) + end +end + @active PatForm(T) begin if T <: Form Some(T) @@ -107,7 +127,7 @@ end export PatFormDim @active PatScalar(T) begin - if T <: Scalar + if T <: AbstractScalar Some(T) end end @@ -153,6 +173,7 @@ export isDualForm, isForm0, isForm1, isForm2 @operator d(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} _ => throw(ExteriorDerivativeError(S)) end @@ -162,6 +183,7 @@ end @operator ★(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} _ => throw(HodgeStarError(S)) end @@ -171,17 +193,23 @@ end @operator Δ(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) _ => throw(LaplacianError(S)) end @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) + @rule Δ(~x::isForm2) => d(★(d(★(~x)))) end +@alias (Δ₀, Δ₁, Δ₂) => Δ + +# TODO: Determine what we need to do for .+ @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar - (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i, d, s, n} # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} @@ -193,10 +221,13 @@ end end end -@operator -(S1, S2)::DECQuantity begin +(S1, S2) end +@operator -(S1, S2)::DECQuantity begin + promote_symtype(+, S1, S2) +end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} _ => throw(BinaryOpError("multiply", S1, S2)) @@ -205,6 +236,7 @@ end @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) if i1 + i2 <= n1 @@ -219,9 +251,10 @@ end abstract type SortError <: Exception end -# struct WedgeDimError <: SortError end - -Base.nameof(s::Scalar) = :Constant +Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal +Base.nameof(s::Union{Const, Type{Const}}) = :Constant +Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter +Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -254,6 +287,8 @@ function Base.nameof(::typeof(∧), s1, s2) Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") end +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") + Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") Base.nameof(::typeof(Δ), s) = :Δ @@ -263,15 +298,20 @@ function Base.nameof(::typeof(★), s) Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end -function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) +# TODO: Check that form type is no larger than the ambient dimension +function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, dim::Int = 2) @match qty begin - :Scalar || :Constant => Scalar - :Form0 => PrimalForm{0, space, 1} - :Form1 => PrimalForm{1, space, 1} - :Form2 => PrimalForm{2, space, 1} - :DualForm0 => DualForm{0, space, 1} - :DualForm1 => DualForm{1, space, 1} - :DualForm2 => DualForm{2, space, 1} + :Scalar => Scalar + :Constant => Const + :Parameter => Parameter + :Literal => Literal + :Form0 => PrimalForm{0, space, dim} + :Form1 => PrimalForm{1, space, dim} + :Form2 => PrimalForm{2, space, dim} + :DualForm0 => DualForm{0, space, dim} + :DualForm1 => DualForm{1, space, dim} + :DualForm2 => DualForm{2, space, dim} + :Infer => InferredType _ => error("Received $qty") end end diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 55520e3..e0f5580 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -5,7 +5,7 @@ using ..DiagrammaticEquations These are the default rules used to do type inference in the 1D exterior calculus. """ op1_inf_rules_1D = [ - # Rules for ∂ₜ + # Rules for ∂ₜ (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ,:dt]), (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ,:dt]), @@ -14,10 +14,10 @@ op1_inf_rules_1D = [ (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm1, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm0, op_names = [:⋆, :⋆₁, :star]), - (src_type = :DualForm1, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :Form0, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :DualForm1, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -29,13 +29,20 @@ op1_inf_rules_1D = [ # Rules for negation (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - + # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), - + + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] op2_inf_rules_1D = [ # Rules for ∧₀₀, ∧₁₀, ∧₀₁ @@ -45,7 +52,7 @@ op2_inf_rules_1D = [ # Rules for L₀, L₁ (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:L, :L₀]), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), + (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:i, :i₁]), @@ -62,10 +69,10 @@ op2_inf_rules_1D = [ (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*]),=# - + (proj1_type = :Form0, proj2_type = :Literal, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Literal, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :DualForm0, proj2_type = :Literal, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm1, proj2_type = :Literal, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -79,11 +86,15 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form0, proj2_type = :Constant, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Constant, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])] + (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + + # These rules contain infer: + (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), + (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] """ These are the default rules used to do type inference in the 2D exterior calculus. @@ -101,13 +112,13 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm2, op_names = [:d, :dual_d₁, :d̃₁]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm2, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm1, op_names = [:⋆, :⋆₁, :star]), - (src_type = :Form2, tgt_type = :DualForm0, op_names = [:⋆, :⋆₂, :star]), + (src_type = :Form0, tgt_type = :DualForm2, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :Form2, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₂, :star]), - (src_type = :DualForm2, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm1, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form2, op_names = [:⋆, :⋆₂⁻¹, :star_inv]), + (src_type = :DualForm2, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm1, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form2, op_names = [:★, :⋆, :⋆₂⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -133,14 +144,17 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])] - + (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] + op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, op_names = [:∧, :∧₁₁, :wedge]), @@ -148,7 +162,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, op_names = [:∧, :∧₀₂, :wedge]), # Rules for L₂ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), + (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:i, :i₂]), @@ -159,7 +173,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ι (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:ι₁₁]), (proj1_type = :DualForm1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:ι₁₂]), - + # Rules for subtraction (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:-, :.-]), (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:-, :.-]), @@ -198,7 +212,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for Δ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ)] - + # We merge 1D and 2D rules since it seems op2 rules are metric-free. If # this assumption is false, this needs to change. op2_res_rules_1D = [ @@ -214,8 +228,8 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, resolved_name = :L₁, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, resolved_name = :i₁, op = :i)] - - + + """ These are the default rules used to do function resolution in the 2D exterior calculus. """ @@ -243,6 +257,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif), (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯), + (src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯), + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭), # Rules for ∇². # TODO: Call this :nabla2 in ASCII? (src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²), @@ -255,7 +274,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)] - + # We merge 1D and 2D rules directly here since it seems op2 rules # are metric-free. If this assumption is false, this needs to change. op2_res_rules_2D = vcat(op2_res_rules_1D, [ @@ -271,7 +290,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂ᵈ, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, resolved_name = :i₂, op = :i)]) - + # TODO: When SummationDecapodes are annotated with the degree of their space, # use dispatch to choose the correct set of rules. infer_types!(d::SummationDecapode) = @@ -339,4 +358,3 @@ Resolve function overloads based on types of src and tgt. """ resolve_overloads!(d::SummationDecapode) = resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) - diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl new file mode 100644 index 0000000..e41048e --- /dev/null +++ b/src/graph_traversal.jl @@ -0,0 +1,73 @@ +using DiagrammaticEquations +using ACSets + +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function + +struct TraversalNode{T} + index::Int + name::T +end + +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + +#XXX: This topological sort is O(n^2). +function topological_sort_edges(d::SummationDecapode) + visited_Var = falses(nparts(d, :Var)) + visited_Var[start_nodes(d)] .= true + visited = Dict(:Op1 => falses(nparts(d, :Op1)), + :Op2 => falses(nparts(d, :Op2)), :Σ => falses(nparts(d, :Σ))) + + op_order = TraversalNode{Symbol}[] + + function visit(op, op_type) + if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))]) + visited[op_type][op] = true + visited_Var[edge_output(d,op,Val(op_type))] = true + push!(op_order, TraversalNode(op, op_type)) + end + end + + for _ in 1:n_ops(d) + visit.(parts(d,:Op1), :Op1) + visit.(parts(d,:Op2), :Op2) + visit.(parts(d,:Σ), :Σ) + end + + @assert length(op_order) == n_ops(d) + op_order +end + +n_ops(d::SummationDecapode) = + nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) + +start_nodes(d::SummationDecapode) = + vcat(infer_states(d), incident(d, :Literal, :type)) + +function retrieve_name(d::SummationDecapode, tsr::TraversalNode) + @match tsr.name begin + :Op1 => d[tsr.index, :op1] + :Op2 => d[tsr.index, :op2] + :Σ => :+ + _ => error("$(tsr.name) is a table without names") + end +end + diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 0c03a0a..7746559 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -43,7 +43,7 @@ Creates an operator `foo` with arguments which are types in a given Theory. This (@rule expr1) ... (@rule exprN) -end +end ``` builds ``` @@ -73,7 +73,7 @@ end """ macro operator(head, body) - # parse body + # parse head ph = @λ begin Expr(:call, foo, Expr(:(::), vars..., theory)) => (foo, vars, theory) Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) @@ -81,7 +81,7 @@ macro operator(head, body) end (f, types, Theory) = ph(head) - # Passing types to functions requires that we type the signature with ::Type{T}. + # Passing types to functions requires that we type the signature with ::Type{T}. # This means that the user would have to write `my_op(::Type{T1}, ::Type{T2}, ...)` # As a convenience to the user, we allow them to specify the signature using just the types themselves: # `my_op(T1, T2, ...)` @@ -89,15 +89,15 @@ macro operator(head, body) sort_constraints = [:($S<:$Theory) for S in types] arity = length(sort_types) - # Parse the body for @rule calls. + # Parse the body for @rule calls. block, rulecalls = @match Base.remove_linenums!(body) begin Expr(:block, block, rules...) => (block, rules) s => nothing end - + # initialize the result result = quote end - + # construct the function on basic symbolics argnames = [gensym(:x) for _ in 1:arity] argclaus = [:($a::Symbolic) for a in argnames] @@ -107,15 +107,21 @@ macro operator(head, body) s = promote_symtype($f, $(argnames...)) SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end + export $f + + # Base.show(io::IO, ::typeof($f)) = print(io, $f) end) - # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) + # if there are rewriting rules, add a method which accepts the function symbol and its arity + # (to prevent shadowing on operators like `-`) if !isempty(rulecalls) push!(result.args, quote function rules(::typeof($f), ::Val{$arity}) [($(rulecalls...))] end + + rules(::typeof($f)) = rules($f, Val{1}) end) end @@ -150,18 +156,18 @@ macro alias(body) result = quote end foreach(aliases) do alias push!(result.args, - esc(quote + quote function $alias(s...) - $rep(s...) + $rep(s...) end export $alias + Base.nameof(::typeof($alias), s) = Symbol("$alias") - end)) + end) end - result + return esc(result) end export @alias alias(x) = error("$x has no aliases") export alias - diff --git a/test/Project.toml b/test/Project.toml index 97b4819..4f3a02a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,4 +6,5 @@ CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" DiagrammaticEquations = "6f00c28b-6bed-4403-80fa-30e0dc12f317" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl new file mode 100644 index 0000000..0212a7e --- /dev/null +++ b/test/acset2symbolic.jl @@ -0,0 +1,286 @@ +using Test +using DiagrammaticEquations +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, @rule +using Catlab + +(≃) = is_isomorphic + +@testset "Basic Roundtrip" begin + op1_only = @decapode begin + A::Form0 + B::Form1 + B == d(A) + end + + @test op1_only == symbolic_rewriting(op1_only) + + + op2_only = @decapode begin + A::Constant + B::Form0 + C::Form0 + + C == A * B + end + + @test op2_only == symbolic_rewriting(op2_only) + + sum_only = @decapode begin + A::Form2 + B::Form2 + C::Form2 + + C == A + B + end + + @test sum_only == symbolic_rewriting(sum_only) + + + multi_sum = @decapode begin + A::Form2 + B::Form2 + C::Form2 + D::Form2 + + D == (A + B) + C + end + infer_types!(multi_sum) + + # TODO: This is correct but the symbolics is splitting up the sum + @test multi_sum == symbolic_rewriting(multi_sum) + + + all_ops = @decapode begin + A::Constant + B::Form0 + C::Form1 + D::Form1 + E::Form1 + F::Form1 + + C == d(B) + D == A * C + F == D + E + end + + # This loses the intermediate names C and D + all_ops_res = symbolic_rewriting(all_ops) + all_ops_res[5, :name] = :D + all_ops_res[6, :name] = :C + @test all_ops ≃ all_ops_res + + with_deriv = @decapode begin + A::Form0 + Ȧ::Form0 + + ∂ₜ(A) == Ȧ + Ȧ == Δ(A) + end + + @test with_deriv == symbolic_rewriting(with_deriv) + + repeated_vars = @decapode begin + A::Form0 + B::Form1 + C::Form1 + + C == d(A) + C == Δ(B) + C == d(A) + end + + @test repeated_vars == symbolic_rewriting(repeated_vars) + + # TODO: This is broken because of the terminals issue in #77 + self_changing = @decapode begin + c_exp == ∂ₜ(c_exp) + end + + @test_broken repeated_vars == symbolic_rewriting(self_changing) + + literal = @decapode begin + A::Form0 + B::Form0 + + B == A * 2 + end + + @test literal == symbolic_rewriting(literal) + + parameter = @decapode begin + A::Form0 + P::Parameter + B::Form0 + + B == A * P + end + + @test parameter == symbolic_rewriting(parameter) + + constant = @decapode begin + A::Form0 + C::Constant + B::Form0 + + B == A * C + end + + @test constant == symbolic_rewriting(constant) +end + +function expr_rewriter(rules::Vector) + return Fixpoint(Prewalk(Fixpoint(Chain(rules)))) +end + +@testset "Basic Rewriting" begin + op1s = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == B + B + d(d(A)) + end + + dd_0 = @rule d(d(~x)) => 0 + + op1s_rewritten = symbolic_rewriting(op1s, expr_rewriter([dd_0])) + + op1s_equiv = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == 2 * B + end + + @test op1s_equiv == op1s_rewritten + + + op2s = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(∧(A, B), C) + end + + wdg_assoc = @rule ∧(∧(~x, ~y), ~z) => ∧(~x, ∧(~y, ~z)) + + op2s_rewritten = symbolic_rewriting(op2s, expr_rewriter([wdg_assoc])) + + op2s_equiv = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(A, ∧(B, C)) + end + infer_types!(op2s_equiv) + + @test op2s_equiv == op2s_rewritten + + + distr_d = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == d(∧(A, B)) + end + infer_types!(distr_d) + + leibniz = @rule d(∧(~x, ~y)) => ∧(d(~x), ~y) + ∧(~x, d(~y)) + + distr_d_rewritten = symbolic_rewriting(distr_d, expr_rewriter([leibniz])) + + distr_d_res = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == ∧(d(A), B) + ∧(A, d(B)) + end + infer_types!(distr_d_res) + + @test distr_d_res == distr_d_rewritten +end + +@testset "Heat" begin + Heat = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*Δ(C) + end + infer_types!(Heat) + + # Same up to re-naming + Heat[5, :name] = Symbol("•1") + @test Heat == symbolic_rewriting(Heat) + + Heat_open = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*★(d(★(d(C)))) + end + infer_types!(Heat_open) + + Heat_open[8, :name] = Symbol("•1") + Heat_open[5, :name] = Symbol("•2") + Heat_open[6, :name] = Symbol("•3") + Heat_open[7, :name] = Symbol("•4") + + @test Heat_open ≃ symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) +end + +@testset "Phytodynamics" begin + Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) + end + infer_types!(Phytodynamics) + test_phy = symbolic_rewriting(Phytodynamics) +end + +@testset "Literals" begin + Heat = parse_decapode(quote + C::Form0 + G::Form0 + ∂ₜ(G) == 3*Δ(C) + end) + context = SymbolicContext(Heat) + SummationDecapode(context) + +end + +@testset "Parameters" begin + + Heat = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == Δ(u)*κ + end + infer_types!(Heat) + + Heat_open = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == ★(d(★(d(u))))*κ + end + infer_types!(Heat_open) + + Heat_open[7, :name] = Symbol("•4") + Heat_open[8, :name] = Symbol("•1") + + z = symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) + @test Heat_open ≃ z + +end diff --git a/test/composition.jl b/test/composition.jl index f0c5c02..408cf37 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -2,15 +2,12 @@ using Test using DiagrammaticEquations using DiagrammaticEquations.Deca using Catlab -using Catlab.WiringDiagrams -using Catlab.Programs -using Catlab.CategoricalAlgebra # import DiagrammaticEquations: OpenSummationDecapode, Open, oapply, oapply_rename # @testset "Composition" begin # Simplest possible decapode relation. -Trivial = @decapode begin +Trivial = @decapode begin H::Form0{X} end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 0b5b7e9..e9bb0df 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -6,31 +6,49 @@ using SymbolicUtils using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle +import DiagrammaticEquations: rules + # load up some variable variables and expressions -a, b = @syms a::Scalar b::Scalar -u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} -ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} -ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} +ϐ, = @syms ϐ::InferredType # \varbeta +ℓ, = @syms ℓ::Literal +c, t = @syms c::Const t::Parameter +a, b = @syms a::Scalar b::Scalar +u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication - @testset "Term Construction" begin - + + @test symtype(ϐ) == InferredType + @test symtype(ℓ) == Literal + @test symtype(c) == Const + @test symtype(t) == Parameter @test symtype(a) == Scalar + @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @test symtype(η) == DualForm{2, :X, 2} @test symtype(ϕ) == PrimalVF{:X, 2} @test symtype(ψ) == DualVF{:X, 2} + @test symtype(c + t) == Scalar + @test symtype(t + t) == Scalar + @test symtype(c + c) == Scalar + @test symtype(t + ϐ) == InferredType + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + @test symtype(u ∧ ϐ) == InferredType + # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} # test unary operator conversion to decaexpr @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) @test Term(∂ₜ(u)) == Tan(Var(:u)) @test Term(★(ω)) == App1(:★₁, Var(:ω)) @@ -44,95 +62,97 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + @test promote_symtype(-, a) == Scalar + @test promote_symtype(-, u, u) == PrimalForm{0, :X, 2} # test composition @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} end -@testset "Operator definition" begin +# this is not nabla but "bizarro Δ" +del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin + @match S begin + PatScalar(_) => error("Argument of type $S is invalid") + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + end + @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) + @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) +end; + +# we will test is new operator +(r0, r1, r2) = @operator ρ(S)::DECQuantity begin + S <: Form ? Scalar : Form + @rule ρ(~x::isForm0) => 0 + @rule ρ(~x::isForm1) => 1 + @rule ρ(~x::isForm2) => 2 +end + +R, = @operator φ(S1, S2, S3)::DECQuantity begin + let T1=S1, T2=S2, T3=S3 + Scalar + end + @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) +end - # this is not nabla but "bizarro Δ" - del_expand_0, del_expand_1 = - @operator ∇(S)::DECQuantity begin - @match S begin - PatScalar(_) => error("Argument of type $S is invalid") - PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) - end - @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) - @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) - end; +@alias (φ′,) => φ +@testset "Operator definition" begin + + # ∇ @test_throws Exception ∇(b) @test symtype(∇(u)) == PrimalForm{0, :X ,2} @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} - @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) - - # we will test is new operator - (r0, r1, r2) = @operator ρ(S)::DECQuantity begin - if S <: Form - Scalar - else - Form - end - @rule ρ(~x::isForm0) => 0 - @rule ρ(~x::isForm1) => 1 - @rule ρ(~x::isForm2) => 2 - end - + + # ρ @test symtype(ρ(u)) == Scalar - R, = @operator φ(S1, S2, S3)::DECQuantity begin - let T1=S1, T2=S2, T3=S3 - Scalar - end - @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) - end - - # TODO we need to alias rewriting rules - @alias (φ′,) => φ - + # R @test isequal(R(φ(2u,2u,2u)), R(φ′(2u,2u,2u))) + # TODO we need to alias rewriting rules end @testset "Conversion" begin - context = Dict(:a => Scalar(),:b => Scalar() - ,:u => PrimalForm(0, X),:du => PrimalForm(1, X)) - js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] - eqs = [Eq(Var(:∂ₜu) - , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - , Eq(Tan(Var(:u)), Var(:∂ₜu))] - heat_eq = DecaExpr(js, eqs) - - symb_heat_eq = DecaSymbolic(lookup, heat_eq) - deca_expr = DecaExpr(symb_heat_eq) - -end + Exp = @decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u + end + context = SymbolicContext(Term(Exp)) + Exp′ = SummationDecapode(DecaExpr(context)) -@testset "Moving between DecaExpr and DecaSymbolic" begin - - @test js == deca_expr.context - - # eqs in the left has AppCirc1[vector, term] - # deca_expr.equations on the right has nested App1 - # expected behavior is that nested AppCirc1 is preserved - @test_broken eqs == deca_expr.equations - # use expand_operators to get rid of parentheses - # infer_types and resolve_overloads - -end + # does roundtripping work + @test Exp == Exp′ -# convert both into ACSets then is_iso them -@testset "" begin + Heat = @decapode begin + u::Form0 + v::Form0 + κ::Constant + ∂ₜ(v) == Δ(u)*κ + end + infer_types!(Heat) + context = SymbolicContext(Term(Heat)) + Heat′ = SummationDecapode(DecaExpr(context)) - Σ = DiagrammaticEquations.SummationDecapode(deca_expr) - Δ = DiagrammaticEquations.SummationDecapode(symb_heat_eq) - @test Σ == Δ + @test Heat == Heat′ + TumorInvasion = @decapode begin + (C,fC)::Form0 + (Dif,Kd,Cmax)::Constant + ∂ₜ(C) == Dif * Δ(C) + fC - C * Kd + end + infer_types!(TumorInvasion) + context = SymbolicContext(Term(TumorInvasion)) + TumorInvasion′ = SummationDecapode(DecaExpr(context)) + + # new terms introduced because Symbolics converts subtraction expressions + # e.g., a - b => +(a, -b) + @test_broken TumorInvasion == TumorInvasion′ + # TI' has (11, Literal, -1) and (12, infer, mult_1) + # Op1 (2, 1, 4, 7) should be (2, 4, 1, 7) + # Sum is (1, 6), (2, 10) end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl new file mode 100644 index 0000000..b09fd24 --- /dev/null +++ b/test/graph_traversal.jl @@ -0,0 +1,64 @@ +using DiagrammaticEquations +using ACSets +using MLStyle +using Test + +function is_correct_length(d::SummationDecapode, result) + return length(result) == n_ops(d) +end + +@testset "Topological Sort on Edges" begin + no_edge = @decapode begin + F == S + end + @test isempty(topological_sort_edges(no_edge)) + + one_op1_deca = @decapode begin + F == f(S) + end + result = topological_sort_edges(one_op1_deca) + @test is_correct_length(one_op1_deca, result) + @test retrieve_name(one_op1_deca, only(result)) == :f + + multi_op1_deca = @decapode begin + F == c(b(a(S))) + end + result = topological_sort_edges(multi_op1_deca) + @test is_correct_length(multi_op1_deca, result) + for (edge, test_name) in zip(result, [:a, :b, :c]) + @test retrieve_name(multi_op1_deca, edge) == test_name + end + + cyclic = @decapode begin + B == g(A) + A == f(B) + end + @test_throws AssertionError topological_sort_edges(cyclic) + + just_op2 = @decapode begin + C == A * B + end + result = topological_sort_edges(just_op2) + @test is_correct_length(just_op2, result) + @test retrieve_name(just_op2, only(result)) == :* + + just_simple_sum = @decapode begin + C == A + B + end + result = topological_sort_edges(just_simple_sum) + @test is_correct_length(just_simple_sum, result) + @test retrieve_name(just_simple_sum, only(result)) == :+ + + just_multi_sum = @decapode begin + F == A + B + C + D + E + end + result = topological_sort_edges(just_multi_sum) + @test is_correct_length(just_multi_sum, result) + @test retrieve_name(just_multi_sum, only(result)) == :+ + + op_combo = @decapode begin + F == h(d(A) + f(g(B) * C) + D) + end + result = topological_sort_edges(op_combo) + @test is_correct_length(op_combo, result) +end diff --git a/test/language.jl b/test/language.jl index 3e3ab4e..25f817e 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1,11 +1,5 @@ using Test using Catlab -using Catlab.Theories -using Catlab.CategoricalAlgebra -using Catlab.WiringDiagrams -using Catlab.WiringDiagrams.DirectedWiringDiagrams -using Catlab.Graphics -using Catlab.Programs using LinearAlgebra using MLStyle using Base.Iterators @@ -356,13 +350,14 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, + DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, + INFER_TYPES, NONINFERABLE_TYPES @testset "Type Retrival" begin type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] # No repeated types @@ -374,12 +369,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_ no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) - @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) - @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + @test equal_types(ALL_TYPES, FORM_TYPES ∪ VECTORFIELD_TYPES ∪ NON_EC_TYPES) + @test equal_types(FORM_TYPES, PRIMALFORM_TYPES ∪ DUALFORM_TYPES) + @test equal_types(NONINFERABLE_TYPES, USER_TYPES ∪ NUMBER_TYPES) # Proper seperation of types - @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(FORM_TYPES ∪ VECTORFIELD_TYPES, NON_EC_TYPES) @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) @test INFER_TYPES == [:infer] @@ -400,9 +395,9 @@ end import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] bad_sources = [:Literal, :Constant, :Parameter] - good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer] + good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer] for tgt in all_types for src in bad_sources @@ -425,13 +420,13 @@ import DiagrammaticEquations: safe_modifytype end end -import DiagrammaticEquations: filterfor_forms +import DiagrammaticEquations: filterfor_ec_types @testset "Form Type Retrieval" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] - @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] - @test isempty(filterfor_forms(Symbol[])) - @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] + @test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF] + @test isempty(filterfor_ec_types(Symbol[])) + @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin @@ -831,6 +826,38 @@ end @test_throws "Type mismatch in summation" infer_types!(d) end + # Test #25: Infer between flattened and sharpened vector fields. + let + d = @decapode begin + A::Form1 + B::DualForm1 + C::PVF + D::DVF + + A == ♭(E) + B == ♭(F) + C == ♯(G) + D == ♯(H) + + I::Form1 + J::DualForm1 + K::PVF + L::DVF + + M == ♯(I) + N == ♯(J) + O == ♭(K) + P == ♭(L) + end + infer_types!(d) + + # TODO: Update this as more sharps and flats are released. + names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF), + (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), + (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), + (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) + @test test_nametype_equality(d, names_types_expected) + end end @testset "Overloading Resolution" begin @@ -1048,6 +1075,42 @@ end op2s_hx = HeatXfer[:op2] op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] @test op2s_hx == op2s_expected_hx + + # Infer types and resolve overloads for the Halfar equation. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_types!(d) + resolve_overloads!(d) + @test d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + end @testset "Compilation Transformation" begin diff --git a/test/runtests.jl b/test/runtests.jl index dd92531..c68bded 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,4 +38,13 @@ end include("openoperators.jl") end +@testset "Symbolic Rewriting" begin + include("graph_traversal.jl") + include("acset2symbolic.jl") +end + +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end + include("aqua.jl")