Skip to content

Commit

Permalink
Adapt Elaborator and JSBuilder for multiple parameter lists
Browse files Browse the repository at this point in the history
  • Loading branch information
FlandiaYingman committed Oct 30, 2024
1 parent 883abd9 commit c3071ac
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 38 deletions.
10 changes: 5 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,11 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
effBuff += eff
nestCtx += sym -> rhsTy
goStats(stats)
case TermDefinition(Fun, sym, params, sig, Some(body), _) :: stats =>
typeFunDef(sym, params match {
case S(params) => Term.Lam(params, body)
case _ => body // * may be a case expressions
}, sig, ctx)
case TermDefinition(Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _) :: stats =>
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
goStats(stats)
case TermDefinition(Fun, sym, Nil, sig, Some(body), _) :: stats =>
typeFunDef(sym, body, sig, ctx) // * may be a case expressions
goStats(stats)
case TermDefinition(Fun, sym, _, S(sig), None, _) :: stats =>
ctx += sym -> typeType(sig)
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ sealed abstract class Defn:
final case class TermDefn(
k: syntax.TermDefKind,
sym: TermSymbol,
params: Opt[Ls[Param]],
params: Ls[ParamList],
body: Block,
) extends Defn

Expand Down
26 changes: 17 additions & 9 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ import Scope.scope
import hkmc2.syntax.ImmutVal
import hkmc2.semantics.Elaborator
import hkmc2.syntax.Tree
import hkmc2.syntax.TermDefKind
import hkmc2.semantics.TopLevelSymbol
import hkmc2.semantics.MemberSymbol
import hkmc2.semantics.ParamList
import hkmc2.codegen.Value.Lam


// TODO factor some logic for other codegen backends
Expand Down Expand Up @@ -99,11 +102,15 @@ class JSBuilder extends CodeBuilder:
result(Value.This(sym))
val (thisProxy, res) = scope.nestRebindThis(defn.sym):
val defnJS = defn match
case TermDefn(syntax.Fun, sym, N, body) =>
case TermDefn(syntax.Fun, sym, Nil, body) =>
TODO("getters")
case TermDefn(syntax.Fun, sym, S(ps), bod) =>
val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
doc"function ${sym.nme}($vars) { #{ # ${body(bod)} #} # }"
case TermDefn(syntax.Fun, sym, ParamList(_, ps) :: pss, bod) =>
val paramList = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
val result = pss.foldRight(bod)({
case (ParamList(_, ps), block) =>
Return(Lam(ps, block), false)
})
doc"function ${sym.nme}(${paramList}) { #{ # ${body(result)} #} # }"
case ClsDefn(sym, syntax.Cls, mtds, flds, ctor) =>
val clsDefn = sym.defn.getOrElse(die)
val clsParams = clsDefn.paramsOpt.getOrElse(Nil)
Expand All @@ -118,11 +125,12 @@ class JSBuilder extends CodeBuilder:
}) { #{ # ${
ctorCode.stripBreaks
} #} # }${
mtds.map: td =>
val vars = td.params.getOrElse(Nil).map(p => scope.allocateName(p.sym)).mkDocument(", ")
doc" # ${td.sym.nme}($vars) { #{ # ${
body(td.body)
} #} # }"
mtds.map:
case td @ TermDefn(_, _, ParamList(_, ps) :: Nil, _) =>
val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
doc" # ${td.sym.nme}($vars) { #{ # ${
body(td.body)
} #} # }"
.mkDocument(" ")
}${
if mtds.exists(_.sym.nme == "toString")
Expand Down
15 changes: 8 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,14 @@ extends Importer:
case S(t) => typeParams(t)
case N => (N, ctx)
// Add parameters to context
val (ps, newCtx) = td.paramLists.foldLeft((Ls[Param](), newCtx1)):
case ((ps, ctx), t) => params(t)(using ctx).mapFirst(ps ++ _)
.mapFirst(some)
val (pss, newCtx) =
td.paramLists.foldLeft(Ls[ParamList](), newCtx1)({case ((pss, ctx), ps) =>
val (qs, newCtx) = params(ps)(using ctx)
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
})
val b = rhs.map(term(_)(using newCtx))
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
val tdf = TermDefinition(k, sym, ps,
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r)
val tdf = TermDefinition(k, sym, pss, td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r)
sym.defn = S(tdf)
tdf
go(sts, tdf :: acc)
Expand Down Expand Up @@ -592,8 +593,8 @@ extends Importer:
def computeVariances(s: Statement): Unit =
val trav = VarianceTraverser()
def go(s: Statement): Unit = s match
case TermDefinition(k, sym, ps, sign, body, r) =>
ps.foreach(_.foreach(trav.traverseType(S(false))))
case TermDefinition(k, sym, pss, sign, body, r) =>
pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false))))
sign.foreach(trav.traverseType(S(true)))
body match
case S(b) =>
Expand Down
17 changes: 14 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ sealed trait Statement extends AutoLocated:
case Assgn(lhs, rhs) => lhs :: rhs :: Nil
case Deref(term) => term :: Nil
case TermDefinition(k, _, ps, sign, body, res) =>
ps.toList.flatMap(_.flatMap(_.subTerms)) ::: sign.toList ::: body.toList
ps.toList.flatMap(_.subTerms) ::: sign.toList ::: body.toList
case cls: ClassDef =>
cls.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: cls.body.blk :: Nil
case td: TypeDef =>
Expand Down Expand Up @@ -175,7 +175,7 @@ sealed trait Statement extends AutoLocated:
case Error => "<error>"
case Tup(fields) => fields.map(_.showDbg).mkString("[", ", ", "]")
case TermDefinition(k, sym, ps, sign, body, res) => s"${k.str} ${sym}${
ps.fold("")(_.map(_.showDbg).mkString("(", ", ", ")"))
ps.map(_.showDbg).mkString("")
}${sign.fold("")(": "+_.showDbg)}${
body match
case S(x) => " = " + x.showDbg
Expand All @@ -194,7 +194,7 @@ final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement
final case class TermDefinition(
k: TermDefKind,
sym: TermSymbol,
params: Opt[Ls[Param]],
params: Ls[ParamList],
sign: Opt[Term],
body: Opt[Term],
resSym: FlowSymbol,
Expand Down Expand Up @@ -272,6 +272,17 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op

object FldFlags { val empty: FldFlags = FldFlags(false, false, false) }

final case class ParamListFlags(ctx: Bool):
def showDbg: Str = (if ctx then "ctx " else "")
override def toString: String = "" + showDbg + ""

object ParamListFlags:
val empty = ParamListFlags(false)

final case class ParamList(flags: ParamListFlags, params: Ls[Param]):
def subTerms: Ls[Term] = params.flatMap(_.subTerms)
def showDbg: Str = flags.showDbg + params.mkString("(", ", ", ")")

trait FldImpl extends AutoLocated:
self: Fld =>
def children: Ls[Located] = self.value :: self.asc.toList ::: Nil
Expand Down
10 changes: 6 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TypeChecker(using raise: Raise):
ts.defn match
case S(td: TermDefinition) =>
td.params match
case N => P.Flow(td.resSym)
case Nil => P.Flow(td.resSym)
case Blk(stats, res) =>
// val p1 = stats.map(typeStat)
// val p2 = typeProd(res)
Expand All @@ -40,7 +40,7 @@ class TypeChecker(using raise: Raise):
stats.foreach:
case t: TermDefinition =>
t.sign.map(typeProd)
t.params.map(typeParams)
t.params.map(_.params).map(typeParams)
t.body.map(typeProd)
P.Ctor(LitSymbol(Tree.UnitLit(true)), Nil)
case t: Term =>
Expand All @@ -57,10 +57,12 @@ class TypeChecker(using raise: Raise):
ts.defn match
case S(td: TermDefinition) =>
td.params match
case N =>
case Nil =>
val f = typeProd(r)
constrain(P.exitIf(f, ts, r.refNum, rc), C.Fun(typeProd(tup), C.Flow(app.resSym)))
case S(ps) =>
case ParamList(_, ps) :: Nil =>
// App applies to the leftmost parameter list
// TODO: how to recursively check the subsequent Apps (if any)?
if ps.size != args.size then
raise(ErrorReport(
msg"Expected ${ps.size.toString} arguments, but got ${
Expand Down
62 changes: 53 additions & 9 deletions hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls
Original file line number Diff line number Diff line change
@@ -1,16 +1,60 @@
:js


// FIXME elbaoration is currently wrong

:sjs
fun foo(x)(y) = x * y

fun f(n1: Int): Int = n1
//│ JS:
//│ function f(n1) { return n1 }; undefined
f(42)
//│ JS:
//│ function foo(x, y) { return x * y }; undefined
//│ this.f(42)
//│ = 42

:sjs
fun foo(x)(y)(z) = x * y + z
fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2)
//│ JS:
//│ function foo(x, y, z) { let tmp; tmp = x * y; return tmp + z }; undefined
//│ function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2 } }; undefined
f(4)(2)
//│ JS:
//│ let tmp; tmp = this.f(4); tmp(2)
//│ = 42

fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3
//│ JS:
//│ function f(n1) {
//│ return (n2) => {
//│ return (n3) => {
//│ let tmp, tmp1, tmp2;
//│ tmp = 10 * n1;
//│ tmp1 = tmp + n2;
//│ tmp2 = 10 * tmp1;
//│ return tmp2 + n3
//│ }
//│ }
//│ };
//│ undefined
f(4)(2)(9)
//│ JS:
//│ let tmp, tmp1; tmp = this.f(4); tmp1 = tmp(2); tmp1(9)
//│ = 429

fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3) + n4
//│ JS:
//│ function f(n1) {
//│ return (n2) => {
//│ return (n3) => {
//│ return (n4) => {
//│ let tmp, tmp1, tmp2, tmp3, tmp4;
//│ tmp = 10 * n1;
//│ tmp1 = tmp + n2;
//│ tmp2 = 10 * tmp1;
//│ tmp3 = tmp2 + n3;
//│ tmp4 = 10 * tmp3;
//│ return tmp4 + n4
//│ }
//│ }
//│ }
//│ };
//│ undefined
f(3)(0)(3)(1)
//│ JS:
//│ let tmp, tmp1, tmp2; tmp = this.f(3); tmp1 = tmp(0); tmp2 = tmp1(3); tmp2(1)
//│ = 3031

0 comments on commit c3071ac

Please sign in to comment.