Skip to content

Commit

Permalink
Merge pull request #8512 from lightningnetwork/rbf-coop-fsm
Browse files Browse the repository at this point in the history
[3/4] - lnwallet/chancloser: add new protofsm based RBF chan closer
  • Loading branch information
Roasbeef authored Dec 10, 2024
2 parents 7a34015 + 3c5f96d commit d6eeaec
Show file tree
Hide file tree
Showing 17 changed files with 4,160 additions and 103 deletions.
12 changes: 2 additions & 10 deletions lnwallet/chancloser/chancloser.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ func (c *ChanCloser) AuxOutputs() fn.Option[AuxCloseOutputs] {
// upfront script is set, we check whether it matches the script provided by
// our peer. If they do not match, we use the disconnect function provided to
// disconnect from the peer.
func validateShutdownScript(disconnect func() error, upfrontScript,
peerScript lnwire.DeliveryAddress, netParams *chaincfg.Params) error {
func validateShutdownScript(upfrontScript, peerScript lnwire.DeliveryAddress,
netParams *chaincfg.Params) error {

// Either way, we'll make sure that the script passed meets our
// standards. The upfrontScript should have already been checked at an
Expand Down Expand Up @@ -568,12 +568,6 @@ func validateShutdownScript(disconnect func() error, upfrontScript,
chancloserLog.Warnf("peer's script: %x does not match upfront "+
"shutdown script: %x", peerScript, upfrontScript)

// Disconnect from the peer because they have violated option upfront
// shutdown.
if err := disconnect(); err != nil {
return err
}

return ErrUpfrontShutdownScriptMismatch
}

Expand Down Expand Up @@ -630,7 +624,6 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) (
// If the remote node opened the channel with option upfront
// shutdown script, check that the script they provided matches.
if err := validateShutdownScript(
c.cfg.Disconnect,
c.cfg.Channel.RemoteUpfrontShutdownScript(),
msg.Address, c.cfg.ChainParams,
); err != nil {
Expand Down Expand Up @@ -681,7 +674,6 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) (
// If the remote node opened the channel with option upfront
// shutdown script, check that the script they provided matches.
if err := validateShutdownScript(
c.cfg.Disconnect,
c.cfg.Channel.RemoteUpfrontShutdownScript(),
msg.Address, c.cfg.ChainParams,
); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions lnwallet/chancloser/chancloser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ func TestMaybeMatchScript(t *testing.T) {
t.Parallel()

err := validateShutdownScript(
func() error { return nil }, test.upfrontScript,
test.shutdownScript, &chaincfg.SimNetParams,
test.upfrontScript, test.shutdownScript,
&chaincfg.SimNetParams,
)

if err != test.expectedErr {
Expand Down Expand Up @@ -189,7 +189,7 @@ func (m *mockChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress {

func (m *mockChannel) CreateCloseProposal(fee btcutil.Amount,
localScript, remoteScript []byte,
_ ...lnwallet.ChanCloseOpt) (input.Signature, *chainhash.Hash,
_ ...lnwallet.ChanCloseOpt) (input.Signature, *wire.MsgTx,
btcutil.Amount, error) {

if m.chanType.IsTaproot() {
Expand Down
3 changes: 1 addition & 2 deletions lnwallet/chancloser/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package chancloser
import (
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
Expand Down Expand Up @@ -100,7 +99,7 @@ type Channel interface { //nolint:interfacebloat
localDeliveryScript []byte, remoteDeliveryScript []byte,
closeOpt ...lnwallet.ChanCloseOpt,
) (
input.Signature, *chainhash.Hash, btcutil.Amount, error)
input.Signature, *wire.MsgTx, btcutil.Amount, error)

// CompleteCooperativeClose persistently "completes" the cooperative
// close by producing a fully signed co-op close transaction.
Expand Down
176 changes: 176 additions & 0 deletions lnwallet/chancloser/mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package chancloser

import (
"sync/atomic"

"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/mock"
)

type dummyAdapters struct {
mock.Mock

msgSent atomic.Bool

confChan chan *chainntnfs.TxConfirmation
spendChan chan *chainntnfs.SpendDetail
}

func newDaemonAdapters() *dummyAdapters {
return &dummyAdapters{
confChan: make(chan *chainntnfs.TxConfirmation, 1),
spendChan: make(chan *chainntnfs.SpendDetail, 1),
}
}

func (d *dummyAdapters) SendMessages(pub btcec.PublicKey,
msgs []lnwire.Message) error {

defer d.msgSent.Store(true)

args := d.Called(pub, msgs)

return args.Error(0)
}

func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx,
label string) error {

args := d.Called(tx, label)

return args.Error(0)
}

func (d *dummyAdapters) DisableChannel(op wire.OutPoint) error {
args := d.Called(op)

return args.Error(0)
}

func (d *dummyAdapters) RegisterConfirmationsNtfn(txid *chainhash.Hash,
pkScript []byte, numConfs, heightHint uint32,
opts ...chainntnfs.NotifierOption,
) (*chainntnfs.ConfirmationEvent, error) {

args := d.Called(txid, pkScript, numConfs)

err := args.Error(0)

return &chainntnfs.ConfirmationEvent{
Confirmed: d.confChan,
}, err
}

func (d *dummyAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint,
pkScript []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) {

args := d.Called(outpoint, pkScript, heightHint)

err := args.Error(0)

return &chainntnfs.SpendEvent{
Spend: d.spendChan,
}, err
}

type mockFeeEstimator struct {
mock.Mock
}

func (m *mockFeeEstimator) EstimateFee(chanType channeldb.ChannelType,
localTxOut, remoteTxOut *wire.TxOut,
idealFeeRate chainfee.SatPerKWeight) btcutil.Amount {

args := m.Called(chanType, localTxOut, remoteTxOut, idealFeeRate)
return args.Get(0).(btcutil.Amount)
}

type mockChanObserver struct {
mock.Mock
}

func (m *mockChanObserver) NoDanglingUpdates() bool {
args := m.Called()
return args.Bool(0)
}

func (m *mockChanObserver) DisableIncomingAdds() error {
args := m.Called()
return args.Error(0)
}

func (m *mockChanObserver) DisableOutgoingAdds() error {
args := m.Called()
return args.Error(0)
}

func (m *mockChanObserver) MarkCoopBroadcasted(txn *wire.MsgTx,
local bool) error {

args := m.Called(txn, local)
return args.Error(0)
}

func (m *mockChanObserver) MarkShutdownSent(deliveryAddr []byte,
isInitiator bool) error {

args := m.Called(deliveryAddr, isInitiator)
return args.Error(0)
}

func (m *mockChanObserver) FinalBalances() fn.Option[ShutdownBalances] {
args := m.Called()
return args.Get(0).(fn.Option[ShutdownBalances])
}

func (m *mockChanObserver) DisableChannel() error {
args := m.Called()
return args.Error(0)
}

type mockErrorReporter struct {
mock.Mock
}

func (m *mockErrorReporter) ReportError(err error) {
m.Called(err)
}

type mockCloseSigner struct {
mock.Mock
}

func (m *mockCloseSigner) CreateCloseProposal(fee btcutil.Amount,
localScript []byte, remoteScript []byte,
closeOpt ...lnwallet.ChanCloseOpt) (
input.Signature, *wire.MsgTx, btcutil.Amount, error) {

args := m.Called(fee, localScript, remoteScript, closeOpt)

return args.Get(0).(input.Signature), args.Get(1).(*wire.MsgTx),
args.Get(2).(btcutil.Amount), args.Error(3)
}

func (m *mockCloseSigner) CompleteCooperativeClose(localSig,
remoteSig input.Signature,
localScript, remoteScript []byte,
fee btcutil.Amount, closeOpt ...lnwallet.ChanCloseOpt,
) (*wire.MsgTx, btcutil.Amount, error) {

args := m.Called(
localSig, remoteSig, localScript, remoteScript, fee, closeOpt,
)

return args.Get(0).(*wire.MsgTx), args.Get(1).(btcutil.Amount),
args.Error(2)
}
77 changes: 77 additions & 0 deletions lnwallet/chancloser/rbf_coop_msg_mapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package chancloser

import (
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/lnwire"
)

// RbfMsgMapper is a struct that implements the MsgMapper interface for the
// rbf-coop close state machine. This enables the state machine to be used with
// protofsm.
type RbfMsgMapper struct {
// blockHeight is the height of the block when the co-op close request
// was initiated. This is used to validate conditions related to the
// thaw height.
blockHeight uint32

// chanID is the channel ID of the channel being closed.
chanID lnwire.ChannelID
}

// NewRbfMsgMapper creates a new RbfMsgMapper instance given the current block
// height when the co-op close request was initiated.
func NewRbfMsgMapper(blockHeight uint32,
chanID lnwire.ChannelID) *RbfMsgMapper {

return &RbfMsgMapper{
blockHeight: blockHeight,
chanID: chanID,
}
}

// someEvent returns the target type as a protocol event option.
func someEvent[T ProtocolEvent](m T) fn.Option[ProtocolEvent] {
return fn.Some(ProtocolEvent(m))
}

// isExpectedChanID returns true if the channel ID of the message matches the
// bound instance.
func (r *RbfMsgMapper) isExpectedChanID(chanID lnwire.ChannelID) bool {
return r.chanID == chanID
}

// MapMsg maps a wire message into a FSM event. If the message is not mappable,
// then an error is returned.
func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] {
switch msg := wireMsg.(type) {
case *lnwire.Shutdown:
if !r.isExpectedChanID(msg.ChannelID) {
return fn.None[ProtocolEvent]()
}

return someEvent(&ShutdownReceived{
BlockHeight: r.blockHeight,
ShutdownScript: msg.Address,
})

case *lnwire.ClosingComplete:
if !r.isExpectedChanID(msg.ChannelID) {
return fn.None[ProtocolEvent]()
}

return someEvent(&OfferReceivedEvent{
SigMsg: *msg,
})

case *lnwire.ClosingSig:
if !r.isExpectedChanID(msg.ChannelID) {
return fn.None[ProtocolEvent]()
}

return someEvent(&LocalSigReceived{
SigMsg: *msg,
})
}

return fn.None[ProtocolEvent]()
}
Loading

0 comments on commit d6eeaec

Please sign in to comment.