Skip to content

Commit

Permalink
started fixing up ematching
Browse files Browse the repository at this point in the history
  • Loading branch information
olynch committed Oct 30, 2023
1 parent 9c8a665 commit a52a407
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
20 changes: 10 additions & 10 deletions src/syntax/egraphs/east.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ EType(a::Ident,b::Ident,c::Vector{EId}) = EType(MethodApp(a,b,c))
end

"""
ETerms in are interpreted in a Presentation. In the case of a MethodApp, the
ETerms in are interpreted in a GATContext. In the case of a MethodApp, the
head/method refer to term constructors or accessors of the theory.
"""
@struct_hash_equal struct ETerm
Expand Down Expand Up @@ -61,21 +61,21 @@ end

"""
Stores a congruent partial equivalence relation on terms in the context of
`presentation`
`context`
"""
struct EGraph
presentation::Presentation
context::GATContext
eqrel::IntDisjointSets{EId}
eclasses::Dict{EId, EClass}
hashcons::Dict{ETerm, EId}
worklist::Vector{EId}
isclean::Ref{Bool}
function EGraph(pres::Presentation)
function EGraph(pres::GATContext)
new(pres, IntDisjointSets{EId}(0), Dict{EId, EClass}(), Dict{ETerm, EId}(), EId[], Ref(true))

Check warning on line 74 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L73-L74

Added lines #L73 - L74 were not covered by tests
end
end

EGraph(T::GAT) = EGraph(Presentation(T)) # Theory without any further context
EGraph(T::GAT) = EGraph(GATContext(T)) # Theory without any further context

Check warning on line 78 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L78

Added line #L78 was not covered by tests

"""
Update e-term to refer to canonical e-ids
Expand Down Expand Up @@ -119,7 +119,7 @@ weaken(n, x)::Term(n) ⊣ [n::Nat, x::Term(S(n))]
```
"""
function econtext(eg::EGraph, m::MethodApp{EId})
termcon = getvalue(eg.presentation[m.method])
termcon = getvalue(eg.context[m.method])
typeof(termcon) == AlgTermConstructor ||

Check warning on line 123 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L121-L123

Added lines #L121 - L123 were not covered by tests
error("head of $etrm is not a term constructor")
length(m.args) == length(termcon.args) ||

Check warning on line 125 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L125

Added line #L125 was not covered by tests
Expand Down Expand Up @@ -158,13 +158,13 @@ end
function compute_etype(eg::EGraph, eterm::ETerm)::EType
@match eterm.body begin
x::Ident => begin
algtype = getvalue(eg.presentation[x]).body
algtype = getvalue(eg.context[x]).body
EType(algtype.head, algtype.method, add!.(Ref(eg), argsof(algtype)))

Check warning on line 162 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L158-L162

Added lines #L158 - L162 were not covered by tests
end
c::EConstant => c.type
m::MethodApp{EId} => begin
ectx = econtext(eg, m)
termcon = getvalue(eg.presentation[m.method])
termcon = getvalue(eg.context[m.method])
type_body = termcon.type.body
EType(

Check warning on line 169 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L164-L169

Added lines #L164 - L169 were not covered by tests
type_body.head,
Expand Down Expand Up @@ -228,7 +228,7 @@ function add!(eg::EGraph, term::AlgTerm)
end

function add!(eg::EGraph, term::Union{Expr, Symbol})
add!(eg, fromexpr(eg.presentation, term, AlgTerm))
add!(eg, fromexpr(eg.context, term, AlgTerm))

Check warning on line 231 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L230-L231

Added lines #L230 - L231 were not covered by tests
end

find!(eg::EGraph, i::EId) = find_root!(eg.eqrel, i)

Check warning on line 234 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L234

Added line #L234 was not covered by tests
Expand Down Expand Up @@ -305,4 +305,4 @@ end

function extract(eg::EGraph, id::EId; chooser=only)::AlgTerm
extract(eg, chooser(eg.eclasses[id].reps))

Check warning on line 307 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L306-L307

Added lines #L306 - L307 were not covered by tests
end
end
17 changes: 9 additions & 8 deletions src/syntax/egraphs/ematching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#
# Here we follow a strategy similar to egg, but modified somewhat for our uses.
#
# We take a pattern, which is a Trm in a Context, and we attempt to find an
# assignment of an enode to each term in the context.
# We take a pattern, which is a AlgTerm in a TypeScope, and we attempt to find an
# assignment of an enode to each variable in the scope.
#
# For instance, we might look for the term `(a * b) * c` in the context
# `[a,b,c::U]` or for the term `f ∘ id(a)` in the context
Expand All @@ -27,9 +27,9 @@ end
Base.:+(r::Reg, i::Int) = Reg(r.idx + i)

Check warning on line 27 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L27

Added line #L27 was not covered by tests

struct Machine
reg::Vector{Id}
lookup::Vector{Id}
matches::Vector{Vector{Id}}
reg::Vector{EId}
lookup::Vector{EId}
matches::Vector{Vector{EId}}
limit::Union{Some{Int}, Nothing}
function Machine(;limit=nothing)
new(Id[], Id[], Vector{Id}[], limit)

Check warning on line 35 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
Expand Down Expand Up @@ -57,7 +57,7 @@ end
#
# For each term, assign the registers past `out` to the arguments of that term,
# and run the rest of the instructions.
Bind(trmcon::Lvl, i::Reg, out::Reg)
Bind(trmcon::Ident, i::Reg, out::Reg)

# Check if the eclass bound to `i` is the same as the eclass bound to `j`
Compare(i::Reg, j::Reg)
Expand All @@ -66,11 +66,12 @@ end
# refer to earlier elements of `term`. Fill out a lookup vector of ids the same
# length as `term` by:
# - For each Reg, just look up the id in the EGraph
# - For each ETrm, look up its arguments in the lookup vector, and then lookup
# - For each ETerm, look up its arguments in the lookup vector (the arguments
# to the ETerm are indices into the lookup vector, not eids), and then look up
# the completed ETrm in the EGraph
#
# At the end, put the last id in the lookup vector into `reg`.
Lookup(term::Vector{Union{Reg, ETrm}}, reg::Reg)
Lookup(term::Vector{Union{Reg, ETerm}}, reg::Reg)

# Iterate through every eclass in the egraph.
#
Expand Down
8 changes: 4 additions & 4 deletions test/syntax/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test

using GATlab

@present P(ThMonoid) begin
@gatcontext P(ThMonoid) begin
(a,b,c)
end

Expand All @@ -19,7 +19,7 @@ i4 = add!(eg, :(c ⋅ (a ⋅ b)))

@test i3 == i4

@present C(ThCategory) begin
@gatcontext C(ThCategory) begin
(x,y,z) :: Ob
f :: Hom(x,y)
g :: Hom(y,z)
Expand All @@ -39,8 +39,8 @@ merge!(eg, add!(eg, :x), add!(eg, :z))

i2 = add!(eg, :(g f))

# # E Matching
# ############
# E-Matching
############

# @theory C <: ThCategory begin
# x::Ob
Expand Down

0 comments on commit a52a407

Please sign in to comment.