Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default model #166

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,9 @@ function mk_fun(f::AlgFunction, theory, mod, jltype_by_sort)
args = map(zip(f.args, sortsignature(f))) do (i,s)
Expr(:(::),nameof(f[i]),jltype_by_sort[s])
end
impl = to_call_impl(f.value,theory, mod, false)
argnames = Vector{Symbol}(undef, length(getcontext(f)))
setindex!.(Ref(argnames), [nameof(f[i]) for i in f.args], getvalue.(f.args))
impl = to_call_impl(f.value,theory, mod, argnames, false)
JuliaFunction(;name=name, args, impl)
end

Expand All @@ -517,7 +519,7 @@ function make_alias_definitions(theory, theory_module, jltype_by_sort, model_typ
args
end
else
[(gensym(:m), :($(TheoryInterface.WithModel){$model_type})); args]
[(gensym(:m), :($(TheoryInterface.WithModel){<:$model_type})); args]
end
argexprs = [Expr(:(::), p...) for p in args]
overload = JuliaFunction(;
Expand Down Expand Up @@ -563,7 +565,7 @@ function qualify_function(fun::JuliaFunction, theory_module, model_type::Union{E

m = gensym(:m)
(
[Expr(:(::), m, Expr(:curly, TheoryInterface.WithModel, model_type)), args...],
[Expr(:(::), m, Expr(:curly, TheoryInterface.WithModel, Expr(:<:, model_type))), args...],
Expr(:let, Expr(:(=), :model, :($m.model)), fun.impl)
)
else
Expand Down Expand Up @@ -652,6 +654,10 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)
v => whereparamdict[AlgSort(tmap(v.method).val)]
end)

whereparams2 = map(sorts(dom_theory)) do v
whereparamdict[AlgSort(tmap(v.method).val)]
end

# Create input for instance_code
################################
accessor_funs = JuliaFunction[] # added to during typecon_funs loop
Expand Down Expand Up @@ -691,11 +697,13 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)
args = [:($k::$(v)) for (k, v) in zip(argnames, sig.types)]

return_type = first(sig.types)
argnames′ = Array{Symbol}(undef, length(getcontext(typecon)))
setindex!.(Ref(argnames′), argnames[2:end], getvalue.(typecon.args))

impls = to_call_impl.(codom_body.args, Ref(codom_theory), Ref(codom_module), true)
impls = to_call_impl.(codom_body.args, Ref(codom_theory), Ref(codom_module),
Ref(argnames′), true)
impl = Expr(:call, Expr(:ref, :($codom_module.$fxname),
:(model.model)), _x, impls...)

JuliaFunction(;name, args, return_type, impl)
end

Expand All @@ -707,8 +715,11 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)

name = nameof(termcon.declaration)
return_type = jltype_by_sort[AlgSort(termcon.type)]
args = [:($k::$v) for (k, v) in zip(nameof.(argsof(termcon)), sig.types)]
impl = to_call_impl(fx.val, codom_theory, codom_module, true)
argnames = nameof.(argsof(termcon))
argnames′ = Array{Symbol}(undef, length(getcontext(termcon)))
setindex!.(Ref(argnames′), argnames, getvalue.(termcon.args))
args = [:($k::$v) for (k, v) in zip(argnames, sig.types)]
impl = to_call_impl(fx.val, codom_theory, codom_module, argnames′, true)

JuliaFunction(;name, args, return_type, impl)
end
Expand All @@ -729,11 +740,12 @@ function migrator(tmap, dom_module, codom_module, dom_theory, codom_theory)
)

tup_params = Expr(:curly, :Tuple, whereparams...)
tup_params2 = Expr(:curly, :Tuple, whereparams2...)

model_expr = Expr(
:curly,
GlobalRef(Syntax.TheoryInterface, :Model),
tup_params
tup_params2 # Types associated with *domain* sorts
)

# The second whereparams needs to be reordered by the sorts of the DOM theory
Expand All @@ -754,19 +766,19 @@ end
Compile an AlgTerm into a Julia call Expr where termcons (e.g. `f`) are
interpreted as `mod.f[model.model](...)`.
"""
function to_call_impl(t::AlgTerm, theory::GAT, mod::Union{Symbol,Module}, migrate::Bool)
function to_call_impl(t::AlgTerm, theory::GAT, mod::Union{Symbol,Module}, argnames::Vector{Symbol}, migrate::Bool)
b = bodyof(t)
if GATs.isvariable(t)
nameof(b)
argnames[getvalue(getlid(b))]
elseif GATs.isdot(t)
impl = to_call_impl(b.body, theory, mod, migrate)
impl = to_call_impl(b.body, theory, mod, argnames, migrate)
if isnamed(b.head)
Expr(:., impl, QuoteNode(nameof(b.head)))
else
Expr(:ref, impl, getlid(b.head).val)
end
else
args = to_call_impl.(argsof(b), Ref(theory), Ref(mod), migrate)
args = to_call_impl.(argsof(b), Ref(theory), Ref(mod), Ref(argnames), migrate)
name = nameof(headof(b))
newhead = if name ∈ nameof.(structs(theory))
Expr(:., :($mod), QuoteNode(name))
Expand Down
Loading