diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/ValueGenerator.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/ValueGenerator.scala index 327d495e16..23784c3d16 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/ValueGenerator.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/aux/ValueGenerator.scala @@ -151,19 +151,25 @@ class ValueGenerator(rewriter: SymbStateRewriter, bound: Int) { } private def genSet(state: SymbState, elemType: TlaType1): SymbState = { - var nextState = state - var elems: List[ArenaCell] = Nil - for (_ <- 1 to bound) { - nextState = gen(nextState, elemType) - elems = nextState.asCell :: elems - } val setType = CellT.fromType1(SetT1(elemType)) - nextState = nextState.updateArena(a => a.appendCellOld(setType)) - val setCell = nextState.arena.topCell - nextState = nextState.updateArena(a => a.appendHas(setCell, elems.map { FixedElemPtr }: _*)) + val stateWithSetCell = state.updateArena(a => a.appendCellOld(setType)) + val setCell = stateWithSetCell.arena.topCell + + val (stateWithGenElems, elemPtrs) = + 1.to(bound).foldLeft((stateWithSetCell, List.empty[ElemPtr])) { case ((s, ptrs), _) => + val nextState = gen(s, elemType) + val stateAsCell = nextState.asCell + // For Gen, not all elements necessarily belong to the set, so we should not use FixedElemPtr + // instead, the pointers should have unconstrained SMT constants as conditions. + // We can just use the edge predicate directly for those. + val ptr = SmtExprElemPtr(stateAsCell, tla.selectInSet(stateAsCell.toBuilder, setCell.toBuilder)) + (nextState, ptr :: ptrs) + } + + var nextState = stateWithGenElems.updateArena(a => a.appendHas(setCell, elemPtrs: _*)) // In the arrays encoding, set membership constraints are not generated in appendHas, so we add them below if (rewriter.solverContext.config.smtEncoding == SMTEncoding.Arrays) { - for (elem <- elems) { + for (elem <- elemPtrs.map(_.elem)) { nextState = nextState.updateArena(_.appendCell(BoolT1)) val pred = nextState.arena.topCell.toNameEx // TODO: when #1916 is closed, remove tlaLegacy and use tla directly @@ -174,7 +180,7 @@ class ValueGenerator(rewriter: SymbStateRewriter, bound: Int) { rewriter.solverContext.assertGroundExpr(ite) } } - nextState.setRex(setCell.toNameEx.withTag(Typed(SetT1(elemType)))) + nextState.setRex(setCell.toNameEx) } private def genSeq(state: SymbState, elemType: TlaType1): SymbState = { diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala index d2ea3536be..b3de34187d 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/TestSymbStateDecoder.scala @@ -2,9 +2,11 @@ package at.forsyte.apalache.tla.bmcmt import at.forsyte.apalache.infra.passes.options.SMTEncoding import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.lir.oper.TlaSetOper import at.forsyte.apalache.tla.typecomp.TBuilderInstruction -import at.forsyte.apalache.tla.types.tla._ import at.forsyte.apalache.tla.types.parser.DefaultType1Parser +import at.forsyte.apalache.tla.types.tla +import at.forsyte.apalache.tla.types.tla._ trait TestSymbStateDecoder extends RewriterBase { private val parser = DefaultType1Parser @@ -267,4 +269,62 @@ trait TestSymbStateDecoder extends RewriterBase { val decodedEx = decoder.decodeCellToTlaEx(nextState.arena, cell) assertBuildEqual(vrt1, decodedEx) } + + test("decode gen: Regression #1 for #2702") { rewriterType: SMTEncoding => + val valName = tla.name("gen", SetT1(IntT1)) + val genEx = tla.gen(1, SetT1(IntT1)) + val x = tla.name("x", IntT1) + val cond = tla.forall(x, valName, tla.eql(x, tla.int(0))) + + val ex = tla.and(tla.eql(valName, genEx), cond) + + val arenaWithGenCell = arena.appendCell(SetT1(IntT1)) + val genCell = arenaWithGenCell.topCell + + val rewriter = create(rewriterType) + val state = new SymbState(ex, arenaWithGenCell, Binding("gen" -> genCell)) + val rewrittenState = rewriter.rewriteUntilDone(state) + assert(solverContext.sat()) + + val decoder = new SymbStateDecoder(solverContext, rewriter) + + val decodedVal: TlaEx = decoder.decodeCellToTlaEx(rewrittenState.arena, genCell) + + assert( + decodedVal match { + case OperEx(TlaSetOper.enumSet, args @ _*) => + args.forall { _ == tla.int(0).build } + case _ => false + } + ) + } + + test("decode gen: Regression #2 for #2702") { rewriterType: SMTEncoding => + val valName = tla.name("gen", SetT1(IntT1)) + val genEx = tla.gen(1, SetT1(IntT1)) + val x = tla.name("x", IntT1) + val cond = tla.forall(x, valName, tla.neql(x, tla.int(42))) + + val ex = tla.and(tla.eql(valName, genEx), cond) + + val arenaWithGenCell = arena.appendCell(SetT1(IntT1)) + val genCell = arenaWithGenCell.topCell + + val rewriter = create(rewriterType) + val state = new SymbState(ex, arenaWithGenCell, Binding("gen" -> genCell)) + val rewrittenState = rewriter.rewriteUntilDone(state) + assert(solverContext.sat()) + + val decoder = new SymbStateDecoder(solverContext, rewriter) + + val decodedVal: TlaEx = decoder.decodeCellToTlaEx(rewrittenState.arena, genCell) + + assert( + decodedVal match { + case OperEx(TlaSetOper.enumSet, args @ _*) => !args.contains(tla.int(42).build) + case _ => false + } + ) + } + }