Skip to content

Commit

Permalink
Remove reliance on global state to deal with symbol UIDs
Browse files Browse the repository at this point in the history
  • Loading branch information
LPTK committed Nov 22, 2024
1 parent d08424f commit ac7c3d1
Show file tree
Hide file tree
Showing 15 changed files with 75 additions and 54 deletions.
8 changes: 5 additions & 3 deletions hkmc2/jvm/src/test/scala/hkmc2/MLsDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ abstract class MLsDiffMaker extends DiffMaker:
if doTrace then super.trace(pre, post)(thunk)
else thunk

var curCtx = Elaborator.Ctx.init.nest(N)
var curCtx = Elaborator.State.init.nest(N)


override def run(): Unit =
Expand Down Expand Up @@ -103,7 +103,8 @@ abstract class MLsDiffMaker extends DiffMaker:
if showParse.isSet || dbgParsing.isSet then
output(syntax.Lexer.printTokens(tokens))

val p = new syntax.Parser(origin, tokens, raise, dbg = dbgParsing.isSet):
val rules = syntax.ParseRules()
val p = new syntax.Parser(origin, tokens, rules, raise, dbg = dbgParsing.isSet):
def doPrintDbg(msg: => Str): Unit = if dbg then output(msg)
val res = p.parseAll(p.block(allowNewlines = true))
val imprtSymbol =
Expand Down Expand Up @@ -136,7 +137,8 @@ abstract class MLsDiffMaker extends DiffMaker:
if showParse.isSet || dbgParsing.isSet then
output(syntax.Lexer.printTokens(tokens))

val p = new syntax.Parser(origin, tokens, raise, dbg = dbgParsing.isSet):
val rules = syntax.ParseRules()
val p = new syntax.Parser(origin, tokens, rules, raise, dbg = dbgParsing.isSet):
def doPrintDbg(msg: => Str): Unit = if dbg then output(msg)
val res = p.parseAll(p.block(allowNewlines = true))

Expand Down
1 change: 0 additions & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/Watcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class Watcher(dir: File):
val relativeName = basePath.map(_ + "/").mkString + path.baseName
val preludePath = os.pwd/os.up/"shared"/"src"/"test"/"mlscript"/"decls"/"Prelude.mls"
val predefPath = os.pwd/os.up/"shared"/"src"/"test"/"mlscript-compile"/"Predef.mls"
semantics.suid.reset // FIXME hack
val isModuleFile = path.segments.contains("mlscript-compile")
if isModuleFile
then
Expand Down
8 changes: 5 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import utils.*
import hkmc2.semantics.MemberSymbol
import hkmc2.semantics.Elaborator
import hkmc2.syntax.Keyword.`override`
import semantics.Elaborator.State


class MLsCompiler(preludeFile: os.Path):
Expand Down Expand Up @@ -38,13 +39,14 @@ class MLsCompiler(preludeFile: os.Path):
// if showParse.isSet || dbgParsing.isSet then
// output(syntax.Lexer.printTokens(tokens))

val p = new syntax.Parser(origin, tokens, raise, dbg = dbgParsing):
given Elaborator.State = new Elaborator.State
val rules = syntax.ParseRules()
val p = new syntax.Parser(origin, tokens, rules, raise, dbg = dbgParsing):
def doPrintDbg(msg: => Str): Unit =
// if dbg then output(msg)
if dbg then println(msg)
val res = p.parseAll(p.block(allowNewlines = true))
given Elaborator.State = new Elaborator.State
given Elaborator.Ctx = Elaborator.Ctx.init.nest(N)
given Elaborator.Ctx = State.init.nest(N)
val wd = file / os.up
val elab = Elaborator(etl, wd)
val resBlk = new syntax.Tree.Block(res)
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import hkmc2.Message.MessageContext

import hkmc2.{semantics => sem}
import hkmc2.semantics.{Term => st}
import semantics.Elaborator.State

import syntax.{Literal, Tree}
import semantics.*
Expand Down Expand Up @@ -260,7 +261,7 @@ class Lowering(using TL, Raise, Elaborator.State):
else End()
)
case Split.End =>
Throw(Instantiate(Select(Value.Ref(Elaborator.Ctx.globalThisSymbol), Tree.Ident("Error")),
Throw(Instantiate(Select(Value.Ref(State.globalThisSymbol), Tree.Ident("Error")),
Value.Lit(syntax.Tree.StrLit("match error")) :: Nil)) // TODO add failed-match scrutinee info

if k.isInstanceOf[TailOp] && isIf then go(iftrm.normalized, topLevel = true)
Expand Down
6 changes: 4 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ abstract class CodeBuilder:
type Context


class JSBuilder extends CodeBuilder:
class JSBuilder(using Elaborator.State) extends CodeBuilder:

val builtinOpsBase: Ls[Str] = Ls(
"+", "-", "*", "/", "%",
Expand Down Expand Up @@ -425,7 +425,9 @@ object JSBuilder:
end JSBuilder


trait JSBuilderSanityChecks(instrument: Bool) extends JSBuilder:
trait JSBuilderSanityChecks
(instrument: Bool)(using Elaborator.State)
extends JSBuilder:

val functionParamVarargSymbol = semantics.TempSymbol(N, "args")

Expand Down
11 changes: 7 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/Scope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,23 @@ import hkmc2.semantics.InnerSymbol
import hkmc2.semantics.VarSymbol
import hkmc2.semantics.Elaborator
import hkmc2.semantics.TopLevelSymbol
import semantics.Elaborator.State


/** When `curThis`, it means this scope does not rebind `this`.
* When `curThis` is Some(None), it means the scope rebinds `this`
* to something unknown, following JavaScript's inane `this` handling in `function`s.
* When `curThis` is Some(Some(sym)), it means the scope rebinds `this`
* to an inner symbol (e.g., class or module). */
class Scope(val parent: Opt[Scope], val curThis: Opt[Opt[InnerSymbol]], val bindings: MutMap[Local, Str]):
class Scope
(val parent: Opt[Scope], val curThis: Opt[Opt[InnerSymbol]], val bindings: MutMap[Local, Str])
(using State):

private var thisProxyAccessed = false
lazy val thisProxy =
curThis match
case N | S(N) => die
case S(S(Elaborator.Ctx.globalThisSymbol)) => "globalThis"
case S(S(State.globalThisSymbol)) => "globalThis"
case S(S(thisSym)) =>
thisProxyAccessed = true
allocateName(thisSym, "this$")
Expand Down Expand Up @@ -112,8 +115,8 @@ object Scope:

def scope(using scp: Scope): Scope = scp

def empty: Scope =
Scope(N, S(S(Elaborator.Ctx.globalThisSymbol)), MutMap.empty)
def empty(using State): Scope =
Scope(N, S(S(State.globalThisSymbol)), MutMap.empty)

def replaceTicks(str: Str): Str = str.replace('\'', '$')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import syntax.Tree.*
import hkmc2.syntax.TypeOrTermDef


trait BlockImpl:
trait BlockImpl(using Elaborator.State):
self: Block =>

val desugStmts = stmts.map(_.desugared)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import hkmc2.syntax.Literal
import Keyword.{as, and, `else`, is, let, `then`}
import collection.mutable.HashMap
import Elaborator.{ctx, Ctxl}
import hkmc2.semantics.Elaborator.Ctx.globalThisSymbol

object Desugarer:
extension (op: Keyword.Infix)
Expand Down
12 changes: 7 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ object Elaborator:
def symbol = symOpt
given Conversion[Symbol, Elem] = RefElem(_)
val empty: Ctx = Ctx(N, N, Map.empty)
// val globalThisSymbol = TermSymbol(ImmutVal, N, Ident("globalThis"))
type Ctxl[A] = Ctx ?=> A
def ctx: Ctxl[Ctx] = summon
class State:
given State = this
val suid = new Uid.Symbol.State
val globalThisSymbol = TopLevelSymbol("globalThis")
val seqSymbol = TermSymbol(ImmutVal, N, Ident(";"))
def init(using State): Ctx = empty.copy(env = Map(
def init(using State): Ctx = Ctx.empty.copy(env = Map(
"globalThis" -> globalThisSymbol,
))
type Ctxl[A] = Ctx ?=> A
def ctx: Ctxl[Ctx] = summon
class State
transparent inline def State(using state: State): State = state
import Elaborator.*

class Elaborator(val tl: TraceLogger, val wd: os.Path)
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ class Importer:

val lexer = new syntax.Lexer(origin, dbg = tl.doTrace)
val tokens = lexer.bracketedTokens
val p = new syntax.Parser(origin, tokens, raise, dbg = tl.doTrace):
val rules = syntax.ParseRules()
val p = new syntax.Parser(origin, tokens, rules, raise, dbg = tl.doTrace):
def doPrintDbg(msg: => Str): Unit =
// if dbg then output(msg)
if dbg then tl.log(msg)
Expand Down
43 changes: 22 additions & 21 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ import scala.collection.mutable.{Set => MutSet}
import mlscript.utils.*, shorthands.*
import syntax.*

import Elaborator.State
import Tree.Ident


// TODO refactor: don't rely on global state!
val suid = new Uid.Symbol.State


abstract class Symbol extends Located:
abstract class Symbol(using State) extends Located:

def nme: Str
val uid: Uid[Symbol] = suid.nextUid

val uid: Uid[Symbol] = State.suid.nextUid

val directRefs: mutable.Buffer[Term.Ref] = mutable.Buffer.empty
def ref(id: Tree.Ident =
Expand Down Expand Up @@ -55,7 +53,7 @@ abstract class Symbol extends Located:
end Symbol


class FlowSymbol(label: Str) extends Symbol:
class FlowSymbol(label: Str)(using State) extends Symbol:
def nme: Str = label
def toLoc: Option[Loc] = N // TODO track source trees of flows
import typing.*
Expand All @@ -72,26 +70,29 @@ sealed trait NamedSymbol extends Symbol:
def name: Str
def id: Ident

abstract class BlockLocalSymbol(name: Str) extends FlowSymbol(name) with LocalSymbol:
abstract class BlockLocalSymbol(name: Str)(using State) extends FlowSymbol(name) with LocalSymbol:
var decl: Opt[Declaration] = N

class TempSymbol(val trm: Opt[Term], dbgNme: Str = "tmp") extends BlockLocalSymbol(dbgNme):
class TempSymbol(val trm: Opt[Term], dbgNme: Str = "tmp")(using State) extends BlockLocalSymbol(dbgNme):
val nameHints: MutSet[Str] = MutSet.empty
override def toLoc: Option[Loc] = trm.flatMap(_.toLoc)
override def toString: Str = s"$$${super.toString}"

class VarSymbol(val id: Ident) extends BlockLocalSymbol(id.name) with NamedSymbol:
class VarSymbol(val id: Ident)(using State) extends BlockLocalSymbol(id.name) with NamedSymbol:
val name: Str = id.name
// override def toString: Str = s"$name@$uid"

class BuiltinSymbol(val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool) extends Symbol:
class BuiltinSymbol
(val nme: Str, val binary: Bool, val unary: Bool, val nullary: Bool)(using State)
extends Symbol:
def toLoc: Option[Loc] = N
override def toString: Str = s"builtin:$nme"


/** This is the outside-facing symbol associated to a possibly-overloaded
* definition living in a block – e.g., a module or class. */
class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree]) extends MemberSymbol[Definition]:
class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree])(using State)
extends MemberSymbol[Definition]:

def toLoc: Option[Loc] = Loc(trees)

Expand All @@ -114,12 +115,12 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree]) extends MemberSymbol[
end BlockMemberSymbol


sealed abstract class MemberSymbol[Defn <: Definition] extends Symbol:
sealed abstract class MemberSymbol[Defn <: Definition](using State) extends Symbol:
def nme: Str
var defn: Opt[Defn] = N


class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident)
class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.Ident)(using State)
extends MemberSymbol[Definition] with LocalSymbol with NamedSymbol:
def nme: Str = id.name
def name: Str = nme
Expand All @@ -129,16 +130,16 @@ class TermSymbol(val k: TermDefKind, val owner: Opt[InnerSymbol], val id: Tree.I

sealed trait CtorSymbol extends Symbol

case class Extr(isTop: Bool) extends CtorSymbol:
case class Extr(isTop: Bool)(using State) extends CtorSymbol:
def nme: Str = if isTop then "Top" else "Bot"
def toLoc: Option[Loc] = N
override def toString: Str = nme

case class LitSymbol(lit: Literal) extends CtorSymbol:
case class LitSymbol(lit: Literal)(using State) extends CtorSymbol:
def nme: Str = lit.toString
def toLoc: Option[Loc] = lit.toLoc
override def toString: Str = s"lit:$lit"
case class TupSymbol(arity: Opt[Int]) extends CtorSymbol:
case class TupSymbol(arity: Opt[Int])(using State) extends CtorSymbol:
def nme: Str = s"Tuple#$arity"
def toLoc: Option[Loc] = N
override def toString: Str = s"tup:$arity"
Expand All @@ -152,26 +153,26 @@ type TypeSymbol = ClassSymbol | TypeAliasSymbol
* A `Ref(_: InnerSymbol)` represents a `this`-like reference to the current object. */
sealed trait InnerSymbol extends Symbol

class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)
class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
extends MemberSymbol[ClassDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of classe here
override def toString: Str = s"class:$nme"
/** Compute the arity. */
def arity: Int = tree.paramLists.headOption.fold(0)(_.fields.length)

class ModuleSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)
class ModuleSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
extends MemberSymbol[ModuleDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of module here
override def toString: Str = s"module:${id.name}"

class TypeAliasSymbol(val id: Tree.Ident) extends MemberSymbol[TypeDef]:
class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol[TypeDef]:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here
override def toString: Str = s"module:${id.name}"

class TopLevelSymbol(blockNme: Str)
class TopLevelSymbol(blockNme: Str)(using State)
extends MemberSymbol[ModuleDef] with InnerSymbol:
def nme = blockNme
def toLoc: Option[Loc] = N
Expand Down
7 changes: 6 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/syntax/ParseRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import sourcecode.{Name, Line}
import mlscript.utils.*, shorthands.*
import hkmc2.Message._
import BracketKind._
import semantics.Elaborator.State


// * TODO: add lookahead to Expr as a PartialFunction[Ls[Token], Bool]
Expand Down Expand Up @@ -49,7 +50,10 @@ class ParseRule[+A](val name: Str, val omitAltsStr: Bool = false)(val alts: Alt[
case str1 :: str2 :: Nil => s"$str1 or $str2"
case strs => strs.init.mkString(", ") + ", or " + strs.last

object ParseRule:
end ParseRule


class ParseRules(using State):
import Keyword.*
import Alt.*
import Tree.*
Expand Down Expand Up @@ -352,4 +356,5 @@ object ParseRule:
genInfixRule(`restricts`, (rhs, _: Unit) => lhs => InfixApp(lhs, `restricts`, rhs)),
)

end ParseRules

13 changes: 8 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/syntax/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import Parser.*
import scala.annotation.tailrec

import Keyword.`let`
import hkmc2.syntax.ParseRule.prefixRules
import hkmc2.syntax.ParseRule.infixRules
import hkmc2.syntax.Keyword.Ellipsis

import semantics.Elaborator.State


object Parser:

Expand Down Expand Up @@ -107,12 +107,15 @@ import Parser._
abstract class Parser(
origin: Origin,
tokens: Ls[TokLoc],
rules: ParseRules,
raiseFun: Diagnostic => Unit,
val dbg: Bool,
// fallbackLoc: Opt[Loc], description: Str = "input",
):
)(using State):
outer =>

import rules.*

protected def doPrintDbg(msg: => Str): Unit
protected def printDbg(msg: => Any): Unit =
doPrintDbg("" * this.indent + msg)
Expand All @@ -134,7 +137,7 @@ abstract class Parser(
res

final def rec(tokens: Ls[Stroken -> Loc], fallbackLoc: Opt[Loc], description: Str): Parser =
new Parser(origin, tokens, raiseFun, dbg
new Parser(origin, tokens, rules, raiseFun, dbg
// , fallbackLoc, description
):
def doPrintDbg(msg: => Str): Unit = outer.printDbg("> " + msg)
Expand Down Expand Up @@ -458,7 +461,7 @@ abstract class Parser(
// TODO: rm `allowIndentedBlock`? Seems it can always be `true`
def expr(prec: Int, allowIndentedBlock: Bool = true)(using Line): Tree =
parseRule(prec,
if allowIndentedBlock then ParseRule.prefixRulesAllowIndentedBlock else prefixRules
if allowIndentedBlock then prefixRulesAllowIndentedBlock else prefixRules
).getOrElse(errExpr) // * a `None` result means an alread-reported error

def simpleExpr(prec: Int)(using Line): Tree = wrap(prec)(simpleExprImpl(prec))
Expand Down
Loading

0 comments on commit ac7c3d1

Please sign in to comment.