Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code generation for bbml #245

Merged
merged 36 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f654551
Map builtin operators
NeilKleistGao Nov 26, 2024
270ff3e
WIP: Fix num and sel code gen & add predef
NeilKleistGao Nov 26, 2024
6896564
WIP: Add throw
NeilKleistGao Nov 26, 2024
c25ebd5
WIP: Fix predef
NeilKleistGao Nov 26, 2024
b9e3782
Fix getter generation & split.end typing
NeilKleistGao Nov 27, 2024
dc58521
Generate code for region & ref
NeilKleistGao Nov 27, 2024
7c56b26
WIP: Add tests
NeilKleistGao Nov 29, 2024
a40e1cc
Merge from main branch
NeilKleistGao Nov 29, 2024
c2deec1
Fix class matching typing
NeilKleistGao Nov 30, 2024
ddeb04a
Merge from main branch
NeilKleistGao Nov 30, 2024
a81aae2
Fix num ops code gen
NeilKleistGao Nov 30, 2024
dae1d36
Fix getter generation
NeilKleistGao Dec 1, 2024
4afe46f
Clean
NeilKleistGao Dec 1, 2024
b0c3f38
Merge from mlscript
NeilKleistGao Dec 1, 2024
a683de6
Update bbml predef
NeilKleistGao Dec 1, 2024
af42e40
Minor
NeilKleistGao Dec 1, 2024
182237f
Reuse getBuiltin
NeilKleistGao Dec 4, 2024
5443f1f
Fix GetElem implementation
NeilKleistGao Dec 4, 2024
89f5c6c
Merge branch 'hkmc2' of https://github.com/hkust-taco/mlscript into b…
NeilKleistGao Dec 5, 2024
ebe9103
Rerun tests
NeilKleistGao Dec 5, 2024
8e63fae
Move defn check to symbol
NeilKleistGao Dec 6, 2024
0ff406f
Fix module getter generation and ctx.get use
NeilKleistGao Dec 6, 2024
ce6f2cf
Refactor getter typing
NeilKleistGao Dec 6, 2024
e0a22ad
Update hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
NeilKleistGao Dec 8, 2024
313d680
Merge from main branch
NeilKleistGao Dec 8, 2024
0bb46f1
WIP: Fix shadowing
NeilKleistGao Dec 8, 2024
5557984
Forbid getters from being defined in non-module and non-function scopes
NeilKleistGao Dec 8, 2024
52e2bd2
Update hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
NeilKleistGao Dec 12, 2024
ae7bfd5
Remove top-level getter selection check
NeilKleistGao Dec 12, 2024
bbc8f9f
Add empty lines
NeilKleistGao Dec 12, 2024
710831f
Fix diff
NeilKleistGao Dec 12, 2024
6e7a96a
Fix diff
NeilKleistGao Dec 12, 2024
7ca79b9
Fix getter logic
NeilKleistGao Dec 12, 2024
634748d
Merge from main branch
NeilKleistGao Dec 12, 2024
e0c2e5e
Minor changes
NeilKleistGao Dec 12, 2024
ea0c9c8
Remove the unused import and update the comment
NeilKleistGao Dec 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/BbmlDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import hkmc2.bbml.*
abstract class BbmlDiffMaker extends JSBackendDiffMaker:

val bbPreludeFile = file / os.up / os.RelPath("bbPrelude.mls")
val bbPredefFile = file / os.up / os.up / os.up /"mlscript-compile"/"bbml"/"Predef.mls"

val bbmlOpt = new NullaryCommand("bbml"):
override def onSet(): Unit =
Expand All @@ -18,7 +19,20 @@ abstract class BbmlDiffMaker extends JSBackendDiffMaker:
if file =/= bbPreludeFile then
importFile(bbPreludeFile, verbose = false)


override def init(): Unit =
if bbmlOpt.isSet then
import syntax.*
import Tree.*
import Keyword.*
given raise: Raise = d =>
output(s"Error: $d")
()
processTrees(
Modified(`import`, N, StrLit(bbPredefFile.toString))
:: Open(Ident("Predef"))
:: Nil)
super.init()

lazy val bbCtx =
given Elaborator.Ctx = curCtx
bbml.BbCtx.init(_ => die)
Expand Down
26 changes: 22 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ final case class BbCtx(
env: HashMap[Uid[Symbol], GeneralType]
):
def +=(p: Symbol -> GeneralType): Unit = env += p._1.uid -> p._2
def get(sym: Symbol): Option[GeneralType] = env.get(sym.uid) orElse parent.dlof(_.get(sym))(None)
def get(sym: Symbol): Option[GeneralType] =
if BbCtx.builtinOps(sym.nme) then ctx.get(s"#${sym.nme}") match
NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case S(Ctx.SelElem(_, _, symOpt)) => symOpt.flatMap(getImpl(_))
case _ => N
else getImpl(sym)
private def getImpl(sym: Symbol): Option[GeneralType] = env.get(sym.uid) orElse parent.dlof(_.getImpl(sym))(None)
def getCls(name: Str): Option[TypeSymbol] =
for
elem <- ctx.get(name)
Expand All @@ -46,6 +51,7 @@ object BbCtx:
def numTy(using ctx: BbCtx): Type = ClassLikeType(ctx.getCls("Num").get, Nil)
def strTy(using ctx: BbCtx): Type = ClassLikeType(ctx.getCls("Str").get, Nil)
def boolTy(using ctx: BbCtx): Type = ClassLikeType(ctx.getCls("Bool").get, Nil)
def errTy(using ctx: BbCtx): Type = ClassLikeType(ctx.getCls("Error").get, Nil)
private def codeBaseTy(ct: TypeArg, cr: TypeArg, isVar: TypeArg)(using ctx: BbCtx): Type =
ClassLikeType(ctx.getCls("CodeBase").get, ct :: cr :: isVar :: Nil)
def codeTy(ct: Type, cr: Type)(using ctx: BbCtx): Type =
Expand All @@ -58,6 +64,8 @@ object BbCtx:
ClassLikeType(ctx.getCls("Ref").get, Wildcard(ct, ct) :: Wildcard.out(sk) :: Nil)
def init(raise: Raise)(using Elaborator.State, Elaborator.Ctx): BbCtx =
new BbCtx(raise, summon, None, 1, HashMap.empty)

val builtinOps = Set("+", "-", "*", "/", "<", ">", "<=", ">=", "==", "!=", "&&", "||")
end BbCtx


Expand Down Expand Up @@ -190,6 +198,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
case _: UnitLit => Top
case _: BoolLit => BbCtx.boolTy), Bot, Bot)
case Ref(sym: Symbol) if sym.nme === "error" => (Bot, Bot, Bot)
case Ref(sym: Symbol) if BbCtx.builtinOps(sym.nme) => ctx.get(sym) match
case S(ty) => (tryMkMono(ty, code), Bot, Bot)
case N =>
(error(msg"Cannot quote operator ${sym.nme}" -> code.toLoc :: Nil), Bot, Bot)
case Lam(PlainParamList(params), body) =>
val nestCtx = ctx.nextLevel
given BbCtx = nestCtx
Expand Down Expand Up @@ -264,7 +276,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
split match
case Split.Cons(Branch(scrutinee, Pattern.ClassLike(sym, _, _, _), cons), alts) =>
// * Pattern matching for classes
val (clsTy, tv, emptyTy) = ctx.getCls(sym.nme).flatMap(_.defn) match
val (clsTy, tv, emptyTy) = sym.asCls.flatMap(_.defn) match
case S(cls) =>
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(N))), (freshVar(N)), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
case _ =>
Expand Down Expand Up @@ -311,7 +323,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
case Split.Else(alts) => sign match
case S(sign) => ascribe(alts, sign)
case _ => typeCheck(alts)
case Split.End => ???
case Split.End => (Bot, Bot)

// * Note: currently, the returned type is not used or useful, but it could be in the future
private def ascribe(lhs: Term, rhs: GeneralType)(using ctx: BbCtx): (GeneralType, Type) =
Expand Down Expand Up @@ -431,6 +443,8 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
goStats(stats)
case (clsDef: ClassDef) :: stats =>
goStats(stats)
case Import(sym, pth) :: stats =>
goStats(stats) // TODO:
goStats(stats)
val (ty, eff) = typeCheck(res)
(ty, effBuff.foldLeft(eff)((res, e) => res | e))
Expand Down Expand Up @@ -524,7 +538,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val sk = freshVar(N)
constrain(tryMkMono(regTy, reg), BbCtx.regionTy(sk))
(BbCtx.refTy(tryMkMono(valTy, value), sk), sk | (regEff | valEff))
case Term.Assgn(lhs, rhs) =>
case Term.SetRef(lhs, rhs) =>
val (lhsTy, lhsEff) = typeCheck(lhs)
val (rhsTy, rhsEff) = typeCheck(rhs)
val sk = freshVar(N)
Expand All @@ -543,6 +557,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
(BbCtx.codeTy(ty, ctxTy), eff)
case _: Term.Unquoted =>
(error(msg"Unquote should nest in quasiquote" -> t.toLoc :: Nil), Bot)
case Throw(e) =>
val (ty, eff) = typeCheck(e)
constrain(tryMkMono(ty, e), BbCtx.errTy)
(Bot, eff)
case Term.Error =>
(Bot, Bot) // TODO: error type?
case _ =>
Expand Down
27 changes: 20 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,13 @@ class Lowering(using TL, Raise, Elaborator.State):
if usesResTmp then k(Value.Ref(l))
else k(Value.Lit(syntax.Tree.UnitLit(true))) // * it seems this currently never happens
)

NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case sel @ Sel(prefix, nme) =>
setupSelection(prefix, nme, sel.sym)(k)
case SelProj(prefix, _, proj) =>
setupSelection(prefix, proj, N)(k)
case sel @ SynthSel(prefix, nme) =>
subTerm(prefix): p =>
k(Select(p, nme)(sel.sym))

case sel @ Sel(prefix, nme) =>
setupSelection(prefix, nme, sel.sym)(k)


case New(cls, as) =>
subTerm(cls): sr =>
def rec(as: Ls[st], asr: Ls[Path]): Block = as match
Expand All @@ -296,7 +294,22 @@ class Lowering(using TL, Raise, Elaborator.State):
term(finallyDo)(_ => End()),
k(Value.Ref(l))
)

NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case Region(reg, body) =>
Assign(reg, Instantiate(Select(Value.Ref(State.globalThisSymbol), Tree.Ident("Region"))(N), Nil), term(body)(k))
case RegRef(reg, value) =>
def rec(as: Ls[st], asr: Ls[Path]): Block = as match
case Nil => k(Instantiate(Select(Value.Ref(State.globalThisSymbol), Tree.Ident("Ref"))(N), asr.reverse))
case a :: as =>
subTerm(a): ar =>
rec(as, ar :: asr)
rec(reg :: value :: Nil, Nil)
case Deref(ref) =>
subTerm(ref): r =>
k(Select(r, Tree.Ident("value"))(N))
case SetRef(lhs, rhs) =>
subTerm(lhs): ref =>
subTerm(rhs): value =>
AssignField(ref, Tree.Ident("value"), value, k(value))(N)
case Error => End("error")

// case _ =>
Expand Down
15 changes: 11 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
case N =>
ts.id.name
case ts: semantics.BlockMemberSymbol => // this means it's a locally-defined member
ts.nme
ts.defn match
case S(semantics.TermDefinition(_, syntax.Fun, _, Nil, _, _, _, _)) => doc"${ts.nme}()"
case _ => doc"${ts.nme}"
// ts.trmTree
case ts: semantics.InnerSymbol =>
summon[Scope].findThis_!(ts)
Expand Down Expand Up @@ -92,7 +94,12 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
else err(msg"Cannot call non-unary builtin symbol '${l.nme}'")
case Call(Value.Ref(l: BuiltinSymbol), args) =>
err(msg"Illeal arity for builtin symbol '${l.nme}'")

NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case Call(s @ Select(_, id), lhs :: rhs :: Nil) =>
Elaborator.ctx.Builtins.tryMapOp(id.name) match
case S(jsOp) =>
val res = doc"${result(lhs)} ${jsOp} ${result(rhs)}"
if needsParens(jsOp) then doc"(${res})" else res
case N => setupCall(result(s), (result(lhs) :: result(rhs) :: Nil).mkDocument(", "))
case Call(fun, args) =>
val base = fun match
case _: Value.Lam => doc"(${result(fun)})"
Expand All @@ -107,7 +114,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
val name = id.name
doc"${result(qual)}${
if JSBuilder.isValidFieldName(name)
then doc".$name"
then doc".$name${if Elaborator.ctx.isGetter(name) then "()" else ""}"
else name.toIntOption match
case S(index) => s"[$index]"
case N => s"[${JSBuilder.makeStringLiteral(name)}]"
Expand Down Expand Up @@ -146,7 +153,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
S(defn.sym).collectFirst{ case s: InnerSymbol => s }):
defn match
case FunDefn(sym, Nil, body) =>
TODO("getters")
doc"function ${sym.nme}() { #{ # ${this.body(body)} #} # }"
case FunDefn(sym, ps :: pss, bod) =>
val result = pss.foldRight(bod):
case (ps, block) =>
Expand Down
34 changes: 26 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,30 @@ object Elaborator:

def withMembers(members: Iterable[Str -> MemberSymbol[?]], out: Opt[Symbol] = N): Ctx =
copy(env = env ++ members.map:
case (nme, sym) => nme -> (
out orElse outer match
case S(outer) => Ctx.SelElem(outer, sym.nme, S(sym))
case N => sym: Ctx.Elem
)
case (nme, sym) =>
val elem = out orElse outer match
case S(outer) => Ctx.SelElem(outer, sym.nme, S(sym))
case N => sym: Ctx.Elem
sym match
case sym: BlockMemberSymbol =>
sym.trees.headOption match // TODO: Support module overloading
case S(td @ TermDef(Fun, _, _)) if td.rhs.isDefined && td.paramLists.isEmpty =>
nme -> Ctx.GetElem(elem)
case _ => nme -> elem
case _ => sym.defn match
case S(TermDefinition(_, Fun, _, Nil, _, _, _, _)) => nme -> Ctx.GetElem(elem)
case _ => nme -> elem
)

def nest(outer: Opt[InnerSymbol]): Ctx = Ctx(outer, Some(this), Map.empty)

def get(name: Str): Opt[Ctx.Elem] =
env.get(name).orElse(parent.flatMap(_.get(name)))
private def getImpl(name: Str): Opt[Ctx.Elem] = env.get(name).orElse(parent.flatMap(_.getImpl(name)))
def isGetter(name: Str): Bool = getImpl(name) match
NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case S(_: Ctx.GetElem) => true
case _ => false
def get(name: Str): Opt[Ctx.Elem] = getImpl(name).map:
case Ctx.GetElem(base) => base
case elem => elem
def getOuter: Opt[InnerSymbol] = outer.orElse(parent.flatMap(_.getOuter))

// * Invariant: We expect that the top-level context only contain hard-coded symbols like `globalThis`
Expand All @@ -78,6 +91,7 @@ object Elaborator:
val Num = assumeBuiltinCls("Num")
val Str = assumeBuiltinCls("Str")
val Predef = assumeBuiltinMod("Predef")
def tryMapOp(op: Str): Opt[Str] = aliasOps.get(op)

object Ctx:
abstract class Elem:
Expand All @@ -97,6 +111,10 @@ object Elaborator:
Term.SynthSel(base.ref(Ident(base.nme)),
new Tree.Ident(nme).withLocOf(id))(symOpt)
def symbol = symOpt
final case class GetElem(val base: Elem) extends Elem:
def nme: Str = base.nme
def ref(id: Tree.Ident): Term = base.ref(id)
NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
def symbol: Opt[Symbol] = base.symbol
given Conversion[Symbol, Elem] = RefElem(_)
val empty: Ctx = Ctx(N, N, Map.empty)

Expand Down Expand Up @@ -279,7 +297,7 @@ extends Importer:
case App(Ident("&"), Tree.Tup(lhs :: rhs :: Nil)) =>
Term.CompType(term(lhs), term(rhs), false)
case App(Ident(":="), Tree.Tup(lhs :: rhs :: Nil)) =>
Term.Assgn(term(lhs), term(rhs))
Term.SetRef(term(lhs), term(rhs))
case App(Ident("#"), Tree.Tup(SynthSel(pre, idn: Ident) :: (idp: Ident) :: Nil)) =>
Term.SelProj(term(pre), term(idn), idp)
case App(Ident("#"), Tree.Tup(SynthSel(pre, Ident(name)) :: App(Ident(proj), args) :: Nil)) =>
Expand Down
5 changes: 5 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum Term extends Statement:
case RegRef(reg: Term, value: Term)
case Assgn(lhs: Term, rhs: Term)
case Deref(ref: Term)
case SetRef(ref: Term, value: Term)
case Ret(result: Term)
case Throw(result: Term)
case Try(body: Term, finallyDo: Term)
Expand Down Expand Up @@ -71,7 +72,9 @@ enum Term extends Statement:
case Region(name, body) => "region expression"
case RegRef(reg, value) => "reference creation"
case Assgn(lhs, rhs) => "assignment"
case SetRef(ref, value) => "set"
NeilKleistGao marked this conversation as resolved.
Show resolved Hide resolved
case Deref(ref) => "dereference"
case Throw(e) => "throw"
end Term

import Term.*
Expand Down Expand Up @@ -111,6 +114,7 @@ sealed trait Statement extends AutoLocated with ProductWithExtraInfo:
case Region(_, body) => body :: Nil
case RegRef(reg, value) => reg :: value :: Nil
case Assgn(lhs, rhs) => lhs :: rhs :: Nil
case SetRef(lhs, rhs) => lhs :: rhs :: Nil
case Deref(term) => term :: Nil
case TermDefinition(_, k, _, ps, sign, body, res, _) =>
ps.toList.flatMap(_.subTerms) ::: sign.toList ::: body.toList
Expand Down Expand Up @@ -178,6 +182,7 @@ sealed trait Statement extends AutoLocated with ProductWithExtraInfo:
case Region(name, body) => s"region ${name.nme} in ${body.showDbg}"
case RegRef(reg, value) => s"(${reg.showDbg}).ref ${value.showDbg}"
case Assgn(lhs, rhs) => s"${lhs.showDbg} := ${rhs.showDbg}"
case SetRef(lhs, rhs) => s"${lhs.showDbg} := ${rhs.showDbg}"
case Deref(term) => s"!$term"
case CompType(lhs, rhs, pol) => s"${lhs.showDbg} ${if pol then "|" else "&"} ${rhs.showDbg}"
case Error => "<error>"
Expand Down
31 changes: 31 additions & 0 deletions hkmc2/shared/src/test/mlscript-compile/bbml/Predef.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
const Predef$class = class Predef {
constructor() {

}
checkArgs(functionName, expected, got) {
let scrut, name, scrut1, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6;
scrut = got != expected;
if (scrut) {
scrut1 = functionName.length > 0;
if (scrut1) {
tmp = " '".concat(functionName);
tmp1 = tmp.concat("'");
} else {
tmp1 = "";
}
name = tmp1;
tmp2 = "Function".concat(name);
tmp3 = tmp2.concat(" expected ");
tmp4 = tmp3.concat(expected);
tmp5 = tmp4.concat(" arguments but got ");
tmp6 = tmp5.concat(got);
throw new Error.class(tmp6);
} else {
return undefined;
}
}
toString() { return "Predef"; }
}; const Predef = new Predef$class;
Predef.class = Predef$class;
undefined
export default Predef;
7 changes: 7 additions & 0 deletions hkmc2/shared/src/test/mlscript-compile/bbml/Predef.mls
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module Predef with ...

fun checkArgs(functionName, expected, got) =
if got != expected then
let name = if functionName.Str#length > 0 then " '".Str#concat(functionName).Str#concat("'") else ""
throw new Error("Function".Str#concat(name).Str#concat(" expected ").Str#concat(expected).Str#concat(" arguments but got ").Str#concat(got))
else ()
1 change: 0 additions & 1 deletion hkmc2/shared/src/test/mlscript/basics/Getters.mls
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ get foo = 123
// * These two should be equivalent and should *not* be getters:

fun f = x => x
//│ /!!!\ Uncaught error: scala.NotImplementedError: getters (of class String)

fun f(x) = x

Expand Down
Loading
Loading