From 175a9e6dae6b67a88d2cf9312a39d9339cc403df Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 2 Aug 2024 17:12:28 -0400 Subject: [PATCH] prototype of typed rewriting for theories and continued with little progress on docs --- DEC/Project.toml | 1 + DEC/docs/literate/egraphs.jl | 22 +++ DEC/scratch/tc.jl | 240 ++++++++++++++++++++++++++++ DEC/src/theories/ThDEC/rewriting.jl | 20 +++ DEC/src/theories/ThDEC/signature.jl | 24 +++ DEC/test/ThDEC/model.jl | 3 + 6 files changed, 310 insertions(+) create mode 100644 DEC/docs/literate/egraphs.jl create mode 100644 DEC/scratch/tc.jl diff --git a/DEC/Project.toml b/DEC/Project.toml index 8c492bb..263a2a0 100644 --- a/DEC/Project.toml +++ b/DEC/Project.toml @@ -17,6 +17,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [compat] CairoMakie = "0.12.5" diff --git a/DEC/docs/literate/egraphs.jl b/DEC/docs/literate/egraphs.jl new file mode 100644 index 0000000..b3b11cf --- /dev/null +++ b/DEC/docs/literate/egraphs.jl @@ -0,0 +1,22 @@ +# This lesson covers the internals of Metatheory.jl-style E-Graphs. Let's reuse the heat_equation model on a new roe. +roe = Roe(DEC.ThDEC.Sort); +function heat_equation(roe) + u = fresh!(roe, PrimalForm(0), :u) + + ∂ₜ(u) ≐ Δ(u) + + ([u], []) +end + +# We apply the model to the roe and collect its state variables. +(state, _) = heat_equation(roe) + +# Recall from the Introduction that an E-Graph is a bipartite graph of ENodes and EClasses. Let's look at the EClasses: +classes = roe.graph.classes +# The keys are Metatheory Id types which store an Id. The values are EClasses, which are implementations of equivalence classes. Nodes which share the same EClass are considered equivalent. + + + +# The constants in Roe are a dictionary of hashes of functions and constants. Let's extract just the values again: +vals = collect(values(e.graph.constants)) +# The `u` is ::RootVar{Form} and ∂ₜ, ★, d are all functions defined in ThDEC/signature.jl file. diff --git a/DEC/scratch/tc.jl b/DEC/scratch/tc.jl new file mode 100644 index 0000000..ee490ce --- /dev/null +++ b/DEC/scratch/tc.jl @@ -0,0 +1,240 @@ +using Metatheory +using Metatheory: OptBuffer +using Metatheory.Library +using Metatheory.Rewriters +using MLStyle +using Test +using Metatheory.Plotting + +b = OptBuffer{UInt128}(10) + +@testset "Predicate Assertions" begin + r = @rule ~a::iseven --> true + Base.iseven(g, ec::EClass) = + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Number && iseven(c) + end + false + end + # + g = EGraph(:(f(2, 1))) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + g = EGraph(:2) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + # + g = EGraph(:3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + # + new_id = addexpr!(g, :f) + union!(g, g.root, new_id) + # + new_id = addexpr!(g, 4) + union!(g, g.root, new_id) + # + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 +end + +### + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +@testset "Check form" begin + + function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c))" + return c isa Form + end + false + end + end + + r = @rule ~a::isform --> true + + t = @theory a begin + ~a::isform + ~b::isform --> 0 + end + + ## initialize and sanity-check + a1=Form(1); a2=Form(2) + a3=Form(1); a4=Form(1) + @assert a1 isa Form; @assert a3 isa Form + @assert a3 isa Form; @assert a4 isa Form + + g = EGraph(:($a1 + $a2)) + saturate!(g, t) + extract!(g, astsize) + + g = EGraph(:a1) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + g = EGraph(:a3) + @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + + + g = EGraph(:(a + b )) + saturate!(g, t) + extract!(g, astsize) + + +end + + +abstract type AbstractSort end + +@data Sort <: AbstractSort begin + Scalar + Form(dim::Int) +end + +t = @theory a b begin + a::Form(1) ∧ b::Form(2) --> 0 +end +# breaks! makevar @ MT/src/Syntax.jl:57 (<- makepattern, ibid:151) +# expects Symbol, not Expr + +function isform(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + @info "$c, $(typeof(c)), $(c isa Form)" + return c isa Form + end + end +end + +r = @rule ~a::isform --> true + +g = EGraph(:(f(2, 1))) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +g = EGraph(:2) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + +g = EGraph(:3) +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + +new_id = addexpr!(g, :f) +union!(g, g.root, new_id) + +new_id = addexpr!(g, 4) +union!(g, g.root, new_id) + +@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + + + + + +a=Form(1) +b=Form(2) +c=Form(1) +d=Form(1) +@assert a isa Form +@assert b isa Form +@assert c isa Form +@assert d isa Form + +ex = :(a + b + c + d) +g = EGraph(ex) +saturate!(g, _T) +extract!(g, astsize) + + +@data Sort1 <: AbstractSort begin + AnyForm +end + +_t = @theory a b begin + a::AnyForm ∧ b::AnyForm --> 0 +end +# PatVar error + +__t = @theory a b begin + a::var"AnyForm's constructor" + 0 --> a +end + +d0 = AnyForm +d1 = AnyForm +d2 = AnyForm +ex = :(d0 + 0 + (d1 + 0) + d2) + +g = EGraph(ex) +saturate!(g, __t) +extract!(g, astsize) + + +rwth = Fixpoint(Prewalk(Chain(__t))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, _t) +extract!(g, astsize) +# returns (a ∧ b), as + +rwth = Fixpoint(Prewalk(Chain(_t))) +rwth(ex) + +a = Form(0) +ex = :(a ∧ b) + + +Derivative = @theory f g a begin + f::Function * d(g::Function) + d(f::Function) * g::Function --> d(f * g) + d(f::Function) + d(g::Function) --> d(f + g) + d(a::Number * f::Function) --> a * d(f) +end + +_Derivative = @theory f g a begin + f * d(g) + d(f) * g --> d(f * g) + d(f) + d(g) --> d(f + g) + d(a * f) --> a * d(f) +end + +ex = :(f * d(g) + d(f) * g) + +rwth = Fixpoint(Prewalk(Chain(_Derivative))) +rwth(ex) + +foo(x) = x + 1 +goo(x) = x + 3 + +rwth = Fixpoint(Prewalk(Chain(Derivative))) +rwth(ex) + +g = EGraph(ex) +saturate!(g, Derivative); +extract!(g, astsize) + + +rwth(ex) + + +types = (U=Form(0), k=Scalar(),); + +tc(x) = @match x begin + s::Symbol => types[s] + ::Expr => true +end + +cond = x -> begin + @info "$x, $(type(x))" + tc(x) +end + +orw = Fixpoint(Prewalk(If(cond, Chain(rewrite_theory)))) + +orw(expr) diff --git a/DEC/src/theories/ThDEC/rewriting.jl b/DEC/src/theories/ThDEC/rewriting.jl index cb9bec7..81182e5 100644 --- a/DEC/src/theories/ThDEC/rewriting.jl +++ b/DEC/src/theories/ThDEC/rewriting.jl @@ -1,5 +1,25 @@ using Metatheory using Metatheory.Library +using Metatheory.Rewriters +using MLStyle + +buf = OptBuffer{UInt128}(10) + +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + +t = @theory a b begin + ~a::isForm + ~b::isForm --> 0 +end + ThMultiplicativeMonoid = @commutative_monoid (*) 1 ThAdditiveGroup = @commutative_group (+) 0 (-) diff --git a/DEC/src/theories/ThDEC/signature.jl b/DEC/src/theories/ThDEC/signature.jl index ef68938..ed20198 100644 --- a/DEC/src/theories/ThDEC/signature.jl +++ b/DEC/src/theories/ThDEC/signature.jl @@ -37,6 +37,30 @@ function Base.show(io::IO, ω::Form) print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") end +## Predicates +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + + +function isForm(g, ec::EClass) + any(ec.nodes) do n + h = v_head(n) + if has_constant(g, h) + c = get_constant(g, h) + return c isa Form + end + false + end +end + ## OPERATIONS @nospecialize diff --git a/DEC/test/ThDEC/model.jl b/DEC/test/ThDEC/model.jl index 1cd2b75..08e998a 100644 --- a/DEC/test/ThDEC/model.jl +++ b/DEC/test/ThDEC/model.jl @@ -1,3 +1,6 @@ +using DEC +import DEC: Δ, ∧ + # load other dependencies using ComponentArrays using CombinatorialSpaces