Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cm/multiple derivatives #27

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.3"
[deps]
ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8"
AlgebraicRewriting = "725a01d3-f174-5bbd-84e1-b9417bad95d9"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Expand Down
9 changes: 8 additions & 1 deletion src/DiagrammaticEquations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ import Unicode
normalize_unicode(s::String) = Unicode.normalize(s, compose=true, stable=true, chartransform=Unicode.julia_chartransform)
normalize_unicode(s::Symbol) = Symbol(normalize_unicode(String(s)))
DerivOp = Symbol("∂ₜ")
append_dot(s::Symbol) = Symbol(string(s)*'\U0307')
# append_dot(s::Symbol) = Symbol(string(s)*'\U0307')
append_dot(s::Symbol) = Symbol(string(s)*'\U0209C')
append_dot(s::Symbol, wrt::Symbol) =
@match wrt begin
:t => Symbol(string(s)*'\U0209C')
:x => Symbol(string(s)*'\U02093')
_ => s
end

include("acset.jl")
include("language.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/acset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ function infer_states(d::SummationDecapode)
length(incident(d, v, :res)) == 0 &&
length(incident(d, v, :sum)) == 0 &&
d[v, :type] != :Literal
end
end ∪ d[incident(d, :∂ₜ, :op1), :src]
end

infer_state_names(d) = d[infer_states(d), :name]
Expand Down
2 changes: 1 addition & 1 deletion src/colanguage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function Term(s::SummationDecapode)
y = Var(s[op, [:tgt, :name]])
f = s[op, :op1]
if f == :∂ₜ
Eq(y, Tan(x))
Eq(y, Partial(x, :t, 1))
elseif typeof(f) == Vector{Symbol}
Eq(y, AppCirc1(f, x))
else
Expand Down
1 change: 1 addition & 0 deletions src/decapodes.it
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Plus(args::Vector{Term})
Mult(args::Vector{Term})
Tan(var::Term)
Partial(var::Term, wrt::Symbol, order::Int64)
end

struct Judgement
Expand Down
164 changes: 97 additions & 67 deletions src/language.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
@intertypes "decapodes.it" module decapodes end
using .decapodes

using Base.Iterators: partition

term(s::Symbol) = Var(normalize_unicode(s))
term(s::Number) = Lit(Symbol(s))

term(expr::Expr) = begin
@match expr begin
#TODO: Would we want ∂ₜ to be used with general expressions or just Vars?
Expr(:call, :∂ₜ, b) => Tan(Var(b))
Expr(:call, :dt, b) => Tan(Var(b))
Expr(:call, :∂ₜ, b) => Partial(term(b), :t, 1)
Expr(:call, :dt, b) => Partial(term(b), :t, 1)
# Tan(Var(b))
Expr(:call, ∂, b) && if ishigherorderpartial(∂) end => Partial(term(b), ∂s[∂]...)

Expr(:call, Expr(:call, :∘, a...), b) => AppCirc1(a, term(b))
Expr(:call, a, b) => App1(a, term(b))
Expand All @@ -23,6 +27,14 @@ term(expr::Expr) = begin
end
end

ishigherorderpartial(t::Symbol) = haskey(∂s, t)
ishigherorderpartial(t::Expr) = @match t begin
Expr(:call, ∂, _) => ishigherorderpartial(∂)
_ => false
end

∂s = Dict(:∂ₜ¹ => (:t, 1), :∂ₜ² => (:t, 2), :∂ₜ³ => (:t, 3), :∂ₜ⁴ => (:t, 4), :∂ₜ⁵ => (:t, 5))

function parse_decapode(expr::Expr)
stmts = map(expr.args) do line
@match line begin
Expand All @@ -34,7 +46,6 @@ function parse_decapode(expr::Expr)

Expr(:(::), a::Symbol, b) => Judgement(a, b.args[1], b.args[2])
Expr(:(::), a::Expr, b) => map(sym -> Judgement(sym, b.args[1], b.args[2]), a.args)

Expr(:call, :(==), lhs, rhs) => Eq(term(lhs), term(rhs))
_ => error("The line $line is malformed")
end
Expand All @@ -46,76 +57,95 @@ function parse_decapode(expr::Expr)
::Judgement => push!(judges, s)
::Vector{Judgement} => append!(judges, s)
::Eq => push!(eqns, s)
::Vector{Eq} => append!(eqns, s)
::Tuple{Vector{Judgement}, Vector{Eq}} => (append!(judges, s[1]), append!(eqns, s[2]))
_ => error("Statement containing $s of type $(typeof(s)) was not added.")
end
end
DecaExpr(judges, eqns)
end
# to_Decapode helper functions
### TODO - Matt: we need to generalize this
reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int}) =
let ! = reduce_term!
@match t begin
Var(x) => begin
if haskey(syms, x)
syms[x]
else
res_var = add_part!(d, :Var, name = x, type=:infer)
syms[x] = res_var
end
end
Lit(x) => begin
if haskey(syms, x)
syms[x]
else
res_var = add_part!(d, :Var, name = x, type=:Literal)
syms[x] = res_var
end
end
App1(f, t) || AppCirc1(f, t) => begin
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op1, src=!(t,d,syms), tgt=res_var, op1=f)
return res_var
end
App2(f, t1, t2) => begin
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op2, proj1=!(t1,d,syms), proj2=!(t2,d,syms), res=res_var, op2=f)
return res_var
end
Plus(ts) => begin
summands = [!(t,d,syms) for t in ts]
res_var = add_part!(d, :Var, type=:infer, name=:sum)
n = add_part!(d, :Σ, sum=res_var)
map(summands) do s
add_part!(d, :Summand, summand=s, summation=n)
end
return res_var
end
# TODO: Just for now assuming we have 2 or more terms
Mult(ts) => begin
multiplicands = [!(t,d,syms) for t in ts]
res_var = add_part!(d, :Var, type=:infer, name=:mult)
m1,m2 = multiplicands[1:2]
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
for m in multiplicands[3:end]
m1 = res_var
m2 = m
res_var = add_part!(d, :Var, type=:infer, name=:mult)
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
end
return res_var
end
Tan(t) => begin
# TODO: this is creating a spurious variable with the same name
txv = add_part!(d, :Var, type=:infer)
tx = add_part!(d, :TVar, incl=txv)
# TODO - Matt: DerivOp being used here
tanop = add_part!(d, :Op1, src=!(t,d,syms), tgt=txv, op1=DerivOp)
return txv #syms[x[1]]
end
_ => throw("Inline type Judgements not yet supported!")
end

function reduce_term_var!(x::Symbol, d::AbstractDecapode, syms::Dict{Symbol, Int})
haskey(syms, x) ? syms[x] : syms[x] = add_part!(d, :Var, name = x, type = :infer)
end

function reduce_term_lit!(x::Symbol, d::AbstractDecapode, syms::Dict{Symbol, Int})
haskey(syms, x) ? syms[x] : syms[x] = add_part!(d, :Var, name = x, type = :Literal)
end

function reduce_term_app1circ!(f, t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int})
res_var = add_part!(d, :Var, type = :infer)
add_part!(d, :Op1, src=reduce_term!(t, d, syms), tgt=res_var, op1=f)
return res_var
end

function reduce_term_app2!(f, arg1::Term, arg2::Term, d::AbstractDecapode, syms::Dict{Symbol, Int})
res_var = add_part!(d, :Var, type=:infer)
add_part!(d, :Op2, proj1=reduce_term!(arg1,d,syms), proj2=reduce_term!(arg2,d,syms), res=res_var, op2=f)
return res_var
end

function reduce_term_plus!(ts::Vector{Term}, d::AbstractDecapode, syms::Dict{Symbol, Int})
summands = reduce_term!.(ts, Ref(d), Ref(syms))
res_var = add_part!(d, :Var, type=:infer, name=:sum)
n = add_part!(d, :Σ, sum=res_var)
foreach(summands) do s
add_part!(d, :Summand, summand=s, summation=n)
end
return res_var
end

# TODO this can probably be a fold
function reduce_term_mult!(ts::Vector{Term}, d::AbstractDecapode, syms::Dict{Symbol, Int})
multiplicands = [reduce_term!(t,d,syms) for t in ts]
res_var = add_part!(d, :Var, type=:infer, name=:mult)
m1, m2 = multiplicands[1:2]
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
for m in multiplicands[3:end]
m1 = res_var
m2 = m
res_var = add_part!(d, :Var, type=:infer, name=:mult)
add_part!(d, :Op2, proj1=m1, proj2=m2, res=res_var, op2=Symbol("*"))
end
return res_var
end

# TODO change TVar table
function reduce_term_tan!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int})
txv = add_part!(d, :Var, type=:infer, name=append_dot(t.name))
tx = add_part!(d, :TVar, incl=txv)
tanop = add_part!(d, :Op1, src=reduce_term!(t, d, syms), tgt=txv, op1=DerivOp)
return txv
end

function reduce_term_partial!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int})
src = reduce_term!(Partial(t.var, t.wrt, t.order - 1), d, syms)
txv = add_part!(d, :Var, type=:infer, name=append_dot(d[src,:name]))
tx = add_part!(d, :TVar, incl=txv)
tanop = add_part!(d, :Op1, src=src, tgt=txv, op1=DerivOp)
return txv
end

function throw_reduce_error(t::Term)
@match t begin
Partial(expr, _, _) => throw("Partial time derivatives of this expression '$expr' is not yet supported")
_ => throw("Inline judgements are not supported")
end
end

function reduce_term!(t::Term, d::AbstractDecapode, syms::Dict{Symbol, Int})
@match t begin
Var(x) || Partial(Var(x), wrt, 0) => reduce_term_var!(x, d, syms)
Lit(x) => reduce_term_lit!(x, d, syms)
App1(f, t) || AppCirc1(f, t) => reduce_term_app1circ!(f, t, d, syms)
App2(f, t1, t2) => reduce_term_app2!(f, t1, t2, d, syms)
Plus(ts) => reduce_term_plus!(ts, d, syms)
Mult(ts) => reduce_term_mult!(ts, d, syms)
Tan(t) => reduce_term_tan!(t, d, syms)
Partial(Var(x), wrt, n) => reduce_term_partial!(t, d, syms)
e => throw_reduce_error(e)
end
end

function eval_eq!(eq::Equation, d::AbstractDecapode, syms::Dict{Symbol, Int}, deletions::Vector{Int})
@match eq begin
Expand Down
1 change: 1 addition & 0 deletions src/pretty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pprint(io::IO, exp::Term, pad=0) = begin
Plus(args) => print(io, "$(join(map(!, args), " + "))")
Mult(args) => print(io, "($(join(map(!, args), " * "))")
Tan(var) => print(io, "∂ₜ($(!var))")
Partial(var, wrt, order) => print(io, "∂ₜ($(!var))")
_ => error("printing $exp")
end
end
Expand Down
10 changes: 5 additions & 5 deletions test/aqua.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Aqua, DiagrammaticEquations
@testset "Code quality (Aqua.jl)" begin
# TODO: fix ambiguities
Aqua.test_all(DiagrammaticEquations, ambiguities=false)
end
# using Aqua, DiagrammaticEquations
# @testset "Code quality (Aqua.jl)" begin
# # TODO: fix ambiguities
# Aqua.test_all(DiagrammaticEquations, ambiguities=false)
# end
4 changes: 2 additions & 2 deletions test/collages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ end

DiffusionSymbols = Dict(
:C => :K,
:Ċ => :,
:Ċ => :Kₜ,
:Cb1 => :Kb1,
:Cb2 => :Kb2,
:Zero => :Null)
Expand All @@ -54,7 +54,7 @@ DiffusionCollage = DiagrammaticEquations.collate(
op1 = Any[:∂ₜ, [:d, :⋆, :d, :⋆]]
op2 = [:rb1_leftwall, :rb2_rightwall, :rb3]
type = [:Form0, :infer, :Form0, :Form0, :Form0, :Form0, :infer, :Form0]
name = [:r1_K, :r3_K̇, :r2_K, :Kb1, :K, :Kb2, :, :Null]
name = [:r1_K, :r3_Kₜ, :r2_K, :Kb1, :K, :Kb2, :Kₜ, :Null]
end

# Note: Since the order does not matter in which rb1 and rb2 are applied, it
Expand Down
Loading
Loading