Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interop with SymbolicUtils #64

Merged
merged 35 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c2d64b5
symbolicutils interop
olynch Aug 14, 2024
c11d4dd
added spaces to the sorts of forms and vector fields
olynch Aug 16, 2024
7de10e8
integrated changes to ThDEC with DecaSymbolic
olynch Aug 17, 2024
5928cc7
added tests for decasymbolic, but needs wrinkles ironed out. musical …
quffaro Aug 18, 2024
d7d6a5b
reverted src/DiagX to restore exports and adding Project.toml
quffaro Aug 19, 2024
2f79b0c
Merge remote-tracking branch 'origin/space-sorts' into symbolicutilsi…
quffaro Aug 19, 2024
a438192
updated code and tests after merge from space-sorts
quffaro Aug 19, 2024
18a1585
review changes:
jpfairbanks Aug 21, 2024
090ddfe
resolving some comments from code review.
quffaro Aug 22, 2024
d8be4ae
adding @alias and @register macros to make DecaSymbolic function work…
quffaro Aug 27, 2024
77770e5
experimenting with a type-driven approach
quffaro Aug 28, 2024
1fa477b
adding promote_symtype and addressing some of the code review comment…
quffaro Aug 30, 2024
3b265ac
refactoring @operator to integrate with promote_symtype
quffaro Sep 4, 2024
7f8597a
operator macro parses @rule but need to write tests and iron out rela…
quffaro Sep 5, 2024
154e51f
rewriting just needs tests
quffaro Sep 6, 2024
18bb71f
TST: add some klausmeier rewrites
jpfairbanks Sep 9, 2024
4ba6dce
BUG: fix method shadowing for existing operators like +/-
jpfairbanks Sep 9, 2024
82f65ce
added more tests, @operator macro is more flexible
quffaro Sep 10, 2024
e23459f
Merge pull request #73 from AlgebraicJulia/jpf/symbolictypes
quffaro Sep 10, 2024
e82e826
almost round-tripping in klausmeier. equations are not currently pass…
quffaro Sep 10, 2024
f32b3c0
Merge branch 'cm/symbolictypes' of github.com:AlgebraicJulia/Diagramm…
quffaro Sep 10, 2024
ecfa931
added rules function which dispatches on function symbol and the Val(…
quffaro Sep 10, 2024
4603ed9
fixed docs for @operator, fixed Term
quffaro Sep 11, 2024
f2ef1d0
Merge pull request #71 from AlgebraicJulia/cm/symbolictypes
quffaro Sep 11, 2024
5324de3
Loosened aqua tests
GeorgeR227 Sep 12, 2024
0a314a8
fixing bug where (+) always returns Scalar
quffaro Sep 18, 2024
5d5c25d
Expression level rewriting (#69)
GeorgeR227 Oct 3, 2024
70a420d
added tests for errors and consolidated errors with George
quffaro Oct 3, 2024
85572a4
Fixed hodge nameof and more tests
GeorgeR227 Oct 4, 2024
df41b18
Many more tests
GeorgeR227 Oct 4, 2024
6d1edf0
Merge branch 'main' into symbolicutilsinterop
jpfairbanks Oct 4, 2024
ccf0a79
More tests
GeorgeR227 Oct 5, 2024
467a93c
Remove unused functions
GeorgeR227 Oct 9, 2024
7c5155d
Fixed test in acset2symbolics
GeorgeR227 Oct 9, 2024
4f1f327
Added some improvements to naming
GeorgeR227 Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
Reexport = "1.2.2"
StructEquality = "2.1.0"
SymbolicUtils = "3.1.2"
Unicode = "1.6"
julia = "1.6"
7 changes: 6 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
using Catlab.CategoricalAlgebra
import Catlab.CategoricalAlgebra: ∧
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using Catlab.ACSetInterface
using MLStyle
import Unicode
using Reexport

## TODO:
## generate schema from a _theory_
Expand All @@ -59,9 +61,12 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

using .Deca
@reexport using .Deca
@reexport using .SymbolicUtilsInterop

end
153 changes: 153 additions & 0 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
module SymbolicUtilsInterop

using ..DiagrammaticEquations: AbstractDecapode, Quantity
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca

using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

# name collision with decapodes.Equation
struct SymbolicEquation{E}
lhs::E
rhs::E
end

Base.show(io::IO, e::SymbolicEquation) = begin
print(io, e.lhs); print(io, " == "); print(io, e.rhs)

Check warning on line 19 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L18-L19

Added lines #L18 - L19 were not covered by tests
end

## a struct carry the symbolic variables and their equations
struct SymbolicContext
vars::Vector{Symbolic}
equations::Vector{SymbolicEquation{Symbolic}}
end
export SymbolicContext

Base.show(io::IO, d::SymbolicContext) = begin
println(io, "SymbolicContext(")
println(io, " Variables: [$(join(d.vars, ", "))]")
println(io, " Equations: [")
eqns = map(d.equations) do op
" $(op)"

Check warning on line 34 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L29-L34

Added lines #L29 - L34 were not covered by tests
end
println(io, "$(join(eqns,",\n"))])")

Check warning on line 36 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L36

Added line #L36 was not covered by tests
end

## BasicSymbolic -> DecaExpr
function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
if SymbolicUtils.issym(t)
decapodes.Var(nameof(t))

Check warning on line 42 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L40-L42

Added lines #L40 - L42 were not covered by tests
else
op = SymbolicUtils.head(t)
args = SymbolicUtils.arguments(t)
termargs = Term.(args)
if op == +
decapodes.Plus(termargs)
elseif op == *
decapodes.Mult(termargs)
elseif op == ∂ₜ
decapodes.Tan(only(termargs))
elseif length(args) == 1
decapodes.App1(nameof(op, symtype.(args)...), termargs...)
elseif length(args) == 2
decapodes.App2(nameof(op, symtype.(args)...), termargs...)

Check warning on line 56 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L44-L56

Added lines #L44 - L56 were not covered by tests
else
error("was unable to convert $t into a Term")

Check warning on line 58 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L58

Added line #L58 was not covered by tests
end
end
end

decapodes.Term(x::Real) = decapodes.Lit(Symbol(x))

Check warning on line 63 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L63

Added line #L63 was not covered by tests

function decapodes.DecaExpr(d::SymbolicContext)
context = map(d.vars) do var
decapodes.Judgement(nameof(var), nameof(symtype(var)), :I)

Check warning on line 67 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
end
equations = map(d.equations) do eq
decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs))

Check warning on line 70 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L69-L70

Added lines #L69 - L70 were not covered by tests
end
decapodes.DecaExpr(context, equations)

Check warning on line 72 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L72

Added line #L72 was not covered by tests
end

"""
Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC

Example:
```
a = @syms a::Real
context = Dict(:a => Scalar(), :u => PrimalForm(0))
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term, __module__=@__MODULE__)

Check warning on line 85 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L85

Added line #L85 was not covered by tests
# user must import symbols into scope
! = (f -> getfield(__module__, f))
@match t begin
Var(name) => SymbolicUtils.Sym{context[name]}(name)
Lit(v) => Meta.parse(string(v))

Check warning on line 90 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L87-L90

Added lines #L87 - L90 were not covered by tests
# see heat_eq test: eqs had AppCirc1, but this returns
# App1(f, App1(...)
AppCirc1(fs, arg) => foldr(

Check warning on line 93 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L93

Added line #L93 was not covered by tests
# panics with constants like :k
# see test/language.jl
(f, x) -> (!(f))(x),

Check warning on line 96 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L96

Added line #L96 was not covered by tests
fs;
init=BasicSymbolic(context, arg, __module__)
)
App1(f, x) => (!(f))(BasicSymbolic(context, x, __module__))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x, __module__), BasicSymbolic(context, y, __module__))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs, Ref(__module__))...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs, Ref(__module__))...)
Tan(x) => ∂ₜ(BasicSymbolic(context, x, __module__))

Check warning on line 104 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L100-L104

Added lines #L100 - L104 were not covered by tests
end
end

function SymbolicContext(d::decapodes.DecaExpr, __module__=@__MODULE__)

Check warning on line 108 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L108

Added line #L108 was not covered by tests
# associates each var to its sort...
context = map(d.context) do j
j.var => symtype(Deca.DECQuantity, j.dim, j.space)

Check warning on line 111 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{s}(v)

Check warning on line 115 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L114-L115

Added lines #L114 - L115 were not covered by tests
end
context = Dict{Symbol,DataType}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs], Ref(__module__))...)

Check warning on line 119 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L117-L119

Added lines #L117 - L119 were not covered by tests
end
SymbolicContext(vars, eqs)

Check warning on line 121 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L121

Added line #L121 was not covered by tests
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Equation(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)

Check warning on line 125 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
end

""" function SummationDecapode(e::SymbolicContext) """
function SummationDecapode(e::SymbolicContext)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()

Check warning on line 131 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L129-L131

Added lines #L129 - L131 were not covered by tests

foreach(e.vars) do var

Check warning on line 133 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L133

Added line #L133 was not covered by tests
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var)))
symbol_table[var.name] = var_id

Check warning on line 136 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
end

deletions = Vector{Int}()
foreach(e.equations) do eq
eval_eq!(eq, d, symbol_table, deletions)

Check warning on line 141 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L139-L141

Added lines #L139 - L141 were not covered by tests
end
rem_parts!(d, :Var, sort(deletions))

Check warning on line 143 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L143

Added line #L143 was not covered by tests

recognize_types(d)

Check warning on line 145 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L145

Added line #L145 was not covered by tests

fill_names!(d)
d[:name] = normalize_unicode.(d[:name])
make_sum_mult_unique!(d)
return d

Check warning on line 150 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L147-L150

Added lines #L147 - L150 were not covered by tests
end

end
5 changes: 5 additions & 0 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ using DataStructures
using ..DiagrammaticEquations
using Catlab

using Reexport

import ..infer_types!, ..resolve_overloads!

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!

include("deca_acset.jl")
include("deca_visualization.jl")
include("ThDEC.jl")

@reexport using .ThDEC

""" function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64})

Expand Down
Loading
Loading