Skip to content

Commit

Permalink
improved function and added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ws4charlie committed Mar 28, 2024
1 parent edc9be6 commit f02cc52
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 36 deletions.
1 change: 1 addition & 0 deletions common/bitcoin/address_taproot.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (a *AddressSegWit) String() string {
return a.EncodeAddress()
}

// DecodeTaprootAddress decodes taproot address only and returns error on non-taproot address
func DecodeTaprootAddress(addr string) (*AddressTaproot, error) {
hrp, version, program, err := decodeSegWitAddress(addr)
if err != nil {
Expand Down
36 changes: 15 additions & 21 deletions zetaclient/bitcoin/bitcoin_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ func (ob *BTCChainClient) IsInTxRestricted(inTx *BTCInTxEvnet) bool {
return false
}

// GetBtcEvent either returns a valid BTCInTxEvnet or nil
// GetBtcEvent either returns a valid BTCInTxEvent or nil
// Note: the caller should retry the tx on error (e.g., GetSenderAddressByVin failed)
func GetBtcEvent(
rpcClient interfaces.BTCRPCClient,
Expand All @@ -741,7 +741,7 @@ func GetBtcEvent(
var value float64
var memo []byte
if len(tx.Vout) >= 2 {
// 1st vout must to addressed to the tssAddress with p2wpkh scriptPubKey
// 1st vout must have tss address as receiver with p2wpkh scriptPubKey
vout0 := tx.Vout[0]
script := vout0.ScriptPubKey.Hex
if len(script) == 44 && script[:4] == "0014" { // P2WPKH output: 0x00 + 20 bytes of pubkey hash
Expand Down Expand Up @@ -1252,12 +1252,12 @@ func (ob *BTCChainClient) checkTssOutTxResult(cctx *types.CrossChainTx, hash *ch

// differentiate between normal and restricted cctx
if compliance.IsCctxRestricted(cctx) {
err = ob.checkTSSVoutCancelled(params, rawResult.Vout, ob.chain)
err = ob.checkTSSVoutCancelled(params, rawResult.Vout)
if err != nil {
return errors.Wrapf(err, "checkTssOutTxResult: invalid TSS Vout in cancelled outTx %s nonce %d", hash, nonce)
}
} else {
err = ob.checkTSSVout(params, rawResult.Vout, ob.chain)
err = ob.checkTSSVout(params, rawResult.Vout)
if err != nil {
return errors.Wrapf(err, "checkTssOutTxResult: invalid TSS Vout in outTx %s nonce %d", hash, nonce)
}
Expand Down Expand Up @@ -1343,7 +1343,7 @@ func (ob *BTCChainClient) checkTSSVin(vins []btcjson.Vin, nonce uint64) error {
// - The first output is the nonce-mark
// - The second output is the correct payment to recipient
// - The third output is the change to TSS (optional)
func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []btcjson.Vout, chain common.Chain) error {
func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []btcjson.Vout) error {
// vouts: [nonce-mark, payment to recipient, change to TSS (optional)]
if !(len(vouts) == 2 || len(vouts) == 3) {
return fmt.Errorf("checkTSSVout: invalid number of vouts: %d", len(vouts))
Expand All @@ -1358,31 +1358,27 @@ func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []b
// the 2nd output is the payment to recipient
receiverExpected = params.Receiver
}
receiverVout, amount, err := DecodeTSSVout(vout, receiverExpected, chain)
receiverVout, amount, err := DecodeTSSVout(vout, receiverExpected, ob.chain)
if err != nil {
return err
}
// 1st vout: nonce-mark
if vout.N == 0 {
switch vout.N {
case 0: // 1st vout: nonce-mark
if receiverVout != tssAddress {
return fmt.Errorf("checkTSSVout: nonce-mark address %s not match TSS address %s", receiverVout, tssAddress)
}
if amount != common.NonceMarkAmount(nonce) {
return fmt.Errorf("checkTSSVout: nonce-mark amount %d not match nonce-mark amount %d", amount, common.NonceMarkAmount(nonce))
}
}
// 2nd vout: payment to recipient
if vout.N == 1 {
case 1: // 2nd vout: payment to recipient
if receiverVout != params.Receiver {
return fmt.Errorf("checkTSSVout: output address %s not match params receiver %s", receiverVout, params.Receiver)
}
// #nosec G701 always positive
if uint64(amount) != params.Amount.Uint64() {
return fmt.Errorf("checkTSSVout: output amount %d not match params amount %d", amount, params.Amount)
}
}
// 3rd vout: change to TSS (optional)
if vout.N == 2 {
case 2: // 3rd vout: change to TSS (optional)
if receiverVout != tssAddress {
return fmt.Errorf("checkTSSVout: change address %s not match TSS address %s", receiverVout, tssAddress)
}
Expand All @@ -1394,7 +1390,7 @@ func (ob *BTCChainClient) checkTSSVout(params *types.OutboundTxParams, vouts []b
// checkTSSVoutCancelled vout is valid if:
// - The first output is the nonce-mark
// - The second output is the change to TSS (optional)
func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams, vouts []btcjson.Vout, chain common.Chain) error {
func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams, vouts []btcjson.Vout) error {
// vouts: [nonce-mark, change to TSS (optional)]
if !(len(vouts) == 1 || len(vouts) == 2) {
return fmt.Errorf("checkTSSVoutCancelled: invalid number of vouts: %d", len(vouts))
Expand All @@ -1404,21 +1400,19 @@ func (ob *BTCChainClient) checkTSSVoutCancelled(params *types.OutboundTxParams,
tssAddress := ob.Tss.BTCAddress()
for _, vout := range vouts {
// decode receiver and amount from vout
receiverVout, amount, err := DecodeTSSVout(vout, tssAddress, chain)
receiverVout, amount, err := DecodeTSSVout(vout, tssAddress, ob.chain)
if err != nil {
return errors.Wrap(err, "checkTSSVoutCancelled: error decoding P2WPKH vout")
}
// 1st vout: nonce-mark
if vout.N == 0 {
switch vout.N {
case 0: // 1st vout: nonce-mark
if receiverVout != tssAddress {
return fmt.Errorf("checkTSSVoutCancelled: nonce-mark address %s not match TSS address %s", receiverVout, tssAddress)
}
if amount != common.NonceMarkAmount(nonce) {
return fmt.Errorf("checkTSSVoutCancelled: nonce-mark amount %d not match nonce-mark amount %d", amount, common.NonceMarkAmount(nonce))
}
}
// 2nd vout: change to TSS (optional)
if vout.N == 2 {
case 1: // 2nd vout: change to TSS (optional)
if receiverVout != tssAddress {
return fmt.Errorf("checkTSSVoutCancelled: change address %s not match TSS address %s", receiverVout, tssAddress)
}
Expand Down
31 changes: 16 additions & 15 deletions zetaclient/bitcoin/bitcoin_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,17 @@ func TestCheckTSSVout(t *testing.T) {
t.Run("valid TSS vout should pass", func(t *testing.T) {
rawResult, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce)
params := cctx.GetCurrentOutTxParam()
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.NoError(t, err)
})
t.Run("should fail if vout length < 2 or > 3", func(t *testing.T) {
_, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce)
params := cctx.GetCurrentOutTxParam()

err := btcClient.checkTSSVout(params, []btcjson.Vout{{}}, chain)
err := btcClient.checkTSSVout(params, []btcjson.Vout{{}})
require.ErrorContains(t, err, "invalid number of vouts")

err = btcClient.checkTSSVout(params, []btcjson.Vout{{}, {}, {}, {}}, chain)
err = btcClient.checkTSSVout(params, []btcjson.Vout{{}, {}, {}, {}})
require.ErrorContains(t, err, "invalid number of vouts")
})
t.Run("should fail on invalid TSS vout", func(t *testing.T) {
Expand All @@ -255,7 +255,7 @@ func TestCheckTSSVout(t *testing.T) {

// invalid TSS vout
rawResult.Vout[0].ScriptPubKey.Hex = "invalid script"
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.Error(t, err)
})
t.Run("should fail if vout 0 is not to the TSS address", func(t *testing.T) {
Expand All @@ -264,7 +264,7 @@ func TestCheckTSSVout(t *testing.T) {

// not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus
rawResult.Vout[0].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa"
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.ErrorContains(t, err, "not match TSS address")
})
t.Run("should fail if vout 0 not match nonce mark", func(t *testing.T) {
Expand All @@ -273,7 +273,7 @@ func TestCheckTSSVout(t *testing.T) {

// not match nonce mark
rawResult.Vout[0].Value = 0.00000147
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.ErrorContains(t, err, "not match nonce-mark amount")
})
t.Run("should fail if vout 1 is not to the receiver address", func(t *testing.T) {
Expand All @@ -282,7 +282,7 @@ func TestCheckTSSVout(t *testing.T) {

// not receiver address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus
rawResult.Vout[1].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa"
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.ErrorContains(t, err, "not match params receiver")
})
t.Run("should fail if vout 1 not match payment amount", func(t *testing.T) {
Expand All @@ -291,7 +291,7 @@ func TestCheckTSSVout(t *testing.T) {

// not match payment amount
rawResult.Vout[1].Value = 0.00011000
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.ErrorContains(t, err, "not match params amount")
})
t.Run("should fail if vout 2 is not to the TSS address", func(t *testing.T) {
Expand All @@ -300,7 +300,7 @@ func TestCheckTSSVout(t *testing.T) {

// not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus
rawResult.Vout[2].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa"
err := btcClient.checkTSSVout(params, rawResult.Vout, chain)
err := btcClient.checkTSSVout(params, rawResult.Vout)
require.ErrorContains(t, err, "not match TSS address")
})
}
Expand All @@ -322,17 +322,17 @@ func TestCheckTSSVoutCancelled(t *testing.T) {
rawResult.Vout = rawResult.Vout[:2]
params := cctx.GetCurrentOutTxParam()

err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain)
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout)
require.NoError(t, err)
})
t.Run("should fail if vout length < 1 or > 2", func(t *testing.T) {
_, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce)
params := cctx.GetCurrentOutTxParam()

err := btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{}, chain)
err := btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{})
require.ErrorContains(t, err, "invalid number of vouts")

err = btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{{}, {}, {}}, chain)
err = btcClient.checkTSSVoutCancelled(params, []btcjson.Vout{{}, {}, {}})
require.ErrorContains(t, err, "invalid number of vouts")
})
t.Run("should fail if vout 0 is not to the TSS address", func(t *testing.T) {
Expand All @@ -344,7 +344,7 @@ func TestCheckTSSVoutCancelled(t *testing.T) {

// not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus
rawResult.Vout[0].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa"
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain)
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout)
require.ErrorContains(t, err, "not match TSS address")
})
t.Run("should fail if vout 0 not match nonce mark", func(t *testing.T) {
Expand All @@ -356,19 +356,20 @@ func TestCheckTSSVoutCancelled(t *testing.T) {

// not match nonce mark
rawResult.Vout[0].Value = 0.00000147
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain)
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout)
require.ErrorContains(t, err, "not match nonce-mark amount")
})
t.Run("should fail if vout 1 is not to the TSS address", func(t *testing.T) {
// remove change vout to simulate cancelled tx
rawResult, cctx := testutils.LoadBTCTxRawResultNCctx(t, chainID, nonce)
rawResult.Vout[1] = rawResult.Vout[2]
rawResult.Vout[1].N = 1 // swap vout index
rawResult.Vout = rawResult.Vout[:2]
params := cctx.GetCurrentOutTxParam()

// not TSS address, bc1qh297vdt8xq6df5xae9z8gzd4jsu9a392mp0dus
rawResult.Vout[1].ScriptPubKey.Hex = "0014ba8be635673034d4d0ddc9447409b594385ec4aa"
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout, chain)
err := btcClient.checkTSSVoutCancelled(params, rawResult.Vout)
require.ErrorContains(t, err, "not match TSS address")
})
}
Expand Down

0 comments on commit f02cc52

Please sign in to comment.