From 806a6f58cf9b3433166b572d4b18cd97739b59ef Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 8 Mar 2024 01:00:23 -0500 Subject: [PATCH] add validate function for cctx --- .../keeper/msg_server_vote_inbound_tx.go | 4 + x/crosschain/types/cctx_utils.go | 120 +++++++++++++----- x/crosschain/types/errors.go | 2 +- x/crosschain/types/message_refund_aborted.go | 2 +- 4 files changed, 96 insertions(+), 32 deletions(-) diff --git a/x/crosschain/keeper/msg_server_vote_inbound_tx.go b/x/crosschain/keeper/msg_server_vote_inbound_tx.go index c515d541f0..b59b538517 100644 --- a/x/crosschain/keeper/msg_server_vote_inbound_tx.go +++ b/x/crosschain/keeper/msg_server_vote_inbound_tx.go @@ -90,6 +90,10 @@ func (k msgServer) VoteOnObservedInboundTx(goCtx context.Context, msg *types.Msg } inboundCctx := k.GetInbound(ctx, msg) + err = inboundCctx.Validate() + if err != nil { + return nil, err + } k.ProcessInbound(ctx, &inboundCctx) k.SaveInbound(ctx, &inboundCctx) return &types.MsgVoteOnObservedInboundTxResponse{}, nil diff --git a/x/crosschain/types/cctx_utils.go b/x/crosschain/types/cctx_utils.go index 3b9b824f9f..f87430a4e0 100644 --- a/x/crosschain/types/cctx_utils.go +++ b/x/crosschain/types/cctx_utils.go @@ -4,8 +4,11 @@ import ( "fmt" "strconv" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcutil" sdk "github.com/cosmos/cosmos-sdk/types" ethcommon "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/zeta-chain/zetacore/common" observertypes "github.com/zeta-chain/zetacore/x/observer/types" ) @@ -61,8 +64,21 @@ func (m CrossChainTx) Validate() error { if len(m.OutboundTxParams) > 2 { return fmt.Errorf("outbound tx params cannot be more than 2") } - if len(m.Index) != 66 { - return ErrInvalidCCTXIndex + if m.Index != "" { + err := ValidateZetaIndex(m.Index) + if err != nil { + return err + } + } + err := m.InboundTxParams.Validate() + if err != nil { + return err + } + for _, outboundTxParam := range m.OutboundTxParams { + err = outboundTxParam.Validate() + if err != nil { + return err + } } return nil @@ -72,54 +88,98 @@ func (m InboundTxParams) Validate() error { if m.Sender == "" { return fmt.Errorf("sender cannot be empty") } - if m.InboundTxObservedHash == "" { - return fmt.Errorf("inbound tx observed hash cannot be empty") + err := ValidateAddressForChain(m.Sender, m.SenderChainId) + if err != nil { + return err } - if len(m.InboundTxBallotIndex) != 66 { - return fmt.Errorf("inbound tx ballot index must be 66 characters") + if common.GetChainFromChainID(m.SenderChainId) == nil { + return fmt.Errorf("invalid sender chain id %d", m.SenderChainId) } - if common.IsEthereumChain(m.SenderChainId) { - if !ethcommon.IsHexAddress(m.Sender) { - return fmt.Errorf("sender a valid ethereum address") + if m.TxOrigin != "" { + errTxOrigin := ValidateAddressForChain(m.TxOrigin, m.SenderChainId) + if errTxOrigin != nil { + return errTxOrigin } } if m.Amount.IsNil() { return fmt.Errorf("amount cannot be nil") } - if common.IsBitcoinChain(m.SenderChainId) { - //if _, err := common.BitcoinAddressToPubKeyHash(m.Sender); err != nil { - // return fmt.Errorf("sender must be a valid bitcoin address") - //} + err = ValidateHashForChain(m.InboundTxObservedHash, m.SenderChainId) + if err != nil { + return err + } + if m.InboundTxBallotIndex != "" { + err = ValidateZetaIndex(m.InboundTxBallotIndex) + if err != nil { + return err + } + } + return nil +} + +func ValidateZetaIndex(index string) error { + if len(index) != 66 { + return ErrInvalidIndexValue } return nil } +func ValidateHashForChain(hash string, chainID int64) error { + if common.IsEthereumChain(chainID) { + _, err := hexutil.Decode(hash) + if err != nil { + return fmt.Errorf("hash must be a valid ethereum hash") + } + } + if common.IsBitcoinChain(chainID) { + _, err := chainhash.NewHashFromStr(hash) + if err != nil { + return fmt.Errorf("hash must be a valid bitcoin hash") + } + } + return fmt.Errorf("invalid chain id %d", chainID) +} + +func ValidateAddressForChain(address string, chainID int64) error { + if common.IsEthereumChain(chainID) { + if !ethcommon.IsHexAddress(address) { + return fmt.Errorf("sender a valid ethereum address") + } + return nil + } + if common.IsBitcoinChain(chainID) { + addr, err := common.DecodeBtcAddress(address, chainID) + if err != nil { + return fmt.Errorf("invalid address %s: %s", address, err) + } + _, ok := addr.(*btcutil.AddressWitnessPubKeyHash) + if !ok { + return fmt.Errorf(" invalid address %s (not P2WPKH address)", address) + } + return nil + } + return fmt.Errorf("invalid chain id %d", chainID) +} func (m OutboundTxParams) Validate() error { if m.Receiver == "" { return fmt.Errorf("receiver cannot be empty") } - if m.Amount.IsNil() { - return fmt.Errorf("amount cannot be nil") - } - if m.OutboundTxGasPrice == "" { - return fmt.Errorf("outbound tx gas price cannot be empty") + err := ValidateAddressForChain(m.Receiver, m.ReceiverChainId) + if err != nil { + return err } - if m.GasLimit == 0 { - return fmt.Errorf("gas limit cannot be 0") + if common.GetChainFromChainID(m.ReceiverChainId) == nil { + return fmt.Errorf("invalid receiver chain id %d", m.ReceiverChainId) } - if m.ReceiverChainId == 0 { - return fmt.Errorf("receiver chain id cannot be 0") + if m.Amount.IsNil() { + return fmt.Errorf("amount cannot be nil") } - if common.IsEthereumChain(m.ReceiverChainId) { - if !ethcommon.IsHexAddress(m.Receiver) { - return fmt.Errorf("receiver must be a valid ethereum address") + if m.OutboundTxBallotIndex != "" { + err = ValidateZetaIndex(m.OutboundTxBallotIndex) + if err != nil { + return err } } - if common.IsBitcoinChain(m.ReceiverChainId) { - //if _, err := common.BitcoinAddressToPubKeyHash(m.Receiver); err != nil { - // return fmt.Errorf("receiver must be a valid bitcoin address") - //} - } return nil } diff --git a/x/crosschain/types/errors.go b/x/crosschain/types/errors.go index 2c5a595eb5..62720aaf66 100644 --- a/x/crosschain/types/errors.go +++ b/x/crosschain/types/errors.go @@ -36,7 +36,7 @@ var ( ErrUnsupportedStatus = errorsmod.Register(ModuleName, 1143, "unsupported status") ErrObservedTxAlreadyFinalized = errorsmod.Register(ModuleName, 1144, "observed tx already finalized") ErrInsufficientFundsTssMigration = errorsmod.Register(ModuleName, 1145, "insufficient funds for TSS migration") - ErrInvalidCCTXIndex = errorsmod.Register(ModuleName, 1146, "invalid cctx index") + ErrInvalidIndexValue = errorsmod.Register(ModuleName, 1146, "invalid index hash") ErrInvalidStatus = errorsmod.Register(ModuleName, 1147, "invalid cctx status") ErrUnableProcessRefund = errorsmod.Register(ModuleName, 1148, "unable to process refund") ErrUnableToFindZetaAccounting = errorsmod.Register(ModuleName, 1149, "unable to find zeta accounting") diff --git a/x/crosschain/types/message_refund_aborted.go b/x/crosschain/types/message_refund_aborted.go index 698d499b09..20115aec21 100644 --- a/x/crosschain/types/message_refund_aborted.go +++ b/x/crosschain/types/message_refund_aborted.go @@ -46,7 +46,7 @@ func (msg *MsgRefundAbortedCCTX) ValidateBasic() error { return errorsmod.Wrapf(sdkerrors.ErrInvalidAddress, "invalid creator address (%s)", err) } if len(msg.CctxIndex) != 66 { - return ErrInvalidCCTXIndex + return ErrInvalidIndexValue } if msg.RefundAddress != "" && !ethcommon.IsHexAddress(msg.RefundAddress) { return ErrInvalidAddress