diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala index 31ca23852..13e823ff9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala @@ -448,11 +448,11 @@ class BBTyper(using elState: Elaborator.State, tl: TL): effBuff += eff nestCtx += sym -> rhsTy goStats(stats) - case TermDefinition(Fun, sym, params, sig, Some(body), _) :: stats => - typeFunDef(sym, params match { - case S(params) => Term.Lam(params, body) - case _ => body // * may be a case expressions - }, sig, ctx) + case TermDefinition(Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _) :: stats => + typeFunDef(sym, Term.Lam(ps, body), sig, ctx) + goStats(stats) + case TermDefinition(Fun, sym, Nil, sig, Some(body), _) :: stats => + typeFunDef(sym, body, sig, ctx) // * may be a case expressions goStats(stats) case TermDefinition(Fun, sym, _, S(sig), None, _) :: stats => ctx += sym -> typeType(sig) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 94706f7a2..9ee982c8c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -74,7 +74,7 @@ sealed abstract class Defn: final case class TermDefn( k: syntax.TermDefKind, sym: TermSymbol, - params: Opt[Ls[Param]], + params: Ls[ParamList], body: Block, ) extends Defn diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index 0fbe435ed..e4284b4ce 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -10,8 +10,11 @@ import Scope.scope import hkmc2.syntax.ImmutVal import hkmc2.semantics.Elaborator import hkmc2.syntax.Tree +import hkmc2.syntax.TermDefKind import hkmc2.semantics.TopLevelSymbol import hkmc2.semantics.MemberSymbol +import hkmc2.semantics.ParamList +import hkmc2.codegen.Value.Lam // TODO factor some logic for other codegen backends @@ -99,11 +102,15 @@ class JSBuilder extends CodeBuilder: result(Value.This(sym)) val (thisProxy, res) = scope.nestRebindThis(defn.sym): val defnJS = defn match - case TermDefn(syntax.Fun, sym, N, body) => + case TermDefn(syntax.Fun, sym, Nil, body) => TODO("getters") - case TermDefn(syntax.Fun, sym, S(ps), bod) => - val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ") - doc"function ${sym.nme}($vars) { #{ # ${body(bod)} #} # }" + case TermDefn(syntax.Fun, sym, ParamList(_, ps) :: pss, bod) => + val paramList = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ") + val result = pss.foldRight(bod)({ + case (ParamList(_, ps), block) => + Return(Lam(ps, block), false) + }) + doc"function ${sym.nme}(${paramList}) { #{ # ${body(result)} #} # }" case ClsDefn(sym, syntax.Cls, mtds, flds, ctor) => val clsDefn = sym.defn.getOrElse(die) val clsParams = clsDefn.paramsOpt.getOrElse(Nil) @@ -118,11 +125,12 @@ class JSBuilder extends CodeBuilder: }) { #{ # ${ ctorCode.stripBreaks } #} # }${ - mtds.map: td => - val vars = td.params.getOrElse(Nil).map(p => scope.allocateName(p.sym)).mkDocument(", ") - doc" # ${td.sym.nme}($vars) { #{ # ${ - body(td.body) - } #} # }" + mtds.map: + case td @ TermDefn(_, _, ParamList(_, ps) :: Nil, _) => + val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ") + doc" # ${td.sym.nme}($vars) { #{ # ${ + body(td.body) + } #} # }" .mkDocument(" ") }${ if mtds.exists(_.sym.nme == "toString") diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 3238fa835..2926a912f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -428,13 +428,14 @@ extends Importer: case S(t) => typeParams(t) case N => (N, ctx) // Add parameters to context - val (ps, newCtx) = td.paramLists.foldLeft((Ls[Param](), newCtx1)): - case ((ps, ctx), t) => params(t)(using ctx).mapFirst(ps ++ _) - .mapFirst(some) + val (pss, newCtx) = + td.paramLists.foldLeft(Ls[ParamList](), newCtx1)({case ((pss, ctx), ps) => + val (qs, newCtx) = params(ps)(using ctx) + (pss :+ ParamList(ParamListFlags.empty, qs), newCtx) + }) val b = rhs.map(term(_)(using newCtx)) val r = FlowSymbol(s"‹result of ${sym}›", nextUid) - val tdf = TermDefinition(k, sym, ps, - td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r) + val tdf = TermDefinition(k, sym, pss, td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r) sym.defn = S(tdf) tdf go(sts, tdf :: acc) @@ -592,8 +593,8 @@ extends Importer: def computeVariances(s: Statement): Unit = val trav = VarianceTraverser() def go(s: Statement): Unit = s match - case TermDefinition(k, sym, ps, sign, body, r) => - ps.foreach(_.foreach(trav.traverseType(S(false)))) + case TermDefinition(k, sym, pss, sign, body, r) => + pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false)))) sign.foreach(trav.traverseType(S(true))) body match case S(b) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala index 8f9686754..bd5ed716f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala @@ -113,7 +113,7 @@ sealed trait Statement extends AutoLocated: case Assgn(lhs, rhs) => lhs :: rhs :: Nil case Deref(term) => term :: Nil case TermDefinition(k, _, ps, sign, body, res) => - ps.toList.flatMap(_.flatMap(_.subTerms)) ::: sign.toList ::: body.toList + ps.toList.flatMap(_.subTerms) ::: sign.toList ::: body.toList case cls: ClassDef => cls.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: cls.body.blk :: Nil case td: TypeDef => @@ -175,7 +175,7 @@ sealed trait Statement extends AutoLocated: case Error => "" case Tup(fields) => fields.map(_.showDbg).mkString("[", ", ", "]") case TermDefinition(k, sym, ps, sign, body, res) => s"${k.str} ${sym}${ - ps.fold("")(_.map(_.showDbg).mkString("(", ", ", ")")) + ps.map(_.showDbg).mkString("") }${sign.fold("")(": "+_.showDbg)}${ body match case S(x) => " = " + x.showDbg @@ -194,7 +194,7 @@ final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement final case class TermDefinition( k: TermDefKind, sym: TermSymbol, - params: Opt[Ls[Param]], + params: Ls[ParamList], sign: Opt[Term], body: Opt[Term], resSym: FlowSymbol, @@ -272,6 +272,17 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op object FldFlags { val empty: FldFlags = FldFlags(false, false, false) } +final case class ParamListFlags(ctx: Bool): + def showDbg: Str = (if ctx then "ctx " else "") + override def toString: String = "‹" + showDbg + "›" + +object ParamListFlags: + val empty = ParamListFlags(false) + +final case class ParamList(flags: ParamListFlags, params: Ls[Param]): + def subTerms: Ls[Term] = params.flatMap(_.subTerms) + def showDbg: Str = flags.showDbg + params.mkString("(", ", ", ")") + trait FldImpl extends AutoLocated: self: Fld => def children: Ls[Located] = self.value :: self.asc.toList ::: Nil diff --git a/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala b/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala index ddd2d50ce..1b730b0e0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala @@ -30,7 +30,7 @@ class TypeChecker(using raise: Raise): ts.defn match case S(td: TermDefinition) => td.params match - case N => P.Flow(td.resSym) + case Nil => P.Flow(td.resSym) case Blk(stats, res) => // val p1 = stats.map(typeStat) // val p2 = typeProd(res) @@ -40,7 +40,7 @@ class TypeChecker(using raise: Raise): stats.foreach: case t: TermDefinition => t.sign.map(typeProd) - t.params.map(typeParams) + t.params.map(_.params).map(typeParams) t.body.map(typeProd) P.Ctor(LitSymbol(Tree.UnitLit(true)), Nil) case t: Term => @@ -57,10 +57,12 @@ class TypeChecker(using raise: Raise): ts.defn match case S(td: TermDefinition) => td.params match - case N => + case Nil => val f = typeProd(r) constrain(P.exitIf(f, ts, r.refNum, rc), C.Fun(typeProd(tup), C.Flow(app.resSym))) - case S(ps) => + case ParamList(_, ps) :: Nil => + // App applies to the leftmost parameter list + // TODO: how to recursively check the subsequent Apps (if any)? if ps.size != args.size then raise(ErrorReport( msg"Expected ${ps.size.toString} arguments, but got ${ diff --git a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls index 83b5abc15..79cb54974 100644 --- a/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls +++ b/hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls @@ -1,16 +1,60 @@ :js - - -// FIXME elbaoration is currently wrong - :sjs -fun foo(x)(y) = x * y + +fun f(n1: Int): Int = n1 +//│ JS: +//│ function f(n1) { return n1 }; undefined +f(42) //│ JS: -//│ function foo(x, y) { return x * y }; undefined +//│ this.f(42) +//│ = 42 -:sjs -fun foo(x)(y)(z) = x * y + z +fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2) //│ JS: -//│ function foo(x, y, z) { let tmp; tmp = x * y; return tmp + z }; undefined +//│ function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2 } }; undefined +f(4)(2) +//│ JS: +//│ let tmp; tmp = this.f(4); tmp(2) +//│ = 42 +fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3 +//│ JS: +//│ function f(n1) { +//│ return (n2) => { +//│ return (n3) => { +//│ let tmp, tmp1, tmp2; +//│ tmp = 10 * n1; +//│ tmp1 = tmp + n2; +//│ tmp2 = 10 * tmp1; +//│ return tmp2 + n3 +//│ } +//│ } +//│ }; +//│ undefined +f(4)(2)(9) +//│ JS: +//│ let tmp, tmp1; tmp = this.f(4); tmp1 = tmp(2); tmp1(9) +//│ = 429 +fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3) + n4 +//│ JS: +//│ function f(n1) { +//│ return (n2) => { +//│ return (n3) => { +//│ return (n4) => { +//│ let tmp, tmp1, tmp2, tmp3, tmp4; +//│ tmp = 10 * n1; +//│ tmp1 = tmp + n2; +//│ tmp2 = 10 * tmp1; +//│ tmp3 = tmp2 + n3; +//│ tmp4 = 10 * tmp3; +//│ return tmp4 + n4 +//│ } +//│ } +//│ } +//│ }; +//│ undefined +f(3)(0)(3)(1) +//│ JS: +//│ let tmp, tmp1, tmp2; tmp = this.f(3); tmp1 = tmp(0); tmp2 = tmp1(3); tmp2(1) +//│ = 3031