diff --git a/Project.toml b/Project.toml index ab55d8e..69f41e4 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,9 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] @@ -16,5 +19,8 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +Reexport = "1.2.2" +StructEquality = "2.1.0" +SymbolicUtils = "3.1.2" Unicode = "1.6" julia = "1.6" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index ef5f2d6..558a839 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -38,11 +38,13 @@ using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs using Catlab.CategoricalAlgebra +import Catlab.CategoricalAlgebra: ∧ using Catlab.WiringDiagrams using Catlab.WiringDiagrams.DirectedWiringDiagrams using Catlab.ACSetInterface using MLStyle import Unicode +using Reexport ## TODO: ## generate schema from a _theory_ @@ -62,9 +64,15 @@ include("rewrite.jl") include("pretty.jl") include("colanguage.jl") include("openoperators.jl") +include("symbolictheoryutils.jl") +include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") +include("SymbolicUtilsInterop.jl") -using .Deca +@reexport using .Deca +@reexport using .SymbolicUtilsInterop + +include("acset2symbolic.jl") end diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl new file mode 100644 index 0000000..bee2291 --- /dev/null +++ b/src/SymbolicUtilsInterop.jl @@ -0,0 +1,157 @@ +module SymbolicUtilsInterop + +using ACSets +using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp +using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! +import ..DiagrammaticEquations: eval_eq!, SummationDecapode +using ..decapodes +using ..Deca + +using MLStyle +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype + +# name collision with decapodes.Equation +struct SymbolicEquation{E} + lhs::E + rhs::E +end +export SymbolicEquation + +Base.show(io::IO, e::SymbolicEquation) = print(io, "$(e.lhs) == $(e.rhs)") + +## a struct carry the symbolic variables and their equations +struct SymbolicContext + vars::Vector{Symbolic} + equations::Vector{SymbolicEquation{Symbolic}} +end +export SymbolicContext + +Base.show(io::IO, d::SymbolicContext) = begin + println(io, "SymbolicContext(") + println(io, " Variables: [$(join(d.vars, ", "))]") + println(io, " Equations: [") + eqns = map(d.equations) do op + " $(op)" + end + println(io, "$(join(eqns,",\n"))])") + end + +## BasicSymbolic -> DecaExpr +function decapodes.Term(t::SymbolicUtils.BasicSymbolic) + if SymbolicUtils.issym(t) + decapodes.Var(nameof(t)) + else + op = SymbolicUtils.head(t) + args = SymbolicUtils.arguments(t) + termargs = Term.(args) + if op == + + decapodes.Plus(termargs) + elseif op == * + decapodes.Mult(termargs) + elseif op ∈ [DerivOp, ∂ₜ] + decapodes.Tan(only(termargs)) + elseif length(args) == 1 + decapodes.App1(nameof(op, symtype.(args)...), termargs...) + elseif length(args) == 2 + decapodes.App2(nameof(op, symtype.(args)...), termargs...) + else + error("was unable to convert $t into a Term") + end + end +end +# TODO subtraction is not parsed as such. e.g., +# a, b = @syms a::Scalar b::Scalar +# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)])) + +decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) + +function decapodes.DecaExpr(d::SymbolicContext) + context = map(d.vars) do var + decapodes.Judgement(nameof(var), nameof(symtype(var)), :I) + end + equations = map(d.equations) do eq + decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) + end + decapodes.DecaExpr(context, equations) +end + +""" +Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC + +Example: +``` + a = @syms a::Real + context = Dict(:a => Scalar(), :u => PrimalForm(0)) + SymbolicUtils.BasicSymbolic(context, Term(a)) +``` +""" +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term) + # user must import symbols into scope + ! = (f -> getfield(@__MODULE__, f)) + @match t begin + Var(name) => SymbolicUtils.Sym{context[name]}(name) + Lit(v) => Meta.parse(string(v)) + # see heat_eq test: eqs had AppCirc1, but this returns + # App1(f, App1(...) + AppCirc1(fs, arg) => foldr( + # panics with constants like :k + # see test/language.jl + (f, x) -> (!(f))(x), + fs; + init=BasicSymbolic(context, arg) + ) + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => (!(DerivOp))(BasicSymbolic(context, x)) + end +end + +function SymbolicContext(d::decapodes.DecaExpr) + # associates each var to its sort... + context = map(d.context) do j + j.var => symtype(Deca.DECQuantity, j.dim, j.space) + end + # ... which we then produce a vector of symbolic vars + vars = map(context) do (v, s) + SymbolicUtils.Sym{s}(v) + end + context = Dict{Symbol,DataType}(context) + eqs = map(d.equations) do eq + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) + end + SymbolicContext(vars, eqs) +end + +function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) + eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) +end + +""" function SummationDecapode(e::SymbolicContext) """ +function SummationDecapode(e::SymbolicContext) + d = SummationDecapode{Any, Any, Symbol}() + symbol_table = Dict{Symbol, Int}() + + foreach(e.vars) do var + # convert Sort(var)::PrimalForm0 --> :Form0 + var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var))) + symbol_table[var.name] = var_id + end + + deletions = Vector{Int}() + foreach(e.equations) do eq + eval_eq!(eq, d, symbol_table, deletions) + end + rem_parts!(d, :Var, sort(deletions)) + + recognize_types(d) + + fill_names!(d) + d[:name] = normalize_unicode.(d[:name]) + make_sum_mult_unique!(d) + return d +end + +end diff --git a/src/acset.jl b/src/acset.jl index 34d1d3c..4cfdaac 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -187,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end +export recognize_types """ is_expanded(d::AbstractNamedDecapode) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl new file mode 100644 index 0000000..a32988b --- /dev/null +++ b/src/acset2symbolic.jl @@ -0,0 +1,83 @@ +using DiagrammaticEquations +using ACSets +using SymbolicUtils +using SymbolicUtils: BasicSymbolic, Symbolic + +export symbolic_rewriting + +const EQUALITY = (==) +const SymEqSym = SymbolicEquation{Symbolic} + +function symbolics_lookup(d::SummationDecapode) + Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type + (name, decavar_to_symbolics(name, type)) + end) +end + +function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) + SymbolicUtils.Sym{new_type}(var_name) +end + +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) + + S = promote_symtype(op_sym, input_syms...) + SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) +end + +function to_symbolics(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d)) +end + +function symbolic_rewriting(d::SummationDecapode, rewriter=identity) + d′ = infer_types!(deepcopy(d)) + eqns = merge_equations(d′) + to_acset(d′, map(rewriter, eqns)) +end + +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ +function merge_equations(d::SummationDecapode) + eqn_lookup, terminal_eqns = Dict(), SymEqSym[] + deriv_op_tgts = d[incident(d, DerivOp, :op1), [:tgt, :name]] # Patches over issue #77 + terminal_vars = Set{Symbol}(vcat(infer_terminal_names(d), deriv_op_tgts)) + + foreach(to_symbolics(d)) do x + sub = SymbolicUtils.substitute(x.rhs, eqn_lookup) + push!(eqn_lookup, (x.lhs => sub)) + if x.lhs.name in terminal_vars + push!(terminal_eqns, SymEqSym(x.lhs, sub)) + end + end + + map(terminal_eqns) do eqn + SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs]) + end +end + +function to_acset(d::SummationDecapode, sym_exprs) + literals = incident(d, :Literal, :type) + + outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i + :($(d[i, :name])::$(d[i, :type])) + end + + #TODO: This step is breaking up summations + final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) + reify!(exprs) = foreach(exprs) do x + if typeof(x) == Expr && x.head == :call + x.args[1] = nameof(x.args[1]) + reify!(x.args[2:end]) + end + end + reify!(final_exprs) + + deca_block = quote end + deca_block.args = [outer_types..., final_exprs...] + + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) +end diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 8202704..3187189 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -4,12 +4,17 @@ using DataStructures using ..DiagrammaticEquations using Catlab +using Reexport + import ..infer_types!, ..resolve_overloads! export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec! include("deca_acset.jl") include("deca_visualization.jl") +include("ThDEC.jl") + +@reexport using .ThDEC """ function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64}) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl new file mode 100644 index 0000000..3fc40ae --- /dev/null +++ b/src/deca/ThDEC.jl @@ -0,0 +1,331 @@ +module ThDEC + +using ..DiagrammaticEquations: @operator, @alias, Quantity +import ..DiagrammaticEquations: rules + +using MLStyle +using StructEquality +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, Term +import SymbolicUtils: symtype, promote_symtype + +import Base: +, -, * +import Catlab: Δ, ∧ + +# ########################## +# DECQuantity +# +# Type necessary for symbolic utils +# ########################## + +abstract type DECQuantity <: Quantity end +export DECQuantity + +# this ensures symtype doesn't recurse endlessly +SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S + +abstract type AbstractScalar <: DECQuantity end + +struct InferredType <: DECQuantity end +export InferredType + +struct Scalar <: AbstractScalar end +struct Parameter <: AbstractScalar end +struct Const <: AbstractScalar end +struct Literal <: AbstractScalar end +export Scalar, Parameter, Const, Literal + +struct FormParams + dim::Int + duality::Bool + space::Symbol + spacedim::Int +end + +dim(fp::FormParams) = fp.dim +duality(fp::FormParams) = fp.duality +space(fp::FormParams) = fp.space +spacedim(fp::FormParams) = fp.spacedim + +""" +i: dimension: 0,1,2, etc. +d: duality: true = dual, false = primal +s: name of the space (a symbol) +n: dimension of the space +""" +struct Form{i,d,s,n} <: DECQuantity end +export Form + +# parameter accessors +dim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = i +isdual(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = d +space(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = s +spacedim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = n + +export dim, isdual, space, spacedim + +# convert form to form params +FormParams(::Type{<:Form{i,d,s,n}}) where {i,s,d,n} = FormParams(i,d,s,n) + +struct VField{d,s,n} <: DECQuantity end +export VField + +# parameter accessors +isdual(::Type{<:VField{d,s,n}}) where {d,s,n} = d +space(::Type{VField{d,s,n}}) where {d,s,n} = s +spacedim(::Type{VField{d,s,n}}) where {d,s,n} = n + +# convenience functions +const PrimalForm{i,s,n} = Form{i,false,s,n} +export PrimalForm + +const DualForm{i,s,n} = Form{i,true,s,n} +export DualForm + +const PrimalVF{s,n} = VField{false,s,n} +export PrimalVF + +const DualVF{s,n} = VField{true,s,n} +export DualVF + +Base.nameof(u::Type{<:PrimalForm}) = Symbol("Form$(dim(u))") +Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm$(dim(u))") + +Base.show(io::IO, ω::Type{<:Form}) = print(io, "$(Base.nameof(ω)) on $(space(ω))") + +# ACTIVE PATTERNS + +@active PatInferredType(T) begin + if T <: InferredType + Some(InferredType) + end +end + +@active PatInferredTypes(T) begin + if any(S->S<:InferredType, T) + Some(InferredType) + end +end + +@active PatForm(T) begin + if T <: Form + Some(T) + end +end +export PatForm + +@active PatFormParams(T) begin + if T <: Form + Some([T.parameters...]) + end +end +export PatFormParams + +@active PatFormDim(T) begin + if T <: Form + Some(dim(T)) + end +end +export PatFormDim + +@active PatScalar(T) begin + if T <: AbstractScalar + Some(T) + end +end +export PatScalar + +@active PatVFParams(T) begin + if T <: VField + Some([T.parameters...]) + end +end +export PatVFParams + +isForm(x) = @match symtype(x) begin + PatFormParams([_,_,_,_]) => true + _ => false +end + +isDualForm(x) = @match symtype(x) begin + PatFormParams([_,d,_,_]) => d + _ => false +end + +# TODO parameterize? +isForm0(x) = @match symtype(x) begin + PatFormParams([0,_,_,_]) => true + _ => false +end + +isForm1(x) = @match symtype(x) begin + PatFormParams([1,_,_,_]) => true + _ => false +end + +isForm2(x) = @match symtype(x) begin + PatFormParams([2,_,_,_]) => true + _ => false +end + +export isDualForm, isForm0, isForm1, isForm2 + +# ############################### +# OPERATORS +# ############################### + +@operator -(S)::DECQuantity begin S end + +@operator ∂ₜ(S)::DECQuantity begin S end + +@operator d(S)::DECQuantity begin + @match S begin + PatInferredType(_) => InferredType + PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} + _ => throw(OperatorError("take the exterior derivative", S)) + end +end + +@alias (d₀, d₁) => d + +@operator ★(S)::DECQuantity begin + @match S begin + PatInferredType(_) => InferredType + PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} + _ => throw(OperatorError("take the hodge star", S)) + end +end + +# TODO in orthodox Decapodes, these are type-specific. +@alias (★₀, ★₁, ★₂, ★₀⁻¹, ★₁⁻¹, ★₂⁻¹) => ★ + +@operator Δ(S)::DECQuantity begin + @match S begin + PatInferredType(_) => InferredType + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + _ => throw(OperatorError("take the Laplacian", S)) + end + @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) + @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) + @rule Δ(~x::isForm2) => d(★(d(★(~x)))) +end + +@alias (Δ₀, Δ₁, Δ₂) => Δ + +# TODO: Determine what we need to do for .+ +@operator +(S1, S2)::DECQuantity begin + @match (S1, S2) begin + PatInferredTypes(_) => InferredType + (PatScalar(_), PatScalar(_)) => Scalar + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i, d, s, n} # commutativity + (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin + if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) + Form{i1, d1, s1, n1} + else + throw(OperatorError("sum", [S1, S2])) + end + end + _ => throw(OperatorError("add", [S1, S2])) + end +end + +@operator -(S1, S2)::DECQuantity begin + promote_symtype(+, S1, S2) +end + +@operator *(S1, S2)::DECQuantity begin + @match (S1, S2) begin + PatInferredTypes(_) => InferredType + (PatScalar(_), PatScalar(_)) => Scalar + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} + _ => throw(OperatorError("multiply", [S1, S2])) + end +end + +@alias (∧₀₀, ∧₀₁, ∧₁₀, ∧₁₁, ∧₀₂, ∧₂₀) => ∧ + +@operator ∧(S1, S2)::DECQuantity begin + @match (S1, S2) begin + PatInferredTypes(_) => InferredType + (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin + (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(OperatorError("take the wedge product", [S1, S2])) + if i1 + i2 <= n1 + Form{i1 + i2, d1, s1, n1} + else + throw(OperatorError("take the wedge product", [S1, S2], "The dimensions of the form are bounded by $n1")) + end + end + _ => throw(OperatorError("take the wedge product", [S1, S2])) + end +end + +abstract type SortError <: Exception end + +Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal +Base.nameof(s::Union{Const, Type{Const}}) = :Constant +Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter +Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar +Base.nameof(s::Union{InferredType, Type{InferredType}}) = :infer + + +Base.nameof(::typeof(-), args...) = Symbol("-") + +const SUBSCRIPT_DIGIT_0 = '₀' + +as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) +sub_dim(args...) = all(isForm.(args)) ? join(as_sub.(dim.(args))) : "" + +Base.nameof(::typeof(∧), s1, s2) = Symbol("∧$(sub_dim(s1, s2))") + +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ") + +function Base.nameof(::typeof(d), s) + dual = isdual(s) ? "dual_" : "" + Symbol("$(dual)d$(sub_dim(s))") +end + +#TODO: Add naming for dual +function Base.nameof(::typeof(Δ), s) + Symbol("Δ$(sub_dim(s))") +end + +function Base.nameof(::typeof(★), s) + inv = isdual(s) ? "⁻¹" : "" + hdg_dim = isdual(s) ? spacedim(s) - dim(s) : dim(s) + Symbol("★$(as_sub(hdg_dim))$(inv)") +end + +# TODO: Check that form type is no larger than the ambient dimension +function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, dim::Int = 2) + @match qty begin + :Scalar => Scalar + :Constant => Const + :Parameter => Parameter + :Literal => Literal + :Form0 => PrimalForm{0, space, dim} + :Form1 => PrimalForm{1, space, dim} + :Form2 => PrimalForm{2, space, dim} + :DualForm0 => DualForm{0, space, dim} + :DualForm1 => DualForm{1, space, dim} + :DualForm2 => DualForm{2, space, dim} + :infer => InferredType + _ => error("Received $qty which is not a valid type for Decapodes") + end +end + +struct OperatorError <: SortError + verb::String + sorts::Vector{DataType} + othermsg::String + function OperatorError(verb::String, sorts::Vector{DataType}, othermsg::String="") + new(verb, sorts, othermsg) + end + function OperatorError(verb::String, sort::DataType, othermsg::String="") + new(verb, [sort], othermsg) + end +end +export OperatorError + +Base.showerror(io::IO, e::OperatorError) = print(io, "Cannot take the $(e.verb) of $(join(e.sorts, " and ")). $(e.othermsg)") + +end diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 4334c4b..e0f5580 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -5,7 +5,7 @@ using ..DiagrammaticEquations These are the default rules used to do type inference in the 1D exterior calculus. """ op1_inf_rules_1D = [ - # Rules for ∂ₜ + # Rules for ∂ₜ (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ,:dt]), (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ,:dt]), @@ -14,10 +14,10 @@ op1_inf_rules_1D = [ (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm1, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm0, op_names = [:⋆, :⋆₁, :star]), - (src_type = :DualForm1, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :Form0, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :DualForm1, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -29,7 +29,7 @@ op1_inf_rules_1D = [ # Rules for negation (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - + # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), @@ -39,7 +39,7 @@ op1_inf_rules_1D = [ # Rules for ♭. (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), - + # Rules for magnitude/ norm (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] @@ -52,7 +52,7 @@ op2_inf_rules_1D = [ # Rules for L₀, L₁ (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:L, :L₀]), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), + (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:i, :i₁]), @@ -69,10 +69,10 @@ op2_inf_rules_1D = [ (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*]),=# - + (proj1_type = :Form0, proj2_type = :Literal, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Literal, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :DualForm0, proj2_type = :Literal, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm1, proj2_type = :Literal, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -86,7 +86,7 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form0, proj2_type = :Constant, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Constant, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -112,13 +112,13 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm2, op_names = [:d, :dual_d₁, :d̃₁]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm2, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm1, op_names = [:⋆, :⋆₁, :star]), - (src_type = :Form2, tgt_type = :DualForm0, op_names = [:⋆, :⋆₂, :star]), + (src_type = :Form0, tgt_type = :DualForm2, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :Form2, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₂, :star]), - (src_type = :DualForm2, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm1, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form2, op_names = [:⋆, :⋆₂⁻¹, :star_inv]), + (src_type = :DualForm2, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm1, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form2, op_names = [:★, :⋆, :⋆₂⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -154,7 +154,7 @@ op1_inf_rules_2D = [ # Rules for magnitude/ norm (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] - + op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, op_names = [:∧, :∧₁₁, :wedge]), @@ -162,7 +162,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, op_names = [:∧, :∧₀₂, :wedge]), # Rules for L₂ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), + (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:i, :i₂]), @@ -173,7 +173,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ι (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:ι₁₁]), (proj1_type = :DualForm1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:ι₁₂]), - + # Rules for subtraction (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:-, :.-]), (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:-, :.-]), @@ -212,7 +212,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for Δ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ)] - + # We merge 1D and 2D rules since it seems op2 rules are metric-free. If # this assumption is false, this needs to change. op2_res_rules_1D = [ @@ -228,8 +228,8 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, resolved_name = :L₁, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, resolved_name = :i₁, op = :i)] - - + + """ These are the default rules used to do function resolution in the 2D exterior calculus. """ @@ -274,7 +274,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)] - + # We merge 1D and 2D rules directly here since it seems op2 rules # are metric-free. If this assumption is false, this needs to change. op2_res_rules_2D = vcat(op2_res_rules_1D, [ @@ -290,7 +290,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂ᵈ, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, resolved_name = :i₂, op = :i)]) - + # TODO: When SummationDecapodes are annotated with the degree of their space, # use dispatch to choose the correct set of rules. infer_types!(d::SummationDecapode) = @@ -358,4 +358,3 @@ Resolve function overloads based on types of src and tgt. """ resolve_overloads!(d::SummationDecapode) = resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) - diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl new file mode 100644 index 0000000..e41048e --- /dev/null +++ b/src/graph_traversal.jl @@ -0,0 +1,73 @@ +using DiagrammaticEquations +using ACSets + +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function + +struct TraversalNode{T} + index::Int + name::T +end + +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + +#XXX: This topological sort is O(n^2). +function topological_sort_edges(d::SummationDecapode) + visited_Var = falses(nparts(d, :Var)) + visited_Var[start_nodes(d)] .= true + visited = Dict(:Op1 => falses(nparts(d, :Op1)), + :Op2 => falses(nparts(d, :Op2)), :Σ => falses(nparts(d, :Σ))) + + op_order = TraversalNode{Symbol}[] + + function visit(op, op_type) + if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))]) + visited[op_type][op] = true + visited_Var[edge_output(d,op,Val(op_type))] = true + push!(op_order, TraversalNode(op, op_type)) + end + end + + for _ in 1:n_ops(d) + visit.(parts(d,:Op1), :Op1) + visit.(parts(d,:Op2), :Op2) + visit.(parts(d,:Σ), :Σ) + end + + @assert length(op_order) == n_ops(d) + op_order +end + +n_ops(d::SummationDecapode) = + nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) + +start_nodes(d::SummationDecapode) = + vcat(infer_states(d), incident(d, :Literal, :type)) + +function retrieve_name(d::SummationDecapode, tsr::TraversalNode) + @match tsr.name begin + :Op1 => d[tsr.index, :op1] + :Op2 => d[tsr.index, :op2] + :Σ => :+ + _ => error("$(tsr.name) is a table without names") + end +end + diff --git a/src/language.jl b/src/language.jl index 5a06081..e0f7d71 100644 --- a/src/language.jl +++ b/src/language.jl @@ -51,8 +51,8 @@ function parse_decapode(expr::Expr) end DecaExpr(judges, eqns) end + # to_Decapode helper functions -### TODO - Matt: we need to generalize this reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) = let ! = reduce_term! @match t begin diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl new file mode 100644 index 0000000..7746559 --- /dev/null +++ b/src/symbolictheoryutils.jl @@ -0,0 +1,173 @@ +using MLStyle +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype +import SymbolicUtils: promote_symtype + +function rules end +export rules + +# TODO: Probable piracy here +function promote_symtype(f::ComposedFunction, args) + promote_symtype(f.outer, promote_symtype(f.inner, args)) +end + +@active PatBlock(e) begin + @match e begin + Expr(:macrocall, name, args...) && if name == Symbol("@", "match") end => Some(e) + Expr(:block, args...) => Some(e) + Expr(:let, args...) => Some(e) + _ => nothing + end +end +export PatBlock + +@active PatRule(e) begin + @match e begin + Expr(:macrocall, head, args...) && if head == Symbol("@", "rule") end => Some(e) + _ => nothing + end +end +export PatRule + +""" DECQuantities in DiagrammaticEquations must be subtypes of Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers. In the context of SymbolicUtils, a Number is any type that you can do arithmetic operations on. +""" +abstract type Quantity <: Number end +export Quantity + +""" +Creates an operator `foo` with arguments which are types in a given Theory. This entails creating (1) a function which performs type construction and (2) a function which consumes BasicSymbolic variables and returns Terms. + +``` +@operator foo(S1, S2, ...)::Theory begin + (body of function) + (@rule expr1) + ... + (@rule exprN) +end +``` +builds +``` +promote_symtype(::typeof{f}, ::Type{S1}, ::Type{S2}, ...) where {S1<:DECQuantity, S2<:DECQuantity, ...} + (body of function) +end +``` +as well as +``` +foo(S1, S2, ...) where {S1<:DECQuantity, ...} + s = promote_symtype(f, S1, S2, ...) + SymbolicUtils.Term{s}(foo, [S1, S2, ...]) +end +``` + +Example: +``` +@operator Δ(s)::DECQuantity begin + @match s begin + ::Scalar => error("Invalid") + ::VField => error("Invalid") + ::Form => ⋆(d(⋆(d(s)))) + end + @rule ~s --> ⋆(d(⋆(d(~s)))) +end +``` +""" +macro operator(head, body) + + # parse head + ph = @λ begin + Expr(:call, foo, Expr(:(::), vars..., theory)) => (foo, vars, theory) + Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) + _ => error("$head") + end + (f, types, Theory) = ph(head) + + # Passing types to functions requires that we type the signature with ::Type{T}. + # This means that the user would have to write `my_op(::Type{T1}, ::Type{T2}, ...)` + # As a convenience to the user, we allow them to specify the signature using just the types themselves: + # `my_op(T1, T2, ...)` + sort_types = [:(::Type{$S}) for S in types] + sort_constraints = [:($S<:$Theory) for S in types] + arity = length(sort_types) + + # Parse the body for @rule calls. + block, rulecalls = @match Base.remove_linenums!(body) begin + Expr(:block, block, rules...) => (block, rules) + s => nothing + end + + # initialize the result + result = quote end + + # construct the function on basic symbolics + argnames = [gensym(:x) for _ in 1:arity] + argclaus = [:($a::Symbolic) for a in argnames] + push!(result.args, quote + @nospecialize + function $f($(argclaus...)) + s = promote_symtype($f, $(argnames...)) + SymbolicUtils.Term{s}($f, Any[$(argnames...)]) + end + + export $f + + # Base.show(io::IO, ::typeof($f)) = print(io, $f) + end) + + # if there are rewriting rules, add a method which accepts the function symbol and its arity + # (to prevent shadowing on operators like `-`) + if !isempty(rulecalls) + push!(result.args, quote + function rules(::typeof($f), ::Val{$arity}) + [($(rulecalls...))] + end + + rules(::typeof($f)) = rules($f, Val{1}) + end) + end + + # we want to feed symtype the generics + push!(result.args, quote + function SymbolicUtils.promote_symtype(::typeof($f), + $(sort_types...)) where {$(sort_constraints...)} + $block + end + function SymbolicUtils.promote_symtype(::typeof($f), args::Vararg{Symbolic, $arity}) + promote_symtype($f, symtype.(args)...) + end + end) + + push!(result.args, Expr(:tuple, rulecalls...)) + + return esc(result) +end +export @operator + +""" +Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces +for each alias typechecking and nameof methods which call those for their rep. +Example: +@alias (d₀, d₁) => d +""" +macro alias(body) + (rep, aliases) = @match body begin + Expr(:call, :(=>), Expr(:tuple, aliases...), rep) => (rep, aliases) + _ => error("parse error") + end + result = quote end + foreach(aliases) do alias + push!(result.args, + quote + function $alias(s...) + $rep(s...) + end + export $alias + + Base.nameof(::typeof($alias), s) = Symbol("$alias") + end) + end + return esc(result) +end +export @alias + +alias(x) = error("$x has no aliases") +export alias diff --git a/test/Project.toml b/test/Project.toml index 97b4819..4f3a02a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,4 +6,5 @@ CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" DiagrammaticEquations = "6f00c28b-6bed-4403-80fa-30e0dc12f317" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl new file mode 100644 index 0000000..e01ca9f --- /dev/null +++ b/test/acset2symbolic.jl @@ -0,0 +1,286 @@ +using Test +using DiagrammaticEquations +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, @rule +using Catlab + +(≃) = is_isomorphic + +@testset "Basic Roundtrip" begin + op1_only = @decapode begin + A::Form0 + B::Form1 + B == d(A) + end + + @test op1_only == symbolic_rewriting(op1_only) + + + op2_only = @decapode begin + A::Constant + B::Form0 + C::Form0 + + C == A * B + end + + @test op2_only == symbolic_rewriting(op2_only) + + sum_only = @decapode begin + A::Form2 + B::Form2 + C::Form2 + + C == A + B + end + + @test sum_only == symbolic_rewriting(sum_only) + + + multi_sum = @decapode begin + A::Form2 + B::Form2 + C::Form2 + D::Form2 + + D == (A + B) + C + end + infer_types!(multi_sum) + + # TODO: This is correct but the symbolics is splitting up the sum + @test multi_sum == symbolic_rewriting(multi_sum) + + + all_ops = @decapode begin + A::Constant + B::Form0 + C::Form1 + D::Form1 + E::Form1 + F::Form1 + + C == d(B) + D == A * C + F == D + E + end + + # This loses the intermediate names C and D + all_ops_res = symbolic_rewriting(all_ops) + all_ops_res[5, :name] = :D + all_ops_res[6, :name] = :C + @test all_ops ≃ all_ops_res + + with_deriv = @decapode begin + A::Form0 + Ȧ::Form0 + + ∂ₜ(A) == Ȧ + Ȧ == Δ(A) + end + + @test with_deriv == symbolic_rewriting(with_deriv) + + repeated_vars = @decapode begin + A::Form0 + B::Form1 + C::Form1 + + C == d(A) + C == Δ(B) + C == d(A) + end + + @test repeated_vars == symbolic_rewriting(repeated_vars) + + # TODO: This is broken because of the terminals issue in #77 + self_changing = @decapode begin + c_exp == ∂ₜ(c_exp) + end + + @test self_changing == symbolic_rewriting(self_changing) + + literal = @decapode begin + A::Form0 + B::Form0 + + B == A * 2 + end + + @test literal == symbolic_rewriting(literal) + + parameter = @decapode begin + A::Form0 + P::Parameter + B::Form0 + + B == A * P + end + + @test parameter == symbolic_rewriting(parameter) + + constant = @decapode begin + A::Form0 + C::Constant + B::Form0 + + B == A * C + end + + @test constant == symbolic_rewriting(constant) +end + +function expr_rewriter(rules::Vector) + return Fixpoint(Prewalk(Fixpoint(Chain(rules)))) +end + +@testset "Basic Rewriting" begin + op1s = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == B + B + d(d(A)) + end + + dd_0 = @rule d(d(~x)) => 0 + + op1s_rewritten = symbolic_rewriting(op1s, expr_rewriter([dd_0])) + + op1s_equiv = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == 2 * B + end + + @test op1s_equiv == op1s_rewritten + + + op2s = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(∧(A, B), C) + end + + wdg_assoc = @rule ∧(∧(~x, ~y), ~z) => ∧(~x, ∧(~y, ~z)) + + op2s_rewritten = symbolic_rewriting(op2s, expr_rewriter([wdg_assoc])) + + op2s_equiv = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(A, ∧(B, C)) + end + infer_types!(op2s_equiv) + + @test op2s_equiv == op2s_rewritten + + + distr_d = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == d(∧(A, B)) + end + infer_types!(distr_d) + + leibniz = @rule d(∧(~x, ~y)) => ∧(d(~x), ~y) + ∧(~x, d(~y)) + + distr_d_rewritten = symbolic_rewriting(distr_d, expr_rewriter([leibniz])) + + distr_d_res = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == ∧(d(A), B) + ∧(A, d(B)) + end + infer_types!(distr_d_res) + + @test distr_d_res == distr_d_rewritten +end + +@testset "Heat" begin + Heat = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*Δ(C) + end + infer_types!(Heat) + + # Same up to re-naming + Heat[5, :name] = Symbol("•1") + @test Heat == symbolic_rewriting(Heat) + + Heat_open = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*★(d(★(d(C)))) + end + infer_types!(Heat_open) + + Heat_open[8, :name] = Symbol("•1") + Heat_open[5, :name] = Symbol("•2") + Heat_open[6, :name] = Symbol("•3") + Heat_open[7, :name] = Symbol("•4") + + @test Heat_open ≃ symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) +end + +@testset "Phytodynamics" begin + Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) + end + infer_types!(Phytodynamics) + test_phy = symbolic_rewriting(Phytodynamics) +end + +@testset "Literals" begin + Heat = parse_decapode(quote + C::Form0 + G::Form0 + ∂ₜ(G) == 3*Δ(C) + end) + context = SymbolicContext(Heat) + SummationDecapode(context) + +end + +@testset "Parameters" begin + + Heat = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == Δ(u)*κ + end + infer_types!(Heat) + + Heat_open = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == ★(d(★(d(u))))*κ + end + infer_types!(Heat_open) + + Heat_open[7, :name] = Symbol("•4") + Heat_open[8, :name] = Symbol("•1") + + z = symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) + @test Heat_open ≃ z + +end diff --git a/test/aqua.jl b/test/aqua.jl index 555f584..38129f7 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -1,5 +1,5 @@ using Aqua, DiagrammaticEquations @testset "Code quality (Aqua.jl)" begin # TODO: fix ambiguities - Aqua.test_all(DiagrammaticEquations, ambiguities=false) + Aqua.test_all(DiagrammaticEquations, ambiguities=false, undefined_exports=false, piracies=false) end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl new file mode 100644 index 0000000..436ceb1 --- /dev/null +++ b/test/decasymbolic.jl @@ -0,0 +1,336 @@ +using Test +using DiagrammaticEquations +using DiagrammaticEquations.Deca.ThDEC +using DiagrammaticEquations.decapodes +using SymbolicUtils +using SymbolicUtils: symtype, promote_symtype, Symbolic +using MLStyle + +import DiagrammaticEquations: rules + +# load up some variable variables and expressions +ϐ, = @syms ϐ::InferredType # \varbeta +ℓ, = @syms ℓ::Literal +c, t = @syms c::Const t::Parameter +a, b = @syms a::Scalar b::Scalar +u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +h, = @syms h::PrimalForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} +# TODO would be nice to pass the space globally to avoid duplication + +u2, = @syms u2::PrimalForm{0, :Y, 2} +u3, = @syms u3::PrimalForm{0, :X, 3} + +@testset "Symtypes" begin + + @test symtype(ϐ) == InferredType + @test symtype(ℓ) == Literal + @test symtype(c) == Const + @test symtype(t) == Parameter + @test symtype(a) == Scalar + + @test symtype(u) == PrimalForm{0, :X, 2} + @test symtype(ω) == PrimalForm{1, :X, 2} + @test symtype(η) == DualForm{2, :X, 2} + @test symtype(ϕ) == PrimalVF{:X, 2} + @test symtype(ψ) == DualVF{:X, 2} + + @test symtype(c + t) == Scalar + @test symtype(t + t) == Scalar + @test symtype(c + c) == Scalar + @test symtype(t + ϐ) == InferredType + + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} + @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + @test symtype(u ∧ ϐ) == InferredType + + # @test_throws ThDEC.SortError ThDEC.♯(u) + @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} + + @test symtype(-(ϐ)) == InferredType + @test symtype(d(ϐ)) == InferredType + @test symtype(★(ϐ)) == InferredType + @test symtype(Δ(ϐ)) == InferredType + + @test symtype(ϐ + ϐ) == InferredType + @test symtype(ϐ + u) == InferredType + @test_throws OperatorError symtype(u + du + ϐ) + # The order of the addition counts when type checking, probably applies + # to other operations as well + @test_broken (try symtype(ϐ + u + du) catch e; e; end) isa Exception + + @test symtype(ϐ - u) == InferredType + @test symtype(ω * ϐ) == InferredType + @test symtype(ϐ ∧ ω) == InferredType +end + +@testset "Type Information" begin + + @test dim(symtype(η)) == 2 + + @test isdual(symtype(η)) + @test !isdual(symtype(u)) + + @test space(symtype(η)) == :X + @test spacedim(symtype(η)) == 2 + + @test isForm0(u) + @test !isForm0(du) + @test !isForm0(a) + + @test isForm1(ω) + @test !isForm1(u) + @test !isForm1(a) + + @test isForm2(η) + @test !isForm2(u) + @test !isForm2(a) + + @test isDualForm(η) + @test !isDualForm(u) + @test !isDualForm(a) +end + +@testset "Nameof" begin + @test nameof(symtype(c)) == :Constant + @test nameof(symtype(t)) == :Parameter + @test nameof(symtype(a)) == :Scalar + @test nameof(symtype(ℓ)) == :Literal + + @test nameof(symtype(u)) == :Form0 + @test nameof(symtype(ω)) == :Form1 + @test nameof(symtype(η)) == :DualForm2 + + @test nameof(-, symtype(u)) == Symbol("-") + @test nameof(-, symtype(u), symtype(u)) == Symbol("-") + @test nameof(-, symtype(a), symtype(b)) == Symbol("-") + + @test nameof(∧, symtype(u), symtype(u)) == Symbol("∧₀₀") + @test nameof(∧, symtype(u), symtype(ω)) == Symbol("∧₀₁") + @test nameof(∧, symtype(ω), symtype(u)) == Symbol("∧₁₀") + + # TODO: Do we need a special designation for wedges with duals in them? + @test nameof(∧, symtype(ω), symtype(η)) == Symbol("∧₁₂") + + @test nameof(∂ₜ, symtype(u)) == Symbol("∂ₜ") + @test nameof(∂ₜ, symtype(d(u))) == Symbol("∂ₜ") + + @test nameof(d, symtype(u)) == Symbol("d₀") + @test nameof(d, symtype(η)) == Symbol("dual_d₂") + + @test nameof(Δ, symtype(u)) == Symbol("Δ₀") + @test nameof(Δ, symtype(ω)) == Symbol("Δ₁") + + @test nameof(★, symtype(u)) == Symbol("★₀") + @test nameof(★, symtype(ω)) == Symbol("★₁") + @test nameof(★, symtype(η)) == Symbol("★₀⁻¹") +end + +@testset "Symtype Promotion" begin + # test promoting types + @test promote_symtype(d, u) == PrimalForm{1, :X, 2} + @test promote_symtype(+, a, b) == Scalar + @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} + @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + @test promote_symtype(-, a) == Scalar + @test promote_symtype(-, u, u) == PrimalForm{0, :X, 2} + + # test composition + @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} +end + +@testset "Term Construction" begin + + # test unary operator conversion to decaexpr + @test Term(1) == Lit(Symbol("1")) + @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) + @test Term(∂ₜ(u)) == Tan(Var(:u)) + @test_broken Term(∂ₜ(u)) == Term(DerivOp(u)) + + @test Term(★(ω)) == App1(:★₁, Var(:ω)) + @test Term(★(η)) == App1(:★₀⁻¹, Var(:η)) + + # The symbolics no longer captures ∘ + @test Term(∘(d, d)(u)) == App1(:d₁, App1(:d₀, Var(:u))) + + # test binary operator conversion to decaexpr + @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) + + # TODO: Currently parses as addition + @test_broken Term(a - b) == App2(:-, Var(:a), Var(:b)) + @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) + @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + + @test Term(ω + du + d(u)) == Plus(Term[App1(:d₀, Var(:u)), Var(:du), Var(:ω)]) + + let + @syms f(x, y, z) + @test_throws "was unable to convert" Term(f(a, b, u)) + end + +end + + +# this is not nabla but "bizarro Δ" +del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin + @match S begin + PatScalar(_) => error("Argument of type $S is invalid") + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + end + @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) + @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) +end; + +# we will test is new operator +(r0, r1, r2) = @operator ρ(S)::DECQuantity begin + S <: Form ? Scalar : Form + @rule ρ(~x::isForm0) => 0 + @rule ρ(~x::isForm1) => 1 + @rule ρ(~x::isForm2) => 2 +end + +R, = @operator φ(S1, S2, S3)::DECQuantity begin + let T1=S1, T2=S2, T3=S3 + Scalar + end + @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) +end + +@alias (φ′,) => φ + +@testset "Operator definition" begin + + # ∇ + @test_throws Exception ∇(b) + @test symtype(∇(u)) == PrimalForm{0, :X ,2} + @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} + @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) + + # ρ + @test symtype(ρ(u)) == Scalar + + # R + @test isequal(R(φ(2u,2u,2u)), R(φ′(2u,2u,2u))) + # TODO we need to alias rewriting rules + +end + +@testset "Errors" begin + + # addition + @test_throws OperatorError u + du # mismatched grade + @test_throws OperatorError h + η # primal and dual + @test_throws OperatorError u + u2 # differing spaces + @test_throws OperatorError u + u3 # differing codimension + + # subtraction + @test_throws OperatorError u - du # mismatched grade + @test_throws OperatorError h - η # primal and dual + @test_throws OperatorError u - u2 # differing spaces + @test_throws OperatorError u - u3 # differing spatial dimension + + # exterior derivative + @test_throws OperatorError d(a) + @test_throws OperatorError d(ϕ) + + # hodge star + @test_throws OperatorError ★(a) + @test_throws OperatorError ★(ϕ) + + # Laplacian + @test_throws OperatorError Δ(a) + @test_throws OperatorError Δ(ϕ) + + # multiplication + @test_throws OperatorError u * du + @test_throws OperatorError ϕ * u + + @test_throws OperatorError du ∧ h # checks if spaces exceed dimension + @test_throws OperatorError a ∧ a # cannot take wedge of scalars + @test_throws OperatorError u ∧ ϕ # cannot take wedge of vector fields + +end + +@testset "Conversion" begin + + roundtrip(d::SummationDecapode) = SummationDecapode(DecaExpr(SymbolicContext(Term(d)))) + + just_vars = @decapode begin + u::Form0 + v::Form0 + end + @test just_vars == roundtrip(just_vars) + + with_tan = @decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u + end + @test with_tan == roundtrip(with_tan) + + with_op2 = @decapode begin + u::Form0 + v::Form1 + w::Form1 + + w == ∧₀₁(u, v) + end + @test with_op2 == roundtrip(with_op2) + + with_mult = @decapode begin + u::Form1 + v::Form1 + + v == u * 2 + end + @test with_mult == roundtrip(with_mult) + + with_circ = @decapode begin + u::Form0 + v::Form2 + v == ∘(d₁, d₀)(u) + end + with_circ_expanded = @decapode begin + u::Form0 + v::Form2 + v == d₁(d₀(u)) + end + @test with_circ_expanded == roundtrip(with_circ) + + with_infers = @decapode begin + v::Form1 + + w == ∧(v, u) + end + # Base.nameof doesn't yet support taking InferredTypes + @test with_infers == roundtrip(with_infers) + + Heat = @decapode begin + u::Form0 + v::Form0 + κ::Constant + ∂ₜ(v) == Δ₀(u)*κ + end + infer_types!(Heat) + @test Heat == roundtrip(Heat) + + TumorInvasion = @decapode begin + (C,fC)::Form0 + (Dif,Kd,Cmax)::Constant + ∂ₜ(C) == Dif * Δ(C) + fC - C * Kd + end + infer_types!(TumorInvasion) + context = SymbolicContext(Term(TumorInvasion)) + TumorInvasion′ = SummationDecapode(DecaExpr(context)) + + # new terms introduced because Symbolics converts subtraction expressions + # e.g., a - b => +(a, -b) + @test_broken TumorInvasion == TumorInvasion′ + # TI' has (11, Literal, -1) and (12, infer, mult_1) + # Op1 (2, 1, 4, 7) should be (2, 4, 1, 7) + # Sum is (1, 6), (2, 10) + +end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl new file mode 100644 index 0000000..b09fd24 --- /dev/null +++ b/test/graph_traversal.jl @@ -0,0 +1,64 @@ +using DiagrammaticEquations +using ACSets +using MLStyle +using Test + +function is_correct_length(d::SummationDecapode, result) + return length(result) == n_ops(d) +end + +@testset "Topological Sort on Edges" begin + no_edge = @decapode begin + F == S + end + @test isempty(topological_sort_edges(no_edge)) + + one_op1_deca = @decapode begin + F == f(S) + end + result = topological_sort_edges(one_op1_deca) + @test is_correct_length(one_op1_deca, result) + @test retrieve_name(one_op1_deca, only(result)) == :f + + multi_op1_deca = @decapode begin + F == c(b(a(S))) + end + result = topological_sort_edges(multi_op1_deca) + @test is_correct_length(multi_op1_deca, result) + for (edge, test_name) in zip(result, [:a, :b, :c]) + @test retrieve_name(multi_op1_deca, edge) == test_name + end + + cyclic = @decapode begin + B == g(A) + A == f(B) + end + @test_throws AssertionError topological_sort_edges(cyclic) + + just_op2 = @decapode begin + C == A * B + end + result = topological_sort_edges(just_op2) + @test is_correct_length(just_op2, result) + @test retrieve_name(just_op2, only(result)) == :* + + just_simple_sum = @decapode begin + C == A + B + end + result = topological_sort_edges(just_simple_sum) + @test is_correct_length(just_simple_sum, result) + @test retrieve_name(just_simple_sum, only(result)) == :+ + + just_multi_sum = @decapode begin + F == A + B + C + D + E + end + result = topological_sort_edges(just_multi_sum) + @test is_correct_length(just_multi_sum, result) + @test retrieve_name(just_multi_sum, only(result)) == :+ + + op_combo = @decapode begin + F == h(d(A) + f(g(B) * C) + D) + end + result = topological_sort_edges(op_combo) + @test is_correct_length(op_combo, result) +end diff --git a/test/klausmeier.jl b/test/klausmeier.jl new file mode 100644 index 0000000..814cf73 --- /dev/null +++ b/test/klausmeier.jl @@ -0,0 +1,68 @@ +using DiagrammaticEquations +using DiagrammaticEquations.SymbolicUtilsInterop +# +using Test +using MLStyle +using SymbolicUtils +using SymbolicUtils: BasicSymbolic, symtype + +# See Klausmeier Equation 2.a +Hydrodynamics = @decapode begin + (n,w)::DualForm0 + dX::Form1 + (a,ν)::Constant + # + ∂ₜ(w) == a - w - w * n^2 + ν * L(dX, w) +end + +# See Klausmeier Equation 2.b +Phytodynamics = @decapode begin + (n,w)::DualForm0 + m::Constant + # + ∂ₜ(n) == w * n^2 - m*n + Δ(n) +end + +Hydrodynamics = parse_decapode(quote + (n,w)::DualForm0 + dX::Form1 + (a,ν)::Constant + # + ∂ₜ(w) == a - w - w + ν * L(dX, w) +end) + +# See Klausmeier Equation 2.b +Phytodynamics = parse_decapode(quote + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w - m*n + Δ(n) +end) + +ps = SymbolicContext(Phytodynamics) +dexpr = DecaExpr(ps) +ps′ = SymbolicContext(dexpr) +# TODO variables are the same but the equations don't match + +n = ps.vars[1] +SymbolicUtils.symtype(n) +Δ(n) + +r, _ = rules(Δ, Val(1)); + +t2 = r(Δ(n)) +t2 |> dump + +using SymbolicUtils.Rewriters +using SymbolicUtils: promote_symtype +r = @rule ★(★(~n)) => ~n + +nested_star_cancel = Postwalk(Chain([r])) +nested_star_cancel(d(★(★(n)))) +nsc = nested_star_cancel + +@test isequal(nsc(★(★(d(n)))), d(n)) +dump(nsc(★(★(d(n))))) +dump(d(n)) +★(★(d(★(★(n))))) +nsc(★(★(d(★(★(n)))))) +nsc(nsc(★(★(d(★(★(n))))))) diff --git a/test/runtests.jl b/test/runtests.jl index 0cb0f30..c68bded 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,8 +2,6 @@ using Test include("pretty.jl") -include("aqua.jl") - @testset "Core" begin include("core.jl") end @@ -39,3 +37,14 @@ end @testset "Open Operators" begin include("openoperators.jl") end + +@testset "Symbolic Rewriting" begin + include("graph_traversal.jl") + include("acset2symbolic.jl") +end + +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end + +include("aqua.jl") diff --git a/todo.md b/todo.md deleted file mode 100644 index 351cc88..0000000 --- a/todo.md +++ /dev/null @@ -1,3 +0,0 @@ -# Refactoring Notes -Decapodey things still in DEq: -* DerivOp is referenced in src/language