Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interop with SymbolicUtils #64

Merged
merged 35 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c2d64b5
symbolicutils interop
olynch Aug 14, 2024
c11d4dd
added spaces to the sorts of forms and vector fields
olynch Aug 16, 2024
7de10e8
integrated changes to ThDEC with DecaSymbolic
olynch Aug 17, 2024
5928cc7
added tests for decasymbolic, but needs wrinkles ironed out. musical …
quffaro Aug 18, 2024
d7d6a5b
reverted src/DiagX to restore exports and adding Project.toml
quffaro Aug 19, 2024
2f79b0c
Merge remote-tracking branch 'origin/space-sorts' into symbolicutilsi…
quffaro Aug 19, 2024
a438192
updated code and tests after merge from space-sorts
quffaro Aug 19, 2024
18a1585
review changes:
jpfairbanks Aug 21, 2024
090ddfe
resolving some comments from code review.
quffaro Aug 22, 2024
d8be4ae
adding @alias and @register macros to make DecaSymbolic function work…
quffaro Aug 27, 2024
77770e5
experimenting with a type-driven approach
quffaro Aug 28, 2024
1fa477b
adding promote_symtype and addressing some of the code review comment…
quffaro Aug 30, 2024
3b265ac
refactoring @operator to integrate with promote_symtype
quffaro Sep 4, 2024
7f8597a
operator macro parses @rule but need to write tests and iron out rela…
quffaro Sep 5, 2024
154e51f
rewriting just needs tests
quffaro Sep 6, 2024
18bb71f
TST: add some klausmeier rewrites
jpfairbanks Sep 9, 2024
4ba6dce
BUG: fix method shadowing for existing operators like +/-
jpfairbanks Sep 9, 2024
82f65ce
added more tests, @operator macro is more flexible
quffaro Sep 10, 2024
e23459f
Merge pull request #73 from AlgebraicJulia/jpf/symbolictypes
quffaro Sep 10, 2024
e82e826
almost round-tripping in klausmeier. equations are not currently pass…
quffaro Sep 10, 2024
f32b3c0
Merge branch 'cm/symbolictypes' of github.com:AlgebraicJulia/Diagramm…
quffaro Sep 10, 2024
ecfa931
added rules function which dispatches on function symbol and the Val(…
quffaro Sep 10, 2024
4603ed9
fixed docs for @operator, fixed Term
quffaro Sep 11, 2024
f2ef1d0
Merge pull request #71 from AlgebraicJulia/cm/symbolictypes
quffaro Sep 11, 2024
5324de3
Loosened aqua tests
GeorgeR227 Sep 12, 2024
0a314a8
fixing bug where (+) always returns Scalar
quffaro Sep 18, 2024
5d5c25d
Expression level rewriting (#69)
GeorgeR227 Oct 3, 2024
70a420d
added tests for errors and consolidated errors with George
quffaro Oct 3, 2024
85572a4
Fixed hodge nameof and more tests
GeorgeR227 Oct 4, 2024
df41b18
Many more tests
GeorgeR227 Oct 4, 2024
6d1edf0
Merge branch 'main' into symbolicutilsinterop
jpfairbanks Oct 4, 2024
ccf0a79
More tests
GeorgeR227 Oct 5, 2024
467a93c
Remove unused functions
GeorgeR227 Oct 9, 2024
7c5155d
Fixed test in acset2symbolics
GeorgeR227 Oct 9, 2024
4f1f327
Added some improvements to naming
GeorgeR227 Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@ name = "DiagrammaticEquations"
uuid = "6f00c28b-6bed-4403-80fa-30e0dc12f317"
license = "MIT"
authors = ["James Fairbanks", "Andrew Baas", "Evan Patterson", "Luke Morris", "George Rauta"]
version = "0.1.6"
version = "0.1.7"

[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
Reexport = "1.2.2"
StructEquality = "2.1.0"
SymbolicUtils = "3.1.2"
Unicode = "1.6"
julia = "1.6"
15 changes: 13 additions & 2 deletions src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""
module DiagrammaticEquations

using Catlab

export
DerivOp, append_dot, normalize_unicode, infer_states, infer_types!,
# Deca
Expand All @@ -12,6 +14,7 @@ recursive_delete_parents, spacename, varname, unicode!, vec_to_dec!,
Collage, collate,
## composition
oapply, unique_by, unique_by!, OpenSummationDecapodeOb, OpenSummationDecapode, Open, default_composition_diagram,
apex, @relation, # Re-exported from Catlab
## acset
SchDecapode, SchNamedDecapode, AbstractDecapode, AbstractNamedDecapode, NamedDecapode, SummationDecapode,
contract_operators!, contract_operators, add_constant!, add_parameter, fill_names!, dot_rename!, is_expanded, expand_operators, infer_state_names, infer_terminal_names, recognize_types,
Expand All @@ -25,21 +28,23 @@ unique_lits!,
Plus, AppCirc1, Var, Tan, App1, App2,
## visualization
to_graphviz_property_graph, typename, draw_composition,
to_graphviz, # Re-exported from Catlab
## rewrite
average_rewrite,
## openoperators
transfer_parents!, transfer_children!, replace_op1!, replace_op2!, replace_all_op1s!, replace_all_op2s!

using Catlab
using Catlab.Theories
import Catlab.Theories: otimes, oplus, compose, ⊗, ⊕, ⋅, associate, associate_unit, Ob, Hom, dom, codom
using Catlab.Programs
using Catlab.CategoricalAlgebra
import Catlab.CategoricalAlgebra: ∧
using Catlab.WiringDiagrams
using Catlab.WiringDiagrams.DirectedWiringDiagrams
using Catlab.ACSetInterface
using MLStyle
import Unicode
using Reexport

## TODO:
## generate schema from a _theory_
Expand All @@ -59,9 +64,15 @@ include("rewrite.jl")
include("pretty.jl")
include("colanguage.jl")
include("openoperators.jl")
include("symbolictheoryutils.jl")
include("graph_traversal.jl")
include("deca/Deca.jl")
include("learn/Learn.jl")
include("SymbolicUtilsInterop.jl")

@reexport using .Deca
@reexport using .SymbolicUtilsInterop

using .Deca
include("acset2symbolic.jl")

end
157 changes: 157 additions & 0 deletions src/SymbolicUtilsInterop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
module SymbolicUtilsInterop

using ACSets
using ..DiagrammaticEquations: AbstractDecapode, Quantity, DerivOp
using ..DiagrammaticEquations: recognize_types, fill_names!, make_sum_mult_unique!
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..decapodes
using ..Deca

using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym, symtype

# name collision with decapodes.Equation
struct SymbolicEquation{E}
lhs::E
rhs::E
end
export SymbolicEquation

Base.show(io::IO, e::SymbolicEquation) = print(io, "$(e.lhs) == $(e.rhs)")

Check warning on line 21 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L21

Added line #L21 was not covered by tests

## a struct carry the symbolic variables and their equations
struct SymbolicContext
vars::Vector{Symbolic}
equations::Vector{SymbolicEquation{Symbolic}}
end
export SymbolicContext

Base.show(io::IO, d::SymbolicContext) = begin
println(io, "SymbolicContext(")
println(io, " Variables: [$(join(d.vars, ", "))]")
println(io, " Equations: [")
eqns = map(d.equations) do op
" $(op)"

Check warning on line 35 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L30-L35

Added lines #L30 - L35 were not covered by tests
end
println(io, "$(join(eqns,",\n"))])")

Check warning on line 37 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L37

Added line #L37 was not covered by tests
end

## BasicSymbolic -> DecaExpr
function decapodes.Term(t::SymbolicUtils.BasicSymbolic)
if SymbolicUtils.issym(t)
decapodes.Var(nameof(t))
else
op = SymbolicUtils.head(t)
args = SymbolicUtils.arguments(t)
termargs = Term.(args)
if op == +
decapodes.Plus(termargs)
elseif op == *
decapodes.Mult(termargs)
elseif op ∈ [DerivOp, ∂ₜ]
decapodes.Tan(only(termargs))
elseif length(args) == 1
decapodes.App1(nameof(op, symtype.(args)...), termargs...)
elseif length(args) == 2
decapodes.App2(nameof(op, symtype.(args)...), termargs...)
else
error("was unable to convert $t into a Term")
end
end
end
# TODO subtraction is not parsed as such. e.g.,
# a, b = @syms a::Scalar b::Scalar
# Term(a-b) = Plus(Term[Var(:a), Mult(Term[Lit(Symbol("-1")), Var(:b)]))

decapodes.Term(x::Real) = decapodes.Lit(Symbol(x))

function decapodes.DecaExpr(d::SymbolicContext)
context = map(d.vars) do var
decapodes.Judgement(nameof(var), nameof(symtype(var)), :I)
end
equations = map(d.equations) do eq
decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs))
end
decapodes.DecaExpr(context, equations)
end

"""
Retrieve the SymbolicUtils expression of a DecaExpr term `t` from a context of variables in ThDEC

Example:
```
a = @syms a::Real
context = Dict(:a => Scalar(), :u => PrimalForm(0))
SymbolicUtils.BasicSymbolic(context, Term(a))
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,DataType}, t::decapodes.Term)
# user must import symbols into scope
! = (f -> getfield(@__MODULE__, f))
@match t begin
Var(name) => SymbolicUtils.Sym{context[name]}(name)
Lit(v) => Meta.parse(string(v))
# see heat_eq test: eqs had AppCirc1, but this returns
# App1(f, App1(...)
AppCirc1(fs, arg) => foldr(

Check warning on line 97 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L97

Added line #L97 was not covered by tests
# panics with constants like :k
# see test/language.jl
(f, x) -> (!(f))(x),
fs;
init=BasicSymbolic(context, arg)
)
App1(f, x) => (!(f))(BasicSymbolic(context, x))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...)

Check warning on line 107 in src/SymbolicUtilsInterop.jl

View check run for this annotation

Codecov / codecov/patch

src/SymbolicUtilsInterop.jl#L107

Added line #L107 was not covered by tests
Tan(x) => (!(DerivOp))(BasicSymbolic(context, x))
end
end

function SymbolicContext(d::decapodes.DecaExpr)
# associates each var to its sort...
context = map(d.context) do j
j.var => symtype(Deca.DECQuantity, j.dim, j.space)
end
# ... which we then produce a vector of symbolic vars
vars = map(context) do (v, s)
SymbolicUtils.Sym{s}(v)
end
context = Dict{Symbol,DataType}(context)
eqs = map(d.equations) do eq
SymbolicEquation{Symbolic}(BasicSymbolic.(Ref(context), [eq.lhs, eq.rhs])...)
end
SymbolicContext(vars, eqs)
end

function eval_eq!(eq::SymbolicEquation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
eval_eq!(Eq(Term(eq.lhs), Term(eq.rhs)), d, syms, deletions)
end

""" function SummationDecapode(e::SymbolicContext) """
function SummationDecapode(e::SymbolicContext)
d = SummationDecapode{Any, Any, Symbol}()
symbol_table = Dict{Symbol, Int}()

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=nameof(symtype(var)))
symbol_table[var.name] = var_id
end

deletions = Vector{Int}()
foreach(e.equations) do eq
eval_eq!(eq, d, symbol_table, deletions)
end
rem_parts!(d, :Var, sort(deletions))

recognize_types(d)

fill_names!(d)
d[:name] = normalize_unicode.(d[:name])
make_sum_mult_unique!(d)
return d
end

end
47 changes: 22 additions & 25 deletions src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,16 @@ end
# A collection of DecaType getters
# TODO: This should be replaced by using a type hierarchy
const ALL_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2,
:Literal, :Parameter, :Constant, :infer]
:PVF, :DVF,
:Literal, :Parameter, :Constant, :infer]

const FORM_TYPES = [:Form0, :Form1, :Form2, :DualForm0, :DualForm1, :DualForm2]
const PRIMALFORM_TYPES = [:Form0, :Form1, :Form2]
const DUALFORM_TYPES = [:DualForm0, :DualForm1, :DualForm2]

const NONFORM_TYPES = [:Constant, :Parameter, :Literal, :infer]
const VECTORFIELD_TYPES = [:PVF, :DVF]

const NON_EC_TYPES = [:Constant, :Parameter, :Literal, :infer]
const USER_TYPES = [:Constant, :Parameter]
const NUMBER_TYPES = [:Literal]
const INFER_TYPES = [:infer]
Expand All @@ -184,6 +187,7 @@ function recognize_types(d::AbstractNamedDecapode)
isempty(unrecognized_types) ||
error("Types $unrecognized_types are not recognized. CHECK: $types")
end
export recognize_types

""" is_expanded(d::AbstractNamedDecapode)

Expand Down Expand Up @@ -427,12 +431,12 @@ function safe_modifytype!(d::SummationDecapode, var_idx::Int, new_type::Symbol)
end

"""
filterfor_forms(types::AbstractVector{Symbol})
filterfor_ec_types(types::AbstractVector{Symbol})

Return any form type symbols.
Return any form or vector-field type symbols.
"""
function filterfor_forms(types::AbstractVector{Symbol})
conditions = x -> !(x in NONFORM_TYPES)
function filterfor_ec_types(types::AbstractVector{Symbol})
conditions = x -> !(x in NON_EC_TYPES)
filter(conditions, types)
end

Expand All @@ -447,29 +451,26 @@ function infer_sum_types!(d::SummationDecapode, Σ_idx::Int)
types = d[idxs, :type]
all(t != :infer for t in types) && return applied # We need not infer

forms = unique(filterfor_forms(types))
ec_types = unique(filterfor_ec_types(types))

form = @match length(forms) begin
ec_type = @match length(ec_types) begin
0 => return applied # We can not infer
1 => only(forms)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $forms")
1 => only(ec_types)
_ => error("Type mismatch in summation $Σ_idx, all the following forms appear: $ec_types")
end

for idx in idxs
applied |= safe_modifytype!(d, idx, form)
applied |= safe_modifytype!(d, idx, ec_type)
end

return applied
end

function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
type_src = d[d[op1_id, :src], :type]
type_tgt = d[d[op1_id, :tgt], :type]
score_src = (rule.src_type == d[d[op1_id, :src], :type])
score_tgt = (rule.tgt_type == d[d[op1_id, :tgt], :type])

score_src = (rule.src_type == type_src)
score_tgt = (rule.tgt_type == type_tgt)
check_op = (d[op1_id, :op1] in rule.op_names)

if(check_op && (score_src + score_tgt == 1))
mod_src = safe_modifytype!(d, d[op1_id, :src], rule.src_type)
mod_tgt = safe_modifytype!(d, d[op1_id, :tgt], rule.tgt_type)
Expand All @@ -480,19 +481,15 @@ function apply_inference_rule_op1!(d::SummationDecapode, op1_id, rule)
end

function apply_inference_rule_op2!(d::SummationDecapode, op2_id, rule)
type_proj1 = d[d[op2_id, :proj1], :type]
type_proj2 = d[d[op2_id, :proj2], :type]
type_res = d[d[op2_id, :res], :type]
score_proj1 = (rule.proj1_type == d[d[op2_id, :proj1], :type])
score_proj2 = (rule.proj2_type == d[d[op2_id, :proj2], :type])
score_res = (rule.res_type == d[d[op2_id, :res], :type])

score_proj1 = (rule.proj1_type == type_proj1)
score_proj2 = (rule.proj2_type == type_proj2)
score_res = (rule.res_type == type_res)
check_op = (d[op2_id, :op2] in rule.op_names)

if(check_op && (score_proj1 + score_proj2 + score_res == 2))
if check_op && (score_proj1 + score_proj2 + score_res == 2)
mod_proj1 = safe_modifytype!(d, d[op2_id, :proj1], rule.proj1_type)
mod_proj2 = safe_modifytype!(d, d[op2_id, :proj2], rule.proj2_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
mod_res = safe_modifytype!(d, d[op2_id, :res], rule.res_type)
return mod_proj1 || mod_proj2 || mod_res
end

Expand Down
Loading
Loading