diff --git a/src/SymbolicUtilsInterop.jl b/src/SymbolicUtilsInterop.jl index 9d4f6e0..2e50048 100644 --- a/src/SymbolicUtilsInterop.jl +++ b/src/SymbolicUtilsInterop.jl @@ -1,7 +1,7 @@ module SymbolicUtilsInterop using ACSets -using ..DiagrammaticEquations: AbstractDecapode, Quantity +using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique! import ..DiagrammaticEquations: eval_eq!, SummationDecapode using ..decapodes @@ -51,7 +51,7 @@ function decapodes.Term(t::SymbolicUtils.BasicSymbolic) decapodes.Plus(termargs) elseif op == * decapodes.Mult(termargs) - elseif op == ∂ₜ + elseif op ∈ [DerivOp, ∂ₜ] decapodes.Tan(only(termargs)) elseif length(args) == 1 decapodes.App1(nameof(op, symtype.(args)...), termargs...) @@ -85,9 +85,9 @@ Example: SymbolicUtils.BasicSymbolic(context, Term(a)) ``` """ -function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__) +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term) # user must import symbols into scope - ! = (f -> getfield(__module__, f)) + ! = (f -> getfield(@__MODULE__, f)) @match t begin Var(name) => SymbolicUtils.Sym{context[name]}(name) Lit(v) => Meta.parse(string(v)) @@ -98,17 +98,17 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapode # see test/language.jl (f, x) -> (!(f))(x), fs; - init=BasicSymbolic(context, arg, __module__) + init=BasicSymbolic(context, arg) ) - App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__)) - App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__)) - Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...) - Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__)) + App1(f, x) => (!(f))(BasicSymbolic(context, x)) + App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => (!(DerivOp))(BasicSymbolic(context, x)) end end -function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) +function SymbolicContext(d::decapodes.DecaExpr) # associates each var to its sort... context = map(d.context) do j j.var => symtype(Deca.DECQuantity, j.dim, j.space) @@ -119,7 +119,7 @@ function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__) end context = Dict{Symbol,DataType}(context) eqs = map(d.equations) do eq - SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...) + SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...) end SymbolicContext(vars, eqs) end diff --git a/src/acset2symbolic.jl b/src/acset2symbolic.jl index ee1898f..667fbdc 100644 --- a/src/acset2symbolic.jl +++ b/src/acset2symbolic.jl @@ -79,7 +79,7 @@ end function extract_symexprs(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}) sym_list = SymbolicEquation{Symbolic}[] for node in topological_sort_edges(d) - retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC + # retrieve_name(d, node) != DerivOp || continue # This is not part of ThDEC push!(sym_list, to_symbolics(d, symvar_lookup, node)) end sym_list diff --git a/src/deca/ThDEC.jl b/src/deca/ThDEC.jl index 91ba946..e71701b 100644 --- a/src/deca/ThDEC.jl +++ b/src/deca/ThDEC.jl @@ -26,11 +26,14 @@ SymbolicUtils.symtype(::Type{S}) where S<:DECQuantity = S abstract type AbstractScalar <: DECQuantity end +struct InferredType <: DECQuantity end +export InferredType + struct Scalar <: AbstractScalar end struct Parameter <: AbstractScalar end -struct ConstScalar <: AbstractScalar end +struct Const <: AbstractScalar end struct Literal <: AbstractScalar end -export Scalar, Parameter, ConstScalar, Literal +export Scalar, Parameter, Const, Literal struct FormParams dim::Int @@ -90,6 +93,18 @@ Base.nameof(u::Type{<:DualForm}) = Symbol("DualForm"*"$(dim(u))") # ACTIVE PATTERNS +@active PatInferredType(T) begin + if T <: InferredType + Some(InferredType) + end +end + +@active PatInferredTypes(T) begin + if any(S->S<:InferredType, T) + Some(InferredType) + end +end + @active PatForm(T) begin if T <: Form Some(T) @@ -158,6 +173,7 @@ export isDualForm, isForm0, isForm1, isForm2 @operator d(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{i+1,d,s,n} _ => throw(ExteriorDerivativeError(S)) end @@ -167,6 +183,7 @@ end @operator ★(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatFormParams([i,d,s,n]) => Form{n-i,d,s,n} _ => throw(HodgeStarError(S)) end @@ -176,6 +193,7 @@ end @operator Δ(S)::DECQuantity begin @match S begin + PatInferredType(_) => InferredType PatForm(_) => promote_symtype(★ ∘ d ∘ ★ ∘ d, S) _ => throw(LaplacianError(S)) end @@ -185,8 +203,11 @@ end @alias (Δ₀, Δ₁, Δ₂) => Δ +# Base.show(io::IO, + @operator +(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => S1 # commutativity (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin @@ -206,6 +227,7 @@ end @operator *(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatScalar(_), PatScalar(_)) => Scalar (PatScalar(_), PatFormParams([i,d,s,n])) || (PatFormParams([i,d,s,n]), PatScalar(_)) => Form{i,d,s,n} _ => throw(BinaryOpError("multiply", S1, S2)) @@ -214,6 +236,7 @@ end @operator ∧(S1, S2)::DECQuantity begin @match (S1, S2) begin + PatInferredTypes(_) => InferredType (PatFormParams([i1,d1,s1,n1]), PatFormParams([i2,d2,s2,n2])) => begin (d1 == d2) && (s1 == s2) && (n1 == n2) || throw(WedgeOpError(S1, S2)) if i1 + i2 <= n1 @@ -228,10 +251,10 @@ end abstract type SortError <: Exception end -Base.nameof(s::Literal) = :Literal -Base.nameof(s::ConstScalar) = :ConstScalar -Base.nameof(s::Parameter) = :Parameter -Base.nameof(s::Scalar) = :Scalar +Base.nameof(s::Union{Literal,Type{Literal}}) = :Literal +Base.nameof(s::Union{Const, Type{Const}}) = :Constant +Base.nameof(s::Union{Parameter, Type{Parameter}}) = :Parameter +Base.nameof(s::Union{Scalar, Type{Scalar}}) = :Scalar function Base.nameof(f::Form; with_dim_parameter=false) dual = isdual(f) ? "Dual" : "" @@ -264,6 +287,8 @@ function Base.nameof(::typeof(∧), s1, s2) Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") end +Base.nameof(::typeof(∂ₜ), s) = Symbol("∂ₜ($(nameof(s)))") + Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))") Base.nameof(::typeof(Δ), s) = :Δ @@ -276,7 +301,7 @@ end function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) @match qty begin :Scalar => Scalar - :Constant => ConstScalar + :Constant => Const :Parameter => Parameter :Literal => Literal :Form0 => PrimalForm{0, space, 1} @@ -285,6 +310,7 @@ function SymbolicUtils.symtype(::Type{<:Quantity}, qty::Symbol, space::Symbol) :DualForm0 => DualForm{0, space, 1} :DualForm1 => DualForm{1, space, 1} :DualForm2 => DualForm{2, space, 1} + :Infer => InferredType _ => error("Received $qty") end end diff --git a/src/symbolictheoryutils.jl b/src/symbolictheoryutils.jl index f1693f3..38f4fcc 100644 --- a/src/symbolictheoryutils.jl +++ b/src/symbolictheoryutils.jl @@ -108,6 +108,8 @@ macro operator(head, body) SymbolicUtils.Term{s}($f, Any[$(argnames...)]) end export $f + + Base.show(io::IO, ::typeof($f)) = print(io, $f) end) # if there are rewriting rules, add a method which accepts the function symbol and its arity (to prevent shadowing on operators like `-`) diff --git a/test/acset2symbolic.jl b/test/acset2symbolic.jl index 89de387..c928901 100644 --- a/test/acset2symbolic.jl +++ b/test/acset2symbolic.jl @@ -202,3 +202,17 @@ end @test Heat_open ≃ z end + +x=@decapode begin + u::Form0 + ∂ₜ(u) == u +end +symbolic_rewriting(x) +# if the `for op1 in parts(og_d, :Op1)...` block is removed, this is annihilated because x has no terminals + +x=@decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u +end +symbolic_rewriting(x) # fine diff --git a/test/decasymbolic.jl b/test/decasymbolic.jl index b3ff661..722157c 100644 --- a/test/decasymbolic.jl +++ b/test/decasymbolic.jl @@ -7,18 +7,20 @@ using SymbolicUtils: symtype, promote_symtype, Symbolic using MLStyle # load up some variable variables and expressions -ℓ, = @syms ℓ::Literal -c, t = @syms c::ConstScalar t::Parameter -a, b = @syms a::Scalar b::Scalar -u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} -ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} -ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} +👻, = @syms 👻::InferredType +ℓ, = @syms ℓ::Literal +c, t = @syms c::Const t::Parameter +a, b = @syms a::Scalar b::Scalar +u, du = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} +ω, η = @syms ω::PrimalForm{1, :X, 2} η::DualForm{2, :X, 2} +ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2} # TODO would be nice to pass the space globally to avoid duplication @testset "Term Construction" begin + @test symtype(👻) == InferredType @test symtype(ℓ) == Literal - @test symtype(c) == ConstScalar + @test symtype(c) == Const @test symtype(t) == Parameter @test symtype(a) == Scalar @@ -31,9 +33,12 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2} @test symtype(c + t) == Scalar @test symtype(t + t) == Scalar @test symtype(c + c) == Scalar + @test symtype(t + 👻) == InferredType @test symtype(u ∧ ω) == PrimalForm{1, :X, 2} @test symtype(ω ∧ ω) == PrimalForm{2, :X, 2} + @test symtype(u ∧ 👻) == InferredType + # @test_throws ThDEC.SortError ThDEC.♯(u) @test symtype(Δ(u) + Δ(u)) == PrimalForm{0, :X, 2} @@ -112,40 +117,39 @@ end @testset "Conversion" begin - context = Dict(:a => Scalar(),:b => Scalar() - ,:u => PrimalForm(0, X),:du => PrimalForm(1, X)) - js = [Judgement(:u, :Form0, :X) - ,Judgement(:∂ₜu, :Form0, :X) - ,Judgement(:Δu, :Form0, :X)] - eqs = [Eq(Var(:∂ₜu) - , AppCirc1([:⋆₂⁻¹, :d₁, :⋆₁, :d₀], Var(:u))) - , Eq(Tan(Var(:u)), Var(:∂ₜu))] - heat_eq = DecaExpr(js, eqs) + Exp = @decapode begin + u::Form0 + v::Form0 + ∂ₜ(v) == u + end + context = SymbolicContext(Term(Exp)) + Exp′ = SummationDecapode(DecaExpr(context)) - symb_heat_eq = DecaSymbolic(lookup, heat_eq) - deca_expr = DecaExpr(symb_heat_eq) + # does roundtripping work + @test Exp == Exp′ -end - -@testset "Moving between DecaExpr and DecaSymbolic" begin - - @test js == deca_expr.context - - # eqs in the left has AppCirc1[vector, term] - # deca_expr.equations on the right has nested App1 - # expected behavior is that nested AppCirc1 is preserved - @test_broken eqs == deca_expr.equations - # use expand_operators to get rid of parentheses - # infer_types and resolve_overloads - -end + Heat = @decapode begin + u::Form0 + v::Form0 + κ::Constant + ∂ₜ(v) == Δ(u)*κ + end + infer_types!(Heat) + context = SymbolicContext(Term(Heat)) + Heat′ = SummationDecapode(DecaExpr(context)) -# convert both into ACSets then is_iso them -@testset "" begin + @test Heat == Heat′ - Σ = DiagrammaticEquations.SummationDecapode(deca_expr) - Δ = DiagrammaticEquations.SummationDecapode(symb_heat_eq) - @test Σ == Δ + TumorInvasion = @decapode begin + (C,fC)::Form0 + (Dif,Kd,Cmax)::Constant + ∂ₜ(C) == Dif * Δ(C) + fC - Kd * C + end + infer_types!(TumorInvasion) + context = SymbolicContext(Term(TumorInvasion)) + TumorInvasion′ = SummationDecapode(DecaExpr(context)) + # new terms introduced + @test_broken TumorInvasion == TumorInvasion′ end diff --git a/test/runtests.jl b/test/runtests.jl index 4a64095..ec5047b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,10 @@ end include("openoperators.jl") end +@testset "ThDEC Symbolics" begin + include("decasymbolic.jl") +end + @testset "Symbolic Rewriting" begin include("graph_traversal.jl") include("acset2symbolic.jl")