Skip to content

Commit

Permalink
Make small tweaks and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Dec 13, 2024
1 parent 3127ccf commit 9a10c0a
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 35 deletions.
11 changes: 11 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ sealed abstract class Block extends Product with AutoLocated:
case TryBlock(sub, fin, rst) => sub.definedVars ++ fin.definedVars ++ rst.definedVars
case Label(lbl, bod, rst) => bod.definedVars ++ rst.definedVars

lazy val size: Int = this match
case _: Return | _: Throw | _: End | _: Break | _: Continue => 1
case Begin(sub, rst) => sub.size + rst.size
case Assign(_, _, rst) => 1 + rst.size
case AssignField(_, _, _, rst) => 1 + rst.size
case Match(_, arms, dflt, rst) =>
1 + arms.map(_._2.size).sum + dflt.map(_.size).getOrElse(0) + rst.size
case Define(_, rst) => 1 + rst.size
case TryBlock(sub, fin, rst) => 1 + sub.size + fin.size + rst.size
case Label(_, bod, rst) => 1 + bod.size + rst.size

// TODO conserve if no changes
def mapTail(f: BlockTail => BlockTail): Block = this match
case b: BlockTail => f(b)
Expand Down
6 changes: 5 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ class Lowering(using TL, Raise, Elaborator.State):

case st.Lam(params, body) =>
val (paramLists, bodyBlock) = setupFunctionDef(params :: Nil, body, N)
k(Value.Lam(paramLists.head, bodyBlock))
if k.isInstanceOf[TailOp] || bodyBlock.size <= 5
then k(Value.Lam(paramLists.head, bodyBlock))
else
val l = new TempSymbol(N)
Assign(l, Value.Lam(paramLists.head, bodyBlock), k(l |> Value.Ref.apply))

/*
case t @ st.If(Split.Let(sym, trm, tail)) =>
Expand Down
25 changes: 15 additions & 10 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,16 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
summon[Scope].findThis_!(ts)
case _ => summon[Scope].lookup_!(l)

def result(a: Arg)(using Raise, Scope): Document =
def argument(a: Arg)(using Raise, Scope): Document =
if a.spread then doc"...${result(a.value)}" else result(a.value)

def operand(a: Arg)(using Raise, Scope): Document =
if a.spread then die else subexpression(a.value)

def subexpression(r: Result)(using Raise, Scope): Document = r match
case _: Value.Lam => doc"(${result(r)})"
case _ => result(r)

def result(r: Result)(using Raise, Scope): Document = r match
case Value.This(sym) => summon[Scope].findThis_!(sym)
case Value.Lit(Tree.StrLit(value)) => JSBuilder.makeStringLiteral(value)
Expand All @@ -82,12 +89,12 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:

case Call(Value.Ref(l: BuiltinSymbol), lhs :: rhs :: Nil) =>
if l.binary then
val res = doc"${result(lhs)} ${l.nme} ${result(rhs)}"
val res = doc"${operand(lhs)} ${l.nme} ${operand(rhs)}"
if needsParens(l.nme) then doc"(${res})" else res
else err(msg"Cannot call non-binary builtin symbol '${l.nme}'")
case Call(Value.Ref(l: BuiltinSymbol), rhs :: Nil) =>
if l.unary then
val res = doc"${l.nme} ${result(rhs)}"
val res = doc"${l.nme} ${operand(rhs)}"
if needsParens(l.nme) then doc"(${res})" else res
else err(msg"Cannot call non-unary builtin symbol '${l.nme}'")
case Call(Value.Ref(l: BuiltinSymbol), args) =>
Expand All @@ -96,14 +103,12 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
case Call(s @ Select(_, id), lhs :: rhs :: Nil) =>
Elaborator.ctx.Builtins.getBuiltinOp(id.name) match
case S(jsOp) =>
val res = doc"${result(lhs)} ${jsOp} ${result(rhs)}"
val res = doc"${operand(lhs)} ${jsOp} ${operand(rhs)}"
if needsParens(jsOp) then doc"(${res})" else res
case N => doc"${result(s)}(${(result(lhs) :: result(rhs) :: Nil).mkDocument(", ")})"
case N => doc"${result(s)}(${(argument(lhs) :: argument(rhs) :: Nil).mkDocument(", ")})"
case c @ Call(fun, args) =>
val base = fun match
case _: Value.Lam => doc"(${result(fun)})"
case _ => result(fun)
val argsDoc = args.map(result).mkDocument(", ")
val base = subexpression(fun)
val argsDoc = args.map(argument).mkDocument(", ")
if c.isMlsFun then doc"${base}(${argsDoc})" else doc"${base}(${argsDoc}) ?? null"
case Value.Lam(ps, bod) => scope.nest givenIn:
val (params, bodyDoc) = setupFunction(none, ps, bod)
Expand All @@ -123,7 +128,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
doc"new ${result(cls)}(${as.map(result).mkDocument(", ")})"
case Value.Arr(es) if es.isEmpty => doc"[]"
case Value.Arr(es) =>
doc"[ #{ # ${es.map(result).mkDocument(doc", # ")} #} # ]"
doc"[ #{ # ${es.map(argument).mkDocument(doc", # ")} #} # ]"
def returningTerm(t: Block)(using Raise, Scope): Document = t match
case Assign(l, r, rst) =>
doc" # ${getVar(l)} = ${result(r)};${returningTerm(rst)}"
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/test/mlscript-compile/Stack.mls
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ open Predef

module Stack with ...

class (::) Cons[A](head: A, tail)
class (::) Cons[A](head: A, tail: Stack[A])
object Nil

fun isEmpty(xs) = xs is Nil
Expand Down Expand Up @@ -46,7 +46,7 @@ fun zip(...xss) =
Nil then go(heads, tails)(t)
Nil and heads is
Nil then assert(tails is Nil); Nil
else heads toReverseArray() :: go(Nil, Nil)(tails reverse())
else heads toReverseArray() :: go(Nil, Nil) of tails reverse()
go(Nil, Nil) of fromArray(xss)


7 changes: 7 additions & 0 deletions hkmc2/shared/src/test/mlscript/basics/MultiParamLists.mls
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ f(42)
//│ this.f(42)
//│ = 42

// TODO compile this to
// function f(n1) { return (n2) => f$(n1, n2); }
// function f$(n1, n2) { let tmp; tmp = 10 * n1; return tmp + n2; }

fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2)
//│ JS (unsanitized):
//│ function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2; }; } null

// TODO compile this to
// this.f$(4, 2)

f(4)(2)
//│ JS (unsanitized):
//│ let tmp; tmp = this.f(4); tmp(2) ?? null
Expand Down
6 changes: 4 additions & 2 deletions hkmc2/shared/src/test/mlscript/codegen/CaseShorthand.mls
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ val isDefined = case
Some then true
None then false
//│ JS (unsanitized):
//│ this.isDefined = (caseScrut) => {
//│ let tmp;
//│ tmp = (caseScrut) => {
//│ if (caseScrut instanceof this.Some.class) {
//│ return true;
//│ } else {
Expand All @@ -61,6 +62,7 @@ val isDefined = case
//│ }
//│ }
//│ };
//│ this.isDefined = tmp;
//│ null
//│ isDefined = [Function (anonymous)]
//│ isDefined = [Function: tmp]

2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/codegen/Do.mls
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ val f = case
//│ > let $doTemp = globalThis:import#Prelude#666(.)console‹member:console›(.)log("non-null")
//│ > caseScrut is 1 then "unit"
//│ > else "other"
//│ f = [Function (anonymous)]
//│ f = [Function: tmp]

f(0)
//│ = 'null'
Expand Down
34 changes: 19 additions & 15 deletions hkmc2/shared/src/test/mlscript/codegen/IfThenElse.mls
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,42 @@ f(false)
:sjs
let f = x => log((if x then "ok" else "ko") + "!")
//│ JS (unsanitized):
//│ this.f = (x) => {
//│ let tmp, tmp1;
//│ let tmp;
//│ tmp = (x) => {
//│ let tmp1, tmp2;
//│ if (x) {
//│ tmp = "ok";
//│ tmp1 = "ok";
//│ } else {
//│ tmp = "ko";
//│ tmp1 = "ko";
//│ }
//│ tmp1 = tmp + "!";
//│ return this.log(tmp1);
//│ tmp2 = tmp1 + "!";
//│ return this.log(tmp2);
//│ };
//│ this.f = tmp;
//│ null
//│ f = [Function (anonymous)]
//│ f = [Function: tmp]

:sjs
let f = x => log((if x and x then "ok" else "ko") + "!")
//│ JS (unsanitized):
//│ this.f = (x) => {
//│ let tmp, tmp1;
//│ let tmp;
//│ tmp = (x) => {
//│ let tmp1, tmp2;
//│ if (x) {
//│ if (x) {
//│ tmp = "ok";
//│ tmp1 = "ok";
//│ } else {
//│ tmp = "ko";
//│ tmp1 = "ko";
//│ }
//│ } else {
//│ tmp = "ko";
//│ tmp1 = "ko";
//│ }
//│ tmp1 = tmp + "!";
//│ return this.log(tmp1);
//│ tmp2 = tmp1 + "!";
//│ return this.log(tmp2);
//│ };
//│ this.f = tmp;
//│ null
//│ f = [Function (anonymous)]
//│ f = [Function: tmp]
// --- TODO: What we want ---
// this.f = (x) => {
// let tmp, tmp1, flag;
Expand Down
44 changes: 44 additions & 0 deletions hkmc2/shared/src/test/mlscript/codegen/InlineLambdas.mls
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
:js

:sjs
(x => x + 1 + 1 + 1 + 1 + 1)(1)
//│ JS (unsanitized):
//│ ((x) => {
//│ let tmp, tmp1, tmp2, tmp3;
//│ tmp = x + 1;
//│ tmp1 = tmp + 1;
//│ tmp2 = tmp1 + 1;
//│ tmp3 = tmp2 + 1;
//│ return tmp3 + 1;
//│ })(1)
//│ = 6

:sjs
(x => x + 1 + 1 + 1 + 1 + 1 + 1)(1)
//│ JS (unsanitized):
//│ let tmp;
//│ tmp = (x) => {
//│ let tmp1, tmp2, tmp3, tmp4, tmp5;
//│ tmp1 = x + 1;
//│ tmp2 = tmp1 + 1;
//│ tmp3 = tmp2 + 1;
//│ tmp4 = tmp3 + 1;
//│ tmp5 = tmp4 + 1;
//│ return tmp5 + 1;
//│ };
//│ tmp(1)
//│ = 7

:sjs
(x => x) + 1
//│ JS (unsanitized):
//│ ((x) => { return x; }) + 1
//│ = '(...args) => { globalThis.Predef.checkArgs("", 1, true, args.length); let x = args[0]; return x; }1'

:sjs
1 + (x => x)
//│ JS (unsanitized):
//│ 1 + ((x) => { return x; })
//│ = '1(...args) => { globalThis.Predef.checkArgs("", 1, true, args.length); let x = args[0]; return x; }'


10 changes: 6 additions & 4 deletions hkmc2/shared/src/test/mlscript/codegen/OptMatch.mls
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ val isDefined = case
Some(_) then true
None then false
//│ JS (unsanitized):
//│ this.isDefined = (caseScrut) => {
//│ let tmp;
//│ tmp = (caseScrut) => {
//│ let param0;
//│ if (caseScrut instanceof this.Some.class) {
//│ param0 = caseScrut.value;
Expand All @@ -48,8 +49,9 @@ val isDefined = case
//│ }
//│ }
//│ };
//│ this.isDefined = tmp;
//│ null
//│ isDefined = [Function (anonymous)]
//│ isDefined = [Function: tmp]

isDefined(Some(1))
//│ = true
Expand All @@ -61,7 +63,7 @@ isDefined(None)
val isDefined = x => if x is
Some(_) then true
None then false
//│ isDefined = [Function (anonymous)]
//│ isDefined = [Function: tmp]

isDefined(Some(1))
//│ = true
Expand All @@ -76,7 +78,7 @@ module Foo with
val isOther = x => if x is
Foo.Other(_) then true
None then false
//│ isOther = [Function (anonymous)]
//│ isOther = [Function: tmp]


fun keepIfGreaterThan(x, y) =
Expand Down

0 comments on commit 9a10c0a

Please sign in to comment.