Skip to content

Commit

Permalink
Migrate models differently
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Oct 16, 2023
1 parent fb12e9a commit 96ddfd4
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 97 deletions.
160 changes: 81 additions & 79 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module ModelInterface

export Model, implements, TypeCheckFail, SignatureMismatchError,
@model, @instance, @withmodel, @fail, @migrate
@model, @instance, @withmodel, @fail, migrate

using ...Syntax
using ...Util.MetaUtils
using ...Util.MetaUtils: JuliaFunctionSigNoWhere

import ...Syntax.TheoryMaps: migrator

using MLStyle
using DataStructures: DefaultDict
using DataStructures: DefaultDict, OrderedDict

"""
`Model{Tup <: Tuple}`
Expand All @@ -16,7 +19,7 @@ A Julia value with type `Model{Tuple{Ts...}}` represents a model of some
part of the theory hierarchy, which uses the types in `Ts...` to implement
the sorts.
A model `m::Model{Tup}` is marked as implementing a `seg::GATSegmant` iff
A model `m::Model{Tup}` is marked as implementing a `seg::GATSegment` iff
`implements(m, ::Type{Val{gettag(seg)}}) == true`
Expand Down Expand Up @@ -332,6 +335,8 @@ end
function ExprInterop.toexpr(sig::JuliaFunctionSig)
Expr(:call, sig.name, [Expr(:(::), type) for type in sig.types]...)
end
ExprInterop.toexpr(sig::JuliaFunctionSigNoWhere) =
ExprInterop.toexpr(sig |> JuliaFunctionSig)

struct SignatureMismatchError <: Exception
name::Symbol
Expand Down Expand Up @@ -363,7 +368,7 @@ function typecheck_instance(
typechecked = JuliaFunction[]

# The overloads that we have to provide
undefined_signatures = Dict{JuliaFunctionSig, Tuple{Ident, Ident}}()
undefined_signatures = Dict{JuliaFunctionSigNoWhere, Tuple{Ident, Ident}}()

overload_errormsg =
"the types for this model declaration do not permit Julia overloading to distinguish between GAT overloads"
Expand All @@ -373,7 +378,7 @@ function typecheck_instance(
continue
end
for (_, x) in allmethods(resolver)
sig = julia_signature(getvalue(theory[x]), jltype_by_sort; oldinstance, X=x)
sig = julia_signature(getvalue(theory[x]), jltype_by_sort; oldinstance, X=x) |> JuliaFunctionSigNoWhere
if haskey(undefined_signatures, sig)
error(overload_errormsg)
end
Expand All @@ -388,15 +393,14 @@ function typecheck_instance(
end

for f in functions
sig = parse_function_sig(f)

sig = parse_function_sig(f) |> JuliaFunctionSigNoWhere
if haskey(undefined_signatures, sig)
(decl, method) = undefined_signatures[sig]

judgment = getvalue(theory, method)

if judgment isa AlgTypeConstructor
f = expand_fail(theory, decl, f)
f = expand_fail(theory, decl, f)
end

delete!(undefined_signatures, sig)
Expand Down Expand Up @@ -601,104 +605,101 @@ end
Future work: There is some subtlety in how accessor functions should be handled.
TODO: The new instance methods do not yet handle the `context` keyword argument.
"""
macro migrate(head)
# Parse
(name, mapname, modelname) = @match head begin
Expr(:(=), name, Expr(:call, mapname, modelname)) =>
(name, mapname, modelname)
_ => error("could not parse head of @theory: $head")
end
codom_types = :(only(supertype($modelname).parameters).types)
# Unpack
tmap = macroexpand(__module__, :($mapname.@map))
dom_module = macroexpand(__module__, :($mapname.@dom))
codom_module = macroexpand(__module__, :($mapname.@codom))
dom_theory, codom_theory = TheoryMaps.dom(tmap), TheoryMaps.codom(tmap)

codom_jltype_by_sort = Dict{Ident,Expr0}(map(enumerate(sorts(codom_theory))) do (i,v)
v.method => Expr(:ref, codom_types, i)
end)
function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)

# Map CODOM sorts to whereparam symbols
whereparamdict = OrderedDict(s=>gensym(s.head.name) for s in sorts(codom_theory))
whereparams = collect(values(whereparamdict))
name = :Migrator #gensym("migrator")
_x = gensym("val")
dom_types = map(methodof.(sorts(dom_theory))) do s
codom_jltype_by_sort[typemap(tmap)[s].val.body.method]
end
jltype_by_sort = Dict(zip(sorts(dom_theory), dom_types))

jltype_by_sort = Dict(map(sorts(dom_theory)) do v
v => whereparamdict[AlgSort(tmap(v.method).val)]
end)

accessor_funs = [] # accessors

# TypeCons for @instance macro
funs = map(collect(typemap(tmap))) do (x, fx)
tcon = getvalue(dom_theory[x])
fxbody = bodyof(fx.val)
fxdecl, fxmethod = headof(fxbody), methodof(fxbody)
fxname = nameof(fxdecl)
xdecl = tcon.declaration
xname = nameof(xdecl)
jltype_by_sort[AlgSort(fxdecl, fxmethod)] = jltype_by_sort[AlgSort(xdecl, x)]
typecon_funs = map(collect(typemap(tmap))) do (x, fx)
typecon = getvalue(dom_theory[x])
# Accessors
eq = equations(codom_theory, fx)
args = [:($_x::$(jltype_by_sort[AlgSort(typecon.declaration, x)]))]
scopedict = Dict{ScopeTag,ScopeTag}(gettag(typecon.localcontext)=>gettag(fx.ctx))
for accessor in idents(typecon.localcontext; lid=typecon.args)
accessor = retag(scopedict, accessor)
a = nameof(accessor)
# If we have a default means of computing the accessor...
if !isempty(eq[accessor])
rtype = typecon.localcontext[ident(typecon.localcontext; name=a)]
ret_type = jltype_by_sort[AlgSort(getvalue(rtype))]
impl = to_call_accessor(first(eq[accessor]), _x, codom_module)
jf = JuliaFunction(;name=a, args=args, return_type=ret_type, impl=impl)
push!(accessor_funs, jf)
end
end

# Typecon function
codom_body = bodyof(fx.val)
fxname = nameof(headof(codom_body))
xname = nameof(typecon.declaration)
sig = julia_signature(dom_theory, x, jltype_by_sort)
argnames = [_x, nameof.(argsof(tcon))...]
args = [:($k::$v) for (k, v) in zip(argnames, sig.types)]
impls = to_call_impl.(fxbody.args, Ref(termcons(codom_theory)), Ref(codom_module))
argnames = [_x, nameof.(argsof(typecon))...]
args = [:($k::$(v)) for (k, v) in zip(argnames, sig.types)]
impls = to_call_impl.(codom_body.args, Ref(termcons(codom_theory)), Ref(codom_module))
impl = Expr(:call, Expr(:ref, :($codom_module.$fxname), :(model.model)), _x, impls...)
JuliaFunction(;name=xname, args=args, return_type=sig.types[1], impl=impl)
end

# TermCons for @instance macro
funs2 = map(collect(termmap(tmap))) do (x, fx)
tcon = getvalue(dom_theory[x])
xname = nameof(tcon.declaration)
termcon_funs = map(collect(termmap(tmap))) do (x, fx)
termcon = getvalue(dom_theory[x])
func_name = nameof(termcon.declaration)
sig = julia_signature(dom_theory, x, jltype_by_sort)
argnames = nameof.(argsof(tcon))
ftype = typemap(tmap)[tcon.type.body.method].val.body
ret_type = jltype_by_sort[AlgSort(headof(ftype), methodof(ftype))]
argnames = nameof.(argsof(termcon))
ret_type = jltype_by_sort[AlgSort(termcon.type)]

args = [:($k::$v) for (k, v) in zip(argnames, sig.types)]

impl = to_call_impl(fx.val, first.(termcons(codom_theory)), codom_module)

JuliaFunction(;name=xname, args=args, return_type=ret_type, impl=impl)
end

funs3 = [] # accessors
for (x, fx) in pairs(typemap(tmap))
tc = getvalue(dom_theory[x])
eq = equations(codom_theory, fx)
args = [:($_x::$(jltype_by_sort[AlgSort(fx.val)]))]
scopedict = Dict{ScopeTag,ScopeTag}(gettag(tc.localcontext)=>gettag(fx.ctx))
for accessor in idents(tc.localcontext; lid=tc.args)
accessor = retag(scopedict, accessor)
a = nameof(accessor)
# If we have a default means of computing the accessor...
if !isempty(eq[accessor])
rtype = tc.localcontext[ident(tc.localcontext; name=a)]
ret_type = jltype_by_sort[AlgSort(getvalue(rtype))]
impl = to_call_accessor(first(eq[accessor]), _x, codom_module)
jf = JuliaFunction(;name=a, args=args, return_type=ret_type, impl=impl)
push!(funs3, jf)
end
end
JuliaFunction(;name=func_name, args=args, return_type=ret_type, impl=impl)
end

model_expr = Expr(
:curly,
GlobalRef(Syntax.TheoryInterface, :Model),
Expr(:curly, :Tuple, esc.(dom_types)...)
)

# Generate instance code
instance_code = generate_instance(
dom_theory,
dom_module,
jltype_by_sort,
name,
[],
Expr(:block, generate_function.([funs...,funs2..., funs3...])...);
typecheck=false
Expr(:curly, name, whereparams...),
whereparams,
Expr(:block, generate_function.([typecon_funs...,
termcon_funs...,
accessor_funs...
])...);
typecheck=true, escape=false
)

tup_params = Expr(:curly, :Tuple, whereparams...)

model_expr = Expr(
:curly,
GlobalRef(Syntax.TheoryInterface, :Model),
tup_params
)

# The second whereparams needs to be reordered by the sorts of the DOM theory
quote
struct $(esc(name)) <: $model_expr
model :: $(esc(modelname))
struct Migrator{$(whereparams...)} <: $model_expr
model :: $(GlobalRef(ModelInterface, :Model)){$tup_params}
function Migrator(model:: $(GlobalRef(ModelInterface, :Model)){$tup_params}) where {$(whereparams...)}
$(GlobalRef(ModelInterface, :implements))(model, $codom_module) || error("Cannot migrate model $model")
new{$(whereparams...)}(model)
end
end

$instance_code
$(instance_code.args...)
end
end

Expand All @@ -725,5 +726,6 @@ function to_call_accessor(t::AlgTerm, x::Symbol, mod::Module)
Expr(:call, Expr(:ref, :($mod.$(nameof(headof(b)))), :(model.model)), rest)
end

migrate(theorymap::Module, m::Model) = theorymap.Migrator(m)

end # module
6 changes: 3 additions & 3 deletions src/stdlib/derivedmodels/DerivedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ using ...StdTheoryMaps
using ...StdModels

# Given a model of a category C, we can derive a model of Cᵒᵖ.
@migrate OpFinSetC = OpCat(FinSetC)
OpFinSetC = migrate(OpCat, FinSetC())

# Interpret `e` as `0` and `⋅` as `+`.
@migrate IntMonoid = NatPlusMonoid(IntNatPlus)
IntMonoid = migrate(NatPlusMonoid, IntNatPlus())

# Interpret `id` as reflexivity and `compose` as transitivity.
@migrate IntPreorderCat = PreorderCat(IntPreorder)
IntPreorderCat = migrate(PreorderCat, IntPreorder())

end
7 changes: 7 additions & 0 deletions src/syntax/TheoryMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ macro theorymap(head, body)
codom = macroexpand(__module__, :($codomname.Meta.@theory))
tmap = fromexpr(dom, codom, body, TheoryMap)

mig = migrator(tmap, dommod, codommod, dom, codom)

esc(
Expr(
:toplevel,
Expand All @@ -411,11 +413,16 @@ macro theorymap(head, body)
macro map() $tmap end
macro dom() $dommod end
macro codom() $codommod end

$mig
end
),
:(Core.@__doc__ $(name)),
)
)
end

function migrator end


end # module
27 changes: 20 additions & 7 deletions src/util/MetaUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ setname(f::JuliaFunction, name) =
end
end

"""For comparing JuliaFunctionSigs modulo the where parameters"""
@struct_hash_equal struct JuliaFunctionSigNoWhere
name::Expr0
types::Vector{Expr0}
end
JuliaFunctionSigNoWhere(f::JuliaFunctionSig) =
JuliaFunctionSigNoWhere(f.name, f.types)

JuliaFunctionSig(f::JuliaFunctionSigNoWhere) = JuliaFunctionSig(f.name, f.types)

# Parsing Julia functions
#########################

Expand Down Expand Up @@ -81,12 +91,12 @@ function parse_function(expr::Expr)::JuliaFunction
Expr(:where, fun_head, whereparams...) => (fun_head, whereparams)
_ => (fun_expr, Expr0[])
end
(call_expr, return_type, impl, doc) = @match fun_expr begin
(call_expr, return_type, impl, doc) = @match fun_head begin
Expr(:(::), Expr(:call, args...), return_type) =>
(Expr(:call, args...), return_type, impl, doc)
Expr(:call, args...) =>
(Expr(:call, args...), nothing, impl, doc)
_ => throw(ParseError("Ill-formed function header $fun_expr"))
_ => throw(ParseError("Ill-formed function header $fun_head"))
end
(name, args, kwargs) = @match call_expr begin
Expr(:call, name, Expr(:parameters, kwargs...), args...) => (name, args, kwargs)
Expand Down Expand Up @@ -137,14 +147,17 @@ function generate_function(fun::JuliaFunction; rename=n->n)::Expr
[]
end
call_expr = Expr(:call, rename(fun.name), kwargsblock..., fun.args...)
if !isempty(fun.whereparams)
call_expr = Expr(:where, call_expr, fun.whereparams...)

if !isnothing(fun.return_type)
call_expr = Expr(:(::), call_expr, fun.return_type)
end
head = if isnothing(fun.return_type)

head = if isempty(fun.whereparams)
call_expr
else
Expr(:(::), call_expr, fun.return_type)
else
Expr(:where, call_expr, fun.whereparams...)
end

body = if isnothing(fun.impl)
Expr(:block)
else
Expand Down
10 changes: 4 additions & 6 deletions test/stdlib/models/Arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,26 @@ end
#--------------------------------------
using .ThMonoid

IM = IntMonoid(IntNatPlus())
@withmodel IM (e) begin
@withmodel IntMonoid (e) begin
@test e() == 0
@test (ThMonoid.:()[IM])(3, 4) == 7
@test (ThMonoid.:()[IntMonoid])(3, 4) == 7
end

# Integers as preorder
#---------------------
using .ThPreorder

@withmodel IntPreorder() (Leq, refl, trans) begin
@test trans((1,3), (3,5)) == (1,5)
@test trans((1,3), (3,5)) == (1,5)
@test_throws TypeCheckFail Leq((5,3), 5, 3)
@test refl(2) == (2,2)
end

# Now using category interface

using .ThCategory
M = IntPreorderCat(IntPreorder())

@withmodel M (Hom, id, compose) begin
@withmodel IntPreorderCat (Hom, id, compose) begin
@test compose((1,3), (3,5)) == (1,5)
@test_throws TypeCheckFail Hom((5,3), 5, 3)
@test_throws ErrorException compose((1,2), (3,5))
Expand Down
3 changes: 1 addition & 2 deletions test/stdlib/models/Op.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ end
# Theory-morphism Op
#-------------------

M = OpFinSetC(FinSetC())
@withmodel M (Ob, Hom, id, compose, dom, codom) begin
@withmodel OpFinSetC (Ob, Hom, id, compose, dom, codom) begin
@test Ob(0) == 0
@test_throws TypeCheckFail Ob(-1)
@test_throws TypeCheckFail Hom([1,5,2], 4, 3)
Expand Down

0 comments on commit 96ddfd4

Please sign in to comment.