From a6709d8055c7c976504fe35b939289b041196a5a Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Wed, 8 Nov 2023 17:47:03 -0800 Subject: [PATCH] the egraph matching machine lives again --- src/syntax/EGraphs.jl | 1 + src/syntax/egraphs/east.jl | 6 +++ src/syntax/egraphs/ematching.jl | 84 ++++++++++++++++----------------- test/syntax/EGraphs.jl | 35 +++++++------- 4 files changed, 66 insertions(+), 60 deletions(-) diff --git a/src/syntax/EGraphs.jl b/src/syntax/EGraphs.jl index e1e402e9..b1fe45b6 100644 --- a/src/syntax/EGraphs.jl +++ b/src/syntax/EGraphs.jl @@ -4,5 +4,6 @@ export EGraph, ETerm, EClass, EType, EConstant, EId, add!, merge!, rebuild! include("egraphs/east.jl") +include("egraphs/ematching.jl") end # module diff --git a/src/syntax/egraphs/east.jl b/src/syntax/egraphs/east.jl index e0038941..76fa2f05 100644 --- a/src/syntax/egraphs/east.jl +++ b/src/syntax/egraphs/east.jl @@ -21,6 +21,8 @@ to. For each parameter there is an e-term. body::MethodApp{EId} end +@as_record EType + EType(a::Ident,b::Ident,c::Vector{EId}) = EType(MethodApp(a,b,c)) @struct_hash_equal struct EConstant @@ -28,6 +30,8 @@ EType(a::Ident,b::Ident,c::Vector{EId}) = EType(MethodApp(a,b,c)) type::EType end +@as_record EConstant + """ ETerms in are interpreted in a GATContext. In the case of a MethodApp, the head/method refer to term constructors or accessors of the theory. @@ -36,6 +40,8 @@ head/method refer to term constructors or accessors of the theory. body::Union{Ident, MethodApp{EId}, EConstant} end +@as_record ETerm + ETerm(a::Ident,b::Ident,c::Vector{EId}) = ETerm(MethodApp(a,b,c)) const Parents = Dict{ETerm, EId} diff --git a/src/syntax/egraphs/ematching.jl b/src/syntax/egraphs/ematching.jl index c8e0e75b..d2a941b3 100644 --- a/src/syntax/egraphs/ematching.jl +++ b/src/syntax/egraphs/ematching.jl @@ -12,6 +12,7 @@ # Note that not all variables in the context are referenced directly in the # term; i.e. `b` is never referenced. Thus, ematching must take into account both # terms and types. +export Reg, Scan, Bind, Compare, Lookup, Machine, run! using ..EGraphs using ...Syntax @@ -32,19 +33,19 @@ struct Machine matches::Vector{Vector{EId}} limit::Union{Some{Int}, Nothing} function Machine(;limit=nothing) - new(Id[], Id[], Vector{Id}[], limit) + new(EId[], EId[], Vector{EId}[], limit) end end Base.getindex(m::Machine, r::Reg) = m.reg[r.idx] -Base.setindex!(m::Machine, r::Reg, i::Id) = (m.reg[r.idx] = i) +Base.setindex!(m::Machine, r::Reg, i::EId) = (m.reg[r.idx] = i) struct FinishedMatching <: Exception end function submit_match!(m::Machine, subst::Vector{Reg}) - match = Id[m[r] for r in subst] + match = EId[m[r] for r in subst] push!(m.matches, match) if !isnothing(m.limit) && length(m.matches) >= m.limit.value throw(FinishedMatching()) @@ -62,16 +63,16 @@ end # Check if the eclass bound to `i` is the same as the eclass bound to `j` Compare(i::Reg, j::Reg) - # Each element of `term` is either a register or an ETrm where the ids + # Each element of `term` is either a register or an ETerm where the ids # refer to earlier elements of `term`. Fill out a lookup vector of ids the same # length as `term` by: # - For each Reg, just look up the id in the EGraph - # - For each ETerm, look up its arguments in the lookup vector (the arguments - # to the ETerm are indices into the lookup vector, not eids), and then look up - # the completed ETrm in the EGraph + # - For each MethodApp, look up its arguments in the lookup vector (the arguments + # to the MethodApp are indices into the lookup vector, not eids), and then look up + # the completed ETerm in the EGraph # # At the end, put the last id in the lookup vector into `reg`. - Lookup(term::Vector{Union{Reg, ETerm}}, reg::Reg) + Lookup(term::Vector{Union{Reg, MethodApp{EId}}}, reg::Reg) # Iterate through every eclass in the egraph. # @@ -79,7 +80,7 @@ end # and run the rest of the instructions. # # Note: we can probably get better performance by only iterating through eclasses - # with a certain ETyp, or that come from a certain constructor. + # with a certain ETyp. Scan(out::Reg) end @@ -97,11 +98,11 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction}, eclass = eg.eclasses[find!(eg, m[i])] remaining = @view instructions[idx+1:end] for etrm in eclass.reps - if etrm.head != trmcon + if !(etrm.body isa MethodApp && etrm.body.method == trmcon) continue end resize!(m.reg, out.idx - 1) - append!(m.reg, etrm.args) + append!(m.reg, etrm.body.args) run!(m, eg, remaining, subst) end return @@ -116,8 +117,8 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction}, for x in trm @match x begin Reg(_) => push!(m.lookup, find!(eg, m[x])) - ETrm(head, args) => begin - etrm = ETrm(head, Id[m.lookup[i] for i in args]) + MethodApp(decl, method, args) => begin + etrm = ETerm(MethodApp(decl, method, EId[m.lookup[i] for i in args])) @match get(eg.hashcons, etrm, nothing) begin nothing => return id => push!(m.lookup, id) @@ -125,9 +126,7 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction}, end end end - if lookup[end] != find!(eg, m[reg]) - return - end + lookup[end] = find!(eg, m[reg]) end Scan(out) => begin remaining = @view instructions[idx+1:end] @@ -144,19 +143,19 @@ function run!(m::Machine, eg::EGraph, instructions::AbstractVector{Instruction}, end struct Compiler - v2r::Dict{Lvl, Reg} - free_vars::Vector{Set{Lvl}} + v2r::Dict{Ident, Reg} + free_vars::Vector{Set{Ident}} subtree_size::Vector{Int} - todo_nodes::Dict{Tuple{Int, Reg}, ETrm} + todo_nodes::Dict{Tuple{Ident, Reg}, ETerm} instructions::Vector{Instruction} next_reg::Reg - theory::Theory - function Compiler(theory::Theory) + theory::GAT + function Compiler(theory::GAT) new( - Dict{Lvl, Reg}(), - Set{Lvl}[], + Dict{Ident, Reg}(), + Set{Ident}[], Int[], - Dict{Tuple{Int, Reg}, ETrm}(), + Dict{Tuple{Int, Reg}, ETerm}(), Instruction[], Reg(1), theory @@ -165,27 +164,27 @@ struct Compiler end struct Pattern - trm::Trm + trm::AlgTerm ctx::Context end -function trm_to_vec!(trm::Trm, vec::Vector{ETrm}) - ids = Vector{Id}(trm_to_vec!.(trm.args, Ref(vec))) - push!(vec, ETrm(trm.head, ids)) +function trm_to_vec!(trm::AlgTerm, vec::Vector{ETerm}) + ids = Vector{EId}(trm_to_vec!.(trm.args, Ref(vec))) + push!(vec, ETerm(trm.head, ids)) length(vec) end -function vec_to_trm(vec::Vector{ETrm}, id::Id) +function vec_to_trm(vec::Vector{ETerm}, id::EId) etrm = vec[id] - args = Vector{Trm}(vec_to_term(Ref(vec), etrm.args)) - Trm(etrm.head, args) + args = Vector{AlgTerm}(vec_to_term(Ref(vec), etrm.args)) + AlgTerm(etrm.head, args) end -function load_pattern!(c::Compiler, patvec::Vector{ETrm}) +function load_pattern!(c::Compiler, patvec::Vector{ETerm}) n = length(patvec) for etrm in patvec - free = Set{Lvl}() + free = Set{Ident}() size = 0 hd = etrm.head if is_context(hd) @@ -202,8 +201,7 @@ function load_pattern!(c::Compiler, patvec::Vector{ETrm}) end end - -function add_todo!(c::Compiler, patvec::Vector{ETrm}, id::Id, reg::Reg) +function add_todo!(c::Compiler, patvec::Vector{ETerm}, id::EId, reg::Reg) etrm = patvec[id] hd = etrm.head if is_context(hd) @@ -250,7 +248,7 @@ end # # Why? Idk, this is how it works in egg function next!(c::Compiler) - function key(idreg::Tuple{Id, Reg}) + function key(idreg::Tuple{EId, Reg}) id = idreg[1] n_bound = length(filter(v -> v in keys(c.v2r), c.free_vars[id])) n_free = length(c.free_vars[id]) - n_bound @@ -266,21 +264,21 @@ function next!(c::Compiler) (k,v) end -is_ground_now(c::Compiler, id::Id) = all(v ∈ keys(c.v2r) for v in c.free_vars[id]) +is_ground_now(c::Compiler, id::EId) = all(v ∈ keys(c.v2r) for v in c.free_vars[id]) -function extract(patvec::Vector{ETrm}, i::Id) +function extract(patvec::Vector{ETerm}, i::EId) trm = vec_to_trm(patvec, i) - vec = ETrm[] + vec = ETerm[] trm_to_vec!(trm, vec) vec end # Returns a Program -function compile(T::Type{<:AbstractTheory}, pat::Pattern) - patvec = ETrm[] +function compile(mod::Module, pat::Pattern) + patvec = ETerm[] trm_to_vec!(pat.trm, patvec) - c = Compiler(gettheory(T)) + c = Compiler(mod.Meta.theory) load_pattern!(c, patvec) @@ -297,7 +295,7 @@ function compile(T::Type{<:AbstractTheory}, pat::Pattern) push!( c.instructions, Lookup( - Union{ETrm, Reg}[is_context(t.head) ? c.v2r[t.head] : t for t in extracted], + Union{ETerm, Reg}[is_context(t.head) ? c.v2r[t.head] : t for t in extracted], reg ) ) diff --git a/test/syntax/EGraphs.jl b/test/syntax/EGraphs.jl index 6ea6f23a..dc828cd7 100644 --- a/test/syntax/EGraphs.jl +++ b/test/syntax/EGraphs.jl @@ -31,7 +31,8 @@ i1 = add!(eg, :(f ⋅ g)) EGraphs.etypeof(eg, i1) -EGraphs.extract(eg, EGraphs.etypeof(eg, i1); chooser=first) +type_fg = EGraphs.extract(eg, EGraphs.etypeof(eg, i1); chooser=first) +@test type_fg == fromexpr(C, :(Hom(x,z)), AlgType) @test_throws Exception add!(eg, :(g ⋅ f)) @@ -42,29 +43,29 @@ i2 = add!(eg, :(g ⋅ f)) # E-Matching ############ -# @theory C <: ThCategory begin -# x::Ob -# y::Ob -# z::Ob -# f::Hom(x,y) -# g::Hom(x,x) -# h::Hom(y,y) -# end +@gatcontext C(ThCategory) begin + x::Ob + y::Ob + z::Ob + f::Hom(x,y) + g::Hom(x,x) + h::Hom(y,y) +end -# eg = EGraph(C.T) +eg = EGraph(C) -# id = add!(eg, @term C (f ⋅ id(y))) +id = add!(eg, :(f ⋅ id(y))) -# compose_lvl = (@term C (f ⋅ h)).head -# id_lvl = (@term C id(x)).head +compose_method = fromexpr(C, :(f ⋅ h), AlgTerm).body.method +id_method = fromexpr(C, :(id(x)), AlgTerm).body.method -# instructions = [Scan(Reg(1)), Bind(compose_lvl, Reg(1), Reg(2)), Bind(id_lvl, Reg(3), Reg(4))] +instructions = [Scan(Reg(1)), Bind(compose_method, Reg(1), Reg(2)), Bind(id_method, Reg(3), Reg(4))] -# m = Machine() +m = Machine() -# run!(m, eg, instructions, [Reg(4), Reg(2)]) +run!(m, eg, instructions, [Reg(4), Reg(2)]) -# @test m.matches[1] == [add!(eg, @term C y), add!(eg, @term C f)] +@test m.matches[1] == [add!(eg, :y), add!(eg, :f)] # Γ = @context ThCategory [a::Ob, b::Ob, α::Hom(a,b)] # t = @term ThCategory Γ (α ⋅ id(b))