Skip to content

Commit

Permalink
make refund address mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
kingpinXD committed Feb 12, 2024
1 parent 2cdd699 commit 8efd398
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 102 deletions.
25 changes: 10 additions & 15 deletions x/crosschain/keeper/msg_server_refund_aborted_tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (k msgServer) RefundAbortedCCTX(goCtx context.Context, msg *types.MsgRefund
}
}

refundAddress, err := GetRefundAddress(cctx, msg.RefundAddress)
refundAddress, err := GetRefundAddress(msg.RefundAddress)
if err != nil {
return nil, errorsmod.Wrap(types.ErrInvalidAddress, err.Error())
}
Expand All @@ -81,24 +81,19 @@ func (k msgServer) RefundAbortedCCTX(goCtx context.Context, msg *types.MsgRefund
// For EVM chain with coin type Zeta the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.
// For EVM chain with coin type Gas the refund address is the tx origin, but can be overridden by the RefundAddress field in the message.

func GetRefundAddress(cctx types.CrossChainTx, optionalRefundAddress string) (ethcommon.Address, error) {
func GetRefundAddress(refundAddress string) (ethcommon.Address, error) {
// make sure a separate refund address is provided for a bitcoin chain as we cannot refund to tx origin or sender in this case
if common.IsBitcoinChain(cctx.InboundTxParams.SenderChainId) && optionalRefundAddress == "" {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "refund address is required for bitcoin chain")
if refundAddress == "" {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "refund address is required")
}
refundAddress := ethcommon.HexToAddress(cctx.InboundTxParams.TxOrigin)
if cctx.InboundTxParams.CoinType == common.CoinType_ERC20 {
refundAddress = ethcommon.HexToAddress(cctx.InboundTxParams.Sender)
}
if optionalRefundAddress != "" {
if !ethcommon.IsHexAddress(optionalRefundAddress) {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "invalid refund address provided")
}
refundAddress = ethcommon.HexToAddress(optionalRefundAddress)
if !ethcommon.IsHexAddress(refundAddress) {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "invalid refund address provided")
}
ethRefundAddress := ethcommon.HexToAddress(refundAddress)
// Double check to make sure the refund address is valid
if refundAddress == (ethcommon.Address{}) {
if ethRefundAddress == (ethcommon.Address{}) {
return ethcommon.Address{}, errorsmod.Wrap(types.ErrInvalidAddress, "invalid refund address")
}
return refundAddress, nil
return ethRefundAddress, nil

}
105 changes: 18 additions & 87 deletions x/crosschain/keeper/msg_server_refund_aborted_tx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

sdk "github.com/cosmos/cosmos-sdk/types"
ethcommon "github.com/ethereum/go-ethereum/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeta-chain/zetacore/cmd/zetacored/config"
"github.com/zeta-chain/zetacore/common"
Expand All @@ -19,89 +20,19 @@ import (
func Test_GetRefundAddress(t *testing.T) {
t.Run("should return refund address if provided coin-type gas", func(t *testing.T) {
validEthAddress := sample.EthAddress()
address, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
TxOrigin: validEthAddress.String(),
CoinType: common.CoinType_Gas,
SenderChainId: getValidEthChainID(t),
}},
"")
address, err := keeper.GetRefundAddress(validEthAddress.String())
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
})
t.Run("should return refund address if provided coin-type zeta", func(t *testing.T) {
validEthAddress := sample.EthAddress()
address, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
TxOrigin: validEthAddress.String(),
CoinType: common.CoinType_Zeta,
SenderChainId: getValidEthChainID(t),
}},
"")
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
t.Run("should fail if refund address is empty", func(t *testing.T) {
address, err := keeper.GetRefundAddress("")
require.ErrorIs(t, crosschaintypes.ErrInvalidAddress, err)
assert.Equal(t, ethcommon.Address{}, address)
})
t.Run("should return refund address if provided coin-type erc20", func(t *testing.T) {
validEthAddress := sample.EthAddress()
address, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
Sender: validEthAddress.String(),
CoinType: common.CoinType_ERC20,
SenderChainId: getValidEthChainID(t),
}},
"")
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
})
t.Run("should return refund address if provided coin-type gas for btc chain", func(t *testing.T) {
validEthAddress := sample.EthAddress()
address, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
CoinType: common.CoinType_Gas,
SenderChainId: getValidBtcChainID(),
}},
validEthAddress.String())
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
})
t.Run("fail if refund address is not provided for btc chain", func(t *testing.T) {
_, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
CoinType: common.CoinType_Gas,
SenderChainId: getValidBtcChainID(),
}},
"")
require.ErrorContains(t, err, "refund address is required for bitcoin chain")
})
t.Run("address overridden if optional address is provided", func(t *testing.T) {
validEthAddress := sample.EthAddress()
address, err := keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
Sender: sample.EthAddress().String(),
CoinType: common.CoinType_ERC20,
SenderChainId: getValidEthChainID(t),
}},
validEthAddress.String())
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
address, err = keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
Sender: sample.EthAddress().String(),
CoinType: common.CoinType_Zeta,
SenderChainId: getValidEthChainID(t),
}},
validEthAddress.String())
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
address, err = keeper.GetRefundAddress(crosschaintypes.CrossChainTx{
InboundTxParams: &crosschaintypes.InboundTxParams{
Sender: sample.EthAddress().String(),
CoinType: common.CoinType_Gas,
SenderChainId: getValidEthChainID(t),
}},
validEthAddress.String())
require.NoError(t, err)
require.Equal(t, validEthAddress, address)
t.Run("should fail if refund address is invalid", func(t *testing.T) {
address, err := keeper.GetRefundAddress("invalid-address")
require.ErrorIs(t, crosschaintypes.ErrInvalidAddress, err)
assert.Equal(t, ethcommon.Address{}, address)
})

}
Expand All @@ -127,7 +58,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.NoError(t, err)

Expand Down Expand Up @@ -160,7 +91,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.NoError(t, err)

Expand Down Expand Up @@ -194,7 +125,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.NoError(t, err)

Expand Down Expand Up @@ -271,7 +202,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.NoError(t, err)

Expand Down Expand Up @@ -416,7 +347,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
})
require.ErrorContains(t, err, "cannot find cctx")
})
t.Run("fail refund if refund address not provided for BTC chain", func(t *testing.T) {
t.Run("fail refund if refund address not provided", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
admin := sample.AccAddress()
chainID := getValidBtcChainID()
Expand All @@ -439,7 +370,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
CctxIndex: cctx.Index,
RefundAddress: "",
})
require.ErrorContains(t, err, "refund address is required for bitcoin chain")
require.ErrorContains(t, err, "refund address is required")
})
t.Run("fail refund tx for coin-type Zeta if zeta accounting object is not present", func(t *testing.T) {
k, ctx, sdkk, zk := keepertest.CrosschainKeeper(t)
Expand All @@ -461,7 +392,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: admin,
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.ErrorContains(t, err, "unable to find zeta accounting")
})
Expand All @@ -486,7 +417,7 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) {
_, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{
Creator: sample.AccAddress(),
CctxIndex: cctx.Index,
RefundAddress: "",
RefundAddress: cctx.InboundTxParams.Sender,
})
require.ErrorIs(t, err, observertypes.ErrNotAuthorized)
})
Expand Down

0 comments on commit 8efd398

Please sign in to comment.