diff --git a/src/models/ModelInterface.jl b/src/models/ModelInterface.jl index f532d2b0..b4f25eff 100644 --- a/src/models/ModelInterface.jl +++ b/src/models/ModelInterface.jl @@ -145,18 +145,10 @@ macro instance(head, model, body) theory = macroexpand(__module__, :($theory_module.Meta.@theory)) # A dictionary to look up the Julia type of a type constructor from its name (an ident) - i = 0 - jltype_by_sort = Dict(map(sorts(theory)) do s - sorttype = getvalue(theory[methodof(s)]) - s => if sorttype isa AlgTypeConstructor - i += 1 - instance_types[i] - elseif sorttype isa AlgStruct - nameof(sorttype.declaration) - end - end) - i == length(instance_types) || error("Did not use all types ($i): $instance_types") - + jltype_by_sort = Dict{AlgSort,Expr0}([ + zip(primitive_sorts(theory), instance_types)..., + [s => nameof(headof(s)) for s in struct_sorts(theory)]... + ]) # Get the model type that we are overloading for, or nothing if this is the # default instance for `instance_types` diff --git a/src/syntax/GATs.jl b/src/syntax/GATs.jl index c7db9c82..3d451eb2 100644 --- a/src/syntax/GATs.jl +++ b/src/syntax/GATs.jl @@ -5,8 +5,8 @@ export Constant, AlgTerm, AlgType, AlgAST, AlgTypeConstructor, AlgAccessor, AlgAxiom, AlgStruct, AlgDot, AlgFunction, typesortsignature, sortsignature, getdecl, GATSegment, GAT, GATContext, gettheory, gettypecontext, allmethods, - resolvemethod, resolvefield, - termcons,typecons, accessors, structs, + resolvemethod, + termcons,typecons, accessors, structs, primitive_sorts, struct_sorts, equations, build_infer_expr, compile, sortcheck, allnames, sorts, sortname, InCtx, TermInCtx, TypeInCtx, headof, argsof, methodof, bodyof, argcontext, infer_type diff --git a/src/syntax/gats/gat.jl b/src/syntax/gats/gat.jl index e62a0181..eba4b2cd 100644 --- a/src/syntax/gats/gat.jl +++ b/src/syntax/gats/gat.jl @@ -163,6 +163,12 @@ function allnames(theory::GAT; aliases=false) end sorts(theory::GAT) = theory.sorts +primitive_sorts(theory::GAT) = + filter(s->getvalue(theory[methodof(s)]) isa AlgTypeConstructor, sorts(theory)) + +# NOTE: AlgStruct is the only derived sort this returns. +struct_sorts(theory::GAT) = + filter(s->getvalue(theory[methodof(s)]) isa AlgStruct, sorts(theory)) function termcons(theory::GAT) xs = Tuple{Ident, Ident}[] @@ -231,20 +237,5 @@ else Scopes.unsafe_pushbinding!(theory, Binding{Judgment}(name, AlgDeclaration())) end -"""Get type associated with a field of a struct""" -function resolvefield(t::Context, method::Ident, field::Symbol) - str = getvalue(t[method]) - str.fields[ident(str.fields; name=field)] |> getvalue -end - """Get all structs in a theory""" -function structs(t::GAT) - res = AlgStruct[] - for s in sorts(t) - v = getvalue(t[methodof(s)]) - if v isa AlgStruct - push!(res, v) - end - end - res -end +structs(t::GAT) = AlgStruct[getvalue(t[methodof(s)]) for s in struct_sorts(t)] \ No newline at end of file diff --git a/src/syntax/gats/judgments.jl b/src/syntax/gats/judgments.jl index a3ad371c..6bc28b36 100644 --- a/src/syntax/gats/judgments.jl +++ b/src/syntax/gats/judgments.jl @@ -66,7 +66,7 @@ A declaration of a type constructor. args::Vector{LID} end -Scopes.getcontext(tc::AlgTypeConstructor) = tc.localcontext +Scopes.getcontext(tc::TrmTypConstructor) = tc.localcontext abstract type AccessorField <: Judgment end @@ -87,7 +87,6 @@ exist. arg::Int end -Scopes.getcontext(::AccessorField) = EmptyContext{AlgType}() getdecl(acc::AccessorField) = acc.declaration @@ -106,7 +105,6 @@ A declaration of a term constructor as a method of an `AlgFunction`. type::Union{TypeScope,AlgType} end -Scopes.getcontext(tc::AlgTermConstructor) = tc.localcontext sortsignature(tc::TrmTypConstructor) = AlgSort.(getvalue.(argsof(tc))) @@ -122,7 +120,6 @@ A declaration of an axiom equands::Vector{AlgTerm} end -Scopes.getcontext(t::AlgAxiom) = t.localcontext """ `AlgSorts` @@ -175,8 +172,6 @@ typesortsignature(tc::AlgStruct) = AlgSort.(getvalue.(typeargsof(tc))) argsof(t::AlgStruct) = getbindings(t.fields) -Scopes.getcontext(tc::AlgStruct) = tc.localcontext - """ A shorthand for a function, such as "square(x) := x * x". It is relevant for models but can be ignored by theory maps, as it is fully determined by other @@ -188,5 +183,3 @@ judgments in the theory. args::Vector{LID} value::AlgTerm end - -Scopes.getcontext(tc::AlgFunction) = tc.localcontext diff --git a/test/nonstdlib/Pushouts.jl b/test/nonstdlib/Pushouts.jl index 76597ef5..7b2b2a98 100644 --- a/test/nonstdlib/Pushouts.jl +++ b/test/nonstdlib/Pushouts.jl @@ -28,4 +28,5 @@ Universal input: Output @test cospan(res) == Cospan(4, [1,1,2,3], [1,2,4]) @test ι₁(res) == [1,1,2,3] @test universal(res, Cospan(4, [3,3,2,2],[3,2,1])) == [3,2,2,1] + @test_throws ErrorException universal(PushoutInt(4, [1,1,2,3],[1,2,3]), Cospan(4, [3,3,2,2],[3,2,1])) end diff --git a/test/syntax/GATs.jl b/test/syntax/GATs.jl index fe34251d..7640d69f 100644 --- a/test/syntax/GATs.jl +++ b/test/syntax/GATs.jl @@ -64,6 +64,7 @@ ob_decl = getvalue(thcat[O]) ObT = fromexpr(thcat, :Ob, AlgType) ObS = AlgSort(ObT) +@test headof(ObS) == O @test toexpr(GATContext(thcat), ObS) == :Ob