diff --git a/roe/Project.toml b/roe/Project.toml new file mode 100644 index 0000000..3fb743a --- /dev/null +++ b/roe/Project.toml @@ -0,0 +1,46 @@ +name = "DEC" +uuid = "e670b126-1168-4583-bd0f-29deacb39f6a" +authors = ["Owen Lynch "] +version = "0.1.0" + +[deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +AssociatedTests = "e00e7eca-deca-4415-9dc4-9b3efe792c16" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +CombinatorialSpaces = "b1c52339-7909-45ad-8b6a-6e388f7c67f2" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +Decapodes = "679ab3ea-c928-4fe6-8d59-fd451142d391" +Dtries = "fb203528-72e9-47b6-b8ee-6a26b9f77273" +GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" +MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +AbstractTrees = "0.4.5" +AssociatedTests = "0.1.0" +CairoMakie = "0.12.5" +Colors = "0.12.11" +CombinatorialSpaces = "0.6.7" +Crayons = "4.1.1" +Decapodes = "0.5.5" +GeometryBasics = "0.4.11" +MLStyle = "0.4.17" +Moshi = "0.3.1" +OrderedCollections = "1.6.3" +OrdinaryDiffEq = "6.86.0" +Reexport = "1.2.2" +StructEquality = "2.1.0" +SymbolicUtils = "1.5.1" +Test = "1.11.0" diff --git a/roe/README.md b/roe/README.md new file mode 100644 index 0000000..4207320 --- /dev/null +++ b/roe/README.md @@ -0,0 +1,133 @@ +# Roe + +This is a refactor of the core of diagrammatic equations, attempting to achieve a "more Julionic" approach to the problem of typed computer algebra via direct use of multiple dispatch. + +## Signature + +The "signature" of the DEC is encoded in a module `ThDEC`, in the following way. + +First, we make a type `Sort`, elements of which represent types in the discrete exterior calculus, for instance scalars, dual/primal forms of any degree, and vector fields. + +Then, we make a Julia function for each sort in the DEC. We define these Julia functions to *act on sorts*. So for instance, for the wedge product, we write a function definition like + +```julia +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) + (Form(i1, isdual), Form(i2, isdual)) => + if i1 + i2 <= 2 + Form(i1 + i2, isdual) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + _ => throw(SortError("Can only take a wedge product of two forms of the same duality")) + end +end +``` + +The advantage of encoding the signature in this way is twofold. + +1. We can give high-quality, context-specific errors when types fail to match. +2. It doesn't depend on any external libraries (except for MLStyle for convenience); it is just Plain Old Julia. + +## Using the signature to wrap symbolic algebra + +We can then wrap various symbolic frameworks by including the sorts as type parameters/metadata. For instance, for Metatheory, we create a wrapper struct `Var{s::Sort}` which wraps a `Metatheory.Id` (note that `s::Sort` rather than `s<:Sort`) We then define our methods on this struct as + +```julia +unop_dec = [:∂ₜ, :d, :★, :-, :♯, :♭] +for unop in unop_dec + @eval begin + @nospecialize + function $unop(v::Var{s}) where s + s′ = $unop(s) + Var{s′}(roe(v), addcall!(graph(v), $unop, (id(v),))) + end + + export $unop + end +end +``` + +SymbolicUtils.jl is totally analogous: + +```julia +unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] +for unop in unop_dec + @eval begin + @nospecialize + function $unop(v::SymbolicUtils.BasicSymbolic{DECVar{s}}) where s + s′ = $unop(s) + SymbolicUtils.Term{DECVar{s′}}($unop, [v]) + end + + export $unop + end +end +``` + +SymbolicUtils gets confused when the type parameter to `BasicSymbolic` is not a type: we work around this by passing in `DECVar{s}` (name subject to change), which is a zero-field struct that wraps a `Sort` as a type parameter. + +## Models and namespacing + +Models are then plain old Julia functions that accept as their first argument a "roe." The point of the "Roe" is to record variables that are created and equations that are asserted by the function. It looks something like: + +```julia +struct Roe{T} + vars::Dtry{T} + eqs::Vector{Tuple{T, T}} +end +``` + +The type parameter `T` could be instantiated with `BasicSymbolic{<:DECVar}` or `Var`, depending on whether we are working with SymbolicUtils or Metatheory. + +So, for instance, the Klausmeier model might look like + +```julia +function klausmeier(roe::Roe) + @vars roe n::DualForm0 w::DualForm0 dX::Form1 a::Constant{DualForm0} ν::Constant{DualForm0} + @vars roe m::Number + # The equations for our model + @eq roe (∂ₜ(w) == a + w + (w * (n^2)) + ν * L(dX,w)) + @eq roe (∂ₜ(n) == w * n^2 - m*n + Δ(n)) +end +``` + +Namespacing is achieved by moving the roe into a namespace before passing it into submodels. So, for instance, to make a model with two Klausmeier submodels that share the same `m`, we could do: + +```julia +function double_klausmeier(roe::Roe) + klausmeier(namespaced(roe, :k1)) + klausmeier(namespaced(roe, :k2)) + + @eq roe (roe.k1.m == roe.k2.m) +end +``` + +The implementation of `namespace` would be something like + +```julia +function namespace(roe::Roe{T}, name::Symbol) where {T} + Roe(get(roe.vars, name, Dtry{T}()), roe.eqs) +end +``` + +Here, `get` either gets a pre-existing subnamespace at `name`, or creates a new subnamespace and inserts it at `name`. An alternative implementation would have a `namespace::Vec{Symbol}` field on `Roe` which is used to prefix every newly created variable. + +An alternative to adding the equation `roe.k1.m == roe.k2.m` would be to have `klausmeier` take `m` as a parameter, which would look like + +```julia +function klausmeier(roe::Roe, m) + @vars roe n::DualForm0 w::DualForm0 dX::Form1 a::Constant{DualForm0} ν::Constant{DualForm0} + # The equations for our model + @eq roe (∂ₜ(w) == a + w + (w * (n^2)) + ν * L(dX,w)) + @eq roe (∂ₜ(n) == w * n^2 - m*n + Δ(n)) +end + +function double_klausemeier(roe::Roe) + @vars roe m::Number + + klausmeier(namespaced(roe, :k1), m) + klausmeier(namespaced(roe, :k2), m) +end +``` diff --git a/roe/docs/literate/egraphs.jl b/roe/docs/literate/egraphs.jl new file mode 100644 index 0000000..b3b11cf --- /dev/null +++ b/roe/docs/literate/egraphs.jl @@ -0,0 +1,22 @@ +# This lesson covers the internals of Metatheory.jl-style E-Graphs. Let's reuse the heat_equation model on a new roe. +roe = Roe(DEC.ThDEC.Sort); +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# We apply the model to the roe and collect its state variables. +(state, _) = heat_equation(roe) + +# Recall from the Introduction that an E-Graph is a bipartite graph of ENodes and EClasses. Let's look at the EClasses: +classes = roe.graph.classes +# The keys are Metatheory Id types which store an Id. The values are EClasses, which are implementations of equivalence classes. Nodes which share the same EClass are considered equivalent. + + + +# The constants in Roe are a dictionary of hashes of functions and constants. Let's extract just the values again: +vals = collect(values(e.graph.constants)) +# The `u` is ::RootVar{Form} and ∂ₜ, ★, d are all functions defined in ThDEC/signature.jl file. diff --git a/roe/docs/literate/heatequation.jl b/roe/docs/literate/heatequation.jl new file mode 100644 index 0000000..03b1789 --- /dev/null +++ b/roe/docs/literate/heatequation.jl @@ -0,0 +1,74 @@ +# Load AlgebraicJulia dependencies +using DEC +import DEC.ThDEC: Δ # conflicts with CombinatorialSpaces + +# load other dependencies +using ComponentArrays +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} +using CairoMakie + +## Here we define the 1D heat equation model with one state variable and no parameters. That is, given an e-graph "roe," we define `u` to be a primal 0-form. The root variable carries a reference to the e-graph which it resides in. We then assert that the time derivative of the state is just its Laplacian. We return the state variable. +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# Since this is a model in the DEC, we need to initialize the primal and dual meshes. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# Now that we have a dual mesh, we can associate operators in our theory with precomputed matrices from Decapodes.jl. +op_lookup = ThDEC.precompute_matrices(d_rect, DiagonalHodge()) + +# Now we produce a "vector field" function which, given a model and operators in a theory, returns a function to be passed to the ODESolver. In stages, this function +# +# 1) extracts the Root Variables (state or parameter term) and runs the extractor along the e-graph, +# 2) extracts the derivative terms from the model into an SSA +# 3) yields a function accepting derivative terms, state terms, and parameter terms, whose body is both the lines, and derivatives. +vf = vfield(heat_equation, op_lookup) + +# Let's initialize the +U = first.(d_rect[:point]); + +# TODO component arrays +constants_and_parameters = () + +# We will run this for 500 timesteps. +t0 = 500.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +## 1-D HEAT EQUATION WITH DIFFUSIVITY + +function heat_equation_with_constants(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + ℓ = fresh!(roe, Scalar(), :ℓ) + + ∂ₜ(u) ≐ k * Δ(u) + + ([u], [k]) +end + +# we can reuse the mesh and operator lookup +vf = vfield(heat_equation_with_constants, operator_lookup) + +# we can reuse the initial condition U but are specifying diffusivity constants. +constants_and_parameters = ComponentArray(k=0.25,); +t0 = 500 + +@info("Precompiling solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + + diff --git a/roe/docs/literate/tutorial.jl b/roe/docs/literate/tutorial.jl new file mode 100644 index 0000000..31d1885 --- /dev/null +++ b/roe/docs/literate/tutorial.jl @@ -0,0 +1,52 @@ +# This tutorial is a slower-paced introduction into the design. Here, we will construct a simple exponential model. +using DEC +using Test +using Metatheory.EGraphs +using ComponentArrays +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} + +using CairoMakie + +# We define our model of exponential growth. This model is a function which accepts a Roe and returns a tuple of State and Parameter variables. Let's break it down: +# +# 1. Function adds root variables (::RootVar) to the Roe. The root variables have no child nodes. +# 2. Our model makes claims about what terms equal one another. The "≐" operator is an infix of "equate!" which claims unites the ids of the left and right VecExprs. +# 3. The State and Parameter variables are returned. Each variable points to the same parent Roe. +# +# +# Each variable points to the same Roe. +function exp_growth(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + + ∂ₜ(u) ≐ k * u + + ([u], [k]) +end + +# We now need to initialize the primal and dual meshes we'll need to compute with. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# For the theory of the DEC, we will need to associate each operator to the precomputed matrix specific to our dual mesh. +operator_lookup = ThDEC.create_dynamic_model(d_rect, DiagonalHodge()) + +# We now need to convert our model to an ODEProblem. In our case, ``vfield`` produces +vf = vfield(exp_growth, operator_lookup) + +U = first.(d_rect[:point]); + +constants_and_parameters = ComponentArray(k=-0.5,) + +t0 = 50.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +save_dynamics(soln, "decay.gif") + diff --git a/roe/docs/make.jl b/roe/docs/make.jl new file mode 100644 index 0000000..7fa424c --- /dev/null +++ b/roe/docs/make.jl @@ -0,0 +1,57 @@ +using Documenter +using Literate +using Distributed + +using DEC + +using CairoMakie + +# Set Literate.jl config if not being compiled on recognized service. +# config = Dict{String,String}() +# if !(haskey(ENV, "GITHUB_ACTIONS") || haskey(ENV, "GITLAB_CI")) +# config["nbviewer_root_url"] = "https://nbviewer.jupyter.org/github/AlgebraicJulia/DEC.jl/blob/gh-pages/dev" +# config["repo_root_url"] = "https://github.com/AlgebraicJulia/Decapodes.jl/blob/main/docs" +# end + +const literate_dir = joinpath(@__DIR__, "..", "examples") +const generated_dir = joinpath(@__DIR__, "src", "examples") + +@info "Building literate files" +for (root, dirs, files) in walkdir(literate_dir) + out_dir = joinpath(generated_dir, relpath(root, literate_dir)) + pmap(files) do file + f,l = splitext(file) + if l == ".jl" && !startswith(f, "_") + Literate.markdown(joinpath(root, file), out_dir; + config=config, documenter=true, credit=false) + Literate.notebook(joinpath(root, file), out_dir; + execute=true, documenter=true, credit=false) + end + end +end +@info "Completed literate" + +pages = Any[] +push!(pages, "DEC.jl" => "index.md") +push!(pages, "Library Reference" => "api.md") + +@info "Building Documenter.jl docs" +makedocs( + modules = [DEC], + format = Documenter.HTML( + assets = ["assets/analytics.js"], + ), + remotes = nothing, + sitename = "DEC.jl", + doctest = false, + checkdocs = :none, + pages = pages) + + +@info "Deploying docs" +deploydocs( + target = "build", + repo = "github.com/AlgebraicJulia/DEC.jl.git", + branch = "gh-pages", + devbranch = "main" +) diff --git a/roe/scratch/tc.jl b/roe/scratch/tc.jl new file mode 100644 index 0000000..ee490ce --- /dev/null +++ b/roe/scratch/tc.jl @@ -0,0 +1,240 @@ +using Metatheory +using Metatheory: OptBuffer +using Metatheory.Library +using Metatheory.Rewriters +using MLStyle +using Test +using Metatheory.Plotting + +b = OptBuffer{UInt128}(10) + +@testset "Predicate Assertions" begin + r = @rule ~a::iseven --> true + Base.iseven(g, ec::EClass) = + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Number && iseven(c) + end + false + end + # + g = EGraph(:(f(2, 1))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + g = EGraph(:2) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + # + g = EGraph(:3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + new_id = addexpr!(g, :f) + union!(g, g.root, new_id) + # + new_id = addexpr!(g, 4) + union!(g, g.root, new_id) + # + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + +### + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +@testset "Check form" begin + + function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c))" + return c isa Form + end + false + end + end + + r = @rule ~a::isform --> true + + t = @theory a begin + ~a::isform + ~b::isform --> 0 + end + + ## initialize and sanity-check + a1=Form(1); a2=Form(2) + a3=Form(1); a4=Form(1) + @assert a1 isa Form; @assert a3 isa Form + @assert a3 isa Form; @assert a4 isa Form + + g = EGraph(:($a1 + $a2)) + saturate!(g, t) + extract!(g, astsize) + + g = EGraph(:a1) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + g = EGraph(:a3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + + g = EGraph(:(a + b )) + saturate!(g, t) + extract!(g, astsize) + + +end + + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +t = @theory a b begin + a::Form(1) ∧ b::Form(2) --> 0 +end +# breaks! makevar @ MT/src/Syntax.jl:57 (<- makepattern, ibid:151) +# expects Symbol, not Expr + +function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c)), $(c isa Form)" + return c isa Form + end + end +end + +r = @rule ~a::isform --> true + +g = EGraph(:(f(2, 1))) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +g = EGraph(:2) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + +g = EGraph(:3) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +new_id = addexpr!(g, :f) +union!(g, g.root, new_id) + +new_id = addexpr!(g, 4) +union!(g, g.root, new_id) + +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + + + + +a=Form(1) +b=Form(2) +c=Form(1) +d=Form(1) +@assert a isa Form +@assert b isa Form +@assert c isa Form +@assert d isa Form + +ex = :(a + b + c + d) +g = EGraph(ex) +saturate!(g, _T) +extract!(g, astsize) + + +@data Sort1 <: AbstractSort begin + AnyForm +end + +_t = @theory a b begin + a::AnyForm ∧ b::AnyForm --> 0 +end +# PatVar error + +__t = @theory a b begin + a::var"AnyForm's constructor" + 0 --> a +end + +d0 = AnyForm +d1 = AnyForm +d2 = AnyForm +ex = :(d0 + 0 + (d1 + 0) + d2) + +g = EGraph(ex) +saturate!(g, __t) +extract!(g, astsize) + + +rwth = Fixpoint(Prewalk(Chain(__t))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, _t) +extract!(g, astsize) +# returns (a ∧ b), as + +rwth = Fixpoint(Prewalk(Chain(_t))) +rwth(ex) + +a = Form(0) +ex = :(a ∧ b) + + +Derivative = @theory f g a begin + f::Function * d(g::Function) + d(f::Function) * g::Function --> d(f * g) + d(f::Function) + d(g::Function) --> d(f + g) + d(a::Number * f::Function) --> a * d(f) +end + +_Derivative = @theory f g a begin + f * d(g) + d(f) * g --> d(f * g) + d(f) + d(g) --> d(f + g) + d(a * f) --> a * d(f) +end + +ex = :(f * d(g) + d(f) * g) + +rwth = Fixpoint(Prewalk(Chain(_Derivative))) +rwth(ex) + +foo(x) = x + 1 +goo(x) = x + 3 + +rwth = Fixpoint(Prewalk(Chain(Derivative))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, Derivative); +extract!(g, astsize) + + +rwth(ex) + + +types = (U=Form(0), k=Scalar(),); + +tc(x) = @match x begin + s::Symbol => types[s] + ::Expr => true +end + +cond = x -> begin + @info "$x, $(type(x))" + tc(x) +end + +orw = Fixpoint(Prewalk(If(cond, Chain(rewrite_theory)))) + +orw(expr) diff --git a/roe/src/DEC.jl b/roe/src/DEC.jl new file mode 100644 index 0000000..fe02946 --- /dev/null +++ b/roe/src/DEC.jl @@ -0,0 +1,107 @@ +module DEC + +using Reexport + +using MLStyle +using Reexport +using StructEquality +# import Metatheory +# using Metatheory: EGraph, EGraphs, Id, astsize +# using Metatheory: VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH +# import Metatheory: extract! + +import Base: +, -, * + +include("util/module.jl") # Pretty-printing +include("roe.jl") # Checking signature for DEC operations +# include("SSAs.jl") # manipulating SSAs +# include("vfield.jl") # producing a vector field function + +# currently this only holds the DEC +include("theories/module.jl") + +# @reexport using .Util +# @reexport using .SSAs +@reexport using .Theories + +# function vfield(model, operator_lookup::Dict{TA, Any}) +# roe = Roe(DEC.ThDEC.Sort) + +# (state_vars, param_vars) = model(roe) +# length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") +# state_rootvars = map(state_vars) do x +# rv = extract!(x) +# rv isa RootVar ? rv : error("all state variables must be RootVars") +# end +# param_rootvars = map(param_vars) do p +# rv = extract!(p) +# rv isa RootVar ? rv : error("all param variables must be RootVars") +# end + +# u = :u +# p = :p +# du = :du + +# rootvar_lookup = +# Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}}( +# [ +# [rv => (:($(u)), false) for rv in state_rootvars]; +# [rv => (:($(p)), true) for rv in param_rootvars] +# ] +# ) + +# cost = derivative_cost(Set([state_rootvars; param_rootvars])) + +# extractor = EGraphs.Extractor(roe.graph, cost, Float64) + +# function term_select(id) +# EGraphs.find_best_node(extractor, id) +# end + +# ssa = SSA() + +# # TODO overload extract! to index by graph +# derivative_vars = map(state_vars) do v +# extract!(roe.graph, ssa, (∂ₜ(v)).id, term_select) +# end + +# toexpr(v::DEC.SSAs.Var) = Symbol("tmp%$(v.idx)") + +# function toexpr(expr::Term) +# @match expr.head begin +# ::RootVar => @match rootvar_lookup[expr.head] begin +# (v, false) => v +# # evaluates in DEC.k, and this gets the index +# (v, true) => Expr(:ref, v, expr.head.name) +# end +# ::Number => expr.head +# _ => begin +# op = get(operator_lookup, TA(expr.head, first.(expr.args))) +# # Decapode operators return a tuple of functions. We choose the first of these. +# if op isa Tuple +# op = op[1] +# end +# Expr(:call, *, op, toexpr.(last.(expr.args))...) +# end +# end +# end + +# ssalines = map(enumerate(ssa.statements)) do (i, expr) +# :($(toexpr(SSAs.Var(i))) = $(toexpr(expr))) +# end + +# set_derivative_stmts = map(enumerate(derivative_vars)) do (i, v) +# :($(du) .= $(toexpr(v))) +# end + +# # yield function +# eval(quote +# f(du, u, p, _) = begin +# $(ssalines...) +# $(set_derivative_stmts...) +# end +# end) +# end +# export vfield + +end diff --git a/roe/src/SSAs.jl b/roe/src/SSAs.jl new file mode 100644 index 0000000..5d97f64 --- /dev/null +++ b/roe/src/SSAs.jl @@ -0,0 +1,141 @@ +module SSAs + +using ..DEC: AbstractSort, TypedApplication, TA, Roe, RootVar + +# other dependencies +using MLStyle +using StructEquality +using Metatheory: VecExpr +using Metatheory.EGraphs +import Metatheory: extract! + +""" Var + +A wrapper for the index of a Var +""" +@struct_hash_equal struct Var + idx::Int +end +export Var + +function Base.show(io::IO, v::Var) + print(io, "%", v.idx) +end + +""" Term + +A wrapper for a function (::Any) and its args (::Vector{Tuple{Sort, Var}}). + +Example: the equation +``` + a = 1 + b +``` +may have an SSA dictionary +``` + %1 => a + %2 => +(%1, %3) + %3 => b +``` +and so `+` would have +``` +Term(+, [(Scalar(), Var(1)), (Scalar(), Var(2))]) +``` +""" +@struct_hash_equal struct Term + head::Any + args::Vector{Tuple{AbstractSort, Var}} +end +export Term + +function Base.show(io::IO, e::Term) + print(io, e.head) + if length(e.args) > 0 + print(io, Expr(:tuple, (Expr(:(::), v, sort) for (sort, v) in e.args)...)) + end +end + +""" + +Struct defining Static Single-Assignment information for a given roe. + +Advantages of SSA form: + +1. We can preallocate each matrix +2. We can run a register-allocation algorithm to minimize the number of matrices that we have to preallocate +""" +struct SSA + assignment_lookup::Dict{Id, Var} + statements::Vector{Term} + function SSA() + new(Dict{Id, Var}(), Term[]) + end +end +export SSA + +# accessors +statements(ssa::SSA) = ssa.statements +export statements + +# show methods + +function Base.show(io::IO, ssa::SSA) + println(io, "SSA: ") + for (i, expr) in enumerate(statements(ssa)) + println(io, " ", Var(i), " = ", expr) + end +end + +""" add_stmt!(ssa::SSA, id::Id, expr::Term)::Var + +Low-level function which, given an SSA, adds a Term onto the assignment_lookup. Users should use `extract!` instead. + +""" +function add_stmt!(ssa::SSA, id::Id, expr::Term) + push!(ssa.statements, expr) + v = Var(length(ssa.statements)) + ssa.assignment_lookup[id] = v + v +end + +Base.contains(ssa::SSA, id::Id) = haskey(ssa.assignment_lookup, id) +export contains + +Base.getindex(ssa::SSA, id::Id) = ssa.assignment_lookup[id] +export getindex + +""" + extract!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::Var + +This function adds (recursively) the necessary lines to the SSA in order to +compute a value for `id`, and then returns the Var that the value for `id` +will be assigned to. + +The closure parameters control the behavior of this function. + + term_select(id::Id)::VecExpr + +This closure selects, given an id in an EGraph, the term that we want to use in +order to compute a value for that id + +""" +function extract!(g::EGraph, ssa::SSA, id::Id, term_select) + if contains(ssa, id) + return getindex(ssa, id) + end + term = term_select(id) + args = map(EGraphs.v_children(term)) do arg + (g[arg].data, extract!(g, ssa, arg, term_select)) + end + add_stmt!(ssa, id, Term(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) +end +export extract! + +function extract!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) + extract!(g, ssa, id, term_select) +end + +function extract!(roe::Roe{S}, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) where S + extract!(roe, ssa, id, term_select) +end + +end diff --git a/roe/src/egraph/roe.jl b/roe/src/egraph/roe.jl new file mode 100644 index 0000000..dc72c2c --- /dev/null +++ b/roe/src/egraph/roe.jl @@ -0,0 +1,215 @@ +""" RootVar + +A childless node on an e-graph. + +""" +@struct_hash_equal struct RootVar{Sort<:AbstractSort} + name::Symbol + idx::Int + sort::Sort + + function RootVar(name::Symbol, idx::Int, sort::Sort) where Sort + new{Sort}(name, idx, sort) + end +end +export RootVar + +""" Roe + +Struct for storing an EGraph and its variables. + +Roe is the name for lobster eggs. "Egg" is the name of a Rust implementation of e-graphs, by which Metatheory.jl is inspired by. Lobsters are part of the family Decapodes, which is also the name of the AlgebraicJulia package which motivated this package. Hence, Roe. +""" +struct Roe{Sort<:AbstractSort} + variables::Vector{RootVar} + graph::EGraph{Expr, Sort} + + function Roe(Sort::DataType) + new{Sort}(RootVar[], EGraph{Expr, Sort}()) + end +end +export Roe + +# accessors +variables(roe::Roe{S}) where S = roe.variables +graph(roe::Roe{S}) where S = roe.graph +param(roe::Roe{S}) where S = S +export variables, graph, param + +""" + +A struct containing a Roe and the Id of a variable in that EGraph. The type parameter for this struct is the variable it represents. + +""" +struct Var{S} + roe::Roe + id::Id +end + +# accessors +roe(v::Var{S}) where S = v.roe +graph(v::Var{S}) where S = roe(v).graph +id(v::Var{S}) where S = v.id +export roe, graph, id + +# MAKE AND JOIN + +function EGraphs.make(g::EGraph{Expr, Sort}, n::VecExpr) where {Sort<:AbstractSort} + op = EGraphs.get_constant(g, Metatheory.v_head(n)) + @match op begin + ::RootVar => op.sort + ::Number => Scalar() + _ => op((g[arg].data for arg in Metatheory.v_children(n))...) + end +end + +function EGraphs.join(s1::S1, s2::S2) where {S1<:AbstractSort,S2<:AbstractSort} + s1 == s2 ? s1 : throw(JoinError(s1, s2)) +end + +# EXTRACT + +function extract!(v::Var, f=EGraphs.astsize) + extract!(v.roe.graph, f, v.id) +end + +function rootvarcrayon(v::RootVar) + lightnessrange = (50., 100.) + HashColor.hashcrayon(v.idx; lightnessrange, chromarange=(50., 100.)) +end + +function Base.show(io::IO, v::RootVar) + if get(io, :color, true) + crayon = rootvarcrayon(v) + print(io, crayon, "$(v.name)") + print(io, inv(crayon)) + else + print(io, "$(v.name)#$(v.idx)") + end +end + +""" fix_functions(e)::Union{Symbol, Expr} + +Used in the show method for Vars. Traverses the AST of an expression, replacing the head of :call expressions to its name, a Symbol. +""" +function fix_functions(e) + @match e begin + s::Symbol => s + Expr(:call, f::Function, args...) => Expr(:call, nameof(f), fix_functions.(args)...) + Expr(head, args...) => Expr(head, fix_functions.(args)...) + _ => e + end +end + +""" getexpr(v::Var)::Union{Symbol, Expr} + +Extracts an expression (::Var) from its Roe. + +""" +function getexpr(v::Var) + e = EGraphs.extract!(v.roe.graph, Metatheory.astsize, v.id) + fix_functions(e) +end +export getexpr + +function Base.show(io::IO, v::Var) + print(io, getexpr(v)) +end + +""" fresh!(roe::Roe, sort::AbstractSort, name::Symbol)::Var{sort} + +Creates a new variable in a Roe. Specifically, it appends a new RootVar with a given a sort and name to the Roe, adds that RootVar to the e-graph, and returns a Var wrapper around the new e-graph Id, with type parameter given by the sort. + +Example: +``` +fresh!(roe, Form(0), :Temp) +``` +""" +function fresh!(roe::Roe, sort::Sort, name::Symbol) where {Sort<:AbstractSort} + v = RootVar(name, length(roe.variables), sort) + push!(roe.variables, v) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, v)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) + id = EGraphs.add!(roe.graph, n, false) + Var{sort}(roe, id) +end +export fresh! + + +@nospecialize +""" inject_number!(roe::Roe, x::Number)::Var{Scalar()} + +Adds a number to the Roe as a EGraph constant. + +""" +function inject_number!(roe::Roe, x::Number) + x = Float64(x) + n = Metatheory.v_new(0) + Metatheory.v_set_head!(n, EGraphs.add_constant!(roe.graph, x)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) + Var{Scalar()}(roe, EGraphs.add!(roe.graph, n, false)) +end +export inject_number! + +@nospecialize +""" addcall!(g::EGraph, head, args):: + +Adds a call to an EGraph. + +""" +function addcall!(g::EGraph, head, args) + ar = length(args) + n = Metatheory.v_new(ar) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISTREE) + Metatheory.v_set_flag!(n, VECEXPR_FLAG_ISCALL) + Metatheory.v_set_head!(n, EGraphs.add_constant!(g, head)) + Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(head), hash(ar))) + for i in Metatheory.v_children_range(n) + @inbounds n[i] = args[i - VECEXPR_META_LENGTH] + end + EGraphs.add!(g, n, false) +end +export addcall! + +""" equate!(v1::Var{s1}, v2::Var{s2})::EGraph + +Asserts that two variables of the same e-graph are the same. This is done by returning the union of the variable ids with the e-graph. +""" +function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + (s1 == s2) || throw(JoinError(s1, s2)) + v1.roe === v2.roe || throw(EquateError()) + union!(v1.roe.graph, v1.id, v2.id) +end +export equate! + +""" +Infix synonym for `equate!` +""" +≐(v1::Var, v2::Var) = equate!(v1, v2) +export ≐ + +@nospecialize +""" derivative_cost(allowed_roots)::Function + +Returns a function `cost(n::Metatheory.VecExpr, op, costs)` which sets the cost of operations to Inf if they are either ∂ₜ or forbidden RootVars. Otherwise it computes the astsize. + +""" +function derivative_cost(allowed_roots) + function cost(n::VecExpr, op, costs) + if op == ∂ₜ || (op isa RootVar && op ∉ allowed_roots) + Inf + else + astsize(n, op, costs) + end + end +end +export derivative_cost + +# EXCEPTIONS + +struct JoinError <: Exception; s1::AbstractSort; s2::AbstractSort end +Base.showerror(io::IO, e::JoinError) = print(io, "Cannot equate two nodes with different sorts: attempted to equate $(e.s1) with $(e.s2)") + +struct EquateError <: Exception end +Base.showerror(io::IO, e::EquateError) = print(io, "Cannot equate variables from different graphs") diff --git a/roe/src/roe.jl b/roe/src/roe.jl new file mode 100644 index 0000000..3882900 --- /dev/null +++ b/roe/src/roe.jl @@ -0,0 +1,215 @@ +using ..Util.HashColor + +using Base: Workqueue +using StructEquality +using ComponentArrays +using MLStyle +using Dtries +using SymbolicUtils +using SymbolicUtils: BasicSymbolic, Symbolic +using Reexport +using AssociatedTests +using Test + +""" +Sorts in each theory are subtypes of this abstract type. +""" +abstract type AbstractSort end +export AbstractSort + +""" TypedApplication + +Struct containing a Function and the vector of Sorts it requires. +""" +@struct_hash_equal struct TypedApplication{Sort<:AbstractSort} + head::Function + sorts::Vector{Sort} + + function TypedApplication(head::Function, sorts::Vector{Sort}) where {Sort} + new{Sort}(head, sorts) + end +end +export TypedApplication + +const TA = TypedApplication +export TA + +Base.show(io::IO, ta::TA) = print(io, Expr(:call, nameof(ta.head), ta.sorts...)) + +struct SortError <: Exception + message::String +end +export SortError + +Base.get(lookup::Dict{TA,Any}, key::TA) = lookup[key] +export get + +# DECVar{s} +struct OfSort{s} <: Number end +export OfSort + +struct SortMetadata end + +struct Equation{E} + lhs::E + rhs::E +end + +function Base.:(==)(eq1::Equation, eq2::Equation) + isequal(eq1.lhs, eq2.lhs) && isequal(eq1.rhs, eq2.rhs) +end + +""" Roe +""" +struct Roe + working_path::Path + vars::Dtry{Symbolic} + eqs::Vector{Equation{Symbolic}} +end +export Roe + +function Roe() + Roe(Path([]), Dtry{Symbolic}(), Equation{Symbolic}[]) +end + +working_path(roe::Roe) = getfield(roe, :working_path) +vars(roe::Roe) = getfield(roe, :vars) +eqs(roe::Roe) = getfield(roe, :eqs) +export vars, eqs + +@tests Roe begin + roe = Roe() + @test roe isa Roe + @test working_path(roe) == Path([]) + @test vars(roe) == Dtries.Empty{Symbolic}() + @test eqs(roe) == Equation{Symbolic}[] +end + +function Base.getindex(roe::Roe, path::Path)::Symbolic + path = working_path(roe) * path + vars(roe)[path] +end + +function Base.setindex!(roe::Roe, v::Symbolic, path::Path) + setindex!(vars(roe), v, working_path(roe) * path) +end + +@tests Tuple{getindex,setindex!} begin + roe = Roe() + @syms a + + roe[Path([:a, :b])] = a + @test roe[Path([:a, :b])] === a + roe[Path([:a, :c])] = a + @test roe[Path([:a, :c])] === a + + @test_throws Exception (roe[Path([:a])] = a) +end + +function fresh!(roe::Roe, name::Symbol, sort) + v = SymbolicUtils.Sym{OfSort{sort}}(name) + roe[Path([name])] = v + v +end +export fresh! + +@tests fresh! begin + roe = Roe() + x = fresh!(roe, :x, 1) + + @test x isa SymbolicUtils.Sym{OfSort{1}} + @test roe[Path([:x])] === x +end + +function Base.getproperty(roe::Roe, x::Symbol) + p = working_path(roe) * Path([x]) + @match Dtries.lookup(vars(roe), p) begin + Some(v) => v + Nothing => Roe(p, vars(roe), eqs(roe)) + end +end + +@tests getproperty begin + roe = Roe() + b = fresh!(roe.a, :b, 1) + + @test roe.a.b === b +end + +""" @vars +Example: @vars roe u::Form(0) + +```julia +@vars roe u::Form(0) + +-> + +u = fresh!(roe, :u, Form(0)) +``` +""" +macro vars(roe, vars...) + vars = parse_var.(vars) + stmts = map(vars) do (n, t) + :($n = $(fresh!)($roe, $(QuoteNode(n)), $t)) + end + esc(Expr(:block, stmts...)) +end +export @vars + +parse_var = @λ begin + s::Symbol => (s, Number) + Expr(:(::), name, type) => (name, type) + err => error("$err is not a valid expression for a symbolic variable") +end + +@tests var"@vars" begin + roe = Roe() + + @vars roe a b + + @test roe.a === a + @test roe.b === b + + @vars roe.x a b + + @test roe.x.a === a + @test roe.x.b === b +end + +function equate(roe::Roe, lhs::Symbolic, rhs::Symbolic) + push!(eqs(roe), Equation{Symbolic}(lhs, rhs)) +end + +macro eq(roe, eq_expr) + (lhs, rhs) = parse_eq(eq_expr) + esc(quote + $(equate)($roe, $lhs, $rhs) + end) +end + +parse_eq = @λ begin + Expr(:call, :(==), lhs, rhs) => (lhs, rhs) + err => error("$err is not a valid expression for an equation") +end + +@tests var"@eq" begin + roe = Roe() + + @vars roe a b + + @eq roe a == b + + @eq roe (a + b == b + a) + + @test eqs(roe) == [ + Equation{Symbolic}(a, b), + Equation{Symbolic}(a + b, b + a) + ] +end + +function instantiate(f) + roe = Roe() + f(roe) + roe +end +export instantiate diff --git a/roe/src/theories/ThDEC/ThDEC.jl b/roe/src/theories/ThDEC/ThDEC.jl new file mode 100644 index 0000000..4e5dc49 --- /dev/null +++ b/roe/src/theories/ThDEC/ThDEC.jl @@ -0,0 +1,15 @@ +module ThDEC + +# using ...DEC: TypedApplication, TA, Roe, RootVar +# using ...DEC.SSAs + +# using Metatheory: VecExpr +# using Metatheory.EGraphs + +include("signature.jl") # verify operations type-check +include("symbolicutils_overloads.jl") # overload DEC operations to act on roe (egraphs) +# include("semantics.jl") # represent operations as matrices + +# include("rewriting.jl") + +end diff --git a/roe/src/theories/ThDEC/egraphs.jl b/roe/src/theories/ThDEC/egraphs.jl new file mode 100644 index 0000000..a6e89ac --- /dev/null +++ b/roe/src/theories/ThDEC/egraphs.jl @@ -0,0 +1,15 @@ +## SIGNATURE + +# Predicates +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + + diff --git a/roe/src/theories/ThDEC/rewriting.jl b/roe/src/theories/ThDEC/rewriting.jl new file mode 100644 index 0000000..81182e5 --- /dev/null +++ b/roe/src/theories/ThDEC/rewriting.jl @@ -0,0 +1,56 @@ +using Metatheory +using Metatheory.Library +using Metatheory.Rewriters +using MLStyle + +buf = OptBuffer{UInt128}(10) + +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + +t = @theory a b begin + ~a::isForm + ~b::isForm --> 0 +end + + +ThMultiplicativeMonoid = @commutative_monoid (*) 1 +ThAdditiveGroup = @commutative_group (+) 0 (-) +Distributivity = @distrib (*) (+) +ThRing = ThMultiplicativeMonoid ∪ ThAdditiveGroup ∪ Distributivity + +Derivative = @theory (f, g)::Function, a::Scalar begin + f * d(g) + d(f) * g --> d(f * g) + d(f) + d(g) --> d(f + g) + d(a * f) --> a * d(f) +end + + +# e = :(f * d(g) + d(f) * g) +# g = EGraph(e) +# saturate!(g, product) +# extract!(g, astsize) + +zero = @theory f begin + f * 0 --> 0 + f + 0 --> f + 0 + f --> f +end + +square_zero = @theory ω begin + d(d(ω)) --> 0 +end + +linearity = @theory f g a begin + Δ(f + g) == Δ(f) + Δ(g) + Δ(a * f) == a * Δ(f) +end +export linearity + diff --git a/roe/src/theories/ThDEC/roe_overloads.jl b/roe/src/theories/ThDEC/roe_overloads.jl new file mode 100644 index 0000000..837f75d --- /dev/null +++ b/roe/src/theories/ThDEC/roe_overloads.jl @@ -0,0 +1,57 @@ +using ...DEC: Var, addcall!, inject_number! +using ...DEC: roe, graph, id # Var{S} accessors + +import Base: +, -, * + +# These operations create calls on a common egraph. We validate the signature by dispatching the operation on the types using methods we defined in Signature. + +## UNARY OPERATIONS + +unop_dec = [:∂ₜ, :d, :★, :-, :♯, :♭] +for unop in unop_dec + @eval begin + @nospecialize + function $unop(v::Var{s}) where s + s′ = $unop(s) + Var{s′}(roe(v), addcall!(graph(v), $unop, (id(v),))) + end + + export $unop + end +end + +# Δ is a composite of Hodge star and d +Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) +export Δ +# TODO this could be a rewriting rule instead? + +♭♯(v::Var{DualVF()}) = ♯(♭(v)) +export ♭♯ + +## BINARY OPERATIONS + +binop_dec = [:+, :-, :*, :∧] +for binop in binop_dec + @eval begin + @nospecialize + function $binop(v1::Var{s1}, v2::Var{s2}) where {s1, s2} + roe(v1) === roe(v2) || throw(BinopError($binop)) + s = $binop(s1, s2) + Var{s}(v1.roe, addcall!(graph(v1), $binop, (id(v1), id(v2)))) + end + + @nospecialize + $binop(v::Var, x::Number) = $binop(v, inject_number!(roe(v), x)) + + @nospecialize + $binop(x::Number, v::Var) = $binop(inject_number!(roe(v)), x) + + export $binop + end +end + +struct BinopError <: Exception + binop::Symbol +end + +Base.showerror(io::IO, e::BinopError) = print(io, "Cannot use '$binop' on variables from different graphs.") diff --git a/roe/src/theories/ThDEC/semantics.jl b/roe/src/theories/ThDEC/semantics.jl new file mode 100644 index 0000000..f0d031f --- /dev/null +++ b/roe/src/theories/ThDEC/semantics.jl @@ -0,0 +1,73 @@ +import Decapodes +using StructEquality + + +""" create_dynamic_model(sd, hodge)::Dict{TypedApplication, Any} + +Given a matrix and a hodge star (DiagonalHodge() or GeometricHodge()), this returns a lookup dictionary between operators (as TypedApplications) and their corresponding matrices. + +""" +function create_dynamic_model(sd, hodge)::Dict{TypedApplication, Any} + Dict{TypedApplication, Any}( + TA(*, Sort[Scalar(), Scalar()]) => 1, + TA(*, Sort[Scalar(), PrimalForm(0)]) => 1, + # Regular Hodge Stars + TA(★, Sort[PrimalForm(0)]) => Decapodes.dec_mat_hodge(0, sd, hodge), + TA(★, Sort[PrimalForm(1)]) => Decapodes.dec_mat_hodge(1, sd, hodge), + TA(★, Sort[PrimalForm(2)]) => Decapodes.dec_mat_hodge(2, sd, hodge), + + # Inverse Hodge Stars + TA(★, Sort[DualForm(0)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), + # TODO verify ^ why is this 1??? + TA(★, Sort[DualForm(1)]) => Decapodes.dec_pair_inv_hodge(Val{1}, sd, hodge), + # Special since Geo is a solver + TA(★, Sort[DualForm(2)]) => Decapodes.dec_mat_inverse_hodge(0, sd, hodge), + + # Differentials + TA(d, Sort[PrimalForm(0)]) => Decapodes.dec_mat_differential(0, sd), + TA(d, Sort[PrimalForm(1)]) => Decapodes.dec_mat_differential(1, sd), + + # Dual Differentials + TA(d, Sort[DualForm(0)]) => Decapodes.dec_mat_dual_differential(0, sd), + TA(d, Sort[DualForm(1)]) => Decapodes.dec_mat_dual_differential(1, sd), + + # Wedge Products + TA(∧, Sort[PrimalForm(0), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{1,0}, sd), + TA(∧, Sort[PrimalForm(0), PrimalForm(2)]) => Decapodes.dec_pair_wedge_product(Tuple{0,2}, sd), + TA(∧, Sort[PrimalForm(2), PrimalForm(0)]) => Decapodes.dec_pair_wedge_product(Tuple{2,0}, sd), + TA(∧, Sort[PrimalForm(1), PrimalForm(1)]) => Decapodes.dec_pair_wedge_product(Tuple{1,1}, sd), + + # Primal-Dual Wedge Products + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_pd(Tuple{0,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dp(Tuple{1,1}, sd), + TA(∧, Sort[PrimalForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dp(Tuple{1,0}, sd), + + # Dual-Dual Wedge Products + # TA(∧, Sort[DualForm(1), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{1,1}, sd), + TA(∧, Sort[DualForm(1), DualForm(0)]) => Decapodes.dec_wedge_product_dd(Tuple{1,0}, sd), + TA(∧, Sort[DualForm(0), DualForm(1)]) => Decapodes.dec_wedge_product_dd(Tuple{0,1}, sd), + + # # Dual-Dual Interior Products + TA(ι, Sort[DualForm(1), DualForm(1)]) => Decapodes.interior_product_dd(Tuple{1,1}, sd), + TA(ι, Sort[DualForm(1), DualForm(2)]) => Decapodes.interior_product_dd(Tuple{1,2}, sd), + + # # Dual-Dual Lie Derivatives + # :ℒ₁ => ℒ_dd(Tuple{1,1}, sd) + + # # Dual Laplacians + # :Δᵈ₀ => Δᵈ(Val{0},sd) + # :Δᵈ₁ => Δᵈ(Val{1},sd) + + # # Musical Isomorphisms + TA(♯, Sort[PrimalForm(1)]) => Decapodes.dec_♯_p(sd), # Primal(1) -> PVField + TA(♯, Sort[DualForm(1)]) => Decapodes.dec_♯_d(sd), # Dual(1) -> DVField + + TA(♭, Sort[DualVF()]) => Decapodes.dec_♭(sd), # DVField -> Primal(1) + + # # Averaging Operator + # :avg₀₁ => Decapodes.dec_avg₀₁(sd) + ) +end +# TODO can we use OrderedDict to retain our nice presentation? diff --git a/roe/src/theories/ThDEC/signature.jl b/roe/src/theories/ThDEC/signature.jl new file mode 100644 index 0000000..475883c --- /dev/null +++ b/roe/src/theories/ThDEC/signature.jl @@ -0,0 +1,182 @@ +using ...DEC: AbstractSort, SortError + +using MLStyle + +import Base: +, -, * + +# Define the sorts in your theory. +# For the DEC, we work with Scalars and Forms, graded objects which can also be primal or dual. +@data Sort begin + Scalar() + Form(dim::Int, isdual::Bool) + # Vector Field + VF(isdual::Bool) +end +export Scalar, Form + +# accessors +dim(ω::Form) = ω.dim +isdual(ω::Form) = ω.isdual + +# convenience functions +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm + +PrimalVF() = VF(false) +export PrimalVF + +DualVF() = VF(true) +export DualVF + +# show methods +show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" + +function Base.show(io::IO, ω::Form) + print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +end + +## Predicates +# function isForm(g, ec::EClass) +# any(ec.nodes) do n +# h = v_head(n) +# if has_constant(g, h) +# c = get_constant(g, h) +# return c isa Form +# end +# false +# end +# end + + +# function isForm(g, ec::EClass) +# any(ec.nodes) do n +# h = v_head(n) +# if has_constant(g, h) +# c = get_constant(g, h) +# return c isa Form +# end +# false +# end +# end + +## OPERATIONS + +@nospecialize +function +(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end + end +end + +# Type-checking inverse of addition follows addition +-(s1::Sort, s2::Sort) = +(s1, s2) + +# TODO error for Forms + +# Negation is valid +-(s::Sort) = s + +@nospecialize +function *(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + end +end + +@nospecialize +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) + (Form(i1, isdual), Form(i2, isdual)) => + if i1 + i2 <= 2 + Form(i1 + i2, isdual) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + _ => throw(SortError("Can only take a wedge product of two forms of the same duality")) + end +end + +@nospecialize +∂ₜ(s::Sort) = s + +@nospecialize +function d(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) + else + throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + end + end +end + +@nospecialize +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end + +@nospecialize +function ι(s1::Sort, s2::Sort) + @match (s1, s2) begin + (VF(true), Form(i, true)) => PrimalForm() # wrong + (VF(true), Form(i, false)) => DualForm() + _ => throw(SortError("Can only define the discrete interior product on: + PrimalVF, DualForm(i) + DualVF(), PrimalForm(i) + .")) + end +end + +# in practice, a scalar may be treated as a constant 0-form. +function ♯(s::Sort) + @match s begin + Scalar() => PrimalVF() + Form(1, isdual) => VF(isdual) + _ => throw(SortError("Can only take ♯ to 1-forms")) + end +end +# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. + +function ♭(s::Sort) + @match s begin + VF(true) => PrimalForm(1) + _ => throw(SortError("Can only apply ♭ to dual vector fields")) + end +end + +# OTHER + +function ♭♯(s::Sort) + @match s begin + Form(i, isdual) => Form(i, !isdual) + _ => throw(SortError("♭♯ is only defined on forms.")) + end +end + +# Δ = ★d⋆d, but we check signature here to throw a more helpful error +function Δ(s::Sort) + @match s begin + Form(0, isdual) => Form(0, isdual) + _ => throw(SortError("Δ is not defined for $s")) + end +end diff --git a/roe/src/theories/ThDEC/symbolicutils_overloads.jl b/roe/src/theories/ThDEC/symbolicutils_overloads.jl new file mode 100644 index 0000000..7129f01 --- /dev/null +++ b/roe/src/theories/ThDEC/symbolicutils_overloads.jl @@ -0,0 +1,33 @@ +using SymbolicUtils +using ...DEC + +unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] +for unop in unop_dec + @eval begin + @nospecialize + function $unop( + v::SymbolicUtils.BasicSymbolic{OfSort{s}} + ) where {s} + s′ = $unop(s) + SymbolicUtils.Term{OfSort{s′}}($unop, [v]) + end + + export $unop + end +end + +binop_dec = [:+, :-, :*, :∧] +for binop in binop_dec + @eval begin + @nospecialize + function $binop( + v::SymbolicUtils.BasicSymbolic{OfSort{s1}}, + w::SymbolicUtils.BasicSymbolic{OfSort{s2}} + ) where {s1,s2} + s′ = $binop(s1, s2) + SymbolicUtils.Term{OfSort{s′}}($binop, [v, w]) + end + + export $binop + end +end diff --git a/roe/src/theories/module.jl b/roe/src/theories/module.jl new file mode 100644 index 0000000..31e01a9 --- /dev/null +++ b/roe/src/theories/module.jl @@ -0,0 +1,9 @@ +module Theories + +using Reexport + +include("ThDEC/ThDEC.jl") # the theory of the DEC + +@reexport using .ThDEC + +end diff --git a/roe/src/util/HashColor.jl b/roe/src/util/HashColor.jl new file mode 100644 index 0000000..dd4a20b --- /dev/null +++ b/roe/src/util/HashColor.jl @@ -0,0 +1,32 @@ +module HashColor +export hashcolor, hashcrayon + +using Colors +using Random +using Crayons + +randunif(rng, range) = rand(rng) * (range[2] - range[1]) + range[1] + +""" +Returns a Color generated randomly by hashing `x`. + +This uses the LCHuv space to sample uniformly across the human visual spectrum. +""" +function hashcolor(x; lightnessrange=(0.,100.), chromarange=(0.,100.), huerange=(0.,360.)) + h = hash(x) + rng = MersenneTwister(h) + l = randunif(rng, lightnessrange) + c = randunif(rng, chromarange) + h = randunif(rng, huerange) + LCHuv{Float64}(l, c, h) +end + +""" +Returns a Crayon with foreground color given by `hashcolor` +""" +function hashcrayon(x; lightnessrange=(0.,100.), chromarange=(0.,100.), huerange=(0.,360.)) + c = hashcolor(x; lightnessrange, chromarange, huerange) + Crayon(foreground=RGB24(RGB{Float64}(c)).color) +end + +end diff --git a/roe/src/util/Plotting.jl b/roe/src/util/Plotting.jl new file mode 100644 index 0000000..82ea369 --- /dev/null +++ b/roe/src/util/Plotting.jl @@ -0,0 +1,20 @@ +module Plotting + +using CairoMakie + +function save_dynamics(soln, timespan, save_file_name) + time = Observable(0.0) + h = @lift(soln($time)) + f = Figure() + ax = CairoMakie.Axis(f[1,1], title = @lift("Heat at time $($time)")) + gmsh = mesh!(ax, rect, color=h, colormap=:jet, + colorrange=extrema(soln(timespan))) + Colorbar(f[1,2], gmsh) + timestamps = range(0, timespan, step=5.0) + record(f, save_file_name, timestamps; framerate = 15) do t + time[] = t + end +end +export save_dynamics + +end diff --git a/roe/src/util/module.jl b/roe/src/util/module.jl new file mode 100644 index 0000000..bde14db --- /dev/null +++ b/roe/src/util/module.jl @@ -0,0 +1,11 @@ +module Util + +using Reexport + +include("HashColor.jl") +include("Plotting.jl") + +@reexport using .HashColor +@reexport using .Plotting + +end diff --git a/roe/src/vfield.jl b/roe/src/vfield.jl new file mode 100644 index 0000000..eb10877 --- /dev/null +++ b/roe/src/vfield.jl @@ -0,0 +1,156 @@ +using .SSAs +using MLStyle + +""" vfield :: (Decaroe -> (StateVars, ParamVars)) -> VectorFieldFunction + +Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. + +Example: given a diffusivity constant a, the heat equation can be written as: +``` + ∂ₜ u = a * Δ(u) +``` +would return (u, a). + +A limitation of this function can be demonstrated here: given the model + ``` + ∂ₜ a = a + b + ∂ₜ b = a + b + ``` + we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, the extractor would bypass it completely. +""" +function vfield(model, op_lookup::Dict{TA, Any}) + + # ::Roe + # inttialize the Roe (e-graph) + roe = Roe(DEC.ThDEC.Sort) + + # ::Tuple{Vector{Var}, Vector{Var}} + # Pass the roe into the model function, which contributes the variables (via `fresh!`) and equations (via `equate!`). Retrieve the state and parameter variables in the model. + (state_vars, param_vars) = model(roe) + + # A model is inadmissible if there is no state variables. + length(state_vars) >= 1 || throw(VFieldError()) + + # ::Vector{RootVar} + # iterate `extract!` through the state and parameter variables. + state_rootvars = extract_rootvars!(state_vars); + param_rootvars = extract_rootvars!(param_vars); + + # TODO This is currently fixed + u = :u; p = :p; du = :du; + + # ::Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}} + rv_lookup = make_rv_lookup(state_rootvars, param_rootvars, u, p); + + # ::Function + # Return a cost function whose allowed roots are the set union of the model's rootvars. + cost = derivative_cost(Set([state_rootvars; param_rootvars])) + + # ::Extractor + # Pass the Roe's E-Graph into a Metatheory Extractor. + extractor = EGraphs.Extractor(roe.graph, cost, Float64) + + # ::SSA + ssa = SSA() + + # ::Function + term_select(id) = EGraphs.find_best_node(extractor, id); + + # ::Vector{Var} + d_vars = extract_derivative_vars!(roe, ssa, state_vars, term_select); + + # ::Tuple{Vector{Expr}, Vector{Expr}} + # convert the SSA statements and derivative variables into Julia Exprs + (ssalines, derivative_stmts) = build_result_body(ssa, d_vars, du, op_lookup, rv_lookup) + + # yield the function that will be passed to a solver + eval(quote + f(du, u, p, _) = begin + $(ssalines...) + $(derivative_stmts...) + end + end) +end +export vfield + +# Build the body of the function by returning the lines of the ssas and the derivative statments. +function build_result_body(ssa, derivative_vars, du, op_lookup, rv_lookup) + + _toexpr(term) = toexpr(term, op_lookup, rv_lookup) + + ssalines = map(enumerate(ssa.statements)) do (i, stmt) + :($(_toexpr(SSAs.Var(i))) = $(_toexpr(stmt))) + end + + derivative_stmts = map(enumerate(derivative_vars)) do (i, stmt) + :($(du) .= $(_toexpr(stmt))) + end + + return (ssalines, derivative_stmts) +end + +# For normalization purposes, I factored `toexpr` out of `vfield`. However, this means the two lookup variables were no longer in scope for `toexpr`. +# +# It is possible to thread the lookups into the arguments of the `toexpr`s, +# +# ``` +# :($(toexpr(SSAs.Var(i), lookup1, lookup2)) = $(toexpr(stmt, lookup1, lookup2))) +# ``` +# but you would also need to pass the lookup arguments for the `::Var` dispatch for `toexpr`, where the variables would not be used. +# +# Then, you could simplify this but uniting the two functions and using a conditional or @match expression. Since we are traversing a Term, we could just call the function recusively, or define one @λ. +# +# but I felt this was visually too noisy in `build_result_body`. +function toexpr(expr::Union{Term, DEC.SSAs.Var}, op_lookup, rv_lookup) + λtoexpr = @λ begin + var::DEC.SSAs.Var => Symbol("tmp%$(var.idx)") + term::Term && if term.head isa Number end => term.head + # if the head of a term is a RootVar, we'll need to ensure that we can retrieve the value from a named tuple. + # if the boolean value is false, the rootvar is a state_var, otherwise it is a parameter and assumed to be + # accessed by a named tuple. + term::Term && if term.head isa RootVar end => @match rv_lookup[term.head] begin + (rv, false) => rv + (rv, true) => Expr(:ref, rv, term.head.name) + end + # This default case is Decapodes-specific. Decapode operators return a tuple of functions, so we choose the first. + term => begin + op = get(op_lookup, TA(term.head, first.(term.args))) + if op isa Tuple; op = op[1] end + Expr(:call, *, op, λtoexpr.(last.(term.args))...) + end + end + λtoexpr(expr) +end + +# map over the state_vars to apply `extract!` +function extract_derivative_vars!(roe::Roe, ssa::SSA, state_vars, term_select::Function) + map(state_vars) do v + extract!(roe.graph, ssa, (∂ₜ(v)).id, term_select) + end +end + +# given root variables, and produce a dictionary +function make_rv_lookup(state_rvs, param_rvs, state, param) + Dict{RootVar, Tuple{Union{Expr, Symbol}, Bool}}( + [ + [rv => (:($(state)), false) for rv in state_rvs]; + [rv => (:($(param)), true) for rv in param_rvs] + ] + ) +end + +# map over vars +function extract_rootvars!(vars) + map(vars) do x + rv = extract!(x) + rv isa RootVar ? rv : throw(RootVarError("All variables must be RootVars")) + end +end + +struct VFieldError <: Exception end + +Base.showerror(io::IO, e::VFieldError) = println(io, "Need at least one state variable in order to create a vector field") + +struct RootVarError <: Exception; msg::String end + +Base.showerror(io::IO, e::RootVarError) = println(io, e.msg) diff --git a/roe/test/SSAExtract.jl b/roe/test/SSAExtract.jl new file mode 100644 index 0000000..12b437b --- /dev/null +++ b/roe/test/SSAExtract.jl @@ -0,0 +1,70 @@ +# TODO under construction +module TestSSAExtract + +# AlgebraicJulia dependencies +using DEC: AbstractSort +import DEC.ThDEC + +# other dependencies +using Test +using LinearAlgebra +using Metatheory + +# Test question: SSA + +function lotka_volterra(roe) + α = fresh!(roe, Scalar(), :α) + β = fresh!(roe, Scalar(), :β) + γ = fresh!(roe, Scalar(), :γ) + + w = fresh!(roe, Scalar(), :w) + s = fresh!(roe, Scalar(), :s) + + ∂ₜ(s) ≐ α * s - β * w * s + ∂ₜ(w) ≐ - γ * w - β * w * s + + ([w, s], [α, β, γ]) +end + +# a model ThRing => GL(ℝ) is necessary here +ops = Dict{TA, Any}( + TA(-, AbstractSort[Scalar()]) => -I, + TA(-, AbstractSort[Scalar(), Scalar()]) => I, + TA(*, AbstractSort[Scalar(), Scalar()]) => I,) + +(ssa, derivative_vars) = vfield(lotka_volterra, ops) + +basicprinted(x; color=false) = sprint(show, x; context=(:color=>color)) + +@test basicprinted(ssa) == """ +SSA: + %1 = γ#2 + %2 = -(%1::Scalar(),) + %3 = w#3 + %4 = *(%2::Scalar(), %3::Scalar()) + %5 = β#1 + %6 = *(%5::Scalar(), %3::Scalar()) + %7 = s#4 + %8 = *(%6::Scalar(), %7::Scalar()) + %9 = -(%4::Scalar(), %8::Scalar()) + %10 = α#0 + %11 = *(%10::Scalar(), %7::Scalar()) + %12 = -(%11::Scalar(), %8::Scalar()) +""" + +roe = Roe() + +a = fresh!(roe, Scalar(), :a) +b = fresh!(roe, Scalar(), :b) + +ssa = SSAExtract.SSA() + +function term_select(g::EGraph, id::Id) + g[id].nodes[1] +end + +extract!(roe.graph, ssa, (a + b).id, term_select) + +ssa + +end diff --git a/roe/test/ThDEC/ThDEC.jl b/roe/test/ThDEC/ThDEC.jl new file mode 100644 index 0000000..aa69e10 --- /dev/null +++ b/roe/test/ThDEC/ThDEC.jl @@ -0,0 +1,19 @@ +# AlgebraicJulia dependencies +using DEC +import DEC.ThDEC: ∧, Δ # conflicts with CombinatorialSpaces + +# preliminary dependencies for testing +using Test +using Metatheory.EGraphs + +# test the signature +include("signature.jl") + +# test the roe_overloads +include("roe_overloads.jl") + +# test the semantics +include("semantics.jl") + +# test modeling +include("model.jl") diff --git a/roe/test/ThDEC/model.jl b/roe/test/ThDEC/model.jl new file mode 100644 index 0000000..08e998a --- /dev/null +++ b/roe/test/ThDEC/model.jl @@ -0,0 +1,76 @@ +using DEC +import DEC: Δ, ∧ + +# load other dependencies +using ComponentArrays +using CombinatorialSpaces +using GeometryBasics +using OrdinaryDiffEq +Point2D = Point2{Float64} +Point3D = Point3{Float64} + +# plotting +using CairoMakie + +## 1-D HEAT EQUATION + +# initialize the model +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# initialize primal and dual meshes. +rect = triangulated_grid(100, 100, 1, 1, Point3D); +d_rect = EmbeddedDeltaDualComplex2D{Bool, Float64, Point3D}(rect); +subdivide_duals!(d_rect, Circumcenter()); + +# precompute matrices from operators in the DEC theory. +op_lookup = ThDEC.create_dynamic_model(d_rect, DiagonalHodge()) + +# produce a vector field. +vf = vfield(heat_equation, op_lookup) + +U = first.(d_rect[:point]); +constants_and_parameters = () +t0 = 50.0 + +@info("Precompiling Solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); +soln = solve(prob, Tsit5()); + +save_dynamics(soln, t0, "heat-1D.gif") + +## 1-D HEAT EQUATION WITH DIFFUSIVITY + +function new_heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + k = fresh!(roe, Scalar(), :k) + ℓ = fresh!(roe, Scalar(), :ℓ) + + ∂ₜ(u) ≐ ℓ * k * Δ(u) + + ([u], [k, ℓ]) +end + +# we can reuse the mesh and operator lookup +vf = vfield(new_heat_equation, op_lookup) + +# we can reuse the initial condition U. However we need to specify the diffusivity constant +constants_and_parameters = ComponentArray(k=0.25,ℓ=2,); + +# this is a shim +DEC.k = :k; DEC.ℓ = :ℓ; + +# Let's set the time +t0 = 500 + +@info("Precompiling solver") +prob = ODEProblem(vf, U, (0, t0), constants_and_parameters); + +soln = solve(prob, Tsit5()); + +save_dynamics(soln, t0, "heat-1D-scalar.gif") diff --git a/roe/test/ThDEC/roe.jl b/roe/test/ThDEC/roe.jl new file mode 100644 index 0000000..f1a04a4 --- /dev/null +++ b/roe/test/ThDEC/roe.jl @@ -0,0 +1,60 @@ +module TestRoe + +using Test +using Metatheory.EGraphs + +# Test question: are function calls in our theory both idempotent and correctly typing expressions? + +# Instantiate a new Roe with two variables of type Var{Scalar} +roe = Roe() +a = fresh!(roe, Scalar(), :a) +b = fresh!(roe, Scalar(), :b) + +# Write the same expresison twice but with different variable bindings. We expect that each `+` dispatches its Var{S} method defined in Roe/RoeFunctions.jl, which adds a new call to the egraph. +x = a + b +y = a + b + +# We expect that `+` is idempotent; addcall! checks if the + call is already present in the egraph with the two ids for `a` and `b`. +@test x == y + +# We also check that the type of (a+b) is a Scalar. +@test roe.graph[(a+b).id].data == Scalar() + +# Test question: + +# Now we define two primal forms. +ω = fresh!(roe, PrimalForm(1), :ω) +η = fresh!(roe, PrimalForm(0), :η) + +# Is the wedge product of a 0-form and 1-form a 1-form? +@test ω ∧ η isa DEC.Var{PrimalForm(1)} + +# Is the addcall! function idempotent? +@test ω ∧ η == ω ∧ η + +@test_throws SortError x ≐ ω + +# Assert that ω is the same as the expression ω∧η +ω ≐ (ω ∧ η) + +# Test question: can we extract a term from the e-graph? + +# Assert to the egraph that ∂ₜ(a) is 3*a + 5 +∂ₜ(a) ≐ 3 * a + 5 + +EGraphs.extract!(∂ₜ(a), DEC.derivative_cost([DEC.extract!(a)])) + +# Test question: given a model with a partial derivative defined by two expressions with the same astsize, which expression is extracted? + +function transitivity(roe) + w = fresh!(roe, Scalar(), :w) + ∂ₜ(w) ≐ 1 * w + ∂ₜ(w) ≐ 2 * w + w +end +_w = transitivity(roe) +# picks whichever expression it happens to visit first +EGraphs.extract!((∂ₜ(_w)), DEC.derivative_cost([DEC.extract!(_w)])) + + +end diff --git a/roe/test/ThDEC/semantics.jl b/roe/test/ThDEC/semantics.jl new file mode 100644 index 0000000..e69de29 diff --git a/roe/test/ThDEC/signature.jl b/roe/test/ThDEC/signature.jl new file mode 100644 index 0000000..8cb9183 --- /dev/null +++ b/roe/test/ThDEC/signature.jl @@ -0,0 +1,40 @@ +# ## SIGNATURE TESTS + +# Addition +@test Scalar() + Scalar() == Scalar() +@test Scalar() + PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) + Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(1) + PrimalForm(2) + +# Negation and Subtraction +@test -Scalar() == Scalar() +@test Scalar() - Scalar() == Scalar() + +# Scalar Multiplication +@test Scalar() * Scalar() == Scalar() +@test Scalar() * PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) * Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(2) * PrimalForm(1) + +# Exterior Product +@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) +@test PrimalForm(1) ∧ Scalar() == PrimalForm(1) + +@test_throws SortError PrimalForm(1) ∧ DualForm(1) +@test_throws SortError PrimalForm(2) ∧ PrimalForm(1) + +# Time derivative +@test ∂ₜ(Scalar()) == Scalar() +@test ∂ₜ(PrimalForm(1)) == PrimalForm(1) +@test ∂ₜ(DualForm(0)) == DualForm(0) + +# Derivative +@test_throws SortError d(Scalar()) +@test d(PrimalForm(1)) == PrimalForm(2) +@test d(DualForm(0)) == DualForm(1) + +# Hodge star +@test_throws SortError ★(Scalar()) +@test ★(PrimalForm(1)) == DualForm(1) +@test ★(DualForm(0)) == PrimalForm(2) + diff --git a/roe/test/runtests.jl b/roe/test/runtests.jl new file mode 100644 index 0000000..d5ad61d --- /dev/null +++ b/roe/test/runtests.jl @@ -0,0 +1,17 @@ +using Test + +@testset "Signature" begin + include("Signature.jl") +end + +@testset "SSA Extraction" begin + include("SSAExtract.jl") +end + +@testset "Roe Utilities" begin + include("Roe.jl") +end + +@testset "ThDEC" begin + include("ThDEC/tests.jl") +end