From 00ab5acde7a9f40fdef26c38abd573b2c49d60cd Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Fri, 19 Apr 2024 10:55:11 -0500 Subject: [PATCH] some minimum code refactor --- testutil/keeper/mocks/crosschain/observer.go | 20 ++++ x/crosschain/keeper/cctx_utils.go | 2 +- x/crosschain/keeper/cctx_utils_test.go | 2 +- x/crosschain/keeper/grpc_query_cctx.go | 98 ++++++++----------- .../keeper/msg_server_add_to_outtx_tracker.go | 2 +- x/crosschain/types/expected_keepers.go | 1 + x/observer/keeper/chain_params.go | 13 +++ zetaclient/evm/evm_signer.go | 2 +- 8 files changed, 78 insertions(+), 62 deletions(-) diff --git a/testutil/keeper/mocks/crosschain/observer.go b/testutil/keeper/mocks/crosschain/observer.go index aa05e13226..37b8ea2720 100644 --- a/testutil/keeper/mocks/crosschain/observer.go +++ b/testutil/keeper/mocks/crosschain/observer.go @@ -596,6 +596,26 @@ func (_m *CrosschainObserverKeeper) GetSupportedChains(ctx types.Context) []*cha return r0 } +// GetSupportedForeignChains provides a mock function with given fields: ctx +func (_m *CrosschainObserverKeeper) GetSupportedForeignChains(ctx types.Context) []*chains.Chain { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GetSupportedForeignChains") + } + + var r0 []*chains.Chain + if rf, ok := ret.Get(0).(func(types.Context) []*chains.Chain); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*chains.Chain) + } + } + + return r0 +} + // GetTSS provides a mock function with given fields: ctx func (_m *CrosschainObserverKeeper) GetTSS(ctx types.Context) (observertypes.TSS, bool) { ret := _m.Called(ctx) diff --git a/x/crosschain/keeper/cctx_utils.go b/x/crosschain/keeper/cctx_utils.go index 6ae826f9ed..6f9651967a 100644 --- a/x/crosschain/keeper/cctx_utils.go +++ b/x/crosschain/keeper/cctx_utils.go @@ -86,7 +86,7 @@ func (k Keeper) GetRevertGasLimit(ctx sdk.Context, cctx types.CrossChainTx) (uin return 0, nil } -func IsPending(cctx types.CrossChainTx) bool { +func IsPending(cctx *types.CrossChainTx) bool { // pending inbound is not considered a "pending" state because it has not reached consensus yet return cctx.CctxStatus.Status == types.CctxStatus_PendingOutbound || cctx.CctxStatus.Status == types.CctxStatus_PendingRevert } diff --git a/x/crosschain/keeper/cctx_utils_test.go b/x/crosschain/keeper/cctx_utils_test.go index 1c807a97c1..e0dedbb41d 100644 --- a/x/crosschain/keeper/cctx_utils_test.go +++ b/x/crosschain/keeper/cctx_utils_test.go @@ -215,7 +215,7 @@ func Test_IsPending(t *testing.T) { } for _, tc := range tt { t.Run(fmt.Sprintf("status %s", tc.status), func(t *testing.T) { - require.Equal(t, tc.expected, crosschainkeeper.IsPending(types.CrossChainTx{CctxStatus: &types.Status{Status: tc.status}})) + require.Equal(t, tc.expected, crosschainkeeper.IsPending(&types.CrossChainTx{CctxStatus: &types.Status{Status: tc.status}})) }) } } diff --git a/x/crosschain/keeper/grpc_query_cctx.go b/x/crosschain/keeper/grpc_query_cctx.go index 1ab7078b1a..95a2336f31 100644 --- a/x/crosschain/keeper/grpc_query_cctx.go +++ b/x/crosschain/keeper/grpc_query_cctx.go @@ -82,16 +82,12 @@ func (k Keeper) CctxByNonce(c context.Context, req *types.QueryGetCctxByNonceReq return nil, status.Error(codes.Internal, "tss not found") } // #nosec G701 always in range - res, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tss.TssPubkey, req.ChainID, int64(req.Nonce)) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("nonceToCctx not found: nonce %d, chainid %d", req.Nonce, req.ChainID)) - } - val, found := k.GetCrossChainTx(ctx, res.CctxIndex) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("cctx not found: index %s", res.CctxIndex)) + cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, req.ChainID, int64(req.Nonce)) + if err != nil { + return nil, err } - return &types.QueryGetCctxResponse{CrossChainTx: &val}, nil + return &types.QueryGetCctxResponse{CrossChainTx: cctx}, nil } // CctxListPending returns a list of pending cctxs and the total number of pending cctxs @@ -138,20 +134,16 @@ func (k Keeper) CctxListPending(c context.Context, req *types.QueryListCctxPendi startNonce = 0 } for i := startNonce; i < pendingNonces.NonceLow; i++ { - nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tss.TssPubkey, req.ChainId, i) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("nonceToCctx not found: nonce %d, chainid %d", i, req.ChainId)) - } - cctx, found := k.GetCrossChainTx(ctx, nonceToCctx.CctxIndex) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("cctx not found: index %s", nonceToCctx.CctxIndex)) + cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, req.ChainId, i) + if err != nil { + return nil, err } // only take a `limit` number of pending cctxs as result but still count the total pending cctxs if IsPending(cctx) { totalPending++ if !maxCCTXsReached() { - cctxs = append(cctxs, &cctx) + cctxs = append(cctxs, cctx) } } } @@ -162,15 +154,11 @@ func (k Keeper) CctxListPending(c context.Context, req *types.QueryListCctxPendi // now query the pending nonces that we know are pending for i := pendingNonces.NonceLow; i < pendingNonces.NonceHigh && !maxCCTXsReached(); i++ { - nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tss.TssPubkey, req.ChainId, i) - if !found { - return nil, status.Error(codes.Internal, "nonceToCctx not found") + cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, req.ChainId, i) + if err != nil { + return nil, err } - cctx, found := k.GetCrossChainTx(ctx, nonceToCctx.CctxIndex) - if !found { - return nil, status.Error(codes.Internal, "cctxIndex not found") - } - cctxs = append(cctxs, &cctx) + cctxs = append(cctxs, cctx) } return &types.QueryListCctxPendingResponse{ @@ -243,13 +231,10 @@ func (k Keeper) CctxListPendingWithinRateLimit(c context.Context, req *types.Que return uint32(len(cctxs)) >= limit } - // query pending nonces for each supported chain + // query pending nonces for each foreign chain // Note: The pending nonces could change during the RPC call, so query them beforehand - chains := k.zetaObserverKeeper.GetSupportedChains(ctx) + chains := k.zetaObserverKeeper.GetSupportedForeignChains(ctx) for _, chain := range chains { - if chain.IsZetaChain() { - continue - } pendingNonces, found := k.GetObserverKeeper().GetPendingNonces(ctx, tss.TssPubkey, chain.ChainId) if !found { return nil, status.Error(codes.Internal, "pending nonces not found") @@ -257,13 +242,9 @@ func (k Keeper) CctxListPendingWithinRateLimit(c context.Context, req *types.Que pendingNoncesMap[chain.ChainId] = &pendingNonces } - // query backwards for potential missed pending cctxs for each supported chain + // query backwards for potential missed pending cctxs for each foreign chain LoopBackwards: for _, chain := range chains { - if chain.IsZetaChain() { - continue - } - // we should at least query 1000 prior to find any pending cctx that we might have missed // this logic is needed because a confirmation of higher nonce will automatically update the p.NonceLow // therefore might mask some lower nonce cctx that is still pending. @@ -276,13 +257,9 @@ LoopBackwards: // query cctx by nonce backwards to the left boundary of the rate limit sliding window for nonce := startNonce; nonce >= 0; nonce-- { - nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tss.TssPubkey, chain.ChainId, nonce) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("nonceToCctx not found: chainid %d, nonce %d", chain.ChainId, nonce)) - } - cctx, found := k.GetCrossChainTx(ctx, nonceToCctx.CctxIndex) - if !found { - return nil, status.Error(codes.Internal, fmt.Sprintf("cctx not found: index %s", nonceToCctx.CctxIndex)) + cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, chain.ChainId, nonce) + if err != nil { + return nil, err } // We should at least go backwards by 1000 nonces to pick up missed pending cctxs @@ -299,7 +276,7 @@ LoopBackwards: break } // criteria #3: if rate limiter is enabled, we should finish the RPC call if the rate limit is exceeded - if rateLimitExceeded(chain.ChainId, &cctx, gasCoinRates, erc20CoinRates, erc20Coins, &totalCctxValueInZeta, rateLimitInZeta) { + if rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, erc20Coins, &totalCctxValueInZeta, rateLimitInZeta) { limitExceeded = true break LoopBackwards } @@ -309,7 +286,7 @@ LoopBackwards: if IsPending(cctx) { totalPending++ if !maxCCTXsReached() { - cctxs = append(cctxs, &cctx) + cctxs = append(cctxs, cctx) } } } @@ -321,23 +298,15 @@ LoopBackwards: totalPending += uint64(pendingNonces.NonceHigh - pendingNonces.NonceLow) } - // query forwards for pending cctxs for each supported chain + // query forwards for pending cctxs for each foreign chain LoopForwards: for _, chain := range chains { - if chain.IsZetaChain() { - continue - } - // query the pending cctxs in range [NonceLow, NonceHigh) pendingNonces := pendingNoncesMap[chain.ChainId] - for i := pendingNonces.NonceLow; i < pendingNonces.NonceHigh; i++ { - nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tss.TssPubkey, chain.ChainId, i) - if !found { - return nil, status.Error(codes.Internal, "nonceToCctx not found") - } - cctx, found := k.GetCrossChainTx(ctx, nonceToCctx.CctxIndex) - if !found { - return nil, status.Error(codes.Internal, "cctxIndex not found") + for nonce := pendingNonces.NonceLow; nonce < pendingNonces.NonceHigh; nonce++ { + cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, chain.ChainId, nonce) + if err != nil { + return nil, err } // only take a `limit` number of pending cctxs as result @@ -345,11 +314,11 @@ LoopForwards: break LoopForwards } // criteria #3: if rate limiter is enabled, we should finish the RPC call if the rate limit is exceeded - if applyLimit && rateLimitExceeded(chain.ChainId, &cctx, gasCoinRates, erc20CoinRates, erc20Coins, &totalCctxValueInZeta, rateLimitInZeta) { + if applyLimit && rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, erc20Coins, &totalCctxValueInZeta, rateLimitInZeta) { limitExceeded = true break LoopForwards } - cctxs = append(cctxs, &cctx) + cctxs = append(cctxs, cctx) } } @@ -360,6 +329,19 @@ LoopForwards: }, nil } +// getCctxByChainIDAndNonce returns the cctx by chainID and nonce +func getCctxByChainIDAndNonce(k Keeper, ctx sdk.Context, tssPubkey string, chainID int64, nonce int64) (*types.CrossChainTx, error) { + nonceToCctx, found := k.GetObserverKeeper().GetNonceToCctx(ctx, tssPubkey, chainID, nonce) + if !found { + return nil, status.Error(codes.Internal, fmt.Sprintf("nonceToCctx not found: chainid %d, nonce %d", chainID, nonce)) + } + cctx, found := k.GetCrossChainTx(ctx, nonceToCctx.CctxIndex) + if !found { + return nil, status.Error(codes.Internal, fmt.Sprintf("cctx not found: index %s", nonceToCctx.CctxIndex)) + } + return &cctx, nil +} + // convertCctxValue converts the value of the cctx in ZETA using given conversion rates func convertCctxValue( chainID int64, diff --git a/x/crosschain/keeper/msg_server_add_to_outtx_tracker.go b/x/crosschain/keeper/msg_server_add_to_outtx_tracker.go index 9031d0bfe5..4f149a901f 100644 --- a/x/crosschain/keeper/msg_server_add_to_outtx_tracker.go +++ b/x/crosschain/keeper/msg_server_add_to_outtx_tracker.go @@ -41,7 +41,7 @@ func (k msgServer) AddToOutTxTracker(goCtx context.Context, msg *types.MsgAddToO } // tracker submission is only allowed when the cctx is pending - if !IsPending(*cctx.CrossChainTx) { + if !IsPending(cctx.CrossChainTx) { // garbage tracker (for any reason) is harmful to outTx observation and should be removed if it exists // it if does not exist, RemoveOutTxTracker is a no-op k.RemoveOutTxTracker(ctx, msg.ChainId, msg.Nonce) diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index cbcaede349..3d524a91cb 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -90,6 +90,7 @@ type ObserverKeeper interface { ) (bool, bool, observertypes.Ballot, string, error) GetSupportedChainFromChainID(ctx sdk.Context, chainID int64) *chains.Chain GetSupportedChains(ctx sdk.Context) []*chains.Chain + GetSupportedForeignChains(ctx sdk.Context) []*chains.Chain } type FungibleKeeper interface { diff --git a/x/observer/keeper/chain_params.go b/x/observer/keeper/chain_params.go index 7b5e0a246e..8c4d7301fb 100644 --- a/x/observer/keeper/chain_params.go +++ b/x/observer/keeper/chain_params.go @@ -71,3 +71,16 @@ func (k Keeper) GetSupportedChains(ctx sdk.Context) []*chains.Chain { } return c } + +// GetSupportedForeignChains returns the list of supported foreign chains +func (k Keeper) GetSupportedForeignChains(ctx sdk.Context) []*chains.Chain { + allChains := k.GetSupportedChains(ctx) + + foreignChains := make([]*chains.Chain, 0) + for _, chain := range allChains { + if !chain.IsZetaChain() { + foreignChains = append(foreignChains, chain) + } + } + return foreignChains +} diff --git a/zetaclient/evm/evm_signer.go b/zetaclient/evm/evm_signer.go index e077d19fe2..cd038f2912 100644 --- a/zetaclient/evm/evm_signer.go +++ b/zetaclient/evm/evm_signer.go @@ -560,7 +560,7 @@ func (signer *Signer) reportToOutTxTracker(zetaBridge interfaces.ZetaCoreBridger cctx, err := zetaBridge.GetCctxByNonce(chainID, nonce) if err != nil { logger.Err(err).Msgf("reportToOutTxTracker: error getting cctx for chain %d nonce %d outTxHash %s", chainID, nonce, outTxHash) - } else if !crosschainkeeper.IsPending(*cctx) { + } else if !crosschainkeeper.IsPending(cctx) { logger.Info().Msgf("reportToOutTxTracker: cctx already finalized for chain %d nonce %d outTxHash %s", chainID, nonce, outTxHash) break }