Skip to content

Commit

Permalink
ssa_extract and analysis working
Browse files Browse the repository at this point in the history
  • Loading branch information
olynch committed Jul 19, 2024
1 parent 0891f7f commit d260ad9
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 42 deletions.
2 changes: 2 additions & 0 deletions DEC/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"

[compat]
Colors = "0.12.11"
Crayons = "4.1.1"
MLStyle = "0.4.17"
Random = "1.11.0"
Reexport = "1.2.2"
StructEquality = "2.1.0"
100 changes: 74 additions & 26 deletions DEC/src/DEC.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module DEC
using MLStyle
using Reexport
using StructEquality
import Metatheory
using Metatheory: EGraph, EGraphs, Id, VECEXPR_FLAG_ISCALL, VECEXPR_FLAG_ISTREE, VECEXPR_META_LENGTH
Expand All @@ -9,12 +10,23 @@ import Base: +, -
import Base: *

include("HashColor.jl")
include("SSAExtract.jl")

@reexport using .SSAExtract

@data Sort begin
Scalar()
Form(dim::Int, isdual::Bool)
end
export Scalar, Form, DualForm
export Scalar, Form

duality(f::Form) = f.isdual ? "dual" : "primal"

PrimalForm(i::Int) = Form(i, false)
export PrimalForm

DualForm(i::Int) = Form(i, true)
export DualForm

struct SortError <: Exception
message::String
Expand All @@ -24,12 +36,12 @@ end
function +(s1::Sort, s2::Sort)
@match (s1, s2) begin
(Scalar(), Scalar()) => Scalar()
(Scalar(), Form(d, isdual)) || (Form(d, isdual), Scalar()) => Form(d, isdual)
(Form(d, ϖ), Form(d′, ϖ′)) =>
if (d == d′) && (ϖ == ϖ′)
Form(d, ϖ)
(Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual)
(Form(i1, isdual1), Form(i2, isdual2)) =>
if (i1 == i2) && (isdual1 == isdual2)
Form(i1, isdual1)
else
throw(SortError("Cannot add two forms of different dimensions/dualities $(d,ϖ) and $(d′,ϖ′)"))
throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))"))
end
end
end
Expand All @@ -41,21 +53,25 @@ end
@nospecialize
function *(s1::Sort, s2::Sort)
@match (s1, s2) begin
(Scalar(), Scalar()) => Scalar()
(Scalar(), Form(d, ϖ)) || (Form(d, ϖ), Scalar()) => Form(d)
(Form(_, _), Form(_, _)) => throw(SortError("Cannot multiply two forms. Maybe try `∧`??"))
(Scalar(), Scalar()) => Scalar()
(Scalar(), Form(i, isdual)) || (Form(i, isdual), Scalar()) => Form(i, isdual)
(Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??"))
end
end

@nospecialize
function (s1::Sort, s2::Sort)
@match (s1, s2) begin
(_, Scalar()) || (Scalar(), _) => throw(SortError("Cannot take a wedge product with a scalar"))
(Form(d, ϖ), Form(d′, ϖ)) =>
if d + d′ <= 2l
Form(d + d′)
(Form(i1, isdual1), Form(i2, isdual2)) =>
if isdual1 == isdual2
if i1 + i2 <= 2
Form(i1 + i2, isdual1)
else
throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2"))
end
else
throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $d and $(d′)"))
throw(SortError("Cannot wedge two forms of different dualities: attempted to wedge $(duality(s1)) and $(duality(s2))"))
end
end
end
Expand All @@ -67,15 +83,22 @@ end
function d(s::Sort)
@match s begin
Scalar() => throw(SortError("Cannot take exterior derivative of a scalar"))
Form(d) =>
if d <= 1
Form(d + 1)
Form(i, isdual) =>
if i <= 1
Form(i + 1, isdual)
else
throw(SortError("Cannot take exterior derivative of a n-form for n >= 1"))
end
end
end

function (s::Sort)
@match s begin
Scalar() => throw(SortError("Cannot take Hodge star of a scalar"))
Form(i, isdual) => Form(2 - i, !isdual)
end
end

@struct_hash_equal struct RootVar
name::Symbol
idx::Int
Expand Down Expand Up @@ -105,6 +128,15 @@ struct Decapode
end
end

function EGraphs.make(g::EGraph{Expr, Sort}, n::Metatheory.VecExpr)
op = EGraphs.get_constant(g, Metatheory.v_head(n))
if op isa RootVar
op.sort
else
op((g[arg].data for arg in Metatheory.v_children(n))...)
end
end

struct Var{S}
pode::Decapode
id::Id
Expand Down Expand Up @@ -134,22 +166,22 @@ function Base.show(io::IO, v::Var)
print(io, getexpr(v))
end

function fresh!(d::Decapode, sort::Sort, name::Symbol)
v = RootVar(name, length(d.variables), sort)
push!(d.variables, v)
function fresh!(pode::Decapode, sort::Sort, name::Symbol)
v = RootVar(name, length(pode.variables), sort)
push!(pode.variables, v)
n = Metatheory.v_new(0)
Metatheory.v_set_head!(n, EGraphs.add_constant!(d.graph, v))
Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, v))
Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(v), hash(0)))
Var{sort}(d, EGraphs.add!(d.graph, n, false))
Var{sort}(pode, EGraphs.add!(pode.graph, n, false))
end

@nospecialize
function inject_number!(d::Decapode, x::Number)
function inject_number!(pode::Decapode, x::Number)
x = Float64(x)
n = Metatheory.v_new(0)
Metatheory.v_set_head!(n, EGraphs.add_constant!(d.graph, x))
Metatheory.v_set_head!(n, EGraphs.add_constant!(pode.graph, x))
Metatheory.v_set_signature!(n, hash(Metatheory.maybe_quote_operation(x), hash(0)))
Var{Scalar()}(d, EGraphs.add!(d.graph, n, false))
Var{Scalar()}(pode, EGraphs.add!(pode.graph, n, false))
end

@nospecialize
Expand Down Expand Up @@ -226,6 +258,15 @@ function d(v::Var{s}) where {s}
Var{s′}(v.pode, addcall!(v.pode.graph, d, (v.id,)))
end


@nospecialize
function (v::Var{s}) where {s}
s′ = (s)
Var{s′}(v.pode, addcall!(v.pode.graph, ★, (v.id,)))
end

Δ(v::Var{PrimalForm(0)}) = (d((d(v))))

function equate!(v1::Var{s1}, v2::Var{s2}) where {s1, s2}
(s1 == s2) || throw(SortError("Cannot equate variables of a different sort: attempted to equate $s1 with $s2"))
v1.pode === v2.pode || error("Cannot equate variables from different graphs")
Expand All @@ -245,6 +286,11 @@ function derivative_cost(allowed_roots)
end
end

struct TypedApplication
fn::Function
sorts::Vector{Sort}
end

""" vfield :: (Decapode -> (StateVars, ParamVars)) -> VectorFieldFunction
Short for "vector field." Obtains tuple of root vars from a model, where the first component are state variables and the second are parameter variables.
Expand All @@ -263,7 +309,7 @@ A limitation of this function can be demonstrated here: given the model
we would return ([a, b],). Suppose we wanted to extract terms of the form "a + b." Since the expression "a + b" is not a RootVar,
the extractor would bypass it completely.
"""
function vfield(model)
function vfield(model, matrices::Dict{TypedApplication, Any})
pode = Decapode()
(state_vars, param_vars) = model(pode)
length(state_vars) >= 1 || error("need at least one state variable in order to create vector field")
Expand All @@ -274,7 +320,7 @@ function vfield(model)
end
param_rootvars = map(param_vars) do p
rv = extract!(p)
rv isa RootVar || error("all state variables must be RootVars")
rv isa RootVar || error("all param variables must be RootVars")
rv
end
cost = derivative_cost(Set([state_rootvars; param_rootvars]))
Expand All @@ -296,6 +342,8 @@ function vfield(model)
:(@inbounds $(du)[$i] = $(replace_rootvars(e, rootvar_lookup)))
end



eval(
quote
($du, $u, $p, _) -> begin
Expand Down
4 changes: 4 additions & 0 deletions DEC/src/OperatorStorage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
struct OperatorStorage
hodge::Tuple{}
end

Empty file added DEC/src/Roe.jl
Empty file.
94 changes: 94 additions & 0 deletions DEC/src/SSAExtract.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
module SSAExtract

using MLStyle
using Metatheory.EGraphs

struct SSAVar
idx::Int
end

function Base.show(io::IO, v::SSAVar)
print(io, "\$", v.idx)
end

@data SSAExpr begin
Constant(x::Any)
App(fn::Any, args::Vector{SSAVar})
end

function Base.show(io::IO, e::SSAExpr)
@match e begin
Constant(x) => show(io, x)
App(fn, args) => begin
print(io, fn)
print(io, Expr(:tuple, args...))
end
end
end

struct SSA
assignment_lookup::Dict{Id, SSAVar}
statements::Vector{SSAExpr}
function SSA()
new(Dict{Id, SSAVar}(), SSAExpr[])
end
end

function Base.show(io::IO, ssa::SSA)
println(io, "SSA: ")
for (i, expr) in enumerate(ssa.statements)
println(io, " ", SSAVar(i), " = ", expr)
end
end

function add_stmt!(ssa::SSA, id::Id, expr::SSAExpr)
push!(ssa.statements, expr)
v = SSAVar(length(ssa.statements))
ssa.assignment_lookup[id] = v
v
end

function hasid(ssa::SSA, id::Id)
haskey(ssa.assignment_lookup, id)
end

function getvar(ssa::SSA, id::Id)
ssa.assignment_lookup[id]
end

"""
extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar
This function adds (recursively) the necessary lines to the SSA in order to
compute a value for `id`, and then returns the SSAVar that the value for `id`
will be assigned to.
The closure parameters control the behavior of this function.
term_select(g::EGraph, id::Id)::VecExpr
This closure selects, given an id in an EGraph, the term that we want to use in
order to compute a value for that id
make_expr(head::Any, args::Vector{Tuple{Sort, SSAVar}})::SSAExpr
This closure produces an SSAExpr by selecting a head based on the head of the
term in the e-graph and the sorts of the arguments.
"""
function extract_ssa!(g::EGraph, ssa::SSA, id::Id, term_select, make_expr)::SSAVar
if hasid(ssa, id)
return getvar(ssa, id)
end
term = term_select(g, id)
args = map(EGraphs.v_children(term)) do arg
(g[arg].data, extract_ssa!(g, ssa, arg, term_select, make_expr))
end
add_stmt!(ssa, id, make_expr(EGraphs.get_constant(g, EGraphs.v_head(term)), args))
end
export extract_ssa!

function extract_ssa!(g::EGraph, id::Id; ssa::SSA=SSA(), term_select::Function=best_term)
extract_ssa!(g, ssa, id, term_select, make_expr)
end

end
Loading

0 comments on commit d260ad9

Please sign in to comment.