diff --git a/DEC/Project.toml b/DEC/Project.toml index 63e3d73..f8e2621 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -4,18 +4,28 @@ authors = ["Owen Lynch "] version = "0.1.0" [deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +Decapodes = "679ab3ea-c928-4fe6-8d59-fd451142d391" +GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" [compat] +CairoMakie = "0.12.5" Colors = "0.12.11" +CombinatorialSpaces = "0.6.7" Crayons = "4.1.1" +Decapodes = "0.5.5" +GeometryBasics = "0.4.11" MLStyle = "0.4.17" +OrdinaryDiffEq = "6.86.0" Random = "1.11.0" Reexport = "1.2.2" StructEquality = "2.1.0" diff --git a/DEC/src/DEC.jl b/DEC/src/DEC.jl index 496437a..432a258 100644 --- a/DEC/src/DEC.jl +++ b/DEC/src/DEC.jl @@ -10,356 +10,11 @@ import Base: +, - import Base: * include("HashColor.jl") +include("Signature.jl") +include("Roe.jl") include("SSAExtract.jl") +include("Luke.jl") @reexport using .SSAExtract -@data Sort begin - Scalar() - Form(dim::Int, isdual::Bool) -end -export Scalar, Form - -duality(f::Form) = f.isdual ? "dual" : "primal" - -PrimalForm(i::Int) = Form(i, false) -export PrimalForm - -DualForm(i::Int) = Form(i, true) -export DualForm - -struct SortError <: Exception - message::String -end - -@nospecialize -function +(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(i1, isdual1), Form(i2, isdual2)) => - if (i1 == i2) && (isdual1 == isdual2) - Form(i1, isdual1) - else - throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) - end - end -end - --(s1::Sort, s2::Sort) = +(s1, s2) - --(s::Sort) = s - -@nospecialize -function *(s1::Sort, s2::Sort) - @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) - (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) - end -end - -@nospecialize -function ∧(s1::Sort, s2::Sort) - @match (s1, s2) begin - (_, Scalar()) || (Scalar(), _) => throw(SortError("Cannot take a wedge product with a scalar")) - (Form(i1, isdual1), Form(i2, isdual2)) => - if isdual1 == isdual2 - if i1 + i2 <= 2 - Form(i1 + i2, isdual1) - else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) - end - else - throw(SortError("Cannot wedge two forms of different dualities: attempted to wedge $(duality(s1)) and $(duality(s2))")) - end - end -end - -@nospecialize -∂ₜ(s::Sort) = s - -@nospecialize -function d(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) - Form(i, isdual) => - if i <= 1 - Form(i + 1, isdual) - else - throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) - end - end -end - -function ★(s::Sort) - @match s begin - Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) - Form(i, isdual) => Form(2 - i, !isdual) - end -end - -@struct_hash_equal struct RootVar - name::Symbol - idx::Int - sort::Sort -end - -function rootvarcrayon(v::RootVar) - lightnessrange = (50., 100.) - HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) -end - -function Base.show(io::IO, v::RootVar) - if get(io, :color, true) - crayon = rootvarcrayon(v) - print(io, crayon, "$(v.name)") - print(io, inv(crayon)) - else - print(io, "$(v.name)%$(v.idx)") - end -end - -struct Decapode - variables::Vector{RootVar} - graph::EGraph{Expr, Sort} - function Decapode() - new(RootVar[], EGraph{Expr, Sort}()) - end -end - -function EGraphs.make(g::EGraph{Expr, Sort}, n::Metatheory.VecExpr) - op = EGraphs.get_constant(g, Metatheory.v_head(n)) - if op isa RootVar - op.sort - else - op((g[arg].data for arg in Metatheory.v_children(n))...) - end -end - -struct Var{S} - pode::Decapode - id::Id -end - -function extract!(v::Var, f=EGraphs.astsize) - extract!(v.pode.graph, f, v.id) -end - -function fix_functions(e) - @match e begin - s::Symbol => s - Expr(:call, f::Function, args...) => - Expr(:call, nameof(f), fix_functions.(args)...) - Expr(head, args...) => - Expr(head, fix_functions.(args)...) - _ => e - end -end - -function getexpr(v::Var) - e = EGraphs.extract!(v.pode.graph, Metatheory.astsize, v.id) - fix_functions(e) -end - -function Base.show(io::IO, v::Var) - print(io, getexpr(v)) -end - -function fresh!(pode::Decapode, sort::Sort, name::Symbol) - v = RootVar(name, length(pode.variables), sort) - push!(pode.variables, v) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, v)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) - Var{sort}(pode, EGraphs.add!(pode.graph, n, false)) -end - -@nospecialize -function inject_number!(pode::Decapode, x::Number) - x = Float64(x) - n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, x)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) - Var{Scalar()}(pode, EGraphs.add!(pode.graph, n, false)) -end - -@nospecialize -function addcall!(g::EGraph, head, args) - ar = length(args) - n = Metatheory.v_new(ar) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) - Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) - Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) - Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) - for i in Metatheory.v_children_range(n) - @inbounds n[i] = args[i - VECEXPR_META_LENGTH] - end - EGraphs.add!(g, n, false) -end - -@nospecialize -function +(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.pode === v2.pode || error("Cannot add variables from different graphs") - s = s1 + s2 - Var{s}(v1.pode, addcall!(v1.pode.graph, +, (v1.id, v2.id))) -end - -@nospecialize -+(v::Var, x::Number) = +(v, inject_number!(v.pode, x)) - -@nospecialize -+(x::Number, v::Var) = +(inject_number!(v.pode, x), v) - -@nospecialize -function -(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.pode == v2.pode || error("Cannot subtract variables from different graphs") - s = s1 - s2 - Var{s}(v1.pode, addcall!(v1.pode.graph, -, (v1.id, v2.id))) -end - -@nospecialize --(v::Var{s}) where {s} = Var{s}(v.pode, addcall!(v.pode.graph, -, (v.id,))) - -@nospecialize --(v::Var, x::Number) = -(v, inject_number!(v.pode, x)) - -@nospecialize --(x::Number, v::Var) = -(inject_number!(v.pode, x), v) - -@nospecialize -function *(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.pode === v2.pode || error("Cannot multiply variables from different graphs") - s = s1 * s2 - Var{s}(v1.pode, addcall!(v1.pode.graph, *, (v1.id, v2.id))) -end - -@nospecialize -*(v::Var, x::Number) = *(v, inject_number!(v.pode, x)) - -@nospecialize -*(x::Number, v::Var) = *(inject_number!(v.pode, x), v) - -@nospecialize -function ∧(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - v1.pode === v2.pode || error("Cannot wedge variables from different graphs") - s = s1 ∧ s2 - Var{s}(v1.pode, addcall!(v1.pode.graph, ∧, (v1.id, v2.id))) -end - -@nospecialize -function ∂ₜ(v::Var{s}) where {s} - Var{s}(v.pode, addcall!(v.pode.graph, ∂ₜ, (v.id,))) -end - -@nospecialize -function d(v::Var{s}) where {s} - s′ = d(s) - Var{s′}(v.pode, addcall!(v.pode.graph, d, (v.id,))) -end - - -@nospecialize -function ★(v::Var{s}) where {s} - s′ = ★(s) - Var{s′}(v.pode, addcall!(v.pode.graph, ★, (v.id,))) -end - -Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) - -function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} - (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) - v1.pode === v2.pode || error("Cannot equate variables from different graphs") - union!(v1.pode.graph, v1.id, v2.id) -end - -≐(v1::Var, v2::Var) = equate!(v1, v2) - -@nospecialize -function derivative_cost(allowed_roots) - function cost(n::Metatheory.VecExpr, op, costs) - if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) - Inf - else - Metatheory.astsize(n, op, costs) - end - end -end - -struct TypedApplication - fn::Function - sorts::Vector{Sort} -end - -""" vfield :: (Decapode -> (StateVars, ParamVars)) -> VectorFieldFunction - -Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. - -Example: given a diffusivity constant a, the heat equation can be written as: -``` - ∂ₜ u = a * Laplacian(u) -``` -would return (u, a). - -A limitation of this function can be demonstrated here: given the model - ``` - ∂ₜ a = a + b - ∂ₜ b = a + b - ``` - we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, - the extractor would bypass it completely. -""" -function vfield(model, matrices::Dict{TypedApplication, Any}) - pode = Decapode() - (state_vars, param_vars) = model(pode) - length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") - state_rootvars = map(state_vars) do x - rv = extract!(x) - rv isa RootVar || error("all state variables must be RootVars") - rv - end - param_rootvars = map(param_vars) do p - rv = extract!(p) - rv isa RootVar || error("all param variables must be RootVars") - rv - end - cost = derivative_cost(Set([state_rootvars; param_rootvars])) - - u = :u - p = :p - du = :du - - rootvar_lookup = - Dict{RootVar, Expr}( - [ - [rv => :(@inbounds $(u)[$i]) for (i, rv) in enumerate(state_rootvars)]; - [rv => :(@inbounds $(p)[$i]) for (i, rv) in enumerate(param_rootvars)] - ] - ) - - derivative_exprs = map(enumerate(state_vars)) do (i, v) - e = extract!(∂ₜ(v), cost) - :(@inbounds $(du)[$i] = $(replace_rootvars(e, rootvar_lookup))) - end - - - - eval( - quote - ($du, $u, $p, _) -> begin - $(derivative_exprs...) - $du - end - end - ) -end - -function replace_rootvars(e, rootvar_lookup::Dict{RootVar, Expr}) - @match e begin - (rv::RootVar) => rootvar_lookup[rv] - Expr(head, args...) => Expr(head, replace_rootvars.(args, Ref(rootvar_lookup))...) - _ => e - end -end - end # module DEC diff --git a/DEC/src/Luke.jl b/DEC/src/Luke.jl new file mode 100644 index 0000000..31a14e9 --- /dev/null +++ b/DEC/src/Luke.jl @@ -0,0 +1,75 @@ +import Decapodes +using StructEquality + +@struct_hash_equal struct TypedApplication + fn::Function + sorts::Vector{Sort} +end + +const TA = TypedApplication + +function Base.show(io::IOBuffer, ta::TA) + print(io, Expr(:call, ta.fn, ta.sorts...)) +end + +function precompute_matrices(sd, hodge)::Dict{TypedApplication, Any} + Dict{TypedApplication, Any}( + # Regular Hodge Stars + TA(★, Sort[PrimalForm(0)]) => Decapodes.dec_mat_hodge(0, sd, hodge), + TA(★, Sort[PrimalForm(1)]) => Decapodes.dec_mat_hodge(1, sd, hodge), + TA(★, Sort[PrimalForm(2)]) => Decapodes.dec_mat_hodge(2, sd, hodge), + + # Inverse Hodge Stars + TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(1, sd, hodge), # why is this 1??? + TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), # Special since Geo is a solver + TA(★, Sort[DualForm(2)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), + + # Differentials + TA(d, Sort[PrimalForm(0)]) => Decapodes.dec_mat_differential(0, sd), + TA(d, Sort[PrimalForm(1)]) => Decapodes.dec_mat_differential(1, sd), + + # Dual Differentials + TA(d, Sort[DualForm(0)]) => Decapodes.dec_mat_dual_differential(0, sd), + TA(d, Sort[DualForm(1)]) => Decapodes.dec_mat_dual_differential(1, sd), + + # Wedge Products + TA(∧, Sort[PrimalForm(0), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{1,0}, sd), + TA(∧, Sort[PrimalForm(0), PrimalForm(2)]) => Decapodes.dec_pair_wedge_product(Tuple{0,2}, sd), + TA(∧, Sort[PrimalForm(2), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{2,0}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{1,1}, sd), + + # Primal-Dual Wedge Products + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dp(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dp(Tuple{1,0}, sd), + + # Dual-Dual Wedge Products + # TA(∧, Sort[DualForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{1,1}, sd), + TA(∧, Sort[DualForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dd(Tuple{1,0}, sd), + TA(∧, Sort[DualForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{0,1}, sd), + + # # Dual-Dual Interior Products + # :ι₁₁ => interior_product_dd(Tuple{1,1}, sd) + # :ι₁₂ => interior_product_dd(Tuple{1,2}, sd) + + # # Dual-Dual Lie Derivatives + # :ℒ₁ => ℒ_dd(Tuple{1,1}, sd) + + # # Dual Laplacians + # :Δᵈ₀ => Δᵈ(Val{0},sd) + # :Δᵈ₁ => Δᵈ(Val{1},sd) + + # # Musical Isomorphisms + # :♯ => Decapodes.dec_♯_p(sd) + # :♯ᵈ => Decapodes.dec_♯_d(sd) + + # :♭ => Decapodes.dec_♭(sd) + + # # Averaging Operator + # :avg₀₁ => Decapodes.dec_avg₀₁(sd) + + # :neg => x -> -1 .* x + ) +end \ No newline at end of file diff --git a/DEC/src/Roe.jl b/DEC/src/Roe.jl index e69de29..8869b81 100644 --- a/DEC/src/Roe.jl +++ b/DEC/src/Roe.jl @@ -0,0 +1,292 @@ +@struct_hash_equal struct RootVar + name::Symbol + idx::Int + sort::Sort +end + +struct Roe + variables::Vector{RootVar} + graph::EGraph{Expr, Sort} + function Roe() + new(RootVar[], EGraph{Expr, Sort}()) + end +end + +struct Var{S} + roe::Roe + id::Id +end + +function EGraphs.make(g::EGraph{Expr, Sort}, n::Metatheory.VecExpr) + op = EGraphs.get_constant(g,Metatheory.v_head(n)) + if op isa RootVar + op.sort + elseif op isa Number + Scalar() + else + op((g[arg].data for arg in Metatheory.v_children(n))...) + end +end + +function EGraphs.join(s1::Sort, s2::Sort) + if s1 == s2 + s1 + else + error("Cannot equate two nodes with different sorts") + end +end + +function extract!(v::Var, f=EGraphs.astsize) + extract!(v.roe.graph, f, v.id) +end + +function rootvarcrayon(v::RootVar) + lightnessrange = (50., 100.) + HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) +end + +function Base.show(io::IO, v::RootVar) + if get(io, :color, true) + crayon = rootvarcrayon(v) + print(io, crayon, "$(v.name)") + print(io, inv(crayon)) + else + print(io, "$(v.name)#$(v.idx)") + end +end + +function fix_functions(e) + @match e begin + s::Symbol => s + Expr(:call, f::Function, args...) => + Expr(:call, nameof(f), fix_functions.(args)...) + Expr(head, args...) => + Expr(head, fix_functions.(args)...) + _ => e + end +end + +function getexpr(v::Var) + e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) + fix_functions(e) +end + +function Base.show(io::IO, v::Var) + print(io, getexpr(v)) +end + +function fresh!(roe::Roe, sort::Sort, name::Symbol) + v = RootVar(name, length(roe.variables), sort) + push!(roe.variables, v) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) + Var{sort}(roe, EGraphs.add!(roe.graph, n, false)) +end + +@nospecialize +function inject_number!(roe::Roe, x::Number) + x = Float64(x) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) + Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) +end + +@nospecialize +function addcall!(g::EGraph, head, args) + ar = length(args) + n = Metatheory.v_new(ar) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) + Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) + for i in Metatheory.v_children_range(n) + @inbounds n[i] = args[i - VECEXPR_META_LENGTH] + end + EGraphs.add!(g, n, false) +end + +function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) + v1.roe === v2.roe || error("Cannot equate variables from different graphs") + union!(v1.roe.graph, v1.id, v2.id) +end + +≐(v1::Var, v2::Var) = equate!(v1, v2) + +@nospecialize +function derivative_cost(allowed_roots) + function cost(n::Metatheory.VecExpr, op, costs) + if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) + Inf + else + Metatheory.astsize(n, op, costs) + end + end +end + + +@nospecialize +function +(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot add variables from different graphs") + s = s1 + s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, +, (v1.id, v2.id))) +end + +@nospecialize ++(v::Var, x::Number) = +(v, inject_number!(v.roe, x)) + +@nospecialize ++(x::Number, v::Var) = +(inject_number!(v.roe, x), v) + +@nospecialize +function -(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe == v2.roe || error("Cannot subtract variables from different graphs") + s = s1 - s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, -, (v1.id, v2.id))) +end + +@nospecialize +-(v::Var{s}) where {s} = Var{s}(v.roe, addcall!(v.roe.graph, -, (v.id,))) + +@nospecialize +-(v::Var, x::Number) = -(v, inject_number!(v.roe, x)) + +@nospecialize +-(x::Number, v::Var) = -(inject_number!(v.roe, x), v) + +@nospecialize +function *(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot multiply variables from different graphs") + s = s1 * s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, *, (v1.id, v2.id))) +end + +@nospecialize +*(v::Var, x::Number) = *(v, inject_number!(v.roe, x)) + +@nospecialize +*(x::Number, v::Var) = *(inject_number!(v.roe, x), v) + +@nospecialize +function ∧(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + v1.roe === v2.roe || error("Cannot wedge variables from different graphs") + s = s1 ∧ s2 + Var{s}(v1.roe, addcall!(v1.roe.graph, ∧, (v1.id, v2.id))) +end + +@nospecialize +function ∂ₜ(v::Var{s}) where {s} + Var{s}(v.roe, addcall!(v.roe.graph, ∂ₜ, (v.id,))) +end + +@nospecialize +function d(v::Var{s}) where {s} + s′ = d(s) + Var{s′}(v.roe, addcall!(v.roe.graph, d, (v.id,))) +end + + +@nospecialize +function ★(v::Var{s}) where {s} + s′ = ★(s) + Var{s′}(v.roe, addcall!(v.roe.graph, ★, (v.id,))) +end + +Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) + + +""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction + +Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. + +Example: given a diffusivity constant a, the heat equation can be written as: +``` + ∂ₜ u = a * Laplacian(u) +``` +would return (u, a). + +A limitation of this function can be demonstrated here: given the model + ``` + ∂ₜ a = a + b + ∂ₜ b = a + b + ``` + we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, + the extractor would bypass it completely. +""" +function vfield(model, operator_lookup::Dict{TypedApplication, Any}) + roe = Roe() + (state_vars, param_vars) = model(roe) + length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") + state_rootvars = map(state_vars) do x + rv = extract!(x) + rv isa RootVar || error("all state variables must be RootVars") + rv + end + param_rootvars = map(param_vars) do p + rv = extract!(p) + rv isa RootVar || error("all param variables must be RootVars") + rv + end + + u = :u + p = :p + du = :du + + rootvar_lookup = + Dict{RootVar, Union{Expr, Symbol}}( + [ + [rv => :($(u)) for (i, rv) in enumerate(state_rootvars)]; + [rv => :($(p)) for (i, rv) in enumerate(param_rootvars)] + ] + ) + + cost = derivative_cost(Set([state_rootvars; param_rootvars])) + + extractor = EGraphs.Extractor(roe.graph, cost, Float64) + + function term_select(id) + EGraphs.find_best_node(extractor, id) + end + + + ssa = SSAExtract.SSA() + + derivative_vars = map(state_vars) do v + SSAExtract.extract_ssa!(roe.graph, ssa, (∂ₜ(v)).id, term_select) + end + + toexpr(v::SSAExtract.SSAVar) = Symbol("tmp%$(v.idx)") + + function toexpr(expr::SSAExtract.SSAExpr) + if expr.fn isa RootVar + rootvar_lookup[expr.fn] + elseif expr.fn isa Number + expr.fn + else + op = operator_lookup[TypedApplication(expr.fn, first.(expr.args))] + if op isa Tuple + op = op[1] + end + Expr(:call, *, op, toexpr.(last.(expr.args))...) + end + end + + ssalines = map(enumerate(ssa.statements)) do (i, expr) + :($(toexpr(SSAExtract.SSAVar(i))) = $(toexpr(expr))) + end + + set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) + :($(du) .= $(toexpr(v))) + end + + eval( + quote + ($du, $u, $p, _) -> begin + $(ssalines...) + $(set_derivative_stmts...) + end + end + ) +end \ No newline at end of file diff --git a/DEC/src/SSAExtract.jl b/DEC/src/SSAExtract.jl index eb78ac0..ea6126f 100644 --- a/DEC/src/SSAExtract.jl +++ b/DEC/src/SSAExtract.jl @@ -2,30 +2,35 @@ module SSAExtract using MLStyle using Metatheory.EGraphs +using ..DEC: Sort +using StructEquality -struct SSAVar +@struct_hash_equal struct SSAVar idx::Int end function Base.show(io::IO, v::SSAVar) - print(io, "\$", v.idx) + print(io, "%", v.idx) end -@data SSAExpr begin - Constant(x::Any) - App(fn::Any, args::Vector{SSAVar}) +@struct_hash_equal struct SSAExpr + fn::Any + args::Vector{Tuple{Sort, SSAVar}} end function Base.show(io::IO, e::SSAExpr) - @match e begin - Constant(x) => show(io, x) - App(fn, args) => begin - print(io, fn) - print(io, Expr(:tuple, args...)) - end + print(io, e.fn) + if length(e.args) > 0 + print(io, Expr(:tuple, (Expr(:(::), v, sort) for (sort, v) in e.args)...)) end end +""" +Advantages of SSA form: + +1. We can preallocate each matrix +2. We can run a register-allocation algorithm to minimize the number of matrices that we have to preallocate +""" struct SSA assignment_lookup::Dict{Id, SSAVar} statements::Vector{SSAExpr} @@ -65,30 +70,25 @@ will be assigned to. The closure parameters control the behavior of this function. - term_select(g::EGraph, id::Id)::VecExpr + term_select(id::Id)::VecExpr This closure selects, given an id in an EGraph, the term that we want to use in order to compute a value for that id - - make_expr(head::Any, args::Vector{Tuple{Sort, SSAVar}})::SSAExpr - -This closure produces an SSAExpr by selecting a head based on the head of the -term in the e-graph and the sorts of the arguments. """ -function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar +function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select)::SSAVar if hasid(ssa, id) return getvar(ssa, id) end - term = term_select(g, id) + term = term_select(id) args = map(EGraphs.v_children(term)) do arg - (g[arg].data, extract_ssa!(g, ssa, arg, term_select, make_expr)) + (g[arg].data, extract_ssa!(g, ssa, arg, term_select)) end - add_stmt!(ssa, id, make_expr(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) + add_stmt!(ssa, id, SSAExpr(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) end export extract_ssa! function extract_ssa!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) - extract_ssa!(g, ssa, id, term_select, make_expr) + extract_ssa!(g, ssa, id, term_select) end end \ No newline at end of file diff --git a/DEC/src/Signature.jl b/DEC/src/Signature.jl new file mode 100644 index 0000000..89328c9 --- /dev/null +++ b/DEC/src/Signature.jl @@ -0,0 +1,84 @@ +@data Sort begin + Scalar() + Form(dim::Int, isdual::Bool) +end +export Scalar, Form + +duality(f::Form) = f.isdual ? "dual" : "primal" + +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm + +struct SortError <: Exception + message::String +end + +@nospecialize +function +(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end + end +end + +-(s1::Sort, s2::Sort) = +(s1, s2) + +-(s::Sort) = s + +@nospecialize +function *(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + end +end + +@nospecialize +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (_, Scalar()) || (Scalar(), _) => throw(SortError("Cannot take a wedge product with a scalar")) + (Form(i1, isdual1), Form(i2, isdual2)) => + if isdual1 == isdual2 + if i1 + i2 <= 2 + Form(i1 + i2, isdual1) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + else + throw(SortError("Cannot wedge two forms of different dualities: attempted to wedge $(duality(s1)) and $(duality(s2))")) + end + end +end + +@nospecialize +∂ₜ(s::Sort) = s + +@nospecialize +function d(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) + else + throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + end + end +end + +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end diff --git a/DEC/tests/DEC.jl b/DEC/tests/DEC.jl index 7d13051..01cf16f 100644 --- a/DEC/tests/DEC.jl +++ b/DEC/tests/DEC.jl @@ -1,7 +1,7 @@ module TestDEC using DEC -using DEC: Decapode, SortError, d, fresh!, ∂ₜ, ∧, Δ, ≐ +using DEC: Roe, SortError, d, fresh!, ∂ₜ, ∧, Δ, ≐, ★ using Test using Metatheory.EGraphs @@ -19,18 +19,19 @@ using Metatheory.EGraphs # Exterior Product @test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) -pode = Decapode() +roe = Roe() -a = fresh!(pode, Scalar(), :a) -b = fresh!(pode, Scalar(), :b) +a = fresh!(roe, Scalar(), :a) +b = fresh!(roe, Scalar(), :b) x = a + b y = a + b @test x == y +@test roe.graph[(a+b).id].data == Scalar() -ω = fresh!(pode, PrimalForm(1), :ω) -η = fresh!(pode, PrimalForm(0), :η) +ω = fresh!(roe, PrimalForm(1), :ω) +η = fresh!(roe, PrimalForm(0), :η) @test ω ∧ η isa DEC.Var{PrimalForm(1)} @test ω ∧ η == ω ∧ η @@ -57,7 +58,25 @@ function lotka_volterra(pode) ([w, s], [α, β, γ]) end -f = DEC.vfield(lotka_volterra) +(ssa, derivative_vars) = DEC.vfield(lotka_volterra) + +basicprinted(x; color=false) = sprint(show, x; context=(:color=>color)) + +@test basicprinted(ssa) == """ +SSA: + %1 = γ#2 + %2 = -(%1::Scalar(),) + %3 = w#3 + %4 = *(%2::Scalar(), %3::Scalar()) + %5 = β#1 + %6 = *(%5::Scalar(), %3::Scalar()) + %7 = s#4 + %8 = *(%6::Scalar(), %7::Scalar()) + %9 = -(%4::Scalar(), %8::Scalar()) + %10 = α#0 + %11 = *(%10::Scalar(), %7::Scalar()) + %12 = -(%11::Scalar(), %8::Scalar()) +""" function transitivity(pode) w = fresh!(pode, Scalar(), :w) @@ -65,10 +84,12 @@ function transitivity(pode) ∂ₜ(w) ≐ 2 * w w end -_w = transitivity(pode) +_w = transitivity(roe) # picks whichever expression it happens to visit first EGraphs.extract!((∂ₜ(_w)), DEC.derivative_cost([DEC.extract!(_w)])) +## HEAT EQUATION + function heat_equation(pode) u = fresh!(pode, PrimalForm(0), :u) @@ -77,7 +98,43 @@ function heat_equation(pode) ([u], []) end -f = DEC.vfield(heat_equation) +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} + +rect = triangulated_grid(100, 100, 1, 1, Point3D) +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect) +subdivide_duals!(d_rect, Circumcenter()) + +operator_lookup = DEC.precompute_matrices(d_rect, DiagonalHodge()) + +vf = DEC.vfield(heat_equation, operator_lookup) + +U = first.(d_rect[:point]) +constants_and_parameters = () + +tₑ = 500.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, tₑ), constants_and_parameters) +soln = solve(prob, Tsit5()) + +function save_dynamics(save_file_name) + time = Observable(0.0) + h = @lift(soln($time)) + f = Figure() + ax = CairoMakie.Axis(f[1,1], title = @lift("Heat at time $($time)")) + gmsh = mesh!(ax, rect, color=h, colormap=:jet, + colorrange=extrema(soln(tₑ))) + #Colorbar(f[1,2], gmsh, limits=extrema(soln(tₑ).h)) + Colorbar(f[1,2], gmsh) + timestamps = range(0, tₑ, step=5.0) + record(f, save_file_name, timestamps; framerate = 15) do t + time[] = t + end +end end \ No newline at end of file diff --git a/DEC/tests/SSAExtract.jl b/DEC/tests/SSAExtract.jl index f777881..5a12576 100644 --- a/DEC/tests/SSAExtract.jl +++ b/DEC/tests/SSAExtract.jl @@ -15,14 +15,8 @@ function term_select(g::EGraph, id::Id) g[id].nodes[1] end -function make_expr(head, args) - if length(args) == 0 - SSAExtract.Constant(head) - else - SSAExtract.App(head, last.(args)) - end -end +extract_ssa!(pode.graph, ssa, (a + b).id, term_select) -extract_ssa!(pode.graph, ssa, (a + b).id, term_select, make_expr) +ssa end \ No newline at end of file