diff --git a/Makefile b/Makefile index 1b3fa87e7f..af1d9da73a 100644 --- a/Makefile +++ b/Makefile @@ -140,8 +140,11 @@ test-trace: # Note: this is *not* using the Quint models to test the system, # this tests/verifies the Quint models *themselves*. verify-models: - cd tests/mbt/model;\ - ../run_invariants.sh + quint test tests/mbt/model/ccv_test.qnt;\ + quint test tests/mbt/model/ccv_model.qnt;\ + quint test tests/mbt/model/ccv_pss_test.qnt;\ + quint run --invariant "all{ValidatorUpdatesArePropagatedInv,ValidatorSetHasExistedInv,SameVscPacketsInv,MatureOnTimeInv,EventuallyMatureOnProviderInv}" tests/mbt/model/ccv_model.qnt --max-steps 200 --max-samples 200;\ + quint run --invariant "all{ValidatorUpdatesArePropagatedKeyAssignmentInv,ValidatorSetHasExistedKeyAssignmentInv,SameVscPacketsKeyAssignmentInv,MatureOnTimeInv,EventuallyMatureOnProviderInv,KeyAssignmentRulesInv}" tests/mbt/model/ccv_model.qnt --step stepKeyAssignment --max-steps 200 --max-samples 200 diff --git a/tests/mbt/model/README.md b/tests/mbt/model/README.md index 57f5b2eda9..df994b7f84 100644 --- a/tests/mbt/model/README.md +++ b/tests/mbt/model/README.md @@ -51,6 +51,12 @@ To run with key assignment, specify the step flag: `--step stepKeyAssignment`. KeyAssignment also needs some different invariants, see below. +#### Partial Set Security + +To run with Partial Set Security, specify the step flag `--step stepBoundedDriftKeyAndPSS`. +This runs both PSS and Key Assignment. +It also requires running with `ccv_boundeddrift.qnt`, see below. + #### ccv_boundeddrift.qnt This state machine layer is more restricted to generate more interesting traces: * It never allows consumer chains to drift more than `MaxDrift` time apart from each other. @@ -75,10 +81,7 @@ traces are not very useful for testing. To run unit tests, run ``` -quint test ccv_test.qnt -``` -and -``` +quint test ccv_test.qnt; quint test ccv_model.qnt ``` @@ -131,4 +134,8 @@ The available sanity checks are: - CanSendVscPackets - CanSendVscMaturedPackets - CanAssignConsumerKey (only with `--step stepKeyAssignment`) -- CanHaveConsumerAddresses (only with `--step stepKeyAssignment`) \ No newline at end of file +- CanHaveConsumerAddresses (only with `--step stepKeyAssignment`) +- CanOptIn (only with `--step stepBoundedDriftKeyAndPSS` on `ccv_boundeddrift.qnt`) +- CanOptOut (only with `--step stepBoundedDriftKeyAndPSS` on `ccv_boundeddrift.qnt`) +- CanFailOptOut (only with `--step stepBoundedDriftKeyAndPSS` on `ccv_boundeddrift.qnt`) +- CanHaveOptIn (only with `--step stepBoundedDriftKeyAndPSS` on `ccv_boundeddrift.qnt`) diff --git a/tests/mbt/model/ccv.qnt b/tests/mbt/model/ccv.qnt index 2a81fe3a2d..68d08e88a3 100644 --- a/tests/mbt/model/ccv.qnt +++ b/tests/mbt/model/ccv.qnt @@ -89,6 +89,12 @@ module ccv_types { // Stores VscPackets which have been sent but where the provider has *not received a response yet*. sentVscPacketsToConsumer: Chain -> List[VscPacket], + // stores whether, in this block, the validator set has changed. + // this is needed because the validator set might be considered to have changed, even though + // it is still technically identical at our level of abstraction, e.g. a validator power change on the provider + // might leave the validator set the same because a delegation and undelegation cancel each other out. + providerValidatorSetChangedInThisBlock: bool, + // stores for which consumer chains, in this epoch, the validator set is considered to have changed and we thus need to send a VscPacket to the consumer chains. consumersWithPowerChangesInThisEpoch: Set[Chain], @@ -120,7 +126,18 @@ module ccv_types { consumerAddrsToPrune: Chain -> VscId -> List[ConsumerAddr], // For every sent VSCPacket, stores the key assignments that were applied to send it. - keyAssignmentsForVSCPackets: VscId -> (Chain -> (Node -> ConsumerAddr)) + keyAssignmentsForVSCPackets: VscId -> (Chain -> (Node -> ConsumerAddr)), + + // For each consumer chain, + // stores the set of validators that are opted into running the chain. + optedInVals: Chain -> Set[Node], + + // for each consumer, stores the top N for that consumer. + // The top N% of the validator set by voting power + // is obliged to run a topN chain. + // If the chain is a pure opt-in chain (where noone is forced to run it), + // this is 0. + topNByConsumer: Chain -> int, } // utility function: returns a provider state that is initialized minimally. @@ -130,6 +147,7 @@ module ccv_types { outstandingPacketsToConsumer: Map(), receivedMaturations: Set(), sentVscPacketsToConsumer: Map(), + providerValidatorSetChangedInThisBlock: false, consumersWithPowerChangesInThisEpoch: Set(), consumerStatus: Map(), runningVscId: 0, @@ -138,7 +156,9 @@ module ccv_types { consumerAddrToValidator: Map(), consumerAddrsToPrune: Map(), keyAssignmentsForVSCPackets: Map(), - consumersWithAddrAssignmentChangesInThisEpoch: Set() + consumersWithAddrAssignmentChangesInThisEpoch: Set(), + optedInVals: Map(), + topNByConsumer: Map(), } @@ -226,6 +246,33 @@ module ccv_types { // given as a pure val so that we can switch cases based on // whether a chain is the provider or not pure val PROVIDER_CHAIN = "provider" + + // A record that keeps the information needed to add a new consumer. + // In particular, holds: + // the chain name/identifier, + // and the top N factor for the chain. + type ConsumerAdditionMsg = { + chain: Chain, + topN: int + } + + // Creates a new ConsumerAdditionMsg with a given top N. + pure def NewTopNConsumer(chain: Chain, topN: int): ConsumerAdditionMsg = { + { + chain: chain, + topN: topN + } + } + + // Creates a new ConsumerAdditionMsg with topN = 0. + pure def NewOptInConsumer(chain: Chain): ConsumerAdditionMsg = { + NewTopNConsumer(chain, 0) + } + + // Creates a new ConsumerAdditionMsg with top N = 100%. + pure def NewFullConsumer(chain: Chain): ConsumerAdditionMsg = { + NewTopNConsumer(chain, 100) + } } module ccv { @@ -246,6 +293,7 @@ module ccv { import Time.* from "./libraries/Time" import extraSpells.* from "./libraries/extraSpells" import ccv_types.* + import ccv_pss.* from "./ccv_pss" import ccv_utils.* from "./ccv_utils" @@ -417,13 +465,14 @@ module ccv { // i.e. the timestamp for the next block is oldTimestamp + timeAdvancement timeAdvancement: Time, // a set of consumers that were not consumers before, but should be set to running now. - consumersToStart: Set[Chain], + consumersToStart: Set[ConsumerAdditionMsg], // a set of consumers that were running before, but should be set to stopped now. // This argument only needs to contain "voluntary" stops - // forced stops, e.g. because a consumer timed out, // will be added automatically. consumersToStop: Set[Chain]): Result = { val currentProviderState = currentState.providerState + val curValSet = currentProviderState.chainState.currentValidatorSet // check for vsc timeouts val timedOutConsumers = getRunningConsumers(currentProviderState).filter( @@ -431,27 +480,11 @@ module ccv { val res = TimeoutDueToVscTimeout(currentState, consumer, VscTimeout) res._1 ) - - // for each consumer chain, apply the key assignment to the current validator set - val currentValSets = ConsumerChains.mapBy( - (consumer) => - currentProviderState.applyKeyAssignmentToValSet( - consumer, - currentProviderState.chainState.currentValidatorSet - ) - ) - // store the current validator set with the key assignments applied in the history - val newKeyAssignedValSetHistory = currentValSets.keys().mapBy( - (consumer) => - currentProviderState.keyAssignedValSetHistory - .getOrElse(consumer, List()) // get the existing history (empty list if no history yet) - .prepend(currentValSets.get(consumer)) // prepend the current validator set with key assignments applied - ) // run the shared core chainState logic val newChainState = currentProviderState.chainState.endAndBeginBlockShared(timeAdvancement) val providerStateAfterTimeAdvancement = - {...currentProviderState, chainState: newChainState, keyAssignedValSetHistory: newKeyAssignedValSetHistory} + {...currentProviderState, chainState: newChainState} val tmpState = currentState.with( "providerState", providerStateAfterTimeAdvancement ) @@ -472,41 +505,62 @@ module ccv { // start/stop chains - val res = providerStateAfterSending.consumerStatus.StartStopConsumers( + val res = providerStateAfterSending.StartStopConsumers( consumersToStart, consumersToStop, timedOutConsumers ) - val newConsumerStatus = res._1 + val providerStateAfterConsumerAdvancement = res._1.with("providerValidatorSetChangedInThisBlock", false) val err = res._2 - val providerStateAfterConsumerAdvancement = providerStateAfterSending.with( - "consumerStatus", newConsumerStatus - ) + + val consumerAdditions = consumersToStart.map(consumer => consumer.chain) + + // for each running consumer chain, opt in validators that are in the top N + val providerStateAfterPSS = providerStateAfterConsumerAdvancement.endBlockPSS() if (err != "") { Err(err) } else { - // for each consumer we just set to running, set its initial validator set to be the current one on the provider... - val valSet = providerStateAfterConsumerAdvancement.chainState.currentValidatorSet + // for each consumer chain, apply the key assignment to the current validator set + val currentValSets = getRunningConsumers(providerStateAfterPSS).mapBy( + (consumer) => + providerStateAfterPSS.applyKeyAssignmentToValSet( + consumer, + // get the validator set after partial set security has been applied + GetPSSValidatorSet(providerStateAfterPSS, curValSet, consumer) + ) + ) + + // store the current validator set with the key assignments applied in the history + val newKeyAssignedValSetHistory = currentValSets.keys().mapBy( + (consumer) => + providerStateAfterPSS.keyAssignedValSetHistory + .getOrElse(consumer, List()) // get the existing history (empty list if no history yet) + .prepend(currentValSets.get(consumer)) // prepend the current validator set with key assignments applied + ) + + val providerStateAfterStoringValSets = providerStateAfterPSS.with( + "keyAssignedValSetHistory", newKeyAssignedValSetHistory + ) + val newConsumerStateMap = tmpState.consumerStates.keys().mapBy( (consumer) => - if (consumersToStart.contains(consumer)) { - // ...modified by the key assignments for the consumer - val consValSet = applyKeyAssignmentToValSet(providerStateAfterConsumerAdvancement, consumer, valSet) + if (consumerAdditions.contains(consumer)) { val currentConsumerState: ConsumerState = tmpState.consumerStates.get(consumer) + // correctly set the state for the new consumer val newConsumerState: ConsumerState = currentConsumerState.with( "chainState", currentConsumerState.chainState.with( - "currentValidatorSet", consValSet + "currentValidatorSet", currentValSets.get(consumer) ).with( "votingPowerHistory", - List(consValSet) + List(currentValSets.get(consumer)) ).with( "lastTimestamp", - providerStateAfterConsumerAdvancement.chainState.lastTimestamp + providerStateAfterStoringValSets.chainState.lastTimestamp ).with( "runningTimestamp", - providerStateAfterConsumerAdvancement.chainState.runningTimestamp + providerStateAfterStoringValSets.chainState.runningTimestamp ) ) newConsumerState @@ -515,7 +569,7 @@ module ccv { } ) val newState = tmpState.with( - "providerState", providerStateAfterConsumerAdvancement + "providerState", providerStateAfterStoringValSets ).with( "consumerStates", newConsumerStateMap ) diff --git a/tests/mbt/model/ccv_boundeddrift.qnt b/tests/mbt/model/ccv_boundeddrift.qnt index 24db1d3eaf..b067ca79b2 100644 --- a/tests/mbt/model/ccv_boundeddrift.qnt +++ b/tests/mbt/model/ccv_boundeddrift.qnt @@ -5,6 +5,7 @@ module ccv_boundeddrift { import ccv from "ccv" import Time.* from "./libraries/Time" import extraSpells.* from "./libraries/extraSpells" + import ccv_pss_model.* from "ccv_pss_model" // The boundeddrift module has its own step function. @@ -60,19 +61,22 @@ module ccv_boundeddrift { stepCommon, // allow actions that do not influence time // advance a block for a consumer - all { - runningConsumers.size() > 0, // ensure there is a running consumer, otherwise this action does not make sense - nondet chain = runningConsumers.oneOf() - val maxAdv = findMaxTimeAdvancement(GetChainState(chain), GetOtherChainStates(chain), maxDrift) - val possibleAdvancements = timeAdvancements.filter(t => t <= maxAdv) - all { - possibleAdvancements.size() > 0, // ensure there is a possible advancement, otherwise this action does not make sense - nondet timeAdvancement = possibleAdvancements.oneOf() - EndAndBeginBlockForConsumer(chain, timeAdvancement), - } - }, + stepBoundedDriftConsumer, // advance a block for the provider + stepBoundedDriftProvider + } + + action stepBoundedDriftProvider: bool = { + stepBoundedDriftProvider_helper(allFullConsumers) + } + + action stepBoundedDriftProviderPSS: bool = { + stepBoundedDriftProvider_helper(variousPossibleTopN) + } + + // As an argument, takes a function that, when invoked, gives a top N value to use for a new consumer chain. + action stepBoundedDriftProvider_helper(topNOracle: Set[int]): bool = { val maxAdv = findMaxTimeAdvancement(GetChainState(Ccvt::PROVIDER_CHAIN), GetOtherChainStates(Ccvt::PROVIDER_CHAIN), maxDrift) val possibleAdvancements = timeAdvancements.filter(t => t <= maxAdv) all { @@ -80,14 +84,37 @@ module ccv_boundeddrift { // advance a block for the provider val consumerStatus = currentState.providerState.consumerStatus nondet consumersToStart = oneOf(nonConsumers.powerset()) + nondet topN = oneOf(topNOracle) + nondet consumerAdditions = consumersToStart.map(c => Ccvt::NewTopNConsumer(c, topN)) // make it so we stop consumers only with small likelihood: nondet stopConsumersRand = oneOf(1.to(100)) nondet consumersToStop = if (stopConsumersRand <= consumerStopChance) oneOf(runningConsumers.powerset()) else Set() nondet timeAdvancement = oneOf(possibleAdvancements) - EndAndBeginBlockForProvider(timeAdvancement, consumersToStart, consumersToStop), + EndAndBeginBlockForProvider(timeAdvancement, consumerAdditions, consumersToStop), } } + action stepBoundedDriftConsumer = all { + runningConsumers.size() > 0, // ensure there is a running consumer, otherwise this action does not make sense + nondet chain = runningConsumers.oneOf() + val maxAdv = findMaxTimeAdvancement(GetChainState(chain), GetOtherChainStates(chain), maxDrift) + val possibleAdvancements = timeAdvancements.filter(t => t <= maxAdv) + all { + possibleAdvancements.size() > 0, // ensure there is a possible advancement, otherwise this action does not make sense + nondet timeAdvancement = possibleAdvancements.oneOf() + EndAndBeginBlockForConsumer(chain, timeAdvancement), + } + } + + action stepBoundedDriftKeyAndPSS = any { + stepCommon, + stepBoundedDriftProviderPSS, + stepBoundedDriftConsumer, + nondetKeyAssignment, + StepOptIn, + StepOptOut, + } + action stepBoundedDriftKeyAssignment = any { stepBoundedDrift, nondetKeyAssignment, diff --git a/tests/mbt/model/ccv_model.qnt b/tests/mbt/model/ccv_model.qnt index 1509ad439f..b3b487d113 100644 --- a/tests/mbt/model/ccv_model.qnt +++ b/tests/mbt/model/ccv_model.qnt @@ -49,7 +49,13 @@ module ccv_model { var currentState: ProtocolState - // a type storing the parameters used in actions. + // a type storing the parameters used in actions, + // as well as return values that are not visible from the state, + // i.e. errors. + // Note that whether an error is returned, + // or whether the action is simply not possible when an error occurs, is + // a design choice that is different for each action, + // or can depend on the type of error. // this is used in the trace to store // the name of the last action, plus the parameters we passed to it. // Note: This type holds ALL parameters that are used in ANY action, @@ -59,11 +65,12 @@ module ccv_model { kind: str, consumerChain: Chain, timeAdvancement: Time, - consumersToStart: Set[Chain], + consumersToStart: Set[ConsumerAdditionMsg], consumersToStop: Set[Chain], validator: Node, changeAmount: int, consumerAddr: ConsumerAddr, + expectedError: str, // if the action returns an error, it goes here. } @@ -88,6 +95,7 @@ module ccv_model { validator: "", changeAmount: 0, consumerAddr: "", + expectedError: "", } @@ -101,7 +109,7 @@ module ccv_model { val consumerStates = ConsumerChains.mapBy(chain => GetEmptyConsumerState) val providerStateWithConsumers = providerState.with( "consumerStatus", - ConsumerChains.mapBy(chain => NOT_CONSUMER) + ConsumerChains.mapBy(chain => NOT_CONSUMER) ).with( "outstandingPacketsToConsumer", ConsumerChains.mapBy(chain => List()) @@ -171,7 +179,7 @@ module ccv_model { action EndAndBeginBlockForProvider( timeAdvancement: Time, - consumersToStart: Set[Chain], + consumersToStart: Set[ConsumerAdditionMsg], consumersToStop: Set[Chain]): bool = val result = endAndBeginBlockForProvider(currentState, timeAdvancement, consumersToStart, consumersToStop) all { @@ -232,9 +240,10 @@ module ccv_model { val consumerStatus = currentState.providerState.consumerStatus nondet consumersToStart = oneOf(nonConsumers.powerset()) + val consumerAdditions = consumersToStart.map(chain => NewFullConsumer(chain)) nondet consumersToStop = oneOf(runningConsumers.powerset()) nondet timeAdvancement = oneOf(timeAdvancements) - EndAndBeginBlockForProvider(timeAdvancement, consumersToStart, consumersToStop), + EndAndBeginBlockForProvider(timeAdvancement, consumerAdditions, consumersToStop), stepCommon } @@ -262,17 +271,6 @@ module ccv_model { // UTILITY FUNCTIONS // ================== - pure def removeZeroPowers(valSet: ValidatorSet): ValidatorSet = - valSet.keys().fold( - Map(), - (acc, node) => - if (valSet.get(node) == 0) { - acc - } else { - acc.put(node, valSet.get(node)) - } - ) - pure def oldest(packets: Set[VscPacket]): VscPacket = val newestPossiblePacket: VscPacket = { id: 0, @@ -531,7 +529,7 @@ module ccv_model { // the validator set has changed assert(currentState.providerState.chainState.currentValidatorSet == InitialValidatorSet.put("node1", 150)), // start consumer1 - EndAndBeginBlockForProvider(1 * Second, Set("consumer1"), Set()) + EndAndBeginBlockForProvider(1 * Second, Set(NewFullConsumer("consumer1")), Set()) }) .then( all { @@ -609,7 +607,7 @@ module ccv_model { run SameVscPacketsManualTest = init.then( // start all consumers except for consumer3 - EndAndBeginBlockForProvider(1 * Second, Set("consumer1", "consumer2"), Set()) + EndAndBeginBlockForProvider(1 * Second, Set(NewFullConsumer("consumer1"), NewFullConsumer("consumer2")), Set()) ).then( // change voting power VotingPowerChange("node1", 50) @@ -624,7 +622,7 @@ module ccv_model { DeliverVscPacket("consumer2") ).then( // start consumer3 - EndAndBeginBlockForProvider(1 * Second, Set("consumer3"), Set()) + EndAndBeginBlockForProvider(1 * Second, Set(NewFullConsumer("consumer3")), Set()) ).then( // do another voting power change VotingPowerChange("node2", 50) @@ -655,7 +653,7 @@ module ccv_model { init .then( // start all consumer chains - EndAndBeginBlockForProvider(1 * Second, ConsumerChains, Set()) + EndAndBeginBlockForProvider(1 * Second, ConsumerChains.map(c => NewFullConsumer(c)), Set()) ) .then( // change voting power @@ -868,7 +866,7 @@ module ccv_model { init .then( // start all consumer chains - EndAndBeginBlockForProvider(1 * Second, consumerChains, Set()) + EndAndBeginBlockForProvider(1 * Second, consumerChains.map(c => NewFullConsumer(c)), Set()) ) .then( // node 1 assigns a key on consumer1 @@ -891,7 +889,7 @@ module ccv_model { // the key should be present in the valset on the consumer, and the node itself should not assert(currentState.consumerStates.get("consumer1").chainState.currentValidatorSet.getOrElse("node1", 0) == 0), assert(currentState.consumerStates.get("consumer1").chainState.currentValidatorSet.get("consAddr1") == 100), - // try some key assignments that should fail/succeed without committing to state + // try some key assignments that should fail/succeed without comitting to state val res = assignConsumerKey(currentState, "consumer1", "node1", "consAddr1") // fail - key already assigned (even if it is the same node) assert(hasError(res)), @@ -940,7 +938,7 @@ module ccv_model { VotingPowerChange("node1", 50) ) .then( - EndAndBeginBlockForProvider(1 * Second, Set("consumer1", "consumer2"), Set()) + EndAndBeginBlockForProvider(1 * Second, Set(NewFullConsumer("consumer1"), NewFullConsumer("consumer2")), Set()) ).then( all { ValidatorSetHasExistedKeyAssignmentInv, diff --git a/tests/mbt/model/ccv_pss.qnt b/tests/mbt/model/ccv_pss.qnt new file mode 100644 index 0000000000..4759fe6b34 --- /dev/null +++ b/tests/mbt/model/ccv_pss.qnt @@ -0,0 +1,157 @@ +// This module contains logic for PSS (Partial Set Security). +// PSS is a variant/extension of CCV that +// allows for only a subset of the validator set +// to secure a consumer chain. +// Not all logic related to PSS is inside this module, as some logic is +// too tightly coupled with the core CCV logic, +// which is instead found in ccv.qnt +module ccv_pss { + import ccv_types.* from "./ccv" + import extraSpells.* from "./libraries/extraSpells" + import ccv_utils.* from "./ccv_utils" + + // Given a base validator set, an N for a top N chain, and a set of validators that have opted in to the chain, + // returns the validator set that should be sent to the chain. + // Assumes that the value for N is valid. + pure def GetPSSValidatorSet(providerState: ProviderState, origValSet: ValidatorSet, consumer: Chain): ValidatorSet = { + pure val optedInVals = providerState.optedInVals.getOrElse(consumer, Set()) + GetPSSValidatorSet_helper(origValSet, optedInVals) + } + + pure def GetPSSValidatorSet_helper(origValSet: ValidatorSet, optedInVals: Set[Node]): ValidatorSet = { + origValSet.mapFilter(v => optedInVals.contains(v)) + } + + // Given a validator set and N, returns the top N% of validators by power. + // Note that in the edge case of multiple validators having the same power, + // this will always include all validators with the same power as the lowest top N validator. + pure def GetTopNVals(origValSet: ValidatorSet, N: int): Set[Node] = { + // == sort validators by power == + // define a comparator that compares validators by power + pure def powerCompare(a: Node, b: Node): Ordering = { + pure val powA = origValSet.get(a) + pure val powB = origValSet.get(b) + intCompare(powB, powA) + } + // get a sorted list of validators by power + pure val sortedVals = origValSet.keys().toSortedList(powerCompare) + + // == compute the threshold of how much power the top N have == + pure val totalPower = origValSet.mapValuesSum() + pure val topNPower = totalPower * N / 100 + + // == construct the validator set by going through the sorted vals == + pure val res = sortedVals.foldl( + // accumulator carries 4 values: + // * set of vals in top N (starts with empty set) + // * total power added so far (starts with 0) + // * whether we should add the next validator if it has the same power as the previous one, + // regardless of total power (starts with false) + // * the power of the last validator added (starts with 0) + (Set(), 0, false, 0), + (acc, validator) => + pure val curValSet = acc._1 + pure val accPower = acc._2 + pure val shouldAddSamePow = acc._3 + pure val lastPow = acc._4 + + pure val validatorPower = origValSet.get(validator) + if (validatorPower == lastPow and shouldAddSamePow) { + // we should add the validator because it has the same power as the previous one, + // and we add regardless of total power because we need to include all + // vals with the same power if we include one of them + pure val newAccPower = accPower + validatorPower + (curValSet.union(Set(validator)), newAccPower, true, validatorPower) + } else if (validatorPower > 0 and accPower < topNPower) { + // if we don't have enough power yet, add the validator to the set + pure val newAccPower = accPower + validatorPower + (curValSet.union(Set(validator)), newAccPower, true, validatorPower) + } else { + // if we have enough power and we also are done adding + // all validators with the same power as the lowest top N validator, + // don't add them + (curValSet, accPower, false, 0) + } + ) + res._1 + } + + // Opts a validator in for a consumer chain the provider. + // Possible before the consumer chain starts running, + // and will then be applied when the consumer chain starts running. + pure def OptIn(currentState: ProtocolState, consumer: Chain, validator: Node): Result = { + pure val optedInVals = currentState.providerState.optedInVals.get(consumer) + pure val newOptedInVals = optedInVals.union(Set(validator)) + Ok({ + ...currentState, + providerState: { + ...currentState.providerState, + optedInVals: currentState.providerState.optedInVals.put(consumer, newOptedInVals) + } + }) + } + + // Returns true if the given validator is in the top N for the given consumer chain, + // and false otherwise. + pure def IsTopN(currentState: ProtocolState, validator: Node, consumer: Chain): bool = { + val proviValSet = currentState.providerState.chainState.currentValidatorSet + val N = currentState.providerState.topNByConsumer.get(consumer) + + val topNValSet = GetTopNVals(proviValSet, N) + + topNValSet.contains(validator) + } + + // Returns true if the given validator has opted in to the given consumer chain, + pure def IsOptedIn(currentState: ProtocolState, validator: Node, consumer: Chain): bool = { + currentState.providerState.optedInVals.getOrElse(consumer, Set()).contains(validator) + } + + // Opts a validator out. Safe to call before the consumer chain even runs. + // Will not stop the validator set from being forced to validate when in the top N. + // Validators that are in the top N will not be able to opt out, and + // an error will be returned. + // Similarly, if the validator is not opted in, an error will be returned. + pure def OptOut(currentState: ProtocolState, consumer: Chain, validator: Node): Result = { + if (currentState.IsTopN(validator, consumer)) { + Err("Cannot opt out a validator that is in the top N") + } else if (not(currentState.IsOptedIn(validator, consumer))) { + Err("Cannot opt out a validator that is not opted in") + } else { + pure val optedInVals = currentState.providerState.optedInVals.get(consumer) + pure val newOptedInVals = optedInVals.exclude(Set(validator)) + Ok({ + ...currentState, + providerState: { + ...currentState.providerState, + optedInVals: currentState.providerState.optedInVals.put(consumer, newOptedInVals) + } + }) + } + } + + // Runs the PSS logic that needs to run on endblock. + // Concretely, this will forcefully opt in all validators that are in the top N + // for each chain. + pure def endBlockPSS(providerState: ProviderState): ProviderState = { + val runningConsumers = providerState.getRunningConsumers() + runningConsumers.fold( + providerState, + (acc, consumer) => endBlockPSS_helper(acc, consumer) + ) + } + + // Runs the PSS logic for a single consumer. + // Should only be run for running chains. + pure def endBlockPSS_helper(providerState: ProviderState, consumer: Chain): ProviderState = { + val proviValSet = providerState.chainState.currentValidatorSet + val topNVals = GetTopNVals(proviValSet, providerState.topNByConsumer.get(consumer)) + val prevOptedInVals = providerState.optedInVals.getOrElse(consumer, Set()) + // opt in all the top N validators, i.e. union the top N vals with the previous opted in vals + val newOptedInVals = providerState.optedInVals.put(consumer, prevOptedInVals.union(topNVals)) + { + ...providerState, + optedInVals: newOptedInVals + } + } +} \ No newline at end of file diff --git a/tests/mbt/model/ccv_pss_model.qnt b/tests/mbt/model/ccv_pss_model.qnt new file mode 100644 index 0000000000..76c4873b43 --- /dev/null +++ b/tests/mbt/model/ccv_pss_model.qnt @@ -0,0 +1,113 @@ +module ccv_pss_model { + import ccv_types.* from "./ccv" + import ccv_model.* from "./ccv_model" + import ccv_pss.* from "./ccv_pss" + import extraSpells.* from "./libraries/extraSpells" + + action StepOptIn(): bool = { + all { + runningConsumers.size() > 0, + nondet consumer = oneOf(runningConsumers) + nondet validator = oneOf(nodes) + OptIn_Deterministic(consumer, validator) + } + } + + action OptIn_Deterministic(consumer: Chain, validator: Node): bool = { + val res = OptIn(currentState, consumer, validator) + all { + currentState' = res.newState, + trace' = trace.append( + { + ...emptyAction, + kind: "OptIn", + consumerChain: consumer, + validator: validator, + expectedError: res.error + } + ), + params' = params, + } + } + + action StepOptOut(): bool = { + all { + runningConsumers.size() > 0, + nondet consumer = oneOf(runningConsumers) + nondet validator = oneOf(nodes) + OptOut_Deterministic(consumer, validator) + } + } + + action OptOut_Deterministic(consumer: Chain, validator: Node): bool = { + val res = OptOut(currentState, consumer, validator) + all { + currentState' = res.newState, + trace' = trace.append( + { + ...emptyAction, + kind: "OptOut", + consumerChain: consumer, + validator: validator, + expectedError: res.error + } + ), + params' = params, + } + } + + // Different sets of possible values for the topN parameter. + val allFullConsumers: Set[int] = Set(100) + val allOptIn: Set[int] = Set(0) + // only choose a few values for top N here to not make the "edge cases" of 0 and 100 too unlikely + val variousPossibleTopN: Set[int] = Set(0, 50, 70, 80, 90, 100) + + // INVARIANTS + + // For a consumer chain with a given top N value, + // the total VP on the consumer is at least N% of the total VP of some historical val set on the provider. + val AtLeastTopNPower: bool = + runningConsumers.forall(consumer => { + val topN = currentState.providerState.topNByConsumer.get(consumer) + val totalPowerConsu = currentState.consumerStates.get(consumer).chainState.currentValidatorSet.mapValuesSum() + currentState.providerState.chainState.votingPowerHistory.toSet().exists( + valSet => { + val totalPowerProvi = valSet.mapValuesSum() + + totalPowerConsu >= totalPowerProvi * topN / 100 + } + ) + }) + + // SANITY CHECKS + + val CanOptIn = { + not( + trace[length(trace)-1].kind == "OptIn" + and + trace[length(trace)-1].expectedError == "" + ) + } + + val CanOptOut = { + not( + trace[length(trace)-1].kind == "OptOut" + and + trace[length(trace)-1].expectedError == "" + ) + } + + val CanFailOptOut = { + not( + trace[length(trace)-1].kind == "OptOut" + and + trace[length(trace)-1].expectedError != "" + ) + } + + val CanHaveOptIn = { + currentState.providerState.topNByConsumer.keys().exists(consumer => { + currentState.providerState.topNByConsumer.get(consumer) != 100 + }) + } +} \ No newline at end of file diff --git a/tests/mbt/model/ccv_pss_test.qnt b/tests/mbt/model/ccv_pss_test.qnt new file mode 100644 index 0000000000..2a84f6dbd4 --- /dev/null +++ b/tests/mbt/model/ccv_pss_test.qnt @@ -0,0 +1,34 @@ +// This module contains logic for PSS (Partial Set Security). +// PSS is a variant/extension of CCV that +// allows for only a subset of the validator set +// to secure a consumer chain. +// Not all logic related to PSS is inside this module, as some logic is +// too tightly coupled with the core CCV logic, +// which is instead found in ccv.qnt +module ccv_pss_test { + import ccv_types.* from "./ccv" + import extraSpells.* from "./libraries/extraSpells" + import ccv_utils.* from "./ccv_utils" + import ccv_pss.* from "./ccv_pss" + + run TopNTest = + val valSet = + Map("d" -> 25, "c1" -> 15, "c" -> 15, "b2" -> 10, "b1" -> 10, "b" -> 10, "a2" -> 5, "a1" -> 5, "a" -> 5) + // total power: 5*3 + 10*3 + 15*2 + 25 = 100 + all + { + assert(GetTopNVals(valSet, 0) == Set()), + assert(GetTopNVals(valSet, 1) == Set("d")), + assert(GetTopNVals(valSet, 10) == Set("d")), + assert(GetTopNVals(valSet, 25) == Set("d")), + // if one validator with a power is included, all validators with that power need to be included + assert(GetTopNVals(valSet, 26) == Set("d", "c1", "c")), + assert(GetTopNVals(valSet, 45) == Set("d", "c1", "c")), + assert(GetTopNVals(valSet, 55) == Set("d", "c1", "c")), + assert(GetTopNVals(valSet, 56) == Set("d", "c1", "c", "b2", "b1", "b")), + assert(GetTopNVals(valSet, 85) == Set("d", "c1", "c", "b2", "b1", "b")), + assert(GetTopNVals(valSet, 86) == valSet.keys()), + assert(GetTopNVals(valSet, 95) == valSet.keys()), + assert(GetTopNVals(valSet, 100) == valSet.keys()), + } +} \ No newline at end of file diff --git a/tests/mbt/model/ccv_sync.qnt b/tests/mbt/model/ccv_sync.qnt index d828b88a76..8af76693b8 100644 --- a/tests/mbt/model/ccv_sync.qnt +++ b/tests/mbt/model/ccv_sync.qnt @@ -25,7 +25,7 @@ module ccv_sync { action initSync = all { init.then( - EndAndBeginBlockForProvider(1 * Second, consumerChains, Set()) + EndAndBeginBlockForProvider(1 * Second, consumerChains.map(c => ccvt::NewFullConsumer(c)), Set()) ), QueuedChainsToEndBlock' = consumerChainList.foldl( List(), diff --git a/tests/mbt/model/ccv_test.qnt b/tests/mbt/model/ccv_test.qnt index ab47df50b7..af49ecfa14 100644 --- a/tests/mbt/model/ccv_test.qnt +++ b/tests/mbt/model/ccv_test.qnt @@ -247,15 +247,15 @@ module ccv_test { "chain3" -> STOPPED ) val res = StartStopConsumers( - currentConsumerStatusMap, - Set("chain1"), + GetEmptyProviderState.with("consumerStatus", currentConsumerStatusMap), + Set(NewOptInConsumer("chain1")), Set("chain2"), Set() ) res._2 == "" and - res._1.get("chain1") == RUNNING and - res._1.get("chain2") == STOPPED and - res._1.get("chain3") == STOPPED + res._1.consumerStatus.get("chain1") == RUNNING and + res._1.consumerStatus.get("chain2") == STOPPED and + res._1.consumerStatus.get("chain3") == STOPPED } run ConsumerStatusMapAlreadyRunningTest = @@ -266,8 +266,8 @@ module ccv_test { "chain3" -> STOPPED ) val res = StartStopConsumers( - currentConsumerStatusMap, - Set("chain2"), + GetEmptyProviderState.with("consumerStatus", currentConsumerStatusMap), + Set(NewOptInConsumer("chain2")), Set("chain3"), Set() ) @@ -282,8 +282,8 @@ module ccv_test { "chain3" -> STOPPED ) val res = StartStopConsumers( - currentConsumerStatusMap, - Set("chain1"), + GetEmptyProviderState.with("consumerStatus", currentConsumerStatusMap), + Set(NewOptInConsumer("chain1")), Set("chain3"), Set() ) @@ -298,8 +298,8 @@ module ccv_test { "chain3" -> STOPPED ) val res = StartStopConsumers( - currentConsumerStatusMap, - Set("chain1"), + GetEmptyProviderState.with("consumerStatus", currentConsumerStatusMap), + Set(NewOptInConsumer("chain1")), Set("chain1"), Set() ) diff --git a/tests/mbt/model/ccv_utils.qnt b/tests/mbt/model/ccv_utils.qnt index 476ef923dc..4e4582fd06 100644 --- a/tests/mbt/model/ccv_utils.qnt +++ b/tests/mbt/model/ccv_utils.qnt @@ -141,26 +141,33 @@ module ccv_utils { } pure def StartStopConsumers( - currentConsumerStatusMap: Chain -> str, - consumersToStart: Set[Chain], + currentProviderState: ProviderState, + consumersToStart: Set[ConsumerAdditionMsg], consumersToStop: Set[Chain], consumersToTimeout: Set[Chain] - ): (Chain -> str, str) = { + ): (ProviderState, str) = { + val consumerAdditions = consumersToStart.map( + msg => msg.chain + ) // check if any consumer is both started and stopped - if (consumersToStart.intersect(consumersToStop).size() > 0) { - (currentConsumerStatusMap, "Cannot start and stop a consumer at the same time") + if (consumerAdditions.intersect(consumersToStop).size() > 0) { + (currentProviderState, "Cannot start and stop a consumer at the same time") } else { - val res1 = currentConsumerStatusMap.startConsumers(consumersToStart) + val res1 = currentProviderState.consumerStatus.startConsumers(consumerAdditions) val newConsumerStatus = res1._1 val err1 = res1._2 val res2 = newConsumerStatus.stopConsumers(consumersToStop, consumersToTimeout) val err2 = res2._2 + // set the top N values in the provider correctly + if (err1 != "") { - (currentConsumerStatusMap, err1) + (currentProviderState, err1) } else if (err2 != "") { - (currentConsumerStatusMap, err2) + (currentProviderState, err2) } else { - (res2._1, "") + (currentProviderState + .with("consumerStatus", res2._1) + .SetTopNValues(consumersToStart), "") } } } @@ -170,7 +177,7 @@ module ccv_utils { pure def enterCurValSetIntoBlock(chainState: ChainState): ChainState = { chainState.with( "votingPowerHistory", chainState.votingPowerHistory.prepend( - chainState.currentValidatorSet + removeZeroPowers(chainState.currentValidatorSet) ) ) } @@ -460,4 +467,28 @@ module ccv_utils { (false, "") } } + + // Sets the top N values on the provider chain for the given consumer chains, + // taken as consumerAdditionMsgs = chains with the top N values. + // If a chain in the set is already present, the old value will be overwritten. + pure def SetTopNValues(providerState: ProviderState, consumers: Set[ConsumerAdditionMsg]): ProviderState = + providerState.with( + "topNByConsumer", + consumers.fold( + providerState.topNByConsumer, + (acc, consumer) => acc.put(consumer.chain, consumer.topN) + ) + ) + + // From a validator set, removes all validators with zero power. + pure def removeZeroPowers(valSet: ValidatorSet): ValidatorSet = + valSet.keys().fold( + Map(), + (acc, node) => + if (valSet.get(node) == 0) { + acc + } else { + acc.put(node, valSet.get(node)) + } + ) } \ No newline at end of file diff --git a/tests/mbt/model/libraries/extraSpells.qnt b/tests/mbt/model/libraries/extraSpells.qnt index 9167bb4bcb..6ab4063ad3 100644 --- a/tests/mbt/model/libraries/extraSpells.qnt +++ b/tests/mbt/model/libraries/extraSpells.qnt @@ -139,6 +139,72 @@ module extraSpells { __set.fold(List(), (__l, __e) => __l.append(__e)) } + /// The type of orderings between comparable things + // Follows https://hackage.haskell.org/package/base-4.19.0.0/docs/Data-Ord.html#t:Ordering + // and we think there are likely benefits to using 3 constant values rather than the more + // common integer range in Apalache. + type Ordering = + | EQ + | LT + | GT + + /// Comparison of integers + pure def intCompare(__a: int, __b:int): Ordering = { + if (__a < __b) + { LT } + else if (__a > __b) + { GT } + else + { EQ } + } + + /// Assuming `__l` is sorted according to `__cmp`, returns a list with the element `__x` + /// inserted in order. + /// + /// If `__l` is not sorted, `__x` will be inserted after the first element less than + /// or equal to it. + /// + /// - @param __l a sorted list + /// - @param __x an element to be inserted + /// - @param __cmp an operator defining an `Ordering` of the elemnts of type `a` + /// - @returns a sorted list that includes `__x` + pure def sortedListInsert(__l: List[a], __x: a, __cmp: (a, a) => Ordering): List[a] = { + // We need to track whether __x has been inserted, and the accumulator for the new list + val __init = { is_inserted: false, acc: List() } + + val __result = __l.foldl(__init, (__state, __y) => + if (__state.is_inserted) + { ...__state, acc: __state.acc.append(__y) } + else + match __cmp(__x, __y) { + | GT => { ...__state, acc: __state.acc.append(__y) } + | _ => { is_inserted: true, acc: __state.acc.append(__x).append(__y) } + }) + + if (not(__result.is_inserted)) + // If __x was not inserted, it was GT than every other element, so it goes at the end + __result.acc.append(__x) + else + __result.acc + } + + run sortedListInsertTest = all { + assert(List().sortedListInsert(3, intCompare) == List(3)), + assert(List(1,2,4).sortedListInsert(3, intCompare) == List(1,2,3,4)), + assert(List(4,1,2).sortedListInsert(3, intCompare) == List(3,4,1,2)), + assert(List(1,2,3).sortedListInsert(4, intCompare) == List(1,2,3,4)), + } + + //// Returns a list of all elements of a set. + //// The ordering will be arbitrary. + //// + //// - @param __set a set + //// - @param __cmp an operator defining an `Ordering` of the elemnts of type `a` + //// - @returns a sorted list of all elements of __set + pure def toSortedList(__set: Set[a], __cmp: (a, a) => Ordering): List[a] = { + __set.fold(List(), (__l, __e) => __l.sortedListInsert(__e, __cmp)) + } + //// Returns a set of the elements in the list. //// //// - @param __list a list @@ -207,4 +273,59 @@ module extraSpells { assert(not(listForAll(List(1, 2, 3), __x => __x > 1))), assert(listForAll(List(), __x => __x > 0)), } + + /// Compute the sum of the values over all entries in a map. + /// + /// - @param myMap a map from keys to integers + /// - @returns the sum; when the map is empty, the sum is 0. + pure def mapValuesSum(myMap: a -> int): int = { + myMap.keys().fold(0, ((sum, i) => sum + myMap.get(i))) + } + + run mapValuesSumTest = all { + assert(Map().mapValuesSum() == 0), + assert(2.to(5).mapBy(i => i * 2).mapValuesSum() == 28), + assert(Map(2 -> -4, 4 -> 2).mapValuesSum() == -2), + } + + /// Returns a map of a subset of keys and values from a map, + // where only those keys are included for which the given __f + // returns true. + pure def mapFilter(__map: a -> b, __f: a => bool): a -> b = { + __map.keys().filter(e => __f(e)).mapBy(__k => __map.get(__k)) + } + + /// Compute the maximum of two integers. + /// + /// - @param __i first integer + /// - @param __j second integer + /// - @returns the maximum of __i and __j + pure def max(__i: int, __j: int): int = { + if (__i > __j) __i else __j + } + + run maxTest = all { + assert(max(3, 4) == 4), + assert(max(6, 3) == 6), + assert(max(10, 10) == 10), + assert(max(-3, -5) == -3), + assert(max(-5, -3) == -3), + } + + /// Compute the minimum of two integers. + /// + /// - @param __i first integer + /// - @param __j second integer + /// - @returns the minimum of __i and __j + pure def min(__i: int, __j: int): int = { + if (__i < __j) __i else __j + } + + run minTest = all { + assert(min(3, 4) == 3), + assert(min(6, 3) == 3), + assert(min(10, 10) == 10), + assert(min(-3, -5) == -5), + assert(min(-5, -3) == -5), + } } \ No newline at end of file diff --git a/tests/mbt/run_invariants.sh b/tests/mbt/model/run_invariants.sh similarity index 100% rename from tests/mbt/run_invariants.sh rename to tests/mbt/model/run_invariants.sh