Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Subreddit Comment Retrieval and Stream Concurrency #12

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions reddit/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package reddit

import (
"sync"
)

// OrderedMaxSet is intended to be able to check if things have been seen while
// expiring older entries that are unlikely to be seen again.
// This is to avoid memory issues in long-running streams.
type OrderedMaxSet struct {
MaxSize int
set map[string]struct{}
keys []string
mutex *sync.Mutex
}

// NewOrderedMaxSet instantiates an OrderedMaxSet and returns it for downstream use.
func NewOrderedMaxSet(maxSize int) OrderedMaxSet {
var mutex = &sync.Mutex{}
orderedMaxSet := OrderedMaxSet{
MaxSize: maxSize,
set: map[string]struct{}{},
keys: []string{},
mutex: mutex,
}

return orderedMaxSet
}

// Add accepts a string and inserts it into an OrderedMaxSet
func (s *OrderedMaxSet) Add(v string) {
s.mutex.Lock()
defer s.mutex.Unlock()
_, ok := s.set[v]
if !ok {
s.keys = append(s.keys, v)
s.set[v] = struct{}{}
}
if len(s.keys) > s.MaxSize {
for _, id := range s.keys[:len(s.keys)-s.MaxSize] {
delete(s.set, id)
}
s.keys = s.keys[(len(s.keys) - s.MaxSize):]

}
}

// Delete accepts a string and deletes it from OrderedMaxSet
func (s *OrderedMaxSet) Delete(v string) {
s.mutex.Lock()
defer s.mutex.Unlock()
delete(s.set, v)
}

// Len returns the number of elements in OrderedMaxSet
func (s *OrderedMaxSet) Len() int {
s.mutex.Lock()
defer s.mutex.Unlock()
return len(s.set)
}

// Exists accepts a string and determines if it is present in OrderedMaxSet
func (s *OrderedMaxSet) Exists(v string) bool {
s.mutex.Lock()
defer s.mutex.Unlock()
_, ok := s.set[v]
return ok
}
34 changes: 34 additions & 0 deletions reddit/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package reddit

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestNewOrderedMaxSet(t *testing.T) {
set := NewOrderedMaxSet(1)
set.Add("foo")
set.Add("bar")
println(len(set.keys))
require.Equal(t, set.Len(), 1)
}

func TestOrderedMaxSetCollision(t *testing.T) {
set := NewOrderedMaxSet(2)
set.Add("foo")
set.Add("foo")

require.Equal(t, set.Len(), 1)
}

func TestOrderedMaxSet_Delete(t *testing.T) {
set := NewOrderedMaxSet(1)
set.Add("foo")

require.Equal(t, set.Len(), 1)

set.Delete("foo")
require.Equal(t, set.Len(), 0)
require.False(t, set.Exists("foo"))
}
19 changes: 18 additions & 1 deletion reddit/listings.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,27 @@ func (s *ListingsService) Get(ctx context.Context, ids ...string) ([]*Post, []*C

// GetPosts returns posts from their full IDs.
func (s *ListingsService) GetPosts(ctx context.Context, ids ...string) ([]*Post, *Response, error) {
path := fmt.Sprintf("by_id/%s", strings.Join(ids, ","))
converted_ids := []string{}
for _, id := range ids {
converted_ids = append(converted_ids, "t3_"+id)
}
path := fmt.Sprintf("by_id/%s", strings.Join(converted_ids, ","))
l, resp, err := s.client.getListing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
return l.Posts(), resp, nil
}

func (s *ListingsService) GetComments(ctx context.Context, ids ...string) ([]*Comment, *Response, error) {
converted_ids := []string{}
for _, id := range ids {
converted_ids = append(converted_ids, "t1_"+id)
}
path := fmt.Sprintf("api/info?id=%s", strings.Join(converted_ids, ","))
l, resp, err := s.client.getListing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
return l.Comments(), resp, nil
}
203 changes: 153 additions & 50 deletions reddit/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,85 +31,188 @@ func (s *StreamService) Posts(subreddit string, opts ...StreamOpt) (<-chan *Post
ticker := time.NewTicker(streamConfig.Interval)
postsCh := make(chan *Post)
errsCh := make(chan error)

var once sync.Once
stop := func() {
once.Do(func() {
ticker.Stop()
close(postsCh)
close(errsCh)
})
}
ctx, cancel := context.WithCancel(context.Background())

// originally used the "before" parameter, but if that post gets deleted, subsequent requests
// would just return empty listings; easier to just keep track of all post ids encountered
ids := set{}
ids := NewOrderedMaxSet(2000)

go func() {
defer stop()
defer close(postsCh)
defer close(errsCh)
defer cancel()
var wg sync.WaitGroup
defer wg.Wait()
var mutex sync.Mutex

var n int
infinite := streamConfig.MaxRequests == 0

for ; ; <-ticker.C {
n++

posts, err := s.getPosts(subreddit)
if err != nil {
errsCh <- err
if !infinite && n >= streamConfig.MaxRequests {
break
}
continue
select {
case <-ctx.Done():
ticker.Stop()
return
default:
}

for _, post := range posts {
id := post.FullID

// if this post id is already part of the set, it means that it and the ones
// after it in the list have already been streamed, so break out of the loop
if ids.Exists(id) {
break
n++
wg.Add(1)
go s.getPosts(ctx, subreddit, func(posts []*Post, err error) {
defer wg.Done()

if err != nil {
select {
case <-ctx.Done():
default:
errsCh <- err
}
return
}
ids.Add(id)

if streamConfig.DiscardInitial {
streamConfig.DiscardInitial = false
break
for _, post := range posts {
id := post.FullID

// if this post id is already part of the set, it means that it and the ones
// after it in the list have already been streamed, so break out of the loop
if ids.Exists(id) {
break
}
ids.Add(id)

if func() bool {
mutex.Lock()
toReturn := false
if streamConfig.DiscardInitial {
streamConfig.DiscardInitial = false
toReturn = true
}
mutex.Unlock()
return toReturn
}() {
break
}

select {
case <-ctx.Done():
return
default:
postsCh <- post
}
}

postsCh <- post
}
})

if !infinite && n >= streamConfig.MaxRequests {
break
}
}
}()

return postsCh, errsCh, stop
return postsCh, errsCh, cancel
}

func (s *StreamService) getPosts(subreddit string) ([]*Post, error) {
posts, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, &ListOptions{Limit: 100})
return posts, err
}
// Comments streams comments from the specified subreddit.
// It returns 2 channels and a function:
// - a channel into which new comments will be sent
// - a channel into which any errors will be sent
// - a function that the client can call once to stop the streaming and close the channels
// Because of the 100 result limit imposed by Reddit when fetching posts, some high-traffic
// streams might drop submissions between API requests, such as when streaming r/all.
func (s *StreamService) Comments(subreddit string, opts ...StreamOpt) (<-chan *Comment, <-chan error, func()) {
streamConfig := &streamConfig{
Interval: defaultStreamInterval,
DiscardInitial: false,
MaxRequests: 0,
}
for _, opt := range opts {
opt(streamConfig)
}

type set map[string]struct{}
ticker := time.NewTicker(streamConfig.Interval)
commentsCh := make(chan *Comment)
errsCh := make(chan error)
ctx, cancel := context.WithCancel(context.Background())

func (s set) Add(v string) {
s[v] = struct{}{}
}
ids := NewOrderedMaxSet(2000)

go func() {
defer close(commentsCh)
defer close(errsCh)
defer cancel()
var wg sync.WaitGroup
defer wg.Wait()
var mutex sync.Mutex

var n int
infinite := streamConfig.MaxRequests == 0

for ; ; <-ticker.C {
select {
case <-ctx.Done():
ticker.Stop()
return
default:
}
n++
wg.Add(1)

go s.getComments(ctx, subreddit, func(comments []*Comment, err error) {
defer wg.Done()
if err != nil {
select {
case <-ctx.Done():
default:
errsCh <- err
}
return
}

for _, comment := range comments {
id := comment.FullID

// certain comment streams are inconsistent about the completeness of returned comments
// it's not enough to check if we've seen older comments, but we must check for every comment individually
if !ids.Exists(id) {
ids.Add(id)

if func() bool {
mutex.Lock()
toReturn := false
if streamConfig.DiscardInitial {
streamConfig.DiscardInitial = false
toReturn = true
}
mutex.Unlock()
return toReturn
}() {
break
}

select {
case <-ctx.Done():
return
default:
commentsCh <- comment
}
}

}
})
if !infinite && n >= streamConfig.MaxRequests {
break
}
}
}()

func (s set) Delete(v string) {
delete(s, v)
return commentsCh, errsCh, cancel
}

func (s set) Len() int {
return len(s)
func (s *StreamService) getPosts(ctx context.Context, subreddit string, cb func([]*Post, error)) {
posts, _, err := s.client.Subreddit.NewPosts(ctx, subreddit, &ListOptions{Limit: 100})
cb(posts, err)
}

func (s set) Exists(v string) bool {
_, ok := s[v]
return ok
func (s *StreamService) getComments(ctx context.Context, subreddit string, cb func([]*Comment, error)) {
comments, _, err := s.client.Subreddit.Comments(ctx, subreddit, &ListOptions{Limit: 100})
cb(comments, err)
}
Loading