Skip to content

Commit

Permalink
Merge pull request #2783 from informalsystems/quint/1034/sum-types
Browse files Browse the repository at this point in the history
Support quint sum types
  • Loading branch information
Shon Feder authored Dec 1, 2023
2 parents 2d9a8e2 + 2f52f8e commit cb437eb
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 126 deletions.
118 changes: 87 additions & 31 deletions tla-io/src/main/scala/at/forsyte/apalache/io/quint/Quint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import scalaz._
import scalaz.std.list._
import scalaz.syntax.traverse._

import scala.util.Try
import scala.util.{Failure, Success, Try}
import at.forsyte.apalache.tla.lir.values.TlaStr

// Convert a QuintEx into a TlaEx
//
Expand Down Expand Up @@ -54,15 +55,15 @@ class Quint(quintOutput: QuintOutput) {
private type NullaryOpReader[A] = Reader[Set[String], A]

// Find the type for an id via the lookup table provided in the quint output
private def getTypeFromLookupTable(id: BigInt): QuintType = {
private def getTypeFromLookupTable(id: BigInt): Try[QuintType] = {
table.get(id) match {
case None => throw new QuintIRParseError(s"No entry found for id ${id} in lookup table")
case None => Failure(new QuintIRParseError(s"No entry found for id ${id} in lookup table"))
case Some(lookupEntry) =>
types.get(lookupEntry.id) match {
case None =>
throw new QuintIRParseError(
s"No type found for definition ${lookupEntry.name} (${lookupEntry.id}) associated with id ${id}")
case Some(t) => t.typ
Failure(new QuintIRParseError(
s"No type found for definition ${lookupEntry.name} (${lookupEntry.id}) associated with id ${id}"))
case Some(t) => Success(t.typ)
}
}
}
Expand Down Expand Up @@ -371,29 +372,75 @@ class Quint(quintOutput: QuintOutput) {
})

// Create a TLA record
def record(rowVar: Option[String]): Converter = {
case Seq() => throw new QuintUnsupportedError("Given empty record, but Apalache doesn't support empty records.")
case quintArgs =>
// The quint Rec operator takes its field and value arguments
// via a variadic operator requiring field names passed as strings to
// be alternated with values. E.g.,
//
// Rec("f1", 1, "f2", 2)
//
// So we first separate out the field names from the values, so we
// can make use of the existing combinator for variadic operators.
val (fieldNames, quintVals) = quintArgs
.grouped(2)
.foldRight((List[String](), List[QuintEx]())) {
case (Seq(QuintStr(_, f), v), (fields, values)) => ((f :: fields), v :: values)
case (invalidArgs, _) =>
throw new QuintIRParseError(s"Invalid argument given to Rec ${invalidArgs}")
}
variadicApp { tlaVals =>
val fieldsAndArgs = fieldNames.zip(tlaVals)
tla.rowRec(rowVar, fieldsAndArgs: _*)
}(quintVals)
def record(rowVar: Option[String]): Converter = { quintArgs =>
// The quint Rec operator takes its field and value arguments
// via a variadic operator requiring field names passed as strings to
// be alternated with values. E.g.,
//
// Rec("f1", 1, "f2", 2)
//
// So we first separate out the field names from the values, so we
// can make use of the existing combinator for variadic operators.
//
// Empty records are fine: those are the unit value.
val (fieldNames, quintVals) = quintArgs
.grouped(2)
.foldRight((List[String](), List[QuintEx]())) {
case (Seq(QuintStr(_, f), v), (fields, values)) => ((f :: fields), v :: values)
case (invalidArgs, _) =>
throw new QuintIRParseError(s"Invalid argument given to Rec ${invalidArgs}")
}
variadicApp { tlaVals =>
val fieldsAndArgs = fieldNames.zip(tlaVals)
tla.rowRec(rowVar, fieldsAndArgs: _*)
}(quintVals)
}

// Create a TLA variant
def variant(quintType: QuintType): Converter = {
val tlaType = typeConv.convert(quintType)
tlaType match {
case variantType: VariantT1 =>
binaryApp("variant",
(labelInstruction, expr) =>
// The builder requires a string literal, rather than a string expression
// so we have to build the converted TLA expression and extract its string value.
labelInstruction.flatMap {
case ValEx(TlaStr(label)) => tla.variant(label, expr, variantType)
case invalidLabel =>
throw new QuintIRParseError(s"Invalid label found in application of variant ${invalidLabel}")
})
case _ => throw new QuintIRParseError(s"Invalid type inferred for application of variant ${quintType}")
}
}

// the quint builtin operator representing match expressions looks like
//
// matchVariant(expr, "F1", elim_1, ..., "Fn", elim_n)
//
// Where each `elim_i` is an operator applying to value wrapped in field `Fi` of a variant.
//
// This is converted into the following Apalache expression, using Apalache's variant operators:
//
// CASE VariantTag(expr) = "F1" -> elim_1(VariantGetUnsafe("F1", expr))
// [] ...
// [] VariantTag(expr) = "Fn" -> elim_n(VariantGetUnsafe("Fn", expr))
//
// This ensures that we will apply the proper eliminator to the expected value
// associated with whatever tag is carried by the variant `expr`.
def matchVariant: Converter = variadicApp { case expr +: cases =>
val variantTagCondition = (caseTag) => tla.eql(tla.variantTag(expr), caseTag)
val casesInstructions: Seq[(T, T)] =
cases.grouped(2).toSeq.map { case Seq(label, elim) =>
val appliedElim = label.flatMap {
case ValEx(TlaStr(labelLit)) =>
tla.appOp(elim, tla.variantGetUnsafe(labelLit, expr))
case invalidLabel =>
throw new QuintIRParseError(s"Invalid label found in matchVariant case ${invalidLabel}")
}
variantTagCondition(label) -> appliedElim
}
tla.caseSplit(casesInstructions: _*)
}
}

Expand Down Expand Up @@ -513,6 +560,10 @@ class Quint(quintOutput: QuintOutput) {
case "fieldNames" => unaryApp(opName, tla.dom)
case "with" => ternaryApp(opName, tla.except)

// Sum types
case "variant" => MkTla.variant(types(id).typ)
case "matchVariant" => MkTla.matchVariant

// Maps (functions)
// Map is variadic on n tuples, so build a set of these tuple args
// before converting the resulting set of tuples to a function.
Expand Down Expand Up @@ -555,7 +606,12 @@ class Quint(quintOutput: QuintOutput) {

// Otherwise, the applied operator is defined, and not a builtin
case definedOpName => { args =>
val operType = typeConv.convert(getTypeFromLookupTable(id))
val quintType = getTypeFromLookupTable(id).recoverWith { case err: QuintIRParseError =>
Failure(new QuintIRParseError(
s"While converting operator application of defined operator '${definedOpName}' to arguments ${args}: ${err
.getMessage()}"))
}.get
val operType = typeConv.convert(quintType)
val oper = tla.name(definedOpName, operType)
args.toList.traverse(tlaExpression).map(tlaArgs => tla.appOp(oper, tlaArgs: _*))
}
Expand Down Expand Up @@ -639,10 +695,10 @@ class Quint(quintOutput: QuintOutput) {
case app: QuintApp => tlaApplication(app)
}

// `tlaDef(quintDef)` is a NullaryOpReader that can be run to obtain a value `Some((tlaDecl, maybeName))`
// `tlaDef(quintDef)` is a NullaryOpReader that can be run to obtain a value `Some((maybeName, tlaDecl))`
// where `tlaDecl` is derived from the given `quintDef`, and `maybeName` is `Some(n)` when the `quintDef`
// defines a nullary operator named `n`, or `None` when `quintDef` is not a nullary operator definition.
// If the `quintDef` is not convertable (e.g., a quint type definition), it the outer value is `None`.
// If the `quintDef` is not convertable (e.g., a quint type definition), then the outer value is `None`.
def tlaDef(quintDef: QuintDef): NullaryOpReader[Option[(Option[String], TlaDecl)]] = {
import QuintDef._
Reader(nullaryOps =>
Expand Down
17 changes: 9 additions & 8 deletions tla-io/src/main/scala/at/forsyte/apalache/io/quint/QuintIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ private[quint] object QuintType {
QuintOperT.rw,
QuintTupleT.rw,
QuintRecordT.rw,
QuintUnionT.rw,
QuintSumT.rw,
)

// NOTE: Contrary to quint values, for quint types, source IDs are optional.
Expand Down Expand Up @@ -500,13 +500,14 @@ private[quint] object QuintType {
}
}

case class UnionRecord(tagValue: String, fields: Row)
object UnionRecord {
implicit val rw: RW[UnionRecord] = macroRW
}
@key("sum") case class QuintSumT(fields: Row) extends QuintType
object QuintSumT {
implicit val rw: RW[QuintSumT] = macroRW

@key("union") case class QuintUnionT(tag: String, records: Seq[UnionRecord]) extends QuintType
object QuintUnionT {
implicit val rw: RW[QuintUnionT] = macroRW
// Helper for manually constructing record type
def ofVariantTypes(variantTypes: (String, QuintType)*): QuintSumT = {
val fields = variantTypes.map { case (f, t) => RecordField(f, t) }
QuintSumT(Row.Cell(fields, Row.Nil()))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,61 +78,19 @@ private class QuintTypeConverter extends LazyLogging {
}
}

// Convert a quint union to a TlaType1 row (which is used to represent variants)
//
// NOTE: Union types in quint aren't fully implemented and supported, so this
// corner of the transformation is likely to require update soon.
// See https://github.com/informalsystems/quint/issues/244
//
// In quint, unions are currently represented by a list of tagged rows.
// E.g., (abstracting rom the concrete type representation):
//
// ```
// type u =
// | ( "Foo", {a: Int, b: String })
// | ( "Bar", {c: Set[Int] })
// ```
//
// But Variant types in Apalache are represented by a single row, in which
// the row's keys are the tags, and it's values can be of any type, e.g.:
//
// ```
// type u = { "Foo": { a: Int, b: Str }
// , "Bar": Set[Int]
// }
// ```
//
// Which we parse and represent as
//
// ```
// @typeAlias: u = Foo({a: Int, b: Str}) | Bar(Int);
// ```
//
// As a result, our conversion from quint has to take a list of records of quint
// rows and convert them into a single TlaType1 record, for which all the values
// are themselves records, and the keys are given by the values of the `tag`
// field from quint rows.
private def unionToRowT1(variants: Seq[UnionRecord]): RowT1 = {
val fieldTypes = variants.map {
case UnionRecord(tag, row) => {
(tag, RecRowT1(rowToRowT1(row)))
}
}
RowT1(fieldTypes: _*)
}

val convert: QuintType => TlaType1 = {
case QuintBoolT() => BoolT1
case QuintIntT() => IntT1
case QuintStrT() => StrT1
case QuintConstT(name) => ConstT1(name)
case QuintVarT(name) => VarT1(getVarNo(name))
case QuintSetT(elem) => SetT1(convert(elem))
case QuintSeqT(elem) => SeqT1(convert(elem))
case QuintFunT(arg, res) => FunT1(convert(arg), convert(res))
case QuintOperT(args, res) => OperT1(args.map(convert), convert(res))
case QuintTupleT(row) => rowToTupleT1(row)
case QuintRecordT(row) => RecRowT1(rowToRowT1(row))
case QuintUnionT(_, variants) => VariantT1(unionToRowT1(variants))
case QuintBoolT() => BoolT1
case QuintIntT() => IntT1
case QuintStrT() => StrT1
case QuintConstT(name) =>
ConstT1(name) // TODO: Raise error if we hit this. See https://github.com/informalsystems/apalache/issues/2788
case QuintVarT(name) => VarT1(getVarNo(name))
case QuintSetT(elem) => SetT1(convert(elem))
case QuintSeqT(elem) => SeqT1(convert(elem))
case QuintFunT(arg, res) => FunT1(convert(arg), convert(res))
case QuintOperT(args, res) => OperT1(args.map(convert), convert(res))
case QuintTupleT(row) => rowToTupleT1(row)
case QuintRecordT(row) => RecRowT1(rowToRowT1(row))
case QuintSumT(row) => VariantT1(rowToRowT1(row))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class TestQuintEx extends AnyFunSuite {
val _42 = e(QuintInt(uid, 42), QuintIntT())
val s = e(QuintStr(uid, "s"), QuintStrT())
val t = e(QuintStr(uid, "t"), QuintStrT())
val labelF1 = e(QuintStr(uid, "F1"), QuintStrT())
val labelF2 = e(QuintStr(uid, "F2"), QuintStrT())

// Names and parameters
val name = e(QuintName(uid, "n"), QuintIntT())
Expand Down Expand Up @@ -518,18 +520,16 @@ class TestQuintEx extends AnyFunSuite {
QuintSeqT(QuintIntT()))) == "Apalache!MkSeq(42 - 3, LET __QUINT_LAMBDA0(__quint_var0) ≜ (3 + __quint_var0) - 1 IN __QUINT_LAMBDA0)")
}

/// RECORDS

test("can convert builtin Rec operator application") {
val typ = QuintRecordT.ofFieldTypes(("s", QuintIntT()), ("t", QuintIntT()))
assert(convert(Q.app("Rec", Q.s, Q._1, Q.t, Q._2)(typ)) == """["s" ↦ 1, "t" ↦ 2]""")
}

test("converting builtin Rec operator constructing empty record fails") {
val exn = intercept[QuintUnsupportedError] {
val typ = QuintRecordT.ofFieldTypes()
convert(Q.app("Rec")(typ))
}
assert(exn.getMessage.contains(
"Unsupported quint input: Given empty record, but Apalache doesn't support empty records."))
test("can convert builtin Rec operator constructing an empty record -- the unit type") {
val typ = QuintRecordT.ofFieldTypes()
assert(convert(Q.app("Rec")(typ)) == "[]")
}

test("can convert row-polymorphic record") {
Expand Down Expand Up @@ -584,6 +584,8 @@ class TestQuintEx extends AnyFunSuite {
assert(tlaOpDef.typeTag == Typed(expectedTlaType))
}

/// TUPLES

test("can convert builtin Tup operator application") {
assert(convert(Q.app("Tup", Q._0, Q._1)(QuintTupleT.ofTypes(QuintIntT(), QuintIntT()))) == "<<0, 1>>")
}
Expand All @@ -601,6 +603,31 @@ class TestQuintEx extends AnyFunSuite {
assert(convert(Q.app("tuples", Q.intSet, Q.intSet, Q.intSet)(typ)) == "{1, 2, 3} × {1, 2, 3} × {1, 2, 3}")
}

/// SUM TYPES

test("can convert builtin variant operator application") {
val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintIntT())
assert(convert(Q.app("variant", Q.labelF1, Q._42)(typ)) == """Variants!Variant("F1", 42)""")
}

test("can convert builtin matchVariant operator application") {
val typ = QuintSumT.ofVariantTypes("F1" -> QuintIntT(), "F2" -> QuintRecordT.ofFieldTypes())
val variant = Q.app("variant", Q.labelF1, Q._42)(typ)
val quintMatch = Q.app(
"matchVariant",
variant,
Q.labelF1,
Q.lam(Seq("x" -> QuintIntT()), Q._1, QuintIntT()),
Q.labelF2,
Q.lam(Seq("y" -> QuintRecordT.ofFieldTypes()), Q._2, QuintIntT()),
)(typ)
val expected =
"""|CASE (Variants!VariantTag(Variants!Variant("F1", 42)) = "F1") → LET __QUINT_LAMBDA0(x) ≜ 1 IN __QUINT_LAMBDA0(Variants!VariantGetUnsafe("F1", Variants!Variant("F1", 42)))
|☐ (Variants!VariantTag(Variants!Variant("F1", 42)) = "F2") → LET __QUINT_LAMBDA1(y) ≜ 2 IN __QUINT_LAMBDA1(Variants!VariantGetUnsafe("F2", Variants!Variant("F1", 42)))""".stripMargin
.replace('\n', ' ')
assert(convert(quintMatch) == expected)
}

test("can convert builtin assert operator") {
assert(convert(Q.app("assert", Q.nIsGreaterThan0)(QuintBoolT())) == "n > 0")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,15 @@ class TestQuintTypes extends AnyFunSuite {
assert(translate(record) == RecRowT1(RowT1(("f1" -> IntT1), ("f2" -> StrT1))))
}

test("Quint unions are converted into TlaType1 variants") {
val opt1 =
// i.e.: {tag: "t1", f1: int}
UnionRecord("t1", Row.Cell(List(RecordField("f1", QuintIntT())), Row.Nil()))
val opt2 =
// i.e.: {tag: "t2", f2: string}
UnionRecord("t2", Row.Cell(List(RecordField("f2", QuintStrT())), Row.Nil()))
val variant =
// i.e.: | {tag: "t1", f1: int} | {tag: "t2", f2: string}
QuintUnionT("tag", List(opt1, opt2))
test("Quint sum types are converted into TlaType1 variants") {
val quintSumType =
// i.e.: F1(int) | F2(str)
QuintSumT(Row.Cell(List(RecordField("F1", QuintIntT()), RecordField("F2", QuintStrT())), Row.Nil()))

val expectedVariant =
// i.e.: t1({ f1: Int }) | t2({ f2: Str })
VariantT1(RowT1("t1" -> RecRowT1(RowT1(("f1" -> IntT1))), "t2" -> RecRowT1(RowT1(("f2" -> StrT1)))))
assert(translate(variant) == expectedVariant)
// i.e.: F1(Int) | F1(Str)
VariantT1(RowT1("F1" -> IntT1, "F2" -> StrT1))
assert(translate(quintSumType) == expectedVariant)
}

// tictactoe.json is located in tla-io/src/test/resources/tictactoe.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object TlaFunOper {
* is why we call it "REC_CTOR".</p>
*/
object rec extends TlaFunOper {
override def arity: OperArity = AnyEvenArity() && MinimalArity(2)
override def arity: OperArity = AnyEvenArity()

override val name: String = "RECORD"
override val precedence: (Int, Int) = (16, 16) // as the function application
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class UnsafeFunBuilder extends ProtoBuilder {
private def mkTlaStr: String => TlaEx = strBuilder.str

private def formRecordFieldTypes(args: Seq[TlaEx]): SortedMap[String, TlaType1] = {
require(TlaFunOper.rec.arity.cond(args.size), s"Expected record args to have even, positive arity, found $args.")
require(TlaFunOper.rec.arity.cond(args.size), s"Expected record args to have even arity, found $args.")
// All keys must be ValEx(TlaStr(_))
val (keys, _) = TlaOper.deinterleave(args)
require(keys.forall {
Expand Down
Loading

0 comments on commit cb437eb

Please sign in to comment.