Skip to content

Commit

Permalink
Patches for BbML (#240)
Browse files Browse the repository at this point in the history
* Fix extrusion cache

* Fix boolean if type check

* Fix quotes

* Fix run function type

* Fix context envs

* Slightly improve error messages

* Fix missing operators

* Cache skolem extrusion

* Rename prelude file
  • Loading branch information
NeilKleistGao authored Nov 25, 2024
1 parent 719113d commit efb66c1
Show file tree
Hide file tree
Showing 29 changed files with 566 additions and 494 deletions.
6 changes: 3 additions & 3 deletions hkmc2/jvm/src/test/scala/hkmc2/BbmlDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import hkmc2.bbml.*

abstract class BbmlDiffMaker extends JSBackendDiffMaker:

val bbPredefFile = file / os.up / os.RelPath("bbPredef.mls")
val bbPreludeFile = file / os.up / os.RelPath("bbPrelude.mls")

val bbmlOpt = new NullaryCommand("bbml"):
override def onSet(): Unit =
super.onSet()
if isGlobal then typeCheck.disable.isGlobal = true
typeCheck.disable.setCurrentValue(())
if file =/= bbPredefFile then
importFile(bbPredefFile, verbose = false)
if file =/= bbPreludeFile then
importFile(bbPreludeFile, verbose = false)


lazy val bbCtx =
Expand Down
19 changes: 12 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class ConstraintSolver(infVarState: InfVarUid.State, tl: TraceLogger):

import hkmc2.bbml.NormalForm.*

private def freshXVar(lvl: Int): InfVar = InfVar(lvl, infVarState.nextUid, new VarState(), false)
private def freshXVar(lvl: Int, hint: Option[Str]): InfVar = InfVar(lvl, infVarState.nextUid, new VarState(), false)(hint)

def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache): Type =
def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache, bbctx: BbCtx, cctx: CCtx, tl: TL): Type =
trace[Type](s"Extruding[${printPol(pol)}] $ty", r => s"~> $r"):
if ty.lvl <= lvl then ty else ty.toBasic/*TODO improve extrude directly*/ match
case ClassLikeType(sym, targs) =>
Expand All @@ -49,13 +49,18 @@ class ConstraintSolver(infVarState: InfVarUid.State, tl: TraceLogger):
case t: Type => Wildcard(extrude(t)(using lvl, !pol), extrude(t))
})
case v @ InfVar(_, uid, state, true) => // * skolem
if pol then
state.upperBounds.foldLeft[Type](Top)(_ & _)
else
state.lowerBounds.foldLeft[Type](Bot)(_ | _)
cache.getOrElse(uid -> pol, {
val nv = freshXVar(lvl, v.hint)
cache += uid -> pol -> nv
if pol then
constrainImpl(state.upperBounds.foldLeft[Type](Top)(_ & _), nv)
else
constrainImpl(nv, state.lowerBounds.foldLeft[Type](Bot)(_ | _))
nv
})
case v @ InfVar(_, uid, _, false) =>
cache.getOrElse(uid -> pol, {
val nv = freshXVar(lvl)
val nv = freshXVar(lvl, v.hint)
cache += uid -> pol -> nv
if pol then
v.state.upperBounds ::= nv
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ extends NormalForm with CachedBasicType:
}.foldLeft[Opt[Ls[(InfVar, Bool)]]](S(Nil))((res, p) => (res, p) match { // * None -> bot
case (N, _) => N
case (S(Nil), p) => S(p :: Nil)
case (S((InfVar(v, uid1, s, k), p1) :: tail), (InfVar(_, uid2, _, _), p2)) if uid1 === uid2 =>
if p1 === p2 then S((InfVar(v, uid1, s, k), p1) :: tail) else N
case (S((lhs @ InfVar(v, uid1, s, k), p1) :: tail), (InfVar(_, uid2, _, _), p2)) if uid1 === uid2 =>
if p1 === p2 then S((InfVar(v, uid1, s, k)(lhs.hint), p1) :: tail) else N
case (S(head :: tail), p) => S(p :: head :: tail)
})
vars match
Expand Down
105 changes: 58 additions & 47 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@ final case class BbCtx(
ctx: Ctx,
parent: Option[BbCtx],
lvl: Int,
clsDefs: HashMap[Str, ClassDef],
env: HashMap[Uid[Symbol], GeneralType],
quoteSkolemEnv: HashMap[Uid[Symbol], InfVar], // * SkolemTag for variables in quasiquotes
env: HashMap[Uid[Symbol], GeneralType]
):
def +=(p: Symbol -> GeneralType): Unit = env += p._1.uid -> p._2
def get(sym: Symbol): Option[GeneralType] = env.get(sym.uid) orElse parent.dlof(_.get(sym))(None)
def *=(cls: ClassDef): Unit = clsDefs += cls.sym.id.name -> cls
def getCls(name: Str): Option[TypeSymbol] =
for
elem <- ctx.get(name)
Expand All @@ -38,10 +35,8 @@ final case class BbCtx(
yield cls
def &=(p: (Symbol, Type, InfVar)): Unit =
env += p._1.uid -> BbCtx.varTy(p._2, p._3)(using this)
quoteSkolemEnv += p._1.uid -> p._3
def getSk(sym: Symbol): Option[Type] = quoteSkolemEnv.get(sym.uid) orElse parent.dlof(_.getSk(sym))(None)
def nest: BbCtx = copy(parent = Some(this))
def nextLevel: BbCtx = copy(lvl = lvl + 1)
def nest: BbCtx = copy(parent = Some(this), env = HashMap.empty)
def nextLevel: BbCtx = copy(parent = Some(this), lvl = lvl + 1, env = HashMap.empty)

given (using ctx: BbCtx): Raise = ctx.raise

Expand All @@ -62,7 +57,7 @@ object BbCtx:
def refTy(ct: Type, sk: Type)(using ctx: BbCtx): Type =
ClassLikeType(ctx.getCls("Ref").get, Wildcard(ct, ct) :: Wildcard.out(sk) :: Nil)
def init(raise: Raise)(using Elaborator.State, Elaborator.Ctx): BbCtx =
new BbCtx(raise, summon, None, 1, HashMap.empty, HashMap.empty, HashMap.empty)
new BbCtx(raise, summon, None, 1, HashMap.empty)
end BbCtx


Expand All @@ -72,13 +67,13 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
private val infVarState = new InfVarUid.State()
private val solver = new ConstraintSolver(infVarState, tl)

private def freshSkolem(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)
private def freshVar(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)
private def freshWildcard(using ctx: BbCtx) =
val in = freshVar
val out = freshVar
private def freshSkolem(hint: Option[Str])(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)(hint)
private def freshVar(hint: Option[Str])(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)(hint)
private def freshWildcard(hint: Option[Str])(using ctx: BbCtx) =
val in = freshVar(hint)
val out = freshVar(hint)
// in.state.upperBounds ::= out // * Not needed for soundness; complicates inferred types
Wildcard(in, out)

Expand Down Expand Up @@ -157,7 +152,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
private def genPolyType(tvs: Ls[QuantVar], body: => GeneralType)(using ctx: BbCtx, cctx: CCtx) =
val bds = tvs.map:
case qv @ QuantVar(sym, ub, lb) =>
val tv = freshVar
val tv = freshVar(S(sym.name))
ctx += sym -> tv // TODO: a type var symbol may be better...
tv -> qv
bds.foreach:
Expand All @@ -176,7 +171,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):

private def instantiate(ty: PolyType)(using ctx: BbCtx): GeneralType = ty.instantiate(infVarState.nextUid, ctx.lvl)(tl)

private def extrude(ty: GeneralType)(using ctx: BbCtx, pol: Bool): GeneralType = ty match
private def extrude(ty: GeneralType)(using ctx: BbCtx, pol: Bool, cctx: CCtx): GeneralType = ty match
case ty: Type => solver.extrude(ty)(using ctx.lvl, pol, HashMap.empty)
case PolyType(tvs, body) => PolyType(tvs, extrude(body))
case PolyFunType(args, ret, eff) =>
Expand All @@ -185,7 +180,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
private def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx): Unit =
solver.constrain(lhs, rhs)

// TODO: content type
private def typeCode(code: Term)(using ctx: BbCtx): (Type, Type, Type) =
given CCtx = CCtx.init(code, N)
code match
Expand All @@ -201,12 +195,12 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
given BbCtx = nestCtx
val bds = params.map:
case Param(_, sym, _) =>
val tv = freshVar
val sk = freshSkolem
val tv = freshVar(S(sym.name))
val sk = freshSkolem(S(sym.name))
nestCtx &= (sym, tv, sk)
(tv, sk)
val (bodyTy, ctxTy, eff) = typeCode(body)
val res = freshVar(using ctx)
val res = freshVar(N)(using ctx)
constrain(ctxTy, bds.foldLeft[Type](res)((res, bd) => res | bd._2))
(FunType(bds.map(_._1), bodyTy, Bot), res, eff)
case Term.App(lhs, Term.Tup(rhs)) =>
Expand All @@ -215,26 +209,29 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
case (res, p: Fld) =>
val (ty, ctx, eff) = typeCode(p.term)
(ty :: res._1, res._2 | ctx, res._3 | eff)
val resTy = freshVar
val resTy = freshVar(N)
constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right
(resTy, lhsCtx | rhsCtx, lhsEff | rhsEff)
case sel @ Term.Sel(Term.Ref(_: TopLevelSymbol), _) if sel.symbol.isDefined =>
val (opTy, eff) = typeCheck(Ref(sel.symbol.get)(sel.nme, 666)) // FIXME 666
(tryMkMono(opTy, sel), Bot, eff)
case Term.Unquoted(body) =>
val (ty, eff) = typeCheck(body)
val tv = freshVar
val cr = freshVar
val tv = freshVar(N)
val cr = freshVar(N)
constrain(tryMkMono(ty, body), BbCtx.codeTy(tv, cr))
(tv, cr, eff)
case Term.Blk(LetDecl(sym) :: DefineVar(sym2, rhs) :: Nil, body) if sym2 is sym => // TODO: more than one!!
val (rhsTy, rhsCtx, rhsEff) = typeCode(rhs)(using ctx)
val nestCtx = ctx.nextLevel
given BbCtx = nestCtx
val sk = freshSkolem
val sk = freshSkolem(S(sym.nme))
nestCtx &= (sym, rhsTy, sk)
val (bodyTy, bodyCtx, bodyEff) = typeCode(body)
val res = freshVar(using ctx)
val res = freshVar(N)(using ctx)
constrain(bodyCtx, sk | res)
(bodyTy, rhsCtx | res, rhsEff | bodyEff)
case Term.IfLike(Keyword.`if`, Split.Cons(Branch(cond, Pattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts))) =>
case Term.IfLike(Keyword.`if`, Split.Let(_, cond, Split.Cons(Branch(_, Pattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts)))) =>
val (condTy, condCtx, condEff) = typeCode(cond)
val (consTy, consCtx, consEff) = typeCode(cons)
val (altsTy, altsCtx, altsEff) = typeCode(alts)
Expand All @@ -252,7 +249,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
()
case N =>
given BbCtx = ctx.nextLevel
val funTyV = freshVar
val funTyV = freshVar(S(sym.nme))
pctx += sym -> funTyV // for recursive functions
val (res, _) = typeCheck(lam)
val funTy = tryMkMono(res, lam)
Expand All @@ -266,10 +263,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
: (GeneralType, Type) =
split match
case Split.Cons(Branch(scrutinee, Pattern.ClassLike(sym, _, _, _), cons), alts) =>
// * Pattern matching
// * Pattern matching for classes
val (clsTy, tv, emptyTy) = ctx.getCls(sym.nme).flatMap(_.defn) match
case S(cls) =>
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard)), freshVar, ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(N))), (freshVar(N)), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
case _ =>
error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil)
(Bot, Bot, Bot)
Expand All @@ -286,6 +283,23 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val (altsTy, altsEff) = typeSplit(alts, sign)(using nestCtx2)
val allEff = scrutineeEff | (consEff | altsEff)
(sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff)
// * Pattern matching for literals
case Split.Cons(Branch(scrutinee, Pattern.Lit(lit), cons), alts) =>
val (scrutineeTy, scrutineeEff) = typeCheck(scrutinee)
val litTy = lit match
case _: Tree.BoolLit => BbCtx.boolTy
case _: Tree.IntLit => BbCtx.intTy
case _: Tree.DecLit => BbCtx.numTy
case _: Tree.StrLit => BbCtx.strTy
case _: Tree.UnitLit => Top

constrain(tryMkMono(scrutineeTy, scrutinee), litTy)
val nestCtx1 = ctx.nest
val nestCtx2 = ctx.nest
val (consTy, consEff) = typeSplit(cons, sign)(using nestCtx1)
val (altsTy, altsEff) = typeSplit(alts, sign)(using nestCtx2)
val allEff = scrutineeEff | (consEff | altsEff)
(sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff)
case Split.Let(name, term, tail) =>
val nestCtx = ctx.nest
given BbCtx = nestCtx
Expand Down Expand Up @@ -361,8 +375,8 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val (ty, eff) = typeCheck(f.term)
Left(ty) :: Right(eff) :: Nil
.partitionMap(x => x)
val effVar = freshVar
val retVar = freshVar
val effVar = freshVar(N)
val retVar = freshVar(N)
constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar))
(retVar, argEff.foldLeft[Type](effVar | lhsEff)((res, e) => res | e))

Expand Down Expand Up @@ -394,8 +408,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
(error(msg"Variable not found: ${sym.nme}"
-> t.toLoc :: Nil), Bot)
case Blk(stats, res) =>
val nestCtx = ctx.nest
given BbCtx = nestCtx
val effBuff = ListBuffer.empty[Type]
def goStats(stats: Ls[Statement]): Unit = stats match
case Nil => ()
Expand All @@ -406,7 +418,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
require(sym2 is sym)
val (rhsTy, eff) = typeCheck(rhs)
effBuff += eff
nestCtx += sym -> rhsTy
ctx += sym -> rhsTy
goStats(stats)
case TermDefinition(_, Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _, _) :: stats =>
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
Expand All @@ -418,7 +430,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
ctx += sym -> typeType(sig)
goStats(stats)
case (clsDef: ClassDef) :: stats =>
ctx *= clsDef
goStats(stats)
goStats(stats)
val (ty, eff) = typeCheck(res)
Expand All @@ -434,7 +445,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
given BbCtx = nestCtx
val tvs = params.map:
case Param(_, sym, sign) =>
val ty = sign.map(s => typeType(s)(using nestCtx)).getOrElse(freshVar)
val ty = sign.map(s => typeType(s)(using nestCtx)).getOrElse(freshVar(S(sym.nme)))
nestCtx += sym -> ty
ty
val (bodyTy, eff) = typeCheck(body)
Expand All @@ -446,7 +457,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val map = HashMap[Uid[Symbol], TypeArg]()
val targs = clsDfn.tparams.map {
case TyParam(_, _, targ) =>
val ty = freshWildcard
val ty = freshWildcard(N)
map += targ.uid -> ty
ty
}
Expand All @@ -471,12 +482,12 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val map = HashMap[Uid[Symbol], TypeArg]()
val targs = clsDfn.tparams.map {
case TyParam(_, S(_), targ) =>
val ty = freshVar
val ty = freshVar(N)
map += targ.uid -> ty
ty
case TyParam(_, N, targ) =>
// val ty = freshWildcard // FIXME probably not correct
val ty = freshVar
val ty = freshVar(N)
map += targ.uid -> ty
ty
}
Expand All @@ -498,28 +509,28 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
case Term.Region(sym, body) =>
val nestCtx = ctx.nextLevel
given BbCtx = nestCtx
val sk = freshSkolem
val sk = freshSkolem(S(sym.nme))
nestCtx += sym -> BbCtx.regionTy(sk)
val (res, eff) = typeCheck(body)
val tv = freshVar(using ctx)
val tv = freshVar(N)(using ctx)
constrain(eff, tv | sk)
(extrude(res)(using ctx, true), tv | allocType)
case Term.RegRef(reg, value) =>
val (regTy, regEff) = typeCheck(reg)
val (valTy, valEff) = typeCheck(value)
val sk = freshVar
val sk = freshVar(N)
constrain(tryMkMono(regTy, reg), BbCtx.regionTy(sk))
(BbCtx.refTy(tryMkMono(valTy, value), sk), sk | (regEff | valEff))
case Term.Assgn(lhs, rhs) =>
val (lhsTy, lhsEff) = typeCheck(lhs)
val (rhsTy, rhsEff) = typeCheck(rhs)
val sk = freshVar
val sk = freshVar(N)
constrain(tryMkMono(lhsTy, lhs), BbCtx.refTy(tryMkMono(rhsTy, rhs), sk))
(tryMkMono(rhsTy, rhs), sk | (lhsEff | rhsEff))
case Term.Deref(ref) =>
val (refTy, refEff) = typeCheck(ref)
val sk = freshVar
val ctnt = freshVar
val sk = freshVar(N)
val ctnt = freshVar(N)
constrain(tryMkMono(refTy, ref), BbCtx.refTy(ctnt, sk))
(ctnt, sk | refEff)
case Term.Quoted(body) =>
Expand Down
Loading

0 comments on commit efb66c1

Please sign in to comment.