From a7d785ccf0197581d094541bc4cfcb031aa175a7 Mon Sep 17 00:00:00 2001 From: Kris Brown Date: Sun, 8 Dec 2024 18:41:28 -0800 Subject: [PATCH] Explicit dispatch --- src/models/ModelInterface.jl | 145 +++++++++++++++++++++++++++++++--- src/models/SymbolicModels.jl | 18 ++++- src/syntax/TheoryInterface.jl | 18 ++++- src/syntax/gats/gat.jl | 2 + test/models/ModelInterface.jl | 17 +++- test/models/SymbolicModels.jl | 9 +++ 6 files changed, 191 insertions(+), 18 deletions(-) diff --git a/src/models/ModelInterface.jl b/src/models/ModelInterface.jl index d9df72a1..8d453c57 100644 --- a/src/models/ModelInterface.jl +++ b/src/models/ModelInterface.jl @@ -45,7 +45,8 @@ in the theory. module ModelInterface export implements, impl_type, TypeCheckFail, SignatureMismatchError, - @model, @instance, @withmodel, @fail, migrate_model + @model, @instance, @withmodel, @fail, migrate_model, @default_model, + Dispatch using ...Syntax using ...Util.MetaUtils @@ -53,7 +54,7 @@ using ...Util.MetaUtils: JuliaFunctionSigNoWhere import ...Syntax.TheoryMaps: migrator using ...Syntax.TheoryMaps: dom, codom -using ...Syntax.TheoryInterface: GAT_MODULE_LOOKUP +using ...Syntax.TheoryInterface: GAT_MODULE_LOOKUP, Dispatch using MLStyle using DataStructures: DefaultDict, OrderedDict @@ -80,15 +81,18 @@ implements(m, ::Type{Val{tag}}) where {tag} = nothing implements(m, tag::ScopeTag) = implements(m, Val{tag}) -impl_type(m, tag::ScopeTag) = impl_type(m, Val{tag}) - -impl_type(m, mod::Module, name::Symbol) = - impl_type(m, gettag(ident(mod.Meta.theory; name))) - - implements(m, theory_module::Module) = all(!isnothing(implements(m, gettag(scope))) for scope in theory_module.Meta.theory.segments.scopes) +""" +If `m` implements a GAT with a type constructor (identified by ident `id`), +mapped to a Julia type, this function returns that Julia type. +""" +impl_type(m, id::Ident) = impl_type(m, Val{gettag(id)}, Val{getlid(id)}) + +impl_type(m, mod::Module, name::Symbol) = + impl_type(m, ident(mod.Meta.theory; name)) + struct TypeCheckFail <: Exception model::Any theory::GAT @@ -600,9 +604,9 @@ end function impl_type_declaration(model_type, whereparams, sort, jltype) quote if !hasmethod($(GlobalRef(ModelInterface, :impl_type)), - ($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(getdecl(sort)))}})) + ($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(getdecl(sort)))}}, Type{Val{$(getlid(getdecl(sort)))}})) $(GlobalRef(ModelInterface, :impl_type))( - ::$(model_type), ::Type{Val{$(gettag(getdecl(sort)))}} + ::$(model_type), ::Type{Val{$(gettag(getdecl(sort)))}}, ::Type{Val{$(getlid(getdecl(sort)))}} ) where {$(whereparams...)} = $(jltype) end end @@ -610,7 +614,7 @@ end function implements_declaration(model_type, scope, whereparams) notes = ImplementationNotes(nothing) - m = only(methods(implements, (Any, Type{Val{1}}))) + _, m = methods(implements, (Any, Type{Val{1}})) # first method is for Dispatch quote if $m == only(methods($(GlobalRef(ModelInterface, :implements)), ($(model_type) where {$(whereparams...)}, Type{Val{$(gettag(scope))}}))) @@ -679,7 +683,7 @@ function migrate_model(F::AbsTheoryMap, m::Any, new_model_name::Union{Nothing,Sy # Expressions which evaluate to the correct Julia type jltype_by_sort = Dict(map(sorts(dom_theory)) do v - v => :(impl_type($m, gettag(getdecl(AlgSort($F($v.method).val))))) + v => :(impl_type($m, getdecl(AlgSort($F($v.method).val)))) end) _x = gensym("val") @@ -802,4 +806,121 @@ function to_call_accessor(t::AlgTerm, x::Symbol, mod::Module) Expr(:call, Expr(:ref, :($mod.$(nameof(headof(b)))), :(model.model)), rest) end + +# Default models + +""" +Create an @instance for a model `M` whose methods are determined by type +dispatch, e.g.: + +``` +@instance ThCategory{O,H} [model::M] begin + id(x::O) = id(x) + compose(f::H, g::H)::H = compose(f, g) +end +``` + +Use this with caution! For example, using this with two different models of +the same theory with the same types would cause a conflict. +""" +function default_instance(theorymodule, theory, jltype_by_sort, model) + acc = Iterators.flatten(values.(values(theory.accessors))) + + termcon_funs = map(last.(termcons(theory)) ∪ acc) do x + generate_function(use_dispatch_method_impl(x, theory, jltype_by_sort)) + end + generate_instance( + theory, theorymodule, jltype_by_sort, model, [], + Expr(:block, termcon_funs...); typecheck=true, escape=true) +end + +""" +Create an @instance for a model `M` whose methods are determined by the +implementation of another model, `M2`, e.g.: + +``` +@instance ThCategory{O,H} [model::M] begin + id(x::O) = id[M2](x) + compose(f::H, g::H)::H = compose[M2](f, g) +end +``` +""" +function default_model(theorymodule, theory, jltype_by_sort, model) + acc = Iterators.flatten(values.(values(theory.accessors))) + termcon_funs = map(last.(termcons(theory)) ∪ acc) do x + generate_function(use_model_method_impl(x, theory, jltype_by_sort, model)) + end + generate_instance( + theory, theorymodule, jltype_by_sort, nothing, [], + Expr(:block, termcon_funs...); typecheck=true, escape=true) +end + +macro default_model(head, model) + # Parse the head of @instance to get theory and instance types + (theory_module, instance_types) = @match head begin + :($ThX{$(Ts...)}) => (ThX, Ts) + _ => error("invalid syntax for head of @instance macro: $head") + end + + # Get the underlying theory + theory = macroexpand(__module__, :($theory_module.Meta.@theory)) + + # A dictionary to look up the Julia type of a type constructor from its name (an ident) + jltype_by_sort = Dict{AlgSort,Expr0}([ + zip(primitive_sorts(theory), instance_types)..., + [s => nameof(headof(s)) for s in struct_sorts(theory)]... + ]) + + # Get the model type that we are overloading for, or nothing if this is the + # default instance for `instance_types` + m = parse_model_param(model)[1] + + # Create the actual instance + default_model(theory_module, theory, jltype_by_sort, m) +end + +""" +A canonical implementation that just calls the method with the implementation +of another model, `m`. +""" +function use_model_method_impl(x::Ident, theory::GAT, + jltype_by_sort::Dict{AlgSort}, m::Expr0) + op = getvalue(theory[x]) + name = nameof(getdecl(op)) + return_type = op isa AlgAccessor ? nothing : jltype_by_sort[AlgSort(op.type)] + args = args_from_sorts(sortsignature(op), jltype_by_sort) + impl = :(return $(name)[$m()]($(args...))) + JuliaFunction(name=name, args=args, return_type=return_type, impl=impl) +end + +""" +A canonical implementation that just calls the method with type dispatch. +""" +function use_dispatch_method_impl(x::Ident, theory::GAT, + jltype_by_sort::Dict{AlgSort}) + op = getvalue(theory[x]) + name = nameof(getdecl(op)) + return_type = op isa AlgAccessor ? nothing : jltype_by_sort[AlgSort(op.type)] + args = args_from_sorts(sortsignature(op), jltype_by_sort) + impl = :(return $(name)($(args...))) + JuliaFunction(name=name, args=args, return_type=return_type, impl=impl) +end + +# Special model for any theory which uses dispatch +################################################## + +""" +Check whether a dispatch model implements a particular scope of a theory. +This could be more rigorous, like actually checking whether certain methods +exist, but for now users will be assuming that the dispatch methods exist when +using a Dispatch model. +""" +function implements(d::Dispatch, ::T) where {X,T <: Type{Val{X}}} + hastag(d.t, X) ? true : nothing +end + +function impl_type(d::Dispatch, x::Ident) + d.types[AlgSort(x, only(d.t.resolvers[x])[2])] +end + end # module diff --git a/src/models/SymbolicModels.jl b/src/models/SymbolicModels.jl index 26f33c38..d9a180f8 100644 --- a/src/models/SymbolicModels.jl +++ b/src/models/SymbolicModels.jl @@ -8,6 +8,7 @@ using ...Util using ...Syntax import ...Syntax: invoke_term using ..ModelInterface +using ..ModelInterface: args_from_sorts, default_instance using Base.Meta: ParseError using MLStyle @@ -262,6 +263,10 @@ macro symbolic_model(decl, theoryname, body) module_structs = symbolic_structs(theory, abstract_types, __module__) + # Part 1.5: Generate a model + imp = Expr(:import, Expr(:(:), Expr(:., :., :., name), [ + Expr(:., x) for x in [name, theoryname, name, nameof.(sorts(theory))...]]...)) + # Part 2: Generating internal methods module_methods = [internal_accessors(theory)..., @@ -271,15 +276,20 @@ macro symbolic_model(decl, theoryname, body) export $(nameof.(sorts(theory))...) using ..($(nameof(__module__))) import ..($(nameof(__module__))): $theoryname + + $(module_structs...) + + $(generate_function.(module_methods)...) + module $(esc(:Meta)) - import ..($(name)): $theoryname + $imp const $(esc(:theory_module)) = $(esc(theoryname)) const $(esc(:theory)) = $(theory) const $(esc(:theory_type)) = $(esc(theoryname)).Meta.T + # Canonical symbolic model + $(esc(:M)) = Dispatch($(esc(:theory)), [$(esc.(nameof.(theory.sorts))...)]) end - $(module_structs...) - $(generate_function.(module_methods)...) end) # Part 3: Generating instance of theory @@ -731,4 +741,4 @@ function show_latex_script(io::IO, expr::GATExpr, head::String; print(io, "}") end -end +end # module diff --git a/src/syntax/TheoryInterface.jl b/src/syntax/TheoryInterface.jl index 6ceea774..8fc1d7de 100644 --- a/src/syntax/TheoryInterface.jl +++ b/src/syntax/TheoryInterface.jl @@ -184,11 +184,27 @@ function theory_impl(head, body, __module__) ) end +""" +The Dispatch type is a model of every theory. +""" +@struct_hash_equal struct Dispatch + t::GAT + types::Dict{AlgSort,Type} +end + +Dispatch(t::GAT, v::AbstractVector{<:Type}) = + Dispatch(t, Dict(zip(sorts(t), v))) + +# WARNING: if any other package play with indexing methodnames with their own +# structs, then this code could be broken because it assumes we are the only +# ones to use this trick. function juliadeclaration(name::Symbol) quote function $name end + # we expect just one method because of Dispatch type + if isempty(Base.methods(Base.getindex, [typeof($name), Any])) + Base.getindex(f::typeof($name), ::$(GlobalRef(TheoryInterface, :Dispatch))) = f - if Base.isempty(Base.methods(Base.getindex, [typeof($name), Any])) function Base.getindex(::typeof($name), m::Any) (args...; context=nothing) -> $name($(GlobalRef(TheoryInterface, :WithModel))(m), args...; context) end diff --git a/src/syntax/gats/gat.jl b/src/syntax/gats/gat.jl index ed5eee0b..4dc8ed92 100644 --- a/src/syntax/gats/gat.jl +++ b/src/syntax/gats/gat.jl @@ -36,6 +36,8 @@ fancier. bysignature::Dict{AlgSorts, Ident} end +Base.iterate(m::MethodResolver, i...) = iterate(m.bysignature, i...) + function MethodResolver() MethodResolver(Dict{AlgSorts, Ident}()) end diff --git a/test/models/ModelInterface.jl b/test/models/ModelInterface.jl index e0d70eba..185575eb 100644 --- a/test/models/ModelInterface.jl +++ b/test/models/ModelInterface.jl @@ -56,7 +56,7 @@ end @test !implements(FinSetC(), ThNatPlus) # Todo: get things working where Ob and Hom are the same type (i.e. binding dict not monic) -struct TypedFinSetC +@struct_hash_equal struct TypedFinSetC ntypes::Int end @@ -214,4 +214,19 @@ end # this will fail unless WithModel accepts subtypes @test ThSet.default[MyVect([1,2,3])](1) == 1 +# Test default model + dispatch model +##################################### +@test_throws MethodError id(2) + +@default_model ThCategory{Int, Vector{Int}} [model::FinSetC] + +d = Dispatch(ThCategory.Meta.theory, [Int, Vector{Int}]) +@test implements(d, ThCategory) +@test !implements(d, ThNatPlus) +@test impl_type(d, ThCategory, :Ob) == Int +@test impl_type(d, ThCategory, :Hom) == Vector{Int} + +@test id(2) == [1,2] == ThCategory.id[d](2) +@test compose([1,2,3], [2,1,3]) == [2,1,3] + end # module diff --git a/test/models/SymbolicModels.jl b/test/models/SymbolicModels.jl index 61271e32..b37684cf 100644 --- a/test/models/SymbolicModels.jl +++ b/test/models/SymbolicModels.jl @@ -19,6 +19,15 @@ f = FreeCategory.Hom{:generator}([:f], [x, y]) @test ThCategory.id(x) isa HomExpr{:id} @test ThCategory.compose(ThCategory.id(x), f) == f +M = FreeCategory.Meta.M +@test M isa Dispatch +@test implements(M, ThCategory) +@test !implements(M, ThNatPlus) +@test impl_type(M, ThCategory, :Ob) == FreeCategory.Ob +@test impl_type(M, ThCategory, :Hom) == FreeCategory.Hom +@test ThCategory.id[M](x) isa HomExpr{:id} +@test ThCategory.compose[M](ThCategory.id(x), f) == f + # Monoid ########