diff --git a/.unreleased/bug-fixes/fix-default-cases.md b/.unreleased/bug-fixes/fix-default-cases.md new file mode 100644 index 0000000000..8b221a853c --- /dev/null +++ b/.unreleased/bug-fixes/fix-default-cases.md @@ -0,0 +1 @@ +Fix missing support for default match cases in quint conversion (#2792) diff --git a/tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala b/tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala index fa43440e7e..9b6f2fab55 100644 --- a/tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala +++ b/tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala @@ -416,7 +416,7 @@ class Quint(quintOutput: QuintOutput) { // the quint builtin operator representing match expressions looks like // - // matchVariant(expr, "F1", elim_1, ..., "Fn", elim_n) + // matchVariant(expr, "F1", elim_1, ..., "Fn", elim_n[, "_", defaultElim]) // // Where each `elim_i` is an operator applying to value wrapped in field `Fi` of a variant. // @@ -425,13 +425,36 @@ class Quint(quintOutput: QuintOutput) { // CASE VariantTag(expr) = "F1" -> elim_1(VariantGetUnsafe("F1", expr)) // [] ... // [] VariantTag(expr) = "Fn" -> elim_n(VariantGetUnsafe("Fn", expr)) + // [] OTHER -> defaultElim([]) // // This ensures that we will apply the proper eliminator to the expected value // associated with whatever tag is carried by the variant `expr`. + // + // The final, default case may not be present, in which case no `OTHER` case is + // constructed. def matchVariant: Converter = variadicApp { case expr +: cases => val variantTagCondition = (caseTag) => tla.eql(tla.variantTag(expr), caseTag) + + // Check the last case to see if there is a default case, which will need special treatment + // If a valid quint match expression has a default case, it will always be the last case + // in a match. + val (matchCases, defaultCase) = cases.grouped(2).toSeq match { + case Seq() => + (Seq(), None) // A match expression with no cases is invalid: we let the builder handle the error + case allCases @ (cs :+ Seq(label, defaultElim)) => + build(label) match { + case ValEx(TlaStr("_")) => + // We have a default case, which is always paired with an eliminator that + // can be applied to the unit value (an empty record). + (cs, Some(tla.appOp(defaultElim, tla.rowRec(None)))) + case _ => + // All cases have match expressions + (allCases, None) + } + } + val casesInstructions: Seq[(T, T)] = - cases.grouped(2).toSeq.map { case Seq(label, elim) => + matchCases.map { case Seq(label, elim) => val appliedElim = label.flatMap { case ValEx(TlaStr(labelLit)) => tla.appOp(elim, tla.variantGetUnsafe(labelLit, expr)) @@ -440,7 +463,11 @@ class Quint(quintOutput: QuintOutput) { } variantTagCondition(label) -> appliedElim } - tla.caseSplit(casesInstructions: _*) + + defaultCase match { + case None => tla.caseSplit(casesInstructions: _*) + case Some(default) => tla.caseOther(default, casesInstructions: _*) + } } } diff --git a/tla-io/src/test/scala/at/forsyte/apalache/io/quint/TestQuintEx.scala b/tla-io/src/test/scala/at/forsyte/apalache/io/quint/TestQuintEx.scala index a7ac8f6644..c37be76287 100644 --- a/tla-io/src/test/scala/at/forsyte/apalache/io/quint/TestQuintEx.scala +++ b/tla-io/src/test/scala/at/forsyte/apalache/io/quint/TestQuintEx.scala @@ -115,6 +115,8 @@ class TestQuintEx extends AnyFunSuite { val t = e(QuintStr(uid, "t"), QuintStrT()) val labelF1 = e(QuintStr(uid, "F1"), QuintStrT()) val labelF2 = e(QuintStr(uid, "F2"), QuintStrT()) + val labelF3 = e(QuintStr(uid, "F3"), QuintStrT()) + val wildLabel = e(QuintStr(uid, "_"), QuintStrT()) // Names and parameters val name = e(QuintName(uid, "n"), QuintIntT()) @@ -611,7 +613,7 @@ class TestQuintEx extends AnyFunSuite { } test("can convert builtin matchVariant operator application") { - val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintRecordT.ofFieldTypes()) + val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintRecordT.ofFieldTypes(), "F3" -> QuintIntT()) val variant = Q.app("variant", Q.labelF1, Q._42)(typ) val quintMatch = Q.app( "matchVariant", @@ -620,10 +622,13 @@ class TestQuintEx extends AnyFunSuite { Q.lam(Seq("x" -> QuintIntT()), Q._1, QuintIntT()), Q.labelF2, Q.lam(Seq("y" -> QuintRecordT.ofFieldTypes()), Q._2, QuintIntT()), + Q.wildLabel, + Q.lam(Seq("_" -> QuintVarT("t")), Q._2, QuintIntT()), // Default case )(typ) val expected = """|CASE (Variants!VariantTag(Variants!Variant("F1", 42)) = "F1") → LET __QUINT_LAMBDA0(x) ≜ 1 IN __QUINT_LAMBDA0(Variants!VariantGetUnsafe("F1", Variants!Variant("F1", 42))) - |☐ (Variants!VariantTag(Variants!Variant("F1", 42)) = "F2") → LET __QUINT_LAMBDA1(y) ≜ 2 IN __QUINT_LAMBDA1(Variants!VariantGetUnsafe("F2", Variants!Variant("F1", 42)))""".stripMargin + |☐ (Variants!VariantTag(Variants!Variant("F1", 42)) = "F2") → LET __QUINT_LAMBDA1(y) ≜ 2 IN __QUINT_LAMBDA1(Variants!VariantGetUnsafe("F2", Variants!Variant("F1", 42))) + |☐ OTHER → LET __QUINT_LAMBDA2(_) ≜ 2 IN __QUINT_LAMBDA2([])""".stripMargin .replace('\n', ' ') assert(convert(quintMatch) == expected) }