Skip to content

Commit

Permalink
working pushouts
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Oct 20, 2023
1 parent 5bbd879 commit d2cbdbd
Show file tree
Hide file tree
Showing 29 changed files with 643 additions and 112 deletions.
1 change: 1 addition & 0 deletions src/GATlab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ include("util/module.jl")
include("syntax/module.jl")
include("models/module.jl")
include("stdlib/module.jl")
include("nonstdlib/module.jl") # don't reexport this

@reexport using .Util
@reexport using .Syntax
Expand Down
104 changes: 82 additions & 22 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,18 @@ macro instance(head, model, body)
theory = macroexpand(__module__, :($theory_module.Meta.@theory))

# A dictionary to look up the Julia type of a type constructor from its name (an ident)
jltype_by_sort = Dict(zip(sorts(theory), instance_types)) # for type checking
i = 0
jltype_by_sort = Dict(map(sorts(theory)) do s
sorttype = getvalue(theory[methodof(s)])
s => if sorttype isa AlgTypeConstructor
i += 1
instance_types[i]
elseif sorttype isa AlgStruct
nameof(sorttype.declaration)
end
end)
i == length(instance_types) || error("Did not use all types ($i): $instance_types")


# Get the model type that we are overloading for, or nothing if this is the
# default instance for `instance_types`
Expand Down Expand Up @@ -176,7 +187,7 @@ function generate_instance(
# methods for type constructors and type argument accessors if these methods
# are missing
typechecked_functions = if typecheck
typecheck_instance(theory, functions, ext_functions, jltype_by_sort; oldinstance)
typecheck_instance(theory, functions, ext_functions, jltype_by_sort; oldinstance, theory_module)
else
[functions..., ext_functions...] # skip typechecking and expand_fail
end
Expand All @@ -187,15 +198,16 @@ function generate_instance(
# to
# `ThCategory.Ob(m::WithModel{M}, x; context=nothing) = let model = m.model in blah end`
qualified_functions =
map(fun -> qualify_function(fun, theory_module, model_type, whereparams), typechecked_functions)
map(fun -> qualify_function(fun, theory_module, model_type, whereparams,
Set(nameof.(structs(theory)))),
typechecked_functions)

append!(
qualified_functions,
make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions)
make_alias_definitions(theory, theory_module, jltype_by_sort, model_type,
whereparams, ext_functions)
)

# Add overloads for the alias methods

# Declare that this model implements the theory

implements_declarations = if !isnothing(model_type)
Expand Down Expand Up @@ -302,7 +314,7 @@ function julia_signature(
args = if oldinstance && isempty(sortsig)
Expr0[Expr(:curly, :Type, jltype_by_sort[AlgSort(termcon.type)])]
else
Expr0[jltype_by_sort[sort] for sort in sortsignature(termcon)]
Expr0[jltype_by_sort[sort] for sort in sortsig if !sort.eq]
end
JuliaFunctionSig(
nameof(getdecl(termcon)),
Expand Down Expand Up @@ -332,6 +344,16 @@ function julia_signature(
JuliaFunctionSig(nameof(getdecl(acc)), [jlargtype])
end


function julia_signature(str::AlgFunction, jltype_by_sort::Dict{AlgSort}; kw...)
sortsig = sortsignature(str)
args = Expr0[jltype_by_sort[sort] for sort in sortsig]
JuliaFunctionSig(
nameof(getdecl(str)),
args
)
end

function ExprInterop.toexpr(sig::JuliaFunctionSig)
Expr(:call, sig.name, [Expr(:(::), type) for type in sig.types]...)
end
Expand Down Expand Up @@ -364,6 +386,7 @@ function typecheck_instance(
ext_functions::Vector{Symbol},
jltype_by_sort::Dict{AlgSort};
oldinstance=false,
theory_module=nothing,
)::Vector{JuliaFunction}
typechecked = JuliaFunction[]

Expand All @@ -378,14 +401,24 @@ function typecheck_instance(
continue
end
for (_, x) in allmethods(resolver)
if getvalue(theory[x]) isa AlgStruct
continue
end
sig = julia_signature(getvalue(theory[x]), jltype_by_sort; oldinstance, X=x) |> JuliaFunctionSigNoWhere
if haskey(undefined_signatures, sig)
error(overload_errormsg)
error(overload_errormsg * ": $x vs $(undefined_signatures[sig])")
end
undefined_signatures[sig] = (decl, x)
end
end

for x in getidents(theory)
v = getvalue(theory[x])
if v isa AlgFunction
push!(typechecked, mk_fun(v, theory, theory_module, jltype_by_sort))
end
end

expected_signatures = DefaultDict{Ident, Set{Expr0}}(()->Set{Expr0}())

for (sig, (decl, _)) in undefined_signatures
Expand Down Expand Up @@ -432,7 +465,7 @@ function typecheck_instance(
for (sig, (decl, method)) in undefined_signatures
judgment = getvalue(theory[method])
if judgment isa AlgTermConstructor
error("Failed to implement $(toexpr(sig))")
error("Failed to implement $decl: $(toexpr(sig))")
elseif judgment isa AlgTypeConstructor
push!(typechecked, default_typecon_impl(method, theory, jltype_by_sort))
elseif judgment isa AlgAccessor
Expand Down Expand Up @@ -464,14 +497,23 @@ function expand_fail(theory::GAT, x::Ident, f::JuliaFunction)
)
end

function mk_fun(f::AlgFunction, theory, mod, jltype_by_sort)
name = nameof(f.declaration)
args = map(zip(f.args, sortsignature(f))) do (i,s)
Expr(:(::),nameof(f[i]),jltype_by_sort[s])
end
impl = to_call_impl(f.value,theory, mod, false)
JuliaFunction(;name=name, args, impl)
end

function make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions)
lines = []
oldinstance = isnothing(model_type)
for segment in theory.segments.scopes
for binding in segment
alias = getvalue(binding)
name = nameof(binding)
if alias isa Alias && name ext_functions
if alias isa Alias && nameof(alias.ref) ext_functions
for (argsorts, method) in allmethods(theory.resolvers[alias.ref])
args = [(gensym(), jltype_by_sort[sort]) for sort in argsorts]
args = if oldinstance
Expand Down Expand Up @@ -504,9 +546,10 @@ end
"""
Add `WithModel` param first, if this is not an old instance (it shouldn't have it already)
Qualify method name to be in theory module
Qualify args to struct types
Add `context` kwargs if not already present
"""
function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{Expr0, Nothing}, whereparams)
function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{Expr0, Nothing}, whereparams, structnames)
kwargs = filter(fun.kwargs) do kwarg
@match kwarg begin
Expr(:kw, :context, _) => false
Expand All @@ -517,11 +560,18 @@ function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{E
end
end
kwargs = Expr0[Expr(:kw, :context, nothing); kwargs]

(args, impl) = if !isnothing(model_type)
args = map(fun.args) do arg
@match arg begin
Expr(:(::), argname, ty) => Expr(:(::), argname,
ty structnames ? Expr(:., theory_module, QuoteNode(ty)) : ty )
_ => arg
end
end

m = gensym(:m)
(
[Expr(:(::), m, Expr(:curly, TheoryInterface.WithModel, model_type)), fun.args...],
[Expr(:(::), m, Expr(:curly, TheoryInterface.WithModel, model_type)), args...],
Expr(:let, Expr(:(=), :model, :($m.model)), fun.impl)
)
else
Expand All @@ -542,9 +592,12 @@ end
function implements_declaration(model_type, scope, whereparams)
notes = ImplementationNotes(nothing)
quote
$(GlobalRef(ModelInterface, :implements))(
::$(model_type), ::Type{Val{$(gettag(scope))}}
) where {$(whereparams...)} = $notes
if !hasmethod($(GlobalRef(ModelInterface, :implements)),
($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(scope))}}))
$(GlobalRef(ModelInterface, :implements))(
::$(model_type), ::Type{Val{$(gettag(scope))}}
) where {$(whereparams...)} = $notes
end
end
end

Expand Down Expand Up @@ -643,7 +696,7 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)

return_type = first(sig.types)

impls = to_call_impl.(codom_body.args, Ref(codom_module))
impls = to_call_impl.(codom_body.args, Ref(codom_theory), Ref(codom_module), true)
impl = Expr(:call, Expr(:ref, :($codom_module.$fxname),
:(model.model)), _x, impls...)

Expand All @@ -659,7 +712,7 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)
name = nameof(termcon.declaration)
return_type = jltype_by_sort[AlgSort(termcon.type)]
args = [:($k::$v) for (k, v) in zip(nameof.(argsof(termcon)), sig.types)]
impl = to_call_impl(fx.val, codom_module)
impl = to_call_impl(fx.val, codom_theory, codom_module, true)

JuliaFunction(;name, args, return_type, impl)
end
Expand Down Expand Up @@ -705,14 +758,21 @@ end
Compile an AlgTerm into a Julia call Expr where termcons (e.g. `f`) are
interpreted as `mod.f[model.model](...)`.
"""
function to_call_impl(t::AlgTerm, mod::Module)
function to_call_impl(t::AlgTerm, theory::GAT, mod::Union{Symbol,Module}, migrate::Bool)
b = bodyof(t)
if GATs.isvariable(t)
nameof(b)
else
args = to_call_impl.(argsof(b), Ref(mod))
elseif GATs.isdot(t)
Expr(:., to_call_impl(b.body, theory, mod, migrate), QuoteNode(b.head))
else
args = to_call_impl.(argsof(b), Ref(theory), Ref(mod), migrate)
name = nameof(headof(b))
Expr(:call, Expr(:ref, :($mod.$name), :(model.model)), args...)
newhead = if name nameof.(structs(theory))
Expr(:., :($mod), QuoteNode(name))

Check warning on line 771 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L771

Added line #L771 was not covered by tests
else
Expr(:ref, :($mod.$name), migrate ? :(model.model) : :model)
end
Expr(:call, newhead, args...)
end
end

Expand Down
41 changes: 41 additions & 0 deletions src/nonstdlib/models/Pushouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using DataStructures, StructEquality

export PushoutInt

using GATlab

"""Data required to specify a pushout. No fancy caching."""
@struct_hash_equal struct PushoutInt
ob::Int
i1::Vector{Int}
i2::Vector{Int}
end

using .ThPushout

@instance ThPushout{Int, Vector{Int}, PushoutInt} [model::FinSetC] begin
@import Ob, Hom, id, compose, dom, codom
function pushout(sp::Span; context)

Check warning on line 18 in src/nonstdlib/models/Pushouts.jl

View check run for this annotation

Codecov / codecov/patch

src/nonstdlib/models/Pushouts.jl#L18

Added line #L18 was not covered by tests
B, C = context[:d], context[:c]
d = DataStructures.IntDisjointSets(B+C)
[union!(d, sp.left[a], sp.right[a]+B) for a in 1:sp.apex]
roots = DataStructures.find_root!.(Ref(d), 1:length(d))
rootindex = sort(collect(Set(values(roots))))
toindex = [findfirst(==(r),rootindex) for r in roots]
PushoutInt(DataStructures.num_groups(d),
[toindex[roots[b]] for b in 1:B],
[toindex[roots[c+B]] for c in 1:C]
)
end
cospan(p::PushoutInt) = Cospan(p.ob, p.i1, p.i2)
function universal(p::PushoutInt, csp::Cospan; context)

Check warning on line 31 in src/nonstdlib/models/Pushouts.jl

View check run for this annotation

Codecov / codecov/patch

src/nonstdlib/models/Pushouts.jl#L31

Added line #L31 was not covered by tests
map(1:p.ob) do i
for (proj, csp_map) in [(p.i1, csp.left), (p.i2, csp.right)]
for (j, i′) in enumerate(proj)
if i′ == i return csp_map[j] end
end
end
error("Pushout is jointly surjective")

Check warning on line 38 in src/nonstdlib/models/Pushouts.jl

View check run for this annotation

Codecov / codecov/patch

src/nonstdlib/models/Pushouts.jl#L38

Added line #L38 was not covered by tests
end
end
end
10 changes: 10 additions & 0 deletions src/nonstdlib/models/module.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module NonStdModels

using ...Syntax
using ...Models
using ...Stdlib
using ..NonStdTheories

include("Pushouts.jl")

end
15 changes: 15 additions & 0 deletions src/nonstdlib/module.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module NonStdlib

using Reexport

include("theories/module.jl")
include("models/module.jl")
# include("theorymaps/module.jl")
# include("derivedmodels/module.jl")

@reexport using .NonStdTheories
@reexport using .NonStdModels
# @reexport using .StdTheoryMaps
# @reexport using .StdDerivedModels

end
52 changes: 52 additions & 0 deletions src/nonstdlib/theories/Pushouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
A theory of a category with pushouts highlights two features of
GATs which require syntactic sugar. The first is an operation
with a (dependent) record type as its output.
The second is the type of equality witnesses for any given
AlgSort. This allows us to define functions which are only
valid when certain *derived terms* from the arguments are
equal. (E.g. when to apply universal property of a pushout).
"""

export ThPushout

@theory ThSpanCospan <: ThCategory begin
struct Span(dom::Ob, codom::Ob)
apex::Ob
left::(apex→dom)
right::(apex→codom)
end

struct Cospan(dom::Ob, codom::Ob)
apex::Ob
left::(dom→apex)
right::(codom→apex)
end
end

"""A category with pushouts"""
@theory ThPushout <: ThSpanCospan begin
Pushout(s)::TYPE [(d,c)::Ob, s::Span(d,c)]
pushout(s)::Pushout(s) [(d,c)::Ob, s::Span(d,c)] # compute representative
cospan(p::Pushout(s))::Cospan(d,c) [(d,c)::Ob, s::Span(d,c)] # extract result
apex(p::Pushout(s)) := cospan(p).apex [(d,c)::Ob, s::Span(d,c)]
ι₁(p::Pushout(s)) := cospan(p).left [(d,c)::Ob, s::Span(d,c)]
ι₂(p::Pushout(s)) := cospan(p).right [(d,c)::Ob, s::Span(d,c)]

(pushout_commutes := (s.left)ι₁(p) == (s.right)ι₂(p)
[(d,c)::Ob, s::Span(d,c), p::Pushout(s)])

(universal(p, csp, eq) :: (apex(p) csp.apex)
[(d,c)::Ob, sp::Span(d,c), csp::Cospan(d,c), p::Pushout(sp),
eq::(sp.left⋅csp.left == sp.right⋅csp.right)])

((ι₁(p) universal(p, csp, eq) == csp.left)
[(d,c)::Ob, sp::Span(d,c), csp::Cospan(d,c), p::Pushout(sp),
eq::(sp.left⋅csp.left == sp.right⋅csp.right)])

((ι₂(p) universal(p, csp, eq) == csp.right)
[(d,c)::Ob, sp::Span(d,c), csp::Cospan(d,c), p::Pushout(sp),
eq::(sp.left⋅csp.left == sp.right⋅csp.right)])
end

10 changes: 10 additions & 0 deletions src/nonstdlib/theories/module.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module NonStdTheories

using ...Syntax
using ...Stdlib

using Reexport

include("Pushouts.jl")

end
2 changes: 1 addition & 1 deletion src/stdlib/derivedmodels/DerivedModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ IntMonoid = migrate(NatPlusMonoid, IntNatPlus())
# Interpret `id` as reflexivity and `compose` as transitivity.
IntPreorderCat = migrate(PreorderCat, IntPreorder())

end
end # module
Loading

0 comments on commit d2cbdbd

Please sign in to comment.