From 700050800262a0bb880595cf99ad1759a203f49e Mon Sep 17 00:00:00 2001 From: Kris Brown Date: Thu, 5 Oct 2023 14:46:09 -0700 Subject: [PATCH] model migration revived --- src/models/ModelInterface.jl | 277 ++++++++++++++++--------------- src/stdlib/models/module.jl | 4 +- src/stdlib/module.jl | 4 +- src/syntax/TheoryMaps.jl | 2 +- src/syntax/gats/algorithms.jl | 17 +- src/syntax/gats/ast.jl | 7 +- test/stdlib/models/Arithmetic.jl | 62 +++---- test/stdlib/models/GATs.jl | 2 +- test/stdlib/models/Op.jl | 22 +-- test/stdlib/models/tests.jl | 2 +- 10 files changed, 205 insertions(+), 194 deletions(-) diff --git a/src/models/ModelInterface.jl b/src/models/ModelInterface.jl index a570d971..5dcd0b9a 100644 --- a/src/models/ModelInterface.jl +++ b/src/models/ModelInterface.jl @@ -258,7 +258,7 @@ function default_accessor_impl(x::Ident, theory::GAT, jltype_by_sort::Dict{AlgSo end julia_signature(theory::GAT, x::Ident, jltype_by_sort::Dict{AlgSort}) = - julia_signature(theory, x, getvalue(theory[x]), jltype_by_sort) + julia_signature(getvalue(theory[x]), jltype_by_sort; X=x) function julia_signature( termcon::AlgTermConstructor, @@ -502,140 +502,145 @@ macro withmodel(model, subsexpr, body) end -# """ -# Given a Theory Morphism T->U and a type Mᵤ (whose values are models of U), -# obtain a type Mₜ which has one parameter (of type Mᵤ) and is a model of T. - -# E.g. given NatIsMonoid: ThMonoid->ThNatPlus and IntPlus <: Model{Tuple{Int}} -# and IntPlus implements ThNatPlus: - -# ``` -# @migrate IntPlusMonoid = NatIsMonoid(IntPlus){Int} -# ``` - -# Yields: - -# ``` -# struct IntPlusMonoid <: Model{Tuple{Int}} -# model::IntPlus -# end - -# @instance ThMonoid{Int} [model::IntPlusMonoid] begin ... end -# ``` - -# Future work: There is some subtlety in how accessor functions should be handled. -# TODO: The new instance methods do not yet handle the `context` keyword argument. -# """ -# macro migrate(head) -# # Parse -# (name, mapname, modelname) = @match head begin -# Expr(:(=), name, Expr(:call, mapname, modelname)) => -# (name, mapname, modelname) -# _ => error("could not parse head of @theory: $head") -# end -# codom_types = :(only(supertype($(esc(modelname))).parameters).types) -# # Unpack -# tmap = macroexpand(__module__, :($mapname.@map)) -# dom_module = macroexpand(__module__, :($mapname.@dom)) -# codom_module = macroexpand(__module__, :($mapname.@codom)) -# dom_theory, codom_theory = TheoryMaps.dom(tmap), TheoryMaps.codom(tmap) - -# codom_jltype_by_sort = Dict{Ident,Expr0}(map(enumerate(sorts(codom_theory))) do (i,v) -# v.ref => Expr(:ref, codom_types, i) -# end) -# _x = gensym("val") - -# dom_types = map(sorts(dom_theory)) do s -# codom_jltype_by_sort[typemap(tmap)[s.ref].trm.head] -# end -# jltype_by_sort = Dict(zip(sorts(dom_theory), dom_types)) - -# # TypeCons for @instance macro -# funs = map(collect(typemap(tmap))) do (x, fx) -# xname = nameof(x) -# fxname = nameof(fx.trm.head) -# tc = getvalue(dom_theory[x]) -# jltype_by_sort[AlgSort(fx.trm.head)] = jltype_by_sort[AlgSort(x)] -# sig = julia_signature(dom_theory, x, jltype_by_sort) - -# argnames = [_x, nameof.(argsof(tc))...] -# args = [:($k::$v) for (k, v) in zip(argnames, sig.types)] - -# impls = to_call_impl.(fx.trm.args, Ref(termcons(codom_theory)), Ref(codom_module)) -# impl = Expr(:call, Expr(:ref, :($codom_module.$fxname), :(model.model)), _x, impls...) -# JuliaFunction(;name=xname, args=args, return_type=sig.types[1], impl=impl) -# end - -# # TermCons for @instance macro -# funs2 = map(collect(termmap(tmap))) do (x, fx) -# tc = getvalue(dom_theory[x]) - -# sig = julia_signature(dom_theory, x, jltype_by_sort) -# argnames = nameof.(argsof(tc)) -# ret_type = jltype_by_sort[AlgSort(typemap(tmap)[tc.type.head].trm.head)] - -# args = [:($k::$v) for (k, v) in zip(argnames, sig.types)] - -# impl = to_call_impl(fx.trm, termcons(codom_theory), codom_module) - -# JuliaFunction(;name=nameof(x), args=args, return_type=ret_type, impl=impl) -# end - -# funs3 = [] # accessors -# for (x, fx) in pairs(typemap(tmap)) -# tc = getvalue(dom_theory[x]) -# eq = equations(codom_theory, fx) -# args = [:($_x::$(jltype_by_sort[AlgSort(fx.trm.head)]))] -# scopedict = Dict{ScopeTag,ScopeTag}(gettag(tc.localcontext)=>gettag(fx.ctx)) -# for accessor in idents(tc.localcontext; lid=tc.args) -# accessor = retag(scopedict, accessor) -# a = nameof(accessor) -# # If we have a default means of computing the accessor... -# if !isempty(eq[accessor]) -# rtype = tc.localcontext[ident(tc.localcontext; name=a)] -# ret_type = jltype_by_sort[AlgSort(getvalue(rtype))] -# impl = to_call_impl(first(eq[accessor]), _x, codom_module) -# jf = JuliaFunction(;name=a, args=args, return_type=ret_type, impl=impl) -# push!(funs3, jf) -# end -# end -# end - -# model_expr = Expr( -# :curly, -# GlobalRef(Syntax.TheoryInterface, :Model), -# Expr(:curly, :Tuple, dom_types...) -# ) - -# quote -# struct $(esc(name)) <: $model_expr -# model :: $(esc(modelname)) -# end - -# @instance $dom_module [model :: $(esc(name))] begin -# $(generate_function.([funs...,funs2..., funs3...])...) -# end -# end -# end - -# """ -# Compile an AlgTerm into a Julia call Expr where termcons (e.g. `f`) are -# interpreted as `mod.f[model.model](...)`. -# """ -# function to_call_impl(t::AlgTerm, termcons, mod::Module) -# args = to_call_impl.(t.args, Ref(termcons), Ref(mod)) -# name = nameof(headof(t)) -# if t.head in termcons -# Expr(:call, Expr(:ref, :($mod.$name), :(model.model)), args...) -# else -# isempty(args) || error("Bad term $t (termcons=$termcons)") -# name -# end -# end - -# function to_call_impl(t::GATs.AccessorApplication, x::Symbol, mod::Module) -# rest = t.to isa Ident ? x : to_call_impl(t.to, x, mod) -# Expr(:call, Expr(:ref, :($mod.$(nameof(t.accessor))), :(model.model)), rest) -# end +""" +Given a Theory Morphism T->U and a type Mᵤ (whose values are models of U), +obtain a type Mₜ which has one parameter (of type Mᵤ) and is a model of T. + +E.g. given NatIsMonoid: ThMonoid->ThNatPlus and IntPlus <: Model{Tuple{Int}} +and IntPlus implements ThNatPlus: + +``` +@migrate IntPlusMonoid = NatIsMonoid(IntPlus){Int} +``` + +Yields: + +``` +struct IntPlusMonoid <: Model{Tuple{Int}} + model::IntPlus +end + +@instance ThMonoid{Int} [model::IntPlusMonoid] begin ... end +``` + +Future work: There is some subtlety in how accessor functions should be handled. +TODO: The new instance methods do not yet handle the `context` keyword argument. +""" +macro migrate(head) + # Parse + (name, mapname, modelname) = @match head begin + Expr(:(=), name, Expr(:call, mapname, modelname)) => + (name, mapname, modelname) + _ => error("could not parse head of @theory: $head") + end + codom_types = :(only(supertype($(esc(modelname))).parameters).types) + # Unpack + tmap = macroexpand(__module__, :($mapname.@map)) + dom_module = macroexpand(__module__, :($mapname.@dom)) + codom_module = macroexpand(__module__, :($mapname.@codom)) + dom_theory, codom_theory = TheoryMaps.dom(tmap), TheoryMaps.codom(tmap) + + codom_jltype_by_sort = Dict{Ident,Expr0}(map(enumerate(sorts(codom_theory))) do (i,v) + v.method => Expr(:ref, codom_types, i) + end) + _x = gensym("val") + dom_types = map(methodof.(sorts(dom_theory))) do s + codom_jltype_by_sort[typemap(tmap)[s].val.body.method] + end + jltype_by_sort = Dict(zip(sorts(dom_theory), dom_types)) + + # TypeCons for @instance macro + funs = map(collect(typemap(tmap))) do (x, fx) + tcon = getvalue(dom_theory[x]) + fxbody = bodyof(fx.val) + fxdecl, fxmethod = headof(fxbody), methodof(fxbody) + fxname = nameof(fxdecl) + xdecl = tcon.declaration + xname = nameof(xdecl) + jltype_by_sort[AlgSort(fxdecl, fxmethod)] = jltype_by_sort[AlgSort(xdecl, x)] + sig = julia_signature(dom_theory, x, jltype_by_sort) + argnames = [_x, nameof.(argsof(tcon))...] + args = [:($k::$v) for (k, v) in zip(argnames, sig.types)] + impls = to_call_impl.(fxbody.args, Ref(termcons(codom_theory)), Ref(codom_module)) + impl = Expr(:call, Expr(:ref, :($codom_module.$fxname), :(model.model)), _x, impls...) + JuliaFunction(;name=xname, args=args, return_type=sig.types[1], impl=impl) + end + + # TermCons for @instance macro + funs2 = map(collect(termmap(tmap))) do (x, fx) + tcon = getvalue(dom_theory[x]) + xname = nameof(tcon.declaration) + sig = julia_signature(dom_theory, x, jltype_by_sort) + argnames = nameof.(argsof(tcon)) + ftype = typemap(tmap)[tcon.type.body.method].val.body + ret_type = jltype_by_sort[AlgSort(headof(ftype), methodof(ftype))] + + args = [:($k::$v) for (k, v) in zip(argnames, sig.types)] + + impl = to_call_impl(fx.val, first.(termcons(codom_theory)), codom_module) + + JuliaFunction(;name=xname, args=args, return_type=ret_type, impl=impl) + end + + funs3 = [] # accessors + for (x, fx) in pairs(typemap(tmap)) + tc = getvalue(dom_theory[x]) + eq = equations(codom_theory, fx) + args = [:($_x::$(jltype_by_sort[AlgSort(fx.val)]))] + scopedict = Dict{ScopeTag,ScopeTag}(gettag(tc.localcontext)=>gettag(fx.ctx)) + for accessor in idents(tc.localcontext; lid=tc.args) + accessor = retag(scopedict, accessor) + a = nameof(accessor) + # If we have a default means of computing the accessor... + if !isempty(eq[accessor]) + rtype = tc.localcontext[ident(tc.localcontext; name=a)] + ret_type = jltype_by_sort[AlgSort(getvalue(rtype))] + impl = to_call_accessor(first(eq[accessor]), _x, codom_module) + jf = JuliaFunction(;name=a, args=args, return_type=ret_type, impl=impl) + push!(funs3, jf) + end + end + end + + model_expr = Expr( + :curly, + GlobalRef(Syntax.TheoryInterface, :Model), + Expr(:curly, :Tuple, dom_types...) + ) + + quote + struct $(esc(name)) <: $model_expr + model :: $(esc(modelname)) + end + + @instance $dom_module [model :: $(esc(name))] begin + $(generate_function.([funs...,funs2..., funs3...])...) + end + end +end + +""" +Compile an AlgTerm into a Julia call Expr where termcons (e.g. `f`) are +interpreted as `mod.f[model.model](...)`. +""" +function to_call_impl(t::AlgTerm, termcons, mod::Module) + b = bodyof(t) + if GATs.isvariable(t) + nameof(b) + else + args = to_call_impl.(argsof(b), Ref(termcons), Ref(mod)) + name = nameof(headof(b)) + b.head in termcons || error("t $t termcons $termcons") + Expr(:call, Expr(:ref, :($mod.$name), :(model.model)), args...) + end +end + +function to_call_accessor(t::AlgTerm, x::Symbol, mod::Module) + b = bodyof(t) + arg = only(b.args) + rest = GATs.isvariable(arg) ? x : to_call_accessor(arg, x, mod) + Expr(:call, Expr(:ref, :($mod.$(nameof(headof(b)))), :(model.model)), rest) +end + end # module diff --git a/src/stdlib/models/module.jl b/src/stdlib/models/module.jl index dd69e9e0..569fb3ed 100644 --- a/src/stdlib/models/module.jl +++ b/src/stdlib/models/module.jl @@ -8,7 +8,7 @@ include("FinMatrices.jl") include("SliceCategories.jl") include("Op.jl") include("Nothings.jl") -# include("GATs.jl") +include("GATs.jl") @reexport using .FinSets @reexport using .Arithmetic @@ -16,6 +16,6 @@ include("Nothings.jl") @reexport using .SliceCategories @reexport using .Op @reexport using .Nothings -# @reexport using .GATs +@reexport using .GATs end diff --git a/src/stdlib/module.jl b/src/stdlib/module.jl index 411ada00..742306b3 100644 --- a/src/stdlib/module.jl +++ b/src/stdlib/module.jl @@ -5,11 +5,11 @@ using Reexport include("theories/module.jl") include("models/module.jl") include("theorymaps/module.jl") -# include("derivedmodels/module.jl") +include("derivedmodels/module.jl") @reexport using .StdTheories @reexport using .StdModels @reexport using .StdTheoryMaps -# @reexport using .StdDerivedModels +@reexport using .StdDerivedModels end diff --git a/src/syntax/TheoryMaps.jl b/src/syntax/TheoryMaps.jl index 7b477f01..9bec2558 100644 --- a/src/syntax/TheoryMaps.jl +++ b/src/syntax/TheoryMaps.jl @@ -175,7 +175,7 @@ bind_localctx(ctx::GATContext, t::InCtx) = bind_localctx(GATContext(ctx.theory, AppendContext(ctx.context, t.ctx)), t.val) function bind_localctx(ctx::GATContext, t::AlgAST) - m = GATs.methodof(t.body) + m = methodof(t.body) tc = getvalue(ctx[m]) eqs = equations(ctx.theory, m) typed_terms = Dict{Ident, Pair{AlgTerm,AlgType}}() diff --git a/src/syntax/gats/algorithms.jl b/src/syntax/gats/algorithms.jl index 63b3a2b1..7e69c9c0 100644 --- a/src/syntax/gats/algorithms.jl +++ b/src/syntax/gats/algorithms.jl @@ -94,15 +94,16 @@ function equations(c::GATContext, args::AbstractVector{Ident}; init=nothing) end function equations(theory::GAT, t::TypeInCtx) - tc = getvalue(theory[headof(t.val)]) - extended = ScopeList([t.ctx, Scope([Binding{AlgType, Nothing}(nothing, t.val)])]) - lastx = last(getidents(extended)) - accessor_args = zip(idents(tc.localcontext; lid=tc.args), t.val.args) - init = Dict{Ident, AlgTerm}(map(accessor_args) do (accessor, arg) - hasident(t.ctx, headof(arg)) || error("Case not yet handled") - headof(arg) => AlgType(headof(t.val), accessor, lastx) + b = bodyof(t.val) + m = methodof(b) + newscope = Scope([Binding{AlgType}(nothing, t.val)]) + newterm = AlgTerm(only(getidents(newscope))) + extended = ScopeList([t.ctx, newscope]) + init = Dict{Ident, AlgTerm}(map(collect(theory.accessors[m])) do (i, acc) + algacc = getvalue(theory[acc]) + bodyof(b.args[i]) => AlgTerm(algacc.declaration, acc, [newterm]) end) - equations(extended, Ident[], theory; init=init) + equations(GATContext(theory, extended), Ident[]; init=init) end """Get equations for a term or type constructor""" diff --git a/src/syntax/gats/ast.jl b/src/syntax/gats/ast.jl index 18e1867e..a9512568 100644 --- a/src/syntax/gats/ast.jl +++ b/src/syntax/gats/ast.jl @@ -133,7 +133,6 @@ end `AlgSort` A *sort*, which is essentially a type constructor without arguments -`ref` must be reference to a `AlgTypeConstructor` """ @struct_hash_equal struct AlgSort head::Ident @@ -142,6 +141,9 @@ end AlgSort(t::AlgType) = AlgSort(t.body.head, t.body.method) +headof(a::AlgSort) = a.head +methodof(a::AlgSort) = a.method + function AlgSort(c::Context, t::AlgTerm) if isconstant(t) AlgSort(t.body.type) @@ -179,3 +181,6 @@ end const TermInCtx = InCtx{AlgTerm} const TypeInCtx = InCtx{AlgType} + +Scopes.getvalue(i::InCtx) = i.val +Scopes.getcontext(i::InCtx) = i.ctx diff --git a/test/stdlib/models/Arithmetic.jl b/test/stdlib/models/Arithmetic.jl index d1c7291f..07bb6a39 100644 --- a/test/stdlib/models/Arithmetic.jl +++ b/test/stdlib/models/Arithmetic.jl @@ -11,36 +11,36 @@ using .ThNatPlus @test S(S(Z())) + Z() == 2 end -# # IntMonoid = NatPlusMonoid(IntNatPlus) -# #-------------------------------------- -# using .ThMonoid - -# IM = IntMonoid(IntNatPlus()) -# @withmodel IM (e) begin -# @test e() == 0 -# @test (ThMonoid.:(⋅)[IM])(3, 4) == 7 -# end - -# # Integers as preorder -# #--------------------- -# using .ThPreorder - -# @withmodel IntPreorder() (Leq, refl, trans) begin -# @test trans((1,3), (3,5)) == (1,5) -# @test_throws TypeCheckFail Leq((5,3), 5, 3) -# @test refl(2) == (2,2) -# end - -# # Now using category interface - -# using .ThCategory -# M = IntPreorderCat(IntPreorder()) - -# @withmodel M (Hom, id, compose) begin -# @test compose((1,3), (3,5)) == (1,5) -# @test_throws TypeCheckFail Hom((5,3), 5, 3) -# @test_throws ErrorException compose((1,2), (3,5)) -# @test id(2) == (2,2) -# end +# IntMonoid = NatPlusMonoid(IntNatPlus) +#-------------------------------------- +using .ThMonoid + +IM = IntMonoid(IntNatPlus()) +@withmodel IM (e) begin + @test e() == 0 + @test (ThMonoid.:(⋅)[IM])(3, 4) == 7 +end + +# Integers as preorder +#--------------------- +using .ThPreorder + +@withmodel IntPreorder() (Leq, refl, trans) begin + @test trans((1,3), (3,5)) == (1,5) + @test_throws TypeCheckFail Leq((5,3), 5, 3) + @test refl(2) == (2,2) +end + +# Now using category interface + +using .ThCategory +M = IntPreorderCat(IntPreorder()) + +@withmodel M (Hom, id, compose) begin + @test compose((1,3), (3,5)) == (1,5) + @test_throws TypeCheckFail Hom((5,3), 5, 3) + @test_throws ErrorException compose((1,2), (3,5)) + @test id(2) == (2,2) +end end # module diff --git a/test/stdlib/models/GATs.jl b/test/stdlib/models/GATs.jl index 510298bc..04bb366a 100644 --- a/test/stdlib/models/GATs.jl +++ b/test/stdlib/models/GATs.jl @@ -8,7 +8,7 @@ using .ThCategory expected = @theorymap ThMonoid => ThNatPlus begin default => ℕ x ⋅ y ⊣ [x, y] => y + x - e => Z + e() => Z() end @withmodel GATC() (Ob, Hom, id, compose, dom, codom) begin diff --git a/test/stdlib/models/Op.jl b/test/stdlib/models/Op.jl index a9e3f4f7..01882147 100644 --- a/test/stdlib/models/Op.jl +++ b/test/stdlib/models/Op.jl @@ -24,17 +24,17 @@ end # Theory-morphism Op #------------------- -# M = OpFinSetC(FinSetC()) -# @withmodel M (Ob, Hom, id, compose, dom, codom) begin -# @test Ob(0) == 0 -# @test_throws TypeCheckFail Ob(-1) -# @test_throws TypeCheckFail Hom([1,5,2], 4, 3) -# @test Hom(Int[], 4, 0) == Int[] - -# @test id(2) == [1,2] -# @test compose([1,1,1,3,2], [5]) == [2] -# @test codom([5]) == 1 -# end +M = OpFinSetC(FinSetC()) +@withmodel M (Ob, Hom, id, compose, dom, codom) begin + @test Ob(0) == 0 + @test_throws TypeCheckFail Ob(-1) + @test_throws TypeCheckFail Hom([1,5,2], 4, 3) + @test Hom(Int[], 4, 0) == Int[] + + @test id(2) == [1,2] + @test compose([1,1,1,3,2], [5]) == [2] + @test codom([5]) == 1 +end end # module diff --git a/test/stdlib/models/tests.jl b/test/stdlib/models/tests.jl index 06d4e3ac..7893f362 100644 --- a/test/stdlib/models/tests.jl +++ b/test/stdlib/models/tests.jl @@ -6,6 +6,6 @@ include("FinMatrices.jl") include("SliceCategories.jl") include("Op.jl") include("Nothings.jl") -# include("GATs.jl") +include("GATs.jl") end