Skip to content

Commit

Permalink
Support matching on builtins and improve builtin symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Nov 29, 2024
1 parent 11d66ad commit 7158626
Show file tree
Hide file tree
Showing 19 changed files with 243 additions and 78 deletions.
2 changes: 1 addition & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/DiffTestRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ abstract class CompileTestRunner(state: DiffTestRunner.State)

println(s"Compiling: $relativeName")

val preludePath = dir/"decls"/"Prelude.mls"
val preludePath = dir/"mlscript"/"decls"/"Prelude.mls"

MLsCompiler(preludePath).compileModule(file)

Expand Down
9 changes: 6 additions & 3 deletions hkmc2/jvm/src/test/scala/hkmc2/JSBackendDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import scala.collection.mutable
import mlscript.utils.*, shorthands.*
import utils.*

import semantics.*
import codegen.*
import codegen.js.{JSBuilder, JSBuilderArgNumSanityChecks, JSBuilderSelSanityChecks}
import document.*
import codegen.Block
Expand Down Expand Up @@ -49,9 +51,10 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
if js.isSet then
val low = ltl.givenIn:
new codegen.Lowering with codegen.LoweringSelSanityChecks(noSanityCheck.isUnset)
val jsb = new JSBuilder with JSBuilderArgNumSanityChecks(noSanityCheck.isUnset) with JSBuilderSelSanityChecks(noSanityCheck.isUnset)
import semantics.*
import codegen.*
given Elaborator.Ctx = curCtx
val jsb = new JSBuilder
with JSBuilderArgNumSanityChecks(noSanityCheck.isUnset)
with JSBuilderSelSanityChecks(noSanityCheck.isUnset)
val le = low.program(blk)
if showLoweredTree.isSet then
output(s"Lowered:")
Expand Down
9 changes: 7 additions & 2 deletions hkmc2/jvm/src/test/scala/hkmc2/MLsDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ abstract class MLsDiffMaker extends DiffMaker:
val showUCS = Command("ucs"): ln =>
ln.split(" ").iterator.map(x => "ucs:" + x.trim).toSet

given Elaborator.State = new Elaborator.State
given Elaborator.State = new Elaborator.State:
override def dbg: Bool =
dbgParsing.isSet
|| dbgElab.isSet
|| debug.isSet

val etl = new TraceLogger:
override def doTrace = dbgElab.isSet || scope.exists:
Expand All @@ -63,11 +67,12 @@ abstract class MLsDiffMaker extends DiffMaker:
if doTrace then super.trace(pre, post)(thunk)
else thunk

var curCtx = Elaborator.State.init.nest(N)
var curCtx = Elaborator.State.init


override def run(): Unit =
if file =/= preludeFile then importFile(preludeFile, verbose = false)
curCtx = curCtx.nest(N)
super.run()


Expand Down
77 changes: 47 additions & 30 deletions hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,30 @@ import hkmc2.syntax.Keyword.`override`
import semantics.Elaborator.State


class ParserSetup(file: os.Path, dbgParsing: Bool)(using Elaborator.State, Raise):

val block = os.read(file)
val fph = new FastParseHelpers(block)
val origin = Origin(file.toString, 0, fph)

val lexer = new syntax.Lexer(origin, dbg = dbgParsing)
val tokens = lexer.bracketedTokens

// if showParse.isSet || dbgParsing.isSet then
// output(syntax.Lexer.printTokens(tokens))

val rules = syntax.ParseRules()
val parser = new syntax.Parser(origin, tokens, rules, raise, dbg = dbgParsing):
def doPrintDbg(msg: => Str): Unit =
// if dbg then output(msg)
if dbg then println(msg)

val result = parser.parseAll(parser.block(allowNewlines = true))

val resultBlk = new syntax.Tree.Block(result)



class MLsCompiler(preludeFile: os.Path):


Expand All @@ -29,40 +53,33 @@ class MLsCompiler(preludeFile: os.Path):

def compileModule(file: os.Path): Unit =

val block = os.read(file)
val fph = new FastParseHelpers(block)
val origin = Origin(file.toString, 0, fph)

val lexer = new syntax.Lexer(origin, dbg = dbgParsing)
val tokens = lexer.bracketedTokens
given Elaborator.State = new Elaborator.State

// if showParse.isSet || dbgParsing.isSet then
// output(syntax.Lexer.printTokens(tokens))
val preludeParse = ParserSetup(preludeFile, dbgParsing)
val mainParse = ParserSetup(file, dbgParsing)

given Elaborator.State = new Elaborator.State
val rules = syntax.ParseRules()
val p = new syntax.Parser(origin, tokens, rules, raise, dbg = dbgParsing):
def doPrintDbg(msg: => Str): Unit =
// if dbg then output(msg)
if dbg then println(msg)
val res = p.parseAll(p.block(allowNewlines = true))
given Elaborator.Ctx = State.init.nest(N)
val wd = file / os.up
val elab = Elaborator(etl, wd)
val resBlk = new syntax.Tree.Block(res)
val (blk, newCtx) = elab.importFrom(resBlk)
val low = ltl.givenIn:
codegen.Lowering()
val jsb = codegen.js.JSBuilder()
val le = low.program(blk)
val baseScp: codegen.js.Scope =
codegen.js.Scope.empty
val nestedScp = baseScp.nest
val je = nestedScp.givenIn:
jsb.program(le, S(file.baseName), wd)
val jsStr = je.stripBreaks.mkString(100)
val out = file / os.up / (file.baseName + ".mjs")
os.write.over(out, jsStr)

val initState = State.init.nest(N)

val (pblk, newCtx) = elab.importFrom(preludeParse.resultBlk)(using initState)

newCtx.nest(N).givenIn:

val (blk, newCtx) = elab.importFrom(mainParse.resultBlk)
val low = ltl.givenIn:
codegen.Lowering()
val jsb = codegen.js.JSBuilder()
val le = low.program(blk)
val baseScp: codegen.js.Scope =
codegen.js.Scope.empty
val nestedScp = baseScp.nest
val je = nestedScp.givenIn:
jsb.program(le, S(file.baseName), wd)
val jsStr = je.stripBreaks.mkString(100)
val out = file / os.up / (file.baseName + ".mjs")
os.write.over(out, jsStr)


end MLsCompiler
Expand Down
12 changes: 0 additions & 12 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,7 @@ case class Define(defn: Defn, rest: Block) extends Block with ProductWithTail
sealed abstract class Defn:
val sym: MemberSymbol[?]

// final case class TermDefn(
// k: syntax.TermDefKind,
// // sym: TermSymbol,
// sym: BlockMemberSymbol,
// params: Ls[ParamList],
// body: Block,
// ) extends Defn
final case class FunDefn(
// k: syntax.TermDefKind,
// sym: TermSymbol,
sym: BlockMemberSymbol,
params: Ls[ParamList],
body: Block,
Expand All @@ -101,13 +92,10 @@ final case class ValDefn(
owner: Opt[InnerSymbol],
k: syntax.Val,
sym: BlockMemberSymbol,
// params: Ls[ParamList],
rhs: Path,
) extends Defn

final case class ClsLikeDefn(
// sym: ClassSymbol,
// sym: MemberSymbol[ClassLikeDef],
sym: MemberSymbol[? <: ClassLikeDef],
k: syntax.ClsLikeKind,
methods: Ls[FunDefn],
Expand Down
10 changes: 7 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ abstract class CodeBuilder:
type Context


class JSBuilder(using Elaborator.State) extends CodeBuilder:
class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:

val builtinOpsBase: Ls[Str] = Ls(
"+", "-", "*", "/", "%",
Expand Down Expand Up @@ -238,10 +238,14 @@ class JSBuilder(using Elaborator.State) extends CodeBuilder:
case N => doc""
t :: e :: returningTerm(rest)
case Match(scrut, Case.Cls(cls, pth) -> trm :: Nil, els, rest) =>
val sd = result(scrut)
val test = cls match
// case _: semantics.ModuleSymbol => doc"=== ${result(pth)}"
case _ => doc"instanceof ${result(pth)}"
val t = doc" # if (${ result(scrut) } $test) { #{ ${
case Elaborator.ctx.Builtins.Str => doc"typeof $sd === 'string'"
case Elaborator.ctx.Builtins.Num => doc"typeof $sd === 'number'"
case Elaborator.ctx.Builtins.Int => doc"globalThis.Number.isInteger($sd)"
case _ => doc"$sd instanceof ${result(pth)}"
val t = doc" # if ($test) { #{ ${
returningTerm(trm)
} #} # }"
val e = els match
Expand Down
37 changes: 32 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ object Elaborator:
val reservedNames = binaryOps.toSet ++ aliasOps.keySet + "NaN" + "Infinity"

case class Ctx(outer: Opt[InnerSymbol], parent: Opt[Ctx], env: Map[Str, Ctx.Elem]):

def +(local: Str -> Symbol): Ctx = copy(outer, env = env + local.mapSecond(Ctx.RefElem(_)))
def ++(locals: IterableOnce[Str -> Symbol]): Ctx =
copy(outer, env = env ++ locals.mapValues(Ctx.RefElem(_)))
def elem_++(locals: IterableOnce[Str -> Ctx.Elem]): Ctx =
copy(outer, env = env ++ locals)

def withMembers(members: Iterable[Str -> MemberSymbol[?]], out: Opt[Symbol] = N): Ctx =
copy(env = env ++ members.map:
case (nme, sym) => nme -> (
Expand All @@ -47,14 +49,29 @@ object Elaborator:
case N => sym: Ctx.Elem
)
)

def nest(outer: Opt[InnerSymbol]): Ctx = Ctx(outer, Some(this), Map.empty)

def get(name: Str): Opt[Ctx.Elem] =
env.get(name).orElse(parent.flatMap(_.get(name)))
def getOuter: Opt[InnerSymbol] = outer.orElse(parent.flatMap(_.getOuter))
lazy val allMembers: Map[Str, Symbol] =
parent.fold(Map.empty)(_.allMembers) ++ env.flatMap:
case (n, re: Ctx.RefElem) => (n, re.sym) :: Nil
case _ => Nil // FIXME?

// * Invariant: We expect that the top-level context only contain hard-coded symbols like `globalThis`
// * and that built-in symbols like Int and Str be imported into another nested context on top of it.
// * It should not be possible to shadow these built-in symbols, so user code should always be compiled
// * in further nested contexts.
// * Method `getBuiltin` is used to look up built-in symbols in the context of builtin symbols.
def getBuiltin(nme: Str): Opt[Ctx.Elem] =
parent.filter(_.parent.nonEmpty).fold(env.get(nme))(_.getBuiltin(nme))
object Builtins:
private def assumeBuiltinCls(nme: Str): ClassSymbol =
getBuiltin(nme)
.getOrElse(throw new NoSuchElementException(s"builtin $nme ${env.keySet} $parent"))
.symbol.getOrElse(throw new NoSuchElementException(s"builtin symbol $nme"))
.asCls.getOrElse(throw new NoSuchElementException(s"builtin class symbol $nme"))
val Int = assumeBuiltinCls("Int")
val Num = assumeBuiltinCls("Num")
val Str = assumeBuiltinCls("Str")

object Ctx:
abstract class Elem:
Expand All @@ -76,8 +93,11 @@ object Elaborator:
def symbol = symOpt
given Conversion[Symbol, Elem] = RefElem(_)
val empty: Ctx = Ctx(N, N, Map.empty)

type Ctxl[A] = Ctx ?=> A
def ctx: Ctxl[Ctx] = summon

transparent inline def ctx(using Ctx): Ctx = summon

class State:
given State = this
val suid = new Uid.Symbol.State
Expand All @@ -86,9 +106,16 @@ object Elaborator:
def init(using State): Ctx = Ctx.empty.copy(env = Map(
"globalThis" -> globalThisSymbol,
))
def dbg: Bool = false
def dbgUid(uid: Uid[Symbol]): Str = if dbg then s"$uid" else ""
transparent inline def State(using state: State): State = state

end Elaborator


import Elaborator.*


class Elaborator(val tl: TraceLogger, val wd: os.Path)
(using val raise: Raise, val state: State)
extends Importer:
Expand Down
16 changes: 8 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class FlowSymbol(label: Str)(using State) extends Symbol:
val outFlows2: mutable.Buffer[Consumer] = mutable.Buffer.empty
val inFlows: mutable.Buffer[ConcreteProd] = mutable.Buffer.empty
override def toString: Str =
label
// s"$label@$uid"
label + State.dbgUid(uid)


sealed trait LocalSymbol extends Symbol
Expand All @@ -86,7 +85,7 @@ class BuiltinSymbol
(val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool)(using State)
extends Symbol:
def toLoc: Option[Loc] = N
override def toString: Str = s"builtin:$nme"
override def toString: Str = s"builtin:$nme${State.dbgUid(uid)}"


/** This is the outside-facing symbol associated to a possibly-overloaded
Expand All @@ -110,7 +109,8 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree])(using State)
lazy val hasLiftedClass: Bool =
modTree.isDefined || trmTree.isDefined || clsTree.exists(_.paramLists.nonEmpty)

override def toString: Str = s"member:$nme"
override def toString: Str =
s"member:$nme${State.dbgUid(uid)}"

end BlockMemberSymbol

Expand Down Expand Up @@ -157,25 +157,25 @@ class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
extends MemberSymbol[ClassDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of classe here
override def toString: Str = s"class:$nme"
override def toString: Str = s"class:$nme${State.dbgUid(uid)}"
/** Compute the arity. */
def arity: Int = tree.paramLists.headOption.fold(0)(_.fields.length)

class ModuleSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
extends MemberSymbol[ModuleDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of module here
override def toString: Str = s"module:${id.name}"
override def toString: Str = s"module:${id.name}${State.dbgUid(uid)}"

class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol[TypeDef]:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here
override def toString: Str = s"module:${id.name}"
override def toString: Str = s"module:${id.name}${State.dbgUid(uid)}"

class TopLevelSymbol(blockNme: Str)(using State)
extends MemberSymbol[ModuleDef] with InnerSymbol:
def nme = blockNme
def toLoc: Option[Loc] = N
override def toString: Str = s"globalThis:$blockNme"
override def toString: Str = s"globalThis:$blockNme${State.dbgUid(uid)}"


2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ enum Tree extends AutoLocated:
case Jux(lhs, rhs) => "juxtaposition"
case SynthSel(prefix, name) => "synthetic selection"
case Sel(prefix, name) => "selection"
case InfixApp(lhs, kw, rhs) => "infix application"
case InfixApp(lhs, kw, rhs) => "infix operation"
case New(body) => "new"
case IfLike(Keyword.`if`, split) => "if expression"
case IfLike(Keyword.`while`, split) => "while expression"
Expand Down
15 changes: 15 additions & 0 deletions hkmc2/shared/src/test/mlscript-compile/Example.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ const Example$class = class Example {
}
inc(x) {
return x + 1;
}
test(x1) {
if (globalThis.Number.isInteger(x1)) {
return "int";
} else {
if (typeof x1 === 'number') {
return "num";
} else {
if (typeof x1 === 'string') {
return "str";
} else {
return "other";
}
}
}
}
toString() { return "Example"; }
}; const Example = new Example$class;
Expand Down
6 changes: 6 additions & 0 deletions hkmc2/shared/src/test/mlscript-compile/Example.mls
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ fun (/) funnySlash(f, arg) = f(arg)

fun inc(x) = x + 1

fun test(x) = if x is
Int then "int"
Num then "num"
Str then "str"
else "other"


3 changes: 0 additions & 3 deletions hkmc2/shared/src/test/mlscript-compile/Predef.mls
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@

declare val String
declare val console

module Predef with ...

fun id(x) = x
Expand Down
Loading

0 comments on commit 7158626

Please sign in to comment.