diff --git a/changelog.md b/changelog.md index b5c5744d19..c34e6dbca1 100644 --- a/changelog.md +++ b/changelog.md @@ -68,6 +68,7 @@ * [2515](https://github.com/zeta-chain/node/pull/2515) - replace chainName by chainID for ChainNonces indexing * [2541](https://github.com/zeta-chain/node/pull/2541) - deprecate ChainName field in Chain object * [2542](https://github.com/zeta-chain/node/pull/2542) - adjust permissions to be more restrictive +* [2556](https://github.com/zeta-chain/node/pull/2556) - refactor migrator length check to use consensus type ### Tests diff --git a/pkg/chains/chain_filters.go b/pkg/chains/chain_filters.go new file mode 100644 index 0000000000..3235fb98b1 --- /dev/null +++ b/pkg/chains/chain_filters.go @@ -0,0 +1,54 @@ +package chains + +// ChainFilter is a function that filters chains based on some criteria +type ChainFilter func(c Chain) bool + +// FilterExternalChains filters chains that are external +func FilterExternalChains(c Chain) bool { + return c.IsExternal +} + +// FilterByGateway filters chains by gateway +func FilterByGateway(gw CCTXGateway) ChainFilter { + return func(chain Chain) bool { return chain.CctxGateway == gw } +} + +// FilterByConsensus filters chains by consensus type +func FilterByConsensus(cs Consensus) ChainFilter { + return func(chain Chain) bool { return chain.Consensus == cs } +} + +// FilterChains applies a list of filters to a list of chains +func FilterChains(chainList []Chain, filters ...ChainFilter) []Chain { + // Apply each filter to the list of supported chains + for _, filter := range filters { + var filteredChains []Chain + for _, chain := range chainList { + if filter(chain) { + filteredChains = append(filteredChains, chain) + } + } + chainList = filteredChains + } + + // Return the filtered list of chains + return chainList +} + +// CombineFilterChains combines multiple lists of chains into a single list +func CombineFilterChains(chainLists ...[]Chain) []Chain { + chainMap := make(map[Chain]bool) + var combinedChains []Chain + + // Add chains from each slice to remove duplicates + for _, chains := range chainLists { + for _, chain := range chains { + if !chainMap[chain] { + chainMap[chain] = true + combinedChains = append(combinedChains, chain) + } + } + } + + return combinedChains +} diff --git a/pkg/chains/chain_filters_test.go b/pkg/chains/chain_filters_test.go new file mode 100644 index 0000000000..619d555882 --- /dev/null +++ b/pkg/chains/chain_filters_test.go @@ -0,0 +1,233 @@ +package chains_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" +) + +func TestFilterChains(t *testing.T) { + tt := []struct { + name string + filters []chains.ChainFilter + expected func() []chains.Chain + }{ + { + name: "Filter external chains", + filters: []chains.ChainFilter{chains.FilterExternalChains}, + expected: func() []chains.Chain { + return chains.ExternalChainList([]chains.Chain{}) + }, + }, + { + name: "Filter gateway observer chains", + filters: []chains.ChainFilter{chains.FilterByGateway(chains.CCTXGateway_observers)}, + expected: func() []chains.Chain { + return chains.ChainListByGateway(chains.CCTXGateway_observers, []chains.Chain{}) + }, + }, + { + name: "Filter consensus ethereum chains", + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_ethereum)}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_ethereum, []chains.Chain{}) + }, + }, + { + name: "Filter consensus bitcoin chains", + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_bitcoin)}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_bitcoin, []chains.Chain{}) + }, + }, + { + name: "Filter consensus solana chains", + filters: []chains.ChainFilter{chains.FilterByConsensus(chains.Consensus_solana_consensus)}, + expected: func() []chains.Chain { + return chains.ChainListByConsensus(chains.Consensus_solana_consensus, []chains.Chain{}) + }, + }, + { + name: "Apply multiple filters external chains and gateway observer", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var gatewayObserverChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers { + gatewayObserverChains = append(gatewayObserverChains, chain) + } + } + return gatewayObserverChains + }, + }, + { + name: "Apply multiple filters external chains with gateway observer and consensus ethereum", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_ethereum { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "Apply multiple filters external chains with gateway observer and consensus bitcoin", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_bitcoin { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "test three same filters", + filters: []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterExternalChains, + chains.FilterExternalChains, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + return externalChains + }, + }, + { + name: "Test multiple filters in random order", + filters: []chains.ChainFilter{ + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), + chains.FilterExternalChains, + }, + expected: func() []chains.Chain { + externalChains := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range externalChains { + if chain.CctxGateway == chains.CCTXGateway_observers && + chain.Consensus == chains.Consensus_ethereum { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + chainList := chains.ExternalChainList([]chains.Chain{}) + filteredChains := chains.FilterChains(chainList, tc.filters...) + require.ElementsMatch(t, tc.expected(), filteredChains) + require.Len(t, filteredChains, len(tc.expected())) + }) + } +} + +func TestCombineFilterChains(t *testing.T) { + tt := []struct { + name string + chainLists func() [][]chains.Chain + expected func() []chains.Chain + }{ + { + name: "test support TSS migration filter", + chainLists: func() [][]chains.Chain { + return [][]chains.Chain{ + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), + }...), + } + }, + expected: func() []chains.Chain { + chainList := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range chainList { + if chain.CctxGateway == chains.CCTXGateway_observers && + (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + { + name: "test support TSS migration filter with solana", + chainLists: func() [][]chains.Chain { + return [][]chains.Chain{ + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), + }...), + chains.FilterChains( + chains.ExternalChainList([]chains.Chain{}), + []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_solana_consensus), + }...), + } + }, + expected: func() []chains.Chain { + chainList := chains.ExternalChainList([]chains.Chain{}) + var filterMultipleChains []chains.Chain + for _, chain := range chainList { + if chain.CctxGateway == chains.CCTXGateway_observers && + (chain.Consensus == chains.Consensus_ethereum || chain.Consensus == chains.Consensus_bitcoin || chain.Consensus == chains.Consensus_solana_consensus) { + filterMultipleChains = append(filterMultipleChains, chain) + } + } + return filterMultipleChains + }, + }, + } + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + chainLists := tc.chainLists() + combinedChains := chains.CombineFilterChains(chainLists...) + require.ElementsMatch(t, tc.expected(), combinedChains) + }) + } +} diff --git a/pkg/chains/chains.go b/pkg/chains/chains.go index fafd9f7fbf..da449af748 100644 --- a/pkg/chains/chains.go +++ b/pkg/chains/chains.go @@ -413,6 +413,16 @@ func ChainListByConsensus(consensus Consensus, additionalChains []Chain) []Chain return chainList } +func ChainListByGateway(gateway CCTXGateway, additionalChains []Chain) []Chain { + var chainList []Chain + for _, chain := range CombineDefaultChainsList(additionalChains) { + if chain.CctxGateway == gateway { + chainList = append(chainList, chain) + } + } + return chainList +} + // ChainListForHeaderSupport returns a list of chains that support headers func ChainListForHeaderSupport(additionalChains []Chain) []Chain { var chainList []Chain diff --git a/pkg/chains/chains_test.go b/pkg/chains/chains_test.go index a84d0e41d5..12c2f3d7bf 100644 --- a/pkg/chains/chains_test.go +++ b/pkg/chains/chains_test.go @@ -159,6 +159,56 @@ func TestDefaultChainList(t *testing.T) { }, chains.DefaultChainsList()) } +func TestChainListByGateway(t *testing.T) { + listTests := []struct { + name string + gateway chains.CCTXGateway + expected []chains.Chain + }{ + { + "observers", + chains.CCTXGateway_observers, + []chains.Chain{ + chains.BitcoinMainnet, + chains.BscMainnet, + chains.Ethereum, + chains.BitcoinTestnet, + chains.Mumbai, + chains.Amoy, + chains.BscTestnet, + chains.Goerli, + chains.Sepolia, + chains.BitcoinRegtest, + chains.GoerliLocalnet, + chains.Polygon, + chains.OptimismMainnet, + chains.OptimismSepolia, + chains.BaseMainnet, + chains.BaseSepolia, + chains.SolanaMainnet, + chains.SolanaDevnet, + chains.SolanaLocalnet, + }, + }, + { + "zevm", + chains.CCTXGateway_zevm, + []chains.Chain{ + chains.ZetaChainMainnet, + chains.ZetaChainTestnet, + chains.ZetaChainDevnet, + chains.ZetaChainPrivnet, + }, + }, + } + + for _, lt := range listTests { + t.Run(lt.name, func(t *testing.T) { + require.ElementsMatch(t, lt.expected, chains.ChainListByGateway(lt.gateway, []chains.Chain{})) + }) + } +} + func TestExternalChainList(t *testing.T) { require.ElementsMatch(t, []chains.Chain{ chains.BitcoinMainnet, diff --git a/testutil/keeper/mocks/crosschain/observer.go b/testutil/keeper/mocks/crosschain/observer.go index d71673848f..c90c15c3a6 100644 --- a/testutil/keeper/mocks/crosschain/observer.go +++ b/testutil/keeper/mocks/crosschain/observer.go @@ -604,26 +604,6 @@ func (_m *CrosschainObserverKeeper) GetSupportedChains(ctx types.Context) []chai return r0 } -// GetSupportedForeignChains provides a mock function with given fields: ctx -func (_m *CrosschainObserverKeeper) GetSupportedForeignChains(ctx types.Context) []chains.Chain { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for GetSupportedForeignChains") - } - - var r0 []chains.Chain - if rf, ok := ret.Get(0).(func(types.Context) []chains.Chain); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]chains.Chain) - } - } - - return r0 -} - // GetTSS provides a mock function with given fields: ctx func (_m *CrosschainObserverKeeper) GetTSS(ctx types.Context) (observertypes.TSS, bool) { ret := _m.Called(ctx) diff --git a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go index f2406c7a13..fa97806ee2 100644 --- a/x/crosschain/keeper/grpc_query_cctx_rate_limit.go +++ b/x/crosschain/keeper/grpc_query_cctx_rate_limit.go @@ -74,7 +74,11 @@ func (k Keeper) RateLimiterInput( } // get foreign chains and conversion rates of foreign coins - chains := k.zetaObserverKeeper.GetSupportedForeignChains(ctx) + externalSupportedChains := chains.FilterChains( + k.GetObserverKeeper().GetSupportedChains(ctx), + chains.FilterExternalChains, + ) + _, assetRates, found := k.GetRateLimiterAssetRateList(ctx) if !found { return nil, status.Error(codes.Internal, "asset rates not found") @@ -84,7 +88,7 @@ func (k Keeper) RateLimiterInput( // query pending nonces of each foreign chain and get the lowest height of the pending cctxs lowestPendingCctxHeight := int64(0) pendingNoncesMap := make(map[int64]observertypes.PendingNonces) - for _, chain := range chains { + for _, chain := range externalSupportedChains { pendingNonces, found := k.GetObserverKeeper().GetPendingNonces(ctx, tss.TssPubkey, chain.ChainId) if !found { return nil, status.Error(codes.Internal, "pending nonces not found") @@ -113,7 +117,7 @@ func (k Keeper) RateLimiterInput( cctxsPending := make([]*types.CrossChainTx, 0) // query backwards for pending cctxs of each foreign chain - for _, chain := range chains { + for _, chain := range externalSupportedChains { // we should at least query 1000 prior to find any pending cctx that we might have missed // this logic is needed because a confirmation of higher nonce will automatically update the p.NonceLow // therefore might mask some lower nonce cctx that is still pending. @@ -205,7 +209,7 @@ func (k Keeper) ListPendingCctxWithinRateLimit( totalPending := uint64(0) totalWithdrawInAzeta := sdkmath.NewInt(0) cctxs := make([]*types.CrossChainTx, 0) - foreignChains := k.zetaObserverKeeper.GetSupportedForeignChains(ctx) + foreignChains := chains.FilterChains(k.zetaObserverKeeper.GetSupportedChains(ctx), chains.FilterExternalChains) // check rate limit flags to decide if we should apply rate limit applyLimit := true diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index 78c3b833d7..f68c7e39d6 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -6,6 +6,7 @@ import ( errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/zeta-chain/zetacore/pkg/chains" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" "github.com/zeta-chain/zetacore/x/crosschain/types" ) @@ -37,8 +38,9 @@ func (k msgServer) UpdateTssAddress( } tssMigrators := k.zetaObserverKeeper.GetAllTssFundMigrators(ctx) + // Each connected chain should have its own tss migrator - if len(k.zetaObserverKeeper.GetSupportedForeignChains(ctx)) != len(tssMigrators) { + if len(k.GetChainsSupportingTSSMigration(ctx)) != len(tssMigrators) { return nil, errorsmod.Wrap( types.ErrUnableToUpdateTss, "cannot update tss address incorrect number of migrations have been created and completed", @@ -70,3 +72,24 @@ func (k msgServer) UpdateTssAddress( return &types.MsgUpdateTssAddressResponse{}, nil } + +// GetChainsSupportingTSSMigration returns the chains that support tss migration. +// Chains that support tss migration are chains that have the following properties: +// 1. External chains +// 2. Gateway observer +// 3. Consensus is bitcoin or ethereum (Other consensus types are not supported) +func (k *Keeper) GetChainsSupportingTSSMigration(ctx sdk.Context) []chains.Chain { + supportedChains := k.zetaObserverKeeper.GetSupportedChains(ctx) + return chains.CombineFilterChains([][]chains.Chain{ + chains.FilterChains(supportedChains, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_ethereum), + }...), + chains.FilterChains(supportedChains, []chains.ChainFilter{ + chains.FilterExternalChains, + chains.FilterByGateway(chains.CCTXGateway_observers), + chains.FilterByConsensus(chains.Consensus_bitcoin), + }...), + }...) +} diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index ffdbad3c04..6834d6bed1 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -65,7 +66,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedForeignChains(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -78,7 +79,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedForeignChains(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -109,7 +110,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -122,7 +123,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -139,7 +140,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) }) @@ -156,7 +157,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -169,7 +170,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) msg := crosschaintypes.MsgUpdateTssAddress{ @@ -186,7 +187,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal( t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), - len(k.GetObserverKeeper().GetSupportedChains(ctx)), + len(k.GetChainsSupportingTSSMigration(ctx)), ) }) @@ -207,7 +208,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) // set a single migrator while there are 2 supported chains - chain := k.GetObserverKeeper().GetSupportedChains(ctx)[0] + chain := k.GetChainsSupportingTSSMigration(ctx)[0] index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -254,7 +255,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -301,7 +302,7 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k.GetObserverKeeper().SetTSS(ctx, tssOld) setSupportedChain(ctx, zk, getValidEthChainIDWithIndex(t, 0), getValidEthChainIDWithIndex(t, 1)) - for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { + for _, chain := range k.GetChainsSupportingTSSMigration(ctx) { index := chain.Name + "_migration_tx_index" k.GetObserverKeeper().SetFundMigrator(ctx, types.TssFundMigratorInfo{ ChainId: chain.ChainId, @@ -329,3 +330,26 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { require.Equal(t, len(k.GetObserverKeeper().GetSupportedChains(ctx)), len(migrators)) }) } + +func TestKeeper_GetChainsSupportingTSSMigration(t *testing.T) { + t.Run("should return only ethereum and bitcoin chains", func(t *testing.T) { + k, ctx, _, zk := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{}) + chainList := chains.ExternalChainList([]chains.Chain{}) + var chainParamsList types.ChainParamsList + for _, chain := range chainList { + chainParamsList.ChainParams = append( + chainParamsList.ChainParams, + sample.ChainParamsSupported(chain.ChainId), + ) + } + zk.ObserverKeeper.SetChainParamsList(ctx, chainParamsList) + + chainsSupportingMigration := k.GetChainsSupportingTSSMigration(ctx) + for _, chain := range chainsSupportingMigration { + require.NotEqual(t, chain.Consensus, chains.Consensus_solana_consensus) + require.NotEqual(t, chain.Consensus, chains.Consensus_op_stack) + require.NotEqual(t, chain.Consensus, chains.Consensus_tendermint) + require.Equal(t, chain.IsExternal, true) + } + }) +} diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index b66ab1a12c..1422538b2f 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -103,7 +103,6 @@ type ObserverKeeper interface { ) (bool, bool, observertypes.Ballot, string, error) GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) (chains.Chain, bool) GetSupportedChains(ctx sdk.Context) []chains.Chain - GetSupportedForeignChains(ctx sdk.Context) []chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 9277010c3c..9dc862dd98 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -75,16 +75,3 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []chains.Chain { } return c } - -// GetSupportedForeignChains returns the list of supported foreign chains -func (k Keeper) GetSupportedForeignChains(ctx sdk.Context) []chains.Chain { - allChains := k.GetSupportedChains(ctx) - - foreignChains := make([]chains.Chain, 0) - for _, chain := range allChains { - if !chain.IsZetaChain() { - foreignChains = append(foreignChains, chain) - } - } - return foreignChains -}