diff --git a/config_builder.go b/config_builder.go index 42650bb68b..ddfc97dc2c 100644 --- a/config_builder.go +++ b/config_builder.go @@ -36,6 +36,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/funding" graphdb "github.com/lightningnetwork/lnd/graph/db" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" @@ -47,7 +48,6 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/rpcwallet" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/msgmux" - "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" @@ -166,7 +166,7 @@ type AuxComponents struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[routing.TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] // MsgRouter is an optional message router that if set will be used in // place of a new blank default message router. diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index d8f55afc69..caf2abf1ae 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -206,6 +206,11 @@ const ( Outgoing LinkDirection = true ) +// OptionalBandwidth is a type alias for the result of a bandwidth query that +// may return a bandwidth value or fn.None if the bandwidth is not available or +// not applicable. +type OptionalBandwidth = fn.Option[lnwire.MilliSatoshi] + // ChannelLink is an interface which represents the subsystem for managing the // incoming htlc requests, applying the changes to the channel, and also // propagating/forwarding it to htlc switch. @@ -267,10 +272,10 @@ type ChannelLink interface { // in order to signal to the source of the HTLC, the policy consistency // issue. CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi, - amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, scid lnwire.ShortChannelID) *LinkError + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, scid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError // CheckHtlcTransit should return a nil error if the passed HTLC details // satisfy the current channel policy. Otherwise, a LinkError with a @@ -278,14 +283,15 @@ type ChannelLink interface { // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, - timeout uint32, heightNow uint32) *LinkError + timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError // Stats return the statistics of channel link. Number of updates, // total sent/received milli-satoshis. Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSatoshi) - // Peer returns the serialized public key of remote peer with which we - // have the channel link opened. + // PeerPubKey returns the serialized public key of remote peer with + // which we have the channel link opened. PeerPubKey() [33]byte // AttachMailBox delivers an active MailBox to the link. The MailBox may @@ -302,9 +308,18 @@ type ChannelLink interface { // commitment of the channel that this link is associated with. CommitmentCustomBlob() fn.Option[tlv.Blob] - // Start/Stop are used to initiate the start/stop of the channel link - // functioning. + // AuxBandwidth returns the bandwidth that can be used for a channel, + // expressed in milli-satoshi. This might be different from the regular + // BTC bandwidth for custom channels. This will always return fn.None() + // for a regular (non-custom) channel. + AuxBandwidth(amount lnwire.MilliSatoshi, cid lnwire.ShortChannelID, + htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] + + // Start starts the channel link. Start() error + + // Stop requests the channel link to be shut down. Stop() } @@ -440,7 +455,7 @@ type htlcNotifier interface { NotifyForwardingEvent(key HtlcKey, info HtlcInfo, eventType HtlcEventType) - // NotifyIncomingLinkFailEvent notifies that a htlc has failed on our + // NotifyLinkFailEvent notifies that a htlc has failed on our // incoming link. It takes an isReceive bool to differentiate between // our node's receives and forwards. NotifyLinkFailEvent(key HtlcKey, info HtlcInfo, @@ -461,3 +476,36 @@ type htlcNotifier interface { NotifyFinalHtlcEvent(key models.CircuitKey, info channeldb.FinalHtlcInfo) } + +// AuxHtlcModifier is an interface that allows the sender to modify the outgoing +// HTLC of a payment by changing the amount or the wire message tlv records. +type AuxHtlcModifier interface { + // ProduceHtlcExtraData is a function that, based on the previous extra + // data blob of an HTLC, may produce a different blob or modify the + // amount of bitcoin this htlc should carry. + ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, + htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, + lnwire.CustomRecords, error) +} + +// AuxTrafficShaper is an interface that allows the sender to determine if a +// payment should be carried by a channel based on the TLV records that may be +// present in the `update_add_htlc` message or the channel commitment itself. +type AuxTrafficShaper interface { + AuxHtlcModifier + + // ShouldHandleTraffic is called in order to check if the channel + // identified by the provided channel ID may have external mechanisms + // that would allow it to carry out the payment. + ShouldHandleTraffic(cid lnwire.ShortChannelID, + fundingBlob fn.Option[tlv.Blob]) (bool, error) + + // PaymentBandwidth returns the available bandwidth for a custom channel + // decided by the given channel aux blob and HTLC blob. A return value + // of 0 means there is no bandwidth available. To find out if a channel + // is a custom channel that should be handled by the traffic shaper, the + // ShouldHandleTraffic method should be called first. + PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], + linkBandwidth, + htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 60062862ef..6a577e5694 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -293,6 +293,10 @@ type ChannelLinkConfig struct { // ShouldFwdExpEndorsement is a closure that indicates whether the link // should forward experimental endorsement signals. ShouldFwdExpEndorsement func() bool + + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of the link. + AuxTrafficShaper fn.Option[AuxTrafficShaper] } // channelLink is the service which drives a channel's commitment update @@ -3233,11 +3237,11 @@ func (l *channelLink) UpdateForwardingPolicy( // issue. // // NOTE: Part of the ChannelLink interface. -func (l *channelLink) CheckHtlcForward(payHash [32]byte, - incomingHtlcAmt, amtToForward lnwire.MilliSatoshi, - incomingTimeout, outgoingTimeout uint32, - inboundFee models.InboundFee, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { +func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt, + amtToForward lnwire.MilliSatoshi, incomingTimeout, + outgoingTimeout uint32, inboundFee models.InboundFee, + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3286,7 +3290,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // Check whether the outgoing htlc satisfies the channel policy. err := l.canSendHtlc( policy, payHash, amtToForward, outgoingTimeout, heightNow, - originalScid, + originalScid, customRecords, ) if err != nil { return err @@ -3322,8 +3326,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // the violation. This call is intended to be used for locally initiated // payments for which there is no corresponding incoming htlc. func (l *channelLink) CheckHtlcTransit(payHash [32]byte, - amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32, + customRecords lnwire.CustomRecords) *LinkError { l.RLock() policy := l.cfg.FwrdingPolicy @@ -3334,6 +3338,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // to occur. return l.canSendHtlc( policy, payHash, amt, timeout, heightNow, hop.Source, + customRecords, ) } @@ -3341,7 +3346,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte, // the channel's amount and time lock constraints. func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError { + heightNow uint32, originalScid lnwire.ShortChannelID, + customRecords lnwire.CustomRecords) *LinkError { // As our first sanity check, we'll ensure that the passed HTLC isn't // too small for the next hop. If so, then we'll cancel the HTLC @@ -3399,8 +3405,38 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return NewLinkError(&lnwire.FailExpiryTooFar{}) } + // We now check the available bandwidth to see if this HTLC can be + // forwarded. + availableBandwidth := l.Bandwidth() + auxBandwidth, err := fn.MapOptionZ( + l.cfg.AuxTrafficShaper, + func(ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + var htlcBlob fn.Option[tlv.Blob] + blob, err := customRecords.Serialize() + if err != nil { + return fn.Err[OptionalBandwidth]( + fmt.Errorf("unable to serialize "+ + "custom records: %w", err)) + } + + if len(blob) > 0 { + htlcBlob = fn.Some(blob) + } + + return l.AuxBandwidth(amt, originalScid, htlcBlob, ts) + }, + ).Unpack() + if err != nil { + l.log.Errorf("Unable to determine aux bandwidth: %v", err) + return NewLinkError(&lnwire.FailTemporaryNodeFailure{}) + } + + auxBandwidth.WhenSome(func(bandwidth lnwire.MilliSatoshi) { + availableBandwidth = bandwidth + }) + // Check to see if there is enough balance in this channel. - if amt > l.Bandwidth() { + if amt > availableBandwidth { l.log.Warnf("insufficient bandwidth to route htlc: %v is "+ "larger than %v", amt, l.Bandwidth()) cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { @@ -3415,6 +3451,48 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, return nil } +// AuxBandwidth returns the bandwidth that can be used for a channel, expressed +// in milli-satoshi. This might be different from the regular BTC bandwidth for +// custom channels. This will always return fn.None() for a regular (non-custom) +// channel. +func (l *channelLink) AuxBandwidth(amount lnwire.MilliSatoshi, + cid lnwire.ShortChannelID, htlcBlob fn.Option[tlv.Blob], + ts AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + unknownBandwidth := fn.None[lnwire.MilliSatoshi]() + + fundingBlob := l.FundingCustomBlob() + shouldHandle, err := ts.ShouldHandleTraffic(cid, fundingBlob) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("traffic shaper "+ + "failed to decide whether to handle traffic: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper is handling "+ + "traffic: %v", cid, shouldHandle) + + // If this channel isn't handled by the aux traffic shaper, we'll return + // early. + if !shouldHandle { + return fn.Ok(unknownBandwidth) + } + + // Ask for a specific bandwidth to be used for the channel. + commitmentBlob := l.CommitmentCustomBlob() + auxBandwidth, err := ts.PaymentBandwidth( + htlcBlob, commitmentBlob, l.Bandwidth(), amount, + ) + if err != nil { + return fn.Err[OptionalBandwidth](fmt.Errorf("failed to get "+ + "bandwidth from external traffic shaper: %w", err)) + } + + log.Debugf("ShortChannelID=%v: aux traffic shaper reported available "+ + "bandwidth: %v", cid, auxBandwidth) + + return fn.Ok(fn.Some(auxBandwidth)) +} + // Stats returns the statistics of channel link. // // NOTE: Part of the ChannelLink interface. diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 80632b07e9..938dc2e8a5 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -6243,9 +6243,9 @@ func TestCheckHtlcForward(t *testing.T) { var hash [32]byte t.Run("satisfied", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if result != nil { t.Fatalf("expected policy to be satisfied") @@ -6253,9 +6253,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("below minhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 100, 50, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 100, 50, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok { t.Fatalf("expected FailAmountBelowMinimum failure code") @@ -6263,9 +6263,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("above maxhtlc", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1200, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1200, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok { t.Fatalf("expected FailTemporaryChannelFailure failure code") @@ -6273,9 +6273,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("insufficient fee", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1005, 1000, - 200, 150, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1005, 1000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected FailFeeInsufficient failure code") @@ -6288,17 +6288,17 @@ func TestCheckHtlcForward(t *testing.T) { t.Parallel() result := link.CheckHtlcForward( - hash, 100005, 100000, 200, - 150, models.InboundFee{}, 0, lnwire.ShortChannelID{}, + hash, 100005, 100000, 200, 150, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient) require.True(t, ok, "expected FailFeeInsufficient failure code") }) t.Run("expiry too soon", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 150, models.InboundFee{}, 190, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 150, models.InboundFee{}, 190, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok { t.Fatalf("expected FailExpiryTooSoon failure code") @@ -6306,9 +6306,9 @@ func TestCheckHtlcForward(t *testing.T) { }) t.Run("incorrect cltv expiry", func(t *testing.T) { - result := link.CheckHtlcForward(hash, 1500, 1000, - 200, 190, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 200, 190, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok { t.Fatalf("expected FailIncorrectCltvExpiry failure code") @@ -6318,9 +6318,9 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("cltv expiry too far in the future", func(t *testing.T) { // Check that expiry isn't too far in the future. - result := link.CheckHtlcForward(hash, 1500, 1000, - 10200, 10100, models.InboundFee{}, 0, - lnwire.ShortChannelID{}, + result := link.CheckHtlcForward( + hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0, + lnwire.ShortChannelID{}, nil, ) if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok { t.Fatalf("expected FailExpiryTooFar failure code") @@ -6330,9 +6330,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee satisfied", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000, - 200, 150, models.InboundFee{Base: -2, Rate: -1_000}, - 0, lnwire.ShortChannelID{}) + result := link.CheckHtlcForward( + hash, 1000+10-2-1, 1000, 200, 150, + models.InboundFee{Base: -2, Rate: -1_000}, + 0, lnwire.ShortChannelID{}, nil, + ) if result != nil { t.Fatalf("expected policy to be satisfied") } @@ -6341,9 +6343,11 @@ func TestCheckHtlcForward(t *testing.T) { t.Run("inbound fee insufficient", func(t *testing.T) { t.Parallel() - result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000, + result := link.CheckHtlcForward( + hash, 1000+10-10-101-1, 1000, 200, 150, models.InboundFee{Base: -10, Rate: -100_000}, - 0, lnwire.ShortChannelID{}) + 0, lnwire.ShortChannelID{}, nil, + ) msg := result.WireMessage() if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index ce791bef32..3c201cd701 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -846,14 +846,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) { } func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi, lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32, - lnwire.ShortChannelID) *LinkError { + lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError { return f.checkHtlcForwardResult } func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32, - heightNow uint32) *LinkError { + heightNow uint32, _ lnwire.CustomRecords) *LinkError { return f.checkHtlcTransitResult } @@ -968,6 +968,17 @@ func (f *mockChannelLink) CommitmentCustomBlob() fn.Option[tlv.Blob] { return fn.None[tlv.Blob]() } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (f *mockChannelLink) AuxBandwidth(lnwire.MilliSatoshi, + lnwire.ShortChannelID, + fn.Option[tlv.Blob], AuxTrafficShaper) fn.Result[OptionalBandwidth] { + + return fn.Ok(fn.None[lnwire.MilliSatoshi]()) +} + var _ ChannelLink = (*mockChannelLink)(nil) const testInvoiceCltvExpiry = 6 diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1a08275ec9..b2c699b140 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -917,6 +917,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( currentHeight := atomic.LoadUint32(&s.bestHeight) htlcErr := link.CheckHtlcTransit( htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight, + htlc.CustomRecords, ) if htlcErr != nil { log.Errorf("Link %v policy for local forward not "+ @@ -2887,10 +2888,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, failure = link.CheckHtlcForward( htlc.PaymentHash, packet.incomingAmount, packet.amount, packet.incomingTimeout, - packet.outgoingTimeout, - packet.inboundFee, - currentHeight, - packet.originalOutgoingChanID, + packet.outgoingTimeout, packet.inboundFee, + currentHeight, packet.originalOutgoingChanID, + htlc.CustomRecords, ) } diff --git a/peer/brontide.go b/peer/brontide.go index 6bc49445ee..7074b10071 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -400,6 +400,10 @@ type Config struct { // way contracts are resolved. AuxResolver fn.Option[lnwallet.AuxContractResolver] + // AuxTrafficShaper is an optional auxiliary traffic shaper that can be + // used to manage the bandwidth of peer links. + AuxTrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -1330,6 +1334,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint, ShouldFwdExpEndorsement: p.cfg.ShouldFwdExpEndorsement, DisallowQuiescence: p.cfg.DisallowQuiescence || !p.remoteFeatures.HasFeature(lnwire.QuiescenceOptional), + AuxTrafficShaper: p.cfg.AuxTrafficShaper, } // Before adding our new link, purge the switch of any pending or live diff --git a/routing/bandwidth.go b/routing/bandwidth.go index 12e82131dc..eabd66cf82 100644 --- a/routing/bandwidth.go +++ b/routing/bandwidth.go @@ -29,39 +29,6 @@ type bandwidthHints interface { firstHopCustomBlob() fn.Option[tlv.Blob] } -// TlvTrafficShaper is an interface that allows the sender to determine if a -// payment should be carried by a channel based on the TLV records that may be -// present in the `update_add_htlc` message or the channel commitment itself. -type TlvTrafficShaper interface { - AuxHtlcModifier - - // ShouldHandleTraffic is called in order to check if the channel - // identified by the provided channel ID may have external mechanisms - // that would allow it to carry out the payment. - ShouldHandleTraffic(cid lnwire.ShortChannelID, - fundingBlob fn.Option[tlv.Blob]) (bool, error) - - // PaymentBandwidth returns the available bandwidth for a custom channel - // decided by the given channel aux blob and HTLC blob. A return value - // of 0 means there is no bandwidth available. To find out if a channel - // is a custom channel that should be handled by the traffic shaper, the - // HandleTraffic method should be called first. - PaymentBandwidth(htlcBlob, commitmentBlob fn.Option[tlv.Blob], - linkBandwidth, - htlcAmt lnwire.MilliSatoshi) (lnwire.MilliSatoshi, error) -} - -// AuxHtlcModifier is an interface that allows the sender to modify the outgoing -// HTLC of a payment by changing the amount or the wire message tlv records. -type AuxHtlcModifier interface { - // ProduceHtlcExtraData is a function that, based on the previous extra - // data blob of an HTLC, may produce a different blob or modify the - // amount of bitcoin this htlc should carry. - ProduceHtlcExtraData(totalAmount lnwire.MilliSatoshi, - htlcCustomRecords lnwire.CustomRecords) (lnwire.MilliSatoshi, - lnwire.CustomRecords, error) -} - // getLinkQuery is the function signature used to lookup a link. type getLinkQuery func(lnwire.ShortChannelID) ( htlcswitch.ChannelLink, error) @@ -73,7 +40,7 @@ type bandwidthManager struct { getLink getLinkQuery localChans map[lnwire.ShortChannelID]struct{} firstHopBlob fn.Option[tlv.Blob] - trafficShaper fn.Option[TlvTrafficShaper] + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // newBandwidthManager creates a bandwidth manager for the source node provided @@ -84,13 +51,14 @@ type bandwidthManager struct { // that are inactive, or just don't have enough bandwidth to carry the payment. func newBandwidthManager(graph Graph, sourceNode route.Vertex, linkQuery getLinkQuery, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (*bandwidthManager, error) { + ts fn.Option[htlcswitch.AuxTrafficShaper]) (*bandwidthManager, + error) { manager := &bandwidthManager{ getLink: linkQuery, localChans: make(map[lnwire.ShortChannelID]struct{}), firstHopBlob: firstHopBlob, - trafficShaper: trafficShaper, + trafficShaper: ts, } // First, we'll collect the set of outbound edges from the target @@ -166,44 +134,15 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, result, err := fn.MapOptionZ( b.trafficShaper, - func(ts TlvTrafficShaper) fn.Result[bandwidthResult] { - fundingBlob := link.FundingCustomBlob() - shouldHandle, err := ts.ShouldHandleTraffic( - cid, fundingBlob, - ) - if err != nil { - return bandwidthErr(fmt.Errorf("traffic "+ - "shaper failed to decide whether to "+ - "handle traffic: %w", err)) - } - - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper is handling traffic: %v", cid, - shouldHandle) - - // If this channel isn't handled by the external traffic - // shaper, we'll return early. - if !shouldHandle { - return fn.Ok(bandwidthResult{}) - } - - // Ask for a specific bandwidth to be used for the - // channel. - commitmentBlob := link.CommitmentCustomBlob() - auxBandwidth, err := ts.PaymentBandwidth( - b.firstHopBlob, commitmentBlob, linkBandwidth, - amount, - ) + func(s htlcswitch.AuxTrafficShaper) fn.Result[bandwidthResult] { + auxBandwidth, err := link.AuxBandwidth( + amount, cid, b.firstHopBlob, s, + ).Unpack() if err != nil { return bandwidthErr(fmt.Errorf("failed to get "+ - "bandwidth from external traffic "+ - "shaper: %w", err)) + "auxiliary bandwidth: %w", err)) } - log.Debugf("ShortChannelID=%v: external traffic "+ - "shaper reported available bandwidth: %v", cid, - auxBandwidth) - // We don't know the actual HTLC amount that will be // sent using the custom channel. But we'll still want // to make sure we can add another HTLC, using the @@ -213,7 +152,7 @@ func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID, // the max number of HTLCs on the channel. A proper // balance check is done elsewhere. return fn.Ok(bandwidthResult{ - bandwidth: fn.Some(auxBandwidth), + bandwidth: auxBandwidth, htlcAmount: fn.Some[lnwire.MilliSatoshi](0), }) }, diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go index 4872b5a7ec..28b1dfb1ab 100644 --- a/routing/bandwidth_test.go +++ b/routing/bandwidth_test.go @@ -118,7 +118,9 @@ func TestBandwidthManager(t *testing.T) { m, err := newBandwidthManager( g, sourceNode.pubkey, testCase.linkQuery, fn.None[[]byte](), - fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), ) require.NoError(t, err) diff --git a/routing/mock_test.go b/routing/mock_test.go index 3cdb5ebaf2..3f3f5ea040 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -107,7 +107,7 @@ var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil) func (m *mockPaymentSessionSourceOld) NewPaymentSession( _ *LightningPayment, _ fn.Option[tlv.Blob], - _ fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + _ fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) { return &mockPaymentSessionOld{ routes: m.routes, @@ -635,7 +635,8 @@ var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) func (m *mockPaymentSessionSource) NewPaymentSession( payment *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - tlvShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + tlvShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { args := m.Called(payment, firstHopBlob, tlvShaper) return args.Get(0).(PaymentSession), args.Error(1) @@ -895,6 +896,19 @@ func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { return m.bandwidth } +// AuxBandwidth returns the bandwidth that can be used for a channel, +// expressed in milli-satoshi. This might be different from the regular +// BTC bandwidth for custom channels. This will always return fn.None() +// for a regular (non-custom) channel. +func (m *mockLink) AuxBandwidth(lnwire.MilliSatoshi, lnwire.ShortChannelID, + fn.Option[tlv.Blob], + htlcswitch.AuxTrafficShaper) fn.Result[htlcswitch.OptionalBandwidth] { + + return fn.Ok[htlcswitch.OptionalBandwidth]( + fn.None[lnwire.MilliSatoshi](), + ) +} + // EligibleToForward returns the mock's configured eligibility. func (m *mockLink) EligibleToForward() bool { return !m.ineligible diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 267ce3965d..6f7034ea6a 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -761,7 +761,8 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { // and apply its side effects to the UpdateAddHTLC message. result, err := fn.MapOptionZ( p.router.cfg.TrafficShaper, - func(ts TlvTrafficShaper) fn.Result[extraDataRequest] { + //nolint:ll + func(ts htlcswitch.AuxTrafficShaper) fn.Result[extraDataRequest] { newAmt, newRecords, err := ts.ProduceHtlcExtraData( rt.TotalAmount, p.firstHopCustomRecords, ) @@ -774,7 +775,7 @@ func (p *paymentLifecycle) amendFirstHopData(rt *route.Route) error { return fn.Err[extraDataRequest](err) } - log.Debugf("TLV traffic shaper returned custom "+ + log.Debugf("Aux traffic shaper returned custom "+ "records %v and amount %d msat for HTLC", spew.Sdump(newRecords), newAmt) diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 315c1bad58..d566eb9413 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -30,7 +30,7 @@ func createTestPaymentLifecycle() *paymentLifecycle { quitChan := make(chan struct{}) rt := &ChannelRouter{ cfg: &Config{ - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, @@ -83,7 +83,7 @@ func newTestPaymentLifecycle(t *testing.T) (*paymentLifecycle, *mockers) { Payer: mockPayer, Clock: mockClock, MissionControl: mockMissionControl, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }, diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d5f1a6af41..daaf7743b5 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -4,6 +4,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" @@ -52,7 +53,8 @@ type SessionSource struct { // payment's destination. func (m *SessionSource) NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, error) { + trafficShaper fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, + error) { getBandwidthHints := func(graph Graph) (bandwidthHints, error) { return newBandwidthManager( diff --git a/routing/router.go b/routing/router.go index 9eabe0b2ae..3405354124 100644 --- a/routing/router.go +++ b/routing/router.go @@ -157,7 +157,7 @@ type PaymentSessionSource interface { // finding a path to the payment's destination. NewPaymentSession(p *LightningPayment, firstHopBlob fn.Option[tlv.Blob], - trafficShaper fn.Option[TlvTrafficShaper]) (PaymentSession, + ts fn.Option[htlcswitch.AuxTrafficShaper]) (PaymentSession, error) // NewPaymentSessionEmpty creates a new paymentSession instance that is @@ -297,7 +297,7 @@ type Config struct { // TrafficShaper is an optional traffic shaper that can be used to // control the outgoing channel of a payment. - TrafficShaper fn.Option[TlvTrafficShaper] + TrafficShaper fn.Option[htlcswitch.AuxTrafficShaper] } // EdgeLocator is a struct used to identify a specific edge. diff --git a/routing/router_test.go b/routing/router_test.go index 2923f1fb90..a69b746f14 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -170,7 +170,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, Clock: clock.NewTestClock(time.Unix(1, 0)), ApplyChannelUpdate: graphBuilder.ApplyChannelUpdate, ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper]( + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( &mockTrafficShaper{}, ), }) @@ -2206,8 +2206,10 @@ func TestSendToRouteSkipTempErrSuccess(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Register mockers with the expected method calls. @@ -2291,8 +2293,10 @@ func TestSendToRouteSkipTempErrNonMPP(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Expect an error to be returned. @@ -2347,8 +2351,10 @@ func TestSendToRouteSkipTempErrTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2431,8 +2437,10 @@ func TestSendToRouteSkipTempErrPermanentFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. @@ -2519,8 +2527,10 @@ func TestSendToRouteTempFailure(t *testing.T) { NextPaymentID: func() (uint64, error) { return 0, nil }, - ClosedSCIDs: mockClosedSCIDs, - TrafficShaper: fn.Some[TlvTrafficShaper](&mockTrafficShaper{}), + ClosedSCIDs: mockClosedSCIDs, + TrafficShaper: fn.Some[htlcswitch.AuxTrafficShaper]( + &mockTrafficShaper{}, + ), }} // Create the error to be returned. diff --git a/server.go b/server.go index f8f8239ed6..7f725553a3 100644 --- a/server.go +++ b/server.go @@ -4222,6 +4222,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, MsgRouter: s.implCfg.MsgRouter, AuxChanCloser: s.implCfg.AuxChanCloser, AuxResolver: s.implCfg.AuxContractResolver, + AuxTrafficShaper: s.implCfg.TrafficShaper, ShouldFwdExpEndorsement: func() bool { if s.cfg.ProtocolOptions.NoExperimentalEndorsement() { return false