diff --git a/src/models/ModelInterface.jl b/src/models/ModelInterface.jl index 2269ee9f..a16c256d 100644 --- a/src/models/ModelInterface.jl +++ b/src/models/ModelInterface.jl @@ -135,7 +135,6 @@ macro instance(head, model, body) # TODO: should we allow instance types to be nothing? Is this in Catlab? (theory_module, instance_types) = @match head begin :($ThX{$(Ts...)}) => (ThX, Ts) - :($ThX) => (ThX, nothing) _ => error("invalid syntax for head of @instance macro: $head") end @@ -143,34 +142,57 @@ macro instance(head, model, body) theory = macroexpand(__module__, :($theory_module.Meta.@theory)) # A dictionary to look up the Julia type of a type constructor from its name (an ident) - jltype_by_sort = isnothing(instance_types) ? nothing : Dict(zip(sorts(theory), instance_types)) # for type checking + jltype_by_sort = Dict(zip(sorts(theory), instance_types)) # for type checking # Get the model type that we are overloading for, or nothing if this is the # default instance for `instance_types` model_type, whereparams = parse_model_param(model) - # Parse the body into functions defined here and functions defined elsewhere - functions, ext_functions = parse_instance_body(body, theory) + # Create the actual instance + generate_instance(theory, theory_module, jltype_by_sort, model_type, whereparams, body) +end + +function generate_instance( + theory::GAT, + theory_module::Union{Expr0, Module}, + jltype_by_sort::Dict{AlgSort}, + model_type::Union{Expr0, Nothing}, + whereparams::AbstractVector, + body::Expr; + typecheck=true, + escape=true +) # The old (Catlab) style of instance, where there is no explicit model - oldinstance = isnothing(model) + oldinstance = isnothing(model_type) + + # Parse the body into functions defined here and functions defined elsewhere + functions, ext_functions = parse_instance_body(body, theory) # Checks that all the functions are defined with the correct types. Adds default # methods for type constructors and type argument accessors if these methods # are missing - typechecked_functions = if !isnothing(jltype_by_sort) + typechecked_functions = if typecheck typecheck_instance(theory, functions, ext_functions, jltype_by_sort; oldinstance) else [functions..., ext_functions...] # skip typechecking and expand_fail end + # Adds keyword arguments to the functions, and qualifies them by # `theory_module`, i.e. changes # `Ob(x) = blah` # to # `ThCategory.Ob(m::WithModel{M}, x; context=nothing) = let model = m.model in blah end` - qualified_functions = + qualified_functions = map(fun -> qualify_function(fun, theory_module, model_type, whereparams), typechecked_functions) + append!( + qualified_functions, + make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions) + ) + + # Add overloads for the alias methods + # Declare that this model implements the theory implements_declarations = if !isnothing(model_type) @@ -181,10 +203,16 @@ macro instance(head, model, body) [] end - esc(Expr(:block, + docsink = gensym(:docsink) + + code = Expr(:block, [generate_function(f) for f in qualified_functions]..., - implements_declarations... - )) + implements_declarations..., + :(function $docsink end), + :(Core.@__doc__ $docsink) + ) + + escape ? esc(code) : code end macro instance(head, body) @@ -432,15 +460,59 @@ function expand_fail(theory::GAT, x::Ident, f::JuliaFunction) ) end +function make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions) + lines = [] + oldinstance = isnothing(model_type) + for segment in theory.segments.scopes + for binding in segment + alias = getvalue(binding) + name = nameof(binding) + if alias isa Alias && name ∉ ext_functions + for (argsorts, method) in allmethods(theory.resolvers[alias.ref]) + args = [(gensym(), jltype_by_sort[sort]) for sort in argsorts] + args = if oldinstance + if length(args) == 0 + termcon = getvalue(theory[method]) + retsort = AlgSort(termcon.type) + [(gensym(), Expr(:curly, Type, jltype_by_sort[retsort]))] + else + args + end + else + [(gensym(:m), :($(TheoryInterface.WithModel){$model_type})); args] + end + argexprs = [Expr(:(::), p...) for p in args] + overload = JuliaFunction(; + name = :($theory_module.$name), + args = argexprs, + kwargs = [Expr(:(...), :kwargs)], + whereparams, + impl = :($theory_module.$(nameof(alias.ref))($(first.(args)...); kwargs...)) + ) + push!(lines, overload) + end + end + end + end + lines +end + """ -Add `model` kwarg (it shouldn't have it already) +Add `WithModel` param first, if this is not an old instance (it shouldn't have it already) Qualify method name to be in theory module Add `context` kwargs if not already present - -TODO: throw error if there's junk kwargs present already? """ function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{Expr0, Nothing}, whereparams) - kwargs = Expr0[Expr(:kw, :context, nothing)] + kwargs = filter(fun.kwargs) do kwarg + @match kwarg begin + Expr(:kw, :context, _) => false + :context => false + Expr(:(::), :context, _) => false + Expr(:kw, Expr(:(::), :context, _), _) => false + _ => true + end + end + kwargs = Expr0[Expr(:kw, :context, nothing); kwargs] (args, impl) = if !isnothing(model_type) m = gensym(:m) @@ -472,6 +544,7 @@ function implements_declaration(model_type, scope, whereparams) end end + macro withmodel(model, subsexpr, body) modelvar = gensym("model") @@ -535,7 +608,7 @@ macro migrate(head) (name, mapname, modelname) _ => error("could not parse head of @theory: $head") end - codom_types = :(only(supertype($(esc(modelname))).parameters).types) + codom_types = :(only(supertype($modelname).parameters).types) # Unpack tmap = macroexpand(__module__, :($mapname.@map)) dom_module = macroexpand(__module__, :($mapname.@dom)) @@ -607,7 +680,17 @@ macro migrate(head) model_expr = Expr( :curly, GlobalRef(Syntax.TheoryInterface, :Model), - Expr(:curly, :Tuple, dom_types...) + Expr(:curly, :Tuple, esc.(dom_types)...) + ) + + instance_code = generate_instance( + dom_theory, + dom_module, + jltype_by_sort, + name, + [], + Expr(:block, generate_function.([funs...,funs2..., funs3...])...); + typecheck=false ) quote @@ -615,9 +698,7 @@ macro migrate(head) model :: $(esc(modelname)) end - @instance $dom_module [model :: $(esc(name))] begin - $(generate_function.([funs...,funs2..., funs3...])...) - end + $instance_code end end diff --git a/src/models/SymbolicModels.jl b/src/models/SymbolicModels.jl index 80e86fb9..6d5f6511 100644 --- a/src/models/SymbolicModels.jl +++ b/src/models/SymbolicModels.jl @@ -7,6 +7,7 @@ export GATExpr, @symbolic_model, SyntaxDomainError, head, args, gat_typeof, gat_ using ...Util using ...Syntax import ...Syntax: invoke_term +using ..ModelInterface using Base.Meta: ParseError using MLStyle @@ -260,7 +261,7 @@ macro symbolic_model(decl, theoryname, body) # Part 2: Generating internal methods module_methods = [internal_accessors(theory)..., - internal_constructors(theory)...] + internal_constructors(theory, theorymodule)...] module_decl = :(module $(esc(name)) export $(nameof.(sorts(theory))...) @@ -275,16 +276,20 @@ macro symbolic_model(decl, theoryname, body) # Part 3: Generating instance of theory theory_overloads = symbolic_instance_methods(theory, theoryname, name, overrides) - # Part 4: Generating generators. - - generator_overloads = [] - # generator_overloads = symbolic_generators(theorymodule, theory) + alias_overloads = ModelInterface.make_alias_definitions( + theory, + theoryname, + Dict(sort => :($name.$(nameof(sort))) for sort in sorts(theory)), + nothing, + [], + [] + ) Expr( :toplevel, module_decl, :(Core.@__doc__ $(esc(name))), - esc.(generate_function.([theory_overloads; generator_overloads]))..., + esc.(generate_function.([theory_overloads; alias_overloads]))..., ) end @@ -328,7 +333,7 @@ function internal_accessors(theory::GAT) end |> Iterators.flatten end -function internal_constructors(theory::GAT)::Vector{JuliaFunction} +function internal_constructors(theory::GAT, theorymodule)::Vector{JuliaFunction} map(termcons(theory)) do (decl, method) name = nameof(decl) termcon = getvalue(theory, method) @@ -365,17 +370,18 @@ function internal_constructors(theory::GAT)::Vector{JuliaFunction} end... ) + instance_types = Dict(sort => esc(nameof(sort)) for sort in sorts(theory)) check_or_error = Expr(:(||), :(!strict), check_expr, throw_expr) context_xs = getidents(termcon.localcontext) expr_lookup = Dict{Ident, Any}(map(context_xs) do x - x => compile(arg_expr_lookup, first(eqs[x])) + x => compile(arg_expr_lookup, first(eqs[x]); theorymodule, theory, instance_types) end) build = Expr( :call, Expr(:curly, typename(theory, termcon.type), Expr(:quote, name)), Expr(:vect, nameof.(argsof(termcon))...), - Expr(:ref, GATExpr, compile.(Ref(expr_lookup), termcon.type.body.args)...) + Expr(:ref, GATExpr, compile.(Ref(expr_lookup), termcon.type.body.args; theorymodule, theory, instance_types)...) ) JuliaFunction( diff --git a/src/stdlib/models/Arithmetic.jl b/src/stdlib/models/arithmetic.jl similarity index 88% rename from src/stdlib/models/Arithmetic.jl rename to src/stdlib/models/arithmetic.jl index f706cf81..81c591c1 100644 --- a/src/stdlib/models/Arithmetic.jl +++ b/src/stdlib/models/arithmetic.jl @@ -1,9 +1,7 @@ -module Arithmetic - export IntNatPlus, IntPreorder -using ....Models -using ...StdTheories +using ...Models +using ..StdTheories struct IntNatPlus <: Model{Tuple{Int}} end @@ -24,6 +22,3 @@ struct IntPreorder <: Model{Tuple{Int, Tuple{Int,Int}}} end error("Cannot compose $ab and $bc") end end - - -end # module diff --git a/src/stdlib/models/FinMatrices.jl b/src/stdlib/models/finmatrices.jl similarity index 88% rename from src/stdlib/models/FinMatrices.jl rename to src/stdlib/models/finmatrices.jl index 5d6a6b48..35051d07 100644 --- a/src/stdlib/models/FinMatrices.jl +++ b/src/stdlib/models/finmatrices.jl @@ -1,8 +1,7 @@ -module FinMatrices export FinMatC -using ....Models -using ...StdTheories +using ...Models +using ..StdTheories struct FinMatC{T<:Number} <: Model{Tuple{T}} end @@ -18,5 +17,3 @@ end dom(A::Matrix{T}) = size(A)[1] codom(A::Matrix{T}) = size(A)[2] end - -end diff --git a/src/stdlib/models/FinSets.jl b/src/stdlib/models/finsets.jl similarity index 91% rename from src/stdlib/models/FinSets.jl rename to src/stdlib/models/finsets.jl index 8bf8d303..2820eb92 100644 --- a/src/stdlib/models/FinSets.jl +++ b/src/stdlib/models/finsets.jl @@ -1,8 +1,7 @@ -module FinSets export FinSetC -using ....Models -using ...StdTheories +using ...Models +using ..StdTheories struct FinSetC <: Model{Tuple{Int, Vector{Int}}} end @@ -29,5 +28,3 @@ end dom(f::Vector{Int}) = length(f) codom(::Vector{Int}; context) = context[:codom] end - -end diff --git a/src/stdlib/models/GATs.jl b/src/stdlib/models/gats.jl similarity index 81% rename from src/stdlib/models/GATs.jl rename to src/stdlib/models/gats.jl index 2d082f8e..8d6cc87e 100644 --- a/src/stdlib/models/GATs.jl +++ b/src/stdlib/models/gats.jl @@ -1,9 +1,8 @@ -module GATs export GATC -using ....Models -using ....Syntax -using ...StdTheories +using ...Models +using ...Syntax +using ..StdTheories using GATlab, GATlab.Models @@ -16,5 +15,3 @@ end dom(f::AbsTheoryMap) = TheoryMaps.dom(f) codom(f::AbsTheoryMap) = TheoryMaps.codom(f) end - -end # module diff --git a/src/stdlib/models/module.jl b/src/stdlib/models/module.jl index 569fb3ed..e3d07e75 100644 --- a/src/stdlib/models/module.jl +++ b/src/stdlib/models/module.jl @@ -2,20 +2,12 @@ module StdModels using Reexport -include("FinSets.jl") -include("Arithmetic.jl") -include("FinMatrices.jl") -include("SliceCategories.jl") -include("Op.jl") -include("Nothings.jl") -include("GATs.jl") - -@reexport using .FinSets -@reexport using .Arithmetic -@reexport using .FinMatrices -@reexport using .SliceCategories -@reexport using .Op -@reexport using .Nothings -@reexport using .GATs +include("finsets.jl") +include("arithmetic.jl") +include("finmatrices.jl") +include("slicecategories.jl") +include("op.jl") +include("nothings.jl") +include("gats.jl") end diff --git a/src/stdlib/models/Nothings.jl b/src/stdlib/models/nothings.jl similarity index 86% rename from src/stdlib/models/Nothings.jl rename to src/stdlib/models/nothings.jl index a2150225..3a7fac67 100644 --- a/src/stdlib/models/Nothings.jl +++ b/src/stdlib/models/nothings.jl @@ -1,7 +1,6 @@ -module Nothings export NothingC -using ....Models, ...StdTheories +using ...Models, ..StdTheories struct NothingC <: Model{Tuple{Nothing, Nothing}} end @@ -16,5 +15,3 @@ end compose(::Nothing, ::Nothing) = nothing id(::Nothing) = nothing end - -end diff --git a/src/stdlib/models/Op.jl b/src/stdlib/models/op.jl similarity index 94% rename from src/stdlib/models/Op.jl rename to src/stdlib/models/op.jl index 8df58c94..e0e346f9 100644 --- a/src/stdlib/models/Op.jl +++ b/src/stdlib/models/op.jl @@ -2,13 +2,12 @@ Explicit Op model. Alternatively, see DerivedModels.jl (`OpFinSetC`) for theory-morphism-derived Op models. """ -module Op export OpC, op -using ....Models -using ...StdTheories +using ...Models +using ..StdTheories using StructEquality @struct_hash_equal struct OpC{ObT, HomT, C<:Model{Tuple{ObT, HomT}}} <: Model{Tuple{ObT, HomT}} @@ -36,5 +35,3 @@ rename(nt::NamedTuple, d::Dict{Symbol,Symbol}) = compose[model.cat](g, f; context=rename(context, Dict(:a=>:c, :c=>:a, :b=>:b))) end - -end # module diff --git a/src/stdlib/models/SliceCategories.jl b/src/stdlib/models/slicecategories.jl similarity index 94% rename from src/stdlib/models/SliceCategories.jl rename to src/stdlib/models/slicecategories.jl index f2ea45dc..0d5324b9 100644 --- a/src/stdlib/models/SliceCategories.jl +++ b/src/stdlib/models/slicecategories.jl @@ -1,8 +1,7 @@ -module SliceCategories export SliceC, SliceOb -using ....Models -using ...StdTheories +using ...Models +using ..StdTheories using StructEquality @struct_hash_equal struct SliceOb{ObT, HomT} @@ -50,5 +49,3 @@ using .ThCategory compose(f::HomT, g::HomT; context=nothing) = compose[model.cat](f, g; context=isnothing(context) ? nothing : map(x -> x.ob, context)) end - -end diff --git a/src/stdlib/theories/categories.jl b/src/stdlib/theories/categories.jl index 465aee16..5606a5ee 100644 --- a/src/stdlib/theories/categories.jl +++ b/src/stdlib/theories/categories.jl @@ -1,6 +1,6 @@ export ThClass, ThGraph, ThLawlessCat, ThAscCat, ThCategory, ThThinCategory -import AlgebraicInterfaces: dom, codom, compose, id +import AlgebraicInterfaces: dom, codom, compose, id, Ob, Hom # Category theory diff --git a/src/syntax/TheoryInterface.jl b/src/syntax/TheoryInterface.jl index ede49040..543dee0c 100644 --- a/src/syntax/TheoryInterface.jl +++ b/src/syntax/TheoryInterface.jl @@ -83,7 +83,7 @@ function theory_impl(head, body, __module__) judgment = getvalue(binding) bname = nameof(binding) if judgment isa Union{AlgDeclaration, Alias} - push!(lines, juliadeclaration(bname, judgment)) + push!(lines, juliadeclaration(bname)) push!(newnames, bname) end end @@ -130,10 +130,9 @@ function theory_impl(head, body, __module__) ) end -function juliadeclaration(name::Symbol, ::AlgDeclaration) - decl = :(function $name end) +function juliadeclaration(name::Symbol) quote - $decl + function $name end if Base.isempty(Base.methods(Base.getindex, [typeof($name), $(GlobalRef(TheoryInterface, :Model))])) function Base.getindex(::typeof($name), m::$(GlobalRef(TheoryInterface, :Model))) @@ -143,10 +142,6 @@ function juliadeclaration(name::Symbol, ::AlgDeclaration) end end -function juliadeclaration(name::Symbol, alias::Alias) - :(const $name = $(nameof(alias.ref))) -end - function invoke_term(theory_module, types, name, args; model=nothing) theory = theory_module.Meta.theory method = getproperty(theory_module, name) diff --git a/src/syntax/gats/algorithms.jl b/src/syntax/gats/algorithms.jl index 7e69c9c0..2076c702 100644 --- a/src/syntax/gats/algorithms.jl +++ b/src/syntax/gats/algorithms.jl @@ -112,7 +112,8 @@ function equations(theory::GAT, x::Ident) equations(GATContext(theory, judgment), idents(judgment.localcontext; lid=judgment.args)) end -function compile(expr_lookup::Dict{Ident}, term::AlgTerm; theorymodule=nothing) +function compile(expr_lookup::Dict{Ident}, term::AlgTerm; + theorymodule=nothing, instance_types=nothing, theory=nothing) if isapp(term) name = nameof(term.body.head) fun = if !isnothing(theorymodule) @@ -120,7 +121,15 @@ function compile(expr_lookup::Dict{Ident}, term::AlgTerm; theorymodule=nothing) else esc(name) end - Expr(:call, fun, [compile(expr_lookup, arg; theorymodule) for arg in term.body.args]...) + # In the case that we have an old-style instance we need to pass in the + # return type in order to dispatch a nullary term constructor + args = if !isnothing(instance_types) && isempty(term.body.args) + termcon = getvalue(theory[term.body.method]) + [instance_types[AlgSort(termcon.type)]] + else + [compile(expr_lookup, arg; theorymodule, instance_types, theory) for arg in term.body.args] + end + Expr(:call, fun, args...) elseif isvariable(term) expr_lookup[term.body] elseif isconstant(term) diff --git a/src/util/MetaUtils.jl b/src/util/MetaUtils.jl index e6708a0f..fa611e90 100644 --- a/src/util/MetaUtils.jl +++ b/src/util/MetaUtils.jl @@ -4,7 +4,7 @@ module MetaUtils export JuliaFunction, setimpl, setname, JuliaFunctionSig, parse_docstring, parse_function, parse_function_sig, generate_docstring, generate_function, - replace_symbols, strip_lines, + append_expr!, concat_expr, replace_symbols, strip_lines, Expr0 using Base.Meta: ParseError @@ -159,6 +159,28 @@ end # Operations on Julia expressions ################################# +""" Append a Julia expression to a block expression. +""" +function append_expr!(block::Expr, expr)::Expr + @assert block.head == :block + @match expr begin + Expr(:block, args...) => append!(block.args, args) + _ => push!(block.args, expr) + end + block +end + +""" Concatenate two Julia expressions into a block expression. +""" +function concat_expr(expr1::Expr, expr2::Expr)::Expr + @match (expr1, expr2) begin + (Expr(:block, a1...), Expr(:block, a2...)) => Expr(:block, a1..., a2...) + (Expr(:block, a1...), _) => Expr(:block, a1..., expr2) + (_, Expr(:block, a2...)) => Expr(:block, expr1, a2...) + _ => Expr(:block, expr1, expr2) + end +end + """ Replace symbols occurring anywhere in a Julia function (except the name). """ function replace_symbols(bindings::AbstractDict, f::JuliaFunction)::JuliaFunction diff --git a/test/syntax/TheoryInterface.jl b/test/syntax/TheoryInterface.jl index fa461a3c..8586a280 100644 --- a/test/syntax/TheoryInterface.jl +++ b/test/syntax/TheoryInterface.jl @@ -16,7 +16,7 @@ end using .ThCategoryTypes @test dom isa Function -@test Hom == → +@test Hom != → @test parentmodule(dom) == TestTheoryInterface @theory ThLawlessCategory <: ThCategoryTypes begin @@ -28,7 +28,7 @@ end using .ThLawlessCategory @test compose isa Function -@test compose == (⋅) +@test compose != (⋅) @test parentmodule(id) == TestTheoryInterface @test Set(allnames(ThLawlessCategory.Meta.theory)) == Set([:Ob, :Hom, :dom, :codom, :compose, :id]) @test nameof(ThLawlessCategory.Meta.theory) == :ThLawlessCategory