From f02cc5253a5fbc6f0402c41ad0905837de3eb2c9 Mon Sep 17 00:00:00 2001 From: Charlie Chen Date: Thu, 28 Mar 2024 16:41:43 -0500 Subject: [PATCH] improved function and added comments --- common/bitcoin/address_taproot.go | 1 + zetaclient/bitcoin/bitcoin_client.go | 36 ++++++++++------------- zetaclient/bitcoin/bitcoin_client_test.go | 31 +++++++++---------- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/common/bitcoin/address_taproot.go b/common/bitcoin/address_taproot.go index e178a8e5a3..fa9968dcfe 100644 --- a/common/bitcoin/address_taproot.go +++ b/common/bitcoin/address_taproot.go @@ -193,6 +193,7 @@ func (a *AddressSegWit) String() string { return a.EncodeAddress() } +// DecodeTaprootAddress decodes taproot address only and returns error on non-taproot address func DecodeTaprootAddress(addr string) (*AddressTaproot, error) { hrp, version, program, err := decodeSegWitAddress(addr) if err != nil { diff --git a/zetaclient/bitcoin/bitcoin_client.go b/zetaclient/bitcoin/bitcoin_client.go index 2a3afbb9f7..327db5a941 100644 --- a/zetaclient/bitcoin/bitcoin_client.go +++ b/zetaclient/bitcoin/bitcoin_client.go @@ -726,7 +726,7 @@ func (ob *BTCChainClient) IsInTxRestricted(inTx *BTCInTxEvnet) bool { return false } -// GetBtcEvent either returns a valid BTCInTxEvnet or nil +// GetBtcEvent either returns a valid BTCInTxEvent or nil // Note: the caller should retry the tx on error (e.g., GetSenderAddressByVin failed) func GetBtcEvent( rpcClient interfaces.BTCRPCClient, @@ -741,7 +741,7 @@ func GetBtcEvent( var value float64 var memo []byte if len(tx.Vout) >= 2 { - // 1st vout must to addressed to the tssAddress with p2wpkh scriptPubKey + // 1st vout must have tss address as receiver with p2wpkh scriptPubKey vout0 := tx.Vout[0] script := vout0.ScriptPubKey.Hex if len(script) == 44 && script[:4] == "0014" { // P2WPKH output: 0x00 + 20 bytes of pubkey hash @@ -1252,12 +1252,12 @@ func (ob *BTCChainClient) checkTssOutTxResult(cctx *types.CrossChainTx, hash *ch // differentiate between normal and restricted cctx if compliance.IsCctxRestricted(cctx) { - err = ob.checkTSSVoutCancelled(params, rawResult.Vout, ob.chain) + err = ob.checkTSSVoutCancelled(params, rawResult.Vout) if err != nil { return errors.Wrapf(err, "checkTssOutTxResult: invalid TSS Vout in cancelled outTx %s nonce %d", hash, nonce) } } else { - err = ob.checkTSSVout(params, rawResult.Vout, ob.chain) + err = ob.checkTSSVout(params, rawResult.Vout) if err != nil { return errors.Wrapf(err, "checkTssOutTxResult: invalid TSS Vout in outTx %s nonce %d", hash, nonce) } @@ -1343,7 +1343,7 @@ func (ob *BTCChainClient) checkTSSVin(vins []btcjson.Vin, nonce uint64) error { // - The first output is the nonce-mark // - The second output is the correct payment to recipient // - The third output is the change to TSS (optional) -func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []btcjson.Vout, chain common.Chain) error { +func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []btcjson.Vout) error { // vouts: [nonce-mark, payment to recipient, change to TSS (optional)] if !(len(vouts) == 2 || len(vouts) == 3) { return fmt.Errorf("checkTSSVout: invalid number of vouts: %d", len(vouts)) @@ -1358,21 +1358,19 @@ func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []b // the 2nd output is the payment to recipient receiverExpected = params.Receiver } - receiverVout, amount, err := DecodeTSSVout(vout, receiverExpected, chain) + receiverVout, amount, err := DecodeTSSVout(vout, receiverExpected, ob.chain) if err != nil { return err } - // 1st vout: nonce-mark - if vout.N == 0 { + switch vout.N { + case 0: // 1st vout: nonce-mark if receiverVout != tssAddress { return fmt.Errorf("checkTSSVout: nonce-mark address %s not match TSS address %s", receiverVout, tssAddress) } if amount != common.NonceMarkAmount(nonce) { return fmt.Errorf("checkTSSVout: nonce-mark amount %d not match nonce-mark amount %d", amount, common.NonceMarkAmount(nonce)) } - } - // 2nd vout: payment to recipient - if vout.N == 1 { + case 1: // 2nd vout: payment to recipient if receiverVout != params.Receiver { return fmt.Errorf("checkTSSVout: output address %s not match params receiver %s", receiverVout, params.Receiver) } @@ -1380,9 +1378,7 @@ func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []b if uint64(amount) != params.Amount.Uint64() { return fmt.Errorf("checkTSSVout: output amount %d not match params amount %d", amount, params.Amount) } - } - // 3rd vout: change to TSS (optional) - if vout.N == 2 { + case 2: // 3rd vout: change to TSS (optional) if receiverVout != tssAddress { return fmt.Errorf("checkTSSVout: change address %s not match TSS address %s", receiverVout, tssAddress) } @@ -1394,7 +1390,7 @@ func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []b // checkTSSVoutCancelled vout is valid if: // - The first output is the nonce-mark // - The second output is the change to TSS (optional) -func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams, vouts []btcjson.Vout, chain common.Chain) error { +func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams, vouts []btcjson.Vout) error { // vouts: [nonce-mark, change to TSS (optional)] if !(len(vouts) == 1 || len(vouts) == 2) { return fmt.Errorf("checkTSSVoutCancelled: invalid number of vouts: %d", len(vouts)) @@ -1404,21 +1400,19 @@ func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams, tssAddress := ob.Tss.BTCAddress() for _, vout := range vouts { // decode receiver and amount from vout - receiverVout, amount, err := DecodeTSSVout(vout, tssAddress, chain) + receiverVout, amount, err := DecodeTSSVout(vout, tssAddress, ob.chain) if err != nil { return errors.Wrap(err, "checkTSSVoutCancelled: error decoding P2WPKH vout") } - // 1st vout: nonce-mark - if vout.N == 0 { + switch vout.N { + case 0: // 1st vout: nonce-mark if receiverVout != tssAddress { return fmt.Errorf("checkTSSVoutCancelled: nonce-mark address %s not match TSS address %s", receiverVout, tssAddress) } if amount != common.NonceMarkAmount(nonce) { return fmt.Errorf("checkTSSVoutCancelled: nonce-mark amount %d not match nonce-mark amount %d", amount, common.NonceMarkAmount(nonce)) } - } - // 2nd vout: change to TSS (optional) - if vout.N == 2 { + case 1: // 2nd vout: change to TSS (optional) if receiverVout != tssAddress { return fmt.Errorf("checkTSSVoutCancelled: change address %s not match TSS address %s", receiverVout, tssAddress) } diff --git a/zetaclient/bitcoin/bitcoin_client_test.go b/zetaclient/bitcoin/bitcoin_client_test.go index 21573a7cb5..8eb7a62c14 100644 --- a/zetaclient/bitcoin/bitcoin_client_test.go +++ b/zetaclient/bitcoin/bitcoin_client_test.go @@ -236,17 +236,17 @@ func TestCheckTSSVout(t *testing.T) { t.Run("valid TSS vout should pass", func(t *testing.T) { rawResult, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce) params := cctx.GetCurrentOutTxParam() - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.NoError(t, err) }) t.Run("should fail if vout length < 2 or > 3", func(t *testing.T) { _, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce) params := cctx.GetCurrentOutTxParam() - err := btcClient.checkTSSVout(params, []btcjson.Vout{{}}, chain) + err := btcClient.checkTSSVout(params, []btcjson.Vout{{}}) require.ErrorContains(t, err, "invalid number of vouts") - err = btcClient.checkTSSVout(params, []btcjson.Vout{{}, {}, {}, {}}, chain) + err = btcClient.checkTSSVout(params, []btcjson.Vout{{}, {}, {}, {}}) require.ErrorContains(t, err, "invalid number of vouts") }) t.Run("should fail on invalid TSS vout", func(t *testing.T) { @@ -255,7 +255,7 @@ func TestCheckTSSVout(t *testing.T) { // invalid TSS vout rawResult.Vout[0].ScriptPubKey.Hex = "invalid script" - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.Error(t, err) }) t.Run("should fail if vout 0 is not to the TSS address", func(t *testing.T) { @@ -264,7 +264,7 @@ func TestCheckTSSVout(t *testing.T) { // not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus rawResult.Vout[0].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa" - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.ErrorContains(t, err, "not match TSS address") }) t.Run("should fail if vout 0 not match nonce mark", func(t *testing.T) { @@ -273,7 +273,7 @@ func TestCheckTSSVout(t *testing.T) { // not match nonce mark rawResult.Vout[0].Value = 0.00000147 - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.ErrorContains(t, err, "not match nonce-mark amount") }) t.Run("should fail if vout 1 is not to the receiver address", func(t *testing.T) { @@ -282,7 +282,7 @@ func TestCheckTSSVout(t *testing.T) { // not receiver address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus rawResult.Vout[1].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa" - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.ErrorContains(t, err, "not match params receiver") }) t.Run("should fail if vout 1 not match payment amount", func(t *testing.T) { @@ -291,7 +291,7 @@ func TestCheckTSSVout(t *testing.T) { // not match payment amount rawResult.Vout[1].Value = 0.00011000 - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.ErrorContains(t, err, "not match params amount") }) t.Run("should fail if vout 2 is not to the TSS address", func(t *testing.T) { @@ -300,7 +300,7 @@ func TestCheckTSSVout(t *testing.T) { // not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus rawResult.Vout[2].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa" - err := btcClient.checkTSSVout(params, rawResult.Vout, chain) + err := btcClient.checkTSSVout(params, rawResult.Vout) require.ErrorContains(t, err, "not match TSS address") }) } @@ -322,17 +322,17 @@ func TestCheckTSSVoutCancelled(t *testing.T) { rawResult.Vout = rawResult.Vout[:2] params := cctx.GetCurrentOutTxParam() - err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain) + err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout) require.NoError(t, err) }) t.Run("should fail if vout length < 1 or > 2", func(t *testing.T) { _, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce) params := cctx.GetCurrentOutTxParam() - err := btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{}, chain) + err := btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{}) require.ErrorContains(t, err, "invalid number of vouts") - err = btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{{}, {}, {}}, chain) + err = btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{{}, {}, {}}) require.ErrorContains(t, err, "invalid number of vouts") }) t.Run("should fail if vout 0 is not to the TSS address", func(t *testing.T) { @@ -344,7 +344,7 @@ func TestCheckTSSVoutCancelled(t *testing.T) { // not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus rawResult.Vout[0].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa" - err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain) + err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout) require.ErrorContains(t, err, "not match TSS address") }) t.Run("should fail if vout 0 not match nonce mark", func(t *testing.T) { @@ -356,19 +356,20 @@ func TestCheckTSSVoutCancelled(t *testing.T) { // not match nonce mark rawResult.Vout[0].Value = 0.00000147 - err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain) + err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout) require.ErrorContains(t, err, "not match nonce-mark amount") }) t.Run("should fail if vout 1 is not to the TSS address", func(t *testing.T) { // remove change vout to simulate cancelled tx rawResult, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce) rawResult.Vout[1] = rawResult.Vout[2] + rawResult.Vout[1].N = 1 // swap vout index rawResult.Vout = rawResult.Vout[:2] params := cctx.GetCurrentOutTxParam() // not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus rawResult.Vout[1].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa" - err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain) + err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout) require.ErrorContains(t, err, "not match TSS address") }) }