From 719113df35e57128a505b96a789aa677ba3f772c Mon Sep 17 00:00:00 2001 From: Lionel Parreaux Date: Fri, 22 Nov 2024 23:33:43 +0800 Subject: [PATCH] Add support for term spreads --- .../src/main/scala/hkmc2/bbml/bbML.scala | 26 +++---- .../src/main/scala/hkmc2/codegen/Block.scala | 6 +- .../main/scala/hkmc2/codegen/Lowering.scala | 20 +++--- .../scala/hkmc2/codegen/js/JSBuilder.scala | 3 + .../scala/hkmc2/semantics/Elaborator.scala | 69 +++++++++++-------- .../src/main/scala/hkmc2/semantics/Term.scala | 19 +++-- .../src/main/scala/hkmc2/syntax/Tree.scala | 4 ++ .../main/scala/hkmc2/typing/TypeChecker.scala | 11 +-- .../src/test/mlscript/codegen/BasicTerms.mls | 4 +- .../src/test/mlscript/codegen/Spreads.mls | 25 +++++++ .../mlscript/ucs/papers/OperatorSplit.mls | 22 +++--- .../mlscript/ucs/syntax/NestedOpSplits.mls | 8 +-- 12 files changed, 142 insertions(+), 75 deletions(-) create mode 100644 hkmc2/shared/src/test/mlscript/codegen/Spreads.mls diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala index e51e68359..d2fc0a387 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala @@ -211,10 +211,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL): (FunType(bds.map(_._1), bodyTy, Bot), res, eff) case Term.App(lhs, Term.Tup(rhs)) => val (lhsTy, lhsCtx, lhsEff) = typeCode(lhs) - val (rhsTy, rhsCtx, rhsEff) = rhs.foldLeft[(Ls[Type], Type, Type)]((Nil, Bot, Bot))((res, p) => - val (ty, ctx, eff) = typeCode(p.value) - (ty :: res._1, res._2 | ctx, res._3 | eff) - ) + val (rhsTy, rhsCtx, rhsEff) = rhs.foldLeft[(Ls[Type], Type, Type)]((Nil, Bot, Bot)): + case (res, p: Fld) => + val (ty, ctx, eff) = typeCode(p.term) + (ty :: res._1, res._2 | ctx, res._3 | eff) val resTy = freshVar constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right (resTy, lhsCtx | rhsCtx, lhsEff | rhsEff) @@ -338,7 +338,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL): (rhs, eff) // TODO: t -> loc when toLoc is implemented - private def app(lhs: (GeneralType, Type), rhs: Ls[Fld], t: Term) + private def app(lhs: (GeneralType, Type), rhs: Ls[Elem], t: Term) (using ctx: BbCtx)(using CCtx) : (GeneralType, Type) = lhs match @@ -348,17 +348,19 @@ class BBTyper(using elState: Elaborator.State, tl: TL): then (error(msg"Incorrect number of arguments" -> t.toLoc :: Nil), Bot) else var resEff: Type = lhsEff | eff - rhs.lazyZip(params).foreach: (f, t) => - val (ty, ef) = ascribe(f.value, t) - resEff |= ef + rhs.lazyZip(params).foreach: + case (f: Fld, t) => + val (ty, ef) = ascribe(f.term, t) + resEff |= ef (ret, resEff) case (FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t) case (ty: PolyType, eff) => app((instantiate(ty), eff), rhs, t) case (funTy, lhsEff) => - val (argTy, argEff) = rhs.flatMap(f => - val (ty, eff) = typeCheck(f.value) - Left(ty) :: Right(eff) :: Nil - ).partitionMap(x => x) + val (argTy, argEff) = rhs.flatMap: + case f: Fld => + val (ty, eff) = typeCheck(f.term) + Left(ty) :: Right(eff) :: Nil + .partitionMap(x => x) val effVar = freshVar val retVar = freshVar constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 7d4384560..f8d5ccb01 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -129,7 +129,7 @@ sealed abstract class Result // type Local = LocalSymbol type Local = Symbol -case class Call(fun: Path, args: Ls[Path]) extends Result +case class Call(fun: Path, args: Ls[Arg]) extends Result case class Instantiate(cls: Path, args: Ls[Path]) extends Result @@ -142,5 +142,7 @@ enum Value extends Path: case This(sym: InnerSymbol) // TODO rm – just use Ref case Lit(lit: Literal) case Lam(params: Ls[Param], body: Block) - case Arr(elems: Ls[Path]) + case Arr(elems: Ls[Arg]) + +case class Arg(spread: Bool, value: Path) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 8ee60e22a..634581ff0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -59,26 +59,30 @@ class Lowering(using TL, Raise, Elaborator.State): case st.Asc(lhs, rhs) => term(lhs)(k) case st.Tup(fs) => - fs.foldRight[Ls[Path] => Block](args => k(Value.Arr(args.reverse)))((a, acc) => - args => subTerm(a.value)(r => acc(r :: args)) - )(Nil) + fs.foldRight[Ls[Arg] => Block](args => k(Value.Arr(args.reverse))){ + case (a: Fld, acc) => + args => subTerm(a.term)(r => acc(Arg(false, r) :: args)) + case (s: Spd, acc) => + args => subTerm(s.term)(r => acc(Arg(true, r) :: args)) + }(Nil) case st.Ref(sym) => k(subst(Value.Ref(sym))) case st.App(f, arg) => arg match case Tup(fs) => val as = fs.map: - case sem.Fld(sem.FldFlags.empty, value, N) => value - case sem.Fld(sem.FldFlags(false, false, false, true), value, N) => value + case sem.Fld(sem.FldFlags.empty, value, N) => false -> value + case sem.Fld(sem.FldFlags(false, false, false, true), value, N) => false -> value case sem.Fld(flags, value, asc) => TODO("Other argument forms") + case spd: Spd => true -> spd.term val l = new TempSymbol(S(t)) subTerm(f): fr => - def rec(as: Ls[st], asr: Ls[Path]): Block = as match + def rec(as: Ls[Bool -> st], asr: Ls[Arg]): Block = as match case Nil => k(Call(fr, asr.reverse)) - case a :: as => + case (spd, a) :: as => subTerm(a): ar => - rec(as, ar :: asr) + rec(as, Arg(spd, ar) :: asr) rec(as, Nil) case _ => TODO("Other argument list forms") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 5f456d841..a0fb99651 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -68,6 +68,9 @@ class JSBuilder(using Elaborator.State) extends CodeBuilder: summon[Scope].findThis_!(ts) case _ => summon[Scope].lookup_!(l) + def result(a: Arg)(using Raise, Scope): Document = + if a.spread then doc"...${result(a.value)}" else result(a.value) + def result(r: Result)(using Raise, Scope): Document = r match case Value.This(sym) => summon[Scope].findThis_!(sym) case Value.Lit(Tree.StrLit(value)) => JSBuilder.makeStringLiteral(value) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 16c16726e..eb7175bd5 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -264,25 +264,32 @@ extends Importer: val args = rt match case Term.Tup(fields) => S(fields) case _ => N - val params = lt.symbol - .collect: - case sym: BlockMemberSymbol => sym.trmTree - .flatten - .collect: - case td: TermDef => td.paramLists.headOption - .flatten - for - (args, params) <- (args zip params) - (arg, param) <- (args zip params.fields) - do - val argMod = arg.flags.mod - val paramMod = param match - case Tree.TypeDef(Mod, _, N, N) => true - case _ => false - if argMod && !paramMod then raise: - ErrorReport: - msg"Only module parameters may receive module arguments (values)." -> - arg.toLoc :: Nil + if args.exists: + _.exists: + case spd: Spd => false + case fld: Fld => fld.flags.mod + then + val params = lt.symbol + .collect: + case sym: BlockMemberSymbol => sym.trmTree + .flatten + .collect: + case td: TermDef => td.paramLists.headOption + .flatten + for + (args, params) <- (args zip params) + (arg, param) <- (args zip params.fields) + do + arg match + case spd: Spd => + TODO(spd) + case arg: Fld => + val argMod = arg.flags.mod + val paramMod = param.isModuleModifier + if argMod && !paramMod then raise: + ErrorReport: + msg"Only module parameters may receive module arguments (values)." -> + arg.toLoc :: Nil Term.App(lt, rt)(tree, sym) case Sel(pre, nme) => @@ -375,18 +382,25 @@ extends Importer: case Open(body) => raise(ErrorReport(msg"Illegal position for 'open' statement." -> tree.toLoc :: Nil)) Term.Error + case Spread(kw, kwLoc, body) => + raise(ErrorReport(msg"Illegal position for '${kw.name}' spread operator." -> tree.toLoc :: Nil)) + Term.Error // case _ => // ??? - def fld(tree: Tree): Ctxl[Fld] = tree match + def fld(tree: Tree): Ctxl[Elem] = tree match case InfixApp(lhs, Keyword.`:`, rhs) => Fld(FldFlags.empty, term(lhs), S(term(rhs))) + case Spread(Keyword.`..`, _, S(trm)) => + Spd(false, term(trm)) + case Spread(Keyword.`...`, _, S(trm)) => + Spd(true, term(trm)) case _ => val t = term(tree) - val flags = FldFlags.empty - if ModuleChecker.evalsToModule(t) - then Fld(flags.copy(mod = true), t, N) - else Fld(flags, t, N) + var flags = FldFlags.empty + if ModuleChecker.evalsToModule(t) + then flags = flags.copy(mod = true) + Fld(flags, t, N) def unit: Term.Lit = Term.Lit(UnitLit(true)) @@ -880,9 +894,10 @@ extends Importer: // fields.foreach(f => traverseType(pol)(f.value)) fields.foreach(traverseType(pol)) // case _ => ??? - def traverseType(pol: Pol)(f: Fld): Unit = - traverseType(pol)(f.value) - f.asc.foreach(traverseType(pol)) + def traverseType(pol: Pol)(f: Elem): Unit = f match + case f: Fld => + traverseType(pol)(f.term) + f.asc.foreach(traverseType(pol)) def traverseType(pol: Pol)(f: Param): Unit = f.sign.foreach(traverseType(pol)) end Elaborator diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala index bf0ec2727..36e4bcbbc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala @@ -16,7 +16,7 @@ enum Term extends Statement: case App(lhs: Term, rhs: Term)(val tree: Tree.App, val resSym: FlowSymbol) case TyApp(lhs: Term, targs: Ls[Term]) case Sel(prefix: Term, nme: Tree.Ident)(val sym: Opt[Symbol]) - case Tup(fields: Ls[Fld])(val tree: Tree.Tup) + case Tup(fields: Ls[Elem])(val tree: Tree.Tup) case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, desugared: Split)(val normalized: Split) case Lam(params: Ls[Param], body: Term) case FunTy(lhs: Term, rhs: Term, eff: Opt[Term]) @@ -85,7 +85,7 @@ sealed trait Statement extends AutoLocated: case FunTy(lhs, rhs, eff) => lhs :: rhs :: eff.toList case TyApp(pre, tarsg) => pre :: tarsg case Sel(pre, _) => pre :: Nil - case Tup(fields) => fields.map(_.value) + case Tup(fields) => fields.flatMap(_.subTerms) case IfLike(_, body) => body.subTerms case Lam(params, body) => body :: Nil case Blk(stats, res) => stats.flatMap(_.subTerms) ::: res :: Nil @@ -305,7 +305,14 @@ final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool): flags.mkString(" ") override def toString: String = "‹" + showDbg + "›" -final case class Fld(flags: FldFlags, value: Term, asc: Opt[Term]) extends FldImpl +sealed abstract class Elem: + def subTerms: Ls[Term] = this match + case Fld(_, term, asc) => term :: asc.toList + case Spd(_, term) => term :: Nil + def showDbg: Str +final case class Fld(flags: FldFlags, term: Term, asc: Opt[Term]) extends Elem with FldImpl +final case class Spd(eager: Bool, term: Term) extends Elem: + def showDbg: Str = (if eager then "..." else "..") + term.showDbg final case class TyParam(flags: FldFlags, vce: Opt[Bool], sym: VarSymbol) extends Declaration: @@ -341,11 +348,11 @@ final case class ParamList(flags: ParamListFlags, params: Ls[Param]): trait FldImpl extends AutoLocated: self: Fld => - def children: Ls[Located] = self.value :: self.asc.toList ::: Nil - def showDbg: Str = flags.showDbg + self.value.showDbg + def children: Ls[Located] = self.term :: self.asc.toList ::: Nil + def showDbg: Str = flags.showDbg + self.term.showDbg def describe: Str = (if self.flags.spec then "specialized " else "") + (if self.flags.mut then "mutable " else "") + - self.value.describe + self.term.describe diff --git a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala index a9ea7aa9d..13dda8664 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala @@ -161,6 +161,10 @@ enum Tree extends AutoLocated: case InfixApp(lhs: Ident, Keyword.`:`, rhs) => (lhs, S(rhs)) :: Nil case App(Ident(","), Tup(ps)) => ps.flatMap(_.param) case TermDef(ImmutVal, inner, _) => inner.param + + def isModuleModifier: Bool = this match + case Tree.TypeDef(Mod, _, N, N) => true + case _ => false object Tree: object Block: diff --git a/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala b/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala index af3aae45c..73f968523 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala @@ -71,8 +71,9 @@ class TypeChecker(using Raise, Elaborator.State): msg"Expected ${ps.size.toString} arguments, but got ${ args.size.toString}" -> t.toLoc :: Nil)) // val p1 = ps.zip(args).map: (p, a) => - val p1 = ps.zip(args).foreach: (p, a) => - constrain(P.enterIf(typeProd(a.value), ts, r.refNum, rc), C.Flow(p.sym.asInstanceOf/*FIXME*/)) + val p1 = ps.zip(args).foreach: + case (p, a: Fld) => + constrain(P.enterIf(typeProd(a.term), ts, r.refNum, rc), C.Flow(p.sym.asInstanceOf/*FIXME*/)) constrain(P.Flow(td.resSym), C.Flow(app.resSym)) // P.Flow(td.resSym) P.Flow(app.resSym) @@ -84,7 +85,8 @@ class TypeChecker(using Raise, Elaborator.State): // case Ref(ClassSymbol(Ident("true"))) => // P.Ctor(LitSymbol(Tree.UnitLit(true)), Nil) case Tup(fields) => - P.Ctor(TupSymbol(S(fields.size)), fields.map(f => typeProd(f.value))) + P.Ctor(TupSymbol(S(fields.size)), fields.map: + case f: Fld => typeProd(f.term)) case Error => P.Ctor(Extr(false), Nil) case _ => P.Flow(FlowSymbol("TODO")) // TODO @@ -98,7 +100,8 @@ class TypeChecker(using Raise, Elaborator.State): case Ref(cls: ClassSymbol) => C.Ctor(cls, Nil) case Ref(ts: TermSymbol) => ??? case Tup(fields) => - C.Ctor(TupSymbol(S(fields.size)), fields.map(f => typeCons(f.value))) + C.Ctor(TupSymbol(S(fields.size)), fields.map: + case f: Fld => typeCons(f.term)) // case _ => TODO(t) case class CCtx(path: Ls[L]) diff --git a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls index a5a5252a4..a58c8a926 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BasicTerms.mls @@ -47,7 +47,9 @@ log("Hi") //│ qual = Ref of globalThis:import#Prelude //│ name = Ident of "log" //│ args = Ls of -//│ Lit of StrLit of "Hi" +//│ Arg: +//│ spread = false +//│ value = Lit of StrLit of "Hi" //│ rest = Return: \ //│ res = Lit of IntLit of 2 //│ implct = true diff --git a/hkmc2/shared/src/test/mlscript/codegen/Spreads.mls b/hkmc2/shared/src/test/mlscript/codegen/Spreads.mls new file mode 100644 index 000000000..565e0c1f6 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/codegen/Spreads.mls @@ -0,0 +1,25 @@ +:js + + +let a = [1, 2, 3] +//│ a = [ 1, 2, 3 ] + +let b = [0, ...a] +//│ b = [ 0, 1, 2, 3 ] + + +fun foo(w, x, y, z) = [w, x, y, z] + +foo(...b) +//│ = [ 0, 1, 2, 3 ] + +foo(0, ...a) +//│ = [ 0, 1, 2, 3 ] + +:sjs +foo(1, ...[2, 3], 4) +//│ JS: +//│ this.foo(1, ...[ 2, 3 ], 4) +//│ = [ 1, 2, 3, 4 ] + + diff --git a/hkmc2/shared/src/test/mlscript/ucs/papers/OperatorSplit.mls b/hkmc2/shared/src/test/mlscript/ucs/papers/OperatorSplit.mls index 55e4d09dd..f2c93fb0b 100644 --- a/hkmc2/shared/src/test/mlscript/ucs/papers/OperatorSplit.mls +++ b/hkmc2/shared/src/test/mlscript/ucs/papers/OperatorSplit.mls @@ -92,7 +92,7 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of args +//│ term = Ref of args //│ asc = N //│ tail = Let: \ //│ sym = $scrut @@ -101,11 +101,11 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 0 +//│ term = Lit of IntLit of 0 //│ asc = N //│ tail = Cons: \ //│ head = Branch: @@ -119,11 +119,11 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 0 +//│ term = Lit of IntLit of 0 //│ asc = N //│ tail = Cons: \ //│ head = Branch: @@ -139,11 +139,11 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Sel: +//│ term = Sel: //│ prefix = Ref of globalThis:block#1 //│ nme = Ident of "abs" //│ asc = N @@ -154,11 +154,11 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 100 +//│ term = Lit of IntLit of 100 //│ asc = N //│ tail = Cons: \ //│ head = Branch: @@ -172,11 +172,11 @@ fun example(args) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 10 +//│ term = Lit of IntLit of 10 //│ asc = N //│ tail = Cons: \ //│ head = Branch: diff --git a/hkmc2/shared/src/test/mlscript/ucs/syntax/NestedOpSplits.mls b/hkmc2/shared/src/test/mlscript/ucs/syntax/NestedOpSplits.mls index a831a7335..0acd65038 100644 --- a/hkmc2/shared/src/test/mlscript/ucs/syntax/NestedOpSplits.mls +++ b/hkmc2/shared/src/test/mlscript/ucs/syntax/NestedOpSplits.mls @@ -32,11 +32,11 @@ fun f(x) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of x +//│ term = Ref of x //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 1 +//│ term = Lit of IntLit of 1 //│ asc = N //│ tail = Let: \ //│ sym = $scrut @@ -45,11 +45,11 @@ fun f(x) = //│ rhs = Tup of Ls of //│ Fld: //│ flags = () -//│ value = Ref of $scrut +//│ term = Ref of $scrut //│ asc = N //│ Fld: //│ flags = () -//│ value = Lit of IntLit of 2 +//│ term = Lit of IntLit of 2 //│ asc = N //│ tail = Cons: \ //│ head = Branch: