diff --git a/src/syntax/TheoryInterface.jl b/src/syntax/TheoryInterface.jl index 1aaf62f..faaa846 100644 --- a/src/syntax/TheoryInterface.jl +++ b/src/syntax/TheoryInterface.jl @@ -276,21 +276,15 @@ function wrapper(name::Symbol, t::GAT, mod) quote $use macro wrapper(n) - x, y = if n isa Symbol - n, Any - elseif n.head == :<: - n.args - else - error("Unexpected input for wrapper $n") - end - + x, y = $(parse_wrapper_input)(n) esc(:($($(name)).Meta.@wrapper $x $y)) end - macro abs_wrapper(n) - n.head == :<: || error("Expected: StructName <: AbsType, got $n") - x, y = n.args - esc(:($($(name)).Meta.@wrapper $(x) $(y))) + + macro typed_wrapper(n) + x, y = $(parse_wrapper_input)(n) + esc(:($($(name)).Meta.@typed_wrapper $x $y)) end + macro wrapper(n, abs) doctarget = gensym() esc(quote @@ -328,8 +322,52 @@ function wrapper(name::Symbol, t::GAT, mod) nothing end) end + + macro typed_wrapper(n, abs) + doctarget = gensym() + Ts = nameof.($(sorts)($t)) + Xs = map(Ts) do s + :($(GlobalRef($(TheoryInterface), :impl_type))(x, $($(name)), $($(Meta.quot)(s)))) + end + esc(quote + # Catch any potential docs above the macro call + const $(doctarget) = nothing + Core.@__doc__ $(doctarget) + + # Declare the wrapper struct + struct $n{$(Ts...)} <: $abs + val::Any + function $n(x::Any) + $($(GlobalRef(TheoryInterface, :implements)))(x, $($name)) || error( + "Invalid $($($(name))) model: $x") + new{$(Xs...)}(x) + end + end + # Apply the caught documentation to the new struct + @doc $($(mdp))(@doc $doctarget) $n + + # Define == and hash + $(Expr(:macrocall, $(GlobalRef(StructEquality, Symbol("@struct_hash_equal"))), $(mod), $(:n))) + + # GlobalRef doesn't work: "invalid function name". + GATlab.getvalue(x::$n) = x.val + GATlab.impl_type(x::$n, o::Symbol) = GATlab.impl_type(x.val, $($name), o) + + # Dispatch on model value for all declarations in theory + $(map(filter(x->x[2] isa $AlgDeclaration, $(identvalues(t)))) do (x,j) + if j isa $(AlgDeclaration) + op = nameof(x) + :($($(name)).$op(x::$(($(:n))), args...; kw...) = + $($(name)).$op[x.val](args...; kw...)) + end + end...) + nothing + end) + end end end +parse_wrapper_input(n::Symbol) = n, Any +parse_wrapper_input(n::Expr) = n.head == :<: ? n.args : error("Bad input for wrapper") end # module diff --git a/test/models/ModelInterface.jl b/test/models/ModelInterface.jl index 51a8734..733ff91 100644 --- a/test/models/ModelInterface.jl +++ b/test/models/ModelInterface.jl @@ -253,7 +253,21 @@ end @test_throws MethodError id2(FinSetC()) abstract type MyAbsType end -ThCategory.Meta.@abs_wrapper Cat2 <: MyAbsType +ThCategory.Meta.@wrapper Cat2 <: MyAbsType @test Cat2 <: MyAbsType +# Typed wrappers +#---------------- +"""Typed Cat""" +ThCategory.Meta.@typed_wrapper TCat + +c = TCat(FinSetC()) +@test c isa TCat{Int, Vector{Int}} +@test id(c, 2) == [1,2] + +c2 = TCat(FinMatC{Int}()); +@test c2 isa TCat{Int, Matrix{Int}} + +@test id(c2, 2) == [1 0; 0 1] + end # module