Skip to content

Commit

Permalink
added unit tests for query pending cctxs within rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
ws4charlie committed Apr 20, 2024
1 parent 874d5f3 commit 16955b8
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 42 deletions.
48 changes: 33 additions & 15 deletions x/crosschain/keeper/grpc_query_cctx_rate_limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
"context"
"sort"
"strings"

sdk "github.com/cosmos/cosmos-sdk/types"
Expand Down Expand Up @@ -90,6 +91,12 @@ func (k Keeper) ListPendingCctxWithinRateLimit(c context.Context, req *types.Que
return uint32(len(cctxs)) >= limit
}

// if a cctx falls within the rate limiter window
isCctxInWindow := func(cctx *types.CrossChainTx) bool {
// #nosec G701 checked positive
return cctx.InboundTxParams.InboundTxObservedExternalHeight >= uint64(leftWindowBoundary)
}

// query pending nonces for each foreign chain
// Note: The pending nonces could change during the RPC call, so query them beforehand
pendingNoncesMap := make(map[int64]*observertypes.PendingNonces)
Expand All @@ -114,23 +121,29 @@ LoopBackwards:
endNonce = 0
}

// add the pending nonces to the total pending
// Note: the `totalPending` may not be accurate only if the rate limiter triggers early exit
// `totalPending` is now used for metrics only and it's okay to trade off accuracy for performance
// #nosec G701 always in range
totalPending += uint64(pendingNonces.NonceHigh - pendingNonces.NonceLow)

// query cctx by nonce backwards to the left boundary of the rate limit sliding window
for nonce := startNonce; nonce >= 0; nonce-- {
cctx, err := getCctxByChainIDAndNonce(k, ctx, tss.TssPubkey, chain.ChainId, nonce)
if err != nil {
return nil, err
}
inWindow := isCctxInWindow(cctx)

// We should at least go backwards by 1000 nonces to pick up missed pending cctxs
// We might go even further back if rate limiter is enabled and the endNonce hasn't hit the left window boundary yet
// There are two criteria to stop scanning backwards:
// criteria #1: we'll stop at the left window boundary if the `endNonce` hasn't hit it yet
// #nosec G701 always positive
if nonce < endNonce && cctx.InboundTxParams.InboundTxObservedExternalHeight < uint64(leftWindowBoundary) {
if nonce < endNonce && !inWindow {
break
}
// criteria #2: we should finish the RPC call if the rate limit is exceeded
if rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, foreignCoinMap, &totalCctxValueInZeta, rateLimitInZeta) {
if inWindow && rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, foreignCoinMap, &totalCctxValueInZeta, rateLimitInZeta) {
limitExceeded = true
break LoopBackwards
}
Expand All @@ -143,12 +156,6 @@ LoopBackwards:
}
}
}

// add the pending nonces to the total pending
// Note: the `totalPending` may not be accurate only if the rate limiter triggers early exit
// `totalPending` is now used for metrics only and it's okay to trade off accuracy for performance
// #nosec G701 always in range
totalPending += uint64(pendingNonces.NonceHigh - pendingNonces.NonceLow)
}

// query forwards for pending cctxs for each foreign chain
Expand All @@ -161,20 +168,29 @@ LoopForwards:
if err != nil {
return nil, err
}
inWindow := isCctxInWindow(cctx)

// only take a `limit` number of pending cctxs as result
if maxCCTXsReached() {
break LoopForwards
}
// criteria #2: we should finish the RPC call if the rate limit is exceeded
if rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, foreignCoinMap, &totalCctxValueInZeta, rateLimitInZeta) {
if inWindow && rateLimitExceeded(chain.ChainId, cctx, gasCoinRates, erc20CoinRates, foreignCoinMap, &totalCctxValueInZeta, rateLimitInZeta) {
limitExceeded = true
break LoopForwards
}
cctxs = append(cctxs, cctx)
}
}

// sort the cctxs by chain ID and nonce (lower nonce holds higher priority for scheduling)
sort.Slice(cctxs, func(i, j int) bool {
if cctxs[i].GetCurrentOutTxParam().ReceiverChainId == cctxs[j].GetCurrentOutTxParam().ReceiverChainId {
return cctxs[i].GetCurrentOutTxParam().OutboundTxTssNonce < cctxs[j].GetCurrentOutTxParam().OutboundTxTssNonce
}
return cctxs[i].GetCurrentOutTxParam().ReceiverChainId < cctxs[j].GetCurrentOutTxParam().ReceiverChainId
})

return &types.QueryListPendingCctxWithinRateLimitResponse{
CrossChainTx: cctxs,
TotalPending: totalPending,
Expand Down Expand Up @@ -229,15 +245,17 @@ func ConvertCctxValue(
}
decimals = uint64(fCoin.Decimals)

// the reciprocal of `rate` is the amount of zrc20 needed to buy 1 ZETA
// for example, given rate = 0.8, the reciprocal is 1.25, which means 1.25 ZRC20 can buy 1 ZETA
// given decimals = 6, the `oneZeta` amount will be 1.25 * 10^6 = 1250000
// given decimals = 6, the `oneZrc20` amount will be 10^6 = 1000000
oneZrc20 := sdk.NewDec(10).Power(decimals)
oneZeta := oneZrc20.Quo(rate)

// convert asset amount into ZETA
// step 1: convert the amount into ZRC20 integer amount
// step 2: convert the ZRC20 integer amount into decimal amount
// given amountCctx = 2000000, rate = 0.8, decimals = 6
// the amountZrc20 = 2000000 * 0.8 = 1600000, the amountZeta = 1600000 / 1000000 = 1.6
amountCctx := sdk.NewDecFromBigInt(cctx.GetCurrentOutTxParam().Amount.BigInt())
amountZeta := amountCctx.Quo(oneZeta)
amountZrc20 := amountCctx.Mul(rate)
amountZeta := amountZrc20.Quo(oneZrc20)
return amountZeta
}

Expand Down
197 changes: 170 additions & 27 deletions x/crosschain/keeper/grpc_query_cctx_rate_limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package keeper_test

import (
"fmt"
"sort"
"strings"
"testing"

Expand All @@ -28,8 +27,8 @@ func createTestRateLimiterFlags(
) types.RateLimiterFlags {
var rateLimiterFlags = types.RateLimiterFlags{
Enabled: true,
Window: 100, // 100 zeta blocks, 10 minutes
Rate: math.NewUint(1000), // 1000 ZETA
Window: 500, // 500 zeta blocks, 50 minutes
Rate: math.NewUint(5000), // 5000 ZETA
Conversions: []types.Conversion{
// ETH
{
Expand All @@ -51,11 +50,11 @@ func createTestRateLimiterFlags(
return rateLimiterFlags
}

// createCctxWithCopyTypeAndBlockRange
// createCctxsWithCoinTypeAndHeightRange
// - create 1 cctx per block from lowBlock to highBlock (inclusive)
//
// return created cctxs
func createCctxWithCopyTypeAndHeightRange(
func createCctxsWithCoinTypeAndHeightRange(
t *testing.T,
ctx sdk.Context,
k keeper.Keeper,
Expand All @@ -78,6 +77,7 @@ func createCctxWithCopyTypeAndHeightRange(
cctx.InboundTxParams.CoinType = coinType
cctx.InboundTxParams.Asset = asset
cctx.InboundTxParams.InboundTxObservedExternalHeight = i
cctx.GetCurrentOutTxParam().ReceiverChainId = chainID
cctx.GetCurrentOutTxParam().Amount = sdk.NewUint(amount)
cctx.GetCurrentOutTxParam().OutboundTxTssNonce = nonce
k.SetCrossChainTx(ctx, *cctx)
Expand Down Expand Up @@ -124,6 +124,59 @@ func setupForeignCoins(
}
}

// createKeeperForRateLimiterTest creates a keeper filled with cctxs for rate limiter test
func createKeeperForRateLimiterTest(t *testing.T) (k *keeper.Keeper, ctx sdk.Context, cctxsETH, cctxsBTC []*types.CrossChainTx, rateLimiterFlags types.RateLimiterFlags) {
// chain IDs
ethChainID := getValidEthChainID()
btcChainID := getValidBtcChainID()

// zrc20 addresses for ETH, BTC, USDT and asset for USDT
zrc20ETH := sample.EthAddress().Hex()
zrc20BTC := sample.EthAddress().Hex()
zrc20USDT := sample.EthAddress().Hex()
assetUSDT := sample.EthAddress().Hex()

// create test rate limiter flags
rateLimiterFlags = createTestRateLimiterFlags(zrc20ETH, zrc20BTC, zrc20USDT, "2500", "50000", "0.8")

// define cctx status
statusPending := types.CctxStatus_PendingOutbound
statusMined := types.CctxStatus_OutboundMined

// create test keepers
k, ctx, _, zk := keepertest.CrosschainKeeper(t)

// Set TSS
tss := sample.Tss()
zk.ObserverKeeper.SetTSS(ctx, tss)

// Set foreign coins
setupForeignCoins(t, ctx, zk, zrc20ETH, zrc20BTC, zrc20USDT, assetUSDT)

// Set rate limiter flags
k.SetRateLimiterFlags(ctx, rateLimiterFlags)

// Create cctxs [0~999] and [1000~1199] for Eth chain, 0.001 ETH (2.5 ZETA) per cctx
createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1, 1000, ethChainID, coin.CoinType_Gas, "", uint64(1e15), statusMined)
cctxsETH = createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1001, 1200, ethChainID, coin.CoinType_Gas, "", uint64(1e15), statusPending)

// Set Eth chain pending nonces, [1000~1099] are missed cctxs
setPendingNonces(ctx, zk, ethChainID, 1100, 1200, tss.TssPubkey)

// Create cctxs [0~999] and [1000~1199] for Btc chain, 0.00001 BTC (0.5 ZETA) per cctx
createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1, 1000, btcChainID, coin.CoinType_Gas, "", 1000, statusMined)
cctxsBTC = createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1001, 1200, btcChainID, coin.CoinType_Gas, "", 1000, statusPending)
require.NotNil(t, cctxsBTC)

// Set Btc chain pending nonces, [1000~1099] are missed cctxs
setPendingNonces(ctx, zk, btcChainID, 1100, 1200, tss.TssPubkey)

// Set current block height to 1201, the window is now [701, 1201], the nonces [700~1200] fall into the window
ctx = ctx.WithBlockHeight(1201)

return k, ctx, cctxsETH, cctxsBTC, rateLimiterFlags
}

func Test_ConvertCctxValue(t *testing.T) {
// chain IDs
ethChainID := getValidEthChainID()
Expand Down Expand Up @@ -265,18 +318,46 @@ func Test_ConvertCctxValue(t *testing.T) {
}

func TestKeeper_ListPendingCctxWithinRateLimit(t *testing.T) {
// chain IDs
ethChainID := getValidEthChainID()

// define cctx status
statusPending := types.CctxStatus_PendingOutbound
statusMined := types.CctxStatus_OutboundMined

t.Run("should fail for empty req", func(t *testing.T) {
k, ctx, _, _ := keepertest.CrosschainKeeper(t)
_, err := k.ListPendingCctxWithinRateLimit(ctx, nil)
require.ErrorContains(t, err, "invalid request")
})
t.Run("height out of range", func(t *testing.T) {
k, ctx, _, _ := keepertest.CrosschainKeeper(t)

// Set rate limiter flags as disabled
rFlags := sample.RateLimiterFlags()
k.SetRateLimiterFlags(ctx, rFlags)

ctx = ctx.WithBlockHeight(0)
_, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.ErrorContains(t, err, "height out of range")
})
t.Run("tss not found", func(t *testing.T) {
k, ctx, _, _ := keepertest.CrosschainKeeper(t)

// Set rate limiter flags as disabled
rFlags := sample.RateLimiterFlags()
k.SetRateLimiterFlags(ctx, rFlags)

_, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.ErrorContains(t, err, "tss not found")
})
t.Run("pending nonces not found", func(t *testing.T) {
k, ctx, _, zk := keepertest.CrosschainKeeper(t)

// Set rate limiter flags as disabled
rFlags := sample.RateLimiterFlags()
k.SetRateLimiterFlags(ctx, rFlags)

// Set TSS
tss := sample.Tss()
zk.ObserverKeeper.SetTSS(ctx, tss)

_, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.ErrorContains(t, err, "pending nonces not found")
})
t.Run("should use fallback query", func(t *testing.T) {
k, ctx, _, zk := keepertest.CrosschainKeeper(t)

Expand All @@ -285,13 +366,14 @@ func TestKeeper_ListPendingCctxWithinRateLimit(t *testing.T) {
zk.ObserverKeeper.SetTSS(ctx, tss)

// Set rate limiter flags as disabled
rateLimiterFlags := sample.RateLimiterFlags()
rateLimiterFlags.Enabled = false
k.SetRateLimiterFlags(ctx, rateLimiterFlags)
rFlags := sample.RateLimiterFlags()
rFlags.Enabled = false
k.SetRateLimiterFlags(ctx, rFlags)

// Create cctxs [0~999] and [1000~1199] for Eth chain, 0.001 ETH per cctx
_ = createCctxWithCopyTypeAndHeightRange(t, ctx, *k, zk, tss, 1, 1000, ethChainID, coin.CoinType_Gas, "", uint64(1e15), statusMined)
cctxETH := createCctxWithCopyTypeAndHeightRange(t, ctx, *k, zk, tss, 1001, 1200, ethChainID, coin.CoinType_Gas, "", uint64(1e15), statusPending)
ethChainID := getValidEthChainID()
_ = createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1, 1000, ethChainID, coin.CoinType_Gas, "", uint64(1e15), types.CctxStatus_OutboundMined)
cctxsETH := createCctxsWithCoinTypeAndHeightRange(t, ctx, *k, zk, tss, 1001, 1200, ethChainID, coin.CoinType_Gas, "", uint64(1e15), types.CctxStatus_PendingOutbound)

// Set Eth chain pending nonces which contains 100 missing cctxs
setPendingNonces(ctx, zk, ethChainID, 1100, 1200, tss.TssPubkey)
Expand All @@ -301,23 +383,84 @@ func TestKeeper_ListPendingCctxWithinRateLimit(t *testing.T) {
require.NoError(t, err)
require.Equal(t, 100, len(res.CrossChainTx))

// sort res.CrossChainTx by outbound nonce ascending so that we can compare with cctxETH
sort.Slice(res.CrossChainTx, func(i, j int) bool {
return res.CrossChainTx[i].GetCurrentOutTxParam().OutboundTxTssNonce < res.CrossChainTx[j].GetCurrentOutTxParam().OutboundTxTssNonce
})
require.EqualValues(t, cctxETH[:100], res.CrossChainTx)
// check response
require.EqualValues(t, cctxsETH[:100], res.CrossChainTx)
require.EqualValues(t, uint64(200), res.TotalPending)

// Query pending cctxs use max limit
res, err = k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{Limit: keeper.MaxPendingCctxs})
require.NoError(t, err)
require.Equal(t, 200, len(res.CrossChainTx))

// sort res.CrossChainTx by outbound nonce ascending so that we can compare with cctxETH
sort.Slice(res.CrossChainTx, func(i, j int) bool {
return res.CrossChainTx[i].GetCurrentOutTxParam().OutboundTxTssNonce < res.CrossChainTx[j].GetCurrentOutTxParam().OutboundTxTssNonce
})
require.EqualValues(t, cctxETH, res.CrossChainTx)
// check response
require.EqualValues(t, cctxsETH, res.CrossChainTx)
require.EqualValues(t, uint64(200), res.TotalPending)
})
t.Run("can retrieve pending cctx in range without exceeding rate limit", func(t *testing.T) {
k, ctx, cctxsETH, cctxsBTC, _ := createKeeperForRateLimiterTest(t)

res, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.NoError(t, err)
require.Equal(t, 400, len(res.CrossChainTx))
require.EqualValues(t, cctxsETH, res.CrossChainTx[0:200])
require.EqualValues(t, cctxsBTC, res.CrossChainTx[200:400])
require.EqualValues(t, uint64(400), res.TotalPending)
require.False(t, res.RateLimitExceeded)
})
t.Run("Set rate to a lower value (< 1200) to early break the LoopBackwards with criteria #2", func(t *testing.T) {
k, ctx, cctxsETH, cctxsBTC, rlFlags := createKeeperForRateLimiterTest(t)

rlFlags.Rate = math.NewUint(1000) // 1000 ZETA
k.SetRateLimiterFlags(ctx, rlFlags)

res, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.NoError(t, err)
require.Equal(t, 200, len(res.CrossChainTx))
require.EqualValues(t, cctxsETH[:100], res.CrossChainTx[0:100])
require.EqualValues(t, cctxsBTC[:100], res.CrossChainTx[100:200])
require.EqualValues(t, uint64(400), res.TotalPending)
require.True(t, res.RateLimitExceeded)
})
t.Run("Set high rate and big window to early to break inner loop with the criteria #1", func(t *testing.T) {
k, ctx, cctxsETH, cctxsBTC, rlFlags := createKeeperForRateLimiterTest(t)

// The left boundary will be 51 (1201-1150), less than the endNonce 100 (1100 - 10000)
rlFlags.Rate = math.NewUint(10000)
rlFlags.Window = 1150
k.SetRateLimiterFlags(ctx, rlFlags)

res, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.NoError(t, err)
require.Equal(t, 400, len(res.CrossChainTx))
require.EqualValues(t, cctxsETH, res.CrossChainTx[0:200])
require.EqualValues(t, cctxsBTC, res.CrossChainTx[200:400])
require.EqualValues(t, uint64(400), res.TotalPending)
require.False(t, res.RateLimitExceeded)
})
t.Run("Set lower request limit to early break the LoopForwards loop", func(t *testing.T) {
k, ctx, cctxsETH, cctxsBTC, _ := createKeeperForRateLimiterTest(t)

res, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{Limit: 300})
require.NoError(t, err)
require.Equal(t, 300, len(res.CrossChainTx))
require.EqualValues(t, cctxsETH[:100], res.CrossChainTx[0:100])
require.EqualValues(t, cctxsBTC, res.CrossChainTx[100:300])
require.EqualValues(t, uint64(400), res.TotalPending)
require.False(t, res.RateLimitExceeded)
})
t.Run("Set rate to middle value (1200 < rate < 1500) to early break the LoopForwards loop with criteria #2", func(t *testing.T) {
k, ctx, cctxsETH, cctxsBTC, rlFlags := createKeeperForRateLimiterTest(t)

rlFlags.Window = 500
rlFlags.Rate = math.NewUint(1300) // 1300 ZETA
k.SetRateLimiterFlags(ctx, rlFlags)

res, err := k.ListPendingCctxWithinRateLimit(ctx, &types.QueryListPendingCctxWithinRateLimitRequest{})
require.NoError(t, err)
require.Equal(t, 320, len(res.CrossChainTx)) // 120 ETH cctx + 200 BTC cctx
require.EqualValues(t, cctxsETH[:120], res.CrossChainTx[0:120])
require.EqualValues(t, cctxsBTC, res.CrossChainTx[120:320])
require.EqualValues(t, uint64(400), res.TotalPending)
require.True(t, res.RateLimitExceeded)
})
}

0 comments on commit 16955b8

Please sign in to comment.