diff --git a/reddit/helpers.go b/reddit/helpers.go new file mode 100644 index 0000000..ca28ea4 --- /dev/null +++ b/reddit/helpers.go @@ -0,0 +1,72 @@ +package reddit + +import ( + "sync" +) + +// OrderedMaxSet is intended to be able to check if things have been seen while +// expiring older entries that are unlikely to be seen again. +// This is to avoid memory issues in long-running streams. +type OrderedMaxSet struct { + MaxSize int + set map[string]struct{} + keys []string + mutex *sync.Mutex +} + +func (s OrderedMaxSet) updateKeys(newKeys []string) { + s.keys = newKeys +} + +// NewOrderedMaxSet instantiates an OrderedMaxSet and returns it for downstream use. +func NewOrderedMaxSet(maxSize int) OrderedMaxSet { + var mutex = &sync.Mutex{} + orderedMaxSet := OrderedMaxSet{ + MaxSize: maxSize, + set: map[string]struct{}{}, + keys: []string{}, + mutex: mutex, + } + + return orderedMaxSet +} + +// Add accepts a string and inserts it into an OrderedMaxSet +func (s *OrderedMaxSet) Add(v string) { + s.mutex.Lock() + defer s.mutex.Unlock() + _, ok := s.set[v] + if !ok { + s.keys = append(s.keys, v) + s.set[v] = struct{}{} + } + if len(s.keys) > s.MaxSize { + for _, id := range s.keys[:len(s.keys)-s.MaxSize] { + delete(s.set, id) + } + s.keys = s.keys[(len(s.keys) - s.MaxSize):] + + } +} + +// Delete accepts a string and deletes it from OrderedMaxSet +func (s *OrderedMaxSet) Delete(v string) { + s.mutex.Lock() + defer s.mutex.Unlock() + delete(s.set, v) +} + +// Len returns the number of elements in OrderedMaxSet +func (s *OrderedMaxSet) Len() int { + s.mutex.Lock() + defer s.mutex.Unlock() + return len(s.set) +} + +// Exists accepts a string and determines if it is present in OrderedMaxSet +func (s *OrderedMaxSet) Exists(v string) bool { + s.mutex.Lock() + defer s.mutex.Unlock() + _, ok := s.set[v] + return ok +} diff --git a/reddit/helpers_test.go b/reddit/helpers_test.go new file mode 100644 index 0000000..babaf48 --- /dev/null +++ b/reddit/helpers_test.go @@ -0,0 +1,34 @@ +package reddit + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOrderedMaxSet(t *testing.T) { + set := NewOrderedMaxSet(1) + set.Add("foo") + set.Add("bar") + println(len(set.keys)) + require.Equal(t, set.Len(), 1) +} + +func TestOrderedMaxSetCollision(t *testing.T) { + set := NewOrderedMaxSet(2) + set.Add("foo") + set.Add("foo") + + require.Equal(t, set.Len(), 1) +} + +func TestOrderedMaxSet_Delete(t *testing.T) { + set := NewOrderedMaxSet(1) + set.Add("foo") + + require.Equal(t, set.Len(), 1) + + set.Delete("foo") + require.Equal(t, set.Len(), 0) + require.False(t, set.Exists("foo")) +} diff --git a/reddit/stream.go b/reddit/stream.go index 93f6ca8..951375e 100644 --- a/reddit/stream.go +++ b/reddit/stream.go @@ -43,43 +43,53 @@ func (s *StreamService) Posts(subreddit string, opts ...StreamOpt) (<-chan *Post // originally used the "before" parameter, but if that post gets deleted, subsequent requests // would just return empty listings; easier to just keep track of all post ids encountered - ids := set{} + ids := NewOrderedMaxSet(2000) go func() { defer stop() + var wg sync.WaitGroup + defer wg.Wait() + var mutex sync.Mutex var n int infinite := streamConfig.MaxRequests == 0 for ; ; <-ticker.C { n++ + wg.Add(1) + go s.getPosts(subreddit, func(posts []*Post, err error) { - posts, err := s.getPosts(subreddit) - if err != nil { - errsCh <- err - if !infinite && n >= streamConfig.MaxRequests { - break + if err != nil { + errsCh <- err + return } - continue - } - for _, post := range posts { - id := post.FullID + for _, post := range posts { + id := post.FullID - // if this post id is already part of the set, it means that it and the ones - // after it in the list have already been streamed, so break out of the loop - if ids.Exists(id) { - break - } - ids.Add(id) + // if this post id is already part of the set, it means that it and the ones + // after it in the list have already been streamed, so break out of the loop + if ids.Exists(id) { + break + } + ids.Add(id) - if streamConfig.DiscardInitial { - streamConfig.DiscardInitial = false - break - } + if func() bool { + mutex.Lock() + toReturn := false + if streamConfig.DiscardInitial { + streamConfig.DiscardInitial = false + toReturn = true + } + mutex.Unlock() + return toReturn + }() { + break + } - postsCh <- post - } + postsCh <- post + } + }) if !infinite && n >= streamConfig.MaxRequests { break @@ -120,44 +130,54 @@ func (s *StreamService) Comments(subreddit string, opts ...StreamOpt) (<-chan *C }) } - ids := set{} + ids := NewOrderedMaxSet(2000) go func() { defer stop() + var wg sync.WaitGroup + defer wg.Wait() + var mutex sync.Mutex var n int infinite := streamConfig.MaxRequests == 0 for ; ; <-ticker.C { n++ + wg.Add(1) - comments, err := s.getComments(subreddit) - if err != nil { - errsCh <- err - if !infinite && n >= streamConfig.MaxRequests { - break + go s.getComments(subreddit, func(comments []*Comment, err error) { + defer wg.Done() + if err != nil { + errsCh <- err + return } - continue - } - - for _, comment := range comments { - id := comment.FullID - - // certain comment streams are inconsistent about the completeness of returned comments - // it's not enough to check if we've seen older comments, but we must check for every comment individually - if !ids.Exists(id) { - ids.Add(id) - if streamConfig.DiscardInitial { - streamConfig.DiscardInitial = false - break + for _, comment := range comments { + id := comment.FullID + + // certain comment streams are inconsistent about the completeness of returned comments + // it's not enough to check if we've seen older comments, but we must check for every comment individually + if !ids.Exists(id) { + ids.Add(id) + + if func() bool { + mutex.Lock() + toReturn := false + if streamConfig.DiscardInitial { + streamConfig.DiscardInitial = false + toReturn = true + } + mutex.Unlock() + return toReturn + }() { + break + } + + commentsCh <- comment } - commentsCh <- comment } - - } - + }) if !infinite && n >= streamConfig.MaxRequests { break } @@ -167,31 +187,12 @@ func (s *StreamService) Comments(subreddit string, opts ...StreamOpt) (<-chan *C return commentsCh, errsCh, stop } -func (s *StreamService) getPosts(subreddit string) ([]*Post, error) { +func (s *StreamService) getPosts(subreddit string, cb func([]*Post, error)) { posts, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, &ListOptions{Limit: 100}) - return posts, err + cb(posts, err) } -func (s *StreamService) getComments(subreddit string) ([]*Comment, error) { +func (s *StreamService) getComments(subreddit string, cb func([]*Comment, error)) { comments, _, err := s.client.Subreddit.Comments(context.Background(), subreddit, &ListOptions{Limit: 100}) - return comments, err -} - -type set map[string]struct{} - -func (s set) Add(v string) { - s[v] = struct{}{} -} - -func (s set) Delete(v string) { - delete(s, v) -} - -func (s set) Len() int { - return len(s) -} - -func (s set) Exists(v string) bool { - _, ok := s[v] - return ok + cb(comments, err) } diff --git a/reddit/stream_test.go b/reddit/stream_test.go index 7c1960d..5b4a6ca 100644 --- a/reddit/stream_test.go +++ b/reddit/stream_test.go @@ -133,8 +133,7 @@ func TestStreamService_Posts(t *testing.T) { } }) - posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4)) - defer stop() + posts, errs, _ := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4)) expectedPostIDs := []string{"t3_post1", "t3_post2", "t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"} var i int @@ -283,8 +282,7 @@ func TestStreamService_Posts_DiscardInitial(t *testing.T) { } }) - posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial) - defer stop() + posts, errs, _ := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial) expectedPostIDs := []string{"t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"} var i int @@ -433,8 +431,7 @@ func TestStreamService_Comments(t *testing.T) { } }) - comments, errs, stop := client.Stream.Comments("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4)) - defer stop() + comments, errs, _ := client.Stream.Comments("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4)) expectedCommentIds := []string{"t1_comment1", "t1_comment2", "t1_comment3", "t1_comment4", "t1_comment5", "t1_comment6", "t1_comment7", "t1_comment8", "t1_comment9", "t1_comment10", "t1_comment11", "t1_comment12"} var i int @@ -583,8 +580,7 @@ func TestStreamService_CommentsDiscardInitial(t *testing.T) { } }) - comments, errs, stop := client.Stream.Comments("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial) - defer stop() + comments, errs, _ := client.Stream.Comments("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial) expectedCommentIds := []string{"t1_comment3", "t1_comment4", "t1_comment5", "t1_comment6", "t1_comment7", "t1_comment8", "t1_comment9", "t1_comment10", "t1_comment11", "t1_comment12"} var i int