Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Naïve compilation of refining patterns #242

Merged
merged 13 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
}"
case Instantiate(cls, as) =>
doc"new ${result(cls)}(${as.map(result).mkDocument(", ")})"
case Value.Arr(es) if es.isEmpty => doc"[]"
case Value.Arr(es) =>
doc"[ #{ # ${es.map(result).mkDocument(doc", # ")} #} # ]"
def returningTerm(t: Block)(using Raise, Scope): Document = t match
Expand Down Expand Up @@ -191,7 +192,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
} + ")""""
}; }"""
} #} # }"
if (clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj) then
if (clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj) || (clsDefn.kind is syntax.Pat) then
val clsTmp = summon[Scope].allocateName(new semantics.TempSymbol(N, sym.nme+"$"+"class"))
clsDefn.owner match
case S(owner) =>
Expand Down
20 changes: 5 additions & 15 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import hkmc2.syntax.Literal
import Keyword.{as, and, `else`, is, let, `then`}
import collection.mutable.HashMap
import Elaborator.{ctx, Ctxl}
import ucs.DesugaringBase

object Desugarer:
extension (op: Keyword.Infix)
Expand All @@ -29,8 +30,8 @@ object Desugarer:
val tupleLast: HashMap[Int, BlockLocalSymbol] = HashMap.empty
end Desugarer

class Desugarer(tl: TraceLogger, elaborator: Elaborator)
(using raise: Raise, state: Elaborator.State, c: Elaborator.Ctx):
class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
(using raise: Raise, state: Elaborator.State, c: Elaborator.Ctx) extends DesugaringBase:
import Desugarer.*
import Elaborator.Ctx
import elaborator.term
Expand Down Expand Up @@ -359,12 +360,6 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
raise(ErrorReport(msg"Unrecognized pattern split." -> tree.toLoc :: Nil))
_ => _ => Split.default(Term.Error)

private lazy val tupleSlice =
term(SynthSel(SynthSel(Ident("globalThis"), Ident("Predef")), Ident("tupleSlice")))

private lazy val tupleGet =
term(SynthSel(SynthSel(Ident("globalThis"), Ident("Predef")), Ident("tupleGet")))

/** Elaborate a single match (a scrutinee and a pattern) and forms a split
* with an innermost split as the sequel of the match.
* @param scrutSymbol the symbol representing the scrutinee
Expand Down Expand Up @@ -395,6 +390,7 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(cls: ModuleSymbol) =>
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(psym: PatternSymbol) => makeUnapplyBranch(ref, clsTrm, sequel(ctx))(fallback)
case N =>
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Cannot use this ${ctor.describe} as a pattern" -> ctor.toLoc :: Nil))
Expand All @@ -411,12 +407,6 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
case ((lead, N), Spread(_, _, patOpt)) => (lead, S((patOpt, Nil)))
case ((lead, N), pat) => (lead :+ pat, N)
case ((lead, S((rest, last))), pat) => (lead, S((rest, last :+ pat)))
// Some helper functions. TODO: deduplicate
def int(i: Int) = Term.Lit(IntLit(BigInt(i)))
def fld(t: Term) = Fld(FldFlags.empty, t, N)
def tup(xs: Fld*) = Term.Tup(xs.toList)(Tup(Nil))
def app(lhs: Term, rhs: Term, sym: FlowSymbol) = Term.App(lhs, rhs)(Tree.App(Tree.Empty(), Tree.Empty()), sym)
def getLast(i: Int) = TempSymbol(N, s"last$i")
// `wrap`: add let bindings for tuple elements
// `matches`: pairs of patterns and symbols to be elaborated
val (wrapRest, restMatches) = rest match
Expand All @@ -426,7 +416,7 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
case ((wrapInner, matches), (pat, lastIndex)) =>
val sym = scrutSymbol.getTupleLastSubScrutinee(lastIndex)
val wrap = (split: Split) =>
Split.Let(sym, app(tupleGet, tup(fld(ref), fld(int(-1 - lastIndex))), sym), wrapInner(split))
Split.Let(sym, callTupleGet(ref, -1 - lastIndex, sym), wrapInner(split))
(wrap, (sym, pat) :: matches)
val lastMatches = reversedLastMatches.reverse
rest match
Expand Down
28 changes: 19 additions & 9 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ object Elaborator:
given State = this
val suid = new Uid.Symbol.State
val globalThisSymbol = TopLevelSymbol("globalThis")
val builtinOpsMap =
val baseBuiltins = binaryOps.map: op =>
op -> BuiltinSymbol(op, binary = true, unary = unaryOps(op), nullary = false)
.toMap
baseBuiltins ++ aliasOps.map:
case (alias, base) => alias -> baseBuiltins(base)
val seqSymbol = TermSymbol(ImmutVal, N, Ident(";"))
def init(using State): Ctx = Ctx.empty.copy(env = Map(
"globalThis" -> globalThisSymbol,
Expand All @@ -131,13 +137,6 @@ extends Importer:
private val allocSkolemSym = VarSymbol(Ident("Alloc"))
private val allocSkolemDef = TyParam(FldFlags.empty, N, allocSkolemSym)
allocSkolemSym.decl = S(allocSkolemDef)

private val builtinOpsMap =
val baseBuiltins = binaryOps.map: op =>
op -> BuiltinSymbol(op, binary = true, unary = unaryOps(op), nullary = false)
.toMap
baseBuiltins ++ aliasOps.map:
case (alias, base) => alias -> baseBuiltins(base)

def mkLetBinding(sym: LocalSymbol, rhs: Term): Ls[Statement] =
LetDecl(sym) :: DefineVar(sym, rhs) :: Nil
Expand Down Expand Up @@ -226,7 +225,7 @@ extends Importer:
ctx.get(name) match
case S(sym) => sym.ref(id)
case N =>
builtinOpsMap.get(name) match
state.builtinOpsMap.get(name) match
case S(bi) => bi.ref(id)
case N =>
raise(ErrorReport(msg"Name not found: $name" -> tree.toLoc :: Nil))
Expand Down Expand Up @@ -651,7 +650,7 @@ extends Importer:
raise(d)
go(sts, acc)
case (td @ TypeDef(k, head, extension, body)) :: sts =>
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj), k)
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj) || (k is Pat), k)
val nme = td.name match
case R(id) => id
case L(d) =>
Expand Down Expand Up @@ -703,6 +702,17 @@ extends Importer:
semantics.TypeDef(alsSym, tps, extension.map(term(_)), N)
alsSym.defn = S(d)
d
case Pat =>
val patSym = td.symbol.asInstanceOf[PatternSymbol] // TODO improve `asInstanceOf`
val owner = ctx.outer
newCtx.nest(S(patSym)).givenIn:
assert(body.isEmpty)
log(s"pattern body is ${td.extension}")
val translate = new ucs.Translator(this)
val bod = translate(ps.map(_.params).getOrElse(Nil), td.extension.getOrElse(die))
val pd = PatternDef(owner, patSym, tps, ps, ObjBody(Term.Blk(bod, Term.Lit(UnitLit(true)))))
patSym.defn = S(pd)
pd
case k: (Mod.type | Obj.type) =>
val clsSym = td.symbol.asInstanceOf[ModuleSymbol] // TODO: improve `asInstanceOf`
val owner = ctx.outer
Expand Down
24 changes: 17 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@ abstract class Symbol(using State) extends Located:

def asCls: Opt[ClassSymbol] = this match
case cls: ClassSymbol => S(cls)
case mem: BlockMemberSymbol =>
mem.clsTree.map(_.symbol.asInstanceOf[ClassSymbol])
case mem: BlockMemberSymbol => mem.clsTree.flatMap(_.symbol.asCls)
case _ => N
def asMod: Opt[ModuleSymbol] = this match
case cls: ModuleSymbol => S(cls)
case mem: BlockMemberSymbol =>
mem.modTree.map(_.symbol.asInstanceOf[ModuleSymbol])
case mem: BlockMemberSymbol => mem.modTree.flatMap(_.symbol.asMod)
case _ => N
def asPat: Opt[PatternSymbol] = this match
case pat: PatternSymbol => S(pat)
case mem: BlockMemberSymbol => mem.patTree.flatMap(_.symbol.asPat)
case _ => N
def asAls: Opt[TypeAliasSymbol] = this match
case cls: TypeAliasSymbol => S(cls)
case mem: BlockMemberSymbol =>
mem.alsTree.map(_.symbol.asInstanceOf[TypeAliasSymbol])
case mem: BlockMemberSymbol => mem.alsTree.flatMap(_.symbol.asAls)
case _ => N

def asClsLike: Opt[ClassSymbol | ModuleSymbol] = asCls orElse asMod
def asClsLike: Opt[ClassSymbol | ModuleSymbol | PatternSymbol] =
(asCls: Opt[ClassSymbol | ModuleSymbol | PatternSymbol]) orElse asMod orElse asPat
def asTpe: Opt[TypeSymbol] = asCls orElse asAls

override def equals(x: Any): Bool = x match
Expand Down Expand Up @@ -101,6 +103,8 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree])(using State)
case t: Tree.TypeDef if (t.k is Mod) || (t.k is Obj) => t
def alsTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Als => t
def patTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Pat => t
def trmTree: Opt[Tree.TermDef] = trees.collectFirst:
case t: Tree.TermDef /* if t.k is */ => t
def trmImplTree: Opt[Tree.TermDef] = trees.collectFirst:
Expand Down Expand Up @@ -174,6 +178,12 @@ class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol[Type
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here
override def toString: Str = s"module:${id.name}${State.dbgUid(uid)}"

class PatternSymbol(val id: Tree.Ident)(using State)
extends MemberSymbol[PatternDef] with CtorSymbol with InnerSymbol:
def nme = id.name
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of pattern here
override def toString: Str = s"pattern:${id.name}"

class TopLevelSymbol(blockNme: Str)(using State)
extends MemberSymbol[ModuleDef] with InnerSymbol:
def nme = blockNme
Expand Down
12 changes: 12 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ sealed trait Statement extends AutoLocated with ProductWithExtraInfo:
mod.paramsOpt.toList.flatMap(_.subTerms) ::: mod.body.blk :: Nil
case td: TypeDef =>
td.rhs.toList
case pat: PatternDef =>
pat.paramsOpt.toList.flatMap(_.subTerms) ::: pat.body.blk :: Nil
case Import(sym, pth) => Nil
case Try(body, finallyDo) => body :: finallyDo :: Nil
case Handle(lhs, rhs, defs) => rhs :: defs._1 :: Nil
Expand Down Expand Up @@ -256,6 +258,16 @@ case class ModuleDef(
body: ObjBody,
) extends ClassLikeDef with Companion

case class PatternDef(
owner: Opt[InnerSymbol],
sym: PatternSymbol,
tparams: Ls[TyParam],
paramsOpt: Opt[ParamList],
body: ObjBody
) extends ClassLikeDef:
self =>
val kind: ClsLikeKind = Pat


sealed abstract class ClassDef extends ClassLikeDef:
val kind: ClsLikeKind
Expand Down
114 changes: 114 additions & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/DesugaringBase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package hkmc2
package semantics
package ucs

import mlscript.utils.*, shorthands.*
import syntax.Tree.*, Elaborator.{Ctxl, ctx}

/** Contains some helpers that makes UCS desugaring easier. */
trait DesugaringBase(using state: Elaborator.State):
val elaborator: Elaborator

import elaborator.tl.*, state.globalThisSymbol

protected final def sel(p: Term, k: Ident): Term.SynthSel = Term.SynthSel(p, k)(N)
protected final def sel(p: Term, k: Ident, s: FieldSymbol): Term.SynthSel = Term.SynthSel(p, k)(S(s))
protected final def sel(p: Term, k: Str): Term.SynthSel = sel(p, Ident(k): Ident)
protected final def sel(p: Term, k: Str, s: FieldSymbol): Term.SynthSel = sel(p, Ident(k): Ident, s)
protected final def int(i: Int) = Term.Lit(IntLit(BigInt(i)))
protected final def str(s: Str) = Term.Lit(StrLit(s))
protected final def fld(t: Term) = Fld(FldFlags.empty, t, N)
protected final def tup(xs: Fld*): Term.Tup = Term.Tup(xs.toList)(Tup(Nil))
protected final def app(l: Term, r: Term, label: Str): Term.App = app(l, r, FlowSymbol(label))
protected final def app(l: Term, r: Term, s: FlowSymbol): Term.App = Term.App(l, r)(App(Empty(), Empty()), s)

/** Get the class symbol defined in the `Predef` module. */
protected def resolvePredefMember(name: Str): Ctxl[(Term.SynthSel, ClassSymbol)] =
val predefSymbol = ctx.Builtins.Predef
val innerSel = sel(globalThisSymbol.ref(), "Predef", predefSymbol)
val memberSymbol = predefSymbol.tree.definedSymbols.get(name).flatMap(_.asCls).getOrElse:
lastWords(s"Cannot resolve `$name` in `Predef`.")
(sel(innerSel, name, memberSymbol), memberSymbol)

/** Make a term looks like `globalThis.Predef.MatchResult` with its symbol. */
protected lazy val matchResultClass: Ctxl[(Term.SynthSel, ClassSymbol)] = resolvePredefMember("MatchResult")

/** Make a pattern looks like `globalThis.Predef.MatchResult.class`. */
protected def matchResultPattern(parameters: Opt[List[BlockLocalSymbol]]): Ctxl[Pattern.ClassLike] =
val (classRef, classSym) = matchResultClass
val classSel = Term.SynthSel(matchResultClass._1, Ident("class"))(S(classSym))
Pattern.ClassLike(classSym, classSel, parameters, false)(Empty())

/** Make a term looks like `globalThis.Predef.MatchFailure` with its symbol. */
protected lazy val matchFailureClass: Ctxl[(Term.SynthSel, ClassSymbol)] = resolvePredefMember("MatchFailure")

/** Make a pattern looks like `globalThis.Predef.MatchFailure.class`. */
protected def matchFailurePattern(parameters: Opt[List[BlockLocalSymbol]]): Ctxl[Pattern.ClassLike] =
val (classRef, classSym) = matchResultClass
val classSel = Term.SynthSel(matchResultClass._1, Ident("class"))(S(classSym))
Pattern.ClassLike(classSym, classSel, parameters, false)(Empty())

/** Create a term that selects a method in the `Predef` module. */
protected final def selectPredefMethod =
sel(sel(globalThisSymbol.ref(), "Predef"), _: Str)

protected lazy val tupleSlice = selectPredefMethod("tupleSlice")
protected lazy val tupleGet = selectPredefMethod("tupleGet")
protected lazy val stringStartsWith = selectPredefMethod("stringStartsWith")
protected lazy val stringGet = selectPredefMethod("stringGet")
protected lazy val stringDrop = selectPredefMethod("stringDrop")

/** Make a term that looks like `tupleGet(t, i)`. */
protected final def callTupleGet(t: Term, i: Int, label: Str): Ctxl[Term] =
callTupleGet(t, i, FlowSymbol(label))

/** Make a term that looks like `tupleGet(t, i)`. */
protected final def callTupleGet(t: Term, i: Int, s: FlowSymbol): Ctxl[Term] =
app(tupleGet, tup(fld(t), fld(int(i))), s)

/** Make a term that looks like `stringStartsWith(t, p)`. */
protected final def callStringStartsWith(t: Term, p: Term, label: Str): Ctxl[Term] =
app(stringStartsWith, tup(fld(t), fld(p)), FlowSymbol(label))

/** Make a term that looks like `stringStartsWith(t, i)`. */
protected final def callStringGet(t: Term, i: Int, label: Str): Ctxl[Term] =
app(stringGet, tup(fld(t), fld(int(i))), FlowSymbol(label))

/** Make a term that looks like `stringStartsWith(t, n)`. */
protected final def callStringDrop(t: Term, n: Int, label: Str): Ctxl[Term] =
app(stringDrop, tup(fld(t), fld(int(n))), FlowSymbol(label))

protected final def tempLet(dbgName: Str, term: Term)(inner: TempSymbol => Split): Split =
val s = TempSymbol(N, dbgName)
Split.Let(s, term, inner(s))

protected final def plainTest(cond: Term, dbgName: Str = "cond")(inner: => Split): Split =
val s = TempSymbol(N, dbgName)
Split.Let(s, cond, Branch(s.ref(), inner) ~: Split.End)

/** Make a `Branch` that calls `Pattern` symbols' `unapply` functions. */
def makeUnapplyBranch(
scrut: => Term.Ref,
clsTerm: Term,
inner: => Split,
method: Str = "unapply"
)(fallback: Split): Ctxl[Split] =
val call = app(sel(clsTerm, method), tup(fld(scrut)), FlowSymbol(s"result of $method"))
tempLet("matchResult", call): resultSymbol =>
Branch(resultSymbol.ref(), matchResultPattern(N), inner) ~: fallback

/** Make a `Branch` that calls `Pattern` symbols' `unapplyStringPrefix` functions. */
def makeUnapplyStringPrefixBranch(
scrut: => Term.Ref,
clsTerm: Term,
inner: TempSymbol => Split,
method: Str = "unapplyStringPrefix"
)(fallback: Split): Ctxl[Split] =
val call = app(sel(clsTerm, method), tup(fld(scrut)), FlowSymbol(s"result of $method"))
tempLet("matchResult", call): resultSymbol =>
val argSym = TempSymbol(N, "arg")
Branch(
resultSymbol.ref(),
matchResultPattern(S(argSym :: Nil)),
tempLet("postfix", callTupleGet(argSym.ref(), 0, "postfix"))(inner)
) ~: fallback
Loading
Loading