Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix channel arbitrator lingering goroutine #9264

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 148 additions & 53 deletions contractcourt/chain_arbitrator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package contractcourt

import (
"context"
"errors"
"fmt"
"sync"
Expand All @@ -27,6 +28,12 @@ import (
// ErrChainArbExiting signals that the chain arbitrator is shutting down.
var ErrChainArbExiting = errors.New("ChainArbitrator exiting")

const (
// chainArbTimeout is the timeout for the chain arbitrator to start
// the channel arbitrators for each channel.
chainArbTimeout = 5 * time.Minute
)

// ResolutionMsg is a message sent by resolvers to outside sub-systems once an
// outgoing contract has been fully resolved. For multi-hop contracts, if we
// resolve the outgoing contract, we'll also need to ensure that the incoming
Expand Down Expand Up @@ -244,7 +251,7 @@ type ChainArbitrator struct {
started int32 // To be used atomically.
stopped int32 // To be used atomically.

sync.Mutex
sync.RWMutex

// activeChannels is a map of all the active contracts that are still
// open, and not fully resolved.
Expand All @@ -258,6 +265,10 @@ type ChainArbitrator struct {
// methods and interface it needs to operate.
cfg ChainArbitratorConfig

// resolveContract is a channel which is used to signal the cleanup of
// the channel arbitrator resources.
resolveChanArb chan wire.OutPoint

// chanSource will be used by the ChainArbitrator to fetch all the
// active channels that it must still watch over.
chanSource *channeldb.DB
Expand All @@ -276,6 +287,7 @@ func NewChainArbitrator(cfg ChainArbitratorConfig,
cfg: cfg,
activeChannels: make(map[wire.OutPoint]*ChannelArbitrator),
activeWatchers: make(map[wire.OutPoint]*chainWatcher),
resolveChanArb: make(chan wire.OutPoint),
chanSource: db,
quit: make(chan struct{}),
}
Expand Down Expand Up @@ -497,6 +509,9 @@ func (c *ChainArbitrator) getArbChannel(
// ResolveContract marks a contract as fully resolved within the database.
// This is only to be done once all contracts which were live on the channel
// before hitting the chain have been resolved.
//
// NOTE: This function must be called without the chain arbitrator lock because
// it acquires the lock itself.
func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
log.Infof("Marking ChannelPoint(%v) fully resolved", chanPoint)

Expand All @@ -509,44 +524,24 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error {
return err
}

// Now that the channel has been marked as fully closed, we'll stop
// both the channel arbitrator and chain watcher for this channel if
// they're still active.
var arbLog ArbitratorLog
c.Lock()
chainArb := c.activeChannels[chanPoint]
delete(c.activeChannels, chanPoint)

chainWatcher := c.activeWatchers[chanPoint]
delete(c.activeWatchers, chanPoint)
c.Unlock()

if chainArb != nil {
arbLog = chainArb.log

if err := chainArb.Stop(); err != nil {
log.Warnf("unable to stop ChannelArbitrator(%v): %v",
chanPoint, err)
}
}
if chainWatcher != nil {
if err := chainWatcher.Stop(); err != nil {
log.Warnf("unable to stop ChainWatcher(%v): %v",
chanPoint, err)
}
}

// Once this has been marked as resolved, we'll wipe the log that the
// channel arbitrator was using to store its persistent state. We do
// this after marking the channel resolved, as otherwise, the
// arbitrator would be re-created, and think it was starting from the
// default state.
if arbLog != nil {
c.RLock()
chainArb, ok := c.activeChannels[chanPoint]
c.RUnlock()
if ok && chainArb.log != nil {
arbLog := chainArb.log
if err := arbLog.WipeHistory(); err != nil {
return err
}
}

// Make sure all the resources of the channel arbitrator are cleaned up.
fn.SendOrQuit(c.resolveChanArb, chanPoint, c.quit)

return nil
}

Expand Down Expand Up @@ -599,6 +594,8 @@ func (c *ChainArbitrator) Start() error {
return err
}

// We don't need to lock here because this is the only goroutine
// that will be accessing this map at this point in time.
c.activeWatchers[chanPoint] = chainWatcher
channelArb, err := newActiveChannelArbitrator(
channel, c, chainWatcher.SubscribeChannelEvents(),
Expand Down Expand Up @@ -765,22 +762,76 @@ func (c *ChainArbitrator) Start() error {
return err
}

// Launch the cleanup collector to clean up the channel arbitrator
// resources as soon as a channel is fully resolved onchain.
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.cleanupCollector()
}()

// Launch all the goroutines for each arbitrator so they can carry out
// their duties.
// Set a timeout for the group of goroutines.
ctx, cancel := context.WithTimeout(
context.Background(), chainArbTimeout,
)

channelArbErrs := make(chan error, len(c.activeChannels))
var wgChannelArb sync.WaitGroup

for _, arbitrator := range c.activeChannels {
startState, ok := startStates[arbitrator.cfg.ChanPoint]
if !ok {
stopAndLog()

// In case we encounter an error we need to cancel the
// context to ensure all goroutines are cleaned up.
cancel()
return fmt.Errorf("arbitrator: %v has no start state",
arbitrator.cfg.ChanPoint)
}

if err := arbitrator.Start(startState); err != nil {
stopAndLog()
return err
}
wgChannelArb.Add(1)
go func(arb *ChannelArbitrator) {
defer wgChannelArb.Done()

select {
case channelArbErrs <- arb.Start(startState):

case <-ctx.Done():
channelArbErrs <- ctx.Err()

case <-c.quit:
channelArbErrs <- ErrChainArbExiting
}
}(arbitrator)
}

// Wait for all arbitrators to start in a separate goroutine. We don't
// have to wait here for the chain arbitrator to start, because there
// might be situations where other subsystems will block the start up
// while fetching resolve information (e.g. custom channels.)
//
// NOTE: We do not add this collector to the waitGroup because we want
// to stop the chain arbitrator if there occurs an error.
go func() {
defer cancel()

wgChannelArb.Wait()
close(channelArbErrs)

for err := range channelArbErrs {
if err != nil {
log.Criticalf("ChainArbitrator failed to "+
"all channel arbitrators with: %v", err)

// We initiated a shutdown so we exit early.
return
}
}
}()

// Subscribe to a single stream of block epoch notifications that we
// will dispatch to all active arbitrators.
blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil)
Expand All @@ -800,6 +851,50 @@ func (c *ChainArbitrator) Start() error {
return nil
}

// cleanupCollector cleans up the channel arbitrator resources as soon as a
// channel is fully resolved onchain.
//
// NOTE: This function must be run as a goroutine.
func (c *ChainArbitrator) cleanupCollector() {
for {
select {
case chanPoint := <-c.resolveChanArb:
log.Debugf("ChannelArbitrator(%v) fully resolved, "+
"removing from active sets", chanPoint)

// Now that the channel has been marked as fully closed,
// we'll stop both the channel arbitrator and chain
// watcher for this channel if they're still active.
c.Lock()
channelArb := c.activeChannels[chanPoint]
delete(c.activeChannels, chanPoint)

chainWatcher := c.activeWatchers[chanPoint]
delete(c.activeWatchers, chanPoint)
c.Unlock()

if channelArb != nil {
if err := channelArb.Stop(); err != nil {
log.Warnf("unable to stop "+
"ChannelArbitrator(%v): %v",
chanPoint, err)
}
}
if chainWatcher != nil {
if err := chainWatcher.Stop(); err != nil {
log.Warnf("unable to stop "+
"ChainWatcher(%v): %v",
chanPoint, err)
}
}

// Exit if the chain arbitrator is shutting down.
case <-c.quit:
return
}
}
}

// blockRecipient contains the information we need to dispatch a block to a
// channel arbitrator.
type blockRecipient struct {
Expand All @@ -824,7 +919,7 @@ func (c *ChainArbitrator) dispatchBlocks(
// lock and returns a set of block recipients which can be used to
// dispatch blocks.
getRecipients := func() []blockRecipient {
c.Lock()
c.RLock()
blocks := make([]blockRecipient, 0, len(c.activeChannels))
for _, channel := range c.activeChannels {
blocks = append(blocks, blockRecipient{
Expand All @@ -833,7 +928,7 @@ func (c *ChainArbitrator) dispatchBlocks(
quit: channel.quit,
})
}
c.Unlock()
c.RUnlock()

return blocks
}
Expand Down Expand Up @@ -1066,9 +1161,9 @@ func (c *ChainArbitrator) UpdateContractSignals(chanPoint wire.OutPoint,
log.Infof("Attempting to update ContractSignals for ChannelPoint(%v)",
chanPoint)

c.Lock()
c.RLock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
c.RUnlock()
if !ok {
return fmt.Errorf("unable to find arbitrator")
}
Expand All @@ -1087,9 +1182,9 @@ func (c *ChainArbitrator) UpdateContractSignals(chanPoint wire.OutPoint,
func (c *ChainArbitrator) NotifyContractUpdate(chanPoint wire.OutPoint,
update *ContractUpdate) error {

c.Lock()
c.RLock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
c.RUnlock()
if !ok {
return fmt.Errorf("can't find arbitrator for %v", chanPoint)
}
Expand All @@ -1103,9 +1198,9 @@ func (c *ChainArbitrator) NotifyContractUpdate(chanPoint wire.OutPoint,
func (c *ChainArbitrator) GetChannelArbitrator(chanPoint wire.OutPoint) (
*ChannelArbitrator, error) {

c.Lock()
c.RLock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
c.RUnlock()
if !ok {
return nil, fmt.Errorf("unable to find arbitrator")
}
Expand Down Expand Up @@ -1135,9 +1230,9 @@ type forceCloseReq struct {
//
// TODO(roasbeef): just return the summary itself?
func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.MsgTx, error) {
c.Lock()
c.RLock()
arbitrator, ok := c.activeChannels[chanPoint]
c.Unlock()
c.RUnlock()
if !ok {
return nil, fmt.Errorf("unable to find arbitrator")
}
Expand Down Expand Up @@ -1192,17 +1287,17 @@ func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.Msg
// channel has finished its final funding flow, it should be registered with
// the ChainArbitrator so we can properly react to any on-chain events.
func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error {
c.Lock()
defer c.Unlock()

chanPoint := newChan.FundingOutpoint

log.Infof("Creating new ChannelArbitrator for ChannelPoint(%v)",
chanPoint)

// If we're already watching this channel, then we'll ignore this
// request.
if _, ok := c.activeChannels[chanPoint]; ok {
c.RLock()
_, ok := c.activeChannels[chanPoint]
c.RUnlock()
if ok {
return nil
}

Expand Down Expand Up @@ -1230,8 +1325,6 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error
return err
}

c.activeWatchers[chanPoint] = chainWatcher

// We'll also create a new channel arbitrator instance using this new
// channel, and our internal state.
channelArb, err := newActiveChannelArbitrator(
Expand All @@ -1241,9 +1334,11 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error
return err
}

// With the arbitrator created, we'll add it to our set of active
// arbitrators, then launch it.
// Make sure we hold the lock for the shortest period of time.
c.Lock()
c.activeWatchers[chanPoint] = chainWatcher
c.activeChannels[chanPoint] = channelArb
c.Unlock()

if err := channelArb.Start(nil); err != nil {
return err
Expand All @@ -1261,9 +1356,9 @@ func (c *ChainArbitrator) SubscribeChannelEvents(

// First, we'll attempt to look up the active watcher for this channel.
// If we can't find it, then we'll return an error back to the caller.
c.Lock()
c.RLock()
watcher, ok := c.activeWatchers[chanPoint]
c.Unlock()
c.RUnlock()

if !ok {
return nil, fmt.Errorf("unable to find watcher for: %v",
Expand Down
Loading
Loading