Skip to content

Commit

Permalink
primitive vs struct sorts, more coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Brown committed Nov 10, 2023
1 parent 03b8e35 commit c53cd4b
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 38 deletions.
16 changes: 4 additions & 12 deletions src/models/ModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,10 @@ macro instance(head, model, body)
theory = macroexpand(__module__, :($theory_module.Meta.@theory))

# A dictionary to look up the Julia type of a type constructor from its name (an ident)
i = 0
jltype_by_sort = Dict(map(sorts(theory)) do s
sorttype = getvalue(theory[methodof(s)])
s => if sorttype isa AlgTypeConstructor
i += 1
instance_types[i]
elseif sorttype isa AlgStruct
nameof(sorttype.declaration)
end
end)
i == length(instance_types) || error("Did not use all types ($i): $instance_types")

jltype_by_sort = Dict{AlgSort,Expr0}([
zip(primitive_sorts(theory), instance_types)...,
[s => nameof(headof(s)) for s in struct_sorts(theory)]...
])

# Get the model type that we are overloading for, or nothing if this is the
# default instance for `instance_types`
Expand Down
4 changes: 2 additions & 2 deletions src/syntax/GATs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ export Constant, AlgTerm, AlgType, AlgAST,
AlgTypeConstructor, AlgAccessor, AlgAxiom, AlgStruct, AlgDot, AlgFunction,
typesortsignature, sortsignature, getdecl,
GATSegment, GAT, GATContext, gettheory, gettypecontext, allmethods,
resolvemethod, resolvefield,
termcons,typecons, accessors, structs,
resolvemethod,
termcons,typecons, accessors, structs, primitive_sorts, struct_sorts,
equations, build_infer_expr, compile, sortcheck, allnames, sorts, sortname,
InCtx, TermInCtx, TypeInCtx, headof, argsof, methodof, bodyof, argcontext,
infer_type
Expand Down
23 changes: 7 additions & 16 deletions src/syntax/gats/gat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ function allnames(theory::GAT; aliases=false)
end

sorts(theory::GAT) = theory.sorts
primitive_sorts(theory::GAT) =
filter(s->getvalue(theory[methodof(s)]) isa AlgTypeConstructor, sorts(theory))

# NOTE: AlgStruct is the only derived sort this returns.
struct_sorts(theory::GAT) =
filter(s->getvalue(theory[methodof(s)]) isa AlgStruct, sorts(theory))

function termcons(theory::GAT)
xs = Tuple{Ident, Ident}[]
Expand Down Expand Up @@ -231,20 +237,5 @@ else
Scopes.unsafe_pushbinding!(theory, Binding{Judgment}(name, AlgDeclaration()))
end

"""Get type associated with a field of a struct"""
function resolvefield(t::Context, method::Ident, field::Symbol)
str = getvalue(t[method])
str.fields[ident(str.fields; name=field)] |> getvalue
end

"""Get all structs in a theory"""
function structs(t::GAT)
res = AlgStruct[]
for s in sorts(t)
v = getvalue(t[methodof(s)])
if v isa AlgStruct
push!(res, v)
end
end
res
end
structs(t::GAT) = AlgStruct[getvalue(t[methodof(s)]) for s in struct_sorts(t)]
9 changes: 1 addition & 8 deletions src/syntax/gats/judgments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ A declaration of a type constructor.
args::Vector{LID}
end

Scopes.getcontext(tc::AlgTypeConstructor) = tc.localcontext
Scopes.getcontext(tc::TrmTypConstructor) = tc.localcontext

abstract type AccessorField <: Judgment end

Expand All @@ -87,7 +87,6 @@ exist.
arg::Int
end

Scopes.getcontext(::AccessorField) = EmptyContext{AlgType}()

getdecl(acc::AccessorField) = acc.declaration

Expand All @@ -106,7 +105,6 @@ A declaration of a term constructor as a method of an `AlgFunction`.
type::Union{TypeScope,AlgType}
end

Scopes.getcontext(tc::AlgTermConstructor) = tc.localcontext

sortsignature(tc::TrmTypConstructor) =
AlgSort.(getvalue.(argsof(tc)))
Expand All @@ -122,7 +120,6 @@ A declaration of an axiom
equands::Vector{AlgTerm}
end

Scopes.getcontext(t::AlgAxiom) = t.localcontext

"""
`AlgSorts`
Expand Down Expand Up @@ -175,8 +172,6 @@ typesortsignature(tc::AlgStruct) =
AlgSort.(getvalue.(typeargsof(tc)))
argsof(t::AlgStruct) = getbindings(t.fields)

Scopes.getcontext(tc::AlgStruct) = tc.localcontext

"""
A shorthand for a function, such as "square(x) := x * x". It is relevant for
models but can be ignored by theory maps, as it is fully determined by other
Expand All @@ -188,5 +183,3 @@ judgments in the theory.
args::Vector{LID}
value::AlgTerm
end

Scopes.getcontext(tc::AlgFunction) = tc.localcontext
1 change: 1 addition & 0 deletions test/nonstdlib/Pushouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ Universal input: Output
@test cospan(res) == Cospan(4, [1,1,2,3], [1,2,4])
@test ι₁(res) == [1,1,2,3]
@test universal(res, Cospan(4, [3,3,2,2],[3,2,1])) == [3,2,2,1]
@test_throws ErrorException universal(PushoutInt(4, [1,1,2,3],[1,2,3]), Cospan(4, [3,3,2,2],[3,2,1]))
end
1 change: 1 addition & 0 deletions test/syntax/GATs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ ob_decl = getvalue(thcat[O])

ObT = fromexpr(thcat, :Ob, AlgType)
ObS = AlgSort(ObT)
@test headof(ObS) == O
@test toexpr(GATContext(thcat), ObS) == :Ob


Expand Down

0 comments on commit c53cd4b

Please sign in to comment.