-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add cache + tests * Apply suggestions from code review Co-authored-by: Shon Feder <[email protected]> * PR comments * Pr comments * test fix * Update tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala Co-authored-by: Thomas Pani <[email protected]> --------- Co-authored-by: Shon Feder <[email protected]> Co-authored-by: Thomas Pani <[email protected]>
- Loading branch information
1 parent
b374361
commit d9a9e90
Showing
2 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
80 changes: 80 additions & 0 deletions
80
.../at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches | ||
|
||
import at.forsyte.apalache.tla.bmcmt.smt.SolverContext | ||
import at.forsyte.apalache.tla.bmcmt.types.{CellT, CellTFrom} | ||
import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena} | ||
import at.forsyte.apalache.tla.lir.{ConstT1, StrT1, TlaType1} | ||
import at.forsyte.apalache.tla.types.tla | ||
|
||
/** | ||
* A cache for uninterpreted literals, that are translated to uninterpreted SMT constants, with a unique sort per | ||
* uninterpreted type. Since two values are equal iff they are literally the same literal, we force inequality between | ||
* all the respective SMT constants. | ||
* | ||
* Note that Strings are just a special kind of uninterpreted type. | ||
* | ||
* @author | ||
* Jure Kukovec | ||
*/ | ||
class UninterpretedLiteralCache extends Cache[PureArena, (TlaType1, String), ArenaCell] { | ||
|
||
/** | ||
* Given a pair `(utype,idx)`, where `utype` represents an uninterpreted type name (possibly "Str") and `idx` some | ||
* unique index within that type, returns an extension of `arena`, containing a cell, which represents "idx_OF_utype" | ||
* (or "idx", if utype = "Str"), and said cell. | ||
* | ||
* Note that two values are equal (and get cached to the same cell) iff they have the same type and the same index, so | ||
* e.g. "1_OF_A" and "1_OF_B" (passed here as ("A", "1") and ("B", "1")) get cached to different, incomparable cells, | ||
* despite having the same index "1". | ||
*/ | ||
protected def create( | ||
arena: PureArena, | ||
typeAndIndex: (TlaType1, String)): (PureArena, ArenaCell) = { | ||
val (utype, _) = typeAndIndex | ||
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.") | ||
// introduce a new cell | ||
val newArena = arena.appendCell(CellT.fromType1(utype)) | ||
(newArena, newArena.topCell) | ||
} | ||
|
||
/** | ||
* The UninterpretedLiteralCache maintains that a cell cache for a value `idx` of type `tp` is distinct from all other | ||
* values of type `tp` (defined so far). | ||
* | ||
* Whenever possible, try to use [[addAllConstraints]] instead of this method, for performance reasons instead: | ||
* | ||
* If we consider a naive implementation of `distinct(a1,..., an)` as `a1 != a2 /\ a1 != a3 /\ ... /\ a{n-1} != an`, a | ||
* `distinct` with `n` elements is equivalent to `dn = n(n-1)/2` disequalities. Suppose we end up with a collection of | ||
* `N` cache values (of a given type). If we called `addConstaintsForElem` after each addition, we'd end up with `d1 + | ||
* d2 + ... + dN` disequalities, i.e. {{{\sum_{n=1}^N n(n-1)/2 = N(N^2 -1)/6}}} In contrast, `addAllConstraints` | ||
* produces `dN = N(N-1)/2` disequalities, which is `O(N^2)`, instead of `O(N^3)`. | ||
*/ | ||
override def addConstraintsForElem(ctx: SolverContext): (((TlaType1, String), ArenaCell)) => Unit = { | ||
case ((utype, _), v) => | ||
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.") | ||
val others = values().withFilter { c => c.cellType == CellTFrom(utype) && c != v }.map(_.toBuilder).toSeq | ||
// The cell should differ from the previously created cells. | ||
// We use the SMT constraint (distinct ...). | ||
ctx.assertGroundExpr(tla.distinct(v.toBuilder +: others: _*)) | ||
} | ||
|
||
/** | ||
* A more efficient implementation, compared to the default one, as it introduces exactly one SMT `distinct` for each | ||
* uninterpreted type instead of one `distinct` per cell. | ||
*/ | ||
override def addAllConstraints(ctx: SolverContext): Unit = { | ||
val utypes = cache.keySet.map { _._1 } | ||
|
||
val initMap = utypes.map { _ -> Set.empty[ArenaCell] }.toMap | ||
|
||
val cellsByUtype = cache.foldLeft(initMap) { case (map, ((utype, _), (cell, _))) => | ||
map + (utype -> (map(utype) + cell)) | ||
} | ||
|
||
// For each utype, all cells of that type are distinct | ||
cellsByUtype.foreach { case (_, cells) => | ||
ctx.assertGroundExpr(tla.distinct(cells.toSeq.map { _.toBuilder }: _*)) | ||
} | ||
|
||
} | ||
} |
147 changes: 147 additions & 0 deletions
147
...ala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/UninterpretedLiteralCacheTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux | ||
|
||
import at.forsyte.apalache.tla.bmcmt.PureArena | ||
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache | ||
import at.forsyte.apalache.tla.lir.{StrT1, TlaType1} | ||
import at.forsyte.apalache.tla.types.{tla, ModelValueHandler} | ||
import org.junit.runner.RunWith | ||
import org.scalatest.BeforeAndAfterEach | ||
import org.scalatest.funsuite.AnyFunSuite | ||
import org.scalatestplus.junit.JUnitRunner | ||
|
||
@RunWith(classOf[JUnitRunner]) | ||
class UninterpretedLiteralCacheTest extends AnyFunSuite with BeforeAndAfterEach { | ||
|
||
var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache | ||
|
||
def tpAndIdx(s: String): (TlaType1, String) = { | ||
val (utype, idx) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s)) | ||
(utype, idx) | ||
} | ||
|
||
override def beforeEach(): Unit = { | ||
cache = new UninterpretedLiteralCache | ||
} | ||
|
||
test("Cache returns stored values after the first call to getOrCreate") { | ||
val str: String = "idx" | ||
|
||
val utypeAndIdx = tpAndIdx(str) | ||
|
||
val arena = PureArena.empty | ||
|
||
// No cached value for the pair | ||
assert(cache.get(utypeAndIdx).isEmpty) | ||
|
||
val (newArena, iCell) = cache.getOrCreate(arena, utypeAndIdx) | ||
|
||
// pair now cached, arena has changed | ||
assert(cache.get(utypeAndIdx).nonEmpty && newArena != arena) | ||
|
||
val (newArena2, iCell2) = cache.getOrCreate(newArena, utypeAndIdx) | ||
|
||
// 2nd call returns the _same_ arena and the previously computed cell | ||
assert(newArena == newArena2 && iCell == iCell2) | ||
} | ||
|
||
test("Same index of different types is cached separately") { | ||
val str1: String = "idx" | ||
val str2: String = "idx_OF_A" | ||
val str3: String = "idx_OF_B" | ||
|
||
val pa1 = tpAndIdx(str1) | ||
val pa2 = tpAndIdx(str2) | ||
val pa3 = tpAndIdx(str3) | ||
|
||
val arena = PureArena.empty | ||
|
||
val (newArena1, cell1) = cache.getOrCreate(arena, pa1) | ||
|
||
assert(arena != newArena1) | ||
|
||
val (newArena2, cell2) = cache.getOrCreate(newArena1, pa2) | ||
|
||
assert(newArena2 != newArena1 && cell2 != cell1) | ||
|
||
val (newArena3, cell3) = cache.getOrCreate(newArena2, pa3) | ||
|
||
assert(newArena3 != newArena2 && cell3 != cell2) | ||
} | ||
|
||
test("Constraints are only added when addAllConstraints is explicitly called, and only once per value") { | ||
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext | ||
|
||
val str1: String = "1_OF_A" | ||
val str2: String = "2_OF_A" | ||
val str3: String = "3_OF_A" | ||
|
||
val pa1 = tpAndIdx(str1) | ||
val pa2 = tpAndIdx(str2) | ||
val pa3 = tpAndIdx(str3) | ||
|
||
val a0 = PureArena.empty | ||
val (a1, c1) = cache.getOrCreate(a0, pa1) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a0, pa1) | ||
cache.getOrCreate(a0, pa1) | ||
val (a2, c2) = cache.getOrCreate(a1, pa2) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a1, pa2) | ||
cache.getOrCreate(a1, pa2) | ||
val (_, c3) = cache.getOrCreate(a2, pa3) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a2, pa3) | ||
cache.getOrCreate(a2, pa3) | ||
|
||
assert(mockCtx.constraints.isEmpty) | ||
|
||
cache.addAllConstraints(mockCtx) | ||
|
||
// Due to the optimized `addAllConstraints` override, we only have 1 "distinct" | ||
assert(mockCtx.constraints == Seq( | ||
tla.distinct(c3.toBuilder, c2.toBuilder, c1.toBuilder).build | ||
)) | ||
} | ||
|
||
test("Constraints are only added when addConstraintsForElem is explicitly called, and only once per value") { | ||
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext | ||
|
||
val str1: String = "1_OF_A" | ||
val str2: String = "2_OF_A" | ||
val str3: String = "3_OF_A" | ||
|
||
val pa1 = tpAndIdx(str1) | ||
val pa2 = tpAndIdx(str2) | ||
val pa3 = tpAndIdx(str3) | ||
|
||
val a0 = PureArena.empty | ||
val (a1, c1) = cache.getOrCreate(a0, pa1) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a0, pa1) | ||
cache.getOrCreate(a0, pa1) | ||
|
||
cache.addConstraintsForElem(mockCtx)(pa1, c1) | ||
|
||
val (a2, c2) = cache.getOrCreate(a1, pa2) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a1, pa2) | ||
cache.getOrCreate(a1, pa2) | ||
|
||
cache.addConstraintsForElem(mockCtx)(pa2, c2) | ||
|
||
val (_, c3) = cache.getOrCreate(a2, pa3) | ||
// Some extra calls, which shouldn't affect constraint generation | ||
cache.getOrCreate(a2, pa3) | ||
cache.getOrCreate(a2, pa3) | ||
|
||
cache.addConstraintsForElem(mockCtx)(pa3, c3) | ||
|
||
// -ForElem creates 3 "distinct" constraints | ||
assert(mockCtx.constraints == Seq( | ||
tla.distinct(c1.toBuilder).build, | ||
tla.distinct(c2.toBuilder, c1.toBuilder).build, | ||
tla.distinct(c3.toBuilder, c1.toBuilder, c2.toBuilder).build, | ||
)) | ||
} | ||
|
||
} |