Skip to content

Commit

Permalink
the egraph matching machine lives again
Browse files Browse the repository at this point in the history
  • Loading branch information
olynch committed Nov 9, 2023
1 parent a52a407 commit a6709d8
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 60 deletions.
1 change: 1 addition & 0 deletions src/syntax/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ export EGraph, ETerm, EClass, EType, EConstant, EId,
add!, merge!, rebuild!

include("egraphs/east.jl")
include("egraphs/ematching.jl")

end # module
6 changes: 6 additions & 0 deletions src/syntax/egraphs/east.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,17 @@ to. For each parameter there is an e-term.
body::MethodApp{EId}
end

@as_record EType

EType(a::Ident,b::Ident,c::Vector{EId}) = EType(MethodApp(a,b,c))

Check warning on line 26 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L26

Added line #L26 was not covered by tests

@struct_hash_equal struct EConstant
value::Any
type::EType
end

@as_record EConstant

"""
ETerms in are interpreted in a GATContext. In the case of a MethodApp, the
head/method refer to term constructors or accessors of the theory.
Expand All @@ -36,6 +40,8 @@ head/method refer to term constructors or accessors of the theory.
body::Union{Ident, MethodApp{EId}, EConstant}
end

@as_record ETerm

ETerm(a::Ident,b::Ident,c::Vector{EId}) = ETerm(MethodApp(a,b,c))

Check warning on line 45 in src/syntax/egraphs/east.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/east.jl#L45

Added line #L45 was not covered by tests

const Parents = Dict{ETerm, EId}
Expand Down
84 changes: 41 additions & 43 deletions src/syntax/egraphs/ematching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Note that not all variables in the context are referenced directly in the
# term; i.e. `b` is never referenced. Thus, ematching must take into account both
# terms and types.
export Reg, Scan, Bind, Compare, Lookup, Machine, run!

using ..EGraphs
using ...Syntax
Expand All @@ -32,19 +33,19 @@ struct Machine
matches::Vector{Vector{EId}}
limit::Union{Some{Int}, Nothing}
function Machine(;limit=nothing)
new(Id[], Id[], Vector{Id}[], limit)
new(EId[], EId[], Vector{EId}[], limit)

Check warning on line 36 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
end
end

Base.getindex(m::Machine, r::Reg) = m.reg[r.idx]

Check warning on line 40 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L40

Added line #L40 was not covered by tests

Base.setindex!(m::Machine, r::Reg, i::Id) = (m.reg[r.idx] = i)
Base.setindex!(m::Machine, r::Reg, i::EId) = (m.reg[r.idx] = i)

Check warning on line 42 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L42

Added line #L42 was not covered by tests

struct FinishedMatching <: Exception
end

function submit_match!(m::Machine, subst::Vector{Reg})
match = Id[m[r] for r in subst]
match = EId[m[r] for r in subst]
push!(m.matches, match)
if !isnothing(m.limit) && length(m.matches) >= m.limit.value
throw(FinishedMatching())

Check warning on line 51 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L47-L51

Added lines #L47 - L51 were not covered by tests
Expand All @@ -62,24 +63,24 @@ end
# Check if the eclass bound to `i` is the same as the eclass bound to `j`
Compare(i::Reg, j::Reg)

# Each element of `term` is either a register or an ETrm where the ids
# Each element of `term` is either a register or an ETerm where the ids
# refer to earlier elements of `term`. Fill out a lookup vector of ids the same
# length as `term` by:
# - For each Reg, just look up the id in the EGraph
# - For each ETerm, look up its arguments in the lookup vector (the arguments
# to the ETerm are indices into the lookup vector, not eids), and then look up
# the completed ETrm in the EGraph
# - For each MethodApp, look up its arguments in the lookup vector (the arguments
# to the MethodApp are indices into the lookup vector, not eids), and then look up
# the completed ETerm in the EGraph
#
# At the end, put the last id in the lookup vector into `reg`.
Lookup(term::Vector{Union{Reg, ETerm}}, reg::Reg)
Lookup(term::Vector{Union{Reg, MethodApp{EId}}}, reg::Reg)

# Iterate through every eclass in the egraph.
#
# For each eclass, assign its id to `out`, truncate the list of registers,
# and run the rest of the instructions.
#
# Note: we can probably get better performance by only iterating through eclasses
# with a certain ETyp, or that come from a certain constructor.
# with a certain ETyp.
Scan(out::Reg)
end

Expand All @@ -97,11 +98,11 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction},
eclass = eg.eclasses[find!(eg, m[i])]
remaining = @view instructions[idx+1:end]
for etrm in eclass.reps
if etrm.head != trmcon
if !(etrm.body isa MethodApp && etrm.body.method == trmcon)
continue

Check warning on line 102 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L94-L102

Added lines #L94 - L102 were not covered by tests
end
resize!(m.reg, out.idx - 1)
append!(m.reg, etrm.args)
append!(m.reg, etrm.body.args)
run!(m, eg, remaining, subst)
end
return

Check warning on line 108 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L104-L108

Added lines #L104 - L108 were not covered by tests
Expand All @@ -116,18 +117,16 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction},
for x in trm
@match x begin
Reg(_) => push!(m.lookup, find!(eg, m[x]))
ETrm(head, args) => begin
etrm = ETrm(head, Id[m.lookup[i] for i in args])
MethodApp(decl, method, args) => begin
etrm = ETerm(MethodApp(decl, method, EId[m.lookup[i] for i in args]))
@match get(eg.hashcons, etrm, nothing) begin
nothing => return
id => push!(m.lookup, id)

Check warning on line 124 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L115-L124

Added lines #L115 - L124 were not covered by tests
end
end
end
end
if lookup[end] != find!(eg, m[reg])
return
end
lookup[end] = find!(eg, m[reg])

Check warning on line 129 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L128-L129

Added lines #L128 - L129 were not covered by tests
end
Scan(out) => begin
remaining = @view instructions[idx+1:end]
Expand All @@ -144,19 +143,19 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction},
end

struct Compiler
v2r::Dict{Lvl, Reg}
free_vars::Vector{Set{Lvl}}
v2r::Dict{Ident, Reg}
free_vars::Vector{Set{Ident}}
subtree_size::Vector{Int}
todo_nodes::Dict{Tuple{Int, Reg}, ETrm}
todo_nodes::Dict{Tuple{Ident, Reg}, ETerm}
instructions::Vector{Instruction}
next_reg::Reg
theory::Theory
function Compiler(theory::Theory)
theory::GAT
function Compiler(theory::GAT)
new(

Check warning on line 154 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L153-L154

Added lines #L153 - L154 were not covered by tests
Dict{Lvl, Reg}(),
Set{Lvl}[],
Dict{Ident, Reg}(),
Set{Ident}[],
Int[],
Dict{Tuple{Int, Reg}, ETrm}(),
Dict{Tuple{Int, Reg}, ETerm}(),
Instruction[],
Reg(1),
theory
Expand All @@ -165,27 +164,27 @@ struct Compiler
end

struct Pattern
trm::Trm
trm::AlgTerm
ctx::Context
end

function trm_to_vec!(trm::Trm, vec::Vector{ETrm})
ids = Vector{Id}(trm_to_vec!.(trm.args, Ref(vec)))
push!(vec, ETrm(trm.head, ids))
function trm_to_vec!(trm::AlgTerm, vec::Vector{ETerm})
ids = Vector{EId}(trm_to_vec!.(trm.args, Ref(vec)))
push!(vec, ETerm(trm.head, ids))
length(vec)

Check warning on line 174 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L171-L174

Added lines #L171 - L174 were not covered by tests
end

function vec_to_trm(vec::Vector{ETrm}, id::Id)
function vec_to_trm(vec::Vector{ETerm}, id::EId)
etrm = vec[id]
args = Vector{Trm}(vec_to_term(Ref(vec), etrm.args))
Trm(etrm.head, args)
args = Vector{AlgTerm}(vec_to_term(Ref(vec), etrm.args))
AlgTerm(etrm.head, args)

Check warning on line 180 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L177-L180

Added lines #L177 - L180 were not covered by tests
end

function load_pattern!(c::Compiler, patvec::Vector{ETrm})
function load_pattern!(c::Compiler, patvec::Vector{ETerm})
n = length(patvec)

Check warning on line 184 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L183-L184

Added lines #L183 - L184 were not covered by tests

for etrm in patvec
free = Set{Lvl}()
free = Set{Ident}()
size = 0
hd = etrm.head
if is_context(hd)
Expand All @@ -202,8 +201,7 @@ function load_pattern!(c::Compiler, patvec::Vector{ETrm})
end

Check warning on line 201 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L199-L201

Added lines #L199 - L201 were not covered by tests
end


function add_todo!(c::Compiler, patvec::Vector{ETrm}, id::Id, reg::Reg)
function add_todo!(c::Compiler, patvec::Vector{ETerm}, id::EId, reg::Reg)
etrm = patvec[id]
hd = etrm.head
if is_context(hd)
Expand Down Expand Up @@ -250,7 +248,7 @@ end
#
# Why? Idk, this is how it works in egg
function next!(c::Compiler)
function key(idreg::Tuple{Id, Reg})
function key(idreg::Tuple{EId, Reg})
id = idreg[1]
n_bound = length(filter(v -> v in keys(c.v2r), c.free_vars[id]))
n_free = length(c.free_vars[id]) - n_bound
Expand All @@ -266,21 +264,21 @@ function next!(c::Compiler)
(k,v)

Check warning on line 264 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L262-L264

Added lines #L262 - L264 were not covered by tests
end

is_ground_now(c::Compiler, id::Id) = all(v keys(c.v2r) for v in c.free_vars[id])
is_ground_now(c::Compiler, id::EId) = all(v keys(c.v2r) for v in c.free_vars[id])

Check warning on line 267 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L267

Added line #L267 was not covered by tests

function extract(patvec::Vector{ETrm}, i::Id)
function extract(patvec::Vector{ETerm}, i::EId)
trm = vec_to_trm(patvec, i)
vec = ETrm[]
vec = ETerm[]
trm_to_vec!(trm, vec)
vec

Check warning on line 273 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L269-L273

Added lines #L269 - L273 were not covered by tests
end

# Returns a Program
function compile(T::Type{<:AbstractTheory}, pat::Pattern)
patvec = ETrm[]
function compile(mod::Module, pat::Pattern)
patvec = ETerm[]
trm_to_vec!(pat.trm, patvec)

Check warning on line 279 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L277-L279

Added lines #L277 - L279 were not covered by tests

c = Compiler(gettheory(T))
c = Compiler(mod.Meta.theory)

Check warning on line 281 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L281

Added line #L281 was not covered by tests

load_pattern!(c, patvec)

Check warning on line 283 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L283

Added line #L283 was not covered by tests

Expand All @@ -297,7 +295,7 @@ function compile(T::Type{<:AbstractTheory}, pat::Pattern)
push!(

Check warning on line 295 in src/syntax/egraphs/ematching.jl

View check run for this annotation

Codecov / codecov/patch

src/syntax/egraphs/ematching.jl#L289-L295

Added lines #L289 - L295 were not covered by tests
c.instructions,
Lookup(
Union{ETrm, Reg}[is_context(t.head) ? c.v2r[t.head] : t for t in extracted],
Union{ETerm, Reg}[is_context(t.head) ? c.v2r[t.head] : t for t in extracted],
reg
)
)
Expand Down
35 changes: 18 additions & 17 deletions test/syntax/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ i1 = add!(eg, :(f ⋅ g))

EGraphs.etypeof(eg, i1)

EGraphs.extract(eg, EGraphs.etypeof(eg, i1); chooser=first)
type_fg = EGraphs.extract(eg, EGraphs.etypeof(eg, i1); chooser=first)
@test type_fg == fromexpr(C, :(Hom(x,z)), AlgType)

@test_throws Exception add!(eg, :(g f))

Expand All @@ -42,29 +43,29 @@ i2 = add!(eg, :(g ⋅ f))
# E-Matching
############

# @theory C <: ThCategory begin
# x::Ob
# y::Ob
# z::Ob
# f::Hom(x,y)
# g::Hom(x,x)
# h::Hom(y,y)
# end
@gatcontext C(ThCategory) begin
x::Ob
y::Ob
z::Ob
f::Hom(x,y)
g::Hom(x,x)
h::Hom(y,y)
end

# eg = EGraph(C.T)
eg = EGraph(C)

# id = add!(eg, @term C (f ⋅ id(y)))
id = add!(eg, :(f id(y)))

# compose_lvl = (@term C (f ⋅ h)).head
# id_lvl = (@term C id(x)).head
compose_method = fromexpr(C, :(f h), AlgTerm).body.method
id_method = fromexpr(C, :(id(x)), AlgTerm).body.method

# instructions = [Scan(Reg(1)), Bind(compose_lvl, Reg(1), Reg(2)), Bind(id_lvl, Reg(3), Reg(4))]
instructions = [Scan(Reg(1)), Bind(compose_method, Reg(1), Reg(2)), Bind(id_method, Reg(3), Reg(4))]

# m = Machine()
m = Machine()

# run!(m, eg, instructions, [Reg(4), Reg(2)])
run!(m, eg, instructions, [Reg(4), Reg(2)])

# @test m.matches[1] == [add!(eg, @term C y), add!(eg, @term C f)]
@test m.matches[1] == [add!(eg, :y), add!(eg, :f)]

# Γ = @context ThCategory [a::Ob, b::Ob, α::Hom(a,b)]
# t = @term ThCategory Γ (α ⋅ id(b))
Expand Down

0 comments on commit a6709d8

Please sign in to comment.