From d260ad98922b81fc509530bd7ee10df95e003529 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Fri, 19 Jul 2024 16:13:41 -0700 Subject: [PATCH] ssa_extract and analysis working --- DEC/Project.toml | 2 + DEC/src/DEC.jl | 100 +++++++++++++++++++++++++++---------- DEC/src/OperatorStorage.jl | 4 ++ DEC/src/Roe.jl | 0 DEC/src/SSAExtract.jl | 94 ++++++++++++++++++++++++++++++++++ DEC/tests/DEC.jl | 39 +++++++++------ DEC/tests/SSAExtract.jl | 28 +++++++++++ 7 files changed, 225 insertions(+), 42 deletions(-) create mode 100644 DEC/src/OperatorStorage.jl create mode 100644 DEC/src/Roe.jl create mode 100644 DEC/src/SSAExtract.jl create mode 100644 DEC/tests/SSAExtract.jl diff --git a/DEC/Project.toml b/DEC/Project.toml index 311c7c35..63e3d732 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -9,6 +9,7 @@ Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" [compat] @@ -16,4 +17,5 @@ Colors = "0.12.11" Crayons = "4.1.1" MLStyle = "0.4.17" Random = "1.11.0" +Reexport = "1.2.2" StructEquality = "2.1.0" diff --git a/DEC/src/DEC.jl b/DEC/src/DEC.jl index 984bab30..496437a7 100644 --- a/DEC/src/DEC.jl +++ b/DEC/src/DEC.jl @@ -1,5 +1,6 @@ module DEC using MLStyle +using Reexport using StructEquality import Metatheory using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH @@ -9,12 +10,23 @@ import Base: +, - import Base: * include("HashColor.jl") +include("SSAExtract.jl") + +@reexport using .SSAExtract @data Sort begin Scalar() Form(dim::Int, isdual::Bool) end -export Scalar, Form, DualForm +export Scalar, Form + +duality(f::Form) = f.isdual ? "dual" : "primal" + +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm struct SortError <: Exception message::String @@ -24,12 +36,12 @@ end function +(s1::Sort, s2::Sort) @match (s1, s2) begin (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(d, isdual)) || (Form(d, isdual), Scalar()) => Form(d, isdual) - (Form(d, ϖ), Form(d′, ϖ′)) => - if (d == d′) && (ϖ == ϖ′) - Form(d, ϖ) + (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) else - throw(SortError("Cannot add two forms of different dimensions/dualities $(d,ϖ) and $(d′,ϖ′)")) + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) end end end @@ -41,9 +53,9 @@ end @nospecialize function *(s1::Sort, s2::Sort) @match (s1, s2) begin - (Scalar(), Scalar()) => Scalar() - (Scalar(), Form(d, ϖ)) || (Form(d, ϖ), Scalar()) => Form(d) - (Form(_, _), Form(_, _)) => throw(SortError("Cannot multiply two forms. Maybe try `∧`??")) + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) end end @@ -51,11 +63,15 @@ end function ∧(s1::Sort, s2::Sort) @match (s1, s2) begin (_, Scalar()) || (Scalar(), _) => throw(SortError("Cannot take a wedge product with a scalar")) - (Form(d, ϖ), Form(d′, ϖ)) => - if d + d′ <= 2l - Form(d + d′) + (Form(i1, isdual1), Form(i2, isdual2)) => + if isdual1 == isdual2 + if i1 + i2 <= 2 + Form(i1 + i2, isdual1) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end else - throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $d and $(d′)")) + throw(SortError("Cannot wedge two forms of different dualities: attempted to wedge $(duality(s1)) and $(duality(s2))")) end end end @@ -67,15 +83,22 @@ end function d(s::Sort) @match s begin Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) - Form(d) => - if d <= 1 - Form(d + 1) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) else throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) end end end +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end + @struct_hash_equal struct RootVar name::Symbol idx::Int @@ -105,6 +128,15 @@ struct Decapode end end +function EGraphs.make(g::EGraph{Expr, Sort}, n::Metatheory.VecExpr) + op = EGraphs.get_constant(g, Metatheory.v_head(n)) + if op isa RootVar + op.sort + else + op((g[arg].data for arg in Metatheory.v_children(n))...) + end +end + struct Var{S} pode::Decapode id::Id @@ -134,22 +166,22 @@ function Base.show(io::IO, v::Var) print(io, getexpr(v)) end -function fresh!(d::Decapode, sort::Sort, name::Symbol) - v = RootVar(name, length(d.variables), sort) - push!(d.variables, v) +function fresh!(pode::Decapode, sort::Sort, name::Symbol) + v = RootVar(name, length(pode.variables), sort) + push!(pode.variables, v) n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(d.graph, v)) + Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, v)) Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0))) - Var{sort}(d, EGraphs.add!(d.graph, n, false)) + Var{sort}(pode, EGraphs.add!(pode.graph, n, false)) end @nospecialize -function inject_number!(d::Decapode, x::Number) +function inject_number!(pode::Decapode, x::Number) x = Float64(x) n = Metatheory.v_new(0) - Metatheory.v_set_head!(n, EGraphs.add_constant!(d.graph, x)) + Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, x)) Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0))) - Var{Scalar()}(d, EGraphs.add!(d.graph, n, false)) + Var{Scalar()}(pode, EGraphs.add!(pode.graph, n, false)) end @nospecialize @@ -226,6 +258,15 @@ function d(v::Var{s}) where {s} Var{s′}(v.pode, addcall!(v.pode.graph, d, (v.id,))) end + +@nospecialize +function ★(v::Var{s}) where {s} + s′ = ★(s) + Var{s′}(v.pode, addcall!(v.pode.graph, ★, (v.id,))) +end + +Δ(v::Var{PrimalForm(0)}) = ★(d(★(d(v)))) + function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2} (s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2")) v1.pode === v2.pode || error("Cannot equate variables from different graphs") @@ -245,6 +286,11 @@ function derivative_cost(allowed_roots) end end +struct TypedApplication + fn::Function + sorts::Vector{Sort} +end + """ vfield :: (Decapode -> (StateVars, ParamVars)) -> VectorFieldFunction Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables. @@ -263,7 +309,7 @@ A limitation of this function can be demonstrated here: given the model we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar, the extractor would bypass it completely. """ -function vfield(model) +function vfield(model, matrices::Dict{TypedApplication, Any}) pode = Decapode() (state_vars, param_vars) = model(pode) length(state_vars) >= 1 || error("need at least one state variable in order to create vector field") @@ -274,7 +320,7 @@ function vfield(model) end param_rootvars = map(param_vars) do p rv = extract!(p) - rv isa RootVar || error("all state variables must be RootVars") + rv isa RootVar || error("all param variables must be RootVars") rv end cost = derivative_cost(Set([state_rootvars; param_rootvars])) @@ -296,6 +342,8 @@ function vfield(model) :(@inbounds $(du)[$i] = $(replace_rootvars(e, rootvar_lookup))) end + + eval( quote ($du, $u, $p, _) -> begin diff --git a/DEC/src/OperatorStorage.jl b/DEC/src/OperatorStorage.jl new file mode 100644 index 00000000..eeb9000b --- /dev/null +++ b/DEC/src/OperatorStorage.jl @@ -0,0 +1,4 @@ +struct OperatorStorage + hodge::Tuple{} +end + diff --git a/DEC/src/Roe.jl b/DEC/src/Roe.jl new file mode 100644 index 00000000..e69de29b diff --git a/DEC/src/SSAExtract.jl b/DEC/src/SSAExtract.jl new file mode 100644 index 00000000..eb78ac0b --- /dev/null +++ b/DEC/src/SSAExtract.jl @@ -0,0 +1,94 @@ +module SSAExtract + +using MLStyle +using Metatheory.EGraphs + +struct SSAVar + idx::Int +end + +function Base.show(io::IO, v::SSAVar) + print(io, "\$", v.idx) +end + +@data SSAExpr begin + Constant(x::Any) + App(fn::Any, args::Vector{SSAVar}) +end + +function Base.show(io::IO, e::SSAExpr) + @match e begin + Constant(x) => show(io, x) + App(fn, args) => begin + print(io, fn) + print(io, Expr(:tuple, args...)) + end + end +end + +struct SSA + assignment_lookup::Dict{Id, SSAVar} + statements::Vector{SSAExpr} + function SSA() + new(Dict{Id, SSAVar}(), SSAExpr[]) + end +end + +function Base.show(io::IO, ssa::SSA) + println(io, "SSA: ") + for (i, expr) in enumerate(ssa.statements) + println(io, " ", SSAVar(i), " = ", expr) + end +end + +function add_stmt!(ssa::SSA, id::Id, expr::SSAExpr) + push!(ssa.statements, expr) + v = SSAVar(length(ssa.statements)) + ssa.assignment_lookup[id] = v + v +end + +function hasid(ssa::SSA, id::Id) + haskey(ssa.assignment_lookup, id) +end + +function getvar(ssa::SSA, id::Id) + ssa.assignment_lookup[id] +end + +""" + extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar + +This function adds (recursively) the necessary lines to the SSA in order to +compute a value for `id`, and then returns the SSAVar that the value for `id` +will be assigned to. + +The closure parameters control the behavior of this function. + + term_select(g::EGraph, id::Id)::VecExpr + +This closure selects, given an id in an EGraph, the term that we want to use in +order to compute a value for that id + + make_expr(head::Any, args::Vector{Tuple{Sort, SSAVar}})::SSAExpr + +This closure produces an SSAExpr by selecting a head based on the head of the +term in the e-graph and the sorts of the arguments. +""" +function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar + if hasid(ssa, id) + return getvar(ssa, id) + end + term = term_select(g, id) + args = map(EGraphs.v_children(term)) do arg + (g[arg].data, extract_ssa!(g, ssa, arg, term_select, make_expr)) + end + add_stmt!(ssa, id, make_expr(EGraphs.get_constant(g, EGraphs.v_head(term)), args)) +end +export extract_ssa! + +function extract_ssa!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term) + extract_ssa!(g, ssa, id, term_select, make_expr) +end + +end \ No newline at end of file diff --git a/DEC/tests/DEC.jl b/DEC/tests/DEC.jl index c39b05a3..7d130511 100644 --- a/DEC/tests/DEC.jl +++ b/DEC/tests/DEC.jl @@ -1,23 +1,23 @@ module TestDEC using DEC -using DEC: Decapode, SortError, d, fresh!, ∂ₜ, ∧, ≐ +using DEC: Decapode, SortError, d, fresh!, ∂ₜ, ∧, Δ, ≐ using Test using Metatheory.EGraphs @test Scalar() + Scalar() == Scalar() -@test Scalar() + Form(1) == Form(1) -@test Form(2) + Scalar() == Form(2) -@test_throws SortError Form(1) + Form(2) +@test Scalar() + PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) + Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(1) + PrimalForm(2) # Scalar Multiplication @test Scalar() * Scalar() == Scalar() -@test Scalar() * Form(1) == Form(1) -@test Form(2) * Scalar() == Form(2) -@test_throws SortError Form(2) * Form(1) +@test Scalar() * PrimalForm(1) == PrimalForm(1) +@test PrimalForm(2) * Scalar() == PrimalForm(2) +@test_throws SortError PrimalForm(2) * PrimalForm(1) # Exterior Product -@test Form(1) ∧ Form(1) == Form(2) +@test PrimalForm(1) ∧ PrimalForm(1) == PrimalForm(2) pode = Decapode() @@ -29,10 +29,10 @@ y = a + b @test x == y -ω = fresh!(pode, Form(1), :ω) -η = fresh!(pode, Form(0), :η) +ω = fresh!(pode, PrimalForm(1), :ω) +η = fresh!(pode, PrimalForm(0), :η) -@test ω ∧ η isa DEC.Var{Form(1)} +@test ω ∧ η isa DEC.Var{PrimalForm(1)} @test ω ∧ η == ω ∧ η @test_throws SortError x ≐ ω @@ -41,7 +41,7 @@ y = a + b ∂ₜ(a) ≐ 3 * a + 5 -EGraphs.extract!(pode.graph, DEC.noderivcost, (∂ₜ(a)).id) +EGraphs.extract!(∂ₜ(a), DEC.derivative_cost([DEC.extract!(a)])) function lotka_volterra(pode) α = fresh!(pode, Scalar(), :α) @@ -59,8 +59,6 @@ end f = DEC.vfield(lotka_volterra) -EGraphs.extract!(pode.graph, DEC.noderivcost, (∂ₜ(w)).id) - function transitivity(pode) w = fresh!(pode, Scalar(), :w) ∂ₜ(w) ≐ 1 * w @@ -69,8 +67,17 @@ function transitivity(pode) end _w = transitivity(pode) # picks whichever expression it happens to visit first -EGraphs.extract!(pode.graph, DEC.noderivcost, (∂ₜ(_w)).id) +EGraphs.extract!((∂ₜ(_w)), DEC.derivative_cost([DEC.extract!(_w)])) + +function heat_equation(pode) + u = fresh!(pode, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +f = DEC.vfield(heat_equation) -EGraphs.extract!(pode.graph, EGraphs.astsize, w.id) end \ No newline at end of file diff --git a/DEC/tests/SSAExtract.jl b/DEC/tests/SSAExtract.jl new file mode 100644 index 00000000..f777881f --- /dev/null +++ b/DEC/tests/SSAExtract.jl @@ -0,0 +1,28 @@ +module TestSSAExtract + +using Test +using Metatheory +using DEC + +pode = Decapode() + +a = fresh!(pode, Scalar(), :a) +b = fresh!(pode, Scalar(), :b) + +ssa = SSAExtract.SSA() + +function term_select(g::EGraph, id::Id) + g[id].nodes[1] +end + +function make_expr(head, args) + if length(args) == 0 + SSAExtract.Constant(head) + else + SSAExtract.App(head, last.(args)) + end +end + +extract_ssa!(pode.graph, ssa, (a + b).id, term_select, make_expr) + +end \ No newline at end of file