Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Ports key assignment to the driver on the PSS feature branch #1628

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion tests/mbt/driver/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
abcitypes "github.com/cometbft/cometbft/abci/types"
cmttypes "github.com/cometbft/cometbft/types"

"github.com/cometbft/cometbft/proto/tendermint/crypto"
appConsumer "github.com/cosmos/interchain-security/v4/app/consumer"
appProvider "github.com/cosmos/interchain-security/v4/app/provider"
simibc "github.com/cosmos/interchain-security/v4/testutil/simibc"
Expand Down Expand Up @@ -123,9 +124,13 @@ func (s *Driver) consumerPower(i int64, chain ChainId) (int64, error) {
return v.Power, nil
}

func (s *Driver) stakingValidator(i int64) (stakingtypes.Validator, bool) {
return s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i))
}

// providerPower returns the power(=number of bonded tokens) of the i-th validator on the provider.
func (s *Driver) providerPower(i int64) (int64, error) {
v, found := s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i))
v, found := s.stakingValidator(i)
if !found {
return 0, fmt.Errorf("validator with id %v not found on provider", i)
} else {
Expand Down Expand Up @@ -370,6 +375,14 @@ func (s *Driver) setTime(chain ChainId, newTime time.Time) {
testChain.App.BeginBlock(abcitypes.RequestBeginBlock{Header: testChain.CurrentHeader})
}

func (s *Driver) AssignKey(chain ChainId, valIndex int64, value crypto.PublicKey) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename value to pubKey or something similar.

stakingVal, found := s.stakingValidator(valIndex)
if !found {
return fmt.Errorf("validator with id %v not found on provider", valIndex)
}
return s.providerKeeper().AssignConsumerKey(s.providerCtx(), string(chain), stakingVal, value)
}

// DeliverPacketToConsumer delivers a packet from the provider to the given consumer recipient.
// It updates the client before delivering the packet.
// Since the channel is ordered, the packet that is delivered is the first packet in the outbox.
Expand Down
3 changes: 2 additions & 1 deletion tests/mbt/driver/generate_more_traces.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ go run ./... -modelPath=../model/ccv_boundeddrift.qnt -step stepBoundedDrift -in
echo "Generating synced traces with maturations"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -invariant CanReceiveMaturations -traceFolder traces/sync_mat -numTraces 20 -numSteps 300 -numSamples 20
echo "Generating long synced traces without invariants"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 20 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 20 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_boundeddrift.qnt --step stepBoundedDriftKeyAssignment --traceFolder traces/bound_key -numTraces 20 -numSteps 100 -numSamples 20
3 changes: 2 additions & 1 deletion tests/mbt/driver/generate_traces.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ go run ./... -modelPath=../model/ccv_boundeddrift.qnt -step stepBoundedDrift -in
echo "Generating synced traces with maturations"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -invariant CanReceiveMaturations -traceFolder traces/sync_mat -numTraces 1 -numSteps 300 -numSamples 20
echo "Generating long synced traces without invariants"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 1 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 1 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_boundeddrift.qnt --step stepBoundedDriftKeyAssignment --traceFolder traces/bound_key -numTraces 1 -numSteps 100 -numSamples 20
125 changes: 98 additions & 27 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/require"

sdktypes "github.com/cosmos/cosmos-sdk/types"

cmttypes "github.com/cometbft/cometbft/types"

tmencoding "github.com/cometbft/cometbft/crypto/encoding"
"github.com/cosmos/interchain-security/v4/testutil/integration"

sdktypes "github.com/cosmos/cosmos-sdk/types"

providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
)

Expand Down Expand Up @@ -69,6 +72,7 @@ func TestMBT(t *testing.T) {
t.Logf("Number of sent packets: %v", stats.numSentPackets)
t.Logf("Number of blocks: %v", stats.numBlocks)
t.Logf("Number of transactions: %v", stats.numTxs)
t.Logf("Number of key assignments: %v", stats.numKeyAssignments)
t.Logf("Average summed block time delta passed per trace: %v", stats.totalBlockTimePassedPerTrace/time.Duration(numTraces))
}

Expand Down Expand Up @@ -117,6 +121,21 @@ func RunItfTrace(t *testing.T, path string) {

t.Log("Chains are: ", chains)

// generate keys that can be assigned on consumers, according to the ConsumerAddresses in the trace
consumerAddressesExpr := params["ConsumerAddresses"].Value.(itf.ListExprType)

_, _, consumerPrivVals, err := integration.CreateValidators(len(consumerAddressesExpr))
require.NoError(t, err, "Error creating consumer signers")

consumerAddrNamesToPrivVals := make(map[string]cmttypes.PrivValidator, len(consumerAddressesExpr))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here on the differences between real address, model cons addresses, etc?

realAddrsToModelConsAddrs := make(map[string]string, len(consumerAddressesExpr))
i := 0
for address, privVal := range consumerPrivVals {
consumerAddrNamesToPrivVals[consumerAddressesExpr[i].Value.(string)] = privVal
realAddrsToModelConsAddrs[address] = consumerAddressesExpr[i].Value.(string)
i++
}

// create params struct
vscTimeout := time.Duration(params["VscTimeout"].Value.(int64)) * time.Second

Expand Down Expand Up @@ -145,6 +164,15 @@ func RunItfTrace(t *testing.T, path string) {
valSet, addressMap, signers, err := CreateValSet(initialValSet)
require.NoError(t, err, "Error creating validator set")

// get the set of signers for consumers: the validator signers, plus signers for the assignable addresses
consumerSigners := make(map[string]cmttypes.PrivValidator, 0)
for consAddr, consPrivVal := range consumerPrivVals {
consumerSigners[consAddr] = consPrivVal
}
for consAddr, signer := range signers {
consumerSigners[consAddr] = signer
}

// get a slice of validators in the right order
nodes := make([]*cmttypes.Validator, len(valNames))
for i, valName := range valNames {
Expand Down Expand Up @@ -211,6 +239,10 @@ func RunItfTrace(t *testing.T, path string) {
// and then increment the rest of the time
runningConsumersBefore := driver.runningConsumers()
driver.endAndBeginBlock("provider", 1*time.Nanosecond)
for _, consumer := range driver.runningConsumers() {
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
}

driver.endAndBeginBlock("provider", time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)
runningConsumersAfter := driver.runningConsumers()

Expand Down Expand Up @@ -243,7 +275,7 @@ func RunItfTrace(t *testing.T, path string) {
consumer.Value.(string),
modelParams,
driver.providerChain().Vals,
signers,
consumerSigners,
nodes,
valNames,
driver.providerChain(),
Expand All @@ -268,11 +300,8 @@ func RunItfTrace(t *testing.T, path string) {
if len(consumersToStart) > 0 && consumer.ChainId == consumersToStart[len(consumersToStart)-1].Value.(string) {
continue
}
consumerChainId := consumer.ChainId

driver.path(ChainId(consumerChainId)).AddClientHeader(PROVIDER, driver.providerHeader())
err := driver.path(ChainId(consumerChainId)).UpdateClient(consumerChainId, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChainId, err)
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
}

case "EndAndBeginBlockForConsumer":
Expand All @@ -286,13 +315,12 @@ func RunItfTrace(t *testing.T, path string) {
_ = headerBefore

driver.endAndBeginBlock(ChainId(consumerChain), 1*time.Nanosecond)
UpdateConsumerClientOnProvider(t, driver, consumerChain)

driver.endAndBeginBlock(ChainId(consumerChain), time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)

// update the client on the provider
consumerHeader := driver.chain(ChainId(consumerChain)).LastHeader
driver.path(ChainId(consumerChain)).AddClientHeader(consumerChain, consumerHeader)
err := driver.path(ChainId(consumerChain)).UpdateClient(PROVIDER, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChain, err)
UpdateConsumerClientOnProvider(t, driver, consumerChain)

case "DeliverVscPacket":
consumerChain := lastAction["consumerChain"].Value.(string)
Expand Down Expand Up @@ -328,8 +356,26 @@ func RunItfTrace(t *testing.T, path string) {
expectError = false
driver.DeliverPacketFromConsumer(ChainId(consumerChain), expectError)
}
default:
case "KeyAssignment":
consumerChain := lastAction["consumerChain"].Value.(string)
node := lastAction["validator"].Value.(string)
consumerAddr := lastAction["consumerAddr"].Value.(string)

t.Log("KeyAssignment", consumerChain, node, consumerAddr)
stats.numKeyAssignments++

valIndex := getIndexOfString(node, valNames)
assignedPrivVal := consumerAddrNamesToPrivVals[consumerAddr]
assignedKey, err := assignedPrivVal.GetPubKey()
require.NoError(t, err, "Error getting pubkey")

protoPubKey, err := tmencoding.PubKeyToProto(assignedKey)
require.NoError(t, err, "Error converting pubkey to proto")

error := driver.AssignKey(ChainId(consumerChain), int64(valIndex), protoPubKey)
require.NoError(t, error, "Error assigning key")

default:
log.Fatalf("Error loading trace file %s, step %v: do not know action type %s",
path, index, actionKind)
}
Expand Down Expand Up @@ -364,7 +410,7 @@ func RunItfTrace(t *testing.T, path string) {
require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match")

// check validator sets - provider current validator set should be the one from the staking keeper
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers)
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs)

// check times - sanity check that the block times match the ones from the model
CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset)
Expand All @@ -383,7 +429,27 @@ func RunItfTrace(t *testing.T, path string) {
t.Log("🟢 Trace is ok!")
}

func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[string]itf.Expr, consumers []string) {
func UpdateProviderClientOnConsumer(t *testing.T, driver *Driver, consumerChainId string) {
driver.path(ChainId(consumerChainId)).AddClientHeader(PROVIDER, driver.providerHeader())
err := driver.path(ChainId(consumerChainId)).UpdateClient(consumerChainId, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChainId, err)
}

func UpdateConsumerClientOnProvider(t *testing.T, driver *Driver, consumerChain string) {
consumerHeader := driver.chain(ChainId(consumerChain)).LastHeader
driver.path(ChainId(consumerChain)).AddClientHeader(consumerChain, consumerHeader)
err := driver.path(ChainId(consumerChain)).UpdateClient(PROVIDER, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChain, err)
}

func CompareValidatorSets(
t *testing.T,
driver *Driver,
currentModelState map[string]itf.Expr,
consumers []string,
// a map from real addresses to the names of those consumer addresses in the model
keyAddrsToModelConsAddrName map[string]string,
) {
t.Helper()
modelValSet := ValidatorSet(currentModelState, "provider")

Expand All @@ -407,23 +473,28 @@ func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[st
pubkey, err := val.ConsPubKey()
require.NoError(t, err, "Error getting pubkey")

consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))
consAddrModelName, ok := keyAddrsToModelConsAddrName[pubkey.Address().String()]
if ok { // the node has a key assigned, use the name of the consumer address in the model
consumerCurValSet[consAddrModelName] = val.Power
} else { // the node doesn't have a key assigned yet, get the validator moniker
consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))

// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map
// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map

// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}
// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}

// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")
// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")

// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
}
}
require.NoError(t, CompareValSet(modelValSet, consumerCurValSet), "Validator sets do not match for consumer %v", consumer)
}
Expand Down
2 changes: 2 additions & 0 deletions tests/mbt/driver/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ type Stats struct {
numTxs int

totalBlockTimePassedPerTrace time.Duration

numKeyAssignments int
}
Loading