diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index 5ea4d8e4bd..b4226ff2cd 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -454,7 +454,7 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, case err == nil: // The tx is valid, return the request ID. requestID := t.storeRecord( - sweepCtx.tx, req, f, sweepCtx.fee, + sweepCtx.tx, req, f, sweepCtx.fee, sweepCtx.outpointToTxIndex, ) log.Infof("Created tx %v for %v inputs: feerate=%v, "+ @@ -510,7 +510,8 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // storeRecord stores the given record in the records map. func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, - f FeeFunction, fee btcutil.Amount) uint64 { + f FeeFunction, fee btcutil.Amount, + outpointToTxIndex map[wire.OutPoint]int) uint64 { // Increase the request counter. // @@ -520,10 +521,11 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, // Register the record. t.records.Store(requestID, &monitorRecord{ - tx: tx, - req: req, - feeFunction: f, - fee: fee, + tx: tx, + req: req, + feeFunction: f, + fee: fee, + outpointToTxIndex: outpointToTxIndex, }) return requestID diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 53b38607f7..688d5878fc 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -323,8 +323,16 @@ func TestStoreRecord(t *testing.T) { // Get the current counter and check it's increased later. initialCounter := tp.requestCounter.Load() + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + // Call the method under test. - requestID := tp.storeRecord(tx, req, feeFunc, fee) + requestID := tp.storeRecord(tx, req, feeFunc, fee, utxoIndex) // Check the request ID is as expected. require.Equal(t, initialCounter+1, requestID) @@ -336,6 +344,7 @@ func TestStoreRecord(t *testing.T) { require.Equal(t, feeFunc, record.feeFunction) require.Equal(t, fee, record.fee) require.Equal(t, req, record.req) + require.Equal(t, utxoIndex, record.outpointToTxIndex) } // mockers wraps a list of mocked interfaces used inside tx publisher. @@ -665,9 +674,17 @@ func TestTxPublisherBroadcast(t *testing.T) { feerate := chainfee.SatPerKWeight(1000) m.feeFunc.On("FeeRate").Return(feerate) + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) // Quickly check when the requestID cannot be found, an error is // returned. @@ -754,6 +771,14 @@ func TestRemoveResult(t *testing.T) { // Create a testing record and put it in the map. fee := btcutil.Amount(1000) + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + testCases := []struct { name string setupRecord func() uint64 @@ -765,7 +790,7 @@ func TestRemoveResult(t *testing.T) { // removed. name: "remove on TxConfirmed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) + id := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) tp.subscriberChans.Store(id, nil) return id @@ -780,7 +805,7 @@ func TestRemoveResult(t *testing.T) { // When the tx is failed, the records will be removed. name: "remove on TxFailed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) + id := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) tp.subscriberChans.Store(id, nil) return id @@ -796,7 +821,7 @@ func TestRemoveResult(t *testing.T) { // Noop when the tx is neither confirmed or failed. name: "noop when tx is not confirmed or failed", setupRecord: func() uint64 { - id := tp.storeRecord(tx, req, m.feeFunc, fee) + id := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) tp.subscriberChans.Store(id, nil) return id @@ -844,9 +869,17 @@ func TestNotifyResult(t *testing.T) { // Create a test tx. tx := &wire.MsgTx{LockTime: 1} + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1) @@ -1201,9 +1234,17 @@ func TestHandleTxConfirmed(t *testing.T) { // Create a test tx. tx := &wire.MsgTx{LockTime: 1} + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) record, ok := tp.records.Load(requestID) require.True(t, ok) @@ -1273,9 +1314,17 @@ func TestHandleFeeBumpTx(t *testing.T) { tx: tx, } + op := wire.OutPoint{ + Hash: chainhash.Hash{1}, + Index: 0, + } + utxoIndex := map[wire.OutPoint]int{ + op: 0, + } + // Create a testing record and put it in the map. fee := btcutil.Amount(1000) - requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex) // Create a subscription to the event. subscriber := make(chan *BumpResult, 1)