From 1a24f7336d587480ed91c3bb65541a2fdecc01ae Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 25 Jan 2024 17:19:16 -0500 Subject: [PATCH] add additional checks for unit tests --- common/gas_limits.go | 19 +++++++ .../keeper/msg_server_migrate_tss_funds.go | 14 ++--- .../msg_server_migrate_tss_funds_test.go | 53 ++++++++++--------- x/crosschain/types/keys.go | 2 +- 4 files changed, 52 insertions(+), 36 deletions(-) diff --git a/common/gas_limits.go b/common/gas_limits.go index 0e85ea9f1c..bbd1f0a6c5 100644 --- a/common/gas_limits.go +++ b/common/gas_limits.go @@ -1,7 +1,26 @@ package common +import ( + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" +) + const ( + // EVMSend is the gas limit required to transfer tokens on an EVM based chain EVMSend = 21000 // TODO: Move gas limits from zeta-client to this file // https://github.com/zeta-chain/node/issues/1606 ) + +// MultiplyGasPrice multiplies the median gas price by the given multiplier and returns the truncated value +func MultiplyGasPrice(medianGasPrice sdkmath.Uint, multiplierString string) (sdkmath.Uint, error) { + multiplier, err := sdk.NewDecFromStr(multiplierString) + if err != nil { + return sdkmath.ZeroUint(), err + } + gasPrice, err := sdk.NewDecFromStr(medianGasPrice.String()) + if err != nil { + return sdkmath.ZeroUint(), err + } + return sdkmath.NewUintFromString(gasPrice.Mul(multiplier).TruncateInt().String()), nil +} diff --git a/x/crosschain/keeper/msg_server_migrate_tss_funds.go b/x/crosschain/keeper/msg_server_migrate_tss_funds.go index 7fb8293a0c..44ab8ee1dc 100644 --- a/x/crosschain/keeper/msg_server_migrate_tss_funds.go +++ b/x/crosschain/keeper/msg_server_migrate_tss_funds.go @@ -119,17 +119,9 @@ func (k Keeper) MigrateTSSFundsForChain(ctx sdk.Context, chainID int64, amount s // Tss migration is a send transaction, so the gas limit is set to 21000 cctx.GetCurrentOutTxParam().OutboundTxGasLimit = common.EVMSend // Multiple current gas price with standard multiplier to add some buffer - multiplier, err := sdk.NewDecFromStr(types.TssMigrationGasMultiplierEVM) - if err != nil { - return err - } - gasPrice, err := sdk.NewDecFromStr(medianGasPrice.String()) - if err != nil { - return err - } - newGasPrice := gasPrice.Mul(multiplier) - cctx.GetCurrentOutTxParam().OutboundTxGasPrice = newGasPrice.TruncateInt().String() - evmFee := sdkmath.NewUint(cctx.GetCurrentOutTxParam().OutboundTxGasLimit).Mul(medianGasPrice) + multipliedGasPrice, err := common.MultiplyGasPrice(medianGasPrice, types.TssMigrationGasMultiplierEVM) + cctx.GetCurrentOutTxParam().OutboundTxGasPrice = multipliedGasPrice.String() + evmFee := sdkmath.NewUint(cctx.GetCurrentOutTxParam().OutboundTxGasLimit).Mul(multipliedGasPrice) if evmFee.GT(amount) { return errorsmod.Wrap(types.ErrInsufficientFundsTssMigration, fmt.Sprintf("insufficient funds to pay for gas fee, amount: %s, gas fee: %s, chainid: %d", amount.String(), evmFee.String(), chainID)) } diff --git a/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go b/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go index 675d27ee64..4d909abdf0 100644 --- a/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go +++ b/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go @@ -6,7 +6,7 @@ import ( sdkmath "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/ethereum/go-ethereum/crypto" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/zeta-chain/zetacore/common" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -25,18 +25,20 @@ func TestKeeper_MigrateTSSFundsForChain(t *testing.T) { amount := sdkmath.NewUintFromString("10000000000000000000") indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) gp, found := k.GetMedianGasPriceInUint(ctx, chain.ChainId) - assert.True(t, found) + require.True(t, found) _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, }) - assert.NoError(t, err) + require.NoError(t, err) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() - _, found = k.GetCrossChainTx(ctx, index) - assert.True(t, found) - assert.Equal(t, gp.MulUint64(crosschaintypes.TssMigrationGasMultiplierEVM), k.GetGasPrice(ctx, chain.ChainId).Prices[1]) + cctx, found := k.GetCrossChainTx(ctx, index) + require.True(t, found) + multipliedValue, err := common.MultiplyGasPrice(gp, crosschaintypes.TssMigrationGasMultiplierEVM) + require.NoError(t, err) + require.Equal(t, multipliedValue.String(), cctx.GetCurrentOutTxParam().OutboundTxGasPrice) }) } @@ -54,11 +56,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.NoError(t, err) + require.NoError(t, err) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() - _, found := k.GetCrossChainTx(ctx, index) - assert.True(t, found) + cctx, found := k.GetCrossChainTx(ctx, index) + require.True(t, found) + feeCalculated := sdk.NewUint(cctx.GetCurrentOutTxParam().OutboundTxGasLimit). + Mul(sdkmath.NewUintFromString(cctx.GetCurrentOutTxParam().OutboundTxGasPrice)) + require.Equal(t, cctx.GetCurrentOutTxParam().Amount.String(), amount.Sub(feeCalculated).String()) }) t.Run("not enough funds in tss address for migration", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeper(t) @@ -73,11 +78,11 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.ErrorContains(t, err, crosschaintypes.ErrCannotMigrateTssFunds.Error()) + require.ErrorContains(t, err, crosschaintypes.ErrCannotMigrateTssFunds.Error()) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() _, found := k.GetCrossChainTx(ctx, index) - assert.False(t, found) + require.False(t, found) }) t.Run("unable to migrate funds if new TSS is not created ", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeper(t) @@ -92,11 +97,11 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.ErrorContains(t, err, "no new tss address has been generated") + require.ErrorContains(t, err, "no new tss address has been generated") hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() _, found := k.GetCrossChainTx(ctx, index) - assert.False(t, found) + require.False(t, found) }) t.Run("unable to migrate funds when nonce low does not match nonce high", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeper(t) @@ -117,12 +122,12 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) - assert.ErrorContains(t, err, "cannot migrate funds when there are pending nonces") + require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) + require.ErrorContains(t, err, "cannot migrate funds when there are pending nonces") hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() _, found := k.GetCrossChainTx(ctx, index) - assert.False(t, found) + require.False(t, found) }) t.Run("unable to migrate funds when a pending cctx is presnt in migration info", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeper(t) @@ -150,14 +155,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) - assert.ErrorContains(t, err, "cannot migrate funds while there are pending migrations") + require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) + require.ErrorContains(t, err, "cannot migrate funds while there are pending migrations") hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() _, found := k.GetCrossChainTx(ctx, index) - assert.False(t, found) + require.False(t, found) _, found = k.GetCrossChainTx(ctx, existingCctx.Index) - assert.True(t, found) + require.True(t, found) }) t.Run("unable to migrate funds if current TSS is not present in TSSHistory and no new TSS has been generated", func(t *testing.T) { @@ -169,7 +174,7 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { amount := sdkmath.NewUintFromString("10000000000000000000") indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, false, false) currentTss, found := k.GetObserverKeeper().GetTSS(ctx) - assert.True(t, found) + require.True(t, found) newTss := sample.Tss() newTss.FinalizedZetaHeight = currentTss.FinalizedZetaHeight - 10 newTss.KeyGenZetaHeight = currentTss.KeyGenZetaHeight - 10 @@ -179,12 +184,12 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Amount: amount, }) - assert.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) - assert.ErrorContains(t, err, "current tss is the latest") + require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) + require.ErrorContains(t, err, "current tss is the latest") hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() _, found = k.GetCrossChainTx(ctx, index) - assert.False(t, found) + require.False(t, found) }) } func setupTssMigrationParams( diff --git a/x/crosschain/types/keys.go b/x/crosschain/types/keys.go index dbec6d5c9f..4adac4d73f 100644 --- a/x/crosschain/types/keys.go +++ b/x/crosschain/types/keys.go @@ -25,7 +25,7 @@ const ( MemStoreKey = "mem_metacore" ProtocolFee = 2000000000000000000 - + //TssMigrationGasMultiplierEVM is multiplied to the median gas price to get the gas price for the tss migration . This is done to avoid the tss migration tx getting stuck in the mempool TssMigrationGasMultiplierEVM = "2.5" )