Skip to content

Commit

Permalink
Improve implementation of builtin operators
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Nov 13, 2024
1 parent ff2c7a7 commit f3c4ccb
Show file tree
Hide file tree
Showing 34 changed files with 319 additions and 291 deletions.
2 changes: 1 addition & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/BbmlDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ abstract class BbmlDiffMaker extends JSBackendDiffMaker:

lazy val bbCtx =
given Elaborator.Ctx = curCtx
bbml.BbCtx.init(_ => die, curCtx.allMembers)
bbml.BbCtx.init(_ => die)


var bbmlTyper: Opt[BBTyper] = None
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/DiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ abstract class DiffMaker:
val expectTypeErrors = NullaryCommand("e")
val expectRuntimeErrors = NullaryCommand("re")
val expectCodeGenErrors = NullaryCommand("ge")
def expectRuntimeOrCodeGenErrors = expectRuntimeErrors.isSet || expectCodeGenErrors.isSet
val allowRuntimeErrors = NullaryCommand("allowRuntimeErrors")
val expectWarnings = NullaryCommand("w")
val showRelativeLineNums = NullaryCommand("showRelativeLineNums")
Expand Down Expand Up @@ -174,7 +175,7 @@ abstract class DiffMaker:
unexpected("runtime error", blockLineNum)
case Diagnostic.Source.Runtime =>
runtimeErrors += 1
if expectRuntimeErrors.isUnset && !tolerateErrors then
if !expectRuntimeOrCodeGenErrors && !tolerateErrors then
failures += globalStartLineNum
unexpected("runtime error", blockLineNum)
case Diagnostic.Kind.Warning =>
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/JSBackendDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
output(jsStr)
def mkQuery(prefix: Str, jsStr: Str) =
val queryStr = jsStr.replaceAll("\n", " ")
val (reply, stderr) = host.query(queryStr, expectRuntimeErrors.isUnset && fixme.isUnset && todo.isUnset)
val (reply, stderr) = host.query(queryStr, !expectRuntimeOrCodeGenErrors && fixme.isUnset && todo.isUnset)
reply match
case ReplHost.Result(content, stdout) =>
if silent.isUnset then
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object BbCtx:
ClassLikeType(ctx.getCls("Region").get, Wildcard(sk, sk) :: Nil)
def refTy(ct: Type, sk: Type)(using ctx: BbCtx): Type =
ClassLikeType(ctx.getCls("Ref").get, Wildcard(ct, ct) :: Wildcard.out(sk) :: Nil)
def init(raise: Raise, predefs: Map[Str, Symbol])(using Elaborator.State, Elaborator.Ctx): BbCtx =
def init(raise: Raise)(using Elaborator.State, Elaborator.Ctx): BbCtx =
new BbCtx(raise, summon, None, 1, HashMap.empty, HashMap.empty, HashMap.empty)
end BbCtx

Expand Down
31 changes: 1 addition & 30 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,35 +58,7 @@ class Lowering(using TL, Raise, Elaborator.State):
args => subTerm(a.value)(r => acc(r :: args))
)(Nil)
case st.Ref(sym) =>
sym match
case sym: BlockMemberSymbol =>
// k(subst(Value.Ref(sym.modTree.get.symbol)))
k(subst(Value.Ref(sym)))
case sym: LocalSymbol =>
k(subst(Value.Ref(sym)))
case sym: ClassSymbol =>
k(subst(Value.Ref(sym)))
case sym: ModuleSymbol =>
k(subst(Value.Ref(sym)))
case sym: TopLevelSymbol =>
k(subst(Value.Ref(sym)))
/* // * Old logic that auto-lifted `C` ~> `(...) => new C(...)`
case sym: ClassSymbol => // TODO rm
// k(subst(Value.Ref(sym)))
sym.defn match
case N => End("error: class has no declaration") // TODO report?
case S(clsDefn) =>
if clsDefn.kind is syntax.Mod then
k(Value.Ref(sym))
else
val ps = clsDefn.paramsOpt.getOrElse(Nil)
val psSyms = ps.map(p =>
p.copy(sym = new VarSymbol(p.sym.id, summon[Elaborator.State].nextUid)))
k(Value.Lam(psSyms,
Return(Instantiate(Value.Ref(sym),
psSyms.map(p => Value.Ref(p.sym))), false)))
*/
// * Perhaps this `new` insertion should also be removed...?
k(subst(Value.Ref(sym)))
case st.App(f, arg) =>
arg match
case Tup(fs) =>
Expand All @@ -104,7 +76,6 @@ class Lowering(using TL, Raise, Elaborator.State):
rec(as, Nil)
case _ =>
TODO("Other argument list forms")

case st.Blk(Nil, res) => term(res)(k)
case st.Blk(Lit(Tree.UnitLit(true)) :: stats, res) =>
subTerm(st.Blk(stats, res))(k)
Expand Down
48 changes: 23 additions & 25 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import hkmc2.semantics.InnerSymbol
import hkmc2.semantics.ParamList
import hkmc2.codegen.Value.Lam
import hkmc2.semantics.BlockMemberSymbol
import hkmc2.semantics.BuiltinSymbol
import hkmc2.Message.MessageContext


// TODO factor some logic for other codegen backends
Expand Down Expand Up @@ -43,6 +45,11 @@ class JSBuilder extends CodeBuilder:
case Argument
case Operand(prec: Int)

def err(errMsg: Message)(using Raise, Scope): Document =
raise(ErrorReport(errMsg -> N :: Nil,
source = Diagnostic.Source.Compilation))
doc"(()=>{throw globalThis.Error(${result(Value.Lit(syntax.Tree.StrLit(errMsg.show)))})})()"

def getVar(l: Local)(using Raise, Scope): Document = l match
case ts: semantics.TermSymbol =>
ts.owner match
Expand All @@ -63,34 +70,25 @@ class JSBuilder extends CodeBuilder:

def result(r: Result)(using Raise, Scope): Document = r match
case Value.This(sym) => summon[Scope].findThis_!(sym)
case Value.Ref(l) => getVar(l)
case Value.Lit(Tree.StrLit(value)) => JSBuilder.makeStringLiteral(value)
case Value.Lit(lit) => lit.idStr
case Value.Ref(l: BuiltinSymbol) =>
if l.nullary then l.nme
else err(msg"Illegal reference to builtin symbol '${l.nme}'")
case Value.Ref(l) => getVar(l)


// * FIXME: this should be done in the Elaborator

// case Call(Value.Ref(l: semantics.InnerSymbol), lhs :: rhs :: Nil) if builtinOpsMap contains l.nme =>
// case Call(Value.Ref(l), lhs :: rhs :: Nil) if builtinOpsMap contains l.nme =>

case Call(Select(Value.Ref(_: TopLevelSymbol), Tree.Ident(nme)), lhs :: rhs :: Nil) if builtinOpsMap contains nme =>
val op = builtinOpsMap(nme)
val res = doc"${result(lhs)} ${op} ${result(rhs)}"
if needsParens(op) then doc"(${res})" else res
case Call(Select(Value.Ref(_: TopLevelSymbol), Tree.Ident(nme)), lhs :: Nil) if builtinOpsMap contains nme =>
val op = builtinOpsMap(nme)
val res = doc"${op} ${result(lhs)}"
if needsParens(op) then doc"(${res})" else res

case Call(Value.Ref(sym: BlockMemberSymbol), lhs :: rhs :: Nil) if builtinOpsMap contains sym.nme =>
val op = builtinOpsMap(sym.nme)
val res = doc"${result(lhs)} ${op} ${result(rhs)}"
if needsParens(op) then doc"(${res})" else res
case Call(Value.Ref(sym: BlockMemberSymbol), lhs :: Nil) if builtinOpsMap contains sym.nme =>
val op = builtinOpsMap(sym.nme)
val res = doc"${op} ${result(lhs)}"
if needsParens(op) then doc"(${res})" else res

case Call(Value.Ref(l: BuiltinSymbol), lhs :: rhs :: Nil) =>
if l.binary then
val res = doc"${result(lhs)} ${l.nme} ${result(rhs)}"
if needsParens(l.nme) then doc"(${res})" else res
else err(msg"Cannot call non-binary builtin symbol '${l.nme}'")
case Call(Value.Ref(l: BuiltinSymbol), rhs :: Nil) =>
if l.unary then
val res = doc"${l.nme} ${result(rhs)}"
if needsParens(l.nme) then doc"(${res})" else res
else err(msg"Cannot call non-unary builtin symbol '${l.nme}'")
case Call(Value.Ref(l: BuiltinSymbol), args) =>
err(msg"Illeal arity for builtin symbol '${l.nme}'")

case Call(fun, args) =>
val base = fun match
Expand Down
8 changes: 1 addition & 7 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,7 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
post = (res: Split) => s"termSplit: after op >>> $res"
):
// Resolve the operator.
val opRef =
ctx.get(opName) match
case S(sym) => sym.ref(opIdent)
case N =>
raise(ErrorReport(msg"Name not found: $opName" -> tree.toLoc :: Nil))
Term.Error
.withLocOf(opIdent)
val opRef = term(opIdent)
// Elaborate and finish the LHS. Nominate the LHS if necessary.
nominate(ctx, finish(term(lhs)(using ctx))): lhsSymbol =>
// Compose a function that takes the RHS and finishes the application.
Expand Down
25 changes: 23 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ import Keyword.{`let`, `set`}

object Elaborator:

val builtinOpsMap: Map[Str, BuiltinSymbol] =
val binOps: Ls[Str] = Ls(
",",
"+", "-", "*", "/", "%",
"==", "!=", "<", "<=", ">", ">=",
"===",
"&&", "||")
val isUnary: Str => Bool = Set("-", "+", "!", "~").contains
val baseBuiltins = binOps.map: op =>
op -> BuiltinSymbol(op, binary = true, unary = isUnary(op), nullary = false)
.toMap
baseBuiltins
+ (";" -> baseBuiltins(","))
+ ("+." -> baseBuiltins("+"))
+ ("-." -> baseBuiltins("-"))
+ ("*." -> baseBuiltins("*"))
val reservedNames = builtinOpsMap.keySet + "NaN" + "Infinity"

case class Ctx(outer: Opt[InnerSymbol], parent: Opt[Ctx], env: Map[Str, Ctx.Elem]):
def +(local: Str -> Symbol): Ctx = copy(outer, env = env + local.mapSecond(Ctx.RefElem(_)))
def ++(locals: IterableOnce[Str -> Symbol]): Ctx =
Expand Down Expand Up @@ -171,8 +189,11 @@ extends Importer:
ctx.get(name) match
case S(sym) => sym.ref(id)
case N =>
raise(ErrorReport(msg"Name not found: $name" -> tree.toLoc :: Nil))
Term.Error
builtinOpsMap.get(name) match
case S(bi) => bi.ref(id)
case N =>
raise(ErrorReport(msg"Name not found: $name" -> tree.toLoc :: Nil))
Term.Error
case TyApp(lhs, targs) =>
Term.TyApp(term(lhs), targs.map {
case Modified(Keyword.`in`, inLoc, arg) => Term.WildcardTy(S(term(arg)), N)
Expand Down
4 changes: 4 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class VarSymbol(val id: Ident, uid: Int) extends BlockLocalSymbol(id.name, uid)
val name: Str = id.name
// override def toString: Str = s"$name@$uid"

class BuiltinSymbol(val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool) extends Symbol:
def toLoc: Option[Loc] = N
override def toString: Str = s"builtin:$nme"


/** This is the outside-facing symbol associated to a possibly-overloaded
* definition living in a block – e.g., a module or class. */
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ final case class QuantVar(sym: VarSymbol, ub: Opt[Term], lb: Opt[Term])
enum Term extends Statement:
case Error
case Lit(lit: Literal)
case Builtin(id: Tree.Ident, nme: Str)
case Ref(sym: Symbol)(val tree: Tree.Ident, val refNum: Int)
case App(lhs: Term, rhs: Term)(val tree: Tree.App, val resSym: FlowSymbol)
case TyApp(lhs: Term, targs: Ls[Term])
Expand Down Expand Up @@ -77,7 +78,7 @@ sealed trait Statement extends AutoLocated:
case Blk(stats, res) => stats ::: res :: Nil
case _ => subTerms
def subTerms: Ls[Term] = this match
case Error | _: Lit | _: Ref => Nil
case Error | _: Lit | _: Ref | _: Builtin => Nil
case App(lhs, rhs) => lhs :: rhs :: Nil
case FunTy(lhs, rhs, eff) => lhs :: rhs :: eff.toList
case TyApp(pre, tarsg) => pre :: tarsg
Expand Down
9 changes: 3 additions & 6 deletions hkmc2/shared/src/test/mlscript-compile/Example.mjs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import "./Predef.mjs";
class Int {
constructor() {

}
toString() { return "Int"; }
};
const Example$class = class Example {
constructor() {

}
funnySlash(f, arg) {
return f(arg)
}
inc(x) {
return x + 1
}
Expand Down
7 changes: 2 additions & 5 deletions hkmc2/shared/src/test/mlscript-compile/Example.mls
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ import "./Predef.mls"
open Predef


class Int

fun (+): (Int, Int) -> Int


module Example with ...


fun (/) funnySlash(f, arg) = f(arg)

fun inc(x) = x + 1


8 changes: 4 additions & 4 deletions hkmc2/shared/src/test/mlscript/bbml/bbBasics.mls
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ class Pair[A, B](fst: A, snd: B)

:fixme
if 1 < 2 then 1 else 0
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@130),LitPat(BoolLit(true)),Else(Lit(IntLit(1)))),Else(Lit(IntLit(0)))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@112),LitPat(BoolLit(true)),Else(Lit(IntLit(1)))),Else(Lit(IntLit(0)))) (of class hkmc2.semantics.Split$Cons)

:fixme
if false then 1 else "1"
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@132),LitPat(BoolLit(true)),Else(Lit(IntLit(1)))),Else(Lit(StrLit(1)))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@114),LitPat(BoolLit(true)),Else(Lit(IntLit(1)))),Else(Lit(StrLit(1)))) (of class hkmc2.semantics.Split$Cons)


if 1 is Int then 1 else 0
Expand Down Expand Up @@ -240,7 +240,7 @@ test("1")
:fixme
fun fact(n) =
if n > 1 then n * fact(n - 1) else 1
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@146),LitPat(BoolLit(true)),Else(App(Sel(Ref(globalThis:import#bbPredef),Ident(*)),Tup(List(Fld(‹›,Ref(n@144),None), Fld(‹›,App(Sel(Ref(globalThis:block#49),Ident(fact)),Tup(List(Fld(‹›,App(Sel(Ref(globalThis:import#bbPredef),Ident(-)),Tup(List(Fld(‹›,Ref(n@144),None), Fld(‹›,Lit(IntLit(1)),None)))),None)))),None)))))),Else(Lit(IntLit(1)))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@128),LitPat(BoolLit(true)),Else(App(Sel(Ref(globalThis:import#bbPredef),Ident(*)),Tup(List(Fld(‹›,Ref(n@126),None), Fld(‹›,App(Sel(Ref(globalThis:block#49),Ident(fact)),Tup(List(Fld(‹›,App(Sel(Ref(globalThis:import#bbPredef),Ident(-)),Tup(List(Fld(‹›,Ref(n@126),None), Fld(‹›,Lit(IntLit(1)),None)))),None)))),None)))))),Else(Lit(IntLit(1)))) (of class hkmc2.semantics.Split$Cons)

fact
//│ Type: ⊥
Expand All @@ -253,7 +253,7 @@ fact(1)
fun fact2 = case
0 then 1
n then n * fact2(n - 1)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref(caseScrut@155),LitPat(IntLit(0)),Else(Lit(IntLit(1)))),Let(n@156,Ref(caseScrut@155),Else(App(Sel(Ref(globalThis:import#bbPredef),Ident(*)),Tup(List(Fld(‹›,Ref(n@156),None), Fld(‹›,App(Sel(Ref(globalThis:block#52),Ident(fact2)),Tup(List(Fld(‹›,App(Sel(Ref(globalThis:import#bbPredef),Ident(-)),Tup(List(Fld(‹›,Ref(n@156),None), Fld(‹›,Lit(IntLit(1)),None)))),None)))),None))))))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref(caseScrut@137),LitPat(IntLit(0)),Else(Lit(IntLit(1)))),Let(n@138,Ref(caseScrut@137),Else(App(Sel(Ref(globalThis:import#bbPredef),Ident(*)),Tup(List(Fld(‹›,Ref(n@138),None), Fld(‹›,App(Sel(Ref(globalThis:block#52),Ident(fact2)),Tup(List(Fld(‹›,App(Sel(Ref(globalThis:import#bbPredef),Ident(-)),Tup(List(Fld(‹›,Ref(n@138),None), Fld(‹›,Lit(IntLit(1)),None)))),None)))),None))))))) (of class hkmc2.semantics.Split$Cons)

fact2
//│ Type: ⊥
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/bbml/bbBorrowing.mls
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ letreg of r =>
123
if next(it) > 0 then () => 0 else () => clear(b)
k()
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@120),LitPat(BoolLit(true)),Else(Lam(List(),Lit(IntLit(0))))),Else(Lam(List(),App(Sel(Ref(globalThis:block#5),Ident(clear)),Tup(List(Fld(‹›,Ref(b@109),None))))))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@102),LitPat(BoolLit(true)),Else(Lam(List(),Lit(IntLit(0))))),Else(Lam(List(),App(Sel(Ref(globalThis:block#5),Ident(clear)),Tup(List(Fld(‹›,Ref(b@91),None))))))) (of class hkmc2.semantics.Split$Cons)

:e
letreg of r =>
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/bbml/bbCheck.mls
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ high(x => x + 1)

:fixme
(if false then x => x else y => y): [A] -> A -> A
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@105),LitPat(BoolLit(true)),Else(Lam(List(Param(‹›,x@106,None)),Ref(x@106)))),Else(Lam(List(Param(‹›,y@104,None)),Ref(y@104)))) (of class hkmc2.semantics.Split$Cons)
//│ /!!!\ Uncaught error: scala.MatchError: Cons(Branch(Ref($scrut@87),LitPat(BoolLit(true)),Else(Lam(List(Param(‹›,x@88,None)),Ref(x@88)))),Else(Lam(List(Param(‹›,y@86,None)),Ref(y@86)))) (of class hkmc2.semantics.Split$Cons)


fun baz: Int -> (([A] -> A -> A), Int) -> Int
Expand Down
Loading

0 comments on commit f3c4ccb

Please sign in to comment.