Skip to content

Commit

Permalink
Merge pull request #64 from AlgebraicJulia/symbolicutilsinterop
Browse files Browse the repository at this point in the history
Interop with SymbolicUtils
  • Loading branch information
jpfairbanks authored Oct 11, 2024
2 parents 7c6feda + 4f1f327 commit 2f82754
Show file tree
Hide file tree
Showing 19 changed files with 1,631 additions and 34 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
Reexport = "1.2.2"
StructEquality = "2.1.0"
SymbolicUtils = "3.1.2"
Unicode = "1.6"
julia = "1.6"
10 changes: 9 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, , , , associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
using Catlab.CategoricalAlgebra
import Catlab.CategoricalAlgebra:
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using Catlab.ACSetInterface
using MLStyle
import Unicode
using Reexport

## TODO:
## generate schema from a _theory_
Expand All @@ -62,9 +64,15 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

using .Deca
@reexport using .Deca
@reexport using .SymbolicUtilsInterop

include("acset2symbolic.jl")

end
157 changes: 157 additions & 0 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
module SymbolicUtilsInterop

using ACSets
using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp
using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique!
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca

using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

# name collision with decapodes.Equation
struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = print(io, "$(e.lhs) == $(e.rhs)")

## a struct carry the symbolic variables and their equations
struct SymbolicContext
vars::Vector{Symbolic}
equations::Vector{SymbolicEquation{Symbolic}}
end
export SymbolicContext

Base.show(io::IO, d::SymbolicContext) = begin
println(io, "SymbolicContext(")
println(io, " Variables: [$(join(d.vars, ", "))]")
println(io, " Equations: [")
eqns = map(d.equations) do op
" $(op)"
end
println(io, "$(join(eqns,",\n"))])")
end

## BasicSymbolic -> DecaExpr
function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
if SymbolicUtils.issym(t)
decapodes.Var(nameof(t))
else
op = SymbolicUtils.head(t)
args = SymbolicUtils.arguments(t)
termargs = Term.(args)
if op == +
decapodes.Plus(termargs)
elseif op == *
decapodes.Mult(termargs)
elseif op [DerivOp, ∂ₜ]
decapodes.Tan(only(termargs))
elseif length(args) == 1
decapodes.App1(nameof(op, symtype.(args)...), termargs...)
elseif length(args) == 2
decapodes.App2(nameof(op, symtype.(args)...), termargs...)
else
error("was unable to convert $t into a Term")
end
end
end
# TODO subtraction is not parsed as such. e.g.,
# a, b = @syms a::Scalar b::Scalar
# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)]))

decapodes.Term(x::Real) = decapodes.Lit(Symbol(x))

function decapodes.DecaExpr(d::SymbolicContext)
context = map(d.vars) do var
decapodes.Judgement(nameof(var), nameof(symtype(var)), :I)
end
equations = map(d.equations) do eq
decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs))
end
decapodes.DecaExpr(context, equations)
end

"""
Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC
Example:
```
a = @syms a::Real
context = Dict(:a => Scalar(), :u => PrimalForm(0))
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term)
# user must import symbols into scope
! = (f -> getfield(@__MODULE__, f))
@match t begin
Var(name) => SymbolicUtils.Sym{context[name]}(name)
Lit(v) => Meta.parse(string(v))
# see heat_eq test: eqs had AppCirc1, but this returns
# App1(f, App1(...)
AppCirc1(fs, arg) => foldr(
# panics with constants like :k
# see test/language.jl
(f, x) -> (!(f))(x),
fs;
init=BasicSymbolic(context, arg)
)
App1(f, x) => (!(f))(BasicSymbolic(context, x))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...)
Tan(x) => (!(DerivOp))(BasicSymbolic(context, x))
end
end

function SymbolicContext(d::decapodes.DecaExpr)
# associates each var to its sort...
context = map(d.context) do j
j.var => symtype(Deca.DECQuantity, j.dim, j.space)
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{s}(v)
end
context = Dict{Symbol,DataType}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...)
end
SymbolicContext(vars, eqs)
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
end

""" function SummationDecapode(e::SymbolicContext) """
function SummationDecapode(e::SymbolicContext)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var)))
symbol_table[var.name] = var_id
end

deletions = Vector{Int}()
foreach(e.equations) do eq
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))

recognize_types(d)

fill_names!(d)
d[:name] = normalize_unicode.(d[:name])
make_sum_mult_unique!(d)
return d
end

end
1 change: 1 addition & 0 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
export recognize_types

""" is_expanded(d::AbstractNamedDecapode)
Expand Down
83 changes: 83 additions & 0 deletions src/acset2symbolic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using DiagrammaticEquations
using ACSets
using SymbolicUtils
using SymbolicUtils: BasicSymbolic, Symbolic

export symbolic_rewriting

const EQUALITY = (==)
const SymEqSym = SymbolicEquation{Symbolic}

function symbolics_lookup(d::SummationDecapode)
Dict{Symbol, BasicSymbolic}(map(d[:name],d[:type]) do name,type
(name, decavar_to_symbolics(name, type))
end)
end

function decavar_to_symbolics(var_name::Symbol, var_type::Symbol, space = :I)
new_type = SymbolicUtils.symtype(Deca.DECQuantity, var_type, space)
SymbolicUtils.Sym{new_type}(var_name)
end

function to_symbolics(d::SummationDecapode, symvar_lookup::Dict{Symbol, BasicSymbolic}, op_idx::Int, op_type::Symbol)
input_syms = getindex.(Ref(symvar_lookup), d[edge_inputs(d,op_idx,Val(op_type)), :name])
output_sym = getindex.(Ref(symvar_lookup), d[edge_output(d,op_idx,Val(op_type)), :name])
op_sym = getfield(@__MODULE__, edge_function(d,op_idx,Val(op_type)))

S = promote_symtype(op_sym, input_syms...)
SymEqSym(output_sym, SymbolicUtils.Term{S}(op_sym, input_syms))
end

function to_symbolics(d::SummationDecapode)
symvar_lookup = symbolics_lookup(d)
map(e -> to_symbolics(d, symvar_lookup, e.index, e.name), topological_sort_edges(d))
end

function symbolic_rewriting(d::SummationDecapode, rewriter=identity)
d′ = infer_types!(deepcopy(d))
eqns = merge_equations(d′)
to_acset(d′, map(rewriter, eqns))
end

# XXX SymbolicUtils.substitute swaps the order of multiplication.
# e.g. ∂ₜ(G) == κ*u becomes ∂ₜ(G) == u*κ
function merge_equations(d::SummationDecapode)
eqn_lookup, terminal_eqns = Dict(), SymEqSym[]
deriv_op_tgts = d[incident(d, DerivOp, :op1), [:tgt, :name]] # Patches over issue #77
terminal_vars = Set{Symbol}(vcat(infer_terminal_names(d), deriv_op_tgts))

foreach(to_symbolics(d)) do x
sub = SymbolicUtils.substitute(x.rhs, eqn_lookup)
push!(eqn_lookup, (x.lhs => sub))
if x.lhs.name in terminal_vars
push!(terminal_eqns, SymEqSym(x.lhs, sub))
end
end

map(terminal_eqns) do eqn
SymbolicUtils.Term{Number}(EQUALITY, [eqn.lhs, eqn.rhs])
end
end

function to_acset(d::SummationDecapode, sym_exprs)
literals = incident(d, :Literal, :type)

outer_types = map([infer_states(d)..., infer_terminals(d)..., literals...]) do i
:($(d[i, :name])::$(d[i, :type]))
end

#TODO: This step is breaking up summations
final_exprs = SymbolicUtils.Code.toexpr.(sym_exprs)
reify!(exprs) = foreach(exprs) do x
if typeof(x) == Expr && x.head == :call
x.args[1] = nameof(x.args[1])
reify!(x.args[2:end])
end
end
reify!(final_exprs)

deca_block = quote end
deca_block.args = [outer_types..., final_exprs...]

(infer_types!, SummationDecapode, parse_decapode)(deca_block)
end
5 changes: 5 additions & 0 deletions src/deca/Deca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ using DataStructures
using ..DiagrammaticEquations
using Catlab

using Reexport

import ..infer_types!, ..resolve_overloads!

export normalize_unicode, varname, infer_types!, resolve_overloads!, typename, spacename, recursive_delete_parents, recursive_delete_parents!, unicode!, op1_res_rules_1D, op2_res_rules_1D, op1_res_rules_2D, op2_res_rules_2D, op1_inf_rules_1D, op2_inf_rules_1D, op1_inf_rules_2D, op2_inf_rules_2D, vec_to_dec!

include("deca_acset.jl")
include("deca_visualization.jl")
include("ThDEC.jl")

@reexport using .ThDEC

""" function recursive_delete_parents!(d::SummationDecapode, to_delete::Vector{Int64})
Expand Down
Loading

0 comments on commit 2f82754

Please sign in to comment.