From c2d64b580f8a478e0e37bbdd5966e587120dbaac Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Wed, 14 Aug 2024 16:46:13 -0700 Subject: [PATCH 01/30] symbolicutils interop --- Project.toml | 2 + src/DiagrammaticEquations.jl | 2 + src/ThDEC.jl | 270 +++++++++++++++++++++++++++++++++++ src/decasymbolic.jl | 194 +++++++++++++++++++++++++ 4 files changed, 468 insertions(+) create mode 100644 src/ThDEC.jl create mode 100644 src/decasymbolic.jl diff --git a/Project.toml b/Project.toml index 21b3563..8ad3aef 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] @@ -16,5 +17,6 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +SymbolicUtils = "3.1.2" Unicode = "1.6" julia = "1.6" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 0d0de25..d226390 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -61,6 +61,8 @@ include("colanguage.jl") include("openoperators.jl") include("deca/Deca.jl") include("learn/Learn.jl") +include("ThDEC.jl") +include("decasymbolic.jl") using .Deca diff --git a/src/ThDEC.jl b/src/ThDEC.jl new file mode 100644 index 0000000..52f5198 --- /dev/null +++ b/src/ThDEC.jl @@ -0,0 +1,270 @@ +module ThDEC +using MLStyle + +import Base: +, -, * + +struct SortError <: Exception + message::String +end + +@data Sort begin + Scalar() + Form(dim::Int, isdual::Bool) + VField(isdual::Bool) +end +export Sort, Scalar, Form, VField + +const SORT_LOOKUP = Dict( + :Form0 => Form(0, false), + :Form1 => Form(1, false), + :Form2 => Form(2, false), + :DualForm0 => Form(0, true), + :DualForm1 => Form(1, true), + :DualForm2 => Form(2, true), + :Constant => Scalar() +) + +function Base.nameof(s::Scalar) + :Constant +end + +function Base.nameof(f::Form) + dual = isdual(f) ? "Dual" : "" + Symbol("$(dual)Form$(dim(f))") +end + +const VF = VField + +dim(ω::Form) = ω.dim +isdual(ω::Form) = ω.isdual + +isdual(v::VField) = v.isdual + +# convenience functions +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm + +PrimalVF() = VF(false) +export PrimalVF + +DualVF() = VF(true) +export DualVF + +# show methods +show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" + +function Base.show(io::IO, ω::Form) + print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +end + +@nospecialize +function +(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end + end +end + +# Type-checking inverse of addition follows addition +-(s1::Sort, s2::Sort) = +(s1, s2) + +# TODO error for Forms + +# Negation is always valid +-(s::Sort) = s + +@nospecialize +function *(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + end +end + +const SUBSCRIPT_DIGIT_0 = '₀' + +function as_sub(n::Int) + join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) +end + +@nospecialize +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) + (Form(i1, isdual), Form(i2, isdual)) => + if i1 + i2 <= 2 + Form(i1 + i2, isdual) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + _ => throw(SortError("Can only take a wedge product of two forms of the same duality")) + end +end + +function Base.nameof(::typeof(∧), s1, s2) + Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") +end + +@nospecialize +∂ₜ(s::Sort) = s + +@nospecialize +function d(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) + else + throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + end + end +end + +function Base.nameof(::typeof(d), s) + Symbol("d$(as_sub(dim(s)))") +end + +@nospecialize +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end + +function Base.nameof(::typeof(★), s) + inv = isdual(s) ? "⁻¹" : "" + Symbol("★$(as_sub(isdual(s) ? 2 - dim(s) : dim(s)))$(inv)") +end + +@nospecialize +function ι(s1::Sort, s2::Sort) + @match (s1, s2) begin + (VF(true), Form(i, true)) => PrimalForm() # wrong + (VF(true), Form(i, false)) => DualForm() + _ => throw(SortError("Can only define the discrete interior product on: + PrimalVF, DualForm(i) + DualVF(), PrimalForm(i) + .")) + end +end + +# in practice, a scalar may be treated as a constant 0-form. +function ♯(s::Sort) + @match s begin + Scalar() => PrimalVF() + Form(1, isdual) => VF(isdual) + _ => throw(SortError("Can only take ♯ to 1-forms")) + end +end +# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. + +function ♭(s::Sort) + @match s begin + VF(true) => PrimalForm(1) + _ => throw(SortError("Can only apply ♭ to dual vector fields")) + end +end + +# OTHER + +function ♭♯(s::Sort) + @match s begin + Form(i, isdual) => Form(i, !isdual) + _ => throw(SortError("♭♯ is only defined on forms.")) + end +end + +# Δ = ★d⋆d, but we check signature here to throw a more helpful error +function Δ(s::Sort) + @match s begin + Form(0, isdual) => Form(0, isdual) + _ => throw(SortError("Δ is not defined for $s")) + end +end + +const OPERATOR_LOOKUP = Dict( + :⋆₀ => ★, + :⋆₁ => ★, + :⋆₂ => ★, + + # Inverse Hodge Stars + :⋆₀⁻¹ => ★, + :⋆₁⁻¹ => ★, + :⋆₂⁻¹ => ★, + + # Differentials + :d₀ => d, + :d₁ => d, + + # Dual Differentials + :dual_d₀ => d, + :d̃₀ => d, + :dual_d₁ => d, + :d̃₁ => d, + + # Wedge Products + :∧₀₁ => ∧, + :∧₁₀ => ∧, + :∧₀₂ => ∧, + :∧₂₀ => ∧, + :∧₁₁ => ∧, + + # Primal-Dual Wedge Products + :∧ᵖᵈ₁₁ => ∧, + :∧ᵖᵈ₀₁ => ∧, + :∧ᵈᵖ₁₁ => ∧, + :∧ᵈᵖ₁₀ => ∧, + + # Dual-Dual Wedge Products + :∧ᵈᵈ₁₁ => ∧, + :∧ᵈᵈ₁₀ => ∧, + :∧ᵈᵈ₀₁ => ∧, + + # Dual-Dual Interior Products + :ι₁₁ => ι, + :ι₁₂ => ι, + + # Dual-Dual Lie Derivatives + # :ℒ₁ => ℒ, + + # Dual Laplacians + # :Δᵈ₀ => Δ, + # :Δᵈ₁ => Δ, + + # Musical Isomorphisms + :♯ => ♯, + :♯ᵈ => ♯, :♭ => ♭, + + # Averaging Operator + # :avg₀₁ => avg, + + # Negatives + :neg => -, + + # Basics + + :- => -, + :+ => +, + :* => *, + :/ => /, + :.- => .-, + :.+ => .+, + :.* => .*, + :./ => ./, +) + +end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl new file mode 100644 index 0000000..3b86fa6 --- /dev/null +++ b/src/decasymbolic.jl @@ -0,0 +1,194 @@ +module SymbolicUtilInterop + +using ..ThDEC +using MLStyle +import ..ThDEC: Sort, dim, isdual +using ..decapodes +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic + +abstract type DECType <: Number end + +""" +i: dimension: 0,1,2, etc. +d: duality: true = dual, false = primal +""" +struct FormT{i,d} <: DECType +end + +struct VFieldT{d} <: DECType +end + +dim(::Type{<:FormT{d}}) where {d} = d +isdual(::Type{FormT{i,d}}) where {i,d} = d + +# convenience functions +const PrimalFormT{i} = FormT{i,false} +export PrimalFormT + +const DualFormT{i} = FormT{i,true} +export DualFormT + +const PrimalVFT = VFieldT{false} +export PrimalVFT + +const DualVFT = VFieldT{true} +export DualVFT + +function Sort(::Type{FormT{i,d}}) where {i,d} + Form(i, d) +end + +function Number(f::Form) + FormT{dim(f),isdual(f)} +end + +function Sort(::Type{VFieldT{d}}) where {d} + VField(d) +end + +function Number(v::VField) + VFieldT{isdual(v)} +end + +function Sort(::Type{<:Real}) + Scalar() +end + +function Number(s::Scalar) + Real +end + +function Sort(::BasicSymbolic{T}) where {T} + Sort(T) +end + +function Sort(::Real) + Scalar() +end + +unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] +for unop in unop_dec + @eval begin + @nospecialize + function ThDEC.$unop( + v::BasicSymbolic{T} + ) where {T<:DECType} + s = ThDEC.$unop(Sort(T)) + SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v]) + end + end +end + +binop_dec = [:+, :-, :*, :∧] +for binop in binop_dec + @eval begin + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:DECType,T2<:DECType} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:DECType,T2<:Real} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:Real,T2<:DECType} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + end +end + +struct Equation{E} + lhs::E + rhs::E +end + +struct DecaSymbolic + vars::Vector{Symbolic} + equations::Vector{Equation{Symbolic}} +end + +function decapodes.Term(t::SymbolicUtils.BasicSymbolic) + if SymbolicUtils.issym(t) + decapodes.Var(nameof(t)) + else + op = SymbolicUtils.head(t) + args = SymbolicUtils.arguments(t) + termargs = Term.(args) + sorts = ThDEC.Sort.(args) + if op == + + decapodes.Plus(termargs) + elseif op == * + decapodes.Mult(termargs) + elseif op == ThDEC.∂ₜ + decapodes.Tan(only(termargs)) + elseif length(args) == 1 + decapodes.App1(nameof(op, sorts...), termargs...) + elseif length(args) == 2 + decapodes.App2(nameof(op, sorts...), termargs...) + else + error("was unable to convert $t into a Term") + end + end +end + +function decapodes.Term(x::Real) + decapodes.Lit(Symbol(x)) +end + +function decapodes.DecaExpr(d::DecaSymbolic) + context = map(d.vars) do var + decapodes.Judgement(nameof(var), nameof(Sort(var)), :I) + end + equations = map(d.equations) do eq + decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) + end + decapodes.DecaExpr(context, equations) +end + +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) + @match t begin + Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) + Lit(v) => Meta.parse(string(v)) # YOLO + AppCirc1(fs, arg) => foldr( + (f, x) -> ThDEC.OPERATOR_LOOKUP[f](x), + fs; + init=BasicSymbolic(context, arg) + ) + App1(f, x) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x)) + App2(f, x, y) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => ThDEC.∂ₜ(BasicSymbolic(context, x)) + end +end + +function DecaSymbolic(d::decapodes.DecaExpr) + context = map(d.context) do j + j.var => ThDEC.SORT_LOOKUP[j.dim] + end + vars = map(context) do (v, s) + SymbolicUtils.Sym{Number(s)}(v) + end + context = Dict{Symbol,Sort}(context) + eqs = map(d.equations) do eq + Equation{Symbolic}(BasicSymbolic(context, eq.lhs), BasicSymbolic(context, eq.rhs)) + end + DecaSymbolic(vars, eqs) +end + +end From c11d4dda9dca29fef83b718ee58918b94cf275b6 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Fri, 16 Aug 2024 16:51:54 -0700 Subject: [PATCH 02/30] added spaces to the sorts of forms and vector fields --- src/ThDEC.jl | 103 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/src/ThDEC.jl b/src/ThDEC.jl index 52f5198..d416cbf 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -7,22 +7,45 @@ struct SortError <: Exception message::String end +@struct_hash_equal struct Space + name::Symbol + dim::Int +end + +dim(s::Space) = s.dim +nameof(s::Space) = s.name + +struct SpaceLookup + default::Space + named::Dict{Symbol,Space} +end + @data Sort begin Scalar() - Form(dim::Int, isdual::Bool) - VField(isdual::Bool) + Form(dim::Int, isdual::Bool, space::Space) + VField(isdual::Bool, space::Space) end export Sort, Scalar, Form, VField -const SORT_LOOKUP = Dict( - :Form0 => Form(0, false), - :Form1 => Form(1, false), - :Form2 => Form(2, false), - :DualForm0 => Form(0, true), - :DualForm1 => Form(1, true), - :DualForm2 => Form(2, true), - :Constant => Scalar() -) +function fromexpr(lookup::SpaceLookup, e, ::Type{Sort}) + (name, spacename) = @match e begin + name::Symbol => (name, nothing) + :(Form0{$}) + end + space = @match spacename begin + ::Nothing => lookup.default + name::Symbol => lookup.named[name] + end + @match name begin + :Form0 => Form(0, false, space), + :Form1 => Form(1, false, space), + :Form2 => Form(2, false, space), + :DualForm0 => Form(0, true, space), + :DualForm1 => Form(1, true, space), + :DualForm2 => Form(2, true, space), + :Constant => Scalar() + end +end function Base.nameof(s::Scalar) :Constant @@ -30,15 +53,18 @@ end function Base.nameof(f::Form) dual = isdual(f) ? "Dual" : "" - Symbol("$(dual)Form$(dim(f))") + formname = Symbol("$(dual)Form$(dim(f))") + Expr(:curly, formname, dim(space(f))) end const VF = VField dim(ω::Form) = ω.dim isdual(ω::Form) = ω.isdual +space(ω::Form) = ω.space isdual(v::VField) = v.isdual +space(v::VField) = v.space # convenience functions PrimalForm(i::Int) = Form(i, false) @@ -60,17 +86,23 @@ function Base.show(io::IO, ω::Form) print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") end +# TODO: VField @nospecialize function +(s1::Sort, s2::Sort) @match (s1, s2) begin (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || - (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(i1, isdual1), Form(i2, isdual2)) => - if (i1 == i2) && (isdual1 == isdual2) + (Scalar(), Form(i, isdual, space)) || + (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) + (Form(i1, isdual1, space2), Form(i2, isdual2, space2)) => + if (i1 == i2) && (isdual1 == isdual2) && (space1 == space2) Form(i1, isdual1) else - throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + throw(SortError( + """ + Can not add two forms of different dimensions/dualities/spaces: + $((i1,isdual1,space1)) and $((i2,isdual2,space2)) + """) + ) end end end @@ -83,13 +115,15 @@ end # Negation is always valid -(s::Sort) = s +# TODO: VField @nospecialize function *(s1::Sort, s2::Sort) @match (s1, s2) begin (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || - (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + (Scalar(), Form(i, isdual, space)) || + (Form(i, isdual, space), Scalar()) => Form(i, isdual) + (Form(_, _, _), Form(_, _, _)) => + throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) end end @@ -99,17 +133,20 @@ function as_sub(n::Int) join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) end +# TODO: VField @nospecialize function ∧(s1::Sort, s2::Sort) @match (s1, s2) begin - (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) - (Form(i1, isdual), Form(i2, isdual)) => - if i1 + i2 <= 2 - Form(i1 + i2, isdual) + (Form(i, isdual, space), Scalar()) || (Scalar(), Form(i, isdual, space)) => + Form(i, isdual, space) + (Form(i1, isdual, space), Form(i2, isdual, space)) => begin + if i1 + i2 <= dim(space) + Form(i1 + i2, isdual, space) else throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) end - _ => throw(SortError("Can only take a wedge product of two forms of the same duality")) + end + _ => throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) end end @@ -124,11 +161,11 @@ end function d(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) - Form(i, isdual) => - if i <= 1 - Form(i + 1, isdual) + Form(i, isdual, space) => + if i < dim(space) + Form(i + 1, isdual, space) else - throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + throw(SortError("Cannot take exterior derivative of a k-form for k >= n, where n = $(dim(space)) is the dimension of its ambient space")) end end end @@ -141,13 +178,13 @@ end function ★(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) - Form(i, isdual) => Form(2 - i, !isdual) + Form(i, isdual, space) => Form(dim(space) - i, !isdual, space) end end function Base.nameof(::typeof(★), s) inv = isdual(s) ? "⁻¹" : "" - Symbol("★$(as_sub(isdual(s) ? 2 - dim(s) : dim(s)))$(inv)") + Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end @nospecialize @@ -183,7 +220,7 @@ end function ♭♯(s::Sort) @match s begin - Form(i, isdual) => Form(i, !isdual) + Form(i, isdual, space) => Form(i, !isdual, space) _ => throw(SortError("♭♯ is only defined on forms.")) end end @@ -191,7 +228,7 @@ end # Δ = ★d⋆d, but we check signature here to throw a more helpful error function Δ(s::Sort) @match s begin - Form(0, isdual) => Form(0, isdual) + Form(0, isdual, space) => Form(0, isdual, space) _ => throw(SortError("Δ is not defined for $s")) end end From 7de10e81403e54a8c93abcfc0e216d9e38f6b202 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Fri, 16 Aug 2024 17:27:11 -0700 Subject: [PATCH 03/30] integrated changes to ThDEC with DecaSymbolic --- Project.toml | 2 ++ src/ThDEC.jl | 37 ++++++++++++++++++++----------------- src/decasymbolic.jl | 28 ++++++++++++++++------------ 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index 8ad3aef..d2d57ce 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" @@ -17,6 +18,7 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +StructEquality = "2.1.0" SymbolicUtils = "3.1.2" Unicode = "1.6" julia = "1.6" diff --git a/src/ThDEC.jl b/src/ThDEC.jl index d416cbf..9e5f3e2 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -1,5 +1,6 @@ module ThDEC using MLStyle +using StructEquality import Base: +, -, * @@ -11,9 +12,10 @@ end name::Symbol dim::Int end +export Space dim(s::Space) = s.dim -nameof(s::Space) = s.name +Base.nameof(s::Space) = s.name struct SpaceLookup default::Space @@ -30,19 +32,19 @@ export Sort, Scalar, Form, VField function fromexpr(lookup::SpaceLookup, e, ::Type{Sort}) (name, spacename) = @match e begin name::Symbol => (name, nothing) - :(Form0{$}) + :($name{$spacename}) => (name, spacename) end space = @match spacename begin ::Nothing => lookup.default name::Symbol => lookup.named[name] end @match name begin - :Form0 => Form(0, false, space), - :Form1 => Form(1, false, space), - :Form2 => Form(2, false, space), - :DualForm0 => Form(0, true, space), - :DualForm1 => Form(1, true, space), - :DualForm2 => Form(2, true, space), + :Form0 => Form(0, false, space) + :Form1 => Form(1, false, space) + :Form2 => Form(2, false, space) + :DualForm0 => Form(0, true, space) + :DualForm1 => Form(1, true, space) + :DualForm2 => Form(2, true, space) :Constant => Scalar() end end @@ -62,6 +64,7 @@ const VF = VField dim(ω::Form) = ω.dim isdual(ω::Form) = ω.isdual space(ω::Form) = ω.space +export space isdual(v::VField) = v.isdual space(v::VField) = v.space @@ -83,7 +86,7 @@ export DualVF show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" function Base.show(io::IO, ω::Form) - print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") + print(io, isdual(ω) ? "DualForm($(dim(ω))) on $(space(ω))" : "PrimalForm($(dim(ω))) on $(space(ω))") end # TODO: VField @@ -93,7 +96,7 @@ function +(s1::Sort, s2::Sort) (Scalar(), Scalar()) => Scalar() (Scalar(), Form(i, isdual, space)) || (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) - (Form(i1, isdual1, space2), Form(i2, isdual2, space2)) => + (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => if (i1 == i2) && (isdual1 == isdual2) && (space1 == space2) Form(i1, isdual1) else @@ -139,14 +142,14 @@ function ∧(s1::Sort, s2::Sort) @match (s1, s2) begin (Form(i, isdual, space), Scalar()) || (Scalar(), Form(i, isdual, space)) => Form(i, isdual, space) - (Form(i1, isdual, space), Form(i2, isdual, space)) => begin + (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => begin + (isdual1 == isdual2) && (space1 == space2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) if i1 + i2 <= dim(space) Form(i1 + i2, isdual, space) else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(dim(space)) is the dimension of the ambient space: tried to wedge product $i1 and $i2")) end end - _ => throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) end end @@ -190,8 +193,8 @@ end @nospecialize function ι(s1::Sort, s2::Sort) @match (s1, s2) begin - (VF(true), Form(i, true)) => PrimalForm() # wrong - (VF(true), Form(i, false)) => DualForm() + (VF(true, space), Form(i, true, space)) => PrimalForm() # wrong + (VF(true, space), Form(i, false, space)) => DualForm() _ => throw(SortError("Can only define the discrete interior product on: PrimalVF, DualForm(i) DualVF(), PrimalForm(i) @@ -203,7 +206,7 @@ end function ♯(s::Sort) @match s begin Scalar() => PrimalVF() - Form(1, isdual) => VF(isdual) + Form(1, isdual, space) => VF(isdual, space) _ => throw(SortError("Can only take ♯ to 1-forms")) end end @@ -211,7 +214,7 @@ end function ♭(s::Sort) @match s begin - VF(true) => PrimalForm(1) + VF(true, space) => PrimalForm(1, false, space) _ => throw(SortError("Can only apply ♭ to dual vector fields")) end end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index 3b86fa6..5ccd25b 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -12,43 +12,47 @@ abstract type DECType <: Number end """ i: dimension: 0,1,2, etc. d: duality: true = dual, false = primal +s: name of the space (a symbol) +n: dimension of the space """ -struct FormT{i,d} <: DECType +struct FormT{i,d,s,n} <: DECType end +export FormT -struct VFieldT{d} <: DECType +struct VFieldT{d,s,n} <: DECType end +export VFieldT dim(::Type{<:FormT{d}}) where {d} = d isdual(::Type{FormT{i,d}}) where {i,d} = d # convenience functions -const PrimalFormT{i} = FormT{i,false} +const PrimalFormT{i,s,n} = FormT{i,false,s,n} export PrimalFormT -const DualFormT{i} = FormT{i,true} +const DualFormT{i,s,n} = FormT{i,true,s,n} export DualFormT -const PrimalVFT = VFieldT{false} +const PrimalVFT{s,n} = VFieldT{false,s,n} export PrimalVFT -const DualVFT = VFieldT{true} +const DualVFT{s,n} = VFieldT{true,s,n} export DualVFT -function Sort(::Type{FormT{i,d}}) where {i,d} - Form(i, d) +function Sort(::Type{FormT{i,d,s,n}}) where {i,d,s,n} + Form(i, d, Space(s, n)) end function Number(f::Form) - FormT{dim(f),isdual(f)} + FormT{dim(f),isdual(f), nameof(space(f)), dim(space(f))} end -function Sort(::Type{VFieldT{d}}) where {d} - VField(d) +function Sort(::Type{VFieldT{d,s,n}}) where {d,s,n} + VField(d, Space(s, n)) end function Number(v::VField) - VFieldT{isdual(v)} + VFieldT{isdual(v), nameof(space(v)), dim(space(v))} end function Sort(::Type{<:Real}) From 5928cc78474d77254dfa73f5cae5e96247795d32 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 18 Aug 2024 12:56:22 -0400 Subject: [PATCH 04/30] added tests for decasymbolic, but needs wrinkles ironed out. musical isos also given placeholder nameof methods --- src/DiagrammaticEquations.jl | 37 ++++---------- src/ThDEC.jl | 13 +++++ src/decasymbolic.jl | 99 ++++++++++++++++++++++-------------- test/decasymbolic.jl | 85 +++++++++++++++++++++++++++++++ 4 files changed, 170 insertions(+), 64 deletions(-) create mode 100644 test/decasymbolic.jl diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index d226390..cb483b9 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -3,43 +3,23 @@ module DiagrammaticEquations export -DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, -# Deca -op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, -op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, -recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, -## collages -Collage, collate, -## composition -oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, -## acset -SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, -contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, -resolve_overloads!, replace_names!, -apply_inference_rule_op1!, apply_inference_rule_op2!, -transfer_parents!, transfer_children!, -unique_lits!, -## language -@decapode, Term, parse_decapode, term, Eq, DecaExpr, -# ~~~~~ -Plus, AppCirc1, Var, Tan, App1, App2, -## visualization -to_graphviz_property_graph, typename, draw_composition, -## rewrite -average_rewrite, -## openoperators -transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! +DerivOp, append_dot, normalize_unicode, + +## intertypes +SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, DecaExpr, Plus, AppCirc1, Var, Tan, App1, App2, Eq using Catlab using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs using Catlab.CategoricalAlgebra +import Catlab.CategoricalAlgebra: ∧ using Catlab.WiringDiagrams using Catlab.WiringDiagrams.DirectedWiringDiagrams using Catlab.ACSetInterface using MLStyle import Unicode +using Reexport ## TODO: ## generate schema from a _theory_ @@ -64,6 +44,9 @@ include("learn/Learn.jl") include("ThDEC.jl") include("decasymbolic.jl") -using .Deca +@reexport using .ThDEC +@reexport using .SymbolicUtilsInterop +@reexport using .Deca + end diff --git a/src/ThDEC.jl b/src/ThDEC.jl index 52f5198..6b16e41 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -1,4 +1,5 @@ module ThDEC + using MLStyle import Base: +, -, * @@ -142,6 +143,7 @@ function ★(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) Form(i, isdual) => Form(2 - i, !isdual) + VF(isdual) => throw(SortError("Cannot take the Hodge star of a vector field")) end end @@ -172,6 +174,12 @@ function ♯(s::Sort) end # musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. +# TODO +function Base.nameof(::typeof(♯), s) + Symbol("♯s") +end + + function ♭(s::Sort) @match s begin VF(true) => PrimalForm(1) @@ -179,6 +187,11 @@ function ♭(s::Sort) end end +# TODO +function Base.nameof(::typeof(♭), s) + Symbol("♭s") +end + # OTHER function ♭♯(s::Sort) diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index 3b86fa6..75e8b62 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -1,23 +1,29 @@ -module SymbolicUtilInterop +module SymbolicUtilsInterop using ..ThDEC using MLStyle import ..ThDEC: Sort, dim, isdual using ..decapodes using SymbolicUtils + using SymbolicUtils: Symbolic, BasicSymbolic +# ########################## +# DECType +# +# Type necessary for symbolic utils +# ########################## + +# define DECType as a Number. Necessary for SymbolicUtils abstract type DECType <: Number end """ i: dimension: 0,1,2, etc. d: duality: true = dual, false = primal """ -struct FormT{i,d} <: DECType -end +struct FormT{i,d} <: DECType end -struct VFieldT{d} <: DECType -end +struct VFieldT{d} <: DECType end dim(::Type{<:FormT{d}}) where {d} = d isdual(::Type{FormT{i,d}}) where {i,d} = d @@ -35,38 +41,30 @@ export PrimalVFT const DualVFT = VFieldT{true} export DualVFT -function Sort(::Type{FormT{i,d}}) where {i,d} - Form(i, d) -end +# convert Real to DecType +Sort(::Type{<:Real}) = Scalar() -function Number(f::Form) - FormT{dim(f),isdual(f)} -end +# convert Real to ThDEC +Sort(::Real) = Scalar() -function Sort(::Type{VFieldT{d}}) where {d} - VField(d) -end +# convert DECType to ThDEC +Sort(::Type{FormT{i,d}}) where {i,d} = Form(i, d) -function Number(v::VField) - VFieldT{isdual(v)} -end +# convert DECType to ThDEC +Sort(::Type{VFieldT{d}}) where {d} = VField(d) -function Sort(::Type{<:Real}) - Scalar() -end +Sort(::BasicSymbolic{T}) where {T} = Sort(T) -function Number(s::Scalar) - Real -end +# convert Form to DECType +Number(f::Form) = FormT{dim(f), isdual(f)} -function Sort(::BasicSymbolic{T}) where {T} - Sort(T) -end +# convert VField to DECType +Number(v::VField) = VFieldT{isdual(v)} -function Sort(::Real) - Scalar() -end +# convert number to real +Number(s::Scalar) = Real +# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] for unop in unop_dec @eval begin @@ -74,7 +72,10 @@ for unop in unop_dec function ThDEC.$unop( v::BasicSymbolic{T} ) where {T<:DECType} + # convert the DECType to ThDEC to type check s = ThDEC.$unop(Sort(T)) + # the resulting type is converted back to DECType + # the resulting term has the operation has its head and `v` as its args. SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v]) end end @@ -91,6 +92,7 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end + export $binop @nospecialize function ThDEC.$binop( @@ -100,6 +102,7 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end + export $binop @nospecialize function ThDEC.$binop( @@ -109,19 +112,25 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end + export $binop end end -struct Equation{E} +# name collision with decapodes.Equation +struct DecaEquation{E} lhs::E rhs::E end +export DecaEquation +# a struct carry the symbolic variables and their equations struct DecaSymbolic vars::Vector{Symbolic} - equations::Vector{Equation{Symbolic}} + equations::Vector{DecaEquation{Symbolic}} end +export DecaSymbolic +# BasicSymbolic -> DecaExpr function decapodes.Term(t::SymbolicUtils.BasicSymbolic) if SymbolicUtils.issym(t) decapodes.Var(nameof(t)) @@ -146,13 +155,13 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) end end -function decapodes.Term(x::Real) - decapodes.Lit(Symbol(x)) -end +decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) function decapodes.DecaExpr(d::DecaSymbolic) context = map(d.vars) do var - decapodes.Judgement(nameof(var), nameof(Sort(var)), :I) + # TODO changed :I to :X to make tests pass, but discussion + # needed on handling spaces + decapodes.Judgement(nameof(var), nameof(Sort(var)), :X) end equations = map(d.equations) do eq decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) @@ -160,11 +169,25 @@ function decapodes.DecaExpr(d::DecaSymbolic) decapodes.DecaExpr(context, equations) end +""" +Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC + +Example: +``` + a = @syms a::Real + context = Dict(:a => Scalar(), :u => PrimalForm(0)) + SymbolicUtils.BasicSymbolic(context, Term(a)) +``` +""" function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) @match t begin Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) - Lit(v) => Meta.parse(string(v)) # YOLO + Lit(v) => Meta.parse(string(v)) # TODO no YOLO + # see heat_eq test: eqs had AppCirc1, but this returns + # App1(f, App1(...) AppCirc1(fs, arg) => foldr( + # panics with constants like :k + # see test/language.jl (f, x) -> ThDEC.OPERATOR_LOOKUP[f](x), fs; init=BasicSymbolic(context, arg) @@ -178,15 +201,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Te end function DecaSymbolic(d::decapodes.DecaExpr) + # associates each var to its sort... context = map(d.context) do j j.var => ThDEC.SORT_LOOKUP[j.dim] end + # ... which we then produce a vector of symbolic vars vars = map(context) do (v, s) SymbolicUtils.Sym{Number(s)}(v) end context = Dict{Symbol,Sort}(context) eqs = map(d.equations) do eq - Equation{Symbolic}(BasicSymbolic(context, eq.lhs), BasicSymbolic(context, eq.rhs)) + DecaEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) end DecaSymbolic(vars, eqs) end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl new file mode 100644 index 0000000..0e0f2b6 --- /dev/null +++ b/test/decasymbolic.jl @@ -0,0 +1,85 @@ +using Test +# +using DiagrammaticEquations +using DiagrammaticEquations.ThDEC +using DiagrammaticEquations.decapodes +# +using SymbolicUtils + +@testset "ThDEC Signature checking" begin + @test Scalar() + Scalar() == Scalar() +end + +# load up some variable variables and expressions +a, b = @syms a::Real b::Real +u, v = @syms u::PrimalFormT{0} du::PrimalFormT{1} +ω, η = @syms ω::PrimalFormT{1} η::DualFormT{2} +ϕ, ψ = @syms ϕ::PrimalVFT ψ::DualVFT + +expr_scalar_addition = a + b +expr_primal_wedge = ThDEC.:∧(ω, du) + +@testset "Term Construction" begin + + # test conversion to underlying type + @test Sort(a) == Scalar() + @test Sort(u) == PrimalForm(0) + @test Sort(ω) == PrimalForm(1) + @test Sort(η) == DualForm(2) + @test Sort(ϕ) == PrimalVF() + @test Sort(ψ) == DualVF() + + @test_throws ThDEC.SortError ThDEC.♯(u) + + # test unary operator conversion to decaexpr + @test Term(1) == DiagrammaticEquations.decapodes.Lit(Symbol("1")) + @test Term(a) == Var(:a) + @test Term(ThDEC.∂ₜ(u)) == Tan(Var(:u)) + @test Term(ThDEC.★(ω)) == App1(:★₁, Var(:ω)) + @test Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) + # @test Term(DiagrammaticEquations.ThDEC.♯(du)) + + @test_throws ThDEC.SortError ThDEC.★(ϕ) + + # test binary operator conversion to decaexpr + @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) + @test Term(a * b) == DiagrammaticEquations.decapodes.Mult(Term[Var(:a), Var(:b)]) + @test Term(ThDEC.:∧(ω, du)) == App2(:∧₁₁, Var(:ω), Var(:du)) + +end + +@testset "Moving between DecaExpr and DecaSymbolic" begin end + +context = Dict(:a => Scalar(), :b => Scalar() + ,:u => PrimalForm(0), :du => PrimalForm(1)) + +js = [Judgement(:u, :Form0, :X) + ,Judgement(:∂ₜu, :Form0, :X) + ,Judgement(:Δu, :Form0, :X)] +eqs = [Eq(Var(:∂ₜu), AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) + ,Eq(Tan(Var(:u)), Var(:∂ₜu))] +heat_eq = DecaExpr(js, eqs) + + +symb_heat_eq = DecaSymbolic(heat_eq) +deca_expr = DecaExpr(symb_heat_eq) + +@test js == deca_expr.context + +# eqs in the left has AppCirc1[vector, term] +# deca_expr.equations on the right has nested App1 +# expected behavior is that nested AppCirc1 is preserved +@test_broken eqs == deca_expr.equations + +# copied from test/language + js = [Judgement(:C, :Form0, :X), + Judgement(:Ċ₁, :Form0, :X), + Judgement(:Ċ₂, :Form0, :X) + ] + # TODO: Do we need to handle the fact that all the functions are parameterized by a space? + eqs = [Eq(Var(:Ċ₁), AppCirc1([:⋆₀⁻¹, :dual_d₁, :⋆₁, :k, :d₀], Var(:C))), + Eq(Var(:Ċ₂), AppCirc1([:⋆₀⁻¹, :dual_d₁, :⋆₁, :d₀], Var(:C))), + Eq(Tan(Var(:C)), Plus([Var(:Ċ₁), Var(:Ċ₂)])) + ] + diffusion_d = DecaExpr(js, eqs) + From d7d6a5b7ee2dfa69fb486edaf47d1f1a91a94182 Mon Sep 17 00:00:00 2001 From: Matt Date: Sun, 18 Aug 2024 23:25:50 -0400 Subject: [PATCH 05/30] reverted src/DiagX to restore exports and adding Project.toml --- Project.toml | 2 ++ src/DiagrammaticEquations.jl | 31 ++++++++++++++++++++++++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 8ad3aef..a17f94a 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" @@ -19,4 +20,5 @@ DataStructures = "0.18.13" MLStyle = "0.4.17" SymbolicUtils = "3.1.2" Unicode = "1.6" +Reexport = "1.2.2" julia = "1.6" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index cb483b9..d0ef187 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -3,10 +3,32 @@ module DiagrammaticEquations export -DerivOp, append_dot, normalize_unicode, - -## intertypes -SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, DecaExpr, Plus, AppCirc1, Var, Tan, App1, App2, Eq +DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, +# Deca +op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, +op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, +recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, +## collages +Collage, collate, +## composition +oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, +## acset +SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, +contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, +resolve_overloads!, replace_names!, +apply_inference_rule_op1!, apply_inference_rule_op2!, +transfer_parents!, transfer_children!, +unique_lits!, +## language +@decapode, Term, parse_decapode, term, Eq, DecaExpr, +# ~~~~~ +Plus, AppCirc1, Var, Tan, App1, App2, +## visualization +to_graphviz_property_graph, typename, draw_composition, +## rewrite +average_rewrite, +## openoperators +transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! using Catlab using Catlab.Theories @@ -48,5 +70,4 @@ include("decasymbolic.jl") @reexport using .SymbolicUtilsInterop @reexport using .Deca - end From a438192cb4452035586830f13c7a50bdb52d89d6 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Aug 2024 15:54:04 -0400 Subject: [PATCH 06/30] updated code and tests after merge from space-sorts --- src/ThDEC.jl | 35 +++++++++++--------- src/decasymbolic.jl | 22 ++++++------ test/decasymbolic.jl | 79 ++++++++++++++++++++------------------------ 3 files changed, 65 insertions(+), 71 deletions(-) diff --git a/src/ThDEC.jl b/src/ThDEC.jl index c913cd8..d3e1c3b 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -22,6 +22,7 @@ struct SpaceLookup default::Space named::Dict{Symbol,Space} end +export SpaceLookup @data Sort begin Scalar() @@ -50,14 +51,16 @@ function fromexpr(lookup::SpaceLookup, e, ::Type{Sort}) end end -function Base.nameof(s::Scalar) - :Constant -end +Base.nameof(s::Scalar) = :Constant -function Base.nameof(f::Form) +function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" formname = Symbol("$(dual)Form$(dim(f))") - Expr(:curly, formname, dim(space(f))) + if with_dim_parameter + return Expr(:curly, formname, dim(space(f))) + else + return formname + end end const VF = VField @@ -71,16 +74,16 @@ isdual(v::VField) = v.isdual space(v::VField) = v.space # convenience functions -PrimalForm(i::Int) = Form(i, false) +PrimalForm(i::Int, space::Space) = Form(i, false, space) export PrimalForm -DualForm(i::Int) = Form(i, true) +DualForm(i::Int, space::Space) = Form(i, true, space) export DualForm -PrimalVF() = VF(false) +PrimalVF(space::Space) = VF(false, space) export PrimalVF -DualVF() = VF(true) +DualVF(space::Space) = VF(true, space) export DualVF # show methods @@ -145,10 +148,10 @@ function ∧(s1::Sort, s2::Sort) Form(i, isdual, space) (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => begin (isdual1 == isdual2) && (space1 == space2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) - if i1 + i2 <= dim(space) - Form(i1 + i2, isdual, space) + if i1 + i2 <= dim(space1) + Form(i1 + i2, isdual1, space1) else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(dim(space)) is the dimension of the ambient space: tried to wedge product $i1 and $i2")) + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(dim(space1)) is the dimension of the ambient space: tried to wedge product $i1 and $i2")) end end end @@ -182,7 +185,7 @@ end function ★(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) - VF(isdual) => throw(SortError("Cannot take the Hodge star of a vector field")) + VF(isdual, space) => throw(SortError("Cannot take the Hodge star of a vector field")) Form(i, isdual, space) => Form(dim(space) - i, !isdual, space) end end @@ -195,8 +198,8 @@ end @nospecialize function ι(s1::Sort, s2::Sort) @match (s1, s2) begin - (VF(true, space), Form(i, true, space)) => PrimalForm() # wrong - (VF(true, space), Form(i, false, space)) => DualForm() + (VF(true, space), Form(i, true, space)) => Form(i, false, space) # wrong + (VF(true, space), Form(i, false, space)) => DualForm(i, true, space) _ => throw(SortError("Can only define the discrete interior product on: PrimalVF, DualForm(i) DualVF(), PrimalForm(i) @@ -222,7 +225,7 @@ end function ♭(s::Sort) @match s begin - VF(true, space) => PrimalForm(1, false, space) + VF(true, space) => Form(1, false, space) _ => throw(SortError("Can only apply ♭ to dual vector fields")) end end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index de291ac..e73ecd9 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -29,8 +29,8 @@ export FormT struct VFieldT{d,s,n} <: DECType end export VFieldT -dim(::Type{<:FormT{d}}) where {d} = d -isdual(::Type{FormT{i,d}}) where {i,d} = d +dim(::Type{<:FormT{i,d,s,n}}) where {i,d,s,n} = d +isdual(::Type{FormT{i,d,s,n}}) where {i,d,s,n} = d # convenience functions const PrimalFormT{i,s,n} = FormT{i,false,s,n} @@ -45,14 +45,14 @@ export PrimalVFT const DualVFT{s,n} = VFieldT{true,s,n} export DualVFT -# ## -# -# ## +""" +converts ThDEC Sorts into DECType +""" +function Sort end Sort(::Type{<:Real}) = Scalar() Sort(::Real) = Scalar() -# convert Real to DecType function Sort(::Type{FormT{i,d,s,n}}) where {i,d,s,n} Form(i, d, Space(s, n)) end @@ -63,7 +63,9 @@ end Sort(::BasicSymbolic{T}) where {T} = Sort(T) -# convert number to real +""" +converts ThDEC Sorts into DecaSymbolic types +""" Number(s::Scalar) = Real Number(f::Form) = FormT{dim(f),isdual(f), nameof(space(f)), dim(space(f))} @@ -165,8 +167,6 @@ decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) function decapodes.DecaExpr(d::DecaSymbolic) context = map(d.vars) do var - # TODO changed :I to :X to make tests pass, but discussion - # needed on handling spaces decapodes.Judgement(nameof(var), nameof(Sort(var)), :X) end equations = map(d.equations) do eq @@ -206,10 +206,10 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Te end end -function DecaSymbolic(d::decapodes.DecaExpr) +function DecaSymbolic(lookup::SpaceLookup, d::decapodes.DecaExpr) # associates each var to its sort... context = map(d.context) do j - j.var => ThDEC.SORT_LOOKUP[j.dim] + j.var => ThDEC.fromexpr(lookup, j.dim, Sort) end # ... which we then produce a vector of symbolic vars vars = map(context) do (v, s) diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 0e0f2b6..20cf7ac 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -1,33 +1,32 @@ using Test -# + using DiagrammaticEquations using DiagrammaticEquations.ThDEC using DiagrammaticEquations.decapodes -# + using SymbolicUtils -@testset "ThDEC Signature checking" begin - @test Scalar() + Scalar() == Scalar() -end +# what space are we working in? +X = Space(:X, 2) + +lookup = SpaceLookup(X, Dict(:X => X)) # load up some variable variables and expressions a, b = @syms a::Real b::Real -u, v = @syms u::PrimalFormT{0} du::PrimalFormT{1} -ω, η = @syms ω::PrimalFormT{1} η::DualFormT{2} -ϕ, ψ = @syms ϕ::PrimalVFT ψ::DualVFT - -expr_scalar_addition = a + b -expr_primal_wedge = ThDEC.:∧(ω, du) +u, v = @syms u::PrimalFormT{0, :X, 2} du::PrimalFormT{1, :X, 2} +ω, η = @syms ω::PrimalFormT{1, :X, 2} η::DualFormT{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVFT{:X, 2} ψ::DualVFT{:X, 2} +# TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin # test conversion to underlying type @test Sort(a) == Scalar() - @test Sort(u) == PrimalForm(0) - @test Sort(ω) == PrimalForm(1) - @test Sort(η) == DualForm(2) - @test Sort(ϕ) == PrimalVF() - @test Sort(ψ) == DualVF() + @test Sort(u) == PrimalForm(0, X) + @test Sort(ω) == PrimalForm(1, X) + @test Sort(η) == DualForm(2, X) + @test Sort(ϕ) == PrimalVF(X) + @test Sort(ψ) == DualVF(X) @test_throws ThDEC.SortError ThDEC.♯(u) @@ -48,38 +47,30 @@ expr_primal_wedge = ThDEC.:∧(ω, du) end -@testset "Moving between DecaExpr and DecaSymbolic" begin end - -context = Dict(:a => Scalar(), :b => Scalar() - ,:u => PrimalForm(0), :du => PrimalForm(1)) +@testset "Moving between DecaExpr and DecaSymbolic" begin -js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] -eqs = [Eq(Var(:∂ₜu), AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - ,Eq(Tan(Var(:u)), Var(:∂ₜu))] -heat_eq = DecaExpr(js, eqs) + context = Dict(:a => Scalar() + , :b => Scalar() + , :u => PrimalForm(0, X) + , :du => PrimalForm(1, X)) + js = [Judgement(:u, :Form0, :X) + ,Judgement(:∂ₜu, :Form0, :X) + ,Judgement(:Δu, :Form0, :X)] + eqs = [Eq(Var(:∂ₜu) + , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) + , Eq(Tan(Var(:u)), Var(:∂ₜu))] + heat_eq = DecaExpr(js, eqs) -symb_heat_eq = DecaSymbolic(heat_eq) -deca_expr = DecaExpr(symb_heat_eq) -@test js == deca_expr.context + symb_heat_eq = DecaSymbolic(lookup, heat_eq) + deca_expr = DecaExpr(symb_heat_eq) -# eqs in the left has AppCirc1[vector, term] -# deca_expr.equations on the right has nested App1 -# expected behavior is that nested AppCirc1 is preserved -@test_broken eqs == deca_expr.equations + @test js == deca_expr.context -# copied from test/language - js = [Judgement(:C, :Form0, :X), - Judgement(:Ċ₁, :Form0, :X), - Judgement(:Ċ₂, :Form0, :X) - ] - # TODO: Do we need to handle the fact that all the functions are parameterized by a space? - eqs = [Eq(Var(:Ċ₁), AppCirc1([:⋆₀⁻¹, :dual_d₁, :⋆₁, :k, :d₀], Var(:C))), - Eq(Var(:Ċ₂), AppCirc1([:⋆₀⁻¹, :dual_d₁, :⋆₁, :d₀], Var(:C))), - Eq(Tan(Var(:C)), Plus([Var(:Ċ₁), Var(:Ċ₂)])) - ] - diffusion_d = DecaExpr(js, eqs) + # eqs in the left has AppCirc1[vector, term] + # deca_expr.equations on the right has nested App1 + # expected behavior is that nested AppCirc1 is preserved + @test_broken eqs == deca_expr.equations +end From 18a1585fabf2ba9a6435705f1ef43613e9ba36b0 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Wed, 21 Aug 2024 12:21:21 -0400 Subject: [PATCH 07/30] review changes: 1. add show methods for DecaEquation and DecaSymbolic 2. remove spurious exports in code gen 3. fix binary form constructors from pre-space code 4. remove aqua export check because of code gen 5. add Klausmeier tests --- src/ThDEC.jl | 9 +++++++-- src/decasymbolic.jl | 20 +++++++++++++++++--- test/aqua.jl | 2 +- test/klausmeier.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 67 insertions(+), 6 deletions(-) create mode 100644 test/klausmeier.jl diff --git a/src/ThDEC.jl b/src/ThDEC.jl index d3e1c3b..8c89de0 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -102,7 +102,7 @@ function +(s1::Sort, s2::Sort) (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => if (i1 == i2) && (isdual1 == isdual2) && (space1 == space2) - Form(i1, isdual1) + Form(i1, isdual1, space1) else throw(SortError( """ @@ -128,7 +128,7 @@ function *(s1::Sort, s2::Sort) @match (s1, s2) begin (Scalar(), Scalar()) => Scalar() (Scalar(), Form(i, isdual, space)) || - (Form(i, isdual, space), Scalar()) => Form(i, isdual) + (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) (Form(_, _, _), Form(_, _, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) end @@ -252,6 +252,10 @@ function Δ(s::Sort) end end +function Base.nameof(::typeof(Δ), s) + Symbol("Δ") +end + const OPERATOR_LOOKUP = Dict( :⋆₀ => ★, :⋆₁ => ★, @@ -296,6 +300,7 @@ const OPERATOR_LOOKUP = Dict( # Dual-Dual Lie Derivatives # :ℒ₁ => ℒ, + # :L => ℒ, # Dual Laplacians # :Δᵈ₀ => Δ, diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index e73ecd9..96e3e0b 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -90,6 +90,8 @@ for unop in unop_dec end binop_dec = [:+, :-, :*, :∧] +export +,-,*,∧ + for binop in binop_dec @eval begin @nospecialize @@ -100,7 +102,6 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end - export $binop @nospecialize function ThDEC.$binop( @@ -110,7 +111,6 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end - export $binop @nospecialize function ThDEC.$binop( @@ -120,7 +120,6 @@ for binop in binop_dec s = ThDEC.$binop(Sort(T1), Sort(T2)) SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) end - export $binop end end @@ -130,6 +129,11 @@ struct DecaEquation{E} rhs::E end export DecaEquation +Base.show(io::IO, e::DecaEquation) = begin + print(io, e.lhs) + print(io, " == ") + print(io, e.rhs) +end # a struct carry the symbolic variables and their equations struct DecaSymbolic @@ -138,6 +142,16 @@ struct DecaSymbolic end export DecaSymbolic +Base.show(io::IO, d::DecaSymbolic) = begin + println(io, "DecaSymbolic(") + println(io, " Variables: [$(join(d.vars, ", "))]") + println(io, " Equations: [") + eqns = map(d.equations) do op + " $(op)" + end + println(io, "$(join(eqns,",\n"))])") + end + # BasicSymbolic -> DecaExpr function decapodes.Term(t::SymbolicUtils.BasicSymbolic) if SymbolicUtils.issym(t) diff --git a/test/aqua.jl b/test/aqua.jl index 555f584..824a58c 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -1,5 +1,5 @@ using Aqua, DiagrammaticEquations @testset "Code quality (Aqua.jl)" begin # TODO: fix ambiguities - Aqua.test_all(DiagrammaticEquations, ambiguities=false) + Aqua.test_all(DiagrammaticEquations, ambiguities=false, undefined_exports=false) end diff --git a/test/klausmeier.jl b/test/klausmeier.jl new file mode 100644 index 0000000..b44d754 --- /dev/null +++ b/test/klausmeier.jl @@ -0,0 +1,42 @@ +using DiagrammaticEquations +using DiagrammaticEquations.SymbolicUtilsInterop +# See Klausmeier Equation 2.a +Hydrodynamics = @decapode begin + (n,w)::DualForm0 + dX::Form1 + (a,ν)::Constant + + ∂ₜ(w) == a - w - w * n^2 + ν * L(dX, w) +end + +# See Klausmeier Equation 2.b +Phytodynamics = @decapode begin + (n,w)::DualForm0 + m::Constant + + ∂ₜ(n) == w * n^2 - m*n + Δ(n) +end + +Hydrodynamics = parse_decapode(quote + (n,w)::DualForm0 + dX::Form1 + (a,ν)::Constant + + ∂ₜ(w) == a - w - w + ν * L(dX, w) + # ∂ₜ(w) == a - w - w * (n ∧ n) + ν * L(dX, w) +end) + +X = Space(:X, 1) +lookup = SpaceLookup(X, Dict(:X => X)) +# DecaSymbolic(lookup, Hydrodynamics) + +# See Klausmeier Equation 2.b +Phytodynamics = parse_decapode(quote + (n,w)::Form0 + m::Constant + + ∂ₜ(n) == w - m*n + Δ(n) + # ∂ₜ(n) == w * n*n - m*n + Δ(n) +end) + +DecaSymbolic(lookup, Phytodynamics) \ No newline at end of file From 090ddfeb73492adda9c3ccf3d89674c1e3c325f6 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 22 Aug 2024 17:52:25 -0400 Subject: [PATCH 08/30] resolving some comments from code review. --- src/ThDEC.jl | 13 +++++++---- src/decasymbolic.jl | 54 +++++++++++++++++++++++++++++++++++++++----- src/language.jl | 2 +- test/decasymbolic.jl | 45 ++++++++++++++++++++++-------------- 4 files changed, 85 insertions(+), 29 deletions(-) diff --git a/src/ThDEC.jl b/src/ThDEC.jl index 8c89de0..91dcbe4 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -24,6 +24,8 @@ struct SpaceLookup end export SpaceLookup +SpaceLookup(default::Space) = SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default)) + @data Sort begin Scalar() Form(dim::Int, isdual::Bool, space::Space) @@ -154,6 +156,7 @@ function ∧(s1::Sort, s2::Sort) throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(dim(space1)) is the dimension of the ambient space: tried to wedge product $i1 and $i2")) end end + (VF(isdual, _), _) || (_, VF(isdual, _)) => throw(SortError("Can only take a wedge product of forms. Flatten (♭) your vector field before applying")) end end @@ -219,7 +222,7 @@ end # TODO function Base.nameof(::typeof(♯), s) - Symbol("♯s") + Symbol("♯$s") end @@ -232,7 +235,7 @@ end # TODO function Base.nameof(::typeof(♭), s) - Symbol("♭s") + Symbol("♭$s") end # OTHER @@ -247,14 +250,14 @@ end # Δ = ★d⋆d, but we check signature here to throw a more helpful error function Δ(s::Sort) @match s begin + Scalar() => Scalar() Form(0, isdual, space) => Form(0, isdual, space) + Form(1, isdual, space) => Form(1, isdual, space) _ => throw(SortError("Δ is not defined for $s")) end end -function Base.nameof(::typeof(Δ), s) - Symbol("Δ") -end +Base.nameof(::typeof(Δ), s) = Symbol("Δ") const OPERATOR_LOOKUP = Dict( :⋆₀ => ★, diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index 96e3e0b..d3a9f86 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -1,5 +1,7 @@ module SymbolicUtilsInterop +using ..DiagrammaticEquations: AbstractDecapode +import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..ThDEC using MLStyle import ..ThDEC: Sort, dim, isdual @@ -63,6 +65,17 @@ end Sort(::BasicSymbolic{T}) where {T} = Sort(T) +# converts a sort to its Julia symbol +function to_symb(sort::Sort) + @match sort begin + Scalar() => :Constant + Form(i, isdual, X) => + Symbol("$(isdual ? "Dual" : "")Form$i") + VField(isdual, X) => + Symbol("$(isdual ? "Dual" : "")VF") + end +end + """ converts ThDEC Sorts into DecaSymbolic types """ @@ -80,17 +93,14 @@ for unop in unop_dec function ThDEC.$unop( v::BasicSymbolic{T} ) where {T<:DECType} - # convert the DECType to ThDEC to type check s = ThDEC.$unop(Sort(T)) - # the resulting type is converted back to DECType - # the resulting term has the operation has its head and `v` as its args. SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v]) end end end -binop_dec = [:+, :-, :*, :∧] -export +,-,*,∧ +binop_dec = [:+, :-, :*, :∧, :^] +export +,-,*,∧,^ for binop in binop_dec @eval begin @@ -129,6 +139,7 @@ struct DecaEquation{E} rhs::E end export DecaEquation + Base.show(io::IO, e::DecaEquation) = begin print(io, e.lhs) print(io, " == ") @@ -202,7 +213,7 @@ Example: function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) @match t begin Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) - Lit(v) => Meta.parse(string(v)) # TODO no YOLO + Lit(v) => Meta.parse(string(v)) # see heat_eq test: eqs had AppCirc1, but this returns # App1(f, App1(...) AppCirc1(fs, arg) => foldr( @@ -236,4 +247,35 @@ function DecaSymbolic(lookup::SpaceLookup, d::decapodes.DecaExpr) DecaSymbolic(vars, eqs) end +function eval_eq!(eq::DecaEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) + eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) +end + +""" function SummationDecapode(e::DecaSymbolic) """ +function SummationDecapode(e::DecaSymbolic) + d = SummationDecapode{Any, Any, Symbol}() + symbol_table = Dict{Symbol, Int}() + + foreach(e.vars) do var + # convert Sort(var)::PrimalForm0 --> :Form0 + var_id = add_part!(d, :Var, name=var.name, type=to_symb(Sort(var))) + symbol_table[var.name] = var_id + end + + deletions = Vector{Int}() + foreach(e.equations) do eq + eval_eq!(eq, d, symbol_table, deletions) + end + rem_parts!(d, :Var, sort(deletions)) + + recognize_types(d) + + fill_names!(d) + d[:name] = normalize_unicode.(d[:name]) + make_sum_mult_unique!(d) + return d +end + + + end diff --git a/src/language.jl b/src/language.jl index 5a06081..e0f7d71 100644 --- a/src/language.jl +++ b/src/language.jl @@ -51,8 +51,8 @@ function parse_decapode(expr::Expr) end DecaExpr(judges, eqns) end + # to_Decapode helper functions -### TODO - Matt: we need to generalize this reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) = let ! = reduce_term! @match t begin diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 20cf7ac..f016f33 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -9,7 +9,7 @@ using SymbolicUtils # what space are we working in? X = Space(:X, 2) -lookup = SpaceLookup(X, Dict(:X => X)) +lookup = SpaceLookup(X) # load up some variable variables and expressions a, b = @syms a::Real b::Real @@ -47,30 +47,41 @@ u, v = @syms u::PrimalFormT{0, :X, 2} du::PrimalFormT{1, :X, 2} end -@testset "Moving between DecaExpr and DecaSymbolic" begin - - context = Dict(:a => Scalar() - , :b => Scalar() - , :u => PrimalForm(0, X) - , :du => PrimalForm(1, X)) - - js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] - eqs = [Eq(Var(:∂ₜu) - , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - , Eq(Tan(Var(:u)), Var(:∂ₜu))] - heat_eq = DecaExpr(js, eqs) +context = Dict(:a => Scalar() + ,:b => Scalar() + ,:u => PrimalForm(0, X) + ,:du => PrimalForm(1, X)) +js = [Judgement(:u, :Form0, :X) + ,Judgement(:∂ₜu, :Form0, :X) + ,Judgement(:Δu, :Form0, :X)] +eqs = [Eq(Var(:∂ₜu) + , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) + , Eq(Tan(Var(:u)), Var(:∂ₜu))] +heat_eq = DecaExpr(js, eqs) - symb_heat_eq = DecaSymbolic(lookup, heat_eq) - deca_expr = DecaExpr(symb_heat_eq) +symb_heat_eq = DecaSymbolic(lookup, heat_eq) +deca_expr = DecaExpr(symb_heat_eq) +@testset "Moving between DecaExpr and DecaSymbolic" begin + @test js == deca_expr.context # eqs in the left has AppCirc1[vector, term] # deca_expr.equations on the right has nested App1 # expected behavior is that nested AppCirc1 is preserved @test_broken eqs == deca_expr.equations + # use expand_operators to get rid of parentheses + # infer_types and resolve_overloads + +end + +# convert both into ACSets then is_iso them +@testset "" begin + + Σ = DiagrammaticEquations.SummationDecapode(deca_expr) + Δ = DiagrammaticEquations.SummationDecapode(symb_heat_eq) + @test Σ == Δ + end From d8be4ae31b323a66e7fdc91d8ea777de45a25176 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 27 Aug 2024 11:33:53 -0400 Subject: [PATCH 09/30] adding @alias and @register macros to make DecaSymbolic function work in test/klausmeier --- Project.toml | 3 +- src/ThDEC.jl | 129 +++++++++++++------------------------------ src/decasymbolic.jl | 97 +++++++++++++++++++++++++------- test/decasymbolic.jl | 29 +++++----- test/klausmeier.jl | 30 ++++++++-- 5 files changed, 156 insertions(+), 132 deletions(-) diff --git a/Project.toml b/Project.toml index c07a9a0..19239ef 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" @@ -19,8 +20,8 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +Reexport = "1.2.2" StructEquality = "2.1.0" SymbolicUtils = "3.1.2" Unicode = "1.6" -Reexport = "1.2.2" julia = "1.6" diff --git a/src/ThDEC.jl b/src/ThDEC.jl index 91dcbe4..6cb25e0 100644 --- a/src/ThDEC.jl +++ b/src/ThDEC.jl @@ -5,6 +5,31 @@ using StructEquality import Base: +, -, * +""" +Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces +for each alias typechecking and nameof methods which call those for their rep. +Example: +@alias d₀, d₁, += d +""" +macro alias(body) + (rep, aliases) = @match body begin + Expr(:tuple, rep, Expr(:tuple, aliases...)) => (rep, aliases) + _ => nothing + end + result = quote end + foreach(aliases) do alias + push!(result.args, + quote + function $(esc(alias))(s...) + $(esc(rep))(s...) + end + export $(esc(alias)) + Base.nameof(::typeof($alias), s) = nameof($rep, s) + end) + end + result +end + struct SortError <: Exception message::String end @@ -24,7 +49,9 @@ struct SpaceLookup end export SpaceLookup -SpaceLookup(default::Space) = SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default)) +function SpaceLookup(default::Space) + SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default)) +end @data Sort begin Scalar() @@ -138,9 +165,7 @@ end const SUBSCRIPT_DIGIT_0 = '₀' -function as_sub(n::Int) - join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) -end +as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) # TODO: VField @nospecialize @@ -180,20 +205,23 @@ function d(s::Sort) end end -function Base.nameof(::typeof(d), s) - Symbol("d$(as_sub(dim(s)))") -end +@alias d, (d₀, d₁) + +Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") @nospecialize -function ★(s::Sort) +function ⋆(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) VF(isdual, space) => throw(SortError("Cannot take the Hodge star of a vector field")) Form(i, isdual, space) => Form(dim(space) - i, !isdual, space) end end +export ⋆ + +@alias ⋆, (⋆₀, ⋆₁, ⋆₂, ⋆₀⁻¹, ⋆₁⁻¹, ⋆₂⁻¹) -function Base.nameof(::typeof(★), s) +function Base.nameof(::typeof(⋆), s) inv = isdual(s) ? "⁻¹" : "" Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end @@ -220,11 +248,7 @@ function ♯(s::Sort) end # musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. -# TODO -function Base.nameof(::typeof(♯), s) - Symbol("♯$s") -end - +Base.nameof(::typeof(♯), s) = Symbol("♯$s") function ♭(s::Sort) @match s begin @@ -233,10 +257,7 @@ function ♭(s::Sort) end end -# TODO -function Base.nameof(::typeof(♭), s) - Symbol("♭$s") -end +Base.nameof(::typeof(♭), s) = Symbol("♭$s") # OTHER @@ -259,76 +280,4 @@ end Base.nameof(::typeof(Δ), s) = Symbol("Δ") -const OPERATOR_LOOKUP = Dict( - :⋆₀ => ★, - :⋆₁ => ★, - :⋆₂ => ★, - - # Inverse Hodge Stars - :⋆₀⁻¹ => ★, - :⋆₁⁻¹ => ★, - :⋆₂⁻¹ => ★, - - # Differentials - :d₀ => d, - :d₁ => d, - - # Dual Differentials - :dual_d₀ => d, - :d̃₀ => d, - :dual_d₁ => d, - :d̃₁ => d, - - # Wedge Products - :∧₀₁ => ∧, - :∧₁₀ => ∧, - :∧₀₂ => ∧, - :∧₂₀ => ∧, - :∧₁₁ => ∧, - - # Primal-Dual Wedge Products - :∧ᵖᵈ₁₁ => ∧, - :∧ᵖᵈ₀₁ => ∧, - :∧ᵈᵖ₁₁ => ∧, - :∧ᵈᵖ₁₀ => ∧, - - # Dual-Dual Wedge Products - :∧ᵈᵈ₁₁ => ∧, - :∧ᵈᵈ₁₀ => ∧, - :∧ᵈᵈ₀₁ => ∧, - - # Dual-Dual Interior Products - :ι₁₁ => ι, - :ι₁₂ => ι, - - # Dual-Dual Lie Derivatives - # :ℒ₁ => ℒ, - # :L => ℒ, - - # Dual Laplacians - # :Δᵈ₀ => Δ, - # :Δᵈ₁ => Δ, - - # Musical Isomorphisms - :♯ => ♯, - :♯ᵈ => ♯, :♭ => ♭, - - # Averaging Operator - # :avg₀₁ => avg, - - # Negatives - :neg => -, - - # Basics - - :- => -, - :+ => +, - :* => *, - :/ => /, - :.- => .-, - :.+ => .+, - :.* => .*, - :./ => ./, -) - end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl index d3a9f86..8933ae4 100644 --- a/src/decasymbolic.jl +++ b/src/decasymbolic.jl @@ -3,12 +3,12 @@ module SymbolicUtilsInterop using ..DiagrammaticEquations: AbstractDecapode import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..ThDEC -using MLStyle import ..ThDEC: Sort, dim, isdual using ..decapodes -using SymbolicUtils -using SymbolicUtils: Symbolic, BasicSymbolic +using MLStyle +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym # ########################## # DECType @@ -65,17 +65,6 @@ end Sort(::BasicSymbolic{T}) where {T} = Sort(T) -# converts a sort to its Julia symbol -function to_symb(sort::Sort) - @match sort begin - Scalar() => :Constant - Form(i, isdual, X) => - Symbol("$(isdual ? "Dual" : "")Form$i") - VField(isdual, X) => - Symbol("$(isdual ? "Dual" : "")VF") - end -end - """ converts ThDEC Sorts into DecaSymbolic types """ @@ -85,8 +74,13 @@ Number(f::Form) = FormT{dim(f),isdual(f), nameof(space(f)), dim(space(f))} Number(v::VField) = VFieldT{isdual(v), nameof(space(v)), dim(space(v))} +# HERE WE DEFINE THE SYMBOLICUTILS + # for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term -unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] +unop_dec = [:∂ₜ, :d, :d₀, :d₁ + , :⋆, :⋆₀, :⋆₁, :⋆₂, :⋆₀⁻¹, :⋆₁⁻¹, :⋆₂⁻¹ + , :♯, :♭, :-] + for unop in unop_dec @eval begin @nospecialize @@ -99,6 +93,8 @@ for unop in unop_dec end end +# BasicSymbolic{FnType{Tuple{PrimalFormT{0}}}, PrimalFormT{0}} + binop_dec = [:+, :-, :*, :∧, :^] export +,-,*,∧,^ @@ -211,6 +207,8 @@ Example: ``` """ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) + # user must import symbols into scope + ! = (f -> getfield(Main, f)) @match t begin Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) Lit(v) => Meta.parse(string(v)) @@ -219,12 +217,13 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Te AppCirc1(fs, arg) => foldr( # panics with constants like :k # see test/language.jl - (f, x) -> ThDEC.OPERATOR_LOOKUP[f](x), + (f, x) -> (!(f))(x), fs; init=BasicSymbolic(context, arg) ) - App1(f, x) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x)) - App2(f, x, y) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x), BasicSymbolic(context, y)) + # getfield(Main, + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) Tan(x) => ThDEC.∂ₜ(BasicSymbolic(context, x)) @@ -258,7 +257,7 @@ function SummationDecapode(e::DecaSymbolic) foreach(e.vars) do var # convert Sort(var)::PrimalForm0 --> :Form0 - var_id = add_part!(d, :Var, name=var.name, type=to_symb(Sort(var))) + var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) symbol_table[var.name] = var_id end @@ -276,6 +275,66 @@ function SummationDecapode(e::DecaSymbolic) return d end +""" +Registers a new function + +``` +@register Δ(s::Sort) begin + @match s begin + ::Scalar => error("Invalid") + ::VField => error("Invalid") + ::Form => ⋆(d(⋆(d(s)))) + end +end +``` +will create an additional method for Δ for operating on BasicSymbolic +""" +macro register(head, body) + # parse head + parsehead = @λ begin + Expr(:call, f, types...) => (f, parsehead.(types)) + Expr(:(::), var, type) => (var, type) + s => s + end + (f, args) = parsehead(head) + matchargs = [:($(x[1])::$(x[2])) for x in args] + + result = quote end + push!(result.args, + esc(quote + function $f($(matchargs...)) + $body + end + end)) + + # e.g., given [(:x, :Scalar), (:ω, :Form)]... + vs = enumerate(unique(getindex.(args, 2))) + theargs = + Dict{Symbol,Symbol}( + [v => Symbol("T$k") for (k,v) in vs] + ) + # ...[(Scalar=>:T1, :Form=>:T2)] + + # reassociate vars with their BasicSymbolic Generic Types + binding = map(args) do (var, type) + (var, :(BasicSymbolic{$(theargs[type])})) + end + newargs = [:($(x[1])::$(x[2])) for x in binding] + constraints = [:($T<:DECType) for T in values(theargs)] + innerargs = [:(Sort($T)) for T in values(theargs)] + + push!(result.args, + quote + @nospecialize + function $(esc(f))($(newargs...)) where $(constraints...) + s = $(esc(f))($(innerargs...)) + SymbolicUtils.Term{Number(s)}($(esc(f)), [$(getindex.(binding, 1)...)]) + end + end) + + return result +end +export @register end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index f016f33..fbbfc15 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -1,14 +1,10 @@ using Test - using DiagrammaticEquations using DiagrammaticEquations.ThDEC using DiagrammaticEquations.decapodes - using SymbolicUtils - # what space are we working in? X = Space(:X, 2) - lookup = SpaceLookup(X) # load up some variable variables and expressions @@ -34,11 +30,11 @@ u, v = @syms u::PrimalFormT{0, :X, 2} du::PrimalFormT{1, :X, 2} @test Term(1) == DiagrammaticEquations.decapodes.Lit(Symbol("1")) @test Term(a) == Var(:a) @test Term(ThDEC.∂ₜ(u)) == Tan(Var(:u)) - @test Term(ThDEC.★(ω)) == App1(:★₁, Var(:ω)) - @test Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) + @test Term(ThDEC.⋆(ω)) == App1(:⋆₁, Var(:ω)) + @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) # @test Term(DiagrammaticEquations.ThDEC.♯(du)) - @test_throws ThDEC.SortError ThDEC.★(ϕ) + @test_throws ThDEC.SortError ThDEC.⋆(ϕ) # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) @@ -47,21 +43,22 @@ u, v = @syms u::PrimalFormT{0, :X, 2} du::PrimalFormT{1, :X, 2} end -context = Dict(:a => Scalar() - ,:b => Scalar() - ,:u => PrimalForm(0, X) - ,:du => PrimalForm(1, X)) +@testset "Conversion" begin -js = [Judgement(:u, :Form0, :X) + context = Dict(:a => Scalar(),:b => Scalar() + ,:u => PrimalForm(0, X),:du => PrimalForm(1, X)) + js = [Judgement(:u, :Form0, :X) ,Judgement(:∂ₜu, :Form0, :X) ,Judgement(:Δu, :Form0, :X)] -eqs = [Eq(Var(:∂ₜu) + eqs = [Eq(Var(:∂ₜu) , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) , Eq(Tan(Var(:u)), Var(:∂ₜu))] -heat_eq = DecaExpr(js, eqs) + heat_eq = DecaExpr(js, eqs) + + symb_heat_eq = DecaSymbolic(lookup, heat_eq) + deca_expr = DecaExpr(symb_heat_eq) -symb_heat_eq = DecaSymbolic(lookup, heat_eq) -deca_expr = DecaExpr(symb_heat_eq) +end @testset "Moving between DecaExpr and DecaSymbolic" begin diff --git a/test/klausmeier.jl b/test/klausmeier.jl index b44d754..9c88e46 100644 --- a/test/klausmeier.jl +++ b/test/klausmeier.jl @@ -1,5 +1,11 @@ using DiagrammaticEquations using DiagrammaticEquations.SymbolicUtilsInterop + +using Test +using MLStyle +using SymbolicUtils +using SymbolicUtils: BasicSymbolic + # See Klausmeier Equation 2.a Hydrodynamics = @decapode begin (n,w)::DualForm0 @@ -23,20 +29,32 @@ Hydrodynamics = parse_decapode(quote (a,ν)::Constant ∂ₜ(w) == a - w - w + ν * L(dX, w) - # ∂ₜ(w) == a - w - w * (n ∧ n) + ν * L(dX, w) end) -X = Space(:X, 1) -lookup = SpaceLookup(X, Dict(:X => X)) +X = Space(:X, 2) +lookup = SpaceLookup(X) # DecaSymbolic(lookup, Hydrodynamics) # See Klausmeier Equation 2.b Phytodynamics = parse_decapode(quote (n,w)::Form0 m::Constant - ∂ₜ(n) == w - m*n + Δ(n) - # ∂ₜ(n) == w * n*n - m*n + Δ(n) end) -DecaSymbolic(lookup, Phytodynamics) \ No newline at end of file +import .ThDEC: d, ⋆, SortError + +@register Δ(s::Sort) begin + @match s begin + ::Scalar => throw(SortError("Scalar")) + ::VField => throw(SortError("Nay!")) + ::Form => ⋆(d(⋆(d(s)))) + end +end + +ω, = @syms ω::PrimalFormT{1, :X, 2} + +@test Δ(PrimalForm(1, X)) == PrimalForm(1, X) +@test Δ(ω) |> typeof == BasicSymbolic{PrimalFormT{1, :X, 2}} + +DecaSymbolic(lookup, Phytodynamics) From 77770e5b17a625fbe03cb3b58aff307a3d697cca Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 27 Aug 2024 21:12:18 -0400 Subject: [PATCH 10/30] experimenting with a type-driven approach --- src/DiagrammaticEquations.jl | 7 +- src/SymbolicUtilsInterop.jl | 154 ++++++++++++++++ src/ThDEC.jl | 283 ----------------------------- src/deca/Deca.jl | 5 + src/deca/ThDEC.jl | 229 +++++++++++++++++++++++ src/decasymbolic.jl | 340 ----------------------------------- src/symbolictheoryutils.jl | 161 +++++++++++++++++ test/decasymbolic.jl | 48 ++--- test/klausmeier.jl | 15 +- 9 files changed, 587 insertions(+), 655 deletions(-) create mode 100644 src/SymbolicUtilsInterop.jl delete mode 100644 src/ThDEC.jl create mode 100644 src/deca/ThDEC.jl delete mode 100644 src/decasymbolic.jl create mode 100644 src/symbolictheoryutils.jl diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index d0ef187..be9cace 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -61,13 +61,12 @@ include("rewrite.jl") include("pretty.jl") include("colanguage.jl") include("openoperators.jl") +include("symbolictheoryutils.jl") include("deca/Deca.jl") include("learn/Learn.jl") -include("ThDEC.jl") -include("decasymbolic.jl") +include("SymbolicUtilsInterop.jl") -@reexport using .ThDEC -@reexport using .SymbolicUtilsInterop @reexport using .Deca +@reexport using .SymbolicUtilsInterop end diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl new file mode 100644 index 0000000..c3c9b5a --- /dev/null +++ b/src/SymbolicUtilsInterop.jl @@ -0,0 +1,154 @@ +module SymbolicUtilsInterop + +using ..DiagrammaticEquations: AbstractDecapode, Quantity +import ..DiagrammaticEquations: eval_eq!, SummationDecapode +using ..decapodes +using ..Deca + +using MLStyle +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym + +# name collision with decapodes.Equation +struct SymbolicEquation{E} + lhs::E + rhs::E +end + +Base.show(io::IO, e::SymbolicEquation) = begin + print(io, e.lhs); print(io, " == "); print(io, e.rhs) +end + +## a struct carry the symbolic variables and their equations +struct SymbolicContext + vars::Vector{Symbolic} + equations::Vector{SymbolicEquation{Symbolic}} +end +export SymbolicContext + +Base.show(io::IO, d::SymbolicContext) = begin + println(io, "SymbolicContext(") + println(io, " Variables: [$(join(d.vars, ", "))]") + println(io, " Equations: [") + eqns = map(d.equations) do op + " $(op)" + end + println(io, "$(join(eqns,",\n"))])") + end + +## BasicSymbolic -> DecaExpr +function decapodes.Term(t::SymbolicUtils.BasicSymbolic) + if SymbolicUtils.issym(t) + decapodes.Var(nameof(t)) + else + op = SymbolicUtils.head(t) + args = SymbolicUtils.arguments(t) + termargs = Term.(args) + if op == + + decapodes.Plus(termargs) + elseif op == * + decapodes.Mult(termargs) + elseif op == ∂ₜ + decapodes.Tan(only(termargs)) + elseif length(args) == 1 + decapodes.App1(nameof(op, args...), termargs...) + elseif length(args) == 2 + decapodes.App2(nameof(op, args...), termargs...) + else + error("was unable to convert $t into a Term") + end + end +end + +decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) + +function decapodes.DecaExpr(d::SymbolicContext) + context = map(d.vars) do var + decapodes.Judgement(nameof(var), nameof(Sort(var)), :X) + end + equations = map(d.equations) do eq + decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) + end + decapodes.DecaExpr(context, equations) +end + +""" +Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC + +Example: +``` + a = @syms a::Real + context = Dict(:a => Scalar(), :u => PrimalForm(0)) + SymbolicUtils.BasicSymbolic(context, Term(a)) +``` +""" +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Quantity}, t::decapodes.Term, __module__=@__MODULE__) + # user must import symbols into scope + ! = (f -> getfield(__module__, f)) + @match t begin + Var(name) => SymbolicUtils.Sym{context[name]}(name) + Lit(v) => Meta.parse(string(v)) + # see heat_eq test: eqs had AppCirc1, but this returns + # App1(f, App1(...) + AppCirc1(fs, arg) => foldr( + # panics with constants like :k + # see test/language.jl + (f, x) -> (!(f))(x), + fs; + init=BasicSymbolic(context, arg, __module__) + ) + App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) + Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__)) + end +end + +function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) + # associates each var to its sort... + context = map(d.context) do j + @info j.var + j.var => j.var + end + # ... which we then produce a vector of symbolic vars + vars = map(context) do (v, s) + SymbolicUtils.Sym{s}(v) + end + context = Dict{Symbol,Quantity}(context) + eqs = map(d.equations) do eq + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) + end + SymbolicContext(vars, eqs) +end + +function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) + eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) +end + +""" function SummationDecapode(e::SymbolicContext) """ +function SummationDecapode(e::SymbolicContext) + d = SummationDecapode{Any, Any, Symbol}() + symbol_table = Dict{Symbol, Int}() + + foreach(e.vars) do var + # convert Sort(var)::PrimalForm0 --> :Form0 + var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) + symbol_table[var.name] = var_id + end + + deletions = Vector{Int}() + foreach(e.equations) do eq + eval_eq!(eq, d, symbol_table, deletions) + end + rem_parts!(d, :Var, sort(deletions)) + + recognize_types(d) + + fill_names!(d) + d[:name] = normalize_unicode.(d[:name]) + make_sum_mult_unique!(d) + return d +end + +end diff --git a/src/ThDEC.jl b/src/ThDEC.jl deleted file mode 100644 index 6cb25e0..0000000 --- a/src/ThDEC.jl +++ /dev/null @@ -1,283 +0,0 @@ -module ThDEC - -using MLStyle -using StructEquality - -import Base: +, -, * - -""" -Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces -for each alias typechecking and nameof methods which call those for their rep. -Example: -@alias d₀, d₁, += d -""" -macro alias(body) - (rep, aliases) = @match body begin - Expr(:tuple, rep, Expr(:tuple, aliases...)) => (rep, aliases) - _ => nothing - end - result = quote end - foreach(aliases) do alias - push!(result.args, - quote - function $(esc(alias))(s...) - $(esc(rep))(s...) - end - export $(esc(alias)) - Base.nameof(::typeof($alias), s) = nameof($rep, s) - end) - end - result -end - -struct SortError <: Exception - message::String -end - -@struct_hash_equal struct Space - name::Symbol - dim::Int -end -export Space - -dim(s::Space) = s.dim -Base.nameof(s::Space) = s.name - -struct SpaceLookup - default::Space - named::Dict{Symbol,Space} -end -export SpaceLookup - -function SpaceLookup(default::Space) - SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default)) -end - -@data Sort begin - Scalar() - Form(dim::Int, isdual::Bool, space::Space) - VField(isdual::Bool, space::Space) -end -export Sort, Scalar, Form, VField - -function fromexpr(lookup::SpaceLookup, e, ::Type{Sort}) - (name, spacename) = @match e begin - name::Symbol => (name, nothing) - :($name{$spacename}) => (name, spacename) - end - space = @match spacename begin - ::Nothing => lookup.default - name::Symbol => lookup.named[name] - end - @match name begin - :Form0 => Form(0, false, space) - :Form1 => Form(1, false, space) - :Form2 => Form(2, false, space) - :DualForm0 => Form(0, true, space) - :DualForm1 => Form(1, true, space) - :DualForm2 => Form(2, true, space) - :Constant => Scalar() - end -end - -Base.nameof(s::Scalar) = :Constant - -function Base.nameof(f::Form; with_dim_parameter=false) - dual = isdual(f) ? "Dual" : "" - formname = Symbol("$(dual)Form$(dim(f))") - if with_dim_parameter - return Expr(:curly, formname, dim(space(f))) - else - return formname - end -end - -const VF = VField - -dim(ω::Form) = ω.dim -isdual(ω::Form) = ω.isdual -space(ω::Form) = ω.space -export space - -isdual(v::VField) = v.isdual -space(v::VField) = v.space - -# convenience functions -PrimalForm(i::Int, space::Space) = Form(i, false, space) -export PrimalForm - -DualForm(i::Int, space::Space) = Form(i, true, space) -export DualForm - -PrimalVF(space::Space) = VF(false, space) -export PrimalVF - -DualVF(space::Space) = VF(true, space) -export DualVF - -# show methods -show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" - -function Base.show(io::IO, ω::Form) - print(io, isdual(ω) ? "DualForm($(dim(ω))) on $(space(ω))" : "PrimalForm($(dim(ω))) on $(space(ω))") -end - -# TODO: VField -@nospecialize -function +(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual, space)) || - (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) - (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => - if (i1 == i2) && (isdual1 == isdual2) && (space1 == space2) - Form(i1, isdual1, space1) - else - throw(SortError( - """ - Can not add two forms of different dimensions/dualities/spaces: - $((i1,isdual1,space1)) and $((i2,isdual2,space2)) - """) - ) - end - end -end - -# Type-checking inverse of addition follows addition --(s1::Sort, s2::Sort) = +(s1, s2) - -# TODO error for Forms - -# Negation is always valid --(s::Sort) = s - -# TODO: VField -@nospecialize -function *(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual, space)) || - (Form(i, isdual, space), Scalar()) => Form(i, isdual, space) - (Form(_, _, _), Form(_, _, _)) => - throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) - end -end - -const SUBSCRIPT_DIGIT_0 = '₀' - -as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) - -# TODO: VField -@nospecialize -function ∧(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Form(i, isdual, space), Scalar()) || (Scalar(), Form(i, isdual, space)) => - Form(i, isdual, space) - (Form(i1, isdual1, space1), Form(i2, isdual2, space2)) => begin - (isdual1 == isdual2) && (space1 == space2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) - if i1 + i2 <= dim(space1) - Form(i1 + i2, isdual1, space1) - else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(dim(space1)) is the dimension of the ambient space: tried to wedge product $i1 and $i2")) - end - end - (VF(isdual, _), _) || (_, VF(isdual, _)) => throw(SortError("Can only take a wedge product of forms. Flatten (♭) your vector field before applying")) - end -end - -function Base.nameof(::typeof(∧), s1, s2) - Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") -end - -@nospecialize -∂ₜ(s::Sort) = s - -@nospecialize -function d(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) - Form(i, isdual, space) => - if i < dim(space) - Form(i + 1, isdual, space) - else - throw(SortError("Cannot take exterior derivative of a k-form for k >= n, where n = $(dim(space)) is the dimension of its ambient space")) - end - end -end - -@alias d, (d₀, d₁) - -Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") - -@nospecialize -function ⋆(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) - VF(isdual, space) => throw(SortError("Cannot take the Hodge star of a vector field")) - Form(i, isdual, space) => Form(dim(space) - i, !isdual, space) - end -end -export ⋆ - -@alias ⋆, (⋆₀, ⋆₁, ⋆₂, ⋆₀⁻¹, ⋆₁⁻¹, ⋆₂⁻¹) - -function Base.nameof(::typeof(⋆), s) - inv = isdual(s) ? "⁻¹" : "" - Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") -end - -@nospecialize -function ι(s1::Sort, s2::Sort) - @match (s1, s2) begin - (VF(true, space), Form(i, true, space)) => Form(i, false, space) # wrong - (VF(true, space), Form(i, false, space)) => DualForm(i, true, space) - _ => throw(SortError("Can only define the discrete interior product on: - PrimalVF, DualForm(i) - DualVF(), PrimalForm(i) - .")) - end -end - -# in practice, a scalar may be treated as a constant 0-form. -function ♯(s::Sort) - @match s begin - Scalar() => PrimalVF() - Form(1, isdual, space) => VF(isdual, space) - _ => throw(SortError("Can only take ♯ to 1-forms")) - end -end -# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. - -Base.nameof(::typeof(♯), s) = Symbol("♯$s") - -function ♭(s::Sort) - @match s begin - VF(true, space) => Form(1, false, space) - _ => throw(SortError("Can only apply ♭ to dual vector fields")) - end -end - -Base.nameof(::typeof(♭), s) = Symbol("♭$s") - -# OTHER - -function ♭♯(s::Sort) - @match s begin - Form(i, isdual, space) => Form(i, !isdual, space) - _ => throw(SortError("♭♯ is only defined on forms.")) - end -end - -# Δ = ★d⋆d, but we check signature here to throw a more helpful error -function Δ(s::Sort) - @match s begin - Scalar() => Scalar() - Form(0, isdual, space) => Form(0, isdual, space) - Form(1, isdual, space) => Form(1, isdual, space) - _ => throw(SortError("Δ is not defined for $s")) - end -end - -Base.nameof(::typeof(Δ), s) = Symbol("Δ") - -end diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 8202704..48d5c8b 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -4,12 +4,17 @@ using DataStructures using ..DiagrammaticEquations using Catlab +using Reexport + import ..infer_types!, ..resolve_overloads! export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec! include("deca_acset.jl") include("deca_visualization.jl") +include("ThDEC.jl") + +@reexport using .TheoryDEC """ function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64}) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl new file mode 100644 index 0000000..4fe93c7 --- /dev/null +++ b/src/deca/ThDEC.jl @@ -0,0 +1,229 @@ +module TheoryDEC + +using ..DiagrammaticEquations: @register, @alias, Quantity + +using MLStyle +using StructEquality +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, Term, symtype + +import Base: +, -, * +import Catlab: Δ, ∧ + +# ########################## +# ThDEC +# +# Type necessary for symbolic utils +# ########################## + +abstract type ThDEC <: Quantity end + +struct Scalar <: ThDEC end +export Scalar + +struct FormParams + dim::Int + duality::Bool + space::Symbol + spacedim::Int +end + +dim(fp::FormParams) = getproperty(fp, :dim) +duality(fp::FormParams) = getproperty(fp, :duality) +space(fp::FormParams) = getproperty(fp, :space) +spacedim(fp::FormParams) = getproperty(fp, :spacedim) + +""" +i: dimension: 0,1,2, etc. +d: duality: true = dual, false = primal +s: name of the space (a symbol) +n: dimension of the space +""" +struct Form{i,d,s,n} <: ThDEC end +export Form + +# parameter accessors +dim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = i +isdual(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = d +space(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = s +spacedim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = n + +export dim, isdual, space, spacedim + +# convert form to form params +FormParams(::Type{<:Form{i,d,s,n}}) where {i,s,d,n} = FormParams(i,d,s,n) + +struct VField{d,s,n} <: ThDEC end +export VField + +# parameter accessors +isdual(::Type{<:VField{d,s,n}}) where {d,s,n} = d +space(::Type{VField{d,s,n}}) where {d,s,n} = s +spacedim(::Type{VField{d,s,n}}) where {d,s,n} = n + + +# convenience functions +const PrimalForm{i,s,n} = Form{i,false,s,n} +export PrimalForm + +const DualForm{i,s,n} = Form{i,true,s,n} +export DualForm + +const PrimalVF{s,n} = VField{false,s,n} +export PrimalVF + +const DualVF{s,n} = VField{true,s,n} +export DualVF + +# ACTIVE PATTERNS + +@active ActForm(T) begin + if T <: Form + Some(T) + end +end + +@active ActFormParams(T) begin + if T <: Form + Some([T.parameters...]) + end +end + +@active ActFormDim(T) begin + if T <: Form + Some(dim(T)) + end +end + +@active ActScalar(T) begin + if T <: Scalar + Some(T) + end +end + +@active ActVFParams(T) begin + if T <: VField + Some([T.parameters...]) + end +end + +# HERE WE DEFINE THE SYMBOLICUTILS + +# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term +unops = [:♯, :♭] + +@register -(S)::ThDEC begin S end + +@register ∂ₜ(S)::ThDEC begin S end + +@register d(S)::ThDEC begin + @match S begin + ActFormParams([i,d,s,n]) => Form{i+1,d,s,n} + _ => throw(SortError("Cannot apply the exterior derivative to $S")) + end +end + +@alias (d₀, d₁) => d + +@register ⋆(S)::ThDEC begin + @match S begin + ActFormParams([i,d,s,n]) => Form{n-i,d,s,n} + _ => throw(SortError("Cannot take the hodge star of $S")) + end +end + +@alias (⋆₀, ⋆₁, ⋆₂, ⋆₀⁻¹, ⋆₁⁻¹, ⋆₂⁻¹) => ⋆ + +@register Δ(S)::ThDEC begin + @match S begin + ActForm(x) => ⋆(d(⋆(d(x)))) + _ => throw(SortError("Cannot take the Laplacian of $S")) + end +end + +@register +(S1, S2)::ThDEC begin + @match (S1, S2) begin + (ActScalar, ActScalar) => Scalar + (ActScalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), ActScalar) => S1 # commutativity + (ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin + if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) + Form{i1, d1, s1, n1} + else + throw(SortError(""" + Can not add two forms of different dimensions/dualities/spaces: + $((i1,d1,s1)) and $((i2,d2,s2)) + """)) + end + end + _ => error("Nay!") + end +end + +@register -(S1, S2)::ThDEC begin +(S1, S2) end + +@register *(S1, S2)::ThDEC begin + @match (S1, S2) begin + (Scalar, Scalar) => Scalar + (Scalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n} + _ => throw(SortError("Cannot multiple $S1 and $S2")) + end +end + +@register ∧(S1, S2)::ThDEC begin + @match (S1, S2) begin + (ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin + (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) + if i1 + i2 <= n1 + Form{i1 + i2, d1, s1, n1} + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $n1 is the dimension of the ambient space: tried to wedge product $i1 and $i2")) + end + end + end +end + +struct SortError <: Exception + message::String +end + +Base.nameof(s::Scalar) = :Constant + +function Base.nameof(f::Form; with_dim_parameter=false) + dual = isdual(f) ? "Dual" : "" + formname = Symbol("$(dual)Form$(dim(f))") + if with_dim_parameter + return Expr(:curly, formname, dim(space(f))) + else + return formname + end +end + +# show methods +show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" + +function Base.show(io::IO, ω::Form) + print(io, isdual(ω) ? "DualForm($(dim(ω))) on $(space(ω))" : "PrimalForm($(dim(ω))) on $(space(ω))") +end + +Base.nameof(::typeof(-), s1, s2) = Symbol("$(as_sub(dim(s1)))-$(as_sub(dim(s2)))") + +const SUBSCRIPT_DIGIT_0 = '₀' + +as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) + +function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} + Symbol("∧$(as_sub(dim(symtype(s1))))$(as_sub(dim(symtype(s2))))") +end + +function Base.nameof(::typeof(∧), s1, s2) + Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") +end + +Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") + +function Base.nameof(::typeof(⋆), s) + inv = isdual(s) ? "⁻¹" : "" + Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") +end + +end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl deleted file mode 100644 index 8933ae4..0000000 --- a/src/decasymbolic.jl +++ /dev/null @@ -1,340 +0,0 @@ -module SymbolicUtilsInterop - -using ..DiagrammaticEquations: AbstractDecapode -import ..DiagrammaticEquations: eval_eq!, SummationDecapode -using ..ThDEC -import ..ThDEC: Sort, dim, isdual -using ..decapodes - -using MLStyle -using SymbolicUtils -using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym - -# ########################## -# DECType -# -# Type necessary for symbolic utils -# ########################## - -# define DECType as a Number. Necessary for SymbolicUtils -abstract type DECType <: Number end - -""" -i: dimension: 0,1,2, etc. -d: duality: true = dual, false = primal -s: name of the space (a symbol) -n: dimension of the space -""" -struct FormT{i,d,s,n} <: DECType end -export FormT - -struct VFieldT{d,s,n} <: DECType end -export VFieldT - -dim(::Type{<:FormT{i,d,s,n}}) where {i,d,s,n} = d -isdual(::Type{FormT{i,d,s,n}}) where {i,d,s,n} = d - -# convenience functions -const PrimalFormT{i,s,n} = FormT{i,false,s,n} -export PrimalFormT - -const DualFormT{i,s,n} = FormT{i,true,s,n} -export DualFormT - -const PrimalVFT{s,n} = VFieldT{false,s,n} -export PrimalVFT - -const DualVFT{s,n} = VFieldT{true,s,n} -export DualVFT - -""" -converts ThDEC Sorts into DECType -""" -function Sort end - -Sort(::Type{<:Real}) = Scalar() -Sort(::Real) = Scalar() - -function Sort(::Type{FormT{i,d,s,n}}) where {i,d,s,n} - Form(i, d, Space(s, n)) -end - -function Sort(::Type{VFieldT{d,s,n}}) where {d,s,n} - VField(d, Space(s, n)) -end - -Sort(::BasicSymbolic{T}) where {T} = Sort(T) - -""" -converts ThDEC Sorts into DecaSymbolic types -""" -Number(s::Scalar) = Real - -Number(f::Form) = FormT{dim(f),isdual(f), nameof(space(f)), dim(space(f))} - -Number(v::VField) = VFieldT{isdual(v), nameof(space(v)), dim(space(v))} - -# HERE WE DEFINE THE SYMBOLICUTILS - -# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term -unop_dec = [:∂ₜ, :d, :d₀, :d₁ - , :⋆, :⋆₀, :⋆₁, :⋆₂, :⋆₀⁻¹, :⋆₁⁻¹, :⋆₂⁻¹ - , :♯, :♭, :-] - -for unop in unop_dec - @eval begin - @nospecialize - function ThDEC.$unop( - v::BasicSymbolic{T} - ) where {T<:DECType} - s = ThDEC.$unop(Sort(T)) - SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v]) - end - end -end - -# BasicSymbolic{FnType{Tuple{PrimalFormT{0}}}, PrimalFormT{0}} - -binop_dec = [:+, :-, :*, :∧, :^] -export +,-,*,∧,^ - -for binop in binop_dec - @eval begin - @nospecialize - function ThDEC.$binop( - v::BasicSymbolic{T1}, - w::BasicSymbolic{T2} - ) where {T1<:DECType,T2<:DECType} - s = ThDEC.$binop(Sort(T1), Sort(T2)) - SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) - end - - @nospecialize - function ThDEC.$binop( - v::BasicSymbolic{T1}, - w::BasicSymbolic{T2} - ) where {T1<:DECType,T2<:Real} - s = ThDEC.$binop(Sort(T1), Sort(T2)) - SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) - end - - @nospecialize - function ThDEC.$binop( - v::BasicSymbolic{T1}, - w::BasicSymbolic{T2} - ) where {T1<:Real,T2<:DECType} - s = ThDEC.$binop(Sort(T1), Sort(T2)) - SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) - end - end -end - -# name collision with decapodes.Equation -struct DecaEquation{E} - lhs::E - rhs::E -end -export DecaEquation - -Base.show(io::IO, e::DecaEquation) = begin - print(io, e.lhs) - print(io, " == ") - print(io, e.rhs) -end - -# a struct carry the symbolic variables and their equations -struct DecaSymbolic - vars::Vector{Symbolic} - equations::Vector{DecaEquation{Symbolic}} -end -export DecaSymbolic - -Base.show(io::IO, d::DecaSymbolic) = begin - println(io, "DecaSymbolic(") - println(io, " Variables: [$(join(d.vars, ", "))]") - println(io, " Equations: [") - eqns = map(d.equations) do op - " $(op)" - end - println(io, "$(join(eqns,",\n"))])") - end - -# BasicSymbolic -> DecaExpr -function decapodes.Term(t::SymbolicUtils.BasicSymbolic) - if SymbolicUtils.issym(t) - decapodes.Var(nameof(t)) - else - op = SymbolicUtils.head(t) - args = SymbolicUtils.arguments(t) - termargs = Term.(args) - sorts = ThDEC.Sort.(args) - if op == + - decapodes.Plus(termargs) - elseif op == * - decapodes.Mult(termargs) - elseif op == ThDEC.∂ₜ - decapodes.Tan(only(termargs)) - elseif length(args) == 1 - decapodes.App1(nameof(op, sorts...), termargs...) - elseif length(args) == 2 - decapodes.App2(nameof(op, sorts...), termargs...) - else - error("was unable to convert $t into a Term") - end - end -end - -decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) - -function decapodes.DecaExpr(d::DecaSymbolic) - context = map(d.vars) do var - decapodes.Judgement(nameof(var), nameof(Sort(var)), :X) - end - equations = map(d.equations) do eq - decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) - end - decapodes.DecaExpr(context, equations) -end - -""" -Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC - -Example: -``` - a = @syms a::Real - context = Dict(:a => Scalar(), :u => PrimalForm(0)) - SymbolicUtils.BasicSymbolic(context, Term(a)) -``` -""" -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) - # user must import symbols into scope - ! = (f -> getfield(Main, f)) - @match t begin - Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) - Lit(v) => Meta.parse(string(v)) - # see heat_eq test: eqs had AppCirc1, but this returns - # App1(f, App1(...) - AppCirc1(fs, arg) => foldr( - # panics with constants like :k - # see test/language.jl - (f, x) -> (!(f))(x), - fs; - init=BasicSymbolic(context, arg) - ) - # getfield(Main, - App1(f, x) => (!(f))(BasicSymbolic(context, x)) - App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) - Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) - Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) - Tan(x) => ThDEC.∂ₜ(BasicSymbolic(context, x)) - end -end - -function DecaSymbolic(lookup::SpaceLookup, d::decapodes.DecaExpr) - # associates each var to its sort... - context = map(d.context) do j - j.var => ThDEC.fromexpr(lookup, j.dim, Sort) - end - # ... which we then produce a vector of symbolic vars - vars = map(context) do (v, s) - SymbolicUtils.Sym{Number(s)}(v) - end - context = Dict{Symbol,Sort}(context) - eqs = map(d.equations) do eq - DecaEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) - end - DecaSymbolic(vars, eqs) -end - -function eval_eq!(eq::DecaEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) - eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) -end - -""" function SummationDecapode(e::DecaSymbolic) """ -function SummationDecapode(e::DecaSymbolic) - d = SummationDecapode{Any, Any, Symbol}() - symbol_table = Dict{Symbol, Int}() - - foreach(e.vars) do var - # convert Sort(var)::PrimalForm0 --> :Form0 - var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) - symbol_table[var.name] = var_id - end - - deletions = Vector{Int}() - foreach(e.equations) do eq - eval_eq!(eq, d, symbol_table, deletions) - end - rem_parts!(d, :Var, sort(deletions)) - - recognize_types(d) - - fill_names!(d) - d[:name] = normalize_unicode.(d[:name]) - make_sum_mult_unique!(d) - return d -end - -""" -Registers a new function - -``` -@register Δ(s::Sort) begin - @match s begin - ::Scalar => error("Invalid") - ::VField => error("Invalid") - ::Form => ⋆(d(⋆(d(s)))) - end -end -``` - -will create an additional method for Δ for operating on BasicSymbolic -""" -macro register(head, body) - # parse head - parsehead = @λ begin - Expr(:call, f, types...) => (f, parsehead.(types)) - Expr(:(::), var, type) => (var, type) - s => s - end - (f, args) = parsehead(head) - matchargs = [:($(x[1])::$(x[2])) for x in args] - - result = quote end - push!(result.args, - esc(quote - function $f($(matchargs...)) - $body - end - end)) - - # e.g., given [(:x, :Scalar), (:ω, :Form)]... - vs = enumerate(unique(getindex.(args, 2))) - theargs = - Dict{Symbol,Symbol}( - [v => Symbol("T$k") for (k,v) in vs] - ) - # ...[(Scalar=>:T1, :Form=>:T2)] - - # reassociate vars with their BasicSymbolic Generic Types - binding = map(args) do (var, type) - (var, :(BasicSymbolic{$(theargs[type])})) - end - newargs = [:($(x[1])::$(x[2])) for x in binding] - constraints = [:($T<:DECType) for T in values(theargs)] - innerargs = [:(Sort($T)) for T in values(theargs)] - - push!(result.args, - quote - @nospecialize - function $(esc(f))($(newargs...)) where $(constraints...) - s = $(esc(f))($(innerargs...)) - SymbolicUtils.Term{Number(s)}($(esc(f)), [$(getindex.(binding, 1)...)]) - end - end) - - return result -end -export @register - -end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl new file mode 100644 index 0000000..b6448eb --- /dev/null +++ b/src/symbolictheoryutils.jl @@ -0,0 +1,161 @@ +using MLStyle +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym + +abstract type Quantity <: Number end +export Quantity + +""" +Registers a new function + +``` +@register foo(S1, S2, ...)::ThDEC begin + (body of function) +end +``` +builds +``` +foo(::Type{S1}, ::Type{S2}, ...) where {S1<:ThDEC, S2<:ThDEC, ...} + (body of function) +end +``` +as well as +``` +foo(S1::BasicSymbolic{T1}, S2::BasicSymbolic{T2}, ...) where {T1<:ThDEC, ...} + s = foo(T1, T2, ...) + SymbolicUtils.Term{s}(foo, [S1, S2, ...]) +end +``` + +``` +@register Δ(s::ThDEC) begin + @match s begin + ::Scalar => error("Invalid") + ::VField => error("Invalid") + ::Form => ⋆(d(⋆(d(s)))) + end +end +``` + +Δ(S1, S2) begin + @match (S1, S2) + + +will create an additional method for Δ for operating on BasicSymbolic +""" +macro register(head, body) + + # parse body + ph = @λ begin + Expr(:call, foo, Expr(:(::), vars..., theory)) => (foo, vars, theory) + Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) + _ => error("$head") + end + (f, vars, Theory) = ph(head) + + symbolic_args = [:(::Type{$S}) for S in vars] + symbolic_constraints = [:($S<:$Theory) for S in vars] + + # initialize the result + result = quote end + + # DEFINE TYPE INFERENCE IN THE ThDEC SYSTEM + + # TODO this just accepts whatever the body is + push!(result.args, + esc(quote + function $f($(symbolic_args...)) where {$(symbolic_constraints...)} + $body + end + end)) + + # CONSTRUCT THE FUNCTION ON BASIC SYMBOLICS + + # ...associate each var (S1) to a generic. this will be used in the + # type constraint of the new function. + generic_vars = [(v, Symbol("T$k")) for (k,v) in enumerate(vars)] + + # reassociate vars with their BasicSymbolic Generic Types + basicsym_bindings = map(generic_vars) do (var, T) + (var, :(BasicSymbolic{$T})) + end + + # binding type bindings to the basicsymbolics + basicsym_args = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings] + + # build constraints + constraints_expr = [:($T<:$Theory) for T in getindex.(generic_vars, 2)] + + push!(result.args, + esc(quote + @nospecialize + function $f($(basicsym_args...)) where {$(constraints_expr...)} + s = $f($(getindex.(generic_vars, 2)...)) + SymbolicUtils.Term{s}($f ,[$(getindex.(basicsym_bindings, 1)...)]) + end + export $f + end)) + + return result +end +export @register + +function alias(x) + error("$x has no aliases") +end + +""" +Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces +for each alias typechecking and nameof methods which call those for their rep. +Example: +@alias (d₀, d₁) => d +""" +macro alias(body) + (rep, aliases) = @match body begin + Expr(:call, :(=>), Expr(:tuple, aliases...), rep) => (rep, aliases) + _ => error("parse error") + end + result = quote end + foreach(aliases) do alias + push!(result.args, + esc(quote + function $alias(s...) + $rep(s...) + end + export $alias + Base.nameof(::typeof($alias), s) = Symbol("$alias") + end)) + end + result +end +export alias + +macro see(body) + ph = @λ begin + Expr(:(=), Expr(:where, Expr(:call, foo, typebindings), params...), + Expr(:block, body...)) => (foo, ph(typebindings), params, body) + Expr(:(::), vars...) => ph.(vars) + Expr(:curly, :Type, Expr(:<:, Expr(:curly, type, params...))) => (type, params) + s => s + end + ph(body) + quote + $foo(arg, s1::B1, s2::B1) where {S1,S2,B1<:BasicSymbolic{S1},B2<:BasicSymbolic{S2}} + + end +end + +@see dim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = i + +function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} + Symbol("∧$(as_sub(dim(symtype(s1))))$(as_sub(dim(symtype(s2))))") +end + + +Expr(:=, + Expr(:where + [Expr(:call + foo, + Expr(:(::), e...)), + params...]), + Expr(:block, body...)) diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index fbbfc15..15cb4f3 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -1,45 +1,45 @@ using Test -using DiagrammaticEquations -using DiagrammaticEquations.ThDEC +using DiagrammaticEquations.Deca.TheoryDEC using DiagrammaticEquations.decapodes using SymbolicUtils -# what space are we working in? -X = Space(:X, 2) -lookup = SpaceLookup(X) +using SymbolicUtils: symtype # load up some variable variables and expressions -a, b = @syms a::Real b::Real -u, v = @syms u::PrimalFormT{0, :X, 2} du::PrimalFormT{1, :X, 2} -ω, η = @syms ω::PrimalFormT{1, :X, 2} η::DualFormT{2, :X, 2} -ϕ, ψ = @syms ϕ::PrimalVFT{:X, 2} ψ::DualVFT{:X, 2} +a, b = @syms a::Scalar b::Scalar +u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin - + + # TODO implement symtype # test conversion to underlying type - @test Sort(a) == Scalar() - @test Sort(u) == PrimalForm(0, X) - @test Sort(ω) == PrimalForm(1, X) - @test Sort(η) == DualForm(2, X) - @test Sort(ϕ) == PrimalVF(X) - @test Sort(ψ) == DualVF(X) + @test symtype(a) == Scalar + @test symtype(u) == PrimalForm{0, :X, 2} + @test symtype(ω) == PrimalForm{1, :X, 2} + @test symtype(η) == DualForm{2, :X, 2} + @test symtype(ϕ) == PrimalVF{:X, 2} + @test symtype(ψ) == DualVF{:X, 2} - @test_throws ThDEC.SortError ThDEC.♯(u) + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} + @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + # @test_throws ThDEC.SortError ThDEC.♯(u) # test unary operator conversion to decaexpr - @test Term(1) == DiagrammaticEquations.decapodes.Lit(Symbol("1")) + @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) - @test Term(ThDEC.∂ₜ(u)) == Tan(Var(:u)) - @test Term(ThDEC.⋆(ω)) == App1(:⋆₁, Var(:ω)) - @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) + @test Term(∂ₜ(u)) == Tan(Var(:u)) + @test Term(⋆(ω)) == App1(:⋆₁, Var(:ω)) + # @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) # @test Term(DiagrammaticEquations.ThDEC.♯(du)) - @test_throws ThDEC.SortError ThDEC.⋆(ϕ) + # @test_throws ThDEC.SortError ThDEC.⋆(ϕ) # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) - @test Term(a * b) == DiagrammaticEquations.decapodes.Mult(Term[Var(:a), Var(:b)]) - @test Term(ThDEC.:∧(ω, du)) == App2(:∧₁₁, Var(:ω), Var(:du)) + @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) + @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) end diff --git a/test/klausmeier.jl b/test/klausmeier.jl index 9c88e46..54fdaee 100644 --- a/test/klausmeier.jl +++ b/test/klausmeier.jl @@ -1,10 +1,10 @@ using DiagrammaticEquations using DiagrammaticEquations.SymbolicUtilsInterop - +# using Test using MLStyle using SymbolicUtils -using SymbolicUtils: BasicSymbolic +using SymbolicUtils: BasicSymbolic, symtype # See Klausmeier Equation 2.a Hydrodynamics = @decapode begin @@ -42,6 +42,8 @@ Phytodynamics = parse_decapode(quote ∂ₜ(n) == w - m*n + Δ(n) end) +@test_broken DecaSymbolic(lookup, Phytodynamics) + import .ThDEC: d, ⋆, SortError @register Δ(s::Sort) begin @@ -55,6 +57,11 @@ end ω, = @syms ω::PrimalFormT{1, :X, 2} @test Δ(PrimalForm(1, X)) == PrimalForm(1, X) -@test Δ(ω) |> typeof == BasicSymbolic{PrimalFormT{1, :X, 2}} +@test symtype(Δ(ω)) == PrimalFormT{1, :X, 2} + +# TODO propagating module information is suited for a macro +symbmodel = DecaSymbolic(lookup, Phytodynamics, Main) + +DecaExpr(symbmodel) + -DecaSymbolic(lookup, Phytodynamics) From 1fa477b61ba2edcbbb7da77d7fcebbea220f218c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 30 Aug 2024 13:58:56 -0400 Subject: [PATCH 11/30] adding promote_symtype and addressing some of the code review comments given --- src/SymbolicUtilsInterop.jl | 7 +++-- src/deca/Deca.jl | 2 +- src/deca/ThDEC.jl | 56 +++++++++++++++++++++------------- src/symbolictheoryutils.jl | 61 +++++++++++++------------------------ test/decasymbolic.jl | 14 ++++++--- 5 files changed, 70 insertions(+), 70 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index c3c9b5a..f6a538b 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -7,7 +7,7 @@ using ..Deca using MLStyle using SymbolicUtils -using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype # name collision with decapodes.Equation struct SymbolicEquation{E} @@ -107,14 +107,15 @@ end function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) # associates each var to its sort... + @info d.context context = map(d.context) do j - @info j.var - j.var => j.var + j.var => symtype(ThDEC, j.dim, j.space) end # ... which we then produce a vector of symbolic vars vars = map(context) do (v, s) SymbolicUtils.Sym{s}(v) end + @info context context = Dict{Symbol,Quantity}(context) eqs = map(d.equations) do eq SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) diff --git a/src/deca/Deca.jl b/src/deca/Deca.jl index 48d5c8b..3187189 100644 --- a/src/deca/Deca.jl +++ b/src/deca/Deca.jl @@ -14,7 +14,7 @@ include("deca_acset.jl") include("deca_visualization.jl") include("ThDEC.jl") -@reexport using .TheoryDEC +@reexport using .ThDEC """ function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64}) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 4fe93c7..bf0618f 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -1,6 +1,6 @@ -module TheoryDEC +module ThDEC -using ..DiagrammaticEquations: @register, @alias, Quantity +using ..DiagrammaticEquations: @operator, @alias, Quantity using MLStyle using StructEquality @@ -11,14 +11,14 @@ import Base: +, -, * import Catlab: Δ, ∧ # ########################## -# ThDEC +# DECQuantity # # Type necessary for symbolic utils # ########################## -abstract type ThDEC <: Quantity end +abstract type DECQuantity <: Quantity end -struct Scalar <: ThDEC end +struct Scalar <: DECQuantity end export Scalar struct FormParams @@ -28,10 +28,10 @@ struct FormParams spacedim::Int end -dim(fp::FormParams) = getproperty(fp, :dim) -duality(fp::FormParams) = getproperty(fp, :duality) -space(fp::FormParams) = getproperty(fp, :space) -spacedim(fp::FormParams) = getproperty(fp, :spacedim) +dim(fp::FormParams) = fp.dim +duality(fp::FormParams) = fp.duality +space(fp::FormParams) = fp.space +spacedim(fp::FormParams) = fp.spacedim """ i: dimension: 0,1,2, etc. @@ -39,7 +39,7 @@ d: duality: true = dual, false = primal s: name of the space (a symbol) n: dimension of the space """ -struct Form{i,d,s,n} <: ThDEC end +struct Form{i,d,s,n} <: DECQuantity end export Form # parameter accessors @@ -53,7 +53,7 @@ export dim, isdual, space, spacedim # convert form to form params FormParams(::Type{<:Form{i,d,s,n}}) where {i,s,d,n} = FormParams(i,d,s,n) -struct VField{d,s,n} <: ThDEC end +struct VField{d,s,n} <: DECQuantity end export VField # parameter accessors @@ -61,7 +61,6 @@ isdual(::Type{<:VField{d,s,n}}) where {d,s,n} = d space(::Type{VField{d,s,n}}) where {d,s,n} = s spacedim(::Type{VField{d,s,n}}) where {d,s,n} = n - # convenience functions const PrimalForm{i,s,n} = Form{i,false,s,n} export PrimalForm @@ -112,11 +111,11 @@ end # for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term unops = [:♯, :♭] -@register -(S)::ThDEC begin S end +@operator -(S)::DECQuantity begin S end -@register ∂ₜ(S)::ThDEC begin S end +@operator ∂ₜ(S)::DECQuantity begin S end -@register d(S)::ThDEC begin +@operator d(S)::DECQuantity begin @match S begin ActFormParams([i,d,s,n]) => Form{i+1,d,s,n} _ => throw(SortError("Cannot apply the exterior derivative to $S")) @@ -125,7 +124,7 @@ end @alias (d₀, d₁) => d -@register ⋆(S)::ThDEC begin +@operator ⋆(S)::DECQuantity begin @match S begin ActFormParams([i,d,s,n]) => Form{n-i,d,s,n} _ => throw(SortError("Cannot take the hodge star of $S")) @@ -134,14 +133,14 @@ end @alias (⋆₀, ⋆₁, ⋆₂, ⋆₀⁻¹, ⋆₁⁻¹, ⋆₂⁻¹) => ⋆ -@register Δ(S)::ThDEC begin +@operator Δ(S)::DECQuantity begin @match S begin ActForm(x) => ⋆(d(⋆(d(x)))) _ => throw(SortError("Cannot take the Laplacian of $S")) end end -@register +(S1, S2)::ThDEC begin +@operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin (ActScalar, ActScalar) => Scalar (ActScalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), ActScalar) => S1 # commutativity @@ -159,9 +158,9 @@ end end end -@register -(S1, S2)::ThDEC begin +(S1, S2) end +@operator -(S1, S2)::DECQuantity begin +(S1, S2) end -@register *(S1, S2)::ThDEC begin +@operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin (Scalar, Scalar) => Scalar (Scalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n} @@ -169,7 +168,7 @@ end end end -@register ∧(S1, S2)::ThDEC begin +@operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin (ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) @@ -186,6 +185,8 @@ struct SortError <: Exception message::String end +# struct WedgeDimError <: SortError end + Base.nameof(s::Scalar) = :Constant function Base.nameof(f::Form; with_dim_parameter=false) @@ -226,4 +227,17 @@ function Base.nameof(::typeof(⋆), s) Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end +function SymbolicUtils.symtype(::Quantity, qty::Symbol, space::Symbol) + @match qty begin + :Scalar => Scalar + :Form0 => PrimalForm{0, space, 1} + :Form1 => PrimalForm{1, space, 1} + :Form2 => PrimalForm{2, space, 1} + :DualForm0 => DualForm{0, space, 1} + :DualForm1 => DualForm{1, space, 1} + :DualForm2 => DualForm{2, space, 1} + _ => error("$qty") + end +end + end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index b6448eb..7d2bae0 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -1,15 +1,17 @@ using MLStyle using SymbolicUtils -using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype +""" ThDEC in DiagrammaticEquations must be subtyped by Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers +""" abstract type Quantity <: Number end export Quantity """ -Registers a new function +Creates an operator `foo` with arguments which are types in a given Theory. This entails creating (1) a function which performs type construction and (2) a function which consumes BasicSymbolic variables and returns Terms. ``` -@register foo(S1, S2, ...)::ThDEC begin +@operator foo(S1, S2, ...)::Theory begin (body of function) end ``` @@ -28,7 +30,7 @@ end ``` ``` -@register Δ(s::ThDEC) begin +@operator Δ(s::ThDEC) begin @match s begin ::Scalar => error("Invalid") ::VField => error("Invalid") @@ -43,7 +45,7 @@ end will create an additional method for Δ for operating on BasicSymbolic """ -macro register(head, body) +macro operator(head, body) # parse body ph = @λ begin @@ -81,24 +83,33 @@ macro register(head, body) end # binding type bindings to the basicsymbolics - basicsym_args = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings] + bs_arg_exprs = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings] # build constraints - constraints_expr = [:($T<:$Theory) for T in getindex.(generic_vars, 2)] + constraint_exprs = [:($T<:$Theory) for T in getindex.(generic_vars, 2)] push!(result.args, esc(quote @nospecialize - function $f($(basicsym_args...)) where {$(constraints_expr...)} + function $f($(bs_arg_exprs...)) where {$(constraint_exprs...)} s = $f($(getindex.(generic_vars, 2)...)) SymbolicUtils.Term{s}($f ,[$(getindex.(basicsym_bindings, 1)...)]) end export $f - end)) + end)) + + push!(result.args, + esc(quote + # we want to feed symtype the generics + function SymbolicUtils.promote_symtype(::typeof($f), + $(bs_arg_exprs...)) where {$(constraint_exprs...)} + $f($(getindex.(generic_vars, 2)...)) + end + end)) return result end -export @register +export @operator function alias(x) error("$x has no aliases") @@ -129,33 +140,3 @@ macro alias(body) result end export alias - -macro see(body) - ph = @λ begin - Expr(:(=), Expr(:where, Expr(:call, foo, typebindings), params...), - Expr(:block, body...)) => (foo, ph(typebindings), params, body) - Expr(:(::), vars...) => ph.(vars) - Expr(:curly, :Type, Expr(:<:, Expr(:curly, type, params...))) => (type, params) - s => s - end - ph(body) - quote - $foo(arg, s1::B1, s2::B1) where {S1,S2,B1<:BasicSymbolic{S1},B2<:BasicSymbolic{S2}} - - end -end - -@see dim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = i - -function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} - Symbol("∧$(as_sub(dim(symtype(s1))))$(as_sub(dim(symtype(s2))))") -end - - -Expr(:=, - Expr(:where - [Expr(:call - foo, - Expr(:(::), e...)), - params...]), - Expr(:block, body...)) diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 15cb4f3..d603d9f 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -1,8 +1,8 @@ using Test -using DiagrammaticEquations.Deca.TheoryDEC +using DiagrammaticEquations.Deca.ThDEC using DiagrammaticEquations.decapodes using SymbolicUtils -using SymbolicUtils: symtype +using SymbolicUtils: symtype, promote_symtype # load up some variable variables and expressions a, b = @syms a::Scalar b::Scalar @@ -11,10 +11,10 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication + @testset "Term Construction" begin # TODO implement symtype - # test conversion to underlying type @test symtype(a) == Scalar @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @@ -30,7 +30,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) @test Term(∂ₜ(u)) == Tan(Var(:u)) - @test Term(⋆(ω)) == App1(:⋆₁, Var(:ω)) + @test_broken Term(⋆(ω)) == App1(:⋆₁, Var(:ω)) # @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) # @test Term(DiagrammaticEquations.ThDEC.♯(du)) @@ -39,7 +39,11 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) - @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + + @test promote_symtype(+, a, b) == Scalar + @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} + @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} end From 3b265ac18ea65bcfd68aa1f45b0163e7c5e60a52 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 4 Sep 2024 12:49:26 -0400 Subject: [PATCH 12/30] refactoring @operator to integrate with promote_symtype --- src/SymbolicUtilsInterop.jl | 2 - src/deca/ThDEC.jl | 123 +++++++++++++++++++++++++----------- src/symbolictheoryutils.jl | 98 +++++++++++++++------------- test/decasymbolic.jl | 23 ++++++- 4 files changed, 163 insertions(+), 83 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index f6a538b..6de6cb5 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -107,7 +107,6 @@ end function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) # associates each var to its sort... - @info d.context context = map(d.context) do j j.var => symtype(ThDEC, j.dim, j.space) end @@ -115,7 +114,6 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) vars = map(context) do (v, s) SymbolicUtils.Sym{s}(v) end - @info context context = Dict{Symbol,Quantity}(context) eqs = map(d.equations) do eq SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index bf0618f..74ff1ee 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -5,7 +5,8 @@ using ..DiagrammaticEquations: @operator, @alias, Quantity using MLStyle using StructEquality using SymbolicUtils -using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, Term, symtype +using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, Term +import SymbolicUtils: symtype, promote_symtype import Base: +, -, * import Catlab: Δ, ∧ @@ -17,6 +18,10 @@ import Catlab: Δ, ∧ # ########################## abstract type DECQuantity <: Quantity end +export DECQuantity + +# this ensures symtype doesn't recurse endlessly +SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S struct Scalar <: DECQuantity end export Scalar @@ -76,40 +81,40 @@ export DualVF # ACTIVE PATTERNS -@active ActForm(T) begin +@active PatForm(T) begin if T <: Form Some(T) end end +export PatForm -@active ActFormParams(T) begin +@active PatFormParams(T) begin if T <: Form Some([T.parameters...]) end end +export PatFormParams -@active ActFormDim(T) begin +@active PatFormDim(T) begin if T <: Form Some(dim(T)) end end +export PatFormDim -@active ActScalar(T) begin +@active PatScalar(T) begin if T <: Scalar Some(T) end end +export PatScalar -@active ActVFParams(T) begin +@active PatVFParams(T) begin if T <: VField Some([T.parameters...]) end end - -# HERE WE DEFINE THE SYMBOLICUTILS - -# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term -unops = [:♯, :♭] +export PatVFParams @operator -(S)::DECQuantity begin S end @@ -117,44 +122,41 @@ unops = [:♯, :♭] @operator d(S)::DECQuantity begin @match S begin - ActFormParams([i,d,s,n]) => Form{i+1,d,s,n} - _ => throw(SortError("Cannot apply the exterior derivative to $S")) + PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} + _ => throw(ExteriorDerivativeError(S)) end end @alias (d₀, d₁) => d -@operator ⋆(S)::DECQuantity begin +@operator ★(S)::DECQuantity begin @match S begin - ActFormParams([i,d,s,n]) => Form{n-i,d,s,n} - _ => throw(SortError("Cannot take the hodge star of $S")) + PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} + _ => throw(HodgeStarError(S)) end end -@alias (⋆₀, ⋆₁, ⋆₂, ⋆₀⁻¹, ⋆₁⁻¹, ⋆₂⁻¹) => ⋆ +@alias (★₀, ★₁, ★₂, ★₀⁻¹, ★₁⁻¹, ★₂⁻¹) => ★ @operator Δ(S)::DECQuantity begin @match S begin - ActForm(x) => ⋆(d(⋆(d(x)))) - _ => throw(SortError("Cannot take the Laplacian of $S")) + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + _ => throw(LaplacianError(S)) end end @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin - (ActScalar, ActScalar) => Scalar - (ActScalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), ActScalar) => S1 # commutativity - (ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin + (PatScalar, PatScalar) => Scalar + (PatScalar, PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar) => S1 # commutativity + (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} else - throw(SortError(""" - Can not add two forms of different dimensions/dualities/spaces: - $((i1,d1,s1)) and $((i2,d2,s2)) - """)) + throw(AdditionDimensionalError(S1, S2)) end end - _ => error("Nay!") + _ => throw(BinaryOpError("add", S1, S2)) end end @@ -163,27 +165,26 @@ end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin (Scalar, Scalar) => Scalar - (Scalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n} - _ => throw(SortError("Cannot multiple $S1 and $S2")) + (Scalar, PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n} + _ => throw(BinaryOpError("multiply", S1, S2)) end end @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin - (ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin - (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space")) + (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin + (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) if i1 + i2 <= n1 Form{i1 + i2, d1, s1, n1} else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than n, where n = $n1 is the dimension of the ambient space: tried to wedge product $i1 and $i2")) + throw(WedgeDimError(S1, S2)) end end + _ => throw(BinaryOpError("take the wedge product of", S1, S2)) end end -struct SortError <: Exception - message::String -end +abstract type SortError <: Exception end # struct WedgeDimError <: SortError end @@ -222,7 +223,7 @@ end Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") -function Base.nameof(::typeof(⋆), s) +function Base.nameof(::typeof(★), s) inv = isdual(s) ? "⁻¹" : "" Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end @@ -240,4 +241,54 @@ function SymbolicUtils.symtype(::Quantity, qty::Symbol, space::Symbol) end end +struct ExteriorDerivativeError <: SortError + sort::DECQuantity +end + +Base.showerror(io, e::ExteriorDerivativeError) = print(io, "Cannot apply the exterior derivative to $(e.sort)") + +struct HodgeStarError <: SortError + sort::DECQuantity +end + +Base.showerror(io, e::HodgeStarError) = print(io, "Cannot take the hodge star of $(e.sort)") + +struct LaplacianError <: SortError + sort::DECQuantity +end + +Base.showerror(io, e::LaplacianError) = print(io, "Cannot take the Laplacian of $(e.sort)") + +struct AdditionDimensionalError <: SortError + sort1::DECQuantity + sort2::DECQuantity +end + +Base.showerror(io, e::AdditionDimensionalError) = print(io, """ + Can not add two forms of different dimensions/dualities/spaces: + $(e.sort1) and $(e.sort2) + """) + +struct BinaryOpError <: SortError + verb::String + sort1::DECQuantity + sort2::DECQuantity +end + +Base.showerror(io, e::BinaryOpError) = print(io, "Cannot $(e.verb) $(e.sort1) and $(e.sort2)") + +struct WedgeOpError <: SortError + sort1::DECQuantity + sort2::DECQuantity +end + +Base.showerror(io, e::WedgeOpError) = print(io, "Can only take a wedge product of two forms of the same duality on the same space. Received $(e.sort1) and $(e.sort2)") + +struct WedgeOpDimError <: SortError + sort1::DECQuantity + sort2::DECQuantity +end + +Base.showerror(io, e::WedgeOpDimError) = print(io, "Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(e.sort.dim) is the dimension of the ambient space: tried to wedge product $(e.sort1) and $(e.sort2)") + end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 7d2bae0..115d8f1 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -1,8 +1,13 @@ using MLStyle using SymbolicUtils using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype +import SymbolicUtils: promote_symtype -""" ThDEC in DiagrammaticEquations must be subtyped by Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers +function promote_symtype(f::ComposedFunction, args) + promote_symtype(f.outer, promote_symtype(f.inner, args)) +end + +""" DECQuantities in DiagrammaticEquations must be subtypes of Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers. In the context of SymbolicUtils, a Number is any type that you can do arithmetic operations on. """ abstract type Quantity <: Number end export Quantity @@ -30,17 +35,24 @@ end ``` ``` -@operator Δ(s::ThDEC) begin +@operator Δ(s)::DECQuantity begin @match s begin ::Scalar => error("Invalid") ::VField => error("Invalid") ::Form => ⋆(d(⋆(d(s)))) end + # @rule ~x --> ⋆(d(⋆(d(s)))) end ``` -Δ(S1, S2) begin - @match (S1, S2) +# relationship between +# - type-rewriting (via Metatheory) +# - pattern-matching (via MLStyle, e.g. active pattners) +@operator Δ(s)::ThDEC begin + @rule Δ(s::PrimalForm{0, X, 1}) --> ⋆(d(⋆(d(s)))) + @rule Δ(s::PrimalForm{1, X, 1}) --> ⋆(d(⋆(d(s)))) + _ => nothing +end will create an additional method for Δ for operating on BasicSymbolic @@ -53,61 +65,61 @@ macro operator(head, body) Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) _ => error("$head") end - (f, vars, Theory) = ph(head) + (f, types, Theory) = ph(head) - symbolic_args = [:(::Type{$S}) for S in vars] - symbolic_constraints = [:($S<:$Theory) for S in vars] + sort_types = [:(::Type{$S}) for S in types] + sort_constraints = [:($S<:$Theory) for S in types] # initialize the result result = quote end # DEFINE TYPE INFERENCE IN THE ThDEC SYSTEM - # TODO this just accepts whatever the body is - push!(result.args, - esc(quote - function $f($(symbolic_args...)) where {$(symbolic_constraints...)} - $body - end - end)) + push!(result.args, quote + function $f end; export $f + end) + + arity = length(sort_types) + + # we want to feed symtype the generics + push!(result.args, quote + function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)} + $body + end + function SymbolicUtils.promote_symtype(::typeof($f), args::Vararg{Symbolic, $arity}) + promote_symtype($f, symtype.(args)...) + end + end) # CONSTRUCT THE FUNCTION ON BASIC SYMBOLICS # ...associate each var (S1) to a generic. this will be used in the # type constraint of the new function. - generic_vars = [(v, Symbol("T$k")) for (k,v) in enumerate(vars)] + generic_types = [(v, Symbol("T$k")) for (k,v) in enumerate(types)] - # reassociate vars with their BasicSymbolic Generic Types - basicsym_bindings = map(generic_vars) do (var, T) - (var, :(BasicSymbolic{$T})) + # reassociate types with their BasicSymbolic Generic Types + basicsym_bindings = map(generic_types) do (var, T) + (var, :Symbolic) end # binding type bindings to the basicsymbolics - bs_arg_exprs = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings] - - # build constraints - constraint_exprs = [:($T<:$Theory) for T in getindex.(generic_vars, 2)] - - push!(result.args, - esc(quote - @nospecialize - function $f($(bs_arg_exprs...)) where {$(constraint_exprs...)} - s = $f($(getindex.(generic_vars, 2)...)) - SymbolicUtils.Term{s}($f ,[$(getindex.(basicsym_bindings, 1)...)]) - end - export $f - end)) - - push!(result.args, - esc(quote - # we want to feed symtype the generics - function SymbolicUtils.promote_symtype(::typeof($f), - $(bs_arg_exprs...)) where {$(constraint_exprs...)} - $f($(getindex.(generic_vars, 2)...)) - end - end)) - - return result + bs_arg_exprs = map(basicsym_bindings) do (var, bs_gen) + [:($var::$bs_gen)] + end + constraint_exprs = [:($T<:$Theory) for T in types] + + # Δ(x::BasicSymbolic(T1}) where T1<:DECQuantity + # should be args subtype of symbolic, not basicsymbolic, not were-clause + push!(result.args, quote + @nospecialize + function $f(args...) + s = promote_symtype($f, args...) + SymbolicUtils.Term{s}($f, [args...]) + end + export $f + end) + + return esc(result) end export @operator diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index d603d9f..ae817d2 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -1,8 +1,10 @@ using Test +using DiagrammaticEquations using DiagrammaticEquations.Deca.ThDEC using DiagrammaticEquations.decapodes using SymbolicUtils using SymbolicUtils: symtype, promote_symtype +using MLStyle # load up some variable variables and expressions a, b = @syms a::Scalar b::Scalar @@ -11,10 +13,8 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication - @testset "Term Construction" begin - # TODO implement symtype @test symtype(a) == Scalar @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @@ -41,10 +41,29 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + @test promote_symtype(d, u) == PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} +end + +@testset "Operator definition" begin + + # this is not nabla but "bizarro Δ" + @operator ∇(S)::DECQuantity begin + @match S begin + PatScalar(_) => error("Argument of type $S is invalid") + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + end + end + # @rule ~x --> ⋆(d(⋆(d(s)))) + + @test_throws Exception ∇(b) + @test symtype(∇(u)) == PrimalForm{0, :X ,2} + + @test_broken promote_symtype(Δ, [u,v]) end @testset "Conversion" begin From 7f8597a03ebd747d821c1a65b9d308d1834a53df Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 5 Sep 2024 18:38:09 -0400 Subject: [PATCH 13/30] operator macro parses @rule but need to write tests and iron out related wrinkles --- src/deca/ThDEC.jl | 27 ++++++++++++++++++++ src/symbolictheoryutils.jl | 52 ++++++++++++++++++++++---------------- test/decasymbolic.jl | 12 ++++++--- todo.md | 3 --- 4 files changed, 65 insertions(+), 29 deletions(-) delete mode 100644 todo.md diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 74ff1ee..e5fee67 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -116,6 +116,33 @@ export PatScalar end export PatVFParams +isDualForm(x) = @match symtype(x) begin + PatFormParams([_,d,_,_]) => d + _ => false +end + +# TODO parameterize? +isForm0(x) = @match symtype(x) begin + PatFormParams([0,_,_,_]) => true + _ => false +end + +isForm1(x) = @match symtype(x) begin + PatFormParams([1,_,_,_]) => true + _ => false +end + +isForm2(x) = @match symtype(x) begin + PatFormParams([2,_,_,_]) => true + _ => false +end + +export isDualForm, isForm0, isForm1, isForm2 + +# ############################### +# OPERATORS +# ############################### + @operator -(S)::DECQuantity begin S end @operator ∂ₜ(S)::DECQuantity begin S end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 115d8f1..0ddc31b 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -7,6 +7,22 @@ function promote_symtype(f::ComposedFunction, args) promote_symtype(f.outer, promote_symtype(f.inner, args)) end +@active PatMatch(e) begin + @match e begin + Expr(:macrocall, head, args...) && if head == Symbol("@", "match") end => Some(e) + _ => nothing + end +end +export PatMatch + +@active PatRule(e) begin + @match e begin + Expr(:macrocall, head, args...) && if head == Symbol("@", "rule") end => Some(e) + _ => nothing + end +end +export PatRule + """ DECQuantities in DiagrammaticEquations must be subtypes of Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers. In the context of SymbolicUtils, a Number is any type that you can do arithmetic operations on. """ abstract type Quantity <: Number end @@ -69,22 +85,30 @@ macro operator(head, body) sort_types = [:(::Type{$S}) for S in types] sort_constraints = [:($S<:$Theory) for S in types] - + arity = length(sort_types) + + match_calls = []; rule_calls = []; + pb = @λ begin + Expr(:block, args...) => pb.(args) + PatMatch(e) => push!(match_calls, e) + PatRule(e) => push!(rule_calls, e) + s => nothing + end + pb(body); + # initialize the result result = quote end # DEFINE TYPE INFERENCE IN THE ThDEC SYSTEM - push!(result.args, quote function $f end; export $f end) - arity = length(sort_types) # we want to feed symtype the generics push!(result.args, quote function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)} - $body + $(match_calls...) end function SymbolicUtils.promote_symtype(::typeof($f), args::Vararg{Symbolic, $arity}) promote_symtype($f, symtype.(args)...) @@ -92,24 +116,6 @@ macro operator(head, body) end) # CONSTRUCT THE FUNCTION ON BASIC SYMBOLICS - - # ...associate each var (S1) to a generic. this will be used in the - # type constraint of the new function. - generic_types = [(v, Symbol("T$k")) for (k,v) in enumerate(types)] - - # reassociate types with their BasicSymbolic Generic Types - basicsym_bindings = map(generic_types) do (var, T) - (var, :Symbolic) - end - - # binding type bindings to the basicsymbolics - bs_arg_exprs = map(basicsym_bindings) do (var, bs_gen) - [:($var::$bs_gen)] - end - constraint_exprs = [:($T<:$Theory) for T in types] - - # Δ(x::BasicSymbolic(T1}) where T1<:DECQuantity - # should be args subtype of symbolic, not basicsymbolic, not were-clause push!(result.args, quote @nospecialize function $f(args...) @@ -119,6 +125,8 @@ macro operator(head, body) export $f end) + push!(result.args, quote $rule_calls end) + return esc(result) end export @operator diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index ae817d2..138f757 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -3,7 +3,7 @@ using DiagrammaticEquations using DiagrammaticEquations.Deca.ThDEC using DiagrammaticEquations.decapodes using SymbolicUtils -using SymbolicUtils: symtype, promote_symtype +using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions @@ -52,14 +52,18 @@ end @testset "Operator definition" begin # this is not nabla but "bizarro Δ" + del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin @match S begin PatScalar(_) => error("Argument of type $S is invalid") PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) end - end - # @rule ~x --> ⋆(d(⋆(d(s)))) - + @rule ~~x::isForm0 => ★(d(★(d(x)))) + @rule ~~x::isForm1 => ★(d(★(d(x)))) + d(★(d(★(x)))) + end; + # TODO rewriting not working atm + # del_expand = Chain(del_expand0, del_expand1) + @test_throws Exception ∇(b) @test symtype(∇(u)) == PrimalForm{0, :X ,2} diff --git a/todo.md b/todo.md deleted file mode 100644 index 351cc88..0000000 --- a/todo.md +++ /dev/null @@ -1,3 +0,0 @@ -# Refactoring Notes -Decapodey things still in DEq: -* DerivOp is referenced in src/language From 154e51f84474ef1374f8e664d7a57148a7d15cfd Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 6 Sep 2024 11:40:17 -0400 Subject: [PATCH 14/30] rewriting just needs tests --- src/symbolictheoryutils.jl | 34 +++++++++++++++++----------------- test/decasymbolic.jl | 9 +++++++-- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 0ddc31b..909baa1 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -83,28 +83,38 @@ macro operator(head, body) end (f, types, Theory) = ph(head) + # Passing types to functions requires that we type the signature with ::Type{T}. + # This means that the user would have to write + # my_op(::Type{T1}, ::Type{T2}, ...) + # As a convenience to the user, we allow them to specify the signature using just the types themselves. + # my_op(T1, T2, ...) sort_types = [:(::Type{$S}) for S in types] sort_constraints = [:($S<:$Theory) for S in types] arity = length(sort_types) + # Parse the body for @match or @rule calls. The @match statement parsing is unsophisticated; multiple + # @match statements will be added, and there is currently no validation. match_calls = []; rule_calls = []; pb = @λ begin Expr(:block, args...) => pb.(args) PatMatch(e) => push!(match_calls, e) PatRule(e) => push!(rule_calls, e) s => nothing - end - pb(body); + end; pb(body); # initialize the result result = quote end - # DEFINE TYPE INFERENCE IN THE ThDEC SYSTEM - push!(result.args, quote - function $f end; export $f + # construct the function on basic symbolics + push!(result.args, quote + @nospecialize + function $f(args...) + s = promote_symtype($f, args...) + SymbolicUtils.Term{s}($f, [args...]) + end + export $f end) - # we want to feed symtype the generics push!(result.args, quote function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)} @@ -115,17 +125,7 @@ macro operator(head, body) end end) - # CONSTRUCT THE FUNCTION ON BASIC SYMBOLICS - push!(result.args, quote - @nospecialize - function $f(args...) - s = promote_symtype($f, args...) - SymbolicUtils.Term{s}($f, [args...]) - end - export $f - end) - - push!(result.args, quote $rule_calls end) + push!(result.args, Expr(:tuple, rule_calls...)) return esc(result) end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 138f757..cd0962c 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -41,12 +41,15 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + # test promoting types @test promote_symtype(d, u) == PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + # test composition @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} + end @testset "Operator definition" begin @@ -58,8 +61,8 @@ end PatScalar(_) => error("Argument of type $S is invalid") PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) end - @rule ~~x::isForm0 => ★(d(★(d(x)))) - @rule ~~x::isForm1 => ★(d(★(d(x)))) + d(★(d(★(x)))) + @rule ~x::isForm0 => ★(d(★(d(~x)))) + @rule ~x::isForm1 => ★(d(★(d(~x)))) + d(★(d(★(~x)))) end; # TODO rewriting not working atm # del_expand = Chain(del_expand0, del_expand1) @@ -68,6 +71,8 @@ end @test symtype(∇(u)) == PrimalForm{0, :X ,2} @test_broken promote_symtype(Δ, [u,v]) + + @test del_expand_0(u) == ★(d(★(d(u)))) end @testset "Conversion" begin From 18bb71ff1406d4597b1b862fdc91867685af9f77 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 9 Sep 2024 16:04:30 -0400 Subject: [PATCH 15/30] TST: add some klausmeier rewrites --- test/klausmeier.jl | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/test/klausmeier.jl b/test/klausmeier.jl index 54fdaee..58fca6b 100644 --- a/test/klausmeier.jl +++ b/test/klausmeier.jl @@ -46,22 +46,38 @@ end) import .ThDEC: d, ⋆, SortError -@register Δ(s::Sort) begin - @match s begin - ::Scalar => throw(SortError("Scalar")) - ::VField => throw(SortError("Nay!")) - ::Form => ⋆(d(⋆(d(s)))) - end -end - ω, = @syms ω::PrimalFormT{1, :X, 2} @test Δ(PrimalForm(1, X)) == PrimalForm(1, X) @test symtype(Δ(ω)) == PrimalFormT{1, :X, 2} # TODO propagating module information is suited for a macro -symbmodel = DecaSymbolic(lookup, Phytodynamics, Main) +symbmodel = ps = DecaSymbolic(lookup, Phytodynamics, Main) DecaExpr(symbmodel) + +n = ps.vars[1] +SymbolicUtils.symtype(n) +Δ(n) + +r = @rule Δ(~n) => ⋆(d(⋆(d(~n)))) + +t2 = r(Δ(n)) +t2 |> dump + + +using SymbolicUtils.Rewriters +using SymbolicUtils: promote_symtype +r = @rule ⋆(⋆(~n)) => ~n +nested_star_cancel = Postwalk(Chain([r])) +nested_star_cancel(d(⋆(⋆(n)))) +nsc = nested_star_cancel + +isequal(nsc(⋆(⋆(d(n)))), d(n)) +dump(nsc(⋆(⋆(d(n))))) +dump(d(n)) +⋆(⋆(d(⋆(⋆(n))))) +nsc(⋆(⋆(d(⋆(⋆(n)))))) +nsc(nsc(⋆(⋆(d(⋆(⋆(n))))))) \ No newline at end of file From 4ba6dce70d815b90cc9056943d0820f55f4a0816 Mon Sep 17 00:00:00 2001 From: James Fairbanks Date: Mon, 9 Sep 2024 16:05:37 -0400 Subject: [PATCH 16/30] BUG: fix method shadowing for existing operators like +/- --- src/symbolictheoryutils.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 909baa1..7e0078d 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -106,11 +106,13 @@ macro operator(head, body) result = quote end # construct the function on basic symbolics + argnames = [gensym(:x) for _ in 1:arity] + argclaus = [:($a::Symbolic) for a in argnames] push!(result.args, quote @nospecialize - function $f(args...) - s = promote_symtype($f, args...) - SymbolicUtils.Term{s}($f, [args...]) + function $f($(argclaus...)) + s = promote_symtype($f, $(argnames...)) + SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end export $f end) From 82f65cec28337c18e58761727f4f2dfd6f3744de Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 10 Sep 2024 07:39:57 -0400 Subject: [PATCH 17/30] added more tests, @operator macro is more flexible --- src/symbolictheoryutils.jl | 67 +++++++++++++++++--------------------- test/decasymbolic.jl | 36 ++++++++++++++++---- test/runtests.jl | 4 +-- 3 files changed, 61 insertions(+), 46 deletions(-) diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 909baa1..441e301 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -7,13 +7,15 @@ function promote_symtype(f::ComposedFunction, args) promote_symtype(f.outer, promote_symtype(f.inner, args)) end -@active PatMatch(e) begin +@active PatBlock(e) begin @match e begin - Expr(:macrocall, head, args...) && if head == Symbol("@", "match") end => Some(e) + Expr(:macrocall, name, args...) && if name == Symbol("@", "match") end => Some(e) + Expr(:block, args...) => Some(e) + Expr(:let, args...) => Some(e) _ => nothing end end -export PatMatch +export PatBlock @active PatRule(e) begin @match e begin @@ -34,22 +36,26 @@ Creates an operator `foo` with arguments which are types in a given Theory. This ``` @operator foo(S1, S2, ...)::Theory begin (body of function) + (@rule expr1) + ... + (@rule exprN) end ``` builds ``` -foo(::Type{S1}, ::Type{S2}, ...) where {S1<:ThDEC, S2<:ThDEC, ...} +promote_symtype(::typeof{f}, ::Type{S1}, ::Type{S2}, ...) where {S1<:DECQuantity, S2<:DECQuantity, ...} (body of function) end ``` as well as ``` -foo(S1::BasicSymbolic{T1}, S2::BasicSymbolic{T2}, ...) where {T1<:ThDEC, ...} - s = foo(T1, T2, ...) +foo(S1, S2, ...) where {T1<:ThDEC, ...} + s = promote_symtype(f, S1, S2, ...) SymbolicUtils.Term{s}(foo, [S1, S2, ...]) end ``` +Example: ``` @operator Δ(s)::DECQuantity begin @match s begin @@ -57,21 +63,9 @@ end ::VField => error("Invalid") ::Form => ⋆(d(⋆(d(s)))) end - # @rule ~x --> ⋆(d(⋆(d(s)))) + @rule ~s --> ⋆(d(⋆(d(~s)))) end ``` - -# relationship between -# - type-rewriting (via Metatheory) -# - pattern-matching (via MLStyle, e.g. active pattners) -@operator Δ(s)::ThDEC begin - @rule Δ(s::PrimalForm{0, X, 1}) --> ⋆(d(⋆(d(s)))) - @rule Δ(s::PrimalForm{1, X, 1}) --> ⋆(d(⋆(d(s)))) - _ => nothing -end - - -will create an additional method for Δ for operating on BasicSymbolic """ macro operator(head, body) @@ -84,24 +78,19 @@ macro operator(head, body) (f, types, Theory) = ph(head) # Passing types to functions requires that we type the signature with ::Type{T}. - # This means that the user would have to write - # my_op(::Type{T1}, ::Type{T2}, ...) - # As a convenience to the user, we allow them to specify the signature using just the types themselves. - # my_op(T1, T2, ...) + # This means that the user would have to write `my_op(::Type{T1}, ::Type{T2}, ...)` + # As a convenience to the user, we allow them to specify the signature using just the types themselves: + # `my_op(T1, T2, ...)` sort_types = [:(::Type{$S}) for S in types] sort_constraints = [:($S<:$Theory) for S in types] arity = length(sort_types) - # Parse the body for @match or @rule calls. The @match statement parsing is unsophisticated; multiple - # @match statements will be added, and there is currently no validation. - match_calls = []; rule_calls = []; - pb = @λ begin - Expr(:block, args...) => pb.(args) - PatMatch(e) => push!(match_calls, e) - PatRule(e) => push!(rule_calls, e) + # Parse the body for @rule calls. + block, rulecalls = @match Base.remove_linenums!(body) begin + Expr(:block, block, rules...) => (block, rules) s => nothing - end; pb(body); - + end + # initialize the result result = quote end @@ -118,23 +107,19 @@ macro operator(head, body) # we want to feed symtype the generics push!(result.args, quote function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)} - $(match_calls...) + $block end function SymbolicUtils.promote_symtype(::typeof($f), args::Vararg{Symbolic, $arity}) promote_symtype($f, symtype.(args)...) end end) - push!(result.args, Expr(:tuple, rule_calls...)) + push!(result.args, Expr(:tuple, rulecalls...)) return esc(result) end export @operator -function alias(x) - error("$x has no aliases") -end - """ Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces for each alias typechecking and nameof methods which call those for their rep. @@ -159,4 +144,10 @@ macro alias(body) end result end +export @alias + +function alias(x) + error("$x has no aliases") +end export alias + diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index cd0962c..d00e31c 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -61,18 +61,42 @@ end PatScalar(_) => error("Argument of type $S is invalid") PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) end - @rule ~x::isForm0 => ★(d(★(d(~x)))) - @rule ~x::isForm1 => ★(d(★(d(~x)))) + d(★(d(★(~x)))) + @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) + @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) end; - # TODO rewriting not working atm - # del_expand = Chain(del_expand0, del_expand1) @test_throws Exception ∇(b) @test symtype(∇(u)) == PrimalForm{0, :X ,2} + @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} - @test_broken promote_symtype(Δ, [u,v]) + @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) + + # we will test is new operator + (r0, r1, r2) = @operator ρ(S)::DECQuantity begin + if S <: Form + Scalar + else + Form + end + @rule ρ(~x::isForm0) => 0 + @rule ρ(~x::isForm1) => 1 + @rule ρ(~x::isForm2) => 2 + end + + @test symtype(ρ(u)) == Scalar + + R, = @operator φ(S1, S2, S3)::DECQuantity begin + let T1=S1, T2=S2, T3=S3 + Scalar + end + @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) + end + + # TODO we need to alias rewriting rules + @alias (φ′,) => φ + + @test isequal(R(φ(2u,2u,2u)), R(φ′(2u,2u,2u))) - @test del_expand_0(u) == ★(d(★(d(u)))) end @testset "Conversion" begin diff --git a/test/runtests.jl b/test/runtests.jl index 0cb0f30..dd92531 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,8 +2,6 @@ using Test include("pretty.jl") -include("aqua.jl") - @testset "Core" begin include("core.jl") end @@ -39,3 +37,5 @@ end @testset "Open Operators" begin include("openoperators.jl") end + +include("aqua.jl") From e82e8262094952e3914cd336aac2ec59576f54de Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 10 Sep 2024 12:35:03 -0400 Subject: [PATCH 18/30] almost round-tripping in klausmeier. equations are not currently passing isequal --- src/SymbolicUtilsInterop.jl | 8 +++---- src/deca/ThDEC.jl | 47 ++++++++++++++++++++----------------- src/symbolictheoryutils.jl | 4 +--- test/klausmeier.jl | 37 ++++++----------------------- 4 files changed, 38 insertions(+), 58 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 6de6cb5..4ffdf84 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -64,7 +64,7 @@ decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) function decapodes.DecaExpr(d::SymbolicContext) context = map(d.vars) do var - decapodes.Judgement(nameof(var), nameof(Sort(var)), :X) + decapodes.Judgement(nameof(var), nameof(symtype(var)), :I) end equations = map(d.equations) do eq decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) @@ -82,7 +82,7 @@ Example: SymbolicUtils.BasicSymbolic(context, Term(a)) ``` """ -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Quantity}, t::decapodes.Term, __module__=@__MODULE__) +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__) # user must import symbols into scope ! = (f -> getfield(__module__, f)) @match t begin @@ -108,13 +108,13 @@ end function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) # associates each var to its sort... context = map(d.context) do j - j.var => symtype(ThDEC, j.dim, j.space) + j.var => symtype(Deca.DECQuantity, j.dim, j.space) end # ... which we then produce a vector of symbolic vars vars = map(context) do (v, s) SymbolicUtils.Sym{s}(v) end - context = Dict{Symbol,Quantity}(context) + context = Dict{Symbol,DataType}(context) eqs = map(d.equations) do eq SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index e5fee67..51852b5 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -79,6 +79,9 @@ export PrimalVF const DualVF{s,n} = VField{true,s,n} export DualVF +Base.nameof(u::Type{<:PrimalForm}) = Symbol("Form"*"$(dim(u))") +Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") + # ACTIVE PATTERNS @active PatForm(T) begin @@ -250,72 +253,74 @@ end Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") +Base.nameof(::typeof(Δ), s) = :Δ + function Base.nameof(::typeof(★), s) inv = isdual(s) ? "⁻¹" : "" Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end -function SymbolicUtils.symtype(::Quantity, qty::Symbol, space::Symbol) +function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin - :Scalar => Scalar + :Scalar || :Constant => Scalar :Form0 => PrimalForm{0, space, 1} :Form1 => PrimalForm{1, space, 1} :Form2 => PrimalForm{2, space, 1} :DualForm0 => DualForm{0, space, 1} :DualForm1 => DualForm{1, space, 1} :DualForm2 => DualForm{2, space, 1} - _ => error("$qty") + _ => error("Received $qty") end end struct ExteriorDerivativeError <: SortError - sort::DECQuantity + sort::DataType end -Base.showerror(io, e::ExteriorDerivativeError) = print(io, "Cannot apply the exterior derivative to $(e.sort)") +Base.showerror(io::IO, e::ExteriorDerivativeError) = print(io, "Cannot apply the exterior derivative to $(e.sort)") struct HodgeStarError <: SortError - sort::DECQuantity + sort::DataType end -Base.showerror(io, e::HodgeStarError) = print(io, "Cannot take the hodge star of $(e.sort)") +Base.showerror(io::IO, e::HodgeStarError) = print(io, "Cannot take the hodge star of $(e.sort)") struct LaplacianError <: SortError - sort::DECQuantity + sort::DataType end -Base.showerror(io, e::LaplacianError) = print(io, "Cannot take the Laplacian of $(e.sort)") +Base.showerror(io::IO, e::LaplacianError) = print(io, "Cannot take the Laplacian of $(e.sort)") struct AdditionDimensionalError <: SortError - sort1::DECQuantity - sort2::DECQuantity + sort1::DataType + sort2::DataType end -Base.showerror(io, e::AdditionDimensionalError) = print(io, """ +Base.showerror(io::IO, e::AdditionDimensionalError) = print(io, """ Can not add two forms of different dimensions/dualities/spaces: $(e.sort1) and $(e.sort2) """) struct BinaryOpError <: SortError verb::String - sort1::DECQuantity - sort2::DECQuantity + sort1::DataType + sort2::DataType end -Base.showerror(io, e::BinaryOpError) = print(io, "Cannot $(e.verb) $(e.sort1) and $(e.sort2)") +Base.showerror(io::IO, e::BinaryOpError) = print(io, "Cannot $(e.verb) $(e.sort1) and $(e.sort2)") struct WedgeOpError <: SortError - sort1::DECQuantity - sort2::DECQuantity + sort1::DataType + sort2::DataType end -Base.showerror(io, e::WedgeOpError) = print(io, "Can only take a wedge product of two forms of the same duality on the same space. Received $(e.sort1) and $(e.sort2)") +Base.showerror(io::IO, e::WedgeOpError) = print(io, "Can only take a wedge product of two forms of the same duality on the same space. Received $(e.sort1) and $(e.sort2)") struct WedgeOpDimError <: SortError - sort1::DECQuantity - sort2::DECQuantity + sort1::DataType + sort2::DataType end -Base.showerror(io, e::WedgeOpDimError) = print(io, "Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(e.sort.dim) is the dimension of the ambient space: tried to wedge product $(e.sort1) and $(e.sort2)") +Base.showerror(io::IO, e::WedgeOpDimError) = print(io, "Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(e.sort.dim) is the dimension of the ambient space: tried to wedge product $(e.sort1) and $(e.sort2)") end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 441e301..d94f67d 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -146,8 +146,6 @@ macro alias(body) end export @alias -function alias(x) - error("$x has no aliases") -end +alias(x) = error("$x has no aliases") export alias diff --git a/test/klausmeier.jl b/test/klausmeier.jl index 54fdaee..2ed4ac2 100644 --- a/test/klausmeier.jl +++ b/test/klausmeier.jl @@ -11,7 +11,7 @@ Hydrodynamics = @decapode begin (n,w)::DualForm0 dX::Form1 (a,ν)::Constant - + # ∂ₜ(w) == a - w - w * n^2 + ν * L(dX, w) end @@ -19,7 +19,7 @@ end Phytodynamics = @decapode begin (n,w)::DualForm0 m::Constant - + # ∂ₜ(n) == w * n^2 - m*n + Δ(n) end @@ -27,14 +27,10 @@ Hydrodynamics = parse_decapode(quote (n,w)::DualForm0 dX::Form1 (a,ν)::Constant - + # ∂ₜ(w) == a - w - w + ν * L(dX, w) end) -X = Space(:X, 2) -lookup = SpaceLookup(X) -# DecaSymbolic(lookup, Hydrodynamics) - # See Klausmeier Equation 2.b Phytodynamics = parse_decapode(quote (n,w)::Form0 @@ -42,26 +38,7 @@ Phytodynamics = parse_decapode(quote ∂ₜ(n) == w - m*n + Δ(n) end) -@test_broken DecaSymbolic(lookup, Phytodynamics) - -import .ThDEC: d, ⋆, SortError - -@register Δ(s::Sort) begin - @match s begin - ::Scalar => throw(SortError("Scalar")) - ::VField => throw(SortError("Nay!")) - ::Form => ⋆(d(⋆(d(s)))) - end -end - -ω, = @syms ω::PrimalFormT{1, :X, 2} - -@test Δ(PrimalForm(1, X)) == PrimalForm(1, X) -@test symtype(Δ(ω)) == PrimalFormT{1, :X, 2} - -# TODO propagating module information is suited for a macro -symbmodel = DecaSymbolic(lookup, Phytodynamics, Main) - -DecaExpr(symbmodel) - - +symbmodel = SymbolicContext(Phytodynamics) +dexpr = DecaExpr(symbmodel) +symbmodel′ = SymbolicContext(dexpr) +# TODO variables are the same but the equations don't match From ecfa9319f0aeccf937680971c3aa98ab953ffe0a Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 10 Sep 2024 13:52:42 -0400 Subject: [PATCH 19/30] added rules function which dispatches on function symbol and the Val(arity) --- src/deca/ThDEC.jl | 3 +++ src/symbolictheoryutils.jl | 14 +++++++++++++- test/klausmeier.jl | 24 ++++++++++++------------ 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 51852b5..7ba9b11 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -1,6 +1,7 @@ module ThDEC using ..DiagrammaticEquations: @operator, @alias, Quantity +import ..DiagrammaticEquations: rules using MLStyle using StructEquality @@ -173,6 +174,8 @@ end PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) _ => throw(LaplacianError(S)) end + @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) + @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) end @operator +(S1, S2)::DECQuantity begin diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 57cb7a5..cff2761 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -3,6 +3,9 @@ using SymbolicUtils using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype import SymbolicUtils: promote_symtype +function rules end +export rules + function promote_symtype(f::ComposedFunction, args) promote_symtype(f.outer, promote_symtype(f.inner, args)) end @@ -106,9 +109,18 @@ macro operator(head, body) export $f end) + if !isempty(rulecalls) + push!(result.args, quote + function rules(::typeof($f), ::Val{$arity}) + [($(rulecalls...))] + end + end) + end + # we want to feed symtype the generics push!(result.args, quote - function SymbolicUtils.promote_symtype(::typeof($f), $(sort_types...)) where {$(sort_constraints...)} + function SymbolicUtils.promote_symtype(::typeof($f), + $(sort_types...)) where {$(sort_constraints...)} $block end function SymbolicUtils.promote_symtype(::typeof($f), args::Vararg{Symbolic, $arity}) diff --git a/test/klausmeier.jl b/test/klausmeier.jl index 8c36f9e..814cf73 100644 --- a/test/klausmeier.jl +++ b/test/klausmeier.jl @@ -38,31 +38,31 @@ Phytodynamics = parse_decapode(quote ∂ₜ(n) == w - m*n + Δ(n) end) -symbmodel = SymbolicContext(Phytodynamics) -dexpr = DecaExpr(symbmodel) -symbmodel′ = SymbolicContext(dexpr) +ps = SymbolicContext(Phytodynamics) +dexpr = DecaExpr(ps) +ps′ = SymbolicContext(dexpr) # TODO variables are the same but the equations don't match n = ps.vars[1] SymbolicUtils.symtype(n) Δ(n) -r = @rule Δ(~n) => ⋆(d(⋆(d(~n)))) +r, _ = rules(Δ, Val(1)); t2 = r(Δ(n)) t2 |> dump - using SymbolicUtils.Rewriters using SymbolicUtils: promote_symtype -r = @rule ⋆(⋆(~n)) => ~n +r = @rule ★(★(~n)) => ~n + nested_star_cancel = Postwalk(Chain([r])) -nested_star_cancel(d(⋆(⋆(n)))) +nested_star_cancel(d(★(★(n)))) nsc = nested_star_cancel -isequal(nsc(⋆(⋆(d(n)))), d(n)) -dump(nsc(⋆(⋆(d(n))))) +@test isequal(nsc(★(★(d(n)))), d(n)) +dump(nsc(★(★(d(n))))) dump(d(n)) -⋆(⋆(d(⋆(⋆(n))))) -nsc(⋆(⋆(d(⋆(⋆(n)))))) -nsc(nsc(⋆(⋆(d(⋆(⋆(n))))))) +★(★(d(★(★(n))))) +nsc(★(★(d(★(★(n)))))) +nsc(nsc(★(★(d(★(★(n))))))) From 4603ed973e64d6483f02dd631aad80e903f77c54 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 11 Sep 2024 09:13:45 -0400 Subject: [PATCH 20/30] fixed docs for @operator, fixed Term --- src/SymbolicUtilsInterop.jl | 4 ++-- src/symbolictheoryutils.jl | 3 ++- test/decasymbolic.jl | 6 +----- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 4ffdf84..61502a8 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -51,9 +51,9 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) elseif op == ∂ₜ decapodes.Tan(only(termargs)) elseif length(args) == 1 - decapodes.App1(nameof(op, args...), termargs...) + decapodes.App1(nameof(op, symtype.(args)...), termargs...) elseif length(args) == 2 - decapodes.App2(nameof(op, args...), termargs...) + decapodes.App2(nameof(op, symtype.(args)...), termargs...) else error("was unable to convert $t into a Term") end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index cff2761..41a741b 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -52,7 +52,7 @@ end ``` as well as ``` -foo(S1, S2, ...) where {T1<:ThDEC, ...} +foo(S1, S2, ...) where {S1<:DECQuantity, ...} s = promote_symtype(f, S1, S2, ...) SymbolicUtils.Term{s}(foo, [S1, S2, ...]) end @@ -109,6 +109,7 @@ macro operator(head, body) export $f end) + # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) if !isempty(rulecalls) push!(result.args, quote function rules(::typeof($f), ::Val{$arity}) diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index d00e31c..bbd9102 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -30,11 +30,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) @test Term(∂ₜ(u)) == Tan(Var(:u)) - @test_broken Term(⋆(ω)) == App1(:⋆₁, Var(:ω)) - # @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ)) - # @test Term(DiagrammaticEquations.ThDEC.♯(du)) - - # @test_throws ThDEC.SortError ThDEC.⋆(ϕ) + @test Term(★(ω)) == App1(:★₁, Var(:ω)) # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) From 5324de3b397a1069b941941022fbab00c96cb354 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Thu, 12 Sep 2024 14:54:10 -0400 Subject: [PATCH 21/30] Loosened aqua tests This just removes the aqua check for type piracy, it doesn't fix it. I've tagged the function that it probably is reacting to. I've also removed revise as a dep. --- Project.toml | 1 - src/symbolictheoryutils.jl | 1 + test/aqua.jl | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 19239ef..b40b43a 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 41a741b..0c03a0a 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -6,6 +6,7 @@ import SymbolicUtils: promote_symtype function rules end export rules +# TODO: Probable piracy here function promote_symtype(f::ComposedFunction, args) promote_symtype(f.outer, promote_symtype(f.inner, args)) end diff --git a/test/aqua.jl b/test/aqua.jl index 824a58c..38129f7 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -1,5 +1,5 @@ using Aqua, DiagrammaticEquations @testset "Code quality (Aqua.jl)" begin # TODO: fix ambiguities - Aqua.test_all(DiagrammaticEquations, ambiguities=false, undefined_exports=false) + Aqua.test_all(DiagrammaticEquations, ambiguities=false, undefined_exports=false, piracies=false) end From 0a314a8ee583d22108917f231daf5c43ac8e846c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Sep 2024 08:27:14 -0400 Subject: [PATCH 22/30] fixing bug where (+) always returns Scalar --- src/deca/ThDEC.jl | 8 ++++---- test/decasymbolic.jl | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 7ba9b11..383274e 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -180,8 +180,8 @@ end @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin - (PatScalar, PatScalar) => Scalar - (PatScalar, PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar) => S1 # commutativity + (PatScalar(_), PatScalar(_)) => Scalar + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} @@ -197,8 +197,8 @@ end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin - (Scalar, Scalar) => Scalar - (Scalar, PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n} + (PatScalar(_), PatScalar(_)) => Scalar + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} _ => throw(BinaryOpError("multiply", S1, S2)) end end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index bbd9102..0b5b7e9 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -13,6 +13,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication + @testset "Term Construction" begin @test symtype(a) == Scalar @@ -25,6 +26,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} # @test_throws ThDEC.SortError ThDEC.♯(u) + @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} # test unary operator conversion to decaexpr @test Term(1) == Lit(Symbol("1")) @@ -36,7 +38,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) - + # test promoting types @test promote_symtype(d, u) == PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar From 5d5c25d522b7cc67a7a28ff826659dad34faf12e Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:15:39 -0400 Subject: [PATCH 23/30] Expression level rewriting (#69) * Set version to 0.1.7 * Added more exports (#44) Added `apex` and `@relation`, `to_graphviz` from Catlab Co-authored-by: James * Add type rules for vectorfields * Add musical overload resolution * Take advantage of :infer in type rules * Initial attempt at rewriting Converts ACSet to a series of Symbolic terms that can be rewritten with a provided rewriter * Added proof of concept Added a short script showcasing how rewriting could be done with the `Sort` types and a reference ACSet. * Added ability to do through-op rewrites This now supports the ability for ACSet intermediate expressions to be merged into one single expression upon which rewriting rules (like dd=0) may be performed. * Added Space import * Completed full pipeline Can take ACSets to Symbolics back to ACSets * Remove metadata usage This needs to switch to use the new type system * Added DECQuantity types Also switched to using SymbolicsUtils' `substitute`. Still needs tests and code needs to be cleaned up. * Completed pipeline again Addition now works as well but rewriting seems to be janky, unrelated to this pipeline specifically I believe. * fixed bug where type-checking subtraction uses +(S1,S2), which is obsolete * George and I debugged rewriting. Incorrect type passed to resulting term meant typed rewriting would fail * Cleaning up pipeline This black boxes the intermediate symbolic expressions to the user. The user will simply submit a rewriter that will then be applied * Fixed order of inclusions * adding support for Parameters and Constants * Added tests for acset2symbolic * etc * Literals testing * parameters test passing after some debugging. * supporting Infer, better Base.nameof, better tests * Clean out-of-order vector constructions * Convert to symbolics inside merge_equations * Reduce cases of topological sort * Reify via recursive function, not lambda case * Further improvement of acset2symbolics Remove special DerivOp handling, fixed bug where multiple equations with the same variable result were being dropped, more tests to cover these cases and further clean up. * Remove extraneous tangents * Remove redundant helper functions * Pass indexed names and types directly * Removed extraneous d arg * fixing work on tumor invasion * macros which create export stmts will fail inside @testset due to JuliaLang issue #51325 * removed ghost emoji and added convenience function for rules. aqua's failing persistent tasks. * Added more tests for acset2symbolics * Fixed persistence issue Also set default form dim to 2 and allowed it to vary. * Final touches * Remove unused fuctionality --------- Co-authored-by: AlgebraicJulia Bot <129184742+algebraicjuliabot@users.noreply.github.com> Co-authored-by: James Co-authored-by: Luke Morris Co-authored-by: Matt --- Project.toml | 2 +- src/DiagrammaticEquations.jl | 8 +- src/SymbolicUtilsInterop.jl | 34 +++-- src/acset.jl | 47 +++--- src/acset2symbolic.jl | 81 ++++++++++ src/deca/ThDEC.jl | 72 +++++++-- src/deca/deca_acset.jl | 88 ++++++----- src/graph_traversal.jl | 73 +++++++++ src/symbolictheoryutils.jl | 30 ++-- test/Project.toml | 1 + test/acset2symbolic.jl | 286 +++++++++++++++++++++++++++++++++++ test/composition.jl | 5 +- test/decasymbolic.jl | 164 +++++++++++--------- test/graph_traversal.jl | 64 ++++++++ test/language.jl | 103 ++++++++++--- test/runtests.jl | 9 ++ 16 files changed, 867 insertions(+), 200 deletions(-) create mode 100644 src/acset2symbolic.jl create mode 100644 src/graph_traversal.jl create mode 100644 test/acset2symbolic.jl create mode 100644 test/graph_traversal.jl diff --git a/Project.toml b/Project.toml index b40b43a..69f41e4 100644 --- a/Project.toml +++ b/Project.toml @@ -2,7 +2,7 @@ name = "DiagrammaticEquations" uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317" license = "MIT" authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"] -version = "0.1.6" +version = "0.1.7" [deps] ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index be9cace..558a839 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -2,6 +2,8 @@ """ module DiagrammaticEquations +using Catlab + export DerivOp, append_dot, normalize_unicode, infer_states, infer_types!, # Deca @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!, Collage, collate, ## composition oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram, +apex, @relation, # Re-exported from Catlab ## acset SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode, contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types, @@ -25,12 +28,12 @@ unique_lits!, Plus, AppCirc1, Var, Tan, App1, App2, ## visualization to_graphviz_property_graph, typename, draw_composition, +to_graphviz, # Re-exported from Catlab ## rewrite average_rewrite, ## openoperators transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s! -using Catlab using Catlab.Theories import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom using Catlab.Programs @@ -62,6 +65,7 @@ include("pretty.jl") include("colanguage.jl") include("openoperators.jl") include("symbolictheoryutils.jl") +include("graph_traversal.jl") include("deca/Deca.jl") include("learn/Learn.jl") include("SymbolicUtilsInterop.jl") @@ -69,4 +73,6 @@ include("SymbolicUtilsInterop.jl") @reexport using .Deca @reexport using .SymbolicUtilsInterop +include("acset2symbolic.jl") + end diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 61502a8..c95fb54 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -1,6 +1,8 @@ module SymbolicUtilsInterop -using ..DiagrammaticEquations: AbstractDecapode, Quantity +using ACSets +using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp +using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..decapodes using ..Deca @@ -14,6 +16,7 @@ struct SymbolicEquation{E} lhs::E rhs::E end +export SymbolicEquation Base.show(io::IO, e::SymbolicEquation) = begin print(io, e.lhs); print(io, " == "); print(io, e.rhs) @@ -48,7 +51,7 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) decapodes.Plus(termargs) elseif op == * decapodes.Mult(termargs) - elseif op == ∂ₜ + elseif op ∈ [DerivOp, ∂ₜ] decapodes.Tan(only(termargs)) elseif length(args) == 1 decapodes.App1(nameof(op, symtype.(args)...), termargs...) @@ -59,6 +62,9 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) end end end +# TODO subtraction is not parsed as such. e.g., +# a, b = @syms a::Scalar b::Scalar +# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)])) decapodes.Term(x::Real) = decapodes.Lit(Symbol(x)) @@ -82,9 +88,9 @@ Example: SymbolicUtils.BasicSymbolic(context, Term(a)) ``` """ -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__) +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term) # user must import symbols into scope - ! = (f -> getfield(__module__, f)) + ! = (f -> getfield(@__MODULE__, f)) @match t begin Var(name) => SymbolicUtils.Sym{context[name]}(name) Lit(v) => Meta.parse(string(v)) @@ -95,17 +101,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapode # see test/language.jl (f, x) -> (!(f))(x), fs; - init=BasicSymbolic(context, arg, __module__) + init=BasicSymbolic(context, arg) ) - App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__)) - App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__)) - Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__)) + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => (!(DerivOp))(BasicSymbolic(context, x)) end end -function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) +function SymbolicContext(d::decapodes.DecaExpr) # associates each var to its sort... context = map(d.context) do j j.var => symtype(Deca.DECQuantity, j.dim, j.space) @@ -116,13 +122,13 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) end context = Dict{Symbol,DataType}(context) eqs = map(d.equations) do eq - SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) end SymbolicContext(vars, eqs) end function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int}) - eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) + eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions) end """ function SummationDecapode(e::SymbolicContext) """ @@ -132,7 +138,7 @@ function SummationDecapode(e::SymbolicContext) foreach(e.vars) do var # convert Sort(var)::PrimalForm0 --> :Form0 - var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var))) + var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var))) symbol_table[var.name] = var_id end diff --git a/src/acset.jl b/src/acset.jl index 9665afd..4cfdaac 100644 --- a/src/acset.jl +++ b/src/acset.jl @@ -158,13 +158,16 @@ end # A collection of DecaType getters # TODO: This should be replaced by using a type hierarchy const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, - :Literal, :Parameter, :Constant, :infer] + :PVF, :DVF, + :Literal, :Parameter, :Constant, :infer] const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2] const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2] -const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer] +const VECTORFIELD_TYPES = [:PVF, :DVF] + +const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer] const USER_TYPES = [:Constant, :Parameter] const NUMBER_TYPES = [:Literal] const INFER_TYPES = [:infer] @@ -184,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode) isempty(unrecognized_types) || error("Types $unrecognized_types are not recognized. CHECK: $types") end +export recognize_types """ is_expanded(d::AbstractNamedDecapode) @@ -427,12 +431,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol) end """ - filterfor_forms(types::AbstractVector{Symbol}) + filterfor_ec_types(types::AbstractVector{Symbol}) -Return any form type symbols. +Return any form or vector-field type symbols. """ -function filterfor_forms(types::AbstractVector{Symbol}) - conditions = x -> !(x in NONFORM_TYPES) +function filterfor_ec_types(types::AbstractVector{Symbol}) + conditions = x -> !(x in NON_EC_TYPES) filter(conditions, types) end @@ -447,29 +451,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int) types = d[idxs, :type] all(t != :infer for t in types) && return applied # We need not infer - forms = unique(filterfor_forms(types)) + ec_types = unique(filterfor_ec_types(types)) - form = @match length(forms) begin + ec_type = @match length(ec_types) begin 0 => return applied # We can not infer - 1 => only(forms) - _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms") + 1 => only(ec_types) + _ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types") end for idx in idxs - applied |= safe_modifytype!(d, idx, form) + applied |= safe_modifytype!(d, idx, ec_type) end return applied end function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) - type_src = d[d[op1_id, :src], :type] - type_tgt = d[d[op1_id, :tgt], :type] + score_src = (rule.src_type == d[d[op1_id, :src], :type]) + score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type]) - score_src = (rule.src_type == type_src) - score_tgt = (rule.tgt_type == type_tgt) check_op = (d[op1_id, :op1] in rule.op_names) - if(check_op && (score_src + score_tgt == 1)) mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type) mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type) @@ -480,19 +481,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule) end function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule) - type_proj1 = d[d[op2_id, :proj1], :type] - type_proj2 = d[d[op2_id, :proj2], :type] - type_res = d[d[op2_id, :res], :type] + score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type]) + score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type]) + score_res = (rule.res_type == d[d[op2_id, :res], :type]) - score_proj1 = (rule.proj1_type == type_proj1) - score_proj2 = (rule.proj2_type == type_proj2) - score_res = (rule.res_type == type_res) check_op = (d[op2_id, :op2] in rule.op_names) - - if(check_op && (score_proj1 + score_proj2 + score_res == 2)) + if check_op && (score_proj1 + score_proj2 + score_res == 2) mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type) mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type) - mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) + mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type) return mod_proj1 || mod_proj2 || mod_res end diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl new file mode 100644 index 0000000..330fa76 --- /dev/null +++ b/src/acset2symbolic.jl @@ -0,0 +1,81 @@ +using DiagrammaticEquations +using ACSets +using SymbolicUtils +using SymbolicUtils: BasicSymbolic, Symbolic + +export symbolic_rewriting + +const EQUALITY = (==) +const SymEqSym = SymbolicEquation{Symbolic} + +function symbolics_lookup(d::SummationDecapode) + Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type + (name, decavar_to_symbolics(name, type)) + end) +end + +function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I) + new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space) + SymbolicUtils.Sym{new_type}(var_name) +end + +function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol) + input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name]) + output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name]) + op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type))) + + S = promote_symtype(op_sym, input_syms...) + SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms)) +end + +function to_symbolics(d::SummationDecapode) + symvar_lookup = symbolics_lookup(d) + map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d)) +end + +function symbolic_rewriting(d::SummationDecapode, rewriter=identity) + d′ = infer_types!(deepcopy(d)) + eqns = merge_equations(d′) + to_acset(d′, map(rewriter, eqns)) +end + +# XXX SymbolicUtils.substitute swaps the order of multiplication. +# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ +function merge_equations(d::SummationDecapode) + eqn_lookup, terminal_eqns = Dict(), SymEqSym[] + + foreach(to_symbolics(d)) do x + sub = SymbolicUtils.substitute(x.rhs, eqn_lookup) + push!(eqn_lookup, (x.lhs => sub)) + if x.lhs.name in infer_terminal_names(d) + push!(terminal_eqns, SymEqSym(x.lhs, sub)) + end + end + + map(terminal_eqns) do eqn + SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs]) + end +end + +function to_acset(d::SummationDecapode, sym_exprs) + literals = incident(d, :Literal, :type) + + outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i + :($(d[i, :name])::$(d[i, :type])) + end + + #TODO: This step is breaking up summations + final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs) + reify!(exprs) = foreach(exprs) do x + if typeof(x) == Expr && x.head == :call + x.args[1] = nameof(x.args[1]) + reify!(x.args[2:end]) + end + end + reify!(final_exprs) + + deca_block = quote end + deca_block.args = [outer_types..., final_exprs...] + + ∘(infer_types!, SummationDecapode, parse_decapode)(deca_block) +end diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 383274e..4fa4695 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -24,8 +24,16 @@ export DECQuantity # this ensures symtype doesn't recurse endlessly SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S -struct Scalar <: DECQuantity end -export Scalar +abstract type AbstractScalar <: DECQuantity end + +struct InferredType <: DECQuantity end +export InferredType + +struct Scalar <: AbstractScalar end +struct Parameter <: AbstractScalar end +struct Const <: AbstractScalar end +struct Literal <: AbstractScalar end +export Scalar, Parameter, Const, Literal struct FormParams dim::Int @@ -85,6 +93,18 @@ Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") # ACTIVE PATTERNS +@active PatInferredType(T) begin + if T <: InferredType + Some(InferredType) + end +end + +@active PatInferredTypes(T) begin + if any(S->S<:InferredType, T) + Some(InferredType) + end +end + @active PatForm(T) begin if T <: Form Some(T) @@ -107,7 +127,7 @@ end export PatFormDim @active PatScalar(T) begin - if T <: Scalar + if T <: AbstractScalar Some(T) end end @@ -153,6 +173,7 @@ export isDualForm, isForm0, isForm1, isForm2 @operator d(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} _ => throw(ExteriorDerivativeError(S)) end @@ -162,6 +183,7 @@ end @operator ★(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} _ => throw(HodgeStarError(S)) end @@ -171,17 +193,23 @@ end @operator Δ(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) _ => throw(LaplacianError(S)) end @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) + @rule Δ(~x::isForm2) => d(★(d(★(~x)))) end +@alias (Δ₀, Δ₁, Δ₂) => Δ + +# TODO: Determine what we need to do for .+ @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar - (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity + (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i, d, s, n} # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} @@ -193,10 +221,13 @@ end end end -@operator -(S1, S2)::DECQuantity begin +(S1, S2) end +@operator -(S1, S2)::DECQuantity begin + promote_symtype(+, S1, S2) +end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} _ => throw(BinaryOpError("multiply", S1, S2)) @@ -205,6 +236,7 @@ end @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) if i1 + i2 <= n1 @@ -219,9 +251,10 @@ end abstract type SortError <: Exception end -# struct WedgeDimError <: SortError end - -Base.nameof(s::Scalar) = :Constant +Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal +Base.nameof(s::Union{Const, Type{Const}}) = :Constant +Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter +Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -254,6 +287,8 @@ function Base.nameof(::typeof(∧), s1, s2) Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") end +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") + Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") Base.nameof(::typeof(Δ), s) = :Δ @@ -263,15 +298,20 @@ function Base.nameof(::typeof(★), s) Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") end -function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) +# TODO: Check that form type is no larger than the ambient dimension +function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, dim::Int = 2) @match qty begin - :Scalar || :Constant => Scalar - :Form0 => PrimalForm{0, space, 1} - :Form1 => PrimalForm{1, space, 1} - :Form2 => PrimalForm{2, space, 1} - :DualForm0 => DualForm{0, space, 1} - :DualForm1 => DualForm{1, space, 1} - :DualForm2 => DualForm{2, space, 1} + :Scalar => Scalar + :Constant => Const + :Parameter => Parameter + :Literal => Literal + :Form0 => PrimalForm{0, space, dim} + :Form1 => PrimalForm{1, space, dim} + :Form2 => PrimalForm{2, space, dim} + :DualForm0 => DualForm{0, space, dim} + :DualForm1 => DualForm{1, space, dim} + :DualForm2 => DualForm{2, space, dim} + :Infer => InferredType _ => error("Received $qty") end end diff --git a/src/deca/deca_acset.jl b/src/deca/deca_acset.jl index 55520e3..e0f5580 100644 --- a/src/deca/deca_acset.jl +++ b/src/deca/deca_acset.jl @@ -5,7 +5,7 @@ using ..DiagrammaticEquations These are the default rules used to do type inference in the 1D exterior calculus. """ op1_inf_rules_1D = [ - # Rules for ∂ₜ + # Rules for ∂ₜ (src_type = :Form0, tgt_type = :Form0, op_names = [:∂ₜ,:dt]), (src_type = :Form1, tgt_type = :Form1, op_names = [:∂ₜ,:dt]), @@ -14,10 +14,10 @@ op1_inf_rules_1D = [ (src_type = :DualForm0, tgt_type = :DualForm1, op_names = [:d, :dual_d₀, :d̃₀]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm1, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm0, op_names = [:⋆, :⋆₁, :star]), - (src_type = :DualForm1, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :Form0, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :DualForm1, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -29,13 +29,20 @@ op1_inf_rules_1D = [ # Rules for negation (src_type = :Form0, tgt_type = :Form0, op_names = [:neg, :(-)]), (src_type = :Form1, tgt_type = :Form1, op_names = [:neg, :(-)]), - + # Rules for the averaging operator (src_type = :Form0, tgt_type = :Form1, op_names = [:avg₀₁, :avg_01]), - + + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:mag, :norm]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:mag, :norm])] + (src_type = :PVF, tgt_type = :Form0, op_names = [:mag, :norm]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:mag, :norm])] op2_inf_rules_1D = [ # Rules for ∧₀₀, ∧₁₀, ∧₀₁ @@ -45,7 +52,7 @@ op2_inf_rules_1D = [ # Rules for L₀, L₁ (proj1_type = :Form1, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:L, :L₀]), - (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), + (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:L, :L₁]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:i, :i₁]), @@ -62,10 +69,10 @@ op2_inf_rules_1D = [ (proj1_type = :Form0, proj2_type = :Parameter, res_type = :Form0, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form1, proj2_type = :Parameter, res_type = :Form1, op_names = [:/, :./, :*, :.*]), (proj1_type = :Form2, proj2_type = :Parameter, res_type = :Form2, op_names = [:/, :./, :*, :.*]),=# - + (proj1_type = :Form0, proj2_type = :Literal, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Literal, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :DualForm0, proj2_type = :Literal, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm1, proj2_type = :Literal, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), @@ -79,11 +86,15 @@ op2_inf_rules_1D = [ (proj1_type = :Constant, proj2_type = :Form1, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form0, proj2_type = :Constant, res_type = :Form0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Form1, proj2_type = :Constant, res_type = :Form1, op_names = [:/, :./, :*, :.*, :^, :.^]), - + (proj1_type = :Constant, proj2_type = :DualForm0, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :Constant, proj2_type = :DualForm1, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), (proj1_type = :DualForm0, proj2_type = :Constant, res_type = :DualForm0, op_names = [:/, :./, :*, :.*, :^, :.^]), - (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^])] + (proj1_type = :DualForm1, proj2_type = :Constant, res_type = :DualForm1, op_names = [:/, :./, :*, :.*, :^, :.^]), + + # These rules contain infer: + (proj1_type = :Form0, proj2_type = :infer, res_type = :Form0, op_names = [:^]), + (proj1_type = :DualForm0, proj2_type = :infer, res_type = :DualForm0, op_names = [:^])] """ These are the default rules used to do type inference in the 2D exterior calculus. @@ -101,13 +112,13 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm2, op_names = [:d, :dual_d₁, :d̃₁]), # Rules for ⋆ - (src_type = :Form0, tgt_type = :DualForm2, op_names = [:⋆, :⋆₀, :star]), - (src_type = :Form1, tgt_type = :DualForm1, op_names = [:⋆, :⋆₁, :star]), - (src_type = :Form2, tgt_type = :DualForm0, op_names = [:⋆, :⋆₂, :star]), + (src_type = :Form0, tgt_type = :DualForm2, op_names = [:★, :⋆, :⋆₀, :star]), + (src_type = :Form1, tgt_type = :DualForm1, op_names = [:★, :⋆, :⋆₁, :star]), + (src_type = :Form2, tgt_type = :DualForm0, op_names = [:★, :⋆, :⋆₂, :star]), - (src_type = :DualForm2, tgt_type = :Form0, op_names = [:⋆, :⋆₀⁻¹, :star_inv]), - (src_type = :DualForm1, tgt_type = :Form1, op_names = [:⋆, :⋆₁⁻¹, :star_inv]), - (src_type = :DualForm0, tgt_type = :Form2, op_names = [:⋆, :⋆₂⁻¹, :star_inv]), + (src_type = :DualForm2, tgt_type = :Form0, op_names = [:★, :⋆, :⋆₀⁻¹, :star_inv]), + (src_type = :DualForm1, tgt_type = :Form1, op_names = [:★, :⋆, :⋆₁⁻¹, :star_inv]), + (src_type = :DualForm0, tgt_type = :Form2, op_names = [:★, :⋆, :⋆₂⁻¹, :star_inv]), # Rules for Δ (src_type = :Form0, tgt_type = :Form0, op_names = [:Δ, :Δ₀, :lapl]), @@ -133,14 +144,17 @@ op1_inf_rules_2D = [ (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:neg, :(-)]), (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:neg, :(-)]), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, op_names = [:♯, :♯ᵖᵖ]), + (src_type = :DualForm1, tgt_type = :DVF, op_names = [:♯, :♯ᵈᵈ]), + + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, op_names = [:♭, :♭ᵈᵖ]), + # Rules for magnitude/ norm - (src_type = :Form0, tgt_type = :Form0, op_names = [:norm, :mag]), - (src_type = :Form1, tgt_type = :Form1, op_names = [:norm, :mag]), - (src_type = :Form2, tgt_type = :Form2, op_names = [:norm, :mag]), - (src_type = :DualForm0, tgt_type = :DualForm0, op_names = [:norm, :mag]), - (src_type = :DualForm1, tgt_type = :DualForm1, op_names = [:norm, :mag]), - (src_type = :DualForm2, tgt_type = :DualForm2, op_names = [:norm, :mag])] - + (src_type = :PVF, tgt_type = :Form0, op_names = [:norm, :mag]), + (src_type = :DVF, tgt_type = :DualForm0, op_names = [:norm, :mag])] + op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ∧₁₁, ∧₂₀, ∧₀₂ (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form2, op_names = [:∧, :∧₁₁, :wedge]), @@ -148,7 +162,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form0, proj2_type = :Form2, res_type = :Form2, op_names = [:∧, :∧₀₂, :wedge]), # Rules for L₂ - (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), + (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, op_names = [:L, :L₂]), # Rules for i₁ (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:i, :i₂]), @@ -159,7 +173,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for ι (proj1_type = :DualForm1, proj2_type = :DualForm1, res_type = :DualForm0, op_names = [:ι₁₁]), (proj1_type = :DualForm1, proj2_type = :DualForm2, res_type = :DualForm1, op_names = [:ι₁₂]), - + # Rules for subtraction (proj1_type = :Form0, proj2_type = :Form0, res_type = :Form0, op_names = [:-, :.-]), (proj1_type = :Form1, proj2_type = :Form1, res_type = :Form1, op_names = [:-, :.-]), @@ -198,7 +212,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # Rules for Δ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :Δ), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :Δ)] - + # We merge 1D and 2D rules since it seems op2 rules are metric-free. If # this assumption is false, this needs to change. op2_res_rules_1D = [ @@ -214,8 +228,8 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm1, resolved_name = :L₁, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm1, res_type = :DualForm0, resolved_name = :i₁, op = :i)] - - + + """ These are the default rules used to do function resolution in the 2D exterior calculus. """ @@ -243,6 +257,11 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :δ), (src_type = :Form2, tgt_type = :Form1, resolved_name = :δ₂, op = :codif), (src_type = :Form1, tgt_type = :Form0, resolved_name = :δ₁, op = :codif), + # Rules for ♯. + (src_type = :Form1, tgt_type = :PVF, resolved_name = :♯ᵖᵖ, op = :♯), + (src_type = :DualForm1, tgt_type = :DVF, resolved_name = :♯ᵈᵈ, op = :♯), + # Rules for ♭. + (src_type = :DVF, tgt_type = :Form1, resolved_name = :♭ᵈᵖ, op = :♭), # Rules for ∇². # TODO: Call this :nabla2 in ASCII? (src_type = :Form0, tgt_type = :Form0, resolved_name = :∇²₀, op = :∇²), @@ -255,7 +274,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ (src_type = :Form0, tgt_type = :Form0, resolved_name = :Δ₀, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₁, op = :lapl), (src_type = :Form1, tgt_type = :Form1, resolved_name = :Δ₂, op = :lapl)] - + # We merge 1D and 2D rules directly here since it seems op2 rules # are metric-free. If this assumption is false, this needs to change. op2_res_rules_2D = vcat(op2_res_rules_1D, [ @@ -271,7 +290,7 @@ op2_inf_rules_2D = vcat(op2_inf_rules_1D, [ # (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm2, resolved_name = :L₂ᵈ, op = :L), # Rules for i. (proj1_type = :Form1, proj2_type = :DualForm2, res_type = :DualForm1, resolved_name = :i₂, op = :i)]) - + # TODO: When SummationDecapodes are annotated with the degree of their space, # use dispatch to choose the correct set of rules. infer_types!(d::SummationDecapode) = @@ -339,4 +358,3 @@ Resolve function overloads based on types of src and tgt. """ resolve_overloads!(d::SummationDecapode) = resolve_overloads!(d, op1_res_rules_2D, op2_res_rules_2D) - diff --git a/src/graph_traversal.jl b/src/graph_traversal.jl new file mode 100644 index 0000000..e41048e --- /dev/null +++ b/src/graph_traversal.jl @@ -0,0 +1,73 @@ +using DiagrammaticEquations +using ACSets + +export TraversalNode, topological_sort_edges, n_ops, retrieve_name, start_nodes, edge_inputs, edge_output, edge_function + +struct TraversalNode{T} + index::Int + name::T +end + +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + [d[idx,:src]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + [d[idx,:proj1], d[idx,:proj2]] +edge_inputs(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[incident(d, idx, :summation), :summand] + +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:tgt] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:res] +edge_output(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + d[idx, :sum] + +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op1}) = + d[idx,:op1] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Op2}) = + d[idx,:op2] +edge_function(d::SummationDecapode, idx::Int, ::Val{:Σ}) = + :+ + +#XXX: This topological sort is O(n^2). +function topological_sort_edges(d::SummationDecapode) + visited_Var = falses(nparts(d, :Var)) + visited_Var[start_nodes(d)] .= true + visited = Dict(:Op1 => falses(nparts(d, :Op1)), + :Op2 => falses(nparts(d, :Op2)), :Σ => falses(nparts(d, :Σ))) + + op_order = TraversalNode{Symbol}[] + + function visit(op, op_type) + if !visited[op_type][op] && all(visited_Var[edge_inputs(d,op,Val(op_type))]) + visited[op_type][op] = true + visited_Var[edge_output(d,op,Val(op_type))] = true + push!(op_order, TraversalNode(op, op_type)) + end + end + + for _ in 1:n_ops(d) + visit.(parts(d,:Op1), :Op1) + visit.(parts(d,:Op2), :Op2) + visit.(parts(d,:Σ), :Σ) + end + + @assert length(op_order) == n_ops(d) + op_order +end + +n_ops(d::SummationDecapode) = + nparts(d, :Op1) + nparts(d, :Op2) + nparts(d, :Σ) + +start_nodes(d::SummationDecapode) = + vcat(infer_states(d), incident(d, :Literal, :type)) + +function retrieve_name(d::SummationDecapode, tsr::TraversalNode) + @match tsr.name begin + :Op1 => d[tsr.index, :op1] + :Op2 => d[tsr.index, :op2] + :Σ => :+ + _ => error("$(tsr.name) is a table without names") + end +end + diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index 0c03a0a..7746559 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -43,7 +43,7 @@ Creates an operator `foo` with arguments which are types in a given Theory. This (@rule expr1) ... (@rule exprN) -end +end ``` builds ``` @@ -73,7 +73,7 @@ end """ macro operator(head, body) - # parse body + # parse head ph = @λ begin Expr(:call, foo, Expr(:(::), vars..., theory)) => (foo, vars, theory) Expr(:(::), Expr(:call, foo, vars...), theory) => (foo, vars, theory) @@ -81,7 +81,7 @@ macro operator(head, body) end (f, types, Theory) = ph(head) - # Passing types to functions requires that we type the signature with ::Type{T}. + # Passing types to functions requires that we type the signature with ::Type{T}. # This means that the user would have to write `my_op(::Type{T1}, ::Type{T2}, ...)` # As a convenience to the user, we allow them to specify the signature using just the types themselves: # `my_op(T1, T2, ...)` @@ -89,15 +89,15 @@ macro operator(head, body) sort_constraints = [:($S<:$Theory) for S in types] arity = length(sort_types) - # Parse the body for @rule calls. + # Parse the body for @rule calls. block, rulecalls = @match Base.remove_linenums!(body) begin Expr(:block, block, rules...) => (block, rules) s => nothing end - + # initialize the result result = quote end - + # construct the function on basic symbolics argnames = [gensym(:x) for _ in 1:arity] argclaus = [:($a::Symbolic) for a in argnames] @@ -107,15 +107,21 @@ macro operator(head, body) s = promote_symtype($f, $(argnames...)) SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end + export $f + + # Base.show(io::IO, ::typeof($f)) = print(io, $f) end) - # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) + # if there are rewriting rules, add a method which accepts the function symbol and its arity + # (to prevent shadowing on operators like `-`) if !isempty(rulecalls) push!(result.args, quote function rules(::typeof($f), ::Val{$arity}) [($(rulecalls...))] end + + rules(::typeof($f)) = rules($f, Val{1}) end) end @@ -150,18 +156,18 @@ macro alias(body) result = quote end foreach(aliases) do alias push!(result.args, - esc(quote + quote function $alias(s...) - $rep(s...) + $rep(s...) end export $alias + Base.nameof(::typeof($alias), s) = Symbol("$alias") - end)) + end) end - result + return esc(result) end export @alias alias(x) = error("$x has no aliases") export alias - diff --git a/test/Project.toml b/test/Project.toml index 97b4819..4f3a02a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,4 +6,5 @@ CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" DiagrammaticEquations = "6f00c28b-6bed-4403-80fa-30e0dc12f317" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl new file mode 100644 index 0000000..0212a7e --- /dev/null +++ b/test/acset2symbolic.jl @@ -0,0 +1,286 @@ +using Test +using DiagrammaticEquations +using SymbolicUtils: Fixpoint, Prewalk, Postwalk, Chain, @rule +using Catlab + +(≃) = is_isomorphic + +@testset "Basic Roundtrip" begin + op1_only = @decapode begin + A::Form0 + B::Form1 + B == d(A) + end + + @test op1_only == symbolic_rewriting(op1_only) + + + op2_only = @decapode begin + A::Constant + B::Form0 + C::Form0 + + C == A * B + end + + @test op2_only == symbolic_rewriting(op2_only) + + sum_only = @decapode begin + A::Form2 + B::Form2 + C::Form2 + + C == A + B + end + + @test sum_only == symbolic_rewriting(sum_only) + + + multi_sum = @decapode begin + A::Form2 + B::Form2 + C::Form2 + D::Form2 + + D == (A + B) + C + end + infer_types!(multi_sum) + + # TODO: This is correct but the symbolics is splitting up the sum + @test multi_sum == symbolic_rewriting(multi_sum) + + + all_ops = @decapode begin + A::Constant + B::Form0 + C::Form1 + D::Form1 + E::Form1 + F::Form1 + + C == d(B) + D == A * C + F == D + E + end + + # This loses the intermediate names C and D + all_ops_res = symbolic_rewriting(all_ops) + all_ops_res[5, :name] = :D + all_ops_res[6, :name] = :C + @test all_ops ≃ all_ops_res + + with_deriv = @decapode begin + A::Form0 + Ȧ::Form0 + + ∂ₜ(A) == Ȧ + Ȧ == Δ(A) + end + + @test with_deriv == symbolic_rewriting(with_deriv) + + repeated_vars = @decapode begin + A::Form0 + B::Form1 + C::Form1 + + C == d(A) + C == Δ(B) + C == d(A) + end + + @test repeated_vars == symbolic_rewriting(repeated_vars) + + # TODO: This is broken because of the terminals issue in #77 + self_changing = @decapode begin + c_exp == ∂ₜ(c_exp) + end + + @test_broken repeated_vars == symbolic_rewriting(self_changing) + + literal = @decapode begin + A::Form0 + B::Form0 + + B == A * 2 + end + + @test literal == symbolic_rewriting(literal) + + parameter = @decapode begin + A::Form0 + P::Parameter + B::Form0 + + B == A * P + end + + @test parameter == symbolic_rewriting(parameter) + + constant = @decapode begin + A::Form0 + C::Constant + B::Form0 + + B == A * C + end + + @test constant == symbolic_rewriting(constant) +end + +function expr_rewriter(rules::Vector) + return Fixpoint(Prewalk(Fixpoint(Chain(rules)))) +end + +@testset "Basic Rewriting" begin + op1s = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == B + B + d(d(A)) + end + + dd_0 = @rule d(d(~x)) => 0 + + op1s_rewritten = symbolic_rewriting(op1s, expr_rewriter([dd_0])) + + op1s_equiv = @decapode begin + A::Form0 + B::Form2 + C::Form2 + + C == 2 * B + end + + @test op1s_equiv == op1s_rewritten + + + op2s = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(∧(A, B), C) + end + + wdg_assoc = @rule ∧(∧(~x, ~y), ~z) => ∧(~x, ∧(~y, ~z)) + + op2s_rewritten = symbolic_rewriting(op2s, expr_rewriter([wdg_assoc])) + + op2s_equiv = @decapode begin + A::Form0 + B::Form0 + C::Form0 + D::Form0 + + + D == ∧(A, ∧(B, C)) + end + infer_types!(op2s_equiv) + + @test op2s_equiv == op2s_rewritten + + + distr_d = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == d(∧(A, B)) + end + infer_types!(distr_d) + + leibniz = @rule d(∧(~x, ~y)) => ∧(d(~x), ~y) + ∧(~x, d(~y)) + + distr_d_rewritten = symbolic_rewriting(distr_d, expr_rewriter([leibniz])) + + distr_d_res = @decapode begin + A::Form0 + B::Form1 + C::Form2 + + C == ∧(d(A), B) + ∧(A, d(B)) + end + infer_types!(distr_d_res) + + @test distr_d_res == distr_d_rewritten +end + +@testset "Heat" begin + Heat = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*Δ(C) + end + infer_types!(Heat) + + # Same up to re-naming + Heat[5, :name] = Symbol("•1") + @test Heat == symbolic_rewriting(Heat) + + Heat_open = @decapode begin + C::Form0 + G::Form0 + D::Constant + ∂ₜ(G) == D*★(d(★(d(C)))) + end + infer_types!(Heat_open) + + Heat_open[8, :name] = Symbol("•1") + Heat_open[5, :name] = Symbol("•2") + Heat_open[6, :name] = Symbol("•3") + Heat_open[7, :name] = Symbol("•4") + + @test Heat_open ≃ symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) +end + +@testset "Phytodynamics" begin + Phytodynamics = @decapode begin + (n,w)::Form0 + m::Constant + ∂ₜ(n) == w + m*n + Δ(n) + end + infer_types!(Phytodynamics) + test_phy = symbolic_rewriting(Phytodynamics) +end + +@testset "Literals" begin + Heat = parse_decapode(quote + C::Form0 + G::Form0 + ∂ₜ(G) == 3*Δ(C) + end) + context = SymbolicContext(Heat) + SummationDecapode(context) + +end + +@testset "Parameters" begin + + Heat = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == Δ(u)*κ + end + infer_types!(Heat) + + Heat_open = @decapode begin + u::Form0 + G::Form0 + κ::Parameter + ∂ₜ(G) == ★(d(★(d(u))))*κ + end + infer_types!(Heat_open) + + Heat_open[7, :name] = Symbol("•4") + Heat_open[8, :name] = Symbol("•1") + + z = symbolic_rewriting(Heat, expr_rewriter(rules(Δ, Val(1)))) + @test Heat_open ≃ z + +end diff --git a/test/composition.jl b/test/composition.jl index f0c5c02..408cf37 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -2,15 +2,12 @@ using Test using DiagrammaticEquations using DiagrammaticEquations.Deca using Catlab -using Catlab.WiringDiagrams -using Catlab.Programs -using Catlab.CategoricalAlgebra # import DiagrammaticEquations: OpenSummationDecapode, Open, oapply, oapply_rename # @testset "Composition" begin # Simplest possible decapode relation. -Trivial = @decapode begin +Trivial = @decapode begin H::Form0{X} end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 0b5b7e9..e9bb0df 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -6,31 +6,49 @@ using SymbolicUtils using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle +import DiagrammaticEquations: rules + # load up some variable variables and expressions -a, b = @syms a::Scalar b::Scalar -u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} -ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} -ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} +ϐ, = @syms ϐ::InferredType # \varbeta +ℓ, = @syms ℓ::Literal +c, t = @syms c::Const t::Parameter +a, b = @syms a::Scalar b::Scalar +u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication - @testset "Term Construction" begin - + + @test symtype(ϐ) == InferredType + @test symtype(ℓ) == Literal + @test symtype(c) == Const + @test symtype(t) == Parameter @test symtype(a) == Scalar + @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @test symtype(η) == DualForm{2, :X, 2} @test symtype(ϕ) == PrimalVF{:X, 2} @test symtype(ψ) == DualVF{:X, 2} + @test symtype(c + t) == Scalar + @test symtype(t + t) == Scalar + @test symtype(c + c) == Scalar + @test symtype(t + ϐ) == InferredType + @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + @test symtype(u ∧ ϐ) == InferredType + # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} # test unary operator conversion to decaexpr @test Term(1) == Lit(Symbol("1")) @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) @test Term(∂ₜ(u)) == Tan(Var(:u)) @test Term(★(ω)) == App1(:★₁, Var(:ω)) @@ -44,95 +62,97 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @test promote_symtype(∧, u, u) == PrimalForm{0, :X, 2} @test promote_symtype(∧, u, ω) == PrimalForm{1, :X, 2} + @test promote_symtype(-, a) == Scalar + @test promote_symtype(-, u, u) == PrimalForm{0, :X, 2} # test composition @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} end -@testset "Operator definition" begin +# this is not nabla but "bizarro Δ" +del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin + @match S begin + PatScalar(_) => error("Argument of type $S is invalid") + PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) + end + @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) + @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) +end; + +# we will test is new operator +(r0, r1, r2) = @operator ρ(S)::DECQuantity begin + S <: Form ? Scalar : Form + @rule ρ(~x::isForm0) => 0 + @rule ρ(~x::isForm1) => 1 + @rule ρ(~x::isForm2) => 2 +end + +R, = @operator φ(S1, S2, S3)::DECQuantity begin + let T1=S1, T2=S2, T3=S3 + Scalar + end + @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) +end - # this is not nabla but "bizarro Δ" - del_expand_0, del_expand_1 = - @operator ∇(S)::DECQuantity begin - @match S begin - PatScalar(_) => error("Argument of type $S is invalid") - PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) - end - @rule ∇(~x::isForm0) => ★(d(★(d(~x)))) - @rule ∇(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) - end; +@alias (φ′,) => φ +@testset "Operator definition" begin + + # ∇ @test_throws Exception ∇(b) @test symtype(∇(u)) == PrimalForm{0, :X ,2} @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} - @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) - - # we will test is new operator - (r0, r1, r2) = @operator ρ(S)::DECQuantity begin - if S <: Form - Scalar - else - Form - end - @rule ρ(~x::isForm0) => 0 - @rule ρ(~x::isForm1) => 1 - @rule ρ(~x::isForm2) => 2 - end - + + # ρ @test symtype(ρ(u)) == Scalar - R, = @operator φ(S1, S2, S3)::DECQuantity begin - let T1=S1, T2=S2, T3=S3 - Scalar - end - @rule φ(2(~x::isForm0), 2(~y::isForm0), 2(~z::isForm0)) => 2*φ(~x,~y,~z) - end - - # TODO we need to alias rewriting rules - @alias (φ′,) => φ - + # R @test isequal(R(φ(2u,2u,2u)), R(φ′(2u,2u,2u))) + # TODO we need to alias rewriting rules end @testset "Conversion" begin - context = Dict(:a => Scalar(),:b => Scalar() - ,:u => PrimalForm(0, X),:du => PrimalForm(1, X)) - js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] - eqs = [Eq(Var(:∂ₜu) - , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - , Eq(Tan(Var(:u)), Var(:∂ₜu))] - heat_eq = DecaExpr(js, eqs) - - symb_heat_eq = DecaSymbolic(lookup, heat_eq) - deca_expr = DecaExpr(symb_heat_eq) - -end + Exp = @decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u + end + context = SymbolicContext(Term(Exp)) + Exp′ = SummationDecapode(DecaExpr(context)) -@testset "Moving between DecaExpr and DecaSymbolic" begin - - @test js == deca_expr.context - - # eqs in the left has AppCirc1[vector, term] - # deca_expr.equations on the right has nested App1 - # expected behavior is that nested AppCirc1 is preserved - @test_broken eqs == deca_expr.equations - # use expand_operators to get rid of parentheses - # infer_types and resolve_overloads - -end + # does roundtripping work + @test Exp == Exp′ -# convert both into ACSets then is_iso them -@testset "" begin + Heat = @decapode begin + u::Form0 + v::Form0 + κ::Constant + ∂ₜ(v) == Δ(u)*κ + end + infer_types!(Heat) + context = SymbolicContext(Term(Heat)) + Heat′ = SummationDecapode(DecaExpr(context)) - Σ = DiagrammaticEquations.SummationDecapode(deca_expr) - Δ = DiagrammaticEquations.SummationDecapode(symb_heat_eq) - @test Σ == Δ + @test Heat == Heat′ + TumorInvasion = @decapode begin + (C,fC)::Form0 + (Dif,Kd,Cmax)::Constant + ∂ₜ(C) == Dif * Δ(C) + fC - C * Kd + end + infer_types!(TumorInvasion) + context = SymbolicContext(Term(TumorInvasion)) + TumorInvasion′ = SummationDecapode(DecaExpr(context)) + + # new terms introduced because Symbolics converts subtraction expressions + # e.g., a - b => +(a, -b) + @test_broken TumorInvasion == TumorInvasion′ + # TI' has (11, Literal, -1) and (12, infer, mult_1) + # Op1 (2, 1, 4, 7) should be (2, 4, 1, 7) + # Sum is (1, 6), (2, 10) end diff --git a/test/graph_traversal.jl b/test/graph_traversal.jl new file mode 100644 index 0000000..b09fd24 --- /dev/null +++ b/test/graph_traversal.jl @@ -0,0 +1,64 @@ +using DiagrammaticEquations +using ACSets +using MLStyle +using Test + +function is_correct_length(d::SummationDecapode, result) + return length(result) == n_ops(d) +end + +@testset "Topological Sort on Edges" begin + no_edge = @decapode begin + F == S + end + @test isempty(topological_sort_edges(no_edge)) + + one_op1_deca = @decapode begin + F == f(S) + end + result = topological_sort_edges(one_op1_deca) + @test is_correct_length(one_op1_deca, result) + @test retrieve_name(one_op1_deca, only(result)) == :f + + multi_op1_deca = @decapode begin + F == c(b(a(S))) + end + result = topological_sort_edges(multi_op1_deca) + @test is_correct_length(multi_op1_deca, result) + for (edge, test_name) in zip(result, [:a, :b, :c]) + @test retrieve_name(multi_op1_deca, edge) == test_name + end + + cyclic = @decapode begin + B == g(A) + A == f(B) + end + @test_throws AssertionError topological_sort_edges(cyclic) + + just_op2 = @decapode begin + C == A * B + end + result = topological_sort_edges(just_op2) + @test is_correct_length(just_op2, result) + @test retrieve_name(just_op2, only(result)) == :* + + just_simple_sum = @decapode begin + C == A + B + end + result = topological_sort_edges(just_simple_sum) + @test is_correct_length(just_simple_sum, result) + @test retrieve_name(just_simple_sum, only(result)) == :+ + + just_multi_sum = @decapode begin + F == A + B + C + D + E + end + result = topological_sort_edges(just_multi_sum) + @test is_correct_length(just_multi_sum, result) + @test retrieve_name(just_multi_sum, only(result)) == :+ + + op_combo = @decapode begin + F == h(d(A) + f(g(B) * C) + D) + end + result = topological_sort_edges(op_combo) + @test is_correct_length(op_combo, result) +end diff --git a/test/language.jl b/test/language.jl index 3e3ab4e..25f817e 100644 --- a/test/language.jl +++ b/test/language.jl @@ -1,11 +1,5 @@ using Test using Catlab -using Catlab.Theories -using Catlab.CategoricalAlgebra -using Catlab.WiringDiagrams -using Catlab.WiringDiagrams.DirectedWiringDiagrams -using Catlab.Graphics -using Catlab.Programs using LinearAlgebra using MLStyle using Base.Iterators @@ -356,13 +350,14 @@ end @test issetequal([:V,:X,:k], infer_state_names(oscillator)) end -import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES +import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, + DUALFORM_TYPES, VECTORFIELD_TYPES, NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, + INFER_TYPES, NONINFERABLE_TYPES @testset "Type Retrival" begin type_groups = [ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_TYPES, - NONFORM_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] + NON_EC_TYPES, USER_TYPES, NUMBER_TYPES, INFER_TYPES, NONINFERABLE_TYPES] # No repeated types @@ -374,12 +369,12 @@ import DiagrammaticEquations: ALL_TYPES, FORM_TYPES, PRIMALFORM_TYPES, DUALFORM_ no_overlaps(types_1, types_2) = isempty(intersect(types_1, types_2)) # Collections of these types should be the same - @test equal_types(ALL_TYPES, vcat(FORM_TYPES, NONFORM_TYPES)) - @test equal_types(FORM_TYPES, vcat(PRIMALFORM_TYPES, DUALFORM_TYPES)) - @test equal_types(NONINFERABLE_TYPES, vcat(USER_TYPES, NUMBER_TYPES)) + @test equal_types(ALL_TYPES, FORM_TYPES ∪ VECTORFIELD_TYPES ∪ NON_EC_TYPES) + @test equal_types(FORM_TYPES, PRIMALFORM_TYPES ∪ DUALFORM_TYPES) + @test equal_types(NONINFERABLE_TYPES, USER_TYPES ∪ NUMBER_TYPES) # Proper seperation of types - @test no_overlaps(FORM_TYPES, NONFORM_TYPES) + @test no_overlaps(FORM_TYPES ∪ VECTORFIELD_TYPES, NON_EC_TYPES) @test no_overlaps(PRIMALFORM_TYPES, DUALFORM_TYPES) @test no_overlaps(NONINFERABLE_TYPES, FORM_TYPES) @test INFER_TYPES == [:infer] @@ -400,9 +395,9 @@ end import DiagrammaticEquations: safe_modifytype @testset "Safe Type Modification" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] bad_sources = [:Literal, :Constant, :Parameter] - good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :infer] + good_sources = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF, :infer] for tgt in all_types for src in bad_sources @@ -425,13 +420,13 @@ import DiagrammaticEquations: safe_modifytype end end -import DiagrammaticEquations: filterfor_forms +import DiagrammaticEquations: filterfor_ec_types @testset "Form Type Retrieval" begin - all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :infer] - @test filterfor_forms(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2] - @test isempty(filterfor_forms(Symbol[])) - @test isempty(filterfor_forms([:Literal, :Constant, :Parameter, :infer])) + all_types = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :Literal, :Constant, :Parameter, :PVF, :DVF, :infer] + @test filterfor_ec_types(all_types) == [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2, :PVF, :DVF] + @test isempty(filterfor_ec_types(Symbol[])) + @test isempty(filterfor_ec_types([:Literal, :Constant, :Parameter, :infer])) end @testset "Type Inference" begin @@ -831,6 +826,38 @@ end @test_throws "Type mismatch in summation" infer_types!(d) end + # Test #25: Infer between flattened and sharpened vector fields. + let + d = @decapode begin + A::Form1 + B::DualForm1 + C::PVF + D::DVF + + A == ♭(E) + B == ♭(F) + C == ♯(G) + D == ♯(H) + + I::Form1 + J::DualForm1 + K::PVF + L::DVF + + M == ♯(I) + N == ♯(J) + O == ♭(K) + P == ♭(L) + end + infer_types!(d) + + # TODO: Update this as more sharps and flats are released. + names_types_expected = Set([(:A, :Form1), (:B, :DualForm1), (:C, :PVF), (:D, :DVF), + (:E, :DVF), (:F, :infer), (:G, :Form1), (:H, :DualForm1), + (:I, :Form1), (:J, :DualForm1), (:K, :PVF), (:L, :DVF), + (:M, :PVF), (:N, :DVF), (:O, :infer), (:P, :Form1)]) + @test test_nametype_equality(d, names_types_expected) + end end @testset "Overloading Resolution" begin @@ -1048,6 +1075,42 @@ end op2s_hx = HeatXfer[:op2] op2s_expected_hx = [:*, :/, :/, :L₀, :/, :L₁, :*, :/, :*, :i₁, :/, :*, :*, :L₀] @test op2s_hx == op2s_expected_hx + + # Infer types and resolve overloads for the Halfar equation. + let + d = @decapode begin + h::Form0 + Γ::Form1 + n::Constant + + ∂ₜ(h) == ∘(⋆, d, ⋆)(Γ * d(h) ∧ (mag(♯(d(h)))^(n-1)) ∧ (h^(n+2))) + end + d = expand_operators(d) + infer_types!(d) + resolve_overloads!(d) + @test d == @acset SummationDecapode{Any, Any, Symbol} begin + Var = 19 + TVar = 1 + Op1 = 8 + Op2 = 6 + Σ = 1 + Summand = 2 + src = [1, 1, 1, 13, 12, 6, 18, 19] + tgt = [4, 9, 13, 12, 11, 18, 19, 4] + proj1 = [2, 3, 11, 8, 1, 7] + proj2 = [9, 15, 14, 10, 5, 16] + res = [8, 14, 10, 7, 16, 6] + incl = [4] + summand = [3, 17] + summation = [1, 1] + sum = [5] + op1 = [:∂ₜ, :d₀, :d₀, :♯ᵖᵖ, :mag, :⋆₁, :dual_d₁, :⋆₀⁻¹] + op2 = [:*, :-, :^, :∧₁₀, :^, :∧₁₀] + type = [:Form0, :Form1, :Constant, :Form0, :infer, :Form1, :Form1, :Form1, :Form1, :Form0, :Form0, :PVF, :Form1, :infer, :Literal, :Form0, :Literal, :DualForm1, :DualForm2] + name = [:h, :Γ, :n, :ḣ, :sum_1, Symbol("•2"), Symbol("•3"), Symbol("•4"), Symbol("•5"), Symbol("•6"), Symbol("•7"), Symbol("•8"), Symbol("•9"), Symbol("•10"), Symbol("1"), Symbol("•11"), Symbol("2"), Symbol("•_6_1"), Symbol("•_6_2")] + end + end + end @testset "Compilation Transformation" begin diff --git a/test/runtests.jl b/test/runtests.jl index dd92531..c68bded 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,4 +38,13 @@ end include("openoperators.jl") end +@testset "Symbolic Rewriting" begin + include("graph_traversal.jl") + include("acset2symbolic.jl") +end + +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end + include("aqua.jl") From 70a420d930cd63622b3ea5245d07a60b7f15c633 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 3 Oct 2024 15:53:42 -0400 Subject: [PATCH 24/30] added tests for errors and consolidated errors with George --- src/deca/ThDEC.jl | 76 ++++++++++++-------------------------------- test/decasymbolic.jl | 40 +++++++++++++++++++++++ 2 files changed, 61 insertions(+), 55 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 4fa4695..1de6f5e 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -175,7 +175,7 @@ export isDualForm, isForm0, isForm1, isForm2 @match S begin PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} - _ => throw(ExteriorDerivativeError(S)) + _ => throw(OperatorError("take the exterior derivative", S)) end end @@ -185,17 +185,18 @@ end @match S begin PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} - _ => throw(HodgeStarError(S)) + _ => throw(OperatorError("take the hodge star", S)) end end +# TODO in orthodox Decapodes, these are type-specific. @alias (★₀, ★₁, ★₂, ★₀⁻¹, ★₁⁻¹, ★₂⁻¹) => ★ @operator Δ(S)::DECQuantity begin @match S begin PatInferredType(_) => InferredType PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) - _ => throw(LaplacianError(S)) + _ => throw(OperatorError("take the Laplacian", S)) end @rule Δ(~x::isForm0) => ★(d(★(d(~x)))) @rule Δ(~x::isForm1) => ★(d(★(d(~x)))) + d(★(d(★(~x)))) @@ -214,10 +215,10 @@ end if (i1 == i2) && (d1 == d2) && (s1 == s2) && (n1 == n2) Form{i1, d1, s1, n1} else - throw(AdditionDimensionalError(S1, S2)) + throw(OperatorError("sum", [S1, S2])) end end - _ => throw(BinaryOpError("add", S1, S2)) + _ => throw(OperatorError("add", [S1, S2])) end end @@ -230,7 +231,7 @@ end PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} - _ => throw(BinaryOpError("multiply", S1, S2)) + _ => throw(OperatorError("multiply", [S1, S2])) end end @@ -238,14 +239,14 @@ end @match (S1, S2) begin PatInferredTypes(_) => InferredType (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin - (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) + (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(OperatorError("take the wedge product", [S1, S2])) if i1 + i2 <= n1 Form{i1 + i2, d1, s1, n1} else - throw(WedgeDimError(S1, S2)) + throw(OperatorError("take the wedge product", [S1, S2], "The dimensions of the form are bounded by $n1")) end end - _ => throw(BinaryOpError("take the wedge product of", S1, S2)) + _ => throw(OperatorError("take the wedge product", [S1, S2])) end end @@ -316,54 +317,19 @@ function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, d end end -struct ExteriorDerivativeError <: SortError - sort::DataType -end - -Base.showerror(io::IO, e::ExteriorDerivativeError) = print(io, "Cannot apply the exterior derivative to $(e.sort)") - -struct HodgeStarError <: SortError - sort::DataType -end - -Base.showerror(io::IO, e::HodgeStarError) = print(io, "Cannot take the hodge star of $(e.sort)") - -struct LaplacianError <: SortError - sort::DataType -end - -Base.showerror(io::IO, e::LaplacianError) = print(io, "Cannot take the Laplacian of $(e.sort)") - -struct AdditionDimensionalError <: SortError - sort1::DataType - sort2::DataType -end - -Base.showerror(io::IO, e::AdditionDimensionalError) = print(io, """ - Can not add two forms of different dimensions/dualities/spaces: - $(e.sort1) and $(e.sort2) - """) - -struct BinaryOpError <: SortError +struct OperatorError <: SortError verb::String - sort1::DataType - sort2::DataType -end - -Base.showerror(io::IO, e::BinaryOpError) = print(io, "Cannot $(e.verb) $(e.sort1) and $(e.sort2)") - -struct WedgeOpError <: SortError - sort1::DataType - sort2::DataType -end - -Base.showerror(io::IO, e::WedgeOpError) = print(io, "Can only take a wedge product of two forms of the same duality on the same space. Received $(e.sort1) and $(e.sort2)") - -struct WedgeOpDimError <: SortError - sort1::DataType - sort2::DataType + sorts::Vector{DataType} + othermsg::String + function OperatorError(verb::String, sorts::Vector{DataType}, othermsg::String="") + new(verb, sorts, othermsg) + end + function OperatorError(verb::String, sort::DataType, othermsg::String="") + new(verb, [sort], othermsg) + end end +export OperatorError -Base.showerror(io::IO, e::WedgeOpDimError) = print(io, "Can only take a wedge product when the dimensions of the forms add to less than n, where n = $(e.sort.dim) is the dimension of the ambient space: tried to wedge product $(e.sort1) and $(e.sort2)") +Base.showerror(io::IO, e::OperatorError) = print(io, "Cannot take the $(e.verb) of $(join(e.sorts, " and ")). $(e.othermsg)") end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index e9bb0df..638895d 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -15,9 +15,13 @@ c, t = @syms c::Const t::Parameter a, b = @syms a::Scalar b::Scalar u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +h, = @syms h::PrimalForm{2, :X, 2} ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication +u2, = @syms u2::PrimalForm{0, :Y, 2} +u3, = @syms u3::PrimalForm{0, :X, 3} + @testset "Term Construction" begin @test symtype(ϐ) == InferredType @@ -114,6 +118,42 @@ end end +@testset "Errors" begin + + # addition + @test_throws OperatorError u + du # mismatched grade + @test_throws OperatorError h + η # primal and dual + @test_throws OperatorError u + u2 # differing spaces + @test_throws OperatorError u + u3 # differing codimension + + # subtraction + @test_throws OperatorError u - du # mismatched grade + @test_throws OperatorError h - η # primal and dual + @test_throws OperatorError u - u2 # differing spaces + @test_throws OperatorError u - u3 # differing spatial dimension + + # exterior derivative + @test_throws OperatorError d(a) + @test_throws OperatorError d(ϕ) + + # hodge star + @test_throws OperatorError ★(a) + @test_throws OperatorError ★(ϕ) + + # Laplacian + @test_throws OperatorError Δ(a) + @test_throws OperatorError Δ(ϕ) + + # multiplication + @test_throws OperatorError u * du + @test_throws OperatorError ϕ * u + + @test_throws OperatorError du ∧ h # checks if spaces exceed dimension + @test_throws OperatorError a ∧ a # cannot take wedge of scalars + @test_throws OperatorError u ∧ ϕ # cannot take wedge of vector fields + +end + @testset "Conversion" begin Exp = @decapode begin From 85572a4a564337c6437fbc98fc8c834f1c4c8346 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:23:24 -0400 Subject: [PATCH 25/30] Fixed hodge nameof and more tests --- src/SymbolicUtilsInterop.jl | 4 +--- src/deca/ThDEC.jl | 21 +++++++++++++++------ test/decasymbolic.jl | 20 ++++++++++++++++---- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index c95fb54..bee2291 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -18,9 +18,7 @@ struct SymbolicEquation{E} end export SymbolicEquation -Base.show(io::IO, e::SymbolicEquation) = begin - print(io, e.lhs); print(io, " == "); print(io, e.rhs) -end +Base.show(io::IO, e::SymbolicEquation) = print(io, "$(e.lhs) == $(e.rhs)") ## a struct carry the symbolic variables and their equations struct SymbolicContext diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 1de6f5e..26f5caf 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -271,32 +271,41 @@ end show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" function Base.show(io::IO, ω::Form) - print(io, isdual(ω) ? "DualForm($(dim(ω))) on $(space(ω))" : "PrimalForm($(dim(ω))) on $(space(ω))") + if isdual(ω) + print(io, "DualForm($(dim(ω))) on $(space(ω))") + else + print(io, "PrimalForm($(dim(ω))) on $(space(ω))") + end end -Base.nameof(::typeof(-), s1, s2) = Symbol("$(as_sub(dim(s1)))-$(as_sub(dim(s2)))") +sub_dim(s) = as_sub(dim(s)) + +Base.nameof(::typeof(-), s1, s2) = Symbol("$(sub_dim(s1))-$(sub_dim(s2))") const SUBSCRIPT_DIGIT_0 = '₀' as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) + function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} - Symbol("∧$(as_sub(dim(symtype(s1))))$(as_sub(dim(symtype(s2))))") + Symbol("∧$(sub_dim(symtype(s1)))$(sub_dim(symtype(s2)))") end function Base.nameof(::typeof(∧), s1, s2) - Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") + Symbol("∧$(sub_dim(s1))$(sub_dim(s2))") end Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") -Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") +Base.nameof(::typeof(d), s) = Symbol("d$(sub_dim(s))") +# Base.nameof(::typeof(Δ), s) = Symbol("Δ$(sub_dim(s))") Base.nameof(::typeof(Δ), s) = :Δ function Base.nameof(::typeof(★), s) inv = isdual(s) ? "⁻¹" : "" - Symbol("★$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)") + hdg_dim = isdual(s) ? spacedim(s) - dim(s) : dim(s) + Symbol("★$(as_sub(hdg_dim))$(inv)") end # TODO: Check that form type is no larger than the ambient dimension diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 638895d..084378e 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -29,7 +29,7 @@ u3, = @syms u3::PrimalForm{0, :X, 3} @test symtype(c) == Const @test symtype(t) == Parameter @test symtype(a) == Scalar - + @test symtype(u) == PrimalForm{0, :X, 2} @test symtype(ω) == PrimalForm{1, :X, 2} @test symtype(η) == DualForm{2, :X, 2} @@ -54,13 +54,25 @@ u3, = @syms u3::PrimalForm{0, :X, 3} @test Term(c) == Var(:c) @test Term(t) == Var(:t) @test Term(∂ₜ(u)) == Tan(Var(:u)) + @test_broken Term(∂ₜ(u)) == Term(DerivOp(u)) + @test Term(★(ω)) == App1(:★₁, Var(:ω)) - + @test Term(★(η)) == App1(:★₀⁻¹, Var(:η)) + # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) + # TODO: Currently parses as addition + @test_broken Term(a - b) == App2(:-, Var(:a), Var(:b)) @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) - + + @test Term(ω + du + d(u)) == Plus(Term[App1(:d₀, Var(:u)), Var(:du), Var(:ω)]) + + let + @syms f(x, y, z) + @test_throws "was unable to convert" Term(f(a, b, u)) + end + # test promoting types @test promote_symtype(d, u) == PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @@ -108,7 +120,7 @@ end @test symtype(∇(u)) == PrimalForm{0, :X ,2} @test promote_symtype(∇, u) == PrimalForm{0, :X, 2} @test isequal(del_expand_0(∇(u)), ★(d(★(d(u))))) - + # ρ @test symtype(ρ(u)) == Scalar From df41b18cda29cb4aabe1e5d0d98751dfdd5f1fec Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:55:59 -0400 Subject: [PATCH 26/30] Many more tests These tests attempt to cover all of the basic functionality of ThDEC. It also points out some strange or desired behavior. --- src/deca/ThDEC.jl | 4 +- test/decasymbolic.jl | 121 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 99 insertions(+), 26 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 26f5caf..a95a1aa 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -268,6 +268,7 @@ function Base.nameof(f::Form; with_dim_parameter=false) end # show methods +# TODO: Remove me? Not being used anywhere show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" function Base.show(io::IO, ω::Form) @@ -286,7 +287,7 @@ const SUBSCRIPT_DIGIT_0 = '₀' as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) - +# TODO: Do we want both nameof's for wedges? This one belows expects different args function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} Symbol("∧$(sub_dim(symtype(s1)))$(sub_dim(symtype(s2)))") end @@ -297,6 +298,7 @@ end Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") +#TODO: Add an option to output d for dual forms, typically d₀ -> dual_d₀ Base.nameof(::typeof(d), s) = Symbol("d$(sub_dim(s))") # Base.nameof(::typeof(Δ), s) = Symbol("Δ$(sub_dim(s))") diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 084378e..11afe40 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -22,7 +22,7 @@ h, = @syms h::PrimalForm{2, :X, 2} u2, = @syms u2::PrimalForm{0, :Y, 2} u3, = @syms u3::PrimalForm{0, :X, 3} -@testset "Term Construction" begin +@testset "Symtypes" begin @test symtype(ϐ) == InferredType @test symtype(ℓ) == Literal @@ -47,32 +47,73 @@ u3, = @syms u3::PrimalForm{0, :X, 3} # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} +end - # test unary operator conversion to decaexpr - @test Term(1) == Lit(Symbol("1")) - @test Term(a) == Var(:a) - @test Term(c) == Var(:c) - @test Term(t) == Var(:t) - @test Term(∂ₜ(u)) == Tan(Var(:u)) - @test_broken Term(∂ₜ(u)) == Term(DerivOp(u)) - - @test Term(★(ω)) == App1(:★₁, Var(:ω)) - @test Term(★(η)) == App1(:★₀⁻¹, Var(:η)) - - # test binary operator conversion to decaexpr - @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) - # TODO: Currently parses as addition - @test_broken Term(a - b) == App2(:-, Var(:a), Var(:b)) - @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) - @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) - - @test Term(ω + du + d(u)) == Plus(Term[App1(:d₀, Var(:u)), Var(:du), Var(:ω)]) - - let - @syms f(x, y, z) - @test_throws "was unable to convert" Term(f(a, b, u)) - end +@testset "Type Information" begin + + @test dim(symtype(η)) == 2 + + @test isdual(symtype(η)) + @test !isdual(symtype(u)) + + @test space(symtype(η)) == :X + @test spacedim(symtype(η)) == 2 + + @test isForm0(u) + @test !isForm0(du) + @test !isForm0(a) + + @test isForm1(ω) + @test !isForm1(u) + @test !isForm1(a) + + @test isForm2(η) + @test !isForm2(u) + @test !isForm2(a) + + @test isDualForm(η) + @test !isDualForm(u) + @test !isDualForm(a) +end + +@testset "Nameof" begin + @test nameof(symtype(c)) == :Constant + @test nameof(symtype(t)) == :Parameter + @test nameof(symtype(a)) == :Scalar + @test nameof(symtype(ℓ)) == :Literal + + @test nameof(symtype(u)) == :Form0 + @test nameof(symtype(ω)) == :Form1 + @test nameof(symtype(η)) == :DualForm2 + + # TODO: Do we want this style of typed subtraction? + @test nameof(-, symtype(u), symtype(u)) == Symbol("₀-₀") + # TODO: This breaks since this expects the types to have a `dim` function + @test_broken nameof(-, symtype(a), symtype(b)) == Symbol("-") + + @test nameof(∧, symtype(u), symtype(u)) == Symbol("∧₀₀") + @test nameof(∧, symtype(u), symtype(ω)) == Symbol("∧₀₁") + @test nameof(∧, symtype(ω), symtype(u)) == Symbol("∧₁₀") + + # TODO: Do we need a special designation for wedges with duals in them? + @test nameof(∧, symtype(ω), symtype(η)) == Symbol("∧₁₂") + # TODO: Why is this being named as such? + @test nameof(∂ₜ, symtype(u)) == Symbol("∂ₜ(Form0)") + @test nameof(∂ₜ, symtype(d(u))) == Symbol("∂ₜ(Form1)") + + @test nameof(d, symtype(u)) == Symbol("d₀") + @test_broken nameof(d, symtype(η)) == Symbol("dual_d₂") + + @test_broken nameof(Δ, symtype(u)) == Symbol("Δ₀") + @test_broken nameof(Δ, symtype(ω)) == Symbol("Δ₁") + + @test nameof(★, symtype(u)) == Symbol("★₀") + @test nameof(★, symtype(ω)) == Symbol("★₁") + @test nameof(★, symtype(η)) == Symbol("★₀⁻¹") +end + +@testset "Symtype Promotion" begin # test promoting types @test promote_symtype(d, u) == PrimalForm{1, :X, 2} @test promote_symtype(+, a, b) == Scalar @@ -83,9 +124,39 @@ u3, = @syms u3::PrimalForm{0, :X, 3} # test composition @test promote_symtype(d ∘ d, u) == PrimalForm{2, :X, 2} +end + +@testset "Term Construction" begin + + # test unary operator conversion to decaexpr + @test Term(1) == Lit(Symbol("1")) + @test Term(a) == Var(:a) + @test Term(c) == Var(:c) + @test Term(t) == Var(:t) + @test Term(∂ₜ(u)) == Tan(Var(:u)) + @test_broken Term(∂ₜ(u)) == Term(DerivOp(u)) + + @test Term(★(ω)) == App1(:★₁, Var(:ω)) + @test Term(★(η)) == App1(:★₀⁻¹, Var(:η)) + + # test binary operator conversion to decaexpr + @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) + + # TODO: Currently parses as addition + @test_broken Term(a - b) == App2(:-, Var(:a), Var(:b)) + @test Term(a * b) == Mult(Term[Var(:a), Var(:b)]) + @test Term(ω ∧ du) == App2(:∧₁₁, Var(:ω), Var(:du)) + + @test Term(ω + du + d(u)) == Plus(Term[App1(:d₀, Var(:u)), Var(:du), Var(:ω)]) + + let + @syms f(x, y, z) + @test_throws "was unable to convert" Term(f(a, b, u)) + end end + # this is not nabla but "bizarro Δ" del_expand_0, del_expand_1 = @operator ∇(S)::DECQuantity begin @match S begin From ccf0a797167581b0e70bdd4657278a88a20dd218 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Sat, 5 Oct 2024 16:43:49 -0400 Subject: [PATCH 27/30] More tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Discovered that the `Base.nameof` functions require logic to support InferredTypes. If any of the args is an InferredType, then we should simply return the generic operators, e.g. d or Δ. --- src/deca/ThDEC.jl | 46 +++++++++++++-------------- test/decasymbolic.jl | 75 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 88 insertions(+), 33 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index a95a1aa..8944569 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -88,8 +88,10 @@ export PrimalVF const DualVF{s,n} = VField{true,s,n} export DualVF -Base.nameof(u::Type{<:PrimalForm}) = Symbol("Form"*"$(dim(u))") -Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") +Base.nameof(u::Type{<:PrimalForm}) = Symbol("Form$(dim(u))") +Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm$(dim(u))") + +Base.show(io::IO, ω::Type{<:Form}) = print(io, "$(Base.nameof(ω)) on $(space(ω))") # ACTIVE PATTERNS @@ -235,6 +237,8 @@ end end end +@alias (∧₀₀, ∧₀₁, ∧₁₀, ∧₁₁, ∧₀₂, ∧₂₀) => ∧ + @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin PatInferredTypes(_) => InferredType @@ -256,28 +260,21 @@ Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal Base.nameof(s::Union{Const, Type{Const}}) = :Constant Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar +Base.nameof(s::Union{InferredType, Type{InferredType}}) = :infer + +# TODO: Remove me? Not sure if we ever fall into this case +# function Base.nameof(f::Form; with_dim_parameter=false) +# dual = isdual(f) ? "Dual" : "" +# formname = Symbol("$(dual)Form$(dim(f))") +# if with_dim_parameter +# return Expr(:curly, formname, dim(space(f))) +# else +# return formname +# end +# end -function Base.nameof(f::Form; with_dim_parameter=false) - dual = isdual(f) ? "Dual" : "" - formname = Symbol("$(dual)Form$(dim(f))") - if with_dim_parameter - return Expr(:curly, formname, dim(space(f))) - else - return formname - end -end - -# show methods # TODO: Remove me? Not being used anywhere -show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" - -function Base.show(io::IO, ω::Form) - if isdual(ω) - print(io, "DualForm($(dim(ω))) on $(space(ω))") - else - print(io, "PrimalForm($(dim(ω))) on $(space(ω))") - end -end +# show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" sub_dim(s) = as_sub(dim(s)) @@ -301,6 +298,7 @@ Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") #TODO: Add an option to output d for dual forms, typically d₀ -> dual_d₀ Base.nameof(::typeof(d), s) = Symbol("d$(sub_dim(s))") +#TODO: Add subtypes for the Laplacian # Base.nameof(::typeof(Δ), s) = Symbol("Δ$(sub_dim(s))") Base.nameof(::typeof(Δ), s) = :Δ @@ -323,8 +321,8 @@ function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol, d :DualForm0 => DualForm{0, space, dim} :DualForm1 => DualForm{1, space, dim} :DualForm2 => DualForm{2, space, dim} - :Infer => InferredType - _ => error("Received $qty") + :infer => InferredType + _ => error("Received $qty which is not a valid type for Decapodes") end end diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 11afe40..0753065 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -47,6 +47,22 @@ u3, = @syms u3::PrimalForm{0, :X, 3} # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} + + @test symtype(-(ϐ)) == InferredType + @test symtype(d(ϐ)) == InferredType + @test symtype(★(ϐ)) == InferredType + @test symtype(Δ(ϐ)) == InferredType + + @test symtype(ϐ + ϐ) == InferredType + @test symtype(ϐ + u) == InferredType + @test_throws OperatorError symtype(u + du + ϐ) + # The order of the addition counts when type checking, probably applies + # to other operations as well + @test_broken (try symtype(ϐ + u + du) catch e; e; end) isa Exception + + @test symtype(ϐ - u) == InferredType + @test symtype(ω * ϐ) == InferredType + @test symtype(ϐ ∧ ω) == InferredType end @testset "Type Information" begin @@ -139,6 +155,9 @@ end @test Term(★(ω)) == App1(:★₁, Var(:ω)) @test Term(★(η)) == App1(:★₀⁻¹, Var(:η)) + # The symbolics no longer captures ∘ + @test Term(∘(d, d)(u)) == App1(:d₁, App1(:d₀, Var(:u))) + # test binary operator conversion to decaexpr @test Term(a + b) == Plus(Term[Var(:a), Var(:b)]) @@ -239,16 +258,57 @@ end @testset "Conversion" begin - Exp = @decapode begin + roundtrip(d::SummationDecapode) = SummationDecapode(DecaExpr(SymbolicContext(Term(d)))) + + just_vars = @decapode begin + u::Form0 + v::Form0 + end + @test just_vars == roundtrip(just_vars) + + with_tan = @decapode begin u::Form0 v::Form0 ∂ₜ(v) == u end - context = SymbolicContext(Term(Exp)) - Exp′ = SummationDecapode(DecaExpr(context)) + @test with_tan == roundtrip(with_tan) - # does roundtripping work - @test Exp == Exp′ + with_op2 = @decapode begin + u::Form0 + v::Form1 + w::Form1 + + w == ∧₀₁(u, v) + end + @test with_op2 == roundtrip(with_op2) + + with_mult = @decapode begin + u::Form1 + v::Form1 + + v == u * 2 + end + @test with_mult == roundtrip(with_mult) + + with_circ = @decapode begin + u::Form0 + v::Form2 + v == ∘(d₁, d₀)(u) + end + with_circ_expanded = @decapode begin + u::Form0 + v::Form2 + v == d₁(d₀(u)) + end + @test with_circ_expanded == roundtrip(with_circ) + + with_infers = @decapode begin + v::Form1 + + w == ∧(v, u) + end + # Base.nameof doesn't yet support taking InferredTypes + @test_broken with_infers == roundtrip(with_infers) Heat = @decapode begin u::Form0 @@ -257,10 +317,7 @@ end ∂ₜ(v) == Δ(u)*κ end infer_types!(Heat) - context = SymbolicContext(Term(Heat)) - Heat′ = SummationDecapode(DecaExpr(context)) - - @test Heat == Heat′ + @test Heat == roundtrip(Heat) TumorInvasion = @decapode begin (C,fC)::Form0 From 467a93c75c88383c5f21b4efef85893b127ceb30 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:37:42 -0400 Subject: [PATCH 28/30] Remove unused functions --- src/deca/ThDEC.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 8944569..3afdad4 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -262,20 +262,6 @@ Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar Base.nameof(s::Union{InferredType, Type{InferredType}}) = :infer -# TODO: Remove me? Not sure if we ever fall into this case -# function Base.nameof(f::Form; with_dim_parameter=false) -# dual = isdual(f) ? "Dual" : "" -# formname = Symbol("$(dual)Form$(dim(f))") -# if with_dim_parameter -# return Expr(:curly, formname, dim(space(f))) -# else -# return formname -# end -# end - -# TODO: Remove me? Not being used anywhere -# show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" - sub_dim(s) = as_sub(dim(s)) Base.nameof(::typeof(-), s1, s2) = Symbol("$(sub_dim(s1))-$(sub_dim(s2))") @@ -284,11 +270,6 @@ const SUBSCRIPT_DIGIT_0 = '₀' as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) -# TODO: Do we want both nameof's for wedges? This one belows expects different args -function Base.nameof(::typeof(∧), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}} - Symbol("∧$(sub_dim(symtype(s1)))$(sub_dim(symtype(s2)))") -end - function Base.nameof(::typeof(∧), s1, s2) Symbol("∧$(sub_dim(s1))$(sub_dim(s2))") end From 7c5155d46c58634d6e1be284da7f68bda2900390 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:59:56 -0400 Subject: [PATCH 29/30] Fixed test in acset2symbolics Added a fix to the acset2symbolics code that helps fix the issue caused there by #77. A full fix to this issue can restore the code to just use infer_terminal_names --- src/acset2symbolic.jl | 4 +++- test/acset2symbolic.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index 330fa76..a32988b 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -43,11 +43,13 @@ end # e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ function merge_equations(d::SummationDecapode) eqn_lookup, terminal_eqns = Dict(), SymEqSym[] + deriv_op_tgts = d[incident(d, DerivOp, :op1), [:tgt, :name]] # Patches over issue #77 + terminal_vars = Set{Symbol}(vcat(infer_terminal_names(d), deriv_op_tgts)) foreach(to_symbolics(d)) do x sub = SymbolicUtils.substitute(x.rhs, eqn_lookup) push!(eqn_lookup, (x.lhs => sub)) - if x.lhs.name in infer_terminal_names(d) + if x.lhs.name in terminal_vars push!(terminal_eqns, SymEqSym(x.lhs, sub)) end end diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 0212a7e..e01ca9f 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -96,7 +96,7 @@ using Catlab c_exp == ∂ₜ(c_exp) end - @test_broken repeated_vars == symbolic_rewriting(self_changing) + @test self_changing == symbolic_rewriting(self_changing) literal = @decapode begin A::Form0 From 4f1f3273b0d59f60d27fde4f685ecd8866d32060 Mon Sep 17 00:00:00 2001 From: GeorgeR227 <78235421+GeorgeR227@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:36:39 -0400 Subject: [PATCH 30/30] Added some improvements to naming We now better deal with being passed infer types and added support for dual exterior derivatives. --- src/deca/ThDEC.jl | 28 +++++++++++++++++----------- test/decasymbolic.jl | 22 ++++++++++------------ 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 3afdad4..3fc40ae 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -142,6 +142,11 @@ export PatScalar end export PatVFParams +isForm(x) = @match symtype(x) begin + PatFormParams([_,_,_,_]) => true + _ => false +end + isDualForm(x) = @match symtype(x) begin PatFormParams([_,d,_,_]) => d _ => false @@ -262,26 +267,27 @@ Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar Base.nameof(s::Union{InferredType, Type{InferredType}}) = :infer -sub_dim(s) = as_sub(dim(s)) -Base.nameof(::typeof(-), s1, s2) = Symbol("$(sub_dim(s1))-$(sub_dim(s2))") +Base.nameof(::typeof(-), args...) = Symbol("-") const SUBSCRIPT_DIGIT_0 = '₀' as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) +sub_dim(args...) = all(isForm.(args)) ? join(as_sub.(dim.(args))) : "" -function Base.nameof(::typeof(∧), s1, s2) - Symbol("∧$(sub_dim(s1))$(sub_dim(s2))") -end +Base.nameof(::typeof(∧), s1, s2) = Symbol("∧$(sub_dim(s1, s2))") -Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ") -#TODO: Add an option to output d for dual forms, typically d₀ -> dual_d₀ -Base.nameof(::typeof(d), s) = Symbol("d$(sub_dim(s))") +function Base.nameof(::typeof(d), s) + dual = isdual(s) ? "dual_" : "" + Symbol("$(dual)d$(sub_dim(s))") +end -#TODO: Add subtypes for the Laplacian -# Base.nameof(::typeof(Δ), s) = Symbol("Δ$(sub_dim(s))") -Base.nameof(::typeof(Δ), s) = :Δ +#TODO: Add naming for dual +function Base.nameof(::typeof(Δ), s) + Symbol("Δ$(sub_dim(s))") +end function Base.nameof(::typeof(★), s) inv = isdual(s) ? "⁻¹" : "" diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index 0753065..436ceb1 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -102,10 +102,9 @@ end @test nameof(symtype(ω)) == :Form1 @test nameof(symtype(η)) == :DualForm2 - # TODO: Do we want this style of typed subtraction? - @test nameof(-, symtype(u), symtype(u)) == Symbol("₀-₀") - # TODO: This breaks since this expects the types to have a `dim` function - @test_broken nameof(-, symtype(a), symtype(b)) == Symbol("-") + @test nameof(-, symtype(u)) == Symbol("-") + @test nameof(-, symtype(u), symtype(u)) == Symbol("-") + @test nameof(-, symtype(a), symtype(b)) == Symbol("-") @test nameof(∧, symtype(u), symtype(u)) == Symbol("∧₀₀") @test nameof(∧, symtype(u), symtype(ω)) == Symbol("∧₀₁") @@ -114,15 +113,14 @@ end # TODO: Do we need a special designation for wedges with duals in them? @test nameof(∧, symtype(ω), symtype(η)) == Symbol("∧₁₂") - # TODO: Why is this being named as such? - @test nameof(∂ₜ, symtype(u)) == Symbol("∂ₜ(Form0)") - @test nameof(∂ₜ, symtype(d(u))) == Symbol("∂ₜ(Form1)") + @test nameof(∂ₜ, symtype(u)) == Symbol("∂ₜ") + @test nameof(∂ₜ, symtype(d(u))) == Symbol("∂ₜ") @test nameof(d, symtype(u)) == Symbol("d₀") - @test_broken nameof(d, symtype(η)) == Symbol("dual_d₂") + @test nameof(d, symtype(η)) == Symbol("dual_d₂") - @test_broken nameof(Δ, symtype(u)) == Symbol("Δ₀") - @test_broken nameof(Δ, symtype(ω)) == Symbol("Δ₁") + @test nameof(Δ, symtype(u)) == Symbol("Δ₀") + @test nameof(Δ, symtype(ω)) == Symbol("Δ₁") @test nameof(★, symtype(u)) == Symbol("★₀") @test nameof(★, symtype(ω)) == Symbol("★₁") @@ -308,13 +306,13 @@ end w == ∧(v, u) end # Base.nameof doesn't yet support taking InferredTypes - @test_broken with_infers == roundtrip(with_infers) + @test with_infers == roundtrip(with_infers) Heat = @decapode begin u::Form0 v::Form0 κ::Constant - ∂ₜ(v) == Δ(u)*κ + ∂ₜ(v) == Δ₀(u)*κ end infer_types!(Heat) @test Heat == roundtrip(Heat)