Skip to content

Commit

Permalink
use outbound amount for refunds instead of inbound
Browse files Browse the repository at this point in the history
  • Loading branch information
kingpinXD committed Feb 9, 2024
1 parent 7931648 commit d1ff261
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 41 deletions.
2 changes: 1 addition & 1 deletion x/crosschain/keeper/cctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (k Keeper) SetCctxAndNonceToCctxAndInTxHashToCctx(ctx sdk.Context, cctx typ
})
}
if cctx.CctxStatus.Status == types.CctxStatus_Aborted && cctx.GetCurrentOutTxParam().CoinType == common.CoinType_Zeta {
k.AddZetaAbortedAmount(ctx, cctx.GetCurrentOutTxParam().Amount)
k.AddZetaAbortedAmount(ctx, GetAbortedAmount(cctx))
}
}

Expand Down
16 changes: 16 additions & 0 deletions x/crosschain/keeper/cctx_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

cosmoserrors "cosmossdk.io/errors"
sdkmath "cosmossdk.io/math"
"github.com/pkg/errors"

sdk "github.com/cosmos/cosmos-sdk/types"
Expand Down Expand Up @@ -89,3 +90,18 @@ func IsPending(cctx types.CrossChainTx) bool {
// pending inbound is not considered a "pending" state because it has not reached consensus yet
return cctx.CctxStatus.Status == types.CctxStatus_PendingOutbound || cctx.CctxStatus.Status == types.CctxStatus_PendingRevert
}

// GetAbortedAmount returns the amount to refund for a given CCTX .
// If the CCTX has an outbound transaction, it returns the amount of the outbound transaction.
// If OutTxParams is nil or the amount is zero, it returns the amount of the inbound transaction.
// This is because there might be a case where the transaction is set to be aborted before paying gas or creating an outbound transaction.In such a situation we can refund the entire amount that has been locked in connector or TSS
func GetAbortedAmount(cctx types.CrossChainTx) sdkmath.Uint {
if cctx.OutboundTxParams != nil && !cctx.GetCurrentOutTxParam().Amount.IsZero() {
return cctx.GetCurrentOutTxParam().Amount
}
if cctx.InboundTxParams != nil {
return cctx.InboundTxParams.Amount
}

return sdkmath.ZeroUint()
}
41 changes: 41 additions & 0 deletions x/crosschain/keeper/cctx_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"math/big"
"testing"

sdkmath "cosmossdk.io/math"
"github.com/stretchr/testify/require"
"github.com/zeta-chain/zetacore/common"
keepertest "github.com/zeta-chain/zetacore/testutil/keeper"
"github.com/zeta-chain/zetacore/testutil/sample"
crosschainkeeper "github.com/zeta-chain/zetacore/x/crosschain/keeper"
"github.com/zeta-chain/zetacore/x/crosschain/types"
fungibletypes "github.com/zeta-chain/zetacore/x/fungible/types"
)
Expand Down Expand Up @@ -150,3 +152,42 @@ func TestGetRevertGasLimit(t *testing.T) {
require.ErrorIs(t, err, fungibletypes.ErrContractCall)
})
}

func TestGetAbortedAmount(t *testing.T) {
amount := sdkmath.NewUint(100)
t.Run("should return the inbound amount if outbound not present", func(t *testing.T) {
cctx := types.CrossChainTx{
InboundTxParams: &types.InboundTxParams{
Amount: amount,
},
}
a := crosschainkeeper.GetAbortedAmount(cctx)
require.Equal(t, amount, a)
})
t.Run("should return the amount outbound amount", func(t *testing.T) {
cctx := types.CrossChainTx{
InboundTxParams: &types.InboundTxParams{
Amount: sdkmath.ZeroUint(),
},
OutboundTxParams: []*types.OutboundTxParams{
{Amount: amount},
},
}
a := crosschainkeeper.GetAbortedAmount(cctx)
require.Equal(t, amount, a)
})
t.Run("should return the zero if outbound amount is not present and inbound is 0", func(t *testing.T) {
cctx := types.CrossChainTx{
InboundTxParams: &types.InboundTxParams{
Amount: sdkmath.ZeroUint(),
},
}
a := crosschainkeeper.GetAbortedAmount(cctx)
require.Equal(t, sdkmath.ZeroUint(), a)
})
t.Run("should return the zero if no amounts are present", func(t *testing.T) {
cctx := types.CrossChainTx{}
a := crosschainkeeper.GetAbortedAmount(cctx)
require.Equal(t, sdkmath.ZeroUint(), a)
})
}
24 changes: 16 additions & 8 deletions x/crosschain/keeper/msg_server_refund_aborted_tx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package keeper

import (
"errors"

errorsmod "cosmossdk.io/errors"
sdk "github.com/cosmos/cosmos-sdk/types"
ethcommon "github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -37,10 +39,16 @@ func (k msgServer) RefundAbortedCCTX(goCtx context.Context, msg *types.MsgRefund
// Check if aborted amount is available to maintain zeta accounting
// NOTE: Need to verify if this check works / is required in athens 3
if cctx.InboundTxParams.CoinType == common.CoinType_Zeta {
err := k.RemoveZetaAbortedAmount(ctx, cctx.InboundTxParams.Amount)
if err != nil {
err := k.RemoveZetaAbortedAmount(ctx, GetAbortedAmount(cctx))
// if the zeta accounting is not found, it means the zeta accounting is not set yet and the refund should not be processed
if errors.Is(err, types.ErrUnableToFindZetaAccounting) {
return nil, errorsmod.Wrap(types.ErrUnableProcessRefund, err.Error())
}
// if the zeta accounting is found but the amount is insufficient, it means the refund can be processed but the zeta accounting is not maintained properly
// aborted amounts for zeta accounting would need to be updated in the envionment via a migration script
if errors.Is(err, types.ErrInsufficientZetaAmount) {
ctx.Logger().Error("Zeta Accounting Error: ", err)
}
}

refundAddress, err := GetRefundAddress(cctx, msg.RefundAddress)
Expand All @@ -64,17 +72,17 @@ func (k msgServer) RefundAbortedCCTX(goCtx context.Context, msg *types.MsgRefund
return &types.MsgRefundAbortedCCTXResponse{}, nil
}

// Set the proper refund address.
// For BTC sender chain the refund address is the one provided in the message in the RefundAddress field.
// For EVM chain with coin type ERC20 the refund address is the sender , but can be overridden by the RefundAddress field in the message.
// For EVM chain with coin type Zeta the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.
// For EVM chain with coin type Gas the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.

func GetRefundAddress(cctx types.CrossChainTx, optionalRefundAddress string) (ethcommon.Address, error) {
// make sure a separate refund address is provided for a bitcoin chain as we cannot refund to tx origin or sender in this case
if common.IsBitcoinChain(cctx.InboundTxParams.SenderChainId) && optionalRefundAddress == "" {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "refund address is required for bitcoin chain")
}
// Set the proper refund address.
// For BTC sender chain the refund address is the one provided in the message in the RefundAddress field.
// For EVM chain with coin type ERC20 the refund address is the sender , but can be overridden by the RefundAddress field in the message.
// For EVM chain with coin type Zeta the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.
// For EVM chain with coin type Gas the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.

refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.TxOrigin)
if cctx.InboundTxParams.CoinType == common.CoinType_ERC20 {
refundAddress = ethcommon.HexToAddress(cctx.InboundTxParams.Sender)
Expand Down
78 changes: 56 additions & 22 deletions x/crosschain/keeper/msg_server_refund_aborted_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package keeper_test
import (
"testing"

sdkmath "cosmossdk.io/math"
sdk "github.com/cosmos/cosmos-sdk/types"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/require"
Expand All @@ -14,6 +13,7 @@ import (
"github.com/zeta-chain/zetacore/x/crosschain/keeper"
crosschaintypes "github.com/zeta-chain/zetacore/x/crosschain/types"
fungibletypes "github.com/zeta-chain/zetacore/x/fungible/types"
observertypes "github.com/zeta-chain/zetacore/x/observer/types"
)

func Test_GetRefundAddress(t *testing.T) {
Expand Down Expand Up @@ -106,7 +106,7 @@ func Test_GetRefundAddress(t *testing.T) {

}
func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
t.Run("Successfully refund tx for coin-type Gas", func(t *testing.T) {
t.Run("successfully refund tx for coin-type gas", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand Down Expand Up @@ -134,12 +134,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.TxOrigin)
balance, err := zk.FungibleKeeper.BalanceOfZRC4(ctx, zrc20, refundAddress)
require.NoError(t, err)
require.Equal(t, cctx.InboundTxParams.Amount.Uint64(), balance.Uint64())
require.Equal(t, cctx.GetCurrentOutTxParam().Amount.Uint64(), balance.Uint64())
c, found := k.GetCrossChainTx(ctx, cctx.Index)
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Successfully refund tx for coin-type Zeta", func(t *testing.T) {
t.Run("successfully refund tx for coin-type zeta", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -154,7 +154,41 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
cctx.InboundTxParams.SenderChainId = chainID
cctx.InboundTxParams.CoinType = common.CoinType_Zeta
k.SetCrossChainTx(ctx, *cctx)
k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.InboundTxParams.Amount})
k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.GetCurrentOutTxParam().Amount})
deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper)

_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
})
require.NoError(t, err)

refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.TxOrigin)
refundAddressCosmos := sdk.AccAddress(refundAddress.Bytes())
balance := sdkk.BankKeeper.GetBalance(ctx, refundAddressCosmos, config.BaseDenom)
require.Equal(t, cctx.GetCurrentOutTxParam().Amount.Uint64(), balance.Amount.Uint64())
c, found := k.GetCrossChainTx(ctx, cctx.Index)
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("successfully refund tx to inbound amount if outbound is not found for coin-type zeta", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
setAdminPolicies(ctx, zk, admin)
msgServer := keeper.NewMsgServerImpl(*k)
k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName)

cctx := sample.CrossChainTx(t, "sample-index")
cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted
cctx.CctxStatus.IsAbortRefunded = false
cctx.InboundTxParams.TxOrigin = cctx.InboundTxParams.Sender
cctx.InboundTxParams.SenderChainId = chainID
cctx.InboundTxParams.CoinType = common.CoinType_Zeta
cctx.OutboundTxParams = nil
k.SetCrossChainTx(ctx, *cctx)
k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.GetCurrentOutTxParam().Amount})
deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper)

_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Expand All @@ -172,7 +206,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Successfully refund to optional refund address if provided", func(t *testing.T) {
t.Run("successfully refund to optional refund address if provided", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand Down Expand Up @@ -200,12 +234,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {

refundAddressCosmos := sdk.AccAddress(refundAddress.Bytes())
balance := sdkk.BankKeeper.GetBalance(ctx, refundAddressCosmos, config.BaseDenom)
require.Equal(t, cctx.InboundTxParams.Amount.Uint64(), balance.Amount.Uint64())
require.Equal(t, cctx.GetCurrentOutTxParam().Amount.Uint64(), balance.Amount.Uint64())
c, found := k.GetCrossChainTx(ctx, cctx.Index)
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Successfully refund tx for coin-type ERC20", func(t *testing.T) {
t.Run("successfully refund tx for coin-type ERC20", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand Down Expand Up @@ -244,12 +278,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.Sender)
balance, err := zk.FungibleKeeper.BalanceOfZRC4(ctx, zrc20Addr, refundAddress)
require.NoError(t, err)
require.Equal(t, cctx.InboundTxParams.Amount.Uint64(), balance.Uint64())
require.Equal(t, cctx.GetCurrentOutTxParam().Amount.Uint64(), balance.Uint64())
c, found := k.GetCrossChainTx(ctx, cctx.Index)
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Successfully refund tx for coin-type Gas with BTC sender", func(t *testing.T) {
t.Run("successfully refund tx for coin-type Gas with BTC sender", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidBtcChainID()
Expand Down Expand Up @@ -277,12 +311,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.TxOrigin)
balance, err := zk.FungibleKeeper.BalanceOfZRC4(ctx, zrc20, refundAddress)
require.NoError(t, err)
require.Equal(t, cctx.InboundTxParams.Amount.Uint64(), balance.Uint64())
require.Equal(t, cctx.GetCurrentOutTxParam().Amount.Uint64(), balance.Uint64())
c, found := k.GetCrossChainTx(ctx, cctx.Index)
require.True(t, found)
require.True(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Fail refund if address provided is invalid", func(t *testing.T) {
t.Run("fail refund if address provided is invalid", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -307,7 +341,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "invalid refund address")
})
t.Run("Fail refund if address provided is invalid 2 ", func(t *testing.T) {
t.Run("fail refund if address provided is null ", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -332,7 +366,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "invalid refund address")
})
t.Run("Fail refund if status is not aborted", func(t *testing.T) {
t.Run("fail refund if status is not aborted", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -359,7 +393,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
require.True(t, found)
require.False(t, c.CctxStatus.IsAbortRefunded)
})
t.Run("Fail refund if status cctx not found", func(t *testing.T) {
t.Run("fail refund if status cctx not found", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -382,7 +416,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "cannot find cctx")
})
t.Run("Fail refund if refund address not provided for BTC chain", func(t *testing.T) {
t.Run("fail refund if refund address not provided for BTC chain", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidBtcChainID()
Expand All @@ -407,7 +441,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "refund address is required for bitcoin chain")
})
t.Run("Fail refund tx for coin-type Zeta if zeta accounting object is not present", func(t *testing.T) {
t.Run("fail refund tx for coin-type Zeta if zeta accounting object is not present", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -431,7 +465,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "unable to find zeta accounting")
})
t.Run("Fail refund tx for coin-type Zeta if zeta accounting does not have enough aborted amount", func(t *testing.T) {
t.Run("fail refund if non admin account is the creator", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidEthChainID(t)
Expand All @@ -444,16 +478,16 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
cctx.CctxStatus.IsAbortRefunded = false
cctx.InboundTxParams.TxOrigin = cctx.InboundTxParams.Sender
cctx.InboundTxParams.SenderChainId = chainID
cctx.InboundTxParams.CoinType = common.CoinType_Zeta
cctx.InboundTxParams.CoinType = common.CoinType_Gas
k.SetCrossChainTx(ctx, *cctx)
k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: sdkmath.OneUint()})
deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper)
_ = setupGasCoin(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper, cctx.InboundTxParams.SenderChainId, "foobar", "foobar")

_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
Creator: sample.AccAddress(),
CctxIndex: cctx.Index,
RefundAddress: "",
})
require.ErrorContains(t, err, "insufficient zeta amount")
require.ErrorIs(t, err, observertypes.ErrNotAuthorized)
})
}
Loading

0 comments on commit d1ff261

Please sign in to comment.