Skip to content

Commit

Permalink
Merge pull request #2792 from informalsystems/shon/fix-wildard-match
Browse files Browse the repository at this point in the history
Add support for default cases in quint match expressions
  • Loading branch information
Shon Feder authored Dec 1, 2023
2 parents 30fe7d4 + 0bca067 commit 67cb5ff
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
1 change: 1 addition & 0 deletions .unreleased/bug-fixes/fix-default-cases.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix missing support for default match cases in quint conversion (#2792)
33 changes: 30 additions & 3 deletions tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class Quint(quintOutput: QuintOutput) {

// the quint builtin operator representing match expressions looks like
//
// matchVariant(expr, "F1", elim_1, ..., "Fn", elim_n)
// matchVariant(expr, "F1", elim_1, ..., "Fn", elim_n[, "_", defaultElim])
//
// Where each `elim_i` is an operator applying to value wrapped in field `Fi` of a variant.
//
Expand All @@ -425,13 +425,36 @@ class Quint(quintOutput: QuintOutput) {
// CASE VariantTag(expr) = "F1" -> elim_1(VariantGetUnsafe("F1", expr))
// [] ...
// [] VariantTag(expr) = "Fn" -> elim_n(VariantGetUnsafe("Fn", expr))
// [] OTHER -> defaultElim([])
//
// This ensures that we will apply the proper eliminator to the expected value
// associated with whatever tag is carried by the variant `expr`.
//
// The final, default case may not be present, in which case no `OTHER` case is
// constructed.
def matchVariant: Converter = variadicApp { case expr +: cases =>
val variantTagCondition = (caseTag) => tla.eql(tla.variantTag(expr), caseTag)

// Check the last case to see if there is a default case, which will need special treatment
// If a valid quint match expression has a default case, it will always be the last case
// in a match.
val (matchCases, defaultCase) = cases.grouped(2).toSeq match {
case Seq() =>
(Seq(), None) // A match expression with no cases is invalid: we let the builder handle the error
case allCases @ (cs :+ Seq(label, defaultElim)) =>
build(label) match {
case ValEx(TlaStr("_")) =>
// We have a default case, which is always paired with an eliminator that
// can be applied to the unit value (an empty record).
(cs, Some(tla.appOp(defaultElim, tla.rowRec(None))))
case _ =>
// All cases have match expressions
(allCases, None)
}
}

val casesInstructions: Seq[(T, T)] =
cases.grouped(2).toSeq.map { case Seq(label, elim) =>
matchCases.map { case Seq(label, elim) =>
val appliedElim = label.flatMap {
case ValEx(TlaStr(labelLit)) =>
tla.appOp(elim, tla.variantGetUnsafe(labelLit, expr))
Expand All @@ -440,7 +463,11 @@ class Quint(quintOutput: QuintOutput) {
}
variantTagCondition(label) -> appliedElim
}
tla.caseSplit(casesInstructions: _*)

defaultCase match {
case None => tla.caseSplit(casesInstructions: _*)
case Some(default) => tla.caseOther(default, casesInstructions: _*)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class TestQuintEx extends AnyFunSuite {
val t = e(QuintStr(uid, "t"), QuintStrT())
val labelF1 = e(QuintStr(uid, "F1"), QuintStrT())
val labelF2 = e(QuintStr(uid, "F2"), QuintStrT())
val labelF3 = e(QuintStr(uid, "F3"), QuintStrT())
val wildLabel = e(QuintStr(uid, "_"), QuintStrT())

// Names and parameters
val name = e(QuintName(uid, "n"), QuintIntT())
Expand Down Expand Up @@ -611,7 +613,7 @@ class TestQuintEx extends AnyFunSuite {
}

test("can convert builtin matchVariant operator application") {
val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintRecordT.ofFieldTypes())
val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintRecordT.ofFieldTypes(), "F3" -> QuintIntT())
val variant = Q.app("variant", Q.labelF1, Q._42)(typ)
val quintMatch = Q.app(
"matchVariant",
Expand All @@ -620,10 +622,13 @@ class TestQuintEx extends AnyFunSuite {
Q.lam(Seq("x" -> QuintIntT()), Q._1, QuintIntT()),
Q.labelF2,
Q.lam(Seq("y" -> QuintRecordT.ofFieldTypes()), Q._2, QuintIntT()),
Q.wildLabel,
Q.lam(Seq("_" -> QuintVarT("t")), Q._2, QuintIntT()), // Default case
)(typ)
val expected =
"""|CASE (Variants!VariantTag(Variants!Variant("F1", 42)) = "F1") → LET __QUINT_LAMBDA0(x) ≜ 1 IN __QUINT_LAMBDA0(Variants!VariantGetUnsafe("F1", Variants!Variant("F1", 42)))
|☐ (Variants!VariantTag(Variants!Variant("F1", 42)) = "F2") → LET __QUINT_LAMBDA1(y) ≜ 2 IN __QUINT_LAMBDA1(Variants!VariantGetUnsafe("F2", Variants!Variant("F1", 42)))""".stripMargin
|☐ (Variants!VariantTag(Variants!Variant("F1", 42)) = "F2") → LET __QUINT_LAMBDA1(y) ≜ 2 IN __QUINT_LAMBDA1(Variants!VariantGetUnsafe("F2", Variants!Variant("F1", 42)))
|☐ OTHER → LET __QUINT_LAMBDA2(_) ≜ 2 IN __QUINT_LAMBDA2([])""".stripMargin
.replace('\n', ' ')
assert(convert(quintMatch) == expected)
}
Expand Down

0 comments on commit 67cb5ff

Please sign in to comment.