Skip to content

Commit

Permalink
adding promote_symtype and addressing some of the code review comment…
Browse files Browse the repository at this point in the history
…s given
  • Loading branch information
quffaro committed Aug 30, 2024
1 parent 77770e5 commit 1fa477b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 70 deletions.
7 changes: 4 additions & 3 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ..Deca

using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

# name collision with decapodes.Equation
struct SymbolicEquation{E}
Expand Down Expand Up @@ -107,14 +107,15 @@ end

function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__)
# associates each var to its sort...
@info d.context
context = map(d.context) do j
@info j.var
j.var => j.var
j.var => symtype(ThDEC, j.dim, j.space)
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{s}(v)
end
@info context
context = Dict{Symbol,Quantity}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...)
Expand Down
2 changes: 1 addition & 1 deletion src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ include("deca_acset.jl")
include("deca_visualization.jl")
include("ThDEC.jl")

@reexport using .TheoryDEC
@reexport using .ThDEC

""" function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64})
Expand Down
56 changes: 35 additions & 21 deletions src/deca/ThDEC.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module TheoryDEC
module ThDEC

using ..DiagrammaticEquations: @register, @alias, Quantity
using ..DiagrammaticEquations: @operator, @alias, Quantity

using MLStyle
using StructEquality
Expand All @@ -11,14 +11,14 @@ import Base: +, -, *
import Catlab: Δ,

# ##########################
# ThDEC
# DECQuantity
#
# Type necessary for symbolic utils
# ##########################

abstract type ThDEC <: Quantity end
abstract type DECQuantity <: Quantity end

struct Scalar <: ThDEC end
struct Scalar <: DECQuantity end
export Scalar

struct FormParams
Expand All @@ -28,18 +28,18 @@ struct FormParams
spacedim::Int
end

dim(fp::FormParams) = getproperty(fp, :dim)
duality(fp::FormParams) = getproperty(fp, :duality)
space(fp::FormParams) = getproperty(fp, :space)
spacedim(fp::FormParams) = getproperty(fp, :spacedim)
dim(fp::FormParams) = fp.dim
duality(fp::FormParams) = fp.duality
space(fp::FormParams) = fp.space
spacedim(fp::FormParams) = fp.spacedim

"""
i: dimension: 0,1,2, etc.
d: duality: true = dual, false = primal
s: name of the space (a symbol)
n: dimension of the space
"""
struct Form{i,d,s,n} <: ThDEC end
struct Form{i,d,s,n} <: DECQuantity end
export Form

# parameter accessors
Expand All @@ -53,15 +53,14 @@ export dim, isdual, space, spacedim
# convert form to form params
FormParams(::Type{<:Form{i,d,s,n}}) where {i,s,d,n} = FormParams(i,d,s,n)

struct VField{d,s,n} <: ThDEC end
struct VField{d,s,n} <: DECQuantity end
export VField

# parameter accessors
isdual(::Type{<:VField{d,s,n}}) where {d,s,n} = d
space(::Type{VField{d,s,n}}) where {d,s,n} = s
spacedim(::Type{VField{d,s,n}}) where {d,s,n} = n


# convenience functions
const PrimalForm{i,s,n} = Form{i,false,s,n}
export PrimalForm
Expand Down Expand Up @@ -112,11 +111,11 @@ end
# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term
unops = [:♯, :♭]

@register -(S)::ThDEC begin S end
@operator -(S)::DECQuantity begin S end

@register ∂ₜ(S)::ThDEC begin S end
@operator ∂ₜ(S)::DECQuantity begin S end

@register d(S)::ThDEC begin
@operator d(S)::DECQuantity begin
@match S begin
ActFormParams([i,d,s,n]) => Form{i+1,d,s,n}
_ => throw(SortError("Cannot apply the exterior derivative to $S"))
Expand All @@ -125,7 +124,7 @@ end

@alias (d₀, d₁) => d

@register (S)::ThDEC begin
@operator (S)::DECQuantity begin
@match S begin
ActFormParams([i,d,s,n]) => Form{n-i,d,s,n}
_ => throw(SortError("Cannot take the hodge star of $S"))
Expand All @@ -134,14 +133,14 @@ end

@alias (₀, ₁, ₂, ₀⁻¹, ₁⁻¹, ₂⁻¹) =>

@register Δ(S)::ThDEC begin
@operator Δ(S)::DECQuantity begin
@match S begin
ActForm(x) => (d((d(x))))
_ => throw(SortError("Cannot take the Laplacian of $S"))
end
end

@register +(S1, S2)::ThDEC begin
@operator +(S1, S2)::DECQuantity begin
@match (S1, S2) begin
(ActScalar, ActScalar) => Scalar
(ActScalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), ActScalar) => S1 # commutativity
Expand All @@ -159,17 +158,17 @@ end
end
end

@register -(S1, S2)::ThDEC begin +(S1, S2) end
@operator -(S1, S2)::DECQuantity begin +(S1, S2) end

@register *(S1, S2)::ThDEC begin
@operator *(S1, S2)::DECQuantity begin
@match (S1, S2) begin
(Scalar, Scalar) => Scalar
(Scalar, ActFormParams([i,d,s,n])) || (ActFormParams([i,d,s,n]), Scalar) => Form{i,d,s,n}
_ => throw(SortError("Cannot multiple $S1 and $S2"))
end
end

@register (S1, S2)::ThDEC begin
@operator (S1, S2)::DECQuantity begin
@match (S1, S2) begin
(ActFormParams([i1,d1,s1,n1]), ActFormParams([i2,d2,s2,n2])) => begin
(d1 == d2) && (s1 == s2) && (n1 == n2) || throw(SortError("Can only take a wedge product of two forms of the same duality on the same space"))
Expand All @@ -186,6 +185,8 @@ struct SortError <: Exception
message::String
end

# struct WedgeDimError <: SortError end

Base.nameof(s::Scalar) = :Constant

function Base.nameof(f::Form; with_dim_parameter=false)
Expand Down Expand Up @@ -226,4 +227,17 @@ function Base.nameof(::typeof(⋆), s)
Symbol("$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)")
end

function SymbolicUtils.symtype(::Quantity, qty::Symbol, space::Symbol)
@match qty begin
:Scalar => Scalar
:Form0 => PrimalForm{0, space, 1}
:Form1 => PrimalForm{1, space, 1}
:Form2 => PrimalForm{2, space, 1}
:DualForm0 => DualForm{0, space, 1}
:DualForm1 => DualForm{1, space, 1}
:DualForm2 => DualForm{2, space, 1}
_ => error("$qty")
end
end

end
61 changes: 21 additions & 40 deletions src/symbolictheoryutils.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

""" ThDEC in DiagrammaticEquations must be subtyped by Number to integrate with SymbolicUtils. An intermediary type, Quantity, makes it clearer that terms in the theory are "symbolic quantities" which behave like numbers
"""
abstract type Quantity <: Number end
export Quantity

"""
Registers a new function
Creates an operator `foo` with arguments which are types in a given Theory. This entails creating (1) a function which performs type construction and (2) a function which consumes BasicSymbolic variables and returns Terms.
```
@register foo(S1, S2, ...)::ThDEC begin
@operator foo(S1, S2, ...)::Theory begin
(body of function)
end
```
Expand All @@ -28,7 +30,7 @@ end
```
```
@register Δ(s::ThDEC) begin
@operator Δ(s::ThDEC) begin
@match s begin
::Scalar => error("Invalid")
::VField => error("Invalid")
Expand All @@ -43,7 +45,7 @@ end
will create an additional method for Δ for operating on BasicSymbolic
"""
macro register(head, body)
macro operator(head, body)

# parse body
ph = begin
Expand Down Expand Up @@ -81,24 +83,33 @@ macro register(head, body)
end

# binding type bindings to the basicsymbolics
basicsym_args = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings]
bs_arg_exprs = [:($var::$basicsym_generic) for (var, basicsym_generic) in basicsym_bindings]

# build constraints
constraints_expr = [:($T<:$Theory) for T in getindex.(generic_vars, 2)]
constraint_exprs = [:($T<:$Theory) for T in getindex.(generic_vars, 2)]

push!(result.args,
esc(quote
@nospecialize
function $f($(basicsym_args...)) where {$(constraints_expr...)}
function $f($(bs_arg_exprs...)) where {$(constraint_exprs...)}
s = $f($(getindex.(generic_vars, 2)...))
SymbolicUtils.Term{s}($f ,[$(getindex.(basicsym_bindings, 1)...)])
end
export $f
end))
end))

push!(result.args,
esc(quote
# we want to feed symtype the generics
function SymbolicUtils.promote_symtype(::typeof($f),
$(bs_arg_exprs...)) where {$(constraint_exprs...)}
$f($(getindex.(generic_vars, 2)...))
end
end))

return result
end
export @register
export @operator

function alias(x)
error("$x has no aliases")
Expand Down Expand Up @@ -129,33 +140,3 @@ macro alias(body)
result
end
export alias

macro see(body)
ph = begin
Expr(:(=), Expr(:where, Expr(:call, foo, typebindings), params...),
Expr(:block, body...)) => (foo, ph(typebindings), params, body)
Expr(:(::), vars...) => ph.(vars)
Expr(:curly, :Type, Expr(:<:, Expr(:curly, type, params...))) => (type, params)
s => s
end
ph(body)
quote
$foo(arg, s1::B1, s2::B1) where {S1,S2,B1<:BasicSymbolic{S1},B2<:BasicSymbolic{S2}}

end
end

@see dim(::Type{<:Form{i,d,s,n}}) where {i,d,s,n} = i

function Base.nameof(::typeof(), s1::B1, s2::B2) where {S1,S2,B1<:BasicSymbolic{S1}, B2<:BasicSymbolic{S2}}
Symbol("$(as_sub(dim(symtype(s1))))$(as_sub(dim(symtype(s2))))")
end


Expr(:=,
Expr(:where
[Expr(:call
foo,
Expr(:(::), e...)),
params...]),
Expr(:block, body...))
14 changes: 9 additions & 5 deletions test/decasymbolic.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
using Test
using DiagrammaticEquations.Deca.TheoryDEC
using DiagrammaticEquations.Deca.ThDEC
using DiagrammaticEquations.decapodes
using SymbolicUtils
using SymbolicUtils: symtype
using SymbolicUtils: symtype, promote_symtype

# load up some variable variables and expressions
a, b = @syms a::Scalar b::Scalar
Expand All @@ -11,10 +11,10 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2}
ϕ, ψ = @syms ϕ::PrimalVF{:X, 2} ψ::DualVF{:X, 2}
# TODO would be nice to pass the space globally to avoid duplication


@testset "Term Construction" begin

# TODO implement symtype
# test conversion to underlying type
@test symtype(a) == Scalar
@test symtype(u) == PrimalForm{0, :X, 2}
@test symtype(ω) == PrimalForm{1, :X, 2}
Expand All @@ -30,7 +30,7 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2}
@test Term(1) == Lit(Symbol("1"))
@test Term(a) == Var(:a)
@test Term(∂ₜ(u)) == Tan(Var(:u))
@test Term((ω)) == App1(:₁, Var())
@test_broken Term((ω)) == App1(:₁, Var())
# @test_broken Term(ThDEC.♭(ψ)) == App1(:♭s, Var(:ψ))
# @test Term(DiagrammaticEquations.ThDEC.♯(du))

Expand All @@ -39,7 +39,11 @@ u, v = @syms u::PrimalForm{0, :X, 2} du::PrimalForm{1, :X, 2}
# test binary operator conversion to decaexpr
@test Term(a + b) == Plus(Term[Var(:a), Var(:b)])
@test Term(a * b) == Mult(Term[Var(:a), Var(:b)])
@test Term du) == App2(:₁₁, Var(), Var(:du))
@test Term du) == App2(:₁₁, Var(), Var(:du))

@test promote_symtype(+, a, b) == Scalar
@test promote_symtype(, u, u) == PrimalForm{0, :X, 2}
@test promote_symtype(, u, ω) == PrimalForm{1, :X, 2}

end

Expand Down

0 comments on commit 1fa477b

Please sign in to comment.