Skip to content

Commit

Permalink
Explicit dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Dec 12, 2024
1 parent 3d4bf2c commit 69dc438
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 18 deletions.
145 changes: 133 additions & 12 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ in the theory.
module ModelInterface

export implements, impl_type, TypeCheckFail, SignatureMismatchError,
@model, @instance, @withmodel, @fail, migrate_model
@model, @instance, @withmodel, @fail, migrate_model, @default_model,
Dispatch

using ...Syntax
using ...Util.MetaUtils
using ...Util.MetaUtils: JuliaFunctionSigNoWhere

import ...Syntax.TheoryMaps: migrator
using ...Syntax.TheoryMaps: dom, codom
using ...Syntax.TheoryInterface: GAT_MODULE_LOOKUP
using ...Syntax.TheoryInterface: GAT_MODULE_LOOKUP, Dispatch

using MLStyle
using DataStructures: DefaultDict, OrderedDict
Expand All @@ -80,15 +81,18 @@ implements(m, ::Type{Val{tag}}) where {tag} = nothing

implements(m, tag::ScopeTag) = implements(m, Val{tag})

impl_type(m, tag::ScopeTag) = impl_type(m, Val{tag})

impl_type(m, mod::Module, name::Symbol) =
impl_type(m, gettag(ident(mod.Meta.theory; name)))


implements(m, theory_module::Module) =
all(!isnothing(implements(m, gettag(scope))) for scope in theory_module.Meta.theory.segments.scopes)

"""
If `m` implements a GAT with a type constructor (identified by ident `id`),
mapped to a Julia type, this function returns that Julia type.
"""
impl_type(m, id::Ident) = impl_type(m, Val{gettag(id)}, Val{getlid(id)})

impl_type(m, mod::Module, name::Symbol) =
impl_type(m, ident(mod.Meta.theory; name))

struct TypeCheckFail <: Exception
model::Any
theory::GAT
Expand Down Expand Up @@ -600,17 +604,17 @@ end
function impl_type_declaration(model_type, whereparams, sort, jltype)
quote
if !hasmethod($(GlobalRef(ModelInterface, :impl_type)),
($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(getdecl(sort)))}}))
($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(getdecl(sort)))}}, Type{Val{$(getlid(getdecl(sort)))}}))
$(GlobalRef(ModelInterface, :impl_type))(
::$(model_type), ::Type{Val{$(gettag(getdecl(sort)))}}
::$(model_type), ::Type{Val{$(gettag(getdecl(sort)))}}, ::Type{Val{$(getlid(getdecl(sort)))}}
) where {$(whereparams...)} = $(jltype)
end
end
end

function implements_declaration(model_type, scope, whereparams)
notes = ImplementationNotes(nothing)
m = only(methods(implements, (Any, Type{Val{1}})))
_, m = methods(implements, (Any, Type{Val{1}})) # first method is for Dispatch
quote
if $m == only(methods($(GlobalRef(ModelInterface, :implements)),
($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(scope))}})))
Expand Down Expand Up @@ -680,7 +684,7 @@ function migrate_model(FM::Module, m::Any, new_model_name::Union{Nothing,Symbol}

# Expressions which evaluate to the correct Julia type
jltype_by_sort = Dict(map(sorts(dom_theory)) do v
v => :(impl_type($m, gettag(getdecl(AlgSort($F($v.method).val)))))
v => :(impl_type($m, getdecl(AlgSort($F($v.method).val))))
end)

_x = gensym("val")
Expand Down Expand Up @@ -803,4 +807,121 @@ function to_call_accessor(t::AlgTerm, x::Symbol, mod::Module)
Expr(:call, Expr(:ref, :($mod.$(nameof(headof(b)))), :(model.model)), rest)
end


# Default models

"""
Create an @instance for a model `M` whose methods are determined by type
dispatch, e.g.:
```
@instance ThCategory{O,H} [model::M] begin
id(x::O) = id(x)
compose(f::H, g::H)::H = compose(f, g)
end
```
Use this with caution! For example, using this with two different models of
the same theory with the same types would cause a conflict.
"""
function default_instance(theorymodule, theory, jltype_by_sort, model)
acc = Iterators.flatten(values.(values(theory.accessors)))

Check warning on line 828 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L827-L828

Added lines #L827 - L828 were not covered by tests

termcon_funs = map(last.(termcons(theory)) acc) do x
generate_function(use_dispatch_method_impl(x, theory, jltype_by_sort))

Check warning on line 831 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L830-L831

Added lines #L830 - L831 were not covered by tests
end
generate_instance(

Check warning on line 833 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L833

Added line #L833 was not covered by tests
theory, theorymodule, jltype_by_sort, model, [],
Expr(:block, termcon_funs...); typecheck=true, escape=true)
end

"""
Create an @instance for a model `M` whose methods are determined by the
implementation of another model, `M2`, e.g.:
```
@instance ThCategory{O,H} [model::M] begin
id(x::O) = id[M2](x)
compose(f::H, g::H)::H = compose[M2](f, g)
end
```
"""
function default_model(theorymodule, theory, jltype_by_sort, model)
acc = Iterators.flatten(values.(values(theory.accessors)))
termcon_funs = map(last.(termcons(theory)) acc) do x
generate_function(use_model_method_impl(x, theory, jltype_by_sort, model))
end
generate_instance(
theory, theorymodule, jltype_by_sort, nothing, [],
Expr(:block, termcon_funs...); typecheck=true, escape=true)
end

macro default_model(head, model)
# Parse the head of @instance to get theory and instance types
(theory_module, instance_types) = @match head begin
:($ThX{$(Ts...)}) => (ThX, Ts)
_ => error("invalid syntax for head of @instance macro: $head")
end

# Get the underlying theory
theory = macroexpand(__module__, :($theory_module.Meta.@theory))

# A dictionary to look up the Julia type of a type constructor from its name (an ident)
jltype_by_sort = Dict{AlgSort,Expr0}([
zip(primitive_sorts(theory), instance_types)...,
[s => nameof(headof(s)) for s in struct_sorts(theory)]...
])

# Get the model type that we are overloading for, or nothing if this is the
# default instance for `instance_types`
m = parse_model_param(model)[1]

# Create the actual instance
default_model(theory_module, theory, jltype_by_sort, m)
end

"""
A canonical implementation that just calls the method with the implementation
of another model, `m`.
"""
function use_model_method_impl(x::Ident, theory::GAT,
jltype_by_sort::Dict{AlgSort}, m::Expr0)
op = getvalue(theory[x])
name = nameof(getdecl(op))
return_type = op isa AlgAccessor ? nothing : jltype_by_sort[AlgSort(op.type)]
args = args_from_sorts(sortsignature(op), jltype_by_sort)
impl = :(return $(name)[$m()]($(args...)))
JuliaFunction(name=name, args=args, return_type=return_type, impl=impl)
end

"""
A canonical implementation that just calls the method with type dispatch.
"""
function use_dispatch_method_impl(x::Ident, theory::GAT,

Check warning on line 900 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L900

Added line #L900 was not covered by tests
jltype_by_sort::Dict{AlgSort})
op = getvalue(theory[x])
name = nameof(getdecl(op))
return_type = op isa AlgAccessor ? nothing : jltype_by_sort[AlgSort(op.type)]
args = args_from_sorts(sortsignature(op), jltype_by_sort)
impl = :(return $(name)($(args...)))
JuliaFunction(name=name, args=args, return_type=return_type, impl=impl)

Check warning on line 907 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L902-L907

Added lines #L902 - L907 were not covered by tests
end

# Special model for any theory which uses dispatch
##################################################

"""
Check whether a dispatch model implements a particular scope of a theory.
This could be more rigorous, like actually checking whether certain methods
exist, but for now users will be assuming that the dispatch methods exist when
using a Dispatch model.
"""
function implements(d::Dispatch, ::T) where {X,T <: Type{Val{X}}}
hastag(d.t, X) ? true : nothing
end

function impl_type(d::Dispatch, x::Ident)
d.types[AlgSort(x, only(d.t.resolvers[x])[2])]
end

end # module
18 changes: 14 additions & 4 deletions src/models/SymbolicModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ...Util
using ...Syntax
import ...Syntax: invoke_term
using ..ModelInterface
using ..ModelInterface: args_from_sorts, default_instance

using Base.Meta: ParseError
using MLStyle
Expand Down Expand Up @@ -262,6 +263,10 @@ macro symbolic_model(decl, theoryname, body)

module_structs = symbolic_structs(theory, abstract_types, __module__)

# Part 1.5: Generate a model
imp = Expr(:import, Expr(:(:), Expr(:., :., :., name), [
Expr(:., x) for x in [name, theoryname, name, nameof.(sorts(theory))...]]...))

# Part 2: Generating internal methods

module_methods = [internal_accessors(theory)...,
Expand All @@ -271,15 +276,20 @@ macro symbolic_model(decl, theoryname, body)
export $(nameof.(sorts(theory))...)
using ..($(nameof(__module__)))
import ..($(nameof(__module__))): $theoryname

$(module_structs...)

$(generate_function.(module_methods)...)

module $(esc(:Meta))
import ..($(name)): $theoryname
$imp
const $(esc(:theory_module)) = $(esc(theoryname))
const $(esc(:theory)) = $(theory)
const $(esc(:theory_type)) = $(esc(theoryname)).Meta.T
# Canonical symbolic model
$(esc(:M)) = Dispatch($(esc(:theory)), [$(esc.(nameof.(theory.sorts))...)])
end

$(module_structs...)
$(generate_function.(module_methods)...)
end)

# Part 3: Generating instance of theory
Expand Down Expand Up @@ -731,4 +741,4 @@ function show_latex_script(io::IO, expr::GATExpr, head::String;
print(io, "}")
end

end
end # module
18 changes: 17 additions & 1 deletion src/syntax/TheoryInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,27 @@ function theory_impl(head, body, __module__)
)
end

"""
The Dispatch type is a model of every theory.
"""
@struct_hash_equal struct Dispatch
t::GAT
types::Dict{AlgSort,Type}
end

Dispatch(t::GAT, v::AbstractVector{<:Type}) =
Dispatch(t, Dict(zip(sorts(t), v)))

# WARNING: if any other package play with indexing methodnames with their own
# structs, then this code could be broken because it assumes we are the only
# ones to use this trick.
function juliadeclaration(name::Symbol)
quote
function $name end
# we expect just one method because of Dispatch type
if isempty(Base.methods(Base.getindex, [typeof($name), Any]))
Base.getindex(f::typeof($name), ::$(GlobalRef(TheoryInterface, :Dispatch))) = f

if Base.isempty(Base.methods(Base.getindex, [typeof($name), Any]))
function Base.getindex(::typeof($name), m::Any)
(args...; context=nothing) -> $name($(GlobalRef(TheoryInterface, :WithModel))(m), args...; context)
end
Expand Down
2 changes: 2 additions & 0 deletions src/syntax/gats/gat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ fancier.
bysignature::Dict{AlgSorts, Ident}
end

Base.iterate(m::MethodResolver, i...) = iterate(m.bysignature, i...)

function MethodResolver()
MethodResolver(Dict{AlgSorts, Ident}())
end
Expand Down
17 changes: 16 additions & 1 deletion test/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ end
@test !implements(FinSetC(), ThNatPlus)

# Todo: get things working where Ob and Hom are the same type (i.e. binding dict not monic)
struct TypedFinSetC
@struct_hash_equal struct TypedFinSetC
ntypes::Int
end

Expand Down Expand Up @@ -214,4 +214,19 @@ end
# this will fail unless WithModel accepts subtypes
@test ThSet.default[MyVect([1,2,3])](1) == 1

# Test default model + dispatch model
#####################################
@test_throws MethodError id(2)

@default_model ThCategory{Int, Vector{Int}} [model::FinSetC]

d = Dispatch(ThCategory.Meta.theory, [Int, Vector{Int}])
@test implements(d, ThCategory)
@test !implements(d, ThNatPlus)
@test impl_type(d, ThCategory, :Ob) == Int
@test impl_type(d, ThCategory, :Hom) == Vector{Int}

@test id(2) == [1,2] == ThCategory.id[d](2)
@test compose([1,2,3], [2,1,3]) == [2,1,3]

end # module
9 changes: 9 additions & 0 deletions test/models/SymbolicModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ f = FreeCategory.Hom{:generator}([:f], [x, y])
@test ThCategory.id(x) isa HomExpr{:id}
@test ThCategory.compose(ThCategory.id(x), f) == f

M = FreeCategory.Meta.M
@test M isa Dispatch
@test implements(M, ThCategory)
@test !implements(M, ThNatPlus)
@test impl_type(M, ThCategory, :Ob) == FreeCategory.Ob
@test impl_type(M, ThCategory, :Hom) == FreeCategory.Hom
@test ThCategory.id[M](x) isa HomExpr{:id}
@test ThCategory.compose[M](ThCategory.id(x), f) == f

# Monoid
########

Expand Down

0 comments on commit 69dc438

Please sign in to comment.