diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index c29178b438..2b313f4cfc 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -1,6 +1,7 @@ package contractcourt import ( + "context" "errors" "fmt" "sync" @@ -27,6 +28,12 @@ import ( // ErrChainArbExiting signals that the chain arbitrator is shutting down. var ErrChainArbExiting = errors.New("ChainArbitrator exiting") +const ( + // chainArbTimeout is the timeout for the chain arbitrator to start + // the channel arbitrators for each channel. + chainArbTimeout = 5 * time.Minute +) + // ResolutionMsg is a message sent by resolvers to outside sub-systems once an // outgoing contract has been fully resolved. For multi-hop contracts, if we // resolve the outgoing contract, we'll also need to ensure that the incoming @@ -244,7 +251,7 @@ type ChainArbitrator struct { started int32 // To be used atomically. stopped int32 // To be used atomically. - sync.Mutex + sync.RWMutex // activeChannels is a map of all the active contracts that are still // open, and not fully resolved. @@ -258,6 +265,10 @@ type ChainArbitrator struct { // methods and interface it needs to operate. cfg ChainArbitratorConfig + // resolveContract is a channel which is used to signal the cleanup of + // the channel arbitrator resources. + resolveChanArb chan wire.OutPoint + // chanSource will be used by the ChainArbitrator to fetch all the // active channels that it must still watch over. chanSource *channeldb.DB @@ -276,6 +287,7 @@ func NewChainArbitrator(cfg ChainArbitratorConfig, cfg: cfg, activeChannels: make(map[wire.OutPoint]*ChannelArbitrator), activeWatchers: make(map[wire.OutPoint]*chainWatcher), + resolveChanArb: make(chan wire.OutPoint), chanSource: db, quit: make(chan struct{}), } @@ -497,6 +509,9 @@ func (c *ChainArbitrator) getArbChannel( // ResolveContract marks a contract as fully resolved within the database. // This is only to be done once all contracts which were live on the channel // before hitting the chain have been resolved. +// +// NOTE: This function must be called without the chain arbitrator lock because +// it acquires the lock itself. func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { log.Infof("Marking ChannelPoint(%v) fully resolved", chanPoint) @@ -509,44 +524,24 @@ func (c *ChainArbitrator) ResolveContract(chanPoint wire.OutPoint) error { return err } - // Now that the channel has been marked as fully closed, we'll stop - // both the channel arbitrator and chain watcher for this channel if - // they're still active. - var arbLog ArbitratorLog - c.Lock() - chainArb := c.activeChannels[chanPoint] - delete(c.activeChannels, chanPoint) - - chainWatcher := c.activeWatchers[chanPoint] - delete(c.activeWatchers, chanPoint) - c.Unlock() - - if chainArb != nil { - arbLog = chainArb.log - - if err := chainArb.Stop(); err != nil { - log.Warnf("unable to stop ChannelArbitrator(%v): %v", - chanPoint, err) - } - } - if chainWatcher != nil { - if err := chainWatcher.Stop(); err != nil { - log.Warnf("unable to stop ChainWatcher(%v): %v", - chanPoint, err) - } - } - // Once this has been marked as resolved, we'll wipe the log that the // channel arbitrator was using to store its persistent state. We do // this after marking the channel resolved, as otherwise, the // arbitrator would be re-created, and think it was starting from the // default state. - if arbLog != nil { + c.RLock() + chainArb, ok := c.activeChannels[chanPoint] + c.RUnlock() + if ok && chainArb.log != nil { + arbLog := chainArb.log if err := arbLog.WipeHistory(); err != nil { return err } } + // Make sure all the resources of the channel arbitrator are cleaned up. + fn.SendOrQuit(c.resolveChanArb, chanPoint, c.quit) + return nil } @@ -599,6 +594,8 @@ func (c *ChainArbitrator) Start() error { return err } + // We don't need to lock here because this is the only goroutine + // that will be accessing this map at this point in time. c.activeWatchers[chanPoint] = chainWatcher channelArb, err := newActiveChannelArbitrator( channel, c, chainWatcher.SubscribeChannelEvents(), @@ -765,22 +762,76 @@ func (c *ChainArbitrator) Start() error { return err } + // Launch the cleanup collector to clean up the channel arbitrator + // resources as soon as a channel is fully resolved onchain. + c.wg.Add(1) + go func() { + defer c.wg.Done() + c.cleanupCollector() + }() + // Launch all the goroutines for each arbitrator so they can carry out // their duties. + // Set a timeout for the group of goroutines. + ctx, cancel := context.WithTimeout( + context.Background(), chainArbTimeout, + ) + + channelArbErrs := make(chan error, len(c.activeChannels)) + var wgChannelArb sync.WaitGroup + for _, arbitrator := range c.activeChannels { startState, ok := startStates[arbitrator.cfg.ChanPoint] if !ok { stopAndLog() + + // In case we encounter an error we need to cancel the + // context to ensure all goroutines are cleaned up. + cancel() return fmt.Errorf("arbitrator: %v has no start state", arbitrator.cfg.ChanPoint) } - if err := arbitrator.Start(startState); err != nil { - stopAndLog() - return err - } + wgChannelArb.Add(1) + go func(arb *ChannelArbitrator) { + defer wgChannelArb.Done() + + select { + case channelArbErrs <- arb.Start(startState): + + case <-ctx.Done(): + channelArbErrs <- ctx.Err() + + case <-c.quit: + channelArbErrs <- ErrChainArbExiting + } + }(arbitrator) } + // Wait for all arbitrators to start in a separate goroutine. We don't + // have to wait here for the chain arbitrator to start, because there + // might be situations where other subsystems will block the start up + // while fetching resolve information (e.g. custom channels.) + // + // NOTE: We do not add this collector to the waitGroup because we want + // to stop the chain arbitrator if there occurs an error. + go func() { + defer cancel() + + wgChannelArb.Wait() + close(channelArbErrs) + + for err := range channelArbErrs { + if err != nil { + log.Criticalf("ChainArbitrator failed to "+ + "all channel arbitrators with: %v", err) + + // We initiated a shutdown so we exit early. + return + } + } + }() + // Subscribe to a single stream of block epoch notifications that we // will dispatch to all active arbitrators. blockEpoch, err := c.cfg.Notifier.RegisterBlockEpochNtfn(nil) @@ -800,6 +851,50 @@ func (c *ChainArbitrator) Start() error { return nil } +// cleanupCollector cleans up the channel arbitrator resources as soon as a +// channel is fully resolved onchain. +// +// NOTE: This function must be run as a goroutine. +func (c *ChainArbitrator) cleanupCollector() { + for { + select { + case chanPoint := <-c.resolveChanArb: + log.Debugf("ChannelArbitrator(%v) fully resolved, "+ + "removing from active sets", chanPoint) + + // Now that the channel has been marked as fully closed, + // we'll stop both the channel arbitrator and chain + // watcher for this channel if they're still active. + c.Lock() + channelArb := c.activeChannels[chanPoint] + delete(c.activeChannels, chanPoint) + + chainWatcher := c.activeWatchers[chanPoint] + delete(c.activeWatchers, chanPoint) + c.Unlock() + + if channelArb != nil { + if err := channelArb.Stop(); err != nil { + log.Warnf("unable to stop "+ + "ChannelArbitrator(%v): %v", + chanPoint, err) + } + } + if chainWatcher != nil { + if err := chainWatcher.Stop(); err != nil { + log.Warnf("unable to stop "+ + "ChainWatcher(%v): %v", + chanPoint, err) + } + } + + // Exit if the chain arbitrator is shutting down. + case <-c.quit: + return + } + } +} + // blockRecipient contains the information we need to dispatch a block to a // channel arbitrator. type blockRecipient struct { @@ -824,7 +919,7 @@ func (c *ChainArbitrator) dispatchBlocks( // lock and returns a set of block recipients which can be used to // dispatch blocks. getRecipients := func() []blockRecipient { - c.Lock() + c.RLock() blocks := make([]blockRecipient, 0, len(c.activeChannels)) for _, channel := range c.activeChannels { blocks = append(blocks, blockRecipient{ @@ -833,7 +928,7 @@ func (c *ChainArbitrator) dispatchBlocks( quit: channel.quit, }) } - c.Unlock() + c.RUnlock() return blocks } @@ -1066,9 +1161,9 @@ func (c *ChainArbitrator) UpdateContractSignals(chanPoint wire.OutPoint, log.Infof("Attempting to update ContractSignals for ChannelPoint(%v)", chanPoint) - c.Lock() + c.RLock() arbitrator, ok := c.activeChannels[chanPoint] - c.Unlock() + c.RUnlock() if !ok { return fmt.Errorf("unable to find arbitrator") } @@ -1087,9 +1182,9 @@ func (c *ChainArbitrator) UpdateContractSignals(chanPoint wire.OutPoint, func (c *ChainArbitrator) NotifyContractUpdate(chanPoint wire.OutPoint, update *ContractUpdate) error { - c.Lock() + c.RLock() arbitrator, ok := c.activeChannels[chanPoint] - c.Unlock() + c.RUnlock() if !ok { return fmt.Errorf("can't find arbitrator for %v", chanPoint) } @@ -1103,9 +1198,9 @@ func (c *ChainArbitrator) NotifyContractUpdate(chanPoint wire.OutPoint, func (c *ChainArbitrator) GetChannelArbitrator(chanPoint wire.OutPoint) ( *ChannelArbitrator, error) { - c.Lock() + c.RLock() arbitrator, ok := c.activeChannels[chanPoint] - c.Unlock() + c.RUnlock() if !ok { return nil, fmt.Errorf("unable to find arbitrator") } @@ -1135,9 +1230,9 @@ type forceCloseReq struct { // // TODO(roasbeef): just return the summary itself? func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.MsgTx, error) { - c.Lock() + c.RLock() arbitrator, ok := c.activeChannels[chanPoint] - c.Unlock() + c.RUnlock() if !ok { return nil, fmt.Errorf("unable to find arbitrator") } @@ -1192,9 +1287,6 @@ func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.Msg // channel has finished its final funding flow, it should be registered with // the ChainArbitrator so we can properly react to any on-chain events. func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error { - c.Lock() - defer c.Unlock() - chanPoint := newChan.FundingOutpoint log.Infof("Creating new ChannelArbitrator for ChannelPoint(%v)", @@ -1202,7 +1294,10 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error // If we're already watching this channel, then we'll ignore this // request. - if _, ok := c.activeChannels[chanPoint]; ok { + c.RLock() + _, ok := c.activeChannels[chanPoint] + c.RUnlock() + if ok { return nil } @@ -1230,8 +1325,6 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error return err } - c.activeWatchers[chanPoint] = chainWatcher - // We'll also create a new channel arbitrator instance using this new // channel, and our internal state. channelArb, err := newActiveChannelArbitrator( @@ -1241,9 +1334,11 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error return err } - // With the arbitrator created, we'll add it to our set of active - // arbitrators, then launch it. + // Make sure we hold the lock for the shortest period of time. + c.Lock() + c.activeWatchers[chanPoint] = chainWatcher c.activeChannels[chanPoint] = channelArb + c.Unlock() if err := channelArb.Start(nil); err != nil { return err @@ -1261,9 +1356,9 @@ func (c *ChainArbitrator) SubscribeChannelEvents( // First, we'll attempt to look up the active watcher for this channel. // If we can't find it, then we'll return an error back to the caller. - c.Lock() + c.RLock() watcher, ok := c.activeWatchers[chanPoint] - c.Unlock() + c.RUnlock() if !ok { return nil, fmt.Errorf("unable to find watcher for: %v", diff --git a/contractcourt/chain_arbitrator_test.go b/contractcourt/chain_arbitrator_test.go index abaca5c2ba..db79f8baa3 100644 --- a/contractcourt/chain_arbitrator_test.go +++ b/contractcourt/chain_arbitrator_test.go @@ -1,6 +1,7 @@ package contractcourt import ( + "fmt" "net" "testing" @@ -11,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/stretchr/testify/require" @@ -192,7 +194,9 @@ func TestResolveContract(t *testing.T) { require.NoError(t, chainArb.Stop()) }) + chainArb.RLock() channelArb := chainArb.activeChannels[channel.FundingOutpoint] + chainArb.RUnlock() // While the resolver are active, we'll now remove the channel from the // database (mark is as closed). @@ -207,14 +211,22 @@ func TestResolveContract(t *testing.T) { // The shouldn't be an active chain watcher or channel arb for this // channel. - if len(chainArb.activeChannels) != 0 { - t.Fatalf("expected zero active channels, instead have %v", - len(chainArb.activeChannels)) - } - if len(chainArb.activeWatchers) != 0 { - t.Fatalf("expected zero active watchers, instead have %v", - len(chainArb.activeWatchers)) - } + waitErr := wait.NoError(func() error { + chainArb.RLock() + defer chainArb.RUnlock() + + if len(chainArb.activeChannels) != 0 { + return fmt.Errorf("expected zero active channels, "+ + "instead have %v", len(chainArb.activeChannels)) + } + if len(chainArb.activeWatchers) != 0 { + return fmt.Errorf("expected zero active watchers, "+ + "instead have %v", len(chainArb.activeWatchers)) + } + + return nil + }, defaultTimeout) + require.NoError(t, waitErr, "timeout waiting for result") // At this point, the channel's arbitrator log should also be empty as // well. diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index cc1ee69589..278693d59b 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -157,6 +157,9 @@ type ChannelArbitratorConfig struct { // fully resolved once all active contracts have individually been // fully resolved. // + // NOTE: This function must be called without the chain arbitrator lock + // because it acquires the lock itself. + // // TODO(roasbeef): need RPC's to combine for pendingchannels RPC MarkChannelResolved func() error