Skip to content

Commit

Permalink
Naïve compilation of refining patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Nov 28, 2024
1 parent efb66c1 commit e8fb2cd
Show file tree
Hide file tree
Showing 20 changed files with 1,014 additions and 18 deletions.
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class JSBuilder(using Elaborator.State) extends CodeBuilder:
}"
case Instantiate(cls, as) =>
doc"new ${result(cls)}(${as.map(result).mkDocument(", ")})"
case Value.Arr(es) if es.isEmpty => doc"[]"
case Value.Arr(es) =>
doc"[ #{ # ${es.map(result).mkDocument(doc", # ")} #} # ]"
def returningTerm(t: Block)(using Raise, Scope): Document = t match
Expand Down Expand Up @@ -190,7 +191,7 @@ class JSBuilder(using Elaborator.State) extends CodeBuilder:
} + ")""""
}; }"""
} #} # }"
if (clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj) then
if ((clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj)) || (clsDefn.kind is syntax.Pat) then
val clsTmp = summon[Scope].allocateName(new semantics.TempSymbol(N, sym.nme+"$"+"class"))
clsDefn.owner match
case S(owner) =>
Expand Down
18 changes: 5 additions & 13 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import hkmc2.syntax.Literal
import Keyword.{as, and, `else`, is, let, `then`}
import collection.mutable.HashMap
import Elaborator.{ctx, Ctxl}
import ucs.DesugaringBase

object Desugarer:
extension (op: Keyword.Infix)
Expand All @@ -29,8 +30,8 @@ object Desugarer:
val tupleLast: HashMap[Int, BlockLocalSymbol] = HashMap.empty
end Desugarer

class Desugarer(tl: TraceLogger, elaborator: Elaborator)
(using raise: Raise, state: Elaborator.State, c: Elaborator.Ctx):
class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
(using raise: Raise, state: Elaborator.State, c: Elaborator.Ctx) extends DesugaringBase:
import Desugarer.*
import Elaborator.Ctx
import elaborator.term
Expand Down Expand Up @@ -359,12 +360,6 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
raise(ErrorReport(msg"Unrecognized pattern split." -> tree.toLoc :: Nil))
_ => _ => Split.default(Term.Error)

private lazy val tupleSlice =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleSlice")))

private lazy val tupleGet =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleGet")))

/** Elaborate a single match (a scrutinee and a pattern) and forms a split
* with an innermost split as the sequel of the match.
* @param scrutSymbol the symbol representing the scrutinee
Expand Down Expand Up @@ -395,6 +390,7 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(cls: ModuleSymbol) =>
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(psym: PatternSymbol) => makeUnapplyBranch(ref, psym, sequel(ctx))(fallback)
case N =>
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Cannot use this ${ctor.describe} as a pattern" -> ctor.toLoc :: Nil))
Expand All @@ -412,10 +408,6 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
case ((lead, N), pat) => (lead :+ pat, N)
case ((lead, S((rest, last))), pat) => (lead, S((rest, last :+ pat)))
// Some helper functions. TODO: deduplicate
def int(i: Int) = Term.Lit(IntLit(BigInt(i)))
def fld(t: Term) = Fld(FldFlags.empty, t, N)
def tup(xs: Fld*) = Term.Tup(xs.toList)(Tup(Nil))
def app(lhs: Term, rhs: Term, sym: FlowSymbol) = Term.App(lhs, rhs)(Tree.App(Tree.Empty(), Tree.Empty()), sym)
def getLast(i: Int) = TempSymbol(N, s"last$i")
// `wrap`: add let bindings for tuple elements
// `matches`: pairs of patterns and symbols to be elaborated
Expand All @@ -426,7 +418,7 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
case ((wrapInner, matches), (pat, lastIndex)) =>
val sym = scrutSymbol.getTupleLastSubScrutinee(lastIndex)
val wrap = (split: Split) =>
Split.Let(sym, app(tupleGet, tup(fld(ref), fld(int(-1 - lastIndex))), sym), wrapInner(split))
Split.Let(sym, callTupleGet(ref, -1 - lastIndex, sym), wrapInner(split))
(wrap, (sym, pat) :: matches)
val lastMatches = reversedLastMatches.reverse
rest match
Expand Down
25 changes: 23 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,18 @@ extends Importer:
def mkLetBinding(sym: LocalSymbol, rhs: Term): Ls[Statement] =
LetDecl(sym) :: DefineVar(sym, rhs) :: Nil

def resolveField(srcTree: Tree, base: Opt[Symbol], nme: Ident): Opt[Symbol] =
def resolveField(srcTree: Tree, base: Opt[Symbol], nme: Ident): Ctxl[Opt[Symbol]] =
base match
// Look up symbols qualified by `globalThis.`.
case S(tsym: TopLevelSymbol) =>
// Locate the nearest context with top-level symbols.
def find(ctx: Ctx): Ctx =
ctx.outer match
case S(sym: TopLevelSymbol) => ctx
case _ => ctx.parent match
case S(pctx) => find(pctx)
case N => ctx
find(ctx).get(nme.name).flatMap(_.symbol)
case S(psym: BlockMemberSymbol) =>
psym.modTree match
case S(cls) =>
Expand Down Expand Up @@ -606,7 +616,7 @@ extends Importer:
raise(d)
go(sts, acc)
case (td @ TypeDef(k, head, extension, body)) :: sts =>
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj), k)
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj) || (k is Pat), k)
val nme = td.name match
case R(id) => id
case L(d) =>
Expand Down Expand Up @@ -658,6 +668,17 @@ extends Importer:
semantics.TypeDef(alsSym, tps, extension.map(term), N)
alsSym.defn = S(d)
d
case Pat =>
val patSym = td.symbol.asInstanceOf[PatternSymbol] // TODO improve `asInstanceOf`
val owner = ctx.outer
newCtx.nest(S(patSym)).givenIn:
assert(body.isEmpty)
log(s"pattern body is ${td.extension}")
val compose = new ucs.Translator(tl, this)
val bod = compose(ps.getOrElse(Nil), td.extension.getOrElse(die))
val pd = PatternDef(owner, patSym, tps, ps, ObjBody(Term.Blk(bod, Term.Lit(UnitLit(true)))))
patSym.defn = S(pd)
pd
case k: (Mod.type | Obj.type) =>
val clsSym = td.symbol.asInstanceOf[ModuleSymbol] // TODO: improve `asInstanceOf`
val owner = ctx.outer
Expand Down
16 changes: 15 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,19 @@ abstract class Symbol(using State) extends Located:
case mem: BlockMemberSymbol =>
mem.modTree.map(_.symbol.asInstanceOf[ModuleSymbol])
case _ => N
def asPat: Opt[PatternSymbol] = this match
case pat: PatternSymbol => S(pat)
case mem: BlockMemberSymbol =>
mem.patTree.map(_.symbol.asInstanceOf[PatternSymbol])
case _ => N
def asAls: Opt[TypeAliasSymbol] = this match
case cls: TypeAliasSymbol => S(cls)
case mem: BlockMemberSymbol =>
mem.alsTree.map(_.symbol.asInstanceOf[TypeAliasSymbol])
case _ => N

def asClsLike: Opt[ClassSymbol | ModuleSymbol] = asCls orElse asMod
def asClsLike: Opt[ClassSymbol | ModuleSymbol | PatternSymbol] =
(asCls: Opt[ClassSymbol | ModuleSymbol | PatternSymbol]) orElse asMod orElse asPat
def asTpe: Opt[TypeSymbol] = asCls orElse asAls

override def equals(x: Any): Bool = x match
Expand Down Expand Up @@ -102,6 +108,8 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree])(using State)
case t: Tree.TypeDef if (t.k is Mod) || (t.k is Obj) => t
def alsTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Als => t
def patTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Pat => t
def trmTree: Opt[Tree.TermDef] = trees.collectFirst:
case t: Tree.TermDef /* if t.k is */ => t
def trmImplTree: Opt[Tree.TermDef] = trees.collectFirst:
Expand Down Expand Up @@ -172,6 +180,12 @@ class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol[Type
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here
override def toString: Str = s"module:${id.name}"

class PatternSymbol(val id: Tree.Ident)(using State)
extends MemberSymbol[PatternDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of pattern here
override def toString: Str = s"pattern:${id.name}"

class TopLevelSymbol(blockNme: Str)(using State)
extends MemberSymbol[ModuleDef] with InnerSymbol:
def nme = blockNme
Expand Down
12 changes: 12 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ sealed trait Statement extends AutoLocated:
mod.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: mod.body.blk :: Nil
case td: TypeDef =>
td.rhs.toList
case pat: PatternDef =>
pat.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: pat.body.blk :: Nil
case Import(sym, pth) => Nil
case Try(body, finallyDo) => body :: finallyDo :: Nil
case Handle(lhs, rhs, defs) => rhs :: defs._1 :: Nil
Expand Down Expand Up @@ -247,6 +249,16 @@ case class ModuleDef(
body: ObjBody,
) extends ClassLikeDef with Companion

case class PatternDef(
owner: Opt[InnerSymbol],
sym: PatternSymbol,
tparams: Ls[TyParam],
paramsOpt: Opt[Ls[Param]],
body: ObjBody
) extends ClassLikeDef:
self =>
val kind: ClsLikeKind = Pat


sealed abstract class ClassDef extends ClassLikeDef:
val kind: ClsLikeKind
Expand Down
122 changes: 122 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/DesugaringBase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package hkmc2
package semantics
package ucs

import mlscript.utils.*, shorthands.*
import syntax.Tree.*, Elaborator.Ctxl

/** Contains some helpers that makes UCS desugaring easier. */
trait DesugaringBase(using Elaborator.State):
val elaborator: Elaborator

import elaborator.{term, cls}

protected transparent inline def int(i: Int) = Term.Lit(IntLit(BigInt(i)))
protected transparent inline def fld(t: Term) = Fld(FldFlags.empty, t, N)
protected transparent inline def tup(xs: Fld*) = Term.Tup(xs.toList)(Tup(Nil))
protected transparent inline def app(lhs: Term, rhs: Term, sym: FlowSymbol) =
Term.App(lhs, rhs)(App(Empty(), Empty()), sym)

protected lazy val matchResultClass =
Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("MatchResult"))

protected lazy val matchResultFailure =
Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("MatchFailure"))

protected lazy val tupleSlice: Ctxl[Term] =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleSlice")))

protected lazy val tupleGet: Ctxl[Term] =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleGet")))

protected lazy val stringStartsWith: Ctxl[Term] =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("stringStartsWith")))

protected lazy val stringGet: Ctxl[Term] =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("stringGet")))

protected lazy val stringDrop: Ctxl[Term] =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("stringDrop")))

protected def callTupleGet(t: Term, i: Int, s: FlowSymbol): Ctxl[Term] =
app(tupleGet, tup(fld(t), fld(int(i))), s)

protected def callStringStartsWith(t: Term, prefix: Term, s: FlowSymbol): Ctxl[Term] =
app(stringStartsWith, tup(fld(t), fld(prefix)), s)

protected def callStringGet(t: Term, i: Int, s: FlowSymbol): Ctxl[Term] =
app(stringGet, tup(fld(t), fld(int(i))), s)

protected def callStringDrop(t: Term, n: Int, s: FlowSymbol): Ctxl[Term] =
app(stringDrop, tup(fld(t), fld(int(n))), s)

protected transparent inline def tempLet(term: Term)(inner: TempSymbol => Split): Split =
val s = TempSymbol(N, "temp")
Split.Let(s, term, inner(s))

protected transparent inline def plainTest(cond: Term, dbgName: Str = "cond")(inner: => Split): Split =
val s = TempSymbol(N, dbgName)
Split.Let(s, cond, Branch(s.ref(), inner) ~: Split.End)

/** Make a `Branch` that calls `Pattern` symbols' `unapply` functions. */
def makeUnapplyBranch(
scrut: => Term.Ref,
psym: PatternSymbol,
inner: => Split,
method: Str = "unapply"
)(fallback: Split): Ctxl[Split] =
val matchResultClassTerm = cls(matchResultClass)
matchResultClassTerm.symbol match
case S(matchResultClassSymbol: ClassSymbol) =>
// def makeUnapplyBranch(scrut: => Term.Ref, pat)
val resultIdent = Ident("matchResult"): Ident
val resultSymbol = TempSymbol(N, "matchResult")
val globalThis = term(Ident("globalThis"))
val unapply = Term.Sel(
prefix = Term.Sel(globalThis, Ident(psym.nme))(S(psym)),
nme = Ident(method)
)(N)
val arguments = Term.Tup(Fld(FldFlags.empty, scrut, N) :: Nil)(Tup(Nil))
val call = Term.App(unapply, arguments)(App(Empty(), Empty()), FlowSymbol(s"result of $method"))
Split.Let(resultSymbol, call,
Branch(
resultSymbol.ref(),
Pattern.ClassLike(matchResultClassSymbol, matchResultClassTerm, N, false)(matchResultClass),
inner
) ~: fallback)
case S(_) | N => lastWords("Cannot locate `MatchResult` class in the global scope.")

/** Make a `Branch` that calls `Pattern` symbols' `unapplyStringPrefix` functions. */
def makeUnapplyStringPrefixBranch(
scrut: => Term.Ref,
psym: PatternSymbol,
inner: (scrut: TempSymbol) => Split,
method: Str = "unapplyStringPrefix"
)(fallback: Split): Ctxl[Split] =
val matchResultClassTerm = cls(matchResultClass)
matchResultClassTerm.symbol match
case S(matchResultClassSymbol: ClassSymbol) =>
// def makeUnapplyBranch(scrut: => Term.Ref, pat)
val resultIdent = Ident("matchResult"): Ident
val resultSymbol = TempSymbol(N, "matchResult")
val globalThis = term(Ident("globalThis"))
val unapply = Term.Sel(
prefix = Term.Sel(globalThis, Ident(psym.nme))(S(psym)),
nme = Ident(method)
)(N)
val arguments = Term.Tup(Fld(FldFlags.empty, scrut, N) :: Nil)(Tup(Nil))
val call = Term.App(unapply, arguments)(App(Empty(), Empty()), FlowSymbol(s"result of $method"))
tempLet(call): resultSymbol =>
val argSym = TempSymbol(N, "arg")
val tupleGetRes = FlowSymbol("postfix")
Branch(
resultSymbol.ref(),
Pattern.ClassLike(
matchResultClassSymbol,
matchResultClassTerm,
S(argSym :: Nil),
false
)(matchResultClass),
tempLet(callTupleGet(argSym.ref(), 0, tupleGetRes))(inner)
) ~: fallback
case S(_) | N => lastWords("Cannot locate `MatchResult` class in the global scope.")
Loading

0 comments on commit e8fb2cd

Please sign in to comment.