Skip to content

Commit

Permalink
Add support for term spreads
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Nov 22, 2024
1 parent ac7c3d1 commit 719113d
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 75 deletions.
26 changes: 14 additions & 12 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
(FunType(bds.map(_._1), bodyTy, Bot), res, eff)
case Term.App(lhs, Term.Tup(rhs)) =>
val (lhsTy, lhsCtx, lhsEff) = typeCode(lhs)
val (rhsTy, rhsCtx, rhsEff) = rhs.foldLeft[(Ls[Type], Type, Type)]((Nil, Bot, Bot))((res, p) =>
val (ty, ctx, eff) = typeCode(p.value)
(ty :: res._1, res._2 | ctx, res._3 | eff)
)
val (rhsTy, rhsCtx, rhsEff) = rhs.foldLeft[(Ls[Type], Type, Type)]((Nil, Bot, Bot)):
case (res, p: Fld) =>
val (ty, ctx, eff) = typeCode(p.term)
(ty :: res._1, res._2 | ctx, res._3 | eff)
val resTy = freshVar
constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right
(resTy, lhsCtx | rhsCtx, lhsEff | rhsEff)
Expand Down Expand Up @@ -338,7 +338,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
(rhs, eff)

// TODO: t -> loc when toLoc is implemented
private def app(lhs: (GeneralType, Type), rhs: Ls[Fld], t: Term)
private def app(lhs: (GeneralType, Type), rhs: Ls[Elem], t: Term)
(using ctx: BbCtx)(using CCtx)
: (GeneralType, Type) =
lhs match
Expand All @@ -348,17 +348,19 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
then (error(msg"Incorrect number of arguments" -> t.toLoc :: Nil), Bot)
else
var resEff: Type = lhsEff | eff
rhs.lazyZip(params).foreach: (f, t) =>
val (ty, ef) = ascribe(f.value, t)
resEff |= ef
rhs.lazyZip(params).foreach:
case (f: Fld, t) =>
val (ty, ef) = ascribe(f.term, t)
resEff |= ef
(ret, resEff)
case (FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t)
case (ty: PolyType, eff) => app((instantiate(ty), eff), rhs, t)
case (funTy, lhsEff) =>
val (argTy, argEff) = rhs.flatMap(f =>
val (ty, eff) = typeCheck(f.value)
Left(ty) :: Right(eff) :: Nil
).partitionMap(x => x)
val (argTy, argEff) = rhs.flatMap:
case f: Fld =>
val (ty, eff) = typeCheck(f.term)
Left(ty) :: Right(eff) :: Nil
.partitionMap(x => x)
val effVar = freshVar
val retVar = freshVar
constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar))
Expand Down
6 changes: 4 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ sealed abstract class Result
// type Local = LocalSymbol
type Local = Symbol

case class Call(fun: Path, args: Ls[Path]) extends Result
case class Call(fun: Path, args: Ls[Arg]) extends Result

case class Instantiate(cls: Path, args: Ls[Path]) extends Result

Expand All @@ -142,5 +142,7 @@ enum Value extends Path:
case This(sym: InnerSymbol) // TODO rm – just use Ref
case Lit(lit: Literal)
case Lam(params: Ls[Param], body: Block)
case Arr(elems: Ls[Path])
case Arr(elems: Ls[Arg])

case class Arg(spread: Bool, value: Path)

20 changes: 12 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,30 @@ class Lowering(using TL, Raise, Elaborator.State):
case st.Asc(lhs, rhs) =>
term(lhs)(k)
case st.Tup(fs) =>
fs.foldRight[Ls[Path] => Block](args => k(Value.Arr(args.reverse)))((a, acc) =>
args => subTerm(a.value)(r => acc(r :: args))
)(Nil)
fs.foldRight[Ls[Arg] => Block](args => k(Value.Arr(args.reverse))){
case (a: Fld, acc) =>
args => subTerm(a.term)(r => acc(Arg(false, r) :: args))
case (s: Spd, acc) =>
args => subTerm(s.term)(r => acc(Arg(true, r) :: args))
}(Nil)
case st.Ref(sym) =>
k(subst(Value.Ref(sym)))
case st.App(f, arg) =>
arg match
case Tup(fs) =>
val as = fs.map:
case sem.Fld(sem.FldFlags.empty, value, N) => value
case sem.Fld(sem.FldFlags(false, false, false, true), value, N) => value
case sem.Fld(sem.FldFlags.empty, value, N) => false -> value
case sem.Fld(sem.FldFlags(false, false, false, true), value, N) => false -> value
case sem.Fld(flags, value, asc) =>
TODO("Other argument forms")
case spd: Spd => true -> spd.term
val l = new TempSymbol(S(t))
subTerm(f): fr =>
def rec(as: Ls[st], asr: Ls[Path]): Block = as match
def rec(as: Ls[Bool -> st], asr: Ls[Arg]): Block = as match
case Nil => k(Call(fr, asr.reverse))
case a :: as =>
case (spd, a) :: as =>
subTerm(a): ar =>
rec(as, ar :: asr)
rec(as, Arg(spd, ar) :: asr)
rec(as, Nil)
case _ =>
TODO("Other argument list forms")
Expand Down
3 changes: 3 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class JSBuilder(using Elaborator.State) extends CodeBuilder:
summon[Scope].findThis_!(ts)
case _ => summon[Scope].lookup_!(l)

def result(a: Arg)(using Raise, Scope): Document =
if a.spread then doc"...${result(a.value)}" else result(a.value)

def result(r: Result)(using Raise, Scope): Document = r match
case Value.This(sym) => summon[Scope].findThis_!(sym)
case Value.Lit(Tree.StrLit(value)) => JSBuilder.makeStringLiteral(value)
Expand Down
69 changes: 42 additions & 27 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,25 +264,32 @@ extends Importer:
val args = rt match
case Term.Tup(fields) => S(fields)
case _ => N
val params = lt.symbol
.collect:
case sym: BlockMemberSymbol => sym.trmTree
.flatten
.collect:
case td: TermDef => td.paramLists.headOption
.flatten
for
(args, params) <- (args zip params)
(arg, param) <- (args zip params.fields)
do
val argMod = arg.flags.mod
val paramMod = param match
case Tree.TypeDef(Mod, _, N, N) => true
case _ => false
if argMod && !paramMod then raise:
ErrorReport:
msg"Only module parameters may receive module arguments (values)." ->
arg.toLoc :: Nil
if args.exists:
_.exists:
case spd: Spd => false
case fld: Fld => fld.flags.mod
then
val params = lt.symbol
.collect:
case sym: BlockMemberSymbol => sym.trmTree
.flatten
.collect:
case td: TermDef => td.paramLists.headOption
.flatten
for
(args, params) <- (args zip params)
(arg, param) <- (args zip params.fields)
do
arg match
case spd: Spd =>
TODO(spd)
case arg: Fld =>
val argMod = arg.flags.mod
val paramMod = param.isModuleModifier
if argMod && !paramMod then raise:
ErrorReport:
msg"Only module parameters may receive module arguments (values)." ->
arg.toLoc :: Nil

Term.App(lt, rt)(tree, sym)
case Sel(pre, nme) =>
Expand Down Expand Up @@ -375,18 +382,25 @@ extends Importer:
case Open(body) =>
raise(ErrorReport(msg"Illegal position for 'open' statement." -> tree.toLoc :: Nil))
Term.Error
case Spread(kw, kwLoc, body) =>
raise(ErrorReport(msg"Illegal position for '${kw.name}' spread operator." -> tree.toLoc :: Nil))
Term.Error
// case _ =>
// ???

def fld(tree: Tree): Ctxl[Fld] = tree match
def fld(tree: Tree): Ctxl[Elem] = tree match
case InfixApp(lhs, Keyword.`:`, rhs) =>
Fld(FldFlags.empty, term(lhs), S(term(rhs)))
case Spread(Keyword.`..`, _, S(trm)) =>
Spd(false, term(trm))
case Spread(Keyword.`...`, _, S(trm)) =>
Spd(true, term(trm))
case _ =>
val t = term(tree)
val flags = FldFlags.empty
if ModuleChecker.evalsToModule(t)
then Fld(flags.copy(mod = true), t, N)
else Fld(flags, t, N)
var flags = FldFlags.empty
if ModuleChecker.evalsToModule(t)
then flags = flags.copy(mod = true)
Fld(flags, t, N)

def unit: Term.Lit = Term.Lit(UnitLit(true))

Expand Down Expand Up @@ -880,9 +894,10 @@ extends Importer:
// fields.foreach(f => traverseType(pol)(f.value))
fields.foreach(traverseType(pol))
// case _ => ???
def traverseType(pol: Pol)(f: Fld): Unit =
traverseType(pol)(f.value)
f.asc.foreach(traverseType(pol))
def traverseType(pol: Pol)(f: Elem): Unit = f match
case f: Fld =>
traverseType(pol)(f.term)
f.asc.foreach(traverseType(pol))
def traverseType(pol: Pol)(f: Param): Unit =
f.sign.foreach(traverseType(pol))
end Elaborator
Expand Down
19 changes: 13 additions & 6 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ enum Term extends Statement:
case App(lhs: Term, rhs: Term)(val tree: Tree.App, val resSym: FlowSymbol)
case TyApp(lhs: Term, targs: Ls[Term])
case Sel(prefix: Term, nme: Tree.Ident)(val sym: Opt[Symbol])
case Tup(fields: Ls[Fld])(val tree: Tree.Tup)
case Tup(fields: Ls[Elem])(val tree: Tree.Tup)
case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, desugared: Split)(val normalized: Split)
case Lam(params: Ls[Param], body: Term)
case FunTy(lhs: Term, rhs: Term, eff: Opt[Term])
Expand Down Expand Up @@ -85,7 +85,7 @@ sealed trait Statement extends AutoLocated:
case FunTy(lhs, rhs, eff) => lhs :: rhs :: eff.toList
case TyApp(pre, tarsg) => pre :: tarsg
case Sel(pre, _) => pre :: Nil
case Tup(fields) => fields.map(_.value)
case Tup(fields) => fields.flatMap(_.subTerms)
case IfLike(_, body) => body.subTerms
case Lam(params, body) => body :: Nil
case Blk(stats, res) => stats.flatMap(_.subTerms) ::: res :: Nil
Expand Down Expand Up @@ -305,7 +305,14 @@ final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool):
flags.mkString(" ")
override def toString: String = "" + showDbg + ""

final case class Fld(flags: FldFlags, value: Term, asc: Opt[Term]) extends FldImpl
sealed abstract class Elem:
def subTerms: Ls[Term] = this match
case Fld(_, term, asc) => term :: asc.toList
case Spd(_, term) => term :: Nil
def showDbg: Str
final case class Fld(flags: FldFlags, term: Term, asc: Opt[Term]) extends Elem with FldImpl
final case class Spd(eager: Bool, term: Term) extends Elem:
def showDbg: Str = (if eager then "..." else "..") + term.showDbg

final case class TyParam(flags: FldFlags, vce: Opt[Bool], sym: VarSymbol) extends Declaration:

Expand Down Expand Up @@ -341,11 +348,11 @@ final case class ParamList(flags: ParamListFlags, params: Ls[Param]):

trait FldImpl extends AutoLocated:
self: Fld =>
def children: Ls[Located] = self.value :: self.asc.toList ::: Nil
def showDbg: Str = flags.showDbg + self.value.showDbg
def children: Ls[Located] = self.term :: self.asc.toList ::: Nil
def showDbg: Str = flags.showDbg + self.term.showDbg
def describe: Str =
(if self.flags.spec then "specialized " else "") +
(if self.flags.mut then "mutable " else "") +
self.value.describe
self.term.describe


4 changes: 4 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ enum Tree extends AutoLocated:
case InfixApp(lhs: Ident, Keyword.`:`, rhs) => (lhs, S(rhs)) :: Nil
case App(Ident(","), Tup(ps)) => ps.flatMap(_.param)
case TermDef(ImmutVal, inner, _) => inner.param

def isModuleModifier: Bool = this match
case Tree.TypeDef(Mod, _, N, N) => true
case _ => false

object Tree:
object Block:
Expand Down
11 changes: 7 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ class TypeChecker(using Raise, Elaborator.State):
msg"Expected ${ps.size.toString} arguments, but got ${
args.size.toString}" -> t.toLoc :: Nil))
// val p1 = ps.zip(args).map: (p, a) =>
val p1 = ps.zip(args).foreach: (p, a) =>
constrain(P.enterIf(typeProd(a.value), ts, r.refNum, rc), C.Flow(p.sym.asInstanceOf/*FIXME*/))
val p1 = ps.zip(args).foreach:
case (p, a: Fld) =>
constrain(P.enterIf(typeProd(a.term), ts, r.refNum, rc), C.Flow(p.sym.asInstanceOf/*FIXME*/))
constrain(P.Flow(td.resSym), C.Flow(app.resSym))
// P.Flow(td.resSym)
P.Flow(app.resSym)
Expand All @@ -84,7 +85,8 @@ class TypeChecker(using Raise, Elaborator.State):
// case Ref(ClassSymbol(Ident("true"))) =>
// P.Ctor(LitSymbol(Tree.UnitLit(true)), Nil)
case Tup(fields) =>
P.Ctor(TupSymbol(S(fields.size)), fields.map(f => typeProd(f.value)))
P.Ctor(TupSymbol(S(fields.size)), fields.map:
case f: Fld => typeProd(f.term))
case Error =>
P.Ctor(Extr(false), Nil)
case _ => P.Flow(FlowSymbol("TODO")) // TODO
Expand All @@ -98,7 +100,8 @@ class TypeChecker(using Raise, Elaborator.State):
case Ref(cls: ClassSymbol) => C.Ctor(cls, Nil)
case Ref(ts: TermSymbol) => ???
case Tup(fields) =>
C.Ctor(TupSymbol(S(fields.size)), fields.map(f => typeCons(f.value)))
C.Ctor(TupSymbol(S(fields.size)), fields.map:
case f: Fld => typeCons(f.term))
// case _ => TODO(t)

case class CCtx(path: Ls[L])
Expand Down
4 changes: 3 additions & 1 deletion hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ log("Hi")
//│ qual = Ref of globalThis:import#Prelude
//│ name = Ident of "log"
//│ args = Ls of
//│ Lit of StrLit of "Hi"
//│ Arg:
//│ spread = false
//│ value = Lit of StrLit of "Hi"
//│ rest = Return: \
//│ res = Lit of IntLit of 2
//│ implct = true
Expand Down
25 changes: 25 additions & 0 deletions hkmc2/shared/src/test/mlscript/codegen/Spreads.mls
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
:js


let a = [1, 2, 3]
//│ a = [ 1, 2, 3 ]

let b = [0, ...a]
//│ b = [ 0, 1, 2, 3 ]


fun foo(w, x, y, z) = [w, x, y, z]

foo(...b)
//│ = [ 0, 1, 2, 3 ]

foo(0, ...a)
//│ = [ 0, 1, 2, 3 ]

:sjs
foo(1, ...[2, 3], 4)
//│ JS:
//│ this.foo(1, ...[ 2, 3 ], 4)
//│ = [ 1, 2, 3, 4 ]


Loading

0 comments on commit 719113d

Please sign in to comment.