Skip to content

Commit

Permalink
fixes to get catlab working
Browse files Browse the repository at this point in the history
  • Loading branch information
olynch committed Oct 11, 2023
1 parent 74445d5 commit 3ae65e2
Show file tree
Hide file tree
Showing 15 changed files with 176 additions and 94 deletions.
119 changes: 100 additions & 19 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,42 +135,64 @@ macro instance(head, model, body)
# TODO: should we allow instance types to be nothing? Is this in Catlab?
(theory_module, instance_types) = @match head begin
:($ThX{$(Ts...)}) => (ThX, Ts)
:($ThX) => (ThX, nothing)
_ => error("invalid syntax for head of @instance macro: $head")
end

# Get the underlying theory
theory = macroexpand(__module__, :($theory_module.Meta.@theory))

# A dictionary to look up the Julia type of a type constructor from its name (an ident)
jltype_by_sort = isnothing(instance_types) ? nothing : Dict(zip(sorts(theory), instance_types)) # for type checking
jltype_by_sort = Dict(zip(sorts(theory), instance_types)) # for type checking

# Get the model type that we are overloading for, or nothing if this is the
# default instance for `instance_types`
model_type, whereparams = parse_model_param(model)

# Parse the body into functions defined here and functions defined elsewhere
functions, ext_functions = parse_instance_body(body, theory)

# Create the actual instance
generate_instance(theory, theory_module, jltype_by_sort, model_type, whereparams, body)
end

function generate_instance(
theory::GAT,
theory_module::Union{Expr0, Module},
jltype_by_sort::Dict{AlgSort},
model_type::Union{Expr0, Nothing},
whereparams::AbstractVector,
body::Expr;
typecheck=true,
escape=true
)
# The old (Catlab) style of instance, where there is no explicit model
oldinstance = isnothing(model)
oldinstance = isnothing(model_type)

# Parse the body into functions defined here and functions defined elsewhere
functions, ext_functions = parse_instance_body(body, theory)

# Checks that all the functions are defined with the correct types. Adds default
# methods for type constructors and type argument accessors if these methods
# are missing
typechecked_functions = if !isnothing(jltype_by_sort)
typechecked_functions = if typecheck
typecheck_instance(theory, functions, ext_functions, jltype_by_sort; oldinstance)
else
[functions..., ext_functions...] # skip typechecking and expand_fail
end

# Adds keyword arguments to the functions, and qualifies them by
# `theory_module`, i.e. changes
# `Ob(x) = blah`
# to
# `ThCategory.Ob(m::WithModel{M}, x; context=nothing) = let model = m.model in blah end`
qualified_functions =
qualified_functions =
map(fun -> qualify_function(fun, theory_module, model_type, whereparams), typechecked_functions)

append!(
qualified_functions,
make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions)
)

# Add overloads for the alias methods

# Declare that this model implements the theory

implements_declarations = if !isnothing(model_type)
Expand All @@ -181,10 +203,16 @@ macro instance(head, model, body)
[]
end

esc(Expr(:block,
docsink = gensym(:docsink)

code = Expr(:block,
[generate_function(f) for f in qualified_functions]...,
implements_declarations...
))
implements_declarations...,
:(function $docsink end),
:(Core.@__doc__ $docsink)
)

escape ? esc(code) : code
end

macro instance(head, body)
Expand Down Expand Up @@ -432,15 +460,59 @@ function expand_fail(theory::GAT, x::Ident, f::JuliaFunction)
)
end

function make_alias_definitions(theory, theory_module, jltype_by_sort, model_type, whereparams, ext_functions)
lines = []
oldinstance = isnothing(model_type)
for segment in theory.segments.scopes
for binding in segment
alias = getvalue(binding)
name = nameof(binding)
if alias isa Alias && name ext_functions
for (argsorts, method) in allmethods(theory.resolvers[alias.ref])
args = [(gensym(), jltype_by_sort[sort]) for sort in argsorts]
args = if oldinstance
if length(args) == 0
termcon = getvalue(theory[method])
retsort = AlgSort(termcon.type)
[(gensym(), Expr(:curly, Type, jltype_by_sort[retsort]))]
else
args
end
else
[(gensym(:m), :($(TheoryInterface.WithModel){$model_type})); args]
end
argexprs = [Expr(:(::), p...) for p in args]
overload = JuliaFunction(;
name = :($theory_module.$name),
args = argexprs,
kwargs = [Expr(:(...), :kwargs)],
whereparams,
impl = :($theory_module.$(nameof(alias.ref))($(first.(args)...); kwargs...))
)
push!(lines, overload)
end
end
end
end
lines
end

"""
Add `model` kwarg (it shouldn't have it already)
Add `WithModel` param first, if this is not an old instance (it shouldn't have it already)
Qualify method name to be in theory module
Add `context` kwargs if not already present
TODO: throw error if there's junk kwargs present already?
"""
function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{Expr0, Nothing}, whereparams)
kwargs = Expr0[Expr(:kw, :context, nothing)]
kwargs = filter(fun.kwargs) do kwarg
@match kwarg begin
Expr(:kw, :context, _) => false
:context => false
Expr(:(::), :context, _) => false

Check warning on line 510 in src/models/ModelInterface.jl

View check run for this annotation

Codecov / codecov/patch

src/models/ModelInterface.jl#L510

Added line #L510 was not covered by tests
Expr(:kw, Expr(:(::), :context, _), _) => false
_ => true
end
end
kwargs = Expr0[Expr(:kw, :context, nothing); kwargs]

(args, impl) = if !isnothing(model_type)
m = gensym(:m)
Expand Down Expand Up @@ -472,6 +544,7 @@ function implements_declaration(model_type, scope, whereparams)
end
end


macro withmodel(model, subsexpr, body)
modelvar = gensym("model")

Expand Down Expand Up @@ -535,7 +608,7 @@ macro migrate(head)
(name, mapname, modelname)
_ => error("could not parse head of @theory: $head")
end
codom_types = :(only(supertype($(esc(modelname))).parameters).types)
codom_types = :(only(supertype($modelname).parameters).types)
# Unpack
tmap = macroexpand(__module__, :($mapname.@map))
dom_module = macroexpand(__module__, :($mapname.@dom))
Expand Down Expand Up @@ -607,17 +680,25 @@ macro migrate(head)
model_expr = Expr(
:curly,
GlobalRef(Syntax.TheoryInterface, :Model),
Expr(:curly, :Tuple, dom_types...)
Expr(:curly, :Tuple, esc.(dom_types)...)
)

instance_code = generate_instance(
dom_theory,
dom_module,
jltype_by_sort,
name,
[],
Expr(:block, generate_function.([funs...,funs2..., funs3...])...);
typecheck=false
)

quote
struct $(esc(name)) <: $model_expr
model :: $(esc(modelname))
end

@instance $dom_module [model :: $(esc(name))] begin
$(generate_function.([funs...,funs2..., funs3...])...)
end
$instance_code
end
end

Expand Down
24 changes: 15 additions & 9 deletions src/models/SymbolicModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export GATExpr, @symbolic_model, SyntaxDomainError, head, args, gat_typeof, gat_
using ...Util
using ...Syntax
import ...Syntax: invoke_term
using ..ModelInterface

using Base.Meta: ParseError
using MLStyle
Expand Down Expand Up @@ -260,7 +261,7 @@ macro symbolic_model(decl, theoryname, body)
# Part 2: Generating internal methods

module_methods = [internal_accessors(theory)...,
internal_constructors(theory)...]
internal_constructors(theory, theorymodule)...]

module_decl = :(module $(esc(name))
export $(nameof.(sorts(theory))...)
Expand All @@ -275,16 +276,20 @@ macro symbolic_model(decl, theoryname, body)
# Part 3: Generating instance of theory
theory_overloads = symbolic_instance_methods(theory, theoryname, name, overrides)

# Part 4: Generating generators.

generator_overloads = []
# generator_overloads = symbolic_generators(theorymodule, theory)
alias_overloads = ModelInterface.make_alias_definitions(
theory,
theoryname,
Dict(sort => :($name.$(nameof(sort))) for sort in sorts(theory)),
nothing,
[],
[]
)

Expr(
:toplevel,
module_decl,
:(Core.@__doc__ $(esc(name))),
esc.(generate_function.([theory_overloads; generator_overloads]))...,
esc.(generate_function.([theory_overloads; alias_overloads]))...,
)
end

Expand Down Expand Up @@ -328,7 +333,7 @@ function internal_accessors(theory::GAT)
end |> Iterators.flatten
end

function internal_constructors(theory::GAT)::Vector{JuliaFunction}
function internal_constructors(theory::GAT, theorymodule)::Vector{JuliaFunction}
map(termcons(theory)) do (decl, method)
name = nameof(decl)
termcon = getvalue(theory, method)
Expand Down Expand Up @@ -365,17 +370,18 @@ function internal_constructors(theory::GAT)::Vector{JuliaFunction}
end...
)

instance_types = Dict(sort => esc(nameof(sort)) for sort in sorts(theory))
check_or_error = Expr(:(||), :(!strict), check_expr, throw_expr)
context_xs = getidents(termcon.localcontext)
expr_lookup = Dict{Ident, Any}(map(context_xs) do x
x => compile(arg_expr_lookup, first(eqs[x]))
x => compile(arg_expr_lookup, first(eqs[x]); theorymodule, theory, instance_types)
end)

build = Expr(
:call,
Expr(:curly, typename(theory, termcon.type), Expr(:quote, name)),
Expr(:vect, nameof.(argsof(termcon))...),
Expr(:ref, GATExpr, compile.(Ref(expr_lookup), termcon.type.body.args)...)
Expr(:ref, GATExpr, compile.(Ref(expr_lookup), termcon.type.body.args; theorymodule, theory, instance_types)...)
)

JuliaFunction(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
module Arithmetic

export IntNatPlus, IntPreorder

using ....Models
using ...StdTheories
using ...Models
using ..StdTheories

struct IntNatPlus <: Model{Tuple{Int}} end

Expand All @@ -24,6 +22,3 @@ struct IntPreorder <: Model{Tuple{Int, Tuple{Int,Int}}} end
error("Cannot compose $ab and $bc")
end
end


end # module
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module FinMatrices
export FinMatC

using ....Models
using ...StdTheories
using ...Models
using ..StdTheories

struct FinMatC{T<:Number} <: Model{Tuple{T}}
end
Expand All @@ -18,5 +17,3 @@ end
dom(A::Matrix{T}) = size(A)[1]
codom(A::Matrix{T}) = size(A)[2]
end

end
7 changes: 2 additions & 5 deletions src/stdlib/models/FinSets.jl → src/stdlib/models/finsets.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module FinSets
export FinSetC

using ....Models
using ...StdTheories
using ...Models
using ..StdTheories

struct FinSetC <: Model{Tuple{Int, Vector{Int}}}
end
Expand All @@ -29,5 +28,3 @@ end
dom(f::Vector{Int}) = length(f)
codom(::Vector{Int}; context) = context[:codom]
end

end
9 changes: 3 additions & 6 deletions src/stdlib/models/GATs.jl → src/stdlib/models/gats.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module GATs
export GATC

using ....Models
using ....Syntax
using ...StdTheories
using ...Models
using ...Syntax
using ..StdTheories

using GATlab, GATlab.Models

Expand All @@ -16,5 +15,3 @@ end
dom(f::AbsTheoryMap) = TheoryMaps.dom(f)
codom(f::AbsTheoryMap) = TheoryMaps.codom(f)
end

end # module
22 changes: 7 additions & 15 deletions src/stdlib/models/module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,12 @@ module StdModels

using Reexport

include("FinSets.jl")
include("Arithmetic.jl")
include("FinMatrices.jl")
include("SliceCategories.jl")
include("Op.jl")
include("Nothings.jl")
include("GATs.jl")

@reexport using .FinSets
@reexport using .Arithmetic
@reexport using .FinMatrices
@reexport using .SliceCategories
@reexport using .Op
@reexport using .Nothings
@reexport using .GATs
include("finsets.jl")
include("arithmetic.jl")
include("finmatrices.jl")
include("slicecategories.jl")
include("op.jl")
include("nothings.jl")
include("gats.jl")

end
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module Nothings
export NothingC

using ....Models, ...StdTheories
using ...Models, ..StdTheories

struct NothingC <: Model{Tuple{Nothing, Nothing}}
end
Expand All @@ -16,5 +15,3 @@ end
compose(::Nothing, ::Nothing) = nothing
id(::Nothing) = nothing
end

end
Loading

0 comments on commit 3ae65e2

Please sign in to comment.