Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interop with SymbolicUtils #64

Merged
merged 35 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c2d64b5
symbolicutils interop
olynch Aug 14, 2024
c11d4dd
added spaces to the sorts of forms and vector fields
olynch Aug 16, 2024
7de10e8
integrated changes to ThDEC with DecaSymbolic
olynch Aug 17, 2024
5928cc7
added tests for decasymbolic, but needs wrinkles ironed out. musical …
quffaro Aug 18, 2024
d7d6a5b
reverted src/DiagX to restore exports and adding Project.toml
quffaro Aug 19, 2024
2f79b0c
Merge remote-tracking branch 'origin/space-sorts' into symbolicutilsi…
quffaro Aug 19, 2024
a438192
updated code and tests after merge from space-sorts
quffaro Aug 19, 2024
18a1585
review changes:
jpfairbanks Aug 21, 2024
090ddfe
resolving some comments from code review.
quffaro Aug 22, 2024
d8be4ae
adding @alias and @register macros to make DecaSymbolic function work…
quffaro Aug 27, 2024
77770e5
experimenting with a type-driven approach
quffaro Aug 28, 2024
1fa477b
adding promote_symtype and addressing some of the code review comment…
quffaro Aug 30, 2024
3b265ac
refactoring @operator to integrate with promote_symtype
quffaro Sep 4, 2024
7f8597a
operator macro parses @rule but need to write tests and iron out rela…
quffaro Sep 5, 2024
154e51f
rewriting just needs tests
quffaro Sep 6, 2024
18bb71f
TST: add some klausmeier rewrites
jpfairbanks Sep 9, 2024
4ba6dce
BUG: fix method shadowing for existing operators like +/-
jpfairbanks Sep 9, 2024
82f65ce
added more tests, @operator macro is more flexible
quffaro Sep 10, 2024
e23459f
Merge pull request #73 from AlgebraicJulia/jpf/symbolictypes
quffaro Sep 10, 2024
e82e826
almost round-tripping in klausmeier. equations are not currently pass…
quffaro Sep 10, 2024
f32b3c0
Merge branch 'cm/symbolictypes' of github.com:AlgebraicJulia/Diagramm…
quffaro Sep 10, 2024
ecfa931
added rules function which dispatches on function symbol and the Val(…
quffaro Sep 10, 2024
4603ed9
fixed docs for @operator, fixed Term
quffaro Sep 11, 2024
f2ef1d0
Merge pull request #71 from AlgebraicJulia/cm/symbolictypes
quffaro Sep 11, 2024
5324de3
Loosened aqua tests
GeorgeR227 Sep 12, 2024
0a314a8
fixing bug where (+) always returns Scalar
quffaro Sep 18, 2024
5d5c25d
Expression level rewriting (#69)
GeorgeR227 Oct 3, 2024
70a420d
added tests for errors and consolidated errors with George
quffaro Oct 3, 2024
85572a4
Fixed hodge nameof and more tests
GeorgeR227 Oct 4, 2024
df41b18
Many more tests
GeorgeR227 Oct 4, 2024
6d1edf0
Merge branch 'main' into symbolicutilsinterop
jpfairbanks Oct 4, 2024
ccf0a79
More tests
GeorgeR227 Oct 5, 2024
467a93c
Remove unused functions
GeorgeR227 Oct 9, 2024
7c5155d
Fixed test in acset2symbolics
GeorgeR227 Oct 9, 2024
4f1f327
Added some improvements to naming
GeorgeR227 Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
Reexport = "1.2.2"
StructEquality = "2.1.0"
SymbolicUtils = "3.1.2"
Unicode = "1.6"
julia = "1.6"
10 changes: 9 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
using Catlab.CategoricalAlgebra
import Catlab.CategoricalAlgebra: ∧
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using Catlab.ACSetInterface
using MLStyle
import Unicode
using Reexport

## TODO:
## generate schema from a _theory_
Expand All @@ -62,9 +64,15 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

using .Deca
@reexport using .Deca
@reexport using .SymbolicUtilsInterop

include("acset2symbolic.jl")

end
157 changes: 157 additions & 0 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
module SymbolicUtilsInterop

using ACSets
using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp
using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique!
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca

using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

# name collision with decapodes.Equation
struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = print(io, "$(e.lhs) == $(e.rhs)")

## a struct carry the symbolic variables and their equations
struct SymbolicContext
vars::Vector{Symbolic}
equations::Vector{SymbolicEquation{Symbolic}}
end
export SymbolicContext

Base.show(io::IO, d::SymbolicContext) = begin
println(io, "SymbolicContext(")
println(io, " Variables: [$(join(d.vars, ", "))]")
println(io, " Equations: [")
eqns = map(d.equations) do op
" $(op)"
end
println(io, "$(join(eqns,",\n"))])")
end

## BasicSymbolic -> DecaExpr
function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
if SymbolicUtils.issym(t)
decapodes.Var(nameof(t))
else
op = SymbolicUtils.head(t)
args = SymbolicUtils.arguments(t)
termargs = Term.(args)
if op == +
decapodes.Plus(termargs)
elseif op == *
decapodes.Mult(termargs)
elseif op ∈ [DerivOp, ∂ₜ]
decapodes.Tan(only(termargs))
elseif length(args) == 1
decapodes.App1(nameof(op, symtype.(args)...), termargs...)
elseif length(args) == 2
decapodes.App2(nameof(op, symtype.(args)...), termargs...)
else
error("was unable to convert $t into a Term")
end
end
end
# TODO subtraction is not parsed as such. e.g.,
# a, b = @syms a::Scalar b::Scalar
# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)]))

decapodes.Term(x::Real) = decapodes.Lit(Symbol(x))

function decapodes.DecaExpr(d::SymbolicContext)
context = map(d.vars) do var
decapodes.Judgement(nameof(var), nameof(symtype(var)), :I)
end
equations = map(d.equations) do eq
decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs))
end
decapodes.DecaExpr(context, equations)
end

"""
Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC

Example:
```
a = @syms a::Real
context = Dict(:a => Scalar(), :u => PrimalForm(0))
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term)
# user must import symbols into scope
! = (f -> getfield(@__MODULE__, f))
@match t begin
Var(name) => SymbolicUtils.Sym{context[name]}(name)
Lit(v) => Meta.parse(string(v))
# see heat_eq test: eqs had AppCirc1, but this returns
# App1(f, App1(...)
AppCirc1(fs, arg) => foldr(
# panics with constants like :k
# see test/language.jl
(f, x) -> (!(f))(x),
fs;
init=BasicSymbolic(context, arg)
)
App1(f, x) => (!(f))(BasicSymbolic(context, x))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...)
Tan(x) => (!(DerivOp))(BasicSymbolic(context, x))
end
end

function SymbolicContext(d::decapodes.DecaExpr)
# associates each var to its sort...
context = map(d.context) do j
j.var => symtype(Deca.DECQuantity, j.dim, j.space)
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{s}(v)
end
context = Dict{Symbol,DataType}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...)
end
SymbolicContext(vars, eqs)
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
end

""" function SummationDecapode(e::SymbolicContext) """
function SummationDecapode(e::SymbolicContext)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var)))
symbol_table[var.name] = var_id
end

deletions = Vector{Int}()
foreach(e.equations) do eq
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))

recognize_types(d)

fill_names!(d)
d[:name] = normalize_unicode.(d[:name])
make_sum_mult_unique!(d)
return d
end

end
1 change: 1 addition & 0 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
export recognize_types

""" is_expanded(d::AbstractNamedDecapode)

Expand Down
83 changes: 83 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using DiagrammaticEquations
using ACSets
using SymbolicUtils
using SymbolicUtils: BasicSymbolic, Symbolic

export symbolic_rewriting

const EQUALITY = (==)
const SymEqSym = SymbolicEquation{Symbolic}

function symbolics_lookup(d::SummationDecapode)
Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type
(name, decavar_to_symbolics(name, type))
end)
end

function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space)
SymbolicUtils.Sym{new_type}(var_name)
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(op_sym, input_syms...)
SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms))
end

function to_symbolics(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d))
end

function symbolic_rewriting(d::SummationDecapode, rewriter=identity)
d′ = infer_types!(deepcopy(d))
eqns = merge_equations(d′)
to_acset(d′, map(rewriter, eqns))
end

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ
function merge_equations(d::SummationDecapode)
eqn_lookup, terminal_eqns = Dict(), SymEqSym[]
deriv_op_tgts = d[incident(d, DerivOp, :op1), [:tgt, :name]] # Patches over issue #77
terminal_vars = Set{Symbol}(vcat(infer_terminal_names(d), deriv_op_tgts))

foreach(to_symbolics(d)) do x
sub = SymbolicUtils.substitute(x.rhs, eqn_lookup)
push!(eqn_lookup, (x.lhs => sub))
if x.lhs.name in terminal_vars
push!(terminal_eqns, SymEqSym(x.lhs, sub))
end
end

map(terminal_eqns) do eqn
SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs])
end
end

function to_acset(d::SummationDecapode, sym_exprs)
literals = incident(d, :Literal, :type)

outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i
:($(d[i, :name])::$(d[i, :type]))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)
reify!(exprs) = foreach(exprs) do x
if typeof(x) == Expr && x.head == :call
x.args[1] = nameof(x.args[1])
reify!(x.args[2:end])
end
end
reify!(final_exprs)

deca_block = quote end
deca_block.args = [outer_types..., final_exprs...]

∘(infer_types!, SummationDecapode, parse_decapode)(deca_block)
end
5 changes: 5 additions & 0 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ using DataStructures
using ..DiagrammaticEquations
using Catlab

using Reexport

import ..infer_types!, ..resolve_overloads!

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!

include("deca_acset.jl")
include("deca_visualization.jl")
include("ThDEC.jl")

@reexport using .ThDEC

""" function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64})

Expand Down
Loading
Loading