Skip to content

Commit

Permalink
generalize typemap to be valued in TypeInCtx
Browse files Browse the repository at this point in the history
remove reference
  • Loading branch information
Kris Brown authored and olynch committed Sep 21, 2023
1 parent d06ab56 commit 06be9f0
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 67 deletions.
11 changes: 10 additions & 1 deletion src/stdlib/theorymaps/Maps.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Maps

export SwapMonoid, NatPlusMonoid, PreorderCat
export SwapMonoid, NatPlusMonoid, PreorderCat, OpCat

using ...StdTheories
using ....Syntax
Expand All @@ -19,6 +19,15 @@ NatPlusMonoid = @theorymap ThMonoid => ThNatPlus begin
(x y) [x, y] => x+y [(x, y)::ℕ]
end


OpCat = @theorymap ThCategory => ThCategory begin
Ob => Ob
Hom => Hom(codom,dom) [dom::Ob, codom::Ob]
compose(f, g) [a::Ob, b::Ob, c::Ob, f::(a → b), g::(b → c)] =>
compose(g, f) [a::Ob, b::Ob, c::Ob, f::(b → a), g::(c → b)]
id(a) [a::Ob] => id(a) [a::Ob]
end

"""Preorders are categories"""
PreorderCat = @theorymap ThCategory => ThPreorder begin
Ob => default
Expand Down
61 changes: 42 additions & 19 deletions src/syntax/GATs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export Constant, AlgTerm, AlgType,
AlgTermConstructor, AlgTypeConstructor, AlgAxiom, sortsignature,
JudgmentBinding, GATSegment, GAT, sortcheck, allnames, sorts, sortname,
termcons, typecons, accessors, equations, build_infer_expr, compile,
TermInCtx, headof, argsof, argcontext
InCtx, TermInCtx, TypeInCtx, headof, argsof, argcontext

using ..Scopes
using ..ExprInterop
Expand Down Expand Up @@ -173,17 +173,19 @@ permitted.
const SortScope = Scope{AlgSort, Nothing}

"""
A term with an accompanying type scope, e.g.
A type or term with an accompanying type scope, e.g.
(a,b)::R
-----------
a*(a+b)
(a,b)::R (a,b)::Ob
----------- or ----------
a*(a+b) Hom(a,b)
"""
@struct_hash_equal struct TermInCtx
@struct_hash_equal struct InCtx{T<:TrmTyp}
ctx::TypeScope
trm::AlgTerm
trm::T
end

const TermInCtx = InCtx{AlgTerm}
const TypeInCtx = InCtx{AlgType}

"""
`sortcheck(ctx::Context, t::AlgTerm)`
Expand Down Expand Up @@ -499,14 +501,29 @@ function compile(expr_lookup::Dict{Ident}, term::AlgTerm; theorymodule=nothing)
end
end

"""Get the canonical term associated with a term constructor"""
function TermInCtx(g::GAT, k::Ident)
InCtx(g::GAT, k::Ident) =
(getvalue(g[k]) isa AlgTermConstructor ? TermInCtx : TypeInCtx)(g, k)

"""
Get the canonical term + ctx associated with a term constructor.
"""
function InCtx{AlgTerm}(g::GAT, k::Ident)
tcon = getvalue(g[k])
lc = argcontext(tcon)
ids = reverse(reverse(idents(lc))[1:(length(tcon.args))])
TermInCtx(lc, AlgTerm(k, AlgTerm.(ids)))
end

"""
Get the canonical type + ctx associated with a type constructor.
"""
function InCtx{AlgType}(g::GAT, k::Ident)
tcon = getvalue(g[k])
lc = argcontext(tcon)
TypeInCtx(lc, AlgType(k, AlgTerm.(idents(lc))))
end


"""
Infer the type of the term of a term. If it is not in context, recurse on its
arguments. The term constructor's output type yields the resulting type once
Expand All @@ -533,6 +550,7 @@ function infer_type(theory::GAT, t::TermInCtx)
if hasident(t.ctx, head)
getvalue(t.ctx[head]) # base case
else
#println("Inferring type of $t w/ head $head")
tc = getvalue(theory[head])
eqs = equations(theory, head)
typed_terms = Dict{Ident, Pair{AlgTerm,AlgType}}()
Expand Down Expand Up @@ -560,14 +578,16 @@ Take a term constructor and determine terms of its local context.
This function is mutually recursive with `infer_type`.
"""
function bind_localctx(theory::GAT, t::TermInCtx)
function bind_localctx(theory::GAT, t::InCtx{T}) where T
head = headof(t.trm)

tc = getvalue(theory[head])
eqs = equations(theory, head)

typed_terms = Dict{Ident, Pair{AlgTerm,AlgType}}()
#println("BINDING LOCAL CONTEXT $t")
for (i,a) in zip(tc.args, t.trm.args)
#println("tc.arg $i => t.trm arg $a")
tt = (a => infer_type(theory, TermInCtx(t.ctx, a)))
typed_terms[ident(tc.args, name=nameof(i))] = tt
end
Expand All @@ -585,13 +605,13 @@ function bind_localctx(theory::GAT, t::TermInCtx)
end

""" Replace idents with AlgTerms. """
function substitute_term(t::TrmTyp, dic::Dict{Ident,AlgTerm})
function substitute_term(t::T, dic::Dict{Ident,AlgTerm}) where T<:TrmTyp
iden = headof(t)
if haskey(dic, iden)
isempty(t.args) || error("Cannot substitute a term with arguments")
dic[iden]
else
AlgTerm(headof(t), substitute_term.(argsof(t), Ref(dic)))
T(headof(t), substitute_term.(argsof(t), Ref(dic)))
end
end

Expand Down Expand Up @@ -716,23 +736,26 @@ function parseaxiom(c::Context, localcontext, type_expr, e; name=nothing)
end
end

function ExprInterop.fromexpr(c::Context, e, ::Type{TermInCtx})
function ExprInterop.fromexpr(c::Context, e, ::Type{InCtx{T}}) where T
(binding, localcontext) = @match normalize_decl(e) begin
Expr(:call, :(), binding, Expr(:vect, args...)) => (binding, parsetypescope(c, args))
e => (e, TypeScope())
end
c′ = AppendScope(c, localcontext)
bound = Dict([nameof(b) => getvalue(b) for b in getbindings(localcontext)])
t = ExprInterop.fromexpr(c′, binding, AlgTerm; bound=bound)
TermInCtx(localcontext, t)
t = ExprInterop.fromexpr(c′, binding, T; bound=bound)
InCtx{T}(localcontext, t)
end

ExprInterop.toexpr(c::Context, tic::TermInCtx) = let c′=AppendScope(c,tic.ctx);
Expr(:call, :(), ExprInterop.toexpr(c′, tic.trm), ExprInterop.toexpr(c′, tic.ctx))
function ExprInterop.toexpr(c::Context, tic::InCtx; kw...)
c′=AppendScope(c,tic.ctx)
etrm = ExprInterop.toexpr(c′, tic.trm; kw...)
ectx = ExprInterop.toexpr(c′, tic.ctx; kw...)
Expr(:call, :(), etrm, ectx)
end

ExprInterop.toexpr(c::Context, ts::TypeScope) =
Expr(:vect,[Expr(:(::), nameof(b), toexpr(c, getvalue(b))) for b in ts]...)
ExprInterop.toexpr(c::Context, ts::TypeScope; kw...) =
Expr(:vect,[Expr(:(::), nameof(b), toexpr(c, getvalue(b); kw...)) for b in ts]...)

ExprInterop.toexpr(c::Context, at::Binding{AlgType, Nothing}) =
Expr(:(::), nameof(at), ExprInterop.toexpr(c, getvalue(at)))
Expand Down
21 changes: 18 additions & 3 deletions src/syntax/Scopes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,23 @@ disambiguated by a signature in `Sig` in the case of overloading.
end
end

function Base.show(io::IO, b::Binding)
print(io, isnothing(nameof(b)) ? "_" : nameof(b))
"""Type for printing out bindings with colored keys"""
@struct_hash_equal struct ScopedBinding
scope::ScopeTag
binding::Binding
end

Base.show(io::IO, b::ScopedBinding) =
Base.show(io, b.binding; crayon=tagcrayon(b.scope))

function Base.show(io::IO, b::Binding; crayon=nothing)
bname = isnothing(nameof(b)) ? "_" : nameof(b)
if isnothing(crayon) || !(get(io, :color, true))
print(io, bname)
else
print(io, crayon, bname)
print(io, inv(crayon))
end
if length(getaliases(b)) > 1
print(io, "(aliases=")
join(io, filter(x -> x != nameof(b), getaliases(b)), ", ")
Expand Down Expand Up @@ -407,7 +422,7 @@ end
function Base.show(io::IO, x::Scope)
print(io, "{")
for (i, b) in enumerate(x.bindings)
print(io, b)
print(io, ScopedBinding(gettag(x), b))
if i < length(x.bindings)
print(io, ", ")
end
Expand Down
94 changes: 52 additions & 42 deletions src/syntax/TheoryMaps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export IdTheoryMap, TheoryIncl, AbsTheoryMap, TheoryMap, @theorymap,

using ..GATs, ..Scopes, ..ExprInterop
using ..Scopes: unsafe_pushbinding!
using ..GATs: bindingexprs, bind_localctx, substitute_term
using ..GATs: InCtx, TrmTyp, bindingexprs, bind_localctx, substitute_term

import ..ExprInterop: toexpr, fromexpr

Expand All @@ -23,13 +23,19 @@ thought of in many ways.
The main methods for an AbsTheoryMap to implement are:
- dom, codom,
- typemap: A dictionary of Ident (of AlgTypeConstructor in domain) to AlgSort
(This must be a AlgSort of the same arity.)
- typemap: A dictionary of Ident (of AlgTypeConstructor in domain) to
TypeInCtx (The TypeScope of the TypeInCtx must be structurally
identical to the localcontext of the type constructor
associated with the key).
- termmap: A dictionary of Ident (of AlgTermConstructor in domain) to
TermInCtx. (The TypeScope of the TrmInCtx must be structurally
identical to the localcontext + args of the term constructor
associated with the key.)
The requirement that the values of `typemap` and `termmap` be structurally
identical to the contexts in the domain can eventually be relaxed (to allow
reordering, additional derived elements, dropping unused arguments), but for now
we require this for simplicity.
"""
abstract type AbsTheoryMap end

Expand All @@ -40,7 +46,7 @@ dom(f::AbsTheoryMap)::GAT = f.dom # assume this exists by default
codom(f::AbsTheoryMap)::GAT = f.codom # assume this exists by default

function compose(f::AbsTheoryMap, g::AbsTheoryMap)
typmap = Dict(k => typemap(g)[v.ref] for (k,v) in pairs(typemap(f)))
typmap = Dict(k => g(v) for (k, v) in pairs(typemap(f)))
trmmap = Dict(k => g(v) for (k, v) in pairs(termmap(f)))
TheoryMap(dom(f), codom(g), typmap, trmmap)
end
Expand All @@ -60,26 +66,24 @@ end

"""Map a context in the domain theory into a context of the codomain theory"""
function (f::AbsTheoryMap)(ctx::TypeScope)
scope = TypeScope()
cache = Dict{Symbol, AlgTerm}()
for b in ctx
argnames = nameof.(headof.(b.value.args))
val = AlgType(f(b.value.head).ref, AlgTerm[cache[a] for a in argnames])
fctx = TypeScope()
for i in 1:length(ctx)
b = ctx[LID(i)]
partial_scope = Scope(getbindings(ctx)[1:i-1]; tag=gettag(ctx))
val = f(partial_scope, getvalue(b), fctx)
new_binding = Binding{AlgType, Nothing}(b.primary, b.aliases, val, b.sig)
cache[nameof(b)] = AlgTerm(Ident(gettag(scope), LID(length(scope)+1),
nameof(new_binding)))
unsafe_pushbinding!(scope, new_binding)
unsafe_pushbinding!(fctx, new_binding)
end
scope
fctx
end

function (f::AbsTheoryMap)(t::TermInCtx)
function (f::AbsTheoryMap)(t::InCtx{T}) where T
fctx = f(t.ctx)
TermInCtx(fctx, f(t.ctx, t.trm, fctx))
InCtx{T}(fctx, f(t.ctx, t.trm, fctx))
end

""" Map a term `t` in context `c` along `f`. """
function (f::AbsTheoryMap)(ctx::TypeScope, t::AlgTerm, fctx=nothing)::AlgTerm
""" Map a term (or type) `t` in context `c` along `f`. """
function (f::AbsTheoryMap)(ctx::TypeScope, t::T, fctx=nothing) where {T<:TrmTyp}
fctx = isnothing(fctx) ? f(ctx) : fctx
head = headof(t)
if hasident(ctx, head)
Expand All @@ -93,19 +97,23 @@ function (f::AbsTheoryMap)(ctx::TypeScope, t::AlgTerm, fctx=nothing)::AlgTerm
for x in [termcon.args, termcon.localcontext])

# new_term has same context as termcon, so recursively map over components
lc = bind_localctx(f.dom, TermInCtx(ctx,t))
flc = Dict(retag(rt_dict, k) => f(ctx, v, fctx) for (k, v) in pairs(lc))

lc = bind_localctx(f.dom, InCtx{T}(ctx, t))
flc = Dict{Ident, AlgTerm}(map(collect(pairs(lc))) do (k, v)
if hasident(termcon.args, k) # offset when squashing localcontext and args
k = Ident(gettag(k), LID(getlid(k).val+length(termcon.localcontext)), nameof(k))
end
retag(rt_dict, k) => f(ctx, v, fctx)
end)
substitute_term(new_term.trm, flc)
end
end

function toexpr(m::AbsTheoryMap)
typs = map(collect(typemap(m))) do (k, v)
Expr(:call, :(=>), toexpr(dom(m), k), toexpr(codom(m), v))
end
Expr(:call, :(=>), toexpr(dom(m), k), toexpr(codom(m), v))
end
trms = map(collect(termmap(m))) do (k,v)
domterm = toexpr(dom(m), TermInCtx(dom(m), k))
domterm = toexpr(dom(m), InCtx(dom(m), k))
Expr(:call, :(=>), domterm, toexpr(codom(m), v))
end
Expr(:block, typs...,trms...)
Expand Down Expand Up @@ -146,7 +154,7 @@ A theory inclusion has a subset of scopes
end

typemap::Union{IdTheoryMap,TheoryIncl}) =
Dict(k => AlgSort(k) for k in typecons(dom(ι)))
Dict(k => TypeInCtx(dom(ι), k) for k in typecons(dom(ι)))

termmap::Union{IdTheoryMap,TheoryIncl}) =
Dict(k=>TermInCtx(dom(ι), k) for k in termcons(dom(ι)))
Expand All @@ -170,28 +178,27 @@ TODO: check that it is well-formed, axioms are preserved.
@struct_hash_equal struct TheoryMap <: AbsTheoryMap
dom::GAT
codom::GAT
typemap::Dict{Ident,AlgSort}
typemap::Dict{Ident,TypeInCtx}
termmap::Dict{Ident,TermInCtx}
function TheoryMap(dom, codom, typmap, trmmap)
missing_types = setdiff(Set(keys(typmap)), Set(typecons(dom)))
missing_terms = setdiff(Set(keys(trmmap)), Set(termcons(dom)))
isempty(missing_types) || error("Missing types $missing_types")
isempty(missing_terms) || error("Missing types $missing_terms")
f = new(dom, codom, typmap, trmmap)

# Check type constructors are coherent
for (k, v) in pairs(typmap)
f_args = f(argcontext(getvalue(dom[k])))
arg_fs = argcontext(getvalue(codom[v.ref]))
err = "Bad type map $k => $v ($f_args != $arg_fs)"
Scopes.equiv(f_args, arg_fs) || error(err)

tymap′, trmap′ = map([typmap, trmmap]) do tmap
Dict(k => v isa Ident ? InCtx(codom, v) : v for (k,v) in pairs(tmap))
end
# Check term constructors are coherent
for (k, v) in pairs(trmmap)
f_args = f(argcontext(getvalue(dom[k])))
arg_fs = v.ctx
err = "Bad term map $k => $v ($f_args != $arg_fs)"
Scopes.equiv(f_args, arg_fs) || error(err)

f = new(dom, codom, tymap′, trmap′)
# Check type/term constructors are coherent
for (typtrm, tmap) in ["type"=>tymap′, "term"=>trmap′]
for (k, v) in pairs(tmap)
f_args = f(argcontext(getvalue(dom[k])))
arg_fs = v.ctx
err = "Bad $typtrm map $k => $v ($f_args != $arg_fs)"
Scopes.equiv(f_args, arg_fs) || error(err)
end
end

f
Expand All @@ -209,15 +216,18 @@ TODO: we currently ignore LineNumberNodes. TheoryMap data structure could
TODO: handle more ambiguity via type inference
"""
function fromexpr(dom::GAT, codom::GAT, e, ::Type{TheoryMap})
tyms, trms = Dict{Ident, AlgSort}(), Dict{Ident, TermInCtx}()
tyms = Dict{Ident, Union{Ident, TypeInCtx}}()
trms = Dict{Ident, Union{Ident, TermInCtx}}()
exprs = @match e begin
Expr(:block, e1::Expr, es...) => [e1,es...]
Expr(:block, ::LineNumberNode, es...) => filter(x->!(x isa LineNumberNode), es)
end
for expr in exprs
e1, e2 = @match expr begin Expr(:call, :(=>), e1, e2) => (e1,e2) end

if e1 nameof.(typecons(dom))
tyms[fromexpr(dom, e1, Ident)] = fromexpr(codom, e2, AlgSort)
val = e2 isa Symbol ? fromexpr(codom, e2, Ident) : fromexpr(codom, e2, TypeInCtx)
tyms[fromexpr(dom, e1, Ident)] = val
else
val = fromexpr(codom, e2, TermInCtx)
key = fromexpr(dom, e1, TermInCtx)
Expand Down
Loading

0 comments on commit 06be9f0

Please sign in to comment.