diff --git a/benchmarks_test.go b/benchmarks_test.go index d3aaf04f..2ceb7dab 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -19,9 +19,9 @@ import ( bitswap "github.com/ipfs/go-bitswap" bssession "github.com/ipfs/go-bitswap/internal/session" + bsnet "github.com/ipfs/go-bitswap/network" testinstance "github.com/ipfs/go-bitswap/testinstance" tn "github.com/ipfs/go-bitswap/testnet" - bsnet "github.com/ipfs/go-bitswap/network" cid "github.com/ipfs/go-cid" delay "github.com/ipfs/go-ipfs-delay" mockrouting "github.com/ipfs/go-ipfs-routing/mock" @@ -99,6 +99,8 @@ var benches = []bench{ bench{"10Nodes-OnePeerPerBlock-BigBatch", 10, 100, onePeerPerBlock, batchFetchAll}, // - request 1, then 10, then 89 blocks (similar to how IPFS would fetch a file) bench{"10Nodes-OnePeerPerBlock-UnixfsFetch", 10, 100, onePeerPerBlock, unixfsFileFetch}, + // - request 1, then 10, then 89 blocks using StreamBlocks + bench{"10Nodes-OnePeerPerBlock-UnixfsStream", 10, 100, onePeerPerBlock, unixfsStreamFetch}, // Fetch from 199 seed nodes, all nodes have all blocks, fetch all 20 blocks with a single GetBlocks() call bench{"200Nodes-AllToAll-BigBatch", 200, 20, allToAll, batchFetchAll}, @@ -572,6 +574,28 @@ func unixfsFileFetch(b *testing.B, bs *bitswap.Bitswap, ks []cid.Cid) { } } +// simulates the fetch pattern of trying to sync a unixfs file graph as fast as possible +// using StreamBlocks() +func unixfsStreamFetch(b *testing.B, bs *bitswap.Bitswap, ks []cid.Cid) { + ses := bs.NewSession(context.Background()) + ksch := make(chan []cid.Cid) + out, err := ses.StreamBlocks(context.Background(), ksch) + if err != nil { + b.Fatal(err) + } + + ksch <- ks[:1] + <-out + ksch <- ks[1:11] + for i := 0; i < 10; i++ { + <-out + } + ksch <- ks[11:] + for i := 0; i < 81; i++ { + <-out + } +} + func unixfsFileFetchLarge(b *testing.B, bs *bitswap.Bitswap, ks []cid.Cid) { ses := bs.NewSession(context.Background()) _, err := ses.GetBlock(context.Background(), ks[0]) diff --git a/bitswap.go b/bitswap.go index f2217b85..b9c0aa17 100644 --- a/bitswap.go +++ b/bitswap.go @@ -38,7 +38,7 @@ import ( var log = logging.Logger("bitswap") -var _ exchange.SessionExchange = (*Bitswap)(nil) +// var _ exchange.SessionExchange = (*Bitswap)(nil) const ( // these requests take at _least_ two minutes at the moment. @@ -144,7 +144,7 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, - notif notifications.PubSub, + notif *notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) bssm.Session { @@ -197,7 +197,6 @@ func New(parent context.Context, network bsnet.BitSwapNetwork, go func() { <-px.Closing() // process closes first cancelFunc() - notif.Shutdown() }() procctx.CloseAfterContext(px, ctx) // parent cancelled first @@ -225,7 +224,7 @@ type Bitswap struct { blockstore blockstore.Blockstore // manages channels of outgoing blocks for sessions - notif notifications.PubSub + notif *notifications.PubSub // newBlocks is a channel for newly added blocks to be provided to the // network. blocks pushed down this channel get buffered and fed to the @@ -294,9 +293,13 @@ func (bs *Bitswap) LedgerForPeer(p peer.ID) *decision.Receipt { return bs.engine.LedgerForPeer(p) } -// GetBlocks returns a channel where the caller may receive blocks that -// correspond to the provided |keys|. Returns an error if BitSwap is unable to -// begin this request within the deadline enforced by the context. +// GetBlocks returns a stream of blocks, given a list of CIDs. It will +// return blocks in any order. +// +// To wait for all remaining blocks, close the CID channel and wait for +// the blocks channel to be closed. A closed channel does not mean that +// _all_ blocks were retrieved, it just means that the fetcher is done +// retrieving blocks. // // NB: Your request remains open until the context expires. To conserve // resources, provide a context with a reasonably short deadline (ie. not one @@ -306,6 +309,22 @@ func (bs *Bitswap) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks return session.GetBlocks(ctx, keys) } +// StreamBlocks returns a stream of blocks, given a stream of CIDs. It will +// return blocks in any order. +// +// To wait for all remaining blocks, close the CID channel and wait for +// the blocks channel to be closed. A closed channel does not mean that +// _all_ blocks were retrieved, it just means that the fetcher is done +// retrieving blocks. +// +// NB: Your request remains open until the context expires. To conserve +// resources, provide a context with a reasonably short deadline (ie. not one +// that lasts throughout the lifetime of the server) +func (bs *Bitswap) StreamBlocks(ctx context.Context, keys <-chan []cid.Cid) (<-chan blocks.Block, error) { + session := bs.sm.NewSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) + return session.StreamBlocks(ctx, keys) +} + // HasBlock announces the existence of a block to this bitswap service. The // service will potentially notify its peers. func (bs *Bitswap) HasBlock(blk blocks.Block) error { @@ -530,6 +549,6 @@ func (bs *Bitswap) IsOnline() bool { // method, but the session will use the fact that the requests are related to // be more efficient in its requests to peers. If you are using a session // from go-blockservice, it will create a bitswap session automatically. -func (bs *Bitswap) NewSession(ctx context.Context) exchange.Fetcher { +func (bs *Bitswap) NewSession(ctx context.Context) bssm.Fetcher { return bs.sm.NewSession(ctx, bs.provSearchDelay, bs.rebroadcastDelay) } diff --git a/bitswap_with_sessions_test.go b/bitswap_with_sessions_test.go index 9551938c..13418b12 100644 --- a/bitswap_with_sessions_test.go +++ b/bitswap_with_sessions_test.go @@ -123,6 +123,64 @@ func TestSessionBetweenPeers(t *testing.T) { } } +func TestSessionBetweenPeersStream(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + vnet := getVirtualNetwork() + ig := testinstance.NewTestInstanceGenerator(vnet, nil, nil) + defer ig.Close() + bgen := blocksutil.NewBlockGenerator() + + inst := ig.Instances(10) + + // Add 101 blocks to Peer A + blks := bgen.Blocks(101) + if err := inst[0].Blockstore().PutMany(blks); err != nil { + t.Fatal(err) + } + + var cids []cid.Cid + for _, blk := range blks { + cids = append(cids, blk.Cid()) + } + + // Create a session on Peer B + ses := inst[1].Exchange.NewSession(ctx) + if _, err := ses.GetBlock(ctx, cids[0]); err != nil { + t.Fatal(err) + } + blks = blks[1:] + cids = cids[1:] + + // Fetch blocks with the session, 10 at a time + ksch := make(chan []cid.Cid) + ch, err := ses.StreamBlocks(ctx, ksch) + for i := 0; i < 10; i++ { + ksch <- cids[i*10 : (i+1)*10] + if err != nil { + t.Fatal(err) + } + + var got []blocks.Block + for i := 0; i < 10; i++ { + got = append(got, <-ch) + } + if err := assertBlockLists(got, blks[i*10:(i+1)*10]); err != nil { + t.Fatal(err) + } + } + for _, is := range inst[2:] { + stat, err := is.Exchange.Stat() + if err != nil { + t.Fatal(err) + } + if stat.MessagesReceived > 2 { + t.Fatal("uninvolved nodes should only receive two messages", stat.MessagesReceived) + } + } +} + func TestSessionSplitFetch(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/internal/getter/getter.go b/internal/getter/getter.go index 02e3b54b..3a3a0cb9 100644 --- a/internal/getter/getter.go +++ b/internal/getter/getter.go @@ -58,75 +58,94 @@ func SyncGetBlock(p context.Context, k cid.Cid, gb GetBlocksFunc) (blocks.Block, } // WantFunc is any function that can express a want for set of blocks. -type WantFunc func(context.Context, []cid.Cid) - -// AsyncGetBlocks take a set of block cids, a pubsub channel for incoming -// blocks, a want function, and a close function, and returns a channel of -// incoming blocks. -func AsyncGetBlocks(ctx context.Context, sessctx context.Context, keys []cid.Cid, notif notifications.PubSub, - want WantFunc, cwants func([]cid.Cid)) (<-chan blocks.Block, error) { - - // If there are no keys supplied, just return a closed channel - if len(keys) == 0 { - out := make(chan blocks.Block) - close(out) - return out, nil - } - - // Use a PubSub notifier to listen for incoming blocks for each key - remaining := cid.NewSet() - promise := notif.Subscribe(ctx, keys...) - for _, k := range keys { - log.Debugw("Bitswap.GetBlockRequest.Start", "cid", k) - remaining.Add(k) - } - - // Send the want request for the keys to the network - want(ctx, keys) - +type WantFunc func([]cid.Cid) + +// AsyncGetBlocks listens for the blocks corresponding to the requested wants, +// and outputs them on the returned channel. +// If the wants channel is closed and all wanted blocks are received, closes +// the returned channel. +// If the session context or request context are cancelled, calls cancelWants +// with all pending wants and closes the returned channel. +func AsyncGetBlocks(ctx context.Context, sessctx context.Context, wants <-chan []cid.Cid, notif *notifications.PubSub, + want WantFunc, cancelWants func([]cid.Cid)) (<-chan blocks.Block, error) { + + // Channel of blocks to return to the client out := make(chan blocks.Block) - go handleIncoming(ctx, sessctx, remaining, promise, out, cwants) - return out, nil -} -// Listens for incoming blocks, passing them to the out channel. -// If the context is cancelled or the incoming channel closes, calls cfun with -// any keys corresponding to blocks that were never received. -func handleIncoming(ctx context.Context, sessctx context.Context, remaining *cid.Set, - in <-chan blocks.Block, out chan blocks.Block, cfun func([]cid.Cid)) { - - ctx, cancel := context.WithCancel(ctx) - - // Clean up before exiting this function, and call the cancel function on - // any remaining keys - defer func() { - cancel() - close(out) - // can't just defer this call on its own, arguments are resolved *when* the defer is created - cfun(remaining.Keys()) - }() + // Keep track of which wants we haven't yet received a block + pending := cid.NewSet() - for { - select { - case blk, ok := <-in: - // If the channel is closed, we're done (note that PubSub closes - // the channel once all the keys have been received) - if !ok { - return + // Use a PubSub notifier to listen for incoming blocks for each key + sub := notif.NewSubscription() + + go func() { + // Before exiting + defer func() { + // Close the client's channel of blocks + close(out) + // Close the subscription + sub.Close() + + // Cancel any pending wants + if pending.Len() > 0 { + cancelWants(pending.Keys()) } + }() - remaining.Remove(blk.Cid()) + blksCh := sub.Blocks() + for { select { - case out <- blk: + + // For each wanted key + case ks, ok := <-wants: + // Stop receiving from the channel if it's closed + if !ok { + wants = nil + if pending.Len() == 0 { + return + } + } else { + for _, k := range ks { + // Record that the want is pending + log.Debugw("Bitswap.GetBlockRequest.Start", "cid", k) + pending.Add(k) + } + + // Add the keys to the subscriber so that we'll be notified + // if the corresponding block arrives + sub.Add(ks...) + + // Send the want request for the keys to the network + want(ks) + } + + // For each received block + case blk := <-blksCh: + // Remove the want from the pending set + pending.Remove(blk.Cid()) + + // Send the block to the client + select { + case out <- blk: + case <-ctx.Done(): + return + case <-sessctx.Done(): + return + } + + // If the wants channel has been closed, and we're not + // expecting any more blocks, exit + if wants == nil && pending.Len() == 0 { + return + } + case <-ctx.Done(): return case <-sessctx.Done(): return } - case <-ctx.Done(): - return - case <-sessctx.Done(): - return } - } + }() + + return out, nil } diff --git a/internal/notifications/notifications.go b/internal/notifications/notifications.go index 7defea73..a28d6489 100644 --- a/internal/notifications/notifications.go +++ b/internal/notifications/notifications.go @@ -1,137 +1,115 @@ package notifications import ( - "context" "sync" - pubsub "github.com/cskr/pubsub" blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" ) -const bufferSize = 16 +// PubSub is used to allow sessions to subscribe to notifications about +// incoming blocks, where multiple sessions may be interested in the same +// block. +type PubSub struct { + lk sync.RWMutex + subs map[cid.Cid]map[*Subscription]struct{} +} -// PubSub is a simple interface for publishing blocks and being able to subscribe -// for cids. It's used internally by bitswap to decouple receiving blocks -// and actually providing them back to the GetBlocks caller. -type PubSub interface { - Publish(block blocks.Block) - Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Block - Shutdown() +// Subscription is a subscription to notifications about blocks +type Subscription struct { + lk sync.RWMutex + ps *PubSub + blks chan blocks.Block + closed chan struct{} } -// New generates a new PubSub interface. -func New() PubSub { - return &impl{ - wrapped: *pubsub.New(bufferSize), - closed: make(chan struct{}), +// New creates a new PubSub +func New() *PubSub { + return &PubSub{ + subs: make(map[cid.Cid]map[*Subscription]struct{}), } } -type impl struct { - lk sync.RWMutex - wrapped pubsub.PubSub +// Listen for keys +func (s *Subscription) Add(ks ...cid.Cid) { + s.ps.addSubscriptionKeys(s, ks) +} - closed chan struct{} +// Channel on which to receive incoming blocks. The channel should be +// closed with Close() +func (s *Subscription) Blocks() <-chan blocks.Block { + return s.blks } -func (ps *impl) Publish(block blocks.Block) { - ps.lk.RLock() - defer ps.lk.RUnlock() +// Receive a block +func (s *Subscription) receive(blk blocks.Block) { + s.lk.Lock() + defer s.lk.Unlock() + select { - case <-ps.closed: - return - default: + case s.blks <- blk: + case <-s.closed: } +} + +// Stop listening and close the associated blocks channel +func (s *Subscription) Close() { + close(s.closed) - ps.wrapped.Pub(block, block.Cid().KeyString()) + s.lk.Lock() + defer s.lk.Unlock() + + s.ps.removeSubscription(s) + + close(s.blks) } -func (ps *impl) Shutdown() { - ps.lk.Lock() - defer ps.lk.Unlock() - select { - case <-ps.closed: - return - default: +// Create a new subscription to PubSub notifications +func (ps *PubSub) NewSubscription() *Subscription { + return &Subscription{ + ps: ps, + blks: make(chan blocks.Block), + closed: make(chan struct{}), } - close(ps.closed) - ps.wrapped.Shutdown() } -// Subscribe returns a channel of blocks for the given |keys|. |blockChannel| -// is closed if the |ctx| times out or is cancelled, or after receiving the blocks -// corresponding to |keys|. -func (ps *impl) Subscribe(ctx context.Context, keys ...cid.Cid) <-chan blocks.Block { +// Publish a block to listeners +func (ps *PubSub) Publish(blk blocks.Block) { + ps.lk.Lock() + defer ps.lk.Unlock() - blocksCh := make(chan blocks.Block, len(keys)) - valuesCh := make(chan interface{}, len(keys)) // provide our own channel to control buffer, prevent blocking - if len(keys) == 0 { - close(blocksCh) - return blocksCh + k := blk.Cid() + for s := range ps.subs[k] { + s.receive(blk) } + delete(ps.subs, k) +} - // prevent shutdown - ps.lk.RLock() - defer ps.lk.RUnlock() - - select { - case <-ps.closed: - close(blocksCh) - return blocksCh - default: - } +// Add keys to the subscription +func (ps *PubSub) addSubscriptionKeys(s *Subscription, ks []cid.Cid) { + ps.lk.Lock() + defer ps.lk.Unlock() - // AddSubOnceEach listens for each key in the list, and closes the channel - // once all keys have been received - ps.wrapped.AddSubOnceEach(valuesCh, toStrings(keys)...) - go func() { - defer func() { - close(blocksCh) - - ps.lk.RLock() - defer ps.lk.RUnlock() - // Don't touch the pubsub instance if we're - // already closed. - select { - case <-ps.closed: - return - default: - } - - ps.wrapped.Unsub(valuesCh) - }() - - for { - select { - case <-ctx.Done(): - return - case <-ps.closed: - case val, ok := <-valuesCh: - if !ok { - return - } - block, ok := val.(blocks.Block) - if !ok { - return - } - select { - case <-ctx.Done(): - return - case blocksCh <- block: // continue - case <-ps.closed: - } - } + for _, k := range ks { + subs, ok := ps.subs[k] + if !ok { + subs = make(map[*Subscription]struct{}) + ps.subs[k] = subs } - }() - - return blocksCh + subs[s] = struct{}{} + } } -func toStrings(keys []cid.Cid) []string { - strs := make([]string, 0, len(keys)) - for _, key := range keys { - strs = append(strs, key.KeyString()) +// Remove the subscription from PubSub +func (ps *PubSub) removeSubscription(s *Subscription) { + ps.lk.Lock() + defer ps.lk.Unlock() + + for k := range ps.subs { + ksubs := ps.subs[k] + delete(ksubs, s) + if len(ksubs) == 0 { + delete(ps.subs, k) + } } - return strs } diff --git a/internal/notifications/notifications_test.go b/internal/notifications/notifications_test.go index 4e59ae9b..827d6c19 100644 --- a/internal/notifications/notifications_test.go +++ b/internal/notifications/notifications_test.go @@ -2,13 +2,10 @@ package notifications import ( "bytes" - "context" "testing" "time" blocks "github.com/ipfs/go-block-format" - cid "github.com/ipfs/go-cid" - blocksutil "github.com/ipfs/go-ipfs-blocksutil" ) func TestDuplicates(t *testing.T) { @@ -16,10 +13,11 @@ func TestDuplicates(t *testing.T) { b2 := blocks.NewBlock([]byte("2")) n := New() - defer n.Shutdown() - ch := n.Subscribe(context.Background(), b1.Cid(), b2.Cid()) + s := n.NewSubscription() + ch := s.Blocks() + s.Add(b1.Cid(), b2.Cid()) - n.Publish(b1) + go n.Publish(b1) blockRecvd, ok := <-ch if !ok { t.Fail() @@ -28,7 +26,7 @@ func TestDuplicates(t *testing.T) { n.Publish(b1) // ignored duplicate - n.Publish(b2) + go n.Publish(b2) blockRecvd, ok = <-ch if !ok { t.Fail() @@ -40,17 +38,15 @@ func TestPublishSubscribe(t *testing.T) { blockSent := blocks.NewBlock([]byte("Greetings from The Interval")) n := New() - defer n.Shutdown() - ch := n.Subscribe(context.Background(), blockSent.Cid()) + s := n.NewSubscription() + s.Add(blockSent.Cid()) - n.Publish(blockSent) - blockRecvd, ok := <-ch + go n.Publish(blockSent) + blockRecvd, ok := <-s.Blocks() if !ok { t.Fail() } - assertBlocksEqual(t, blockRecvd, blockSent) - } func TestSubscribeMany(t *testing.T) { @@ -58,17 +54,18 @@ func TestSubscribeMany(t *testing.T) { e2 := blocks.NewBlock([]byte("2")) n := New() - defer n.Shutdown() - ch := n.Subscribe(context.Background(), e1.Cid(), e2.Cid()) + s := n.NewSubscription() + ch := s.Blocks() + s.Add(e1.Cid(), e2.Cid()) - n.Publish(e1) + go n.Publish(e1) r1, ok := <-ch if !ok { t.Fatal("didn't receive first expected block") } assertBlocksEqual(t, e1, r1) - n.Publish(e2) + go n.Publish(e2) r2, ok := <-ch if !ok { t.Fatal("didn't receive second expected block") @@ -82,11 +79,15 @@ func TestDuplicateSubscribe(t *testing.T) { e1 := blocks.NewBlock([]byte("1")) n := New() - defer n.Shutdown() - ch1 := n.Subscribe(context.Background(), e1.Cid()) - ch2 := n.Subscribe(context.Background(), e1.Cid()) + s1 := n.NewSubscription() + ch1 := s1.Blocks() + s2 := n.NewSubscription() + ch2 := s2.Blocks() + + s1.Add(e1.Cid()) + s2.Add(e1.Cid()) - n.Publish(e1) + go n.Publish(e1) r1, ok := <-ch1 if !ok { t.Fatal("didn't receive first expected block") @@ -100,14 +101,15 @@ func TestDuplicateSubscribe(t *testing.T) { assertBlocksEqual(t, e1, r2) } -func TestShutdownBeforeUnsubscribe(t *testing.T) { +func TestCloseBeforeUnsubscribe(t *testing.T) { e1 := blocks.NewBlock([]byte("1")) n := New() - ctx, cancel := context.WithCancel(context.Background()) - ch := n.Subscribe(ctx, e1.Cid()) // no keys provided - n.Shutdown() - cancel() + s := n.NewSubscription() + ch := s.Blocks() + s.Add(e1.Cid()) + + s.Close() select { case _, ok := <-ch: @@ -119,61 +121,36 @@ func TestShutdownBeforeUnsubscribe(t *testing.T) { } } -func TestSubscribeIsANoopWhenCalledWithNoKeys(t *testing.T) { - n := New() - defer n.Shutdown() - ch := n.Subscribe(context.Background()) // no keys provided - if _, ok := <-ch; ok { - t.Fatal("should be closed if no keys provided") - } -} - -func TestCarryOnWhenDeadlineExpires(t *testing.T) { - - impossibleDeadline := time.Nanosecond - fastExpiringCtx, cancel := context.WithTimeout(context.Background(), impossibleDeadline) - defer cancel() +func TestPublishAfterClose(t *testing.T) { + e1 := blocks.NewBlock([]byte("1")) + e2 := blocks.NewBlock([]byte("2")) n := New() - defer n.Shutdown() - block := blocks.NewBlock([]byte("A Missed Connection")) - blockChannel := n.Subscribe(fastExpiringCtx, block.Cid()) + s := n.NewSubscription() + ch := s.Blocks() + s.Add(e1.Cid(), e2.Cid()) - assertBlockChannelNil(t, blockChannel) -} - -func TestDoesNotDeadLockIfContextCancelledBeforePublish(t *testing.T) { + go n.Publish(e1) - g := blocksutil.NewBlockGenerator() - ctx, cancel := context.WithCancel(context.Background()) - n := New() - defer n.Shutdown() - - t.Log("generate a large number of blocks. exceed default buffer") - bs := g.Blocks(1000) - ks := func() []cid.Cid { - var keys []cid.Cid - for _, b := range bs { - keys = append(keys, b.Cid()) - } - return keys - }() + r1, ok := <-ch + if !ok { + t.Fatal("didn't receive first expected block") + } + assertBlocksEqual(t, e1, r1) - _ = n.Subscribe(ctx, ks...) // ignore received channel + s.Close() - t.Log("cancel context before any blocks published") - cancel() - for _, b := range bs { - n.Publish(b) - } + go n.Publish(e1) - t.Log("publishing the large number of blocks to the ignored channel must not deadlock") -} + time.Sleep(10 * time.Millisecond) -func assertBlockChannelNil(t *testing.T, blockChannel <-chan blocks.Block) { - _, ok := <-blockChannel - if ok { - t.Fail() + select { + case _, ok := <-ch: + if ok { + t.Fatal("channel should have been closed") + } + case <-time.After(5 * time.Second): + t.Fatal("channel should have been closed") } } diff --git a/internal/session/session.go b/internal/session/session.go index 34a7375c..9805798f 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -110,6 +110,7 @@ type Session struct { // channels incoming chan op + wants chan op tickDelayReqs chan time.Duration // do not touch outside run loop @@ -120,7 +121,7 @@ type Session struct { initialSearchDelay time.Duration periodicSearchDelay delay.D // identifiers - notif notifications.PubSub + notif *notifications.PubSub uuid logging.Loggable id uint64 @@ -137,7 +138,7 @@ func New(ctx context.Context, sim *bssim.SessionInterestManager, pm PeerManager, bpm *bsbpm.BlockPresenceManager, - notif notifications.PubSub, + notif *notifications.PubSub, initialSearchDelay time.Duration, periodicSearchDelay delay.D, self peer.ID) *Session { @@ -150,6 +151,7 @@ func New(ctx context.Context, providerFinder: providerFinder, sim: sim, incoming: make(chan op, 128), + wants: make(chan op, 128), latencyTrkr: latencyTracker{}, notif: notif, uuid: loggables.Uuid("GetBlockRequest"), @@ -221,17 +223,41 @@ func (s *Session) GetBlock(parent context.Context, k cid.Cid) (blocks.Block, err // returns a channel that found blocks will be returned on. No order is // guaranteed on the returned blocks. func (s *Session) GetBlocks(ctx context.Context, keys []cid.Cid) (<-chan blocks.Block, error) { + // If there are no keys supplied, just return a closed channel + if len(keys) == 0 { + out := make(chan blocks.Block) + close(out) + return out, nil + } + + keysCh := make(chan []cid.Cid, 1) + keysCh <- keys + close(keysCh) + return s.StreamBlocks(ctx, keysCh) +} + +// StreamBlocks fetches a set of blocks within the context of this session and +// returns a channel that found blocks will be returned on. No order is +// guaranteed on the returned blocks. +func (s *Session) StreamBlocks(ctx context.Context, keys <-chan []cid.Cid) (<-chan blocks.Block, error) { ctx = logging.ContextWithLoggable(ctx, s.uuid) + // Listen for blocks for each want in the the channel of wanted keys return bsgetter.AsyncGetBlocks(ctx, s.ctx, keys, s.notif, - func(ctx context.Context, keys []cid.Cid) { + // Called when the listener has been set up for the keys + func(keys []cid.Cid) { + // Tell the session to request the keys select { - case s.incoming <- op{op: opWant, keys: keys}: + case s.wants <- op{op: opWant, keys: keys}: case <-ctx.Done(): case <-s.ctx.Done(): } }, + + // Called when the request context or session context is cancelled, + // where keys are the remaining pending wants func(keys []cid.Cid) { + // Tell the session to cancel the keys select { case s.incoming <- op{op: opCancel, keys: keys}: case <-s.ctx.Done(): @@ -276,6 +302,26 @@ func (s *Session) nonBlockingEnqueue(o op) { } } +// Pop all queued wants and return their keys. +// This is just to make it more efficient to process consecutive wants. +func (s *Session) getQueuedWants(first []cid.Cid) []cid.Cid { + if len(s.wants) == 0 { + return first + } + + ks := make([]cid.Cid, 0, len(first)+len(s.wants)) + ks = append(ks, first...) + + for { + select { + case op := <-s.wants: + ks = append(ks, op.keys...) + default: + return ks + } + } +} + // Session run loop -- everything in this function should not be called // outside of this loop func (s *Session) run(ctx context.Context) { @@ -285,35 +331,45 @@ func (s *Session) run(ctx context.Context) { s.periodicSearchTimer = time.NewTimer(s.periodicSearchDelay.NextWaitTime()) for { select { + case want := <-s.wants: + // Client wants blocks + ks := s.getQueuedWants(want.keys) + s.wantBlocks(ctx, ks) + case oper := <-s.incoming: switch oper.op { case opReceive: // Received blocks s.handleReceive(oper.keys) - case opWant: - // Client wants blocks - s.wantBlocks(ctx, oper.keys) + case opCancel: // Wants were cancelled s.sw.CancelPending(oper.keys) + case opWantsSent: // Wants were sent to a peer s.sw.WantsSent(oper.keys) + case opBroadcast: // Broadcast want-haves to all peers s.broadcastWantHaves(ctx, oper.keys) + default: panic("unhandled operation") } + case <-s.idleTick.C: // The session hasn't received blocks for a while, broadcast s.broadcastWantHaves(ctx, nil) + case <-s.periodicSearchTimer.C: // Periodically search for a random live want s.handlePeriodicSearch(ctx) + case baseTickDelay := <-s.tickDelayReqs: // Set the base tick delay s.baseTickDelay = baseTickDelay + case <-ctx.Done(): // Shutdown s.handleShutdown() diff --git a/internal/session/session_test.go b/internal/session/session_test.go index d6f89e2d..0e9f5fe9 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -11,6 +11,7 @@ import ( bssim "github.com/ipfs/go-bitswap/internal/sessioninterestmanager" bsspm "github.com/ipfs/go-bitswap/internal/sessionpeermanager" "github.com/ipfs/go-bitswap/internal/testutil" + blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" blocksutil "github.com/ipfs/go-ipfs-blocksutil" delay "github.com/ipfs/go-ipfs-delay" @@ -37,6 +38,23 @@ func (fwm *fakeWantManager) BroadcastWantHaves(ctx context.Context, sesid uint64 case <-ctx.Done(): } } + +func (fwm *fakeWantManager) waitForBroadcastReqs(t *testing.T, ctx context.Context, cnt int, postWait time.Duration) []cid.Cid { + var broadcast []cid.Cid + for len(broadcast) < cnt { + select { + case receivedWantReq := <-fwm.wantReqs: + broadcast = append(broadcast, receivedWantReq.cids...) + case <-ctx.Done(): + t.Fatal("Context done") + } + } + + time.Sleep(postWait) + + return broadcast +} + func (fwm *fakeWantManager) RemoveSession(context.Context, uint64) {} func newFakeSessionPeerManager() *bsspm.SessionPeerManager { @@ -98,7 +116,6 @@ func TestSessionGetBlocks(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() @@ -114,24 +131,24 @@ func TestSessionGetBlocks(t *testing.T) { t.Fatal("error getting blocks") } - // Wait for initial want request - receivedWantReq := <-fwm.wantReqs + // Wait for initial broadcast want-haves + broadcast := fwm.waitForBroadcastReqs(t, ctx, broadcastLiveWantsLimit, 10*time.Millisecond) - // Should have registered session's interest in blocks + // Should have registered session's interest in all wanted blocks intSes := sim.FilterSessionInterested(id, cids) if !testutil.MatchKeysIgnoreOrder(intSes[0], cids) { t.Fatal("did not register session interest in blocks") } // Should have sent out broadcast request for wants - if len(receivedWantReq.cids) != broadcastLiveWantsLimit { + if len(broadcast) != broadcastLiveWantsLimit { t.Fatal("did not enqueue correct initial number of wants") } // Simulate receiving HAVEs from several peers peers := testutil.GeneratePeers(5) for i, p := range peers { - blk := blks[testutil.IndexOf(blks, receivedWantReq.cids[i])] + blk := blks[testutil.IndexOf(blks, broadcast[i])] session.ReceiveFrom(p, []cid.Cid{}, []cid.Cid{blk.Cid()}, []cid.Cid{}) } @@ -174,6 +191,262 @@ func TestSessionGetBlocks(t *testing.T) { } } +func TestSessionStreamBlocks(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + fwm := newFakeWantManager() + fpm := newFakeSessionPeerManager() + fpf := newFakeProviderFinder() + sim := bssim.New() + bpm := bsbpm.New() + notif := notifications.New() + id := testutil.GenerateSessionID() + session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") + blockGenerator := blocksutil.NewBlockGenerator() + blks := blockGenerator.Blocks(broadcastLiveWantsLimit * 2) + var cids []cid.Cid + for _, block := range blks { + cids = append(cids, block.Cid()) + } + + ch := make(chan []cid.Cid) + blksCh, err := session.StreamBlocks(ctx, ch) + if err != nil { + t.Fatal("error streaming blocks") + } + + // Want a block + ch <- cids[:1] + + time.Sleep(10 * time.Millisecond) + + // Simulate receiving block for the CID + p := testutil.GeneratePeers(1)[0] + session.ReceiveFrom(p, cids[:1], []cid.Cid{}, []cid.Cid{}) + + // Verify that published blocks are returned by the session + notif.Publish(blks[0]) + r := <-blksCh + if !r.Cid().Equals(blks[0].Cid()) { + t.Fatal("wrong block") + } + + time.Sleep(10 * time.Millisecond) + + // Verify session no longer wants received block + wanted, _ := sim.SplitWantedUnwanted(blks[:1]) + if len(wanted) != 0 { + t.Fatal("session wants block that has already been received") + } + + // Want more blocks + ch <- cids[1:5] + ch <- cids[5:10] + ch <- cids[10:15] + + time.Sleep(10 * time.Millisecond) + + // Simulate receiving blocks for the CID + rcvBlks := []blocks.Block{blks[2], blks[7], blks[12]} + var rcvCids []cid.Cid + for _, b := range rcvBlks { + rcvCids = append(rcvCids, b.Cid()) + } + session.ReceiveFrom(p, rcvCids, []cid.Cid{}, []cid.Cid{}) + + // Verify that published blocks are returned by the session + for _, b := range rcvBlks { + notif.Publish(b) + r := <-blksCh + if !r.Cid().Equals(b.Cid()) { + t.Fatal("wrong block") + } + } + + time.Sleep(10 * time.Millisecond) + + // Verify session no longer wants received blocks + wanted, _ = sim.SplitWantedUnwanted(blks[:15]) + if len(wanted) != 11 { // Received 4 of 15 blocks + t.Fatal("session wanted incorrect blocks") + } +} + +func TestSessionStreamBlocksCloseKeysChannel(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + fwm := newFakeWantManager() + fpm := newFakeSessionPeerManager() + fpf := newFakeProviderFinder() + sim := bssim.New() + bpm := bsbpm.New() + notif := notifications.New() + id := testutil.GenerateSessionID() + session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") + blockGenerator := blocksutil.NewBlockGenerator() + blk := blockGenerator.Blocks(1)[0] + + ch := make(chan []cid.Cid) + blksCh, err := session.StreamBlocks(ctx, ch) + if err != nil { + t.Fatal("error streaming blocks") + } + + // Want a block + cids := []cid.Cid{blk.Cid()} + ch <- cids + + time.Sleep(10 * time.Millisecond) + + // Simulate receiving block for the CID + p := testutil.GeneratePeers(1)[0] + session.ReceiveFrom(p, cids, []cid.Cid{}, []cid.Cid{}) + + notif.Publish(blk) + <-blksCh + + // Verify that blocks channel is closed when client closes key channel + // and there are no pending blocks + close(ch) + _, ok := <-blksCh + if ok { + t.Fatal("expected blocks channel to be closed") + } +} + +func TestSessionStreamBlocksCloseOnReceiveAllBlocks(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + fwm := newFakeWantManager() + fpm := newFakeSessionPeerManager() + fpf := newFakeProviderFinder() + sim := bssim.New() + bpm := bsbpm.New() + notif := notifications.New() + id := testutil.GenerateSessionID() + session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") + blockGenerator := blocksutil.NewBlockGenerator() + blks := blockGenerator.Blocks(2) + var cids []cid.Cid + for _, block := range blks { + cids = append(cids, block.Cid()) + } + + ch := make(chan []cid.Cid) + blksCh, err := session.StreamBlocks(ctx, ch) + if err != nil { + t.Fatal("error streaming blocks") + } + + // Want blocks + ch <- cids[:1] + ch <- cids[1:] + + time.Sleep(10 * time.Millisecond) + + // Close the channel + close(ch) + + select { + case <-blksCh: + t.Fatal("should not be closed yet") + default: + } + + // Simulate receiving first block for the CID + p := testutil.GeneratePeers(1)[0] + session.ReceiveFrom(p, cids[:1], []cid.Cid{}, []cid.Cid{}) + notif.Publish(blks[0]) + <-blksCh + + select { + case <-blksCh: + t.Fatal("should not be closed yet") + default: + } + + // Simulate receiving second block for the CID + session.ReceiveFrom(p, cids[1:], []cid.Cid{}, []cid.Cid{}) + notif.Publish(blks[1]) + <-blksCh + + // Verify that blocks channel is now closed + _, ok := <-blksCh + if ok { + t.Fatal("expected blocks channel to be closed") + } +} + +func TestSessionStreamBlocksCloseOnContextCancel(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + fwm := newFakeWantManager() + fpm := newFakeSessionPeerManager() + fpf := newFakeProviderFinder() + sim := bssim.New() + bpm := bsbpm.New() + notif := notifications.New() + id := testutil.GenerateSessionID() + session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") + blockGenerator := blocksutil.NewBlockGenerator() + blk := blockGenerator.Blocks(1)[0] + + ch := make(chan []cid.Cid) + ctx, rqCancel := context.WithCancel(context.Background()) + blksCh, err := session.StreamBlocks(ctx, ch) + if err != nil { + t.Fatal("error streaming blocks") + } + + // Want a block + cids := []cid.Cid{blk.Cid()} + ch <- cids + + time.Sleep(10 * time.Millisecond) + + // Verify that blocks channel is closed when request context cancelled + rqCancel() + _, ok := <-blksCh + if ok { + t.Fatal("expected blocks channel to be closed") + } +} + +func TestSessionStreamBlocksCloseOnSessionContextCancel(t *testing.T) { + ctx, sessCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + fwm := newFakeWantManager() + fpm := newFakeSessionPeerManager() + fpf := newFakeProviderFinder() + sim := bssim.New() + bpm := bsbpm.New() + notif := notifications.New() + id := testutil.GenerateSessionID() + session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") + blockGenerator := blocksutil.NewBlockGenerator() + blk := blockGenerator.Blocks(1)[0] + + ch := make(chan []cid.Cid) + ctx, rqCancel := context.WithCancel(context.Background()) + defer rqCancel() + blksCh, err := session.StreamBlocks(ctx, ch) + if err != nil { + t.Fatal("error streaming blocks") + } + + // Want a block + cids := []cid.Cid{blk.Cid()} + ch <- cids + + time.Sleep(10 * time.Millisecond) + + // Verify that blocks channel is closed when session context cancelled + sessCancel() + _, ok := <-blksCh + if ok { + t.Fatal("expected blocks channel to be closed") + } +} + func TestSessionFindMorePeers(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 900*time.Millisecond) defer cancel() @@ -183,7 +456,6 @@ func TestSessionFindMorePeers(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") session.SetBaseTickDelay(200 * time.Microsecond) @@ -199,11 +471,7 @@ func TestSessionFindMorePeers(t *testing.T) { } // The session should initially broadcast want-haves - select { - case <-fwm.wantReqs: - case <-ctx.Done(): - t.Fatal("Did not make first want request ") - } + fwm.waitForBroadcastReqs(t, ctx, broadcastLiveWantsLimit, 10*time.Millisecond) // receive a block to trigger a tick reset time.Sleep(20 * time.Millisecond) // need to make sure some latency registers @@ -216,27 +484,17 @@ func TestSessionFindMorePeers(t *testing.T) { // The session should now time out waiting for a response and broadcast // want-haves again - select { - case <-fwm.wantReqs: - case <-ctx.Done(): - t.Fatal("Did not make second want request ") - } + fwm.waitForBroadcastReqs(t, ctx, broadcastLiveWantsLimit, 10*time.Millisecond) // The session should keep broadcasting periodically until it receives a response - select { - case receivedWantReq := <-fwm.wantReqs: - if len(receivedWantReq.cids) != broadcastLiveWantsLimit { - t.Fatal("did not rebroadcast whole live list") - } - // Make sure the first block is not included because it has already - // been received - for _, c := range receivedWantReq.cids { - if c.Equals(cids[0]) { - t.Fatal("should not braodcast block that was already received") - } + broadcast := fwm.waitForBroadcastReqs(t, ctx, broadcastLiveWantsLimit, 10*time.Millisecond) + + // Make sure the first block is not included because it has already + // been received + for _, c := range broadcast { + if c.Equals(cids[0]) { + t.Fatal("should not broadcast block that was already received") } - case <-ctx.Done(): - t.Fatal("Never rebroadcast want list") } // The session should eventually try to find more peers @@ -248,7 +506,7 @@ func TestSessionFindMorePeers(t *testing.T) { } func TestSessionOnPeersExhausted(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() fwm := newFakeWantManager() fpm := newFakeSessionPeerManager() @@ -257,7 +515,6 @@ func TestSessionOnPeersExhausted(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() @@ -273,10 +530,10 @@ func TestSessionOnPeersExhausted(t *testing.T) { } // Wait for initial want request - receivedWantReq := <-fwm.wantReqs + broadcast := fwm.waitForBroadcastReqs(t, ctx, broadcastLiveWantsLimit, 10*time.Millisecond) // Should have sent out broadcast request for wants - if len(receivedWantReq.cids) != broadcastLiveWantsLimit { + if len(broadcast) != broadcastLiveWantsLimit { t.Fatal("did not enqueue correct initial number of wants") } @@ -284,10 +541,10 @@ func TestSessionOnPeersExhausted(t *testing.T) { session.onPeersExhausted(cids[len(cids)-2:]) // Wait for want request - receivedWantReq = <-fwm.wantReqs + broadcast = fwm.waitForBroadcastReqs(t, ctx, 2, 10*time.Millisecond) // Should have sent out broadcast request for wants - if len(receivedWantReq.cids) != 2 { + if len(broadcast) != 2 { t.Fatal("did not enqueue correct initial number of wants") } } @@ -301,7 +558,6 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, 10*time.Millisecond, delay.Fixed(100*time.Millisecond), "") blockGenerator := blocksutil.NewBlockGenerator() @@ -317,21 +573,10 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { } // The session should initially broadcast want-haves - select { - case <-fwm.wantReqs: - case <-ctx.Done(): - t.Fatal("Did not make first want request ") - } + fwm.waitForBroadcastReqs(t, ctx, len(cids), time.Duration(0)) - // Verify a broadcast was made - select { - case receivedWantReq := <-fwm.wantReqs: - if len(receivedWantReq.cids) < len(cids) { - t.Fatal("did not rebroadcast whole live list") - } - case <-ctx.Done(): - t.Fatal("Never rebroadcast want list") - } + // Verify a rebroadcast is made + fwm.waitForBroadcastReqs(t, ctx, len(cids), time.Duration(0)) // Wait for a request to find more peers to occur select { @@ -345,25 +590,11 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { firstTickLength := time.Since(startTick) // Wait for another broadcast to occur - select { - case receivedWantReq := <-fwm.wantReqs: - if len(receivedWantReq.cids) < len(cids) { - t.Fatal("did not rebroadcast whole live list") - } - case <-ctx.Done(): - t.Fatal("Never rebroadcast want list") - } + fwm.waitForBroadcastReqs(t, ctx, len(cids), time.Duration(0)) // Wait for another broadcast to occur startTick = time.Now() - select { - case receivedWantReq := <-fwm.wantReqs: - if len(receivedWantReq.cids) < len(cids) { - t.Fatal("did not rebroadcast whole live list") - } - case <-ctx.Done(): - t.Fatal("Never rebroadcast want list") - } + fwm.waitForBroadcastReqs(t, ctx, len(cids), time.Duration(0)) // Tick should take longer consecutiveTickLength := time.Since(startTick) @@ -373,14 +604,7 @@ func TestSessionFailingToGetFirstBlock(t *testing.T) { // Wait for another broadcast to occur startTick = time.Now() - select { - case receivedWantReq := <-fwm.wantReqs: - if len(receivedWantReq.cids) < len(cids) { - t.Fatal("did not rebroadcast whole live list") - } - case <-ctx.Done(): - t.Fatal("Never rebroadcast want list") - } + fwm.waitForBroadcastReqs(t, ctx, len(cids), time.Duration(0)) // Tick should take longer secondConsecutiveTickLength := time.Since(startTick) @@ -413,7 +637,6 @@ func TestSessionCtxCancelClosesGetBlocksChannel(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() // Create a new session with its own context @@ -457,7 +680,6 @@ func TestSessionReceiveMessageAfterShutdown(t *testing.T) { sim := bssim.New() bpm := bsbpm.New() notif := notifications.New() - defer notif.Shutdown() id := testutil.GenerateSessionID() session := New(ctx, id, fwm, fpm, fpf, sim, newFakePeerManager(), bpm, notif, time.Second, delay.Fixed(time.Minute), "") blockGenerator := blocksutil.NewBlockGenerator() diff --git a/internal/sessionmanager/sessionmanager.go b/internal/sessionmanager/sessionmanager.go index f7382fad..747fd954 100644 --- a/internal/sessionmanager/sessionmanager.go +++ b/internal/sessionmanager/sessionmanager.go @@ -5,6 +5,7 @@ import ( "sync" "time" + blocks "github.com/ipfs/go-block-format" cid "github.com/ipfs/go-cid" delay "github.com/ipfs/go-ipfs-delay" @@ -12,19 +13,25 @@ import ( notifications "github.com/ipfs/go-bitswap/internal/notifications" bssession "github.com/ipfs/go-bitswap/internal/session" bssim "github.com/ipfs/go-bitswap/internal/sessioninterestmanager" - exchange "github.com/ipfs/go-ipfs-exchange-interface" peer "github.com/libp2p/go-libp2p-core/peer" ) +type Fetcher interface { + // GetBlock returns the block associated with a given key. + GetBlock(context.Context, cid.Cid) (blocks.Block, error) + GetBlocks(context.Context, []cid.Cid) (<-chan blocks.Block, error) + StreamBlocks(context.Context, <-chan []cid.Cid) (<-chan blocks.Block, error) +} + // Session is a session that is managed by the session manager type Session interface { - exchange.Fetcher + Fetcher ID() uint64 ReceiveFrom(peer.ID, []cid.Cid, []cid.Cid, []cid.Cid) } // SessionFactory generates a new session for the SessionManager to track. -type SessionFactory func(ctx context.Context, id uint64, sprm bssession.SessionPeerManager, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, notif notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) Session +type SessionFactory func(ctx context.Context, id uint64, sprm bssession.SessionPeerManager, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, notif *notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) Session // PeerManagerFactory generates a new peer manager for a session. type PeerManagerFactory func(ctx context.Context, id uint64) bssession.SessionPeerManager @@ -38,7 +45,7 @@ type SessionManager struct { peerManagerFactory PeerManagerFactory blockPresenceManager *bsbpm.BlockPresenceManager peerManager bssession.PeerManager - notif notifications.PubSub + notif *notifications.PubSub // Sessions sessLk sync.RWMutex @@ -53,7 +60,7 @@ type SessionManager struct { // New creates a new SessionManager. func New(ctx context.Context, sessionFactory SessionFactory, sessionInterestManager *bssim.SessionInterestManager, peerManagerFactory PeerManagerFactory, - blockPresenceManager *bsbpm.BlockPresenceManager, peerManager bssession.PeerManager, notif notifications.PubSub, self peer.ID) *SessionManager { + blockPresenceManager *bsbpm.BlockPresenceManager, peerManager bssession.PeerManager, notif *notifications.PubSub, self peer.ID) *SessionManager { return &SessionManager{ ctx: ctx, sessionFactory: sessionFactory, @@ -71,7 +78,7 @@ func New(ctx context.Context, sessionFactory SessionFactory, sessionInterestMana // session manager. func (sm *SessionManager) NewSession(ctx context.Context, provSearchDelay time.Duration, - rebroadcastDelay delay.D) exchange.Fetcher { + rebroadcastDelay delay.D) Fetcher { id := sm.GetNextSessionID() sessionctx, cancel := context.WithCancel(ctx) diff --git a/internal/sessionmanager/sessionmanager_test.go b/internal/sessionmanager/sessionmanager_test.go index 4e0152bb..f4f01b53 100644 --- a/internal/sessionmanager/sessionmanager_test.go +++ b/internal/sessionmanager/sessionmanager_test.go @@ -24,7 +24,7 @@ type fakeSession struct { wantHaves []cid.Cid id uint64 pm *fakeSesPeerManager - notif notifications.PubSub + notif *notifications.PubSub } func (*fakeSession) GetBlock(context.Context, cid.Cid) (blocks.Block, error) { @@ -33,6 +33,9 @@ func (*fakeSession) GetBlock(context.Context, cid.Cid) (blocks.Block, error) { func (*fakeSession) GetBlocks(context.Context, []cid.Cid) (<-chan blocks.Block, error) { return nil, nil } +func (*fakeSession) StreamBlocks(context.Context, <-chan []cid.Cid) (<-chan blocks.Block, error) { + return nil, nil +} func (fs *fakeSession) ID() uint64 { return fs.id } @@ -65,7 +68,7 @@ func sessionFactory(ctx context.Context, sim *bssim.SessionInterestManager, pm bssession.PeerManager, bpm *bsbpm.BlockPresenceManager, - notif notifications.PubSub, + notif *notifications.PubSub, provSearchDelay time.Duration, rebroadcastDelay delay.D, self peer.ID) Session { @@ -85,7 +88,6 @@ func TestReceiveFrom(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() notif := notifications.New() - defer notif.Shutdown() sim := bssim.New() bpm := bsbpm.New() pm := &fakePeerManager{} @@ -128,7 +130,6 @@ func TestReceiveBlocksWhenManagerContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() notif := notifications.New() - defer notif.Shutdown() sim := bssim.New() bpm := bsbpm.New() pm := &fakePeerManager{} @@ -163,7 +164,6 @@ func TestReceiveBlocksWhenSessionContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() notif := notifications.New() - defer notif.Shutdown() sim := bssim.New() bpm := bsbpm.New() pm := &fakePeerManager{}