diff --git a/x/ccv/provider/keeper/migration.go b/x/ccv/provider/keeper/migration.go new file mode 100644 index 0000000000..3f04a49ef6 --- /dev/null +++ b/x/ccv/provider/keeper/migration.go @@ -0,0 +1,138 @@ +package keeper + +import ( + "fmt" + + sdktypes "github.com/cosmos/cosmos-sdk/types" + paramtypes "github.com/cosmos/cosmos-sdk/x/params/types" + + providertypes "github.com/cosmos/interchain-security/v3/x/ccv/provider/types" + ccvtypes "github.com/cosmos/interchain-security/v3/x/ccv/types" +) + +// Migrator is a struct for handling in-place store migrations. +type Migrator struct { + providerKeeper Keeper + paramSpace paramtypes.Subspace +} + +// NewMigrator returns a new Migrator. +func NewMigrator(providerKeeper Keeper, paramSpace paramtypes.Subspace) Migrator { + return Migrator{providerKeeper: providerKeeper, paramSpace: paramSpace} +} + +// Migrate2to3 migrates x/ccvprovider state from consensus version 2 to 3. +func (m Migrator) Migrate2to3(ctx sdktypes.Context) error { + return m.providerKeeper.MigrateQueuedPackets(ctx) +} + +func (k Keeper) MigrateQueuedPackets(ctx sdktypes.Context) error { + for _, consumer := range k.GetAllConsumerChains(ctx) { + slashData, vscmData := k.GetAllThrottledPacketData(ctx, consumer.ChainId) + if len(slashData) > 0 { + ctx.Logger().Error(fmt.Sprintf("slash data being dropped: %v", slashData)) + } + for _, data := range vscmData { + k.HandleVSCMaturedPacket(ctx, consumer.ChainId, data) + } + k.DeleteThrottledPacketDataForConsumer(ctx, consumer.ChainId) + } + return nil +} + +// Pending packet data type enum, used to encode the type of packet data stored at each entry in the mutual queue. +// Note this type is copy/pasted from throttle v1 code. +const ( + slashPacketData byte = iota + vscMaturedPacketData +) + +// GetAllThrottledPacketData returns all throttled packet data for a given consumer chain, only used for 2 -> 3 migration. +// Note this method is adapted from throttle v1 code. +func (k Keeper) GetAllThrottledPacketData(ctx sdktypes.Context, consumerChainID string) ( + slashData []ccvtypes.SlashPacketData, vscMaturedData []ccvtypes.VSCMaturedPacketData, +) { + slashData = []ccvtypes.SlashPacketData{} + vscMaturedData = []ccvtypes.VSCMaturedPacketData{} + + store := ctx.KVStore(k.storeKey) + iteratorPrefix := providertypes.ChainIdWithLenKey(providertypes.ThrottledPacketDataBytePrefix, consumerChainID) + iterator := sdktypes.KVStorePrefixIterator(store, iteratorPrefix) + defer iterator.Close() + + for ; iterator.Valid(); iterator.Next() { + bz := iterator.Value() + switch bz[0] { + case slashPacketData: + d := ccvtypes.SlashPacketData{} + if err := d.Unmarshal(bz[1:]); err != nil { + k.Logger(ctx).Error(fmt.Sprintf("failed to unmarshal slash packet data: %v", err)) + continue + } + slashData = append(slashData, d) + case vscMaturedPacketData: + d := ccvtypes.VSCMaturedPacketData{} + if err := d.Unmarshal(bz[1:]); err != nil { + k.Logger(ctx).Error(fmt.Sprintf("failed to unmarshal vsc matured packet data: %v", err)) + continue + } + vscMaturedData = append(vscMaturedData, d) + default: + k.Logger(ctx).Error(fmt.Sprintf("invalid packet data type: %v", bz[0])) + continue + } + } + + return slashData, vscMaturedData +} + +// Note this method is copy/pasted from throttle v1 code. +func (k Keeper) DeleteThrottledPacketDataForConsumer(ctx sdktypes.Context, consumerChainID string) { + store := ctx.KVStore(k.storeKey) + iteratorPrefix := providertypes.ChainIdWithLenKey(providertypes.ThrottledPacketDataBytePrefix, consumerChainID) + iterator := sdktypes.KVStorePrefixIterator(store, iteratorPrefix) + defer iterator.Close() + + keysToDel := [][]byte{} + for ; iterator.Valid(); iterator.Next() { + keysToDel = append(keysToDel, iterator.Key()) + } + // Delete data for this consumer + for _, key := range keysToDel { + store.Delete(key) + } + + // Delete size of data queue for this consumer + store.Delete(providertypes.ThrottledPacketDataSizeKey(consumerChainID)) +} + +// Note this method is adapted from throttle v1 code. +func (k Keeper) QueueThrottledPacketDataOnlyForTesting( + ctx sdktypes.Context, consumerChainID string, ibcSeqNum uint64, packetData interface{}, +) error { + store := ctx.KVStore(k.storeKey) + + var bz []byte + var err error + switch data := packetData.(type) { + case ccvtypes.SlashPacketData: + bz, err = data.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal slash packet data: %v", err) + } + bz = append([]byte{slashPacketData}, bz...) + case ccvtypes.VSCMaturedPacketData: + bz, err = data.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal vsc matured packet data: %v", err) + } + bz = append([]byte{vscMaturedPacketData}, bz...) + default: + // Indicates a developer error, this method should only be called + // by tests, QueueThrottledSlashPacketData, or QueueThrottledVSCMaturedPacketData. + panic(fmt.Sprintf("unexpected packet data type: %T", data)) + } + + store.Set(providertypes.ThrottledPacketDataKey(consumerChainID, ibcSeqNum), bz) + return nil +} diff --git a/x/ccv/provider/keeper/migration_test.go b/x/ccv/provider/keeper/migration_test.go new file mode 100644 index 0000000000..a710e3979c --- /dev/null +++ b/x/ccv/provider/keeper/migration_test.go @@ -0,0 +1,117 @@ +package keeper_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + testutil "github.com/cosmos/interchain-security/v3/testutil/keeper" +) + +func TestMigrate2To3(t *testing.T) { + providerKeeper, ctx, ctrl, _ := testutil.GetProviderKeeperAndCtx(t, testutil.NewInMemKeeperParams(t)) + defer ctrl.Finish() + + // Set consumer client ids to mock consumers being connected to provider + providerKeeper.SetConsumerClientId(ctx, "chain-1", "client-1") + providerKeeper.SetConsumerClientId(ctx, "chain-2", "client-2") + providerKeeper.SetConsumerClientId(ctx, "chain-3", "client-3") + + // Queue some data for chain-1 + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-1", 66, testutil.GetNewSlashPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-1", 67, testutil.GetNewVSCMaturedPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-1", 68, testutil.GetNewSlashPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-1", 69, testutil.GetNewVSCMaturedPacketData()) + + // Queue some data for chain-2 + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-2", 789, testutil.GetNewVSCMaturedPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-2", 790, testutil.GetNewSlashPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-2", 791, testutil.GetNewVSCMaturedPacketData()) + + // Queue some data for chain-3 + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-3", 123, testutil.GetNewSlashPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-3", 124, testutil.GetNewVSCMaturedPacketData()) + providerKeeper.QueueThrottledPacketDataOnlyForTesting( + ctx, "chain-3", 125, testutil.GetNewVSCMaturedPacketData()) + + // Confirm getter methods return expected values + slash1, vscm1 := providerKeeper.GetAllThrottledPacketData(ctx, "chain-1") + require.Len(t, slash1, 2) + require.Len(t, vscm1, 2) + + slash2, vscm2 := providerKeeper.GetAllThrottledPacketData(ctx, "chain-2") + require.Len(t, slash2, 1) + require.Len(t, vscm2, 2) + + slash3, vscm3 := providerKeeper.GetAllThrottledPacketData(ctx, "chain-3") + require.Len(t, slash3, 1) + require.Len(t, vscm3, 2) + + // Set vsc send timestamp for every queued vsc matured packet, + // as a way to assert that the vsc matured packets are handled in the migration. + // + // That is, timestamp should exist before a vsc matured packet is handled, + // and deleted after handling. + for _, data := range vscm1 { + providerKeeper.SetVscSendTimestamp(ctx, "chain-1", data.ValsetUpdateId, time.Now()) + } + for _, data := range vscm2 { + providerKeeper.SetVscSendTimestamp(ctx, "chain-2", data.ValsetUpdateId, time.Now()) + } + for _, data := range vscm3 { + providerKeeper.SetVscSendTimestamp(ctx, "chain-3", data.ValsetUpdateId, time.Now()) + } + + // Confirm timestamps are set + for _, data := range vscm1 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-1", data.ValsetUpdateId) + require.True(t, found) + } + for _, data := range vscm2 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-2", data.ValsetUpdateId) + require.True(t, found) + } + for _, data := range vscm3 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-3", data.ValsetUpdateId) + require.True(t, found) + } + + // Run migration + err := providerKeeper.MigrateQueuedPackets(ctx) + require.NoError(t, err) + + // Confirm throttled data is now deleted + slash1, vscm1 = providerKeeper.GetAllThrottledPacketData(ctx, "chain-1") + require.Empty(t, slash1) + require.Empty(t, vscm1) + slash2, vscm2 = providerKeeper.GetAllThrottledPacketData(ctx, "chain-2") + require.Empty(t, slash2) + require.Empty(t, vscm2) + slash3, vscm3 = providerKeeper.GetAllThrottledPacketData(ctx, "chain-3") + require.Empty(t, slash3) + require.Empty(t, vscm3) + + // Confirm timestamps are deleted, meaning vsc matured packets were handled + for _, data := range vscm1 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-1", data.ValsetUpdateId) + require.False(t, found) + } + for _, data := range vscm2 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-2", data.ValsetUpdateId) + require.False(t, found) + } + for _, data := range vscm3 { + _, found := providerKeeper.GetVscSendTimestamp(ctx, "chain-3", data.ValsetUpdateId) + require.False(t, found) + } +} diff --git a/x/ccv/provider/module.go b/x/ccv/provider/module.go index 5aa743ec2d..c70fdeafa7 100644 --- a/x/ccv/provider/module.go +++ b/x/ccv/provider/module.go @@ -107,6 +107,10 @@ func (AppModule) RegisterInvariants(ir sdk.InvariantRegistry) { func (am AppModule) RegisterServices(cfg module.Configurator) { providertypes.RegisterMsgServer(cfg.MsgServer(), keeper.NewMsgServerImpl(am.keeper)) providertypes.RegisterQueryServer(cfg.QueryServer(), am.keeper) + m := keeper.NewMigrator(*am.keeper, am.paramSpace) + if err := cfg.RegisterMigration(providertypes.ModuleName, 2, m.Migrate2to3); err != nil { + panic(fmt.Sprintf("failed to register migrator for %s: %s", providertypes.ModuleName, err)) + } } // InitGenesis performs genesis initialization for the provider module. It returns no validator updates. @@ -129,7 +133,7 @@ func (am AppModule) ExportGenesis(ctx sdk.Context, cdc codec.JSONCodec) json.Raw } // ConsensusVersion implements AppModule/ConsensusVersion. -func (AppModule) ConsensusVersion() uint64 { return 2 } +func (AppModule) ConsensusVersion() uint64 { return 3 } // BeginBlock implements the AppModule interface func (am AppModule) BeginBlock(ctx sdk.Context, req abci.RequestBeginBlock) {