Skip to content

Commit

Permalink
Add flush method to writer
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacd9 committed Mar 7, 2024
1 parent b2b17ac commit 987e665
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 12 deletions.
53 changes: 47 additions & 6 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,41 @@ func (w *Writer) spawn(f func()) {
}()
}

// Flush writes all currently buffered messages to the kafka cluster. This will
// block until all messages in the batch has been written to kafka, or until the
// context is canceled.
func (w *Writer) Flush(ctx context.Context) error {
w.mutex.Lock()

var wg sync.WaitGroup

// flush all writers
for _, writer := range w.writers {
w := writer
wg.Add(1)
go func() {
b := w.flush()
<-b.done
wg.Done()
}()
}

w.mutex.Unlock()
done := make(chan struct{})

go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

// Close flushes pending writes, and waits for all writes to complete before
// returning. Calling Close also prevents new writes from being submitted to
// the writer, further calls to WriteMessages and the like will fail with
Expand Down Expand Up @@ -1184,17 +1219,23 @@ func (ptw *partitionWriter) writeBatch(batch *writeBatch) {
batch.complete(err)
}

func (ptw *partitionWriter) close() {
func (ptw *partitionWriter) flush() *writeBatch {
ptw.mutex.Lock()
defer ptw.mutex.Unlock()

if ptw.currBatch != nil {
batch := ptw.currBatch
ptw.queue.Put(batch)
ptw.currBatch = nil
batch.trigger()
if ptw.currBatch == nil {
return nil
}

batch := ptw.currBatch
ptw.queue.Put(batch)
ptw.currBatch = nil
batch.trigger()
return batch
}

func (ptw *partitionWriter) close() {
ptw.flush()
ptw.queue.Close()
}

Expand Down
76 changes: 70 additions & 6 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func TestWriter(t *testing.T) {
scenario: "closing a writer right after creating it returns promptly with no error",
function: testWriterClose,
},

{
scenario: "writing 1 message through a writer using round-robin balancing produces 1 message to the first partition",
function: testWriterRoundRobin1,
Expand All @@ -130,6 +129,10 @@ func TestWriter(t *testing.T) {
scenario: "writing a batch of messages",
function: testWriterBatchSize,
},
{
scenario: "writing and flushing a batch of messages",
function: testsWriterFlush,
},

{
scenario: "writing messages with a small batch byte size",
Expand Down Expand Up @@ -450,7 +453,7 @@ func readPartition(topic string, partition int, offset int64) (msgs []Message, e
}
}

func testWriterBatchBytes(t *testing.T) {
func testsWriterFlush(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)
Expand All @@ -461,10 +464,13 @@ func testWriterBatchBytes(t *testing.T) {
}

w := newTestWriter(WriterConfig{
Topic: topic,
BatchBytes: 50,
BatchTimeout: math.MaxInt32 * time.Second,
Topic: topic,
// Set the batch timeout to a large value to avoid the timeout
BatchSize: 1000,
BatchBytes: 1000000,
BatchTimeout: 1000 * time.Second,
Balancer: &RoundRobin{},
Async: true,
})
defer w.Close()

Expand All @@ -480,7 +486,65 @@ func testWriterBatchBytes(t *testing.T) {
return
}

if w.Stats().Writes != 2 {
if err := w.Flush(ctx); err != nil {
t.Errorf("flush error %v", err)
return
}

if w.Stats().Writes != 1 {
t.Error("didn't create expected batches")
return
}
msgs, err := readPartition(topic, 0, offset)
if err != nil {
t.Error("error reading partition", err)
return
}

if len(msgs) != 4 {
t.Error("bad messages in partition", msgs)
return
}

for i, m := range msgs {
if string(m.Value) == "M"+strconv.Itoa(i) {
continue
}
t.Error("bad messages in partition", string(m.Value))
}
}

func testWriterBatchBytes(t *testing.T) {
topic := makeTopic()
createTopic(t, topic, 1)
defer deleteTopic(t, topic)

offset, err := readOffset(topic, 0)
if err != nil {
t.Fatal(err)
}

w := newTestWriter(WriterConfig{
Topic: topic,
BatchBytes: 50,
BatchTimeout: math.MaxInt32 * time.Second,
Balancer: &RoundRobin{},
})
defer w.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := w.WriteMessages(ctx, []Message{
{Value: []byte("M0")},
{Value: []byte("M1")},
{Value: []byte("M2")},
{Value: []byte("M3")},
}...); err != nil {
t.Error(err)
return
}

if w.Stats().Writes != 1 {
t.Error("didn't create expected batches")
return
}
Expand Down

0 comments on commit 987e665

Please sign in to comment.