Skip to content

Commit

Permalink
egraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Oct 9, 2023
1 parent 9101271 commit 6b57b41
Show file tree
Hide file tree
Showing 9 changed files with 783 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/syntax/EGraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module EGraphs

export EGraph, ETerm, EClass, EType, EConstant, EId,
add!, merge!, rebuild!

include("egraphs/east.jl")

end # module
2 changes: 1 addition & 1 deletion src/syntax/GATs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module GATs
export Constant, AlgTerm, AlgType, AlgAST,
TypeScope, TypeCtx, AlgSort, AlgSorts,
TypeScope, TypeCtx, AlgSort, AlgSorts, MethodApp,
AlgDeclaration, AlgTermConstructor, AlgTypeConstructor, AlgAccessor, AlgAxiom,
sortsignature, getdecl,
GATSegment, GAT, Presentation, gettheory, gettypecontext, allmethods, resolvemethod,
Expand Down
45 changes: 45 additions & 0 deletions src/syntax/egraphs/algebraic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
This module contains facilities for working with *algebraic* theories, i.e.
theories where none of the type constructors have arguments.
Type inference and checking are much easier for such theories.
"""

using ...Syntax

"""
Returns whether a theory is algebraic
"""
function is_algebraic(theory::Theory)::Bool
for j in theory.judgments
if j.head isa TypCon
length(j.head.args) == 0 || return false
end
end
true
end

"""
Infer the typ of a trm in an algebraic theory and context.
Throw an error if type cannot be inferred.
"""
function typ_infer(theory::Theory, t::Trm; context::Context = Context())
if iscontext(t.head)
context.ctx[index(t.head)][2]
else
j = theory.judgments[index(t.head)]
j.head isa TrmCon || error("head of $t must be a term constructor")
args = t.args
length(args) == length(j.head.args) ||
error("wrong number of args for top-level term constructor in $t")
argtyps = map(args) do arg
typ_infer(theory, args; context)
end
expected_argtyps = Typ[j.ctx[i][2] for i in j.head.args]
argtyps == expected_argtyps ||
error("arguments to $t are wrong type: expected $expected_argtyps, got $argtyps")
j.head.typ
end
end

308 changes: 308 additions & 0 deletions src/syntax/egraphs/east.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
using MLStyle
using DataStructures
using StructEquality

using ..Scopes
using ..GATs
using ..ExprInterop

"""
An index into the union-find data structure component of an E-graph. Each
e-class is associated to a set of EIds, including a canonical one. This set is
stored by the union-find.
"""
const EId = Int # TODO make a struct which subtypes integer or roll out our own IntDisjointSets

"""
An EType is in the context of a GAT, which the `head` of the `MethodApp` refers
to. For each parameter there is an e-term.
"""
@struct_hash_equal struct EType
body::MethodApp{EId}
end

EType(a::Ident,b::Ident,c::Vector{EId}) = EType(MethodApp(a,b,c))

@struct_hash_equal struct EConstant
value::Any
type::EType
end

"""
ETerms in are interpreted in a Presentation. In the case of a MethodApp, the
head/method refer to term constructors or accessors of the theory.
"""
@struct_hash_equal struct ETerm
body::Union{Ident, MethodApp{EId}, EConstant}
end

ETerm(a::Ident,b::Ident,c::Vector{EId}) = ETerm(MethodApp(a,b,c))

const Parents = Dict{ETerm, EId}

"""
`reps` A representation of an equivalence class of e-terms.
`parents` caches all the e IDs which directly refer to a given term (as opposed
to some reference in a nested term)
"""
mutable struct EClass
reps::Set{ETerm}
type::EType
parents::Parents
function EClass(reps::Set{ETerm}, type::EType, parents::Parents=Parents())
new(reps, type, parents)
end
end

function add_parent!(ec::EClass, etrm::ETerm, i::EId)
ec.parents[etrm] = i
end


"""
Stores a congruent partial equivalence relation on terms in the context of
`presentation`
"""
struct EGraph
presentation::Presentation
eqrel::IntDisjointSets{EId}
eclasses::Dict{EId, EClass}
hashcons::Dict{ETerm, EId}
worklist::Vector{EId}
isclean::Ref{Bool}
function EGraph(pres::Presentation)
new(pres, IntDisjointSets{EId}(0), Dict{EId, EClass}(), Dict{ETerm, EId}(), EId[], Ref(true))
end
end

EGraph(T::GAT) = EGraph(Presentation(T)) # Theory without any further context

"""
Update e-term to refer to canonical e-ids
"""
function canonicalize!(eg::EGraph, etrm::ETerm)
(@match etrm.body begin
x::Union{Constant, Ident} => x
MethodApp(head, method, args) =>
MethodApp{EId}(head, method, find_root!.(Ref(eg.eqrel), args))
end) |> ETerm
end

function etypeof(eg::EGraph, i::EId)
eg.eclasses[i].type
end

"""
This computes the inferred context for an etrm.
For example, if `f` is an id with etyp `Hom(x,y)` and `g` is an id with etyp
`Hom(y,z)`, then context(eg, :(g ∘ f)) computes the context `[x,y,z,f,g]`.
The tricky thing comes from term formers like
weaken(x)::Term(n) ⊣ [n::Nat, x::Term(S(n))]
We get the ETyp for x from the e-graph, and then we have to ematch its argument
with `S(n)` to figure out what `n` is... The problem is that in general `S` will
not be injective, so this is ambiguous!
What we are going to do for now is say that types in the context of a term former
can't be nested. I.e., we only allow types of the form `Term(n)`, not `Term(S(n))`.
Fortunately, I don't think we care about any theories with this kind of context
former.
To fix this issue, you should instead declare term constructors like
```
weaken(n, x)::Term(n) ⊣ [n::Nat, x::Term(S(n))]
```
"""
function econtext(eg::EGraph, m::MethodApp{EId})
termcon = getvalue(eg.presentation[m.method])
typeof(termcon) == AlgTermConstructor ||
error("head of $etrm is not a term constructor")
length(m.args) == length(termcon.args) ||
error("wrong number of args for term constructor in $etrm")
ectx = zeros(EId, length(termcon.localcontext))
# initialize result with top-level arguments
toexpand = Tuple{AlgType, EType}[]
for (lid, eid) in zip(termcon.args, m.args)
ectx[lid.val] = eid
push!(toexpand, (getvalue(termcon[lid]), etypeof(eg, eid)))
end
while !isempty(toexpand)
(algtype, etype) = pop!(toexpand)
for (arg, id) in zip(algtype.body.args, etype.body.args)
id = find_root!(eg.eqrel, id)
@match arg.body begin
_::Constant => nothing
x::Ident => begin
i = getlid(x).val
if ectx[i] != 0
ectx[i] == id ||
error("contradictory inference of context for $m; could not unify $(ectx[i]) and $id")
else
ectx[i] = id
end
push!(toexpand, (getvalue(termcon[x]), etypeof(eg, id)))
end
_::MethodApp => error("we don't do that kind of thing over here")
end
end
end
all(!=(0), ectx) || error("could not fully infer context")
ectx
end

function compute_etype(eg::EGraph, eterm::ETerm)::EType
@match eterm.body begin
x::Ident => begin
algtype = getvalue(eg.presentation[x]).body
EType(algtype.head, algtype.method, add!.(Ref(eg), argsof(algtype)))
end
c::EConstant => c.type
m::MethodApp{EId} => begin
ectx = econtext(eg, m)
termcon = getvalue(eg.presentation[m.method])
type_body = termcon.type.body
EType(
type_body.head,
type_body.method,
EId[subst!(eg, arg, ectx, gettag(termcon.localcontext)) for arg in type_body.args]
)
end
end
end

"""
Returns the `EId` corresponding to the term resulting from the substitution
in `term` of the idents in the scope refered to by `tag` according to the
values in `ectx`
Note: this is similar logic to `add!`: perhaps combine the two by making `ectx`
and `tag` optional?
"""
function subst!(eg::EGraph, term::AlgTerm, ectx::Vector{EId}, tag::ScopeTag)
@match term.body begin
x::Ident && if gettag(x) == tag end => ectx[getlid(x).val]
c::Union{Constant, Ident} => add!(eg, c)
m::MethodApp => begin
args = EId[subst!(eg, arg, ectx, tag) for arg in trm.args]
add!(eg, ETerm(m.head, m.method, args))
end
end
end

"""
Add eterm to an egraph.
"""
function add!(eg::EGraph, eterm::ETerm)
eterm = canonicalize!(eg, eterm)
if haskey(eg.hashcons, eterm)
eg.hashcons[eterm]
else
etype = compute_etype(eg, eterm)
id = push!(eg.eqrel)
if eterm.body isa MethodApp
for argid in eterm.body.args
add_parent!(eg.eclasses[argid], eterm, id)
end
end
eg.hashcons[eterm] = id
eg.eclasses[id] = EClass(Set([eterm]), etype)
id
end
end

function add!(eg::EGraph, term::AlgTerm)
@match term.body begin
x::Ident => add!(eg, ETerm(x))
c::Constant => begin
tb = c.type.body
ec = EConstant(c.value, EType(tb.head, tb.method, add!.(Ref(eg), tb.args)))
add!(eg, ETerm(ec))
end
m::MethodApp => add!(eg, ETerm(m.head, m.method, add!.(Ref(eg), m.args)))
end
end

function add!(eg::EGraph, term::Union{Expr, Symbol})
add!(eg, fromexpr(eg.presentation, term, AlgTerm))
end

find!(eg::EGraph, i::EId) = find_root!(eg.eqrel, i)

"""
Merge the eclasses associated with two eIDs.
"""
function Base.merge!(eg::EGraph, id1::EId, id2::EId)
eg.isclean[] = false
id1, id2 = find!.(Ref(eg), (id1, id2))
if id1 == id2
return id1
end

id = union!(eg.eqrel, id1, id2)
id1, id2 = (id == id1) ? (id2, id1) : (id1, id2)
push!(eg.worklist, id)
ec1 = eg.eclasses[id1]
ec2 = eg.eclasses[id2]
union!(ec2.reps, ec1.reps)
merge!(ec2.parents, ec1.parents)
delete!(eg.eclasses, id1)
id
end

"""
Reinforces the e-graph invariants (i.e., ensures that the equivalence relation
is congruent).
"""
function rebuild!(eg::EGraph)
while !isempty(eg.worklist)
todo = [ find!(eg, i) for i in eg.worklist ]
empty!(eg.worklist)
for i in todo
repair!(eg, i)
end
end
eg.isclean[] = true
end

function repair!(eg::EGraph, i::EId)
for (p_etrm, _) in eg.eclasses[i].parents
delete!(eg.hashcons, p_etrm)
p_etrm = canonicalize!(eg, p_etrm)
eg.hashcons[p_etrm] = find!(eg, i)
end

new_parents = Parents()

for (p_etrm, p_eclass) in eg.eclasses[i].parents
p_etrm = canonicalize(eg, p_etrm)
if p_etrm keys(new_parents)
merge!(eg, p_eclass, new_parents[p_etrm])
end
new_parents[p_etrm] = find!(eg, p_eclass)
end

eg.eclasses[i].parents = new_parents
end

# Extraction
function extract(eg::EGraph, t::EType; chooser=only)::AlgType
body = t.body
AlgType(body.head, body.method, extract.(Ref(eg), body.args; chooser))
end

function extract(eg::EGraph, t::ETerm; chooser=only)::AlgTerm
@match t.body begin
x::Ident => AlgTerm(x)
c::EConstant => Constant(c.value, extract(eg, c.type; chooser))
m::MethodApp => AlgTerm(m.head, m.method, extract.(Ref(eg), m.args; chooser))
end
end

function extract(eg::EGraph, id::EId; chooser=only)::AlgTerm
extract(eg, chooser(eg.eclasses[id].reps))
end
Loading

0 comments on commit 6b57b41

Please sign in to comment.