Skip to content

Commit

Permalink
Change Batch API to be consistent with Query()
Browse files Browse the repository at this point in the history
Exec() method for batch was added & Query() method was refactored.
Batch for now behaves the same way as query.

patch by Oleksandr Luzhniy; reviewed by João Reis, Danylo Savchenko, Bohdan Siryk for CASSGO-7
  • Loading branch information
tengu-alt committed Nov 13, 2024
1 parent 7b7e6af commit 2c5b2be
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Don't restrict server authenticator unless PasswordAuthentictor.AllowedAuthenticators is provided (CASSGO-19)

- Change Batch API to be consistent with Query() (CASSGO-7)

### Fixed

## [1.7.0] - 2024-09-23
Expand Down
16 changes: 9 additions & 7 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) {
t.Fatal(err)
}

b := session.NewBatch(LoggedBatch)
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
if err := session.ExecuteBatch(b); err == nil {
b := session.Batch(LoggedBatch)
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
if err := b.Exec(); err == nil {
t.Fatal("expected to get error for invalid query in batch")
}
}
Expand All @@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) {

micros := time.Now().UnixNano()/1e3 - 1000

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.WithTimestamp(micros)
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")

if err := b.Exec(); err != nil {
t.Fatal(err)
}

var storedTs int64
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
t.Fatal(err)
}

Expand Down
34 changes: 17 additions & 17 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -453,15 +453,15 @@ func TestCAS(t *testing.T) {
t.Fatal("truncate:", err)
}

successBatch := session.NewBatch(LoggedBatch)
successBatch := session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if !applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

successBatch = session.NewBatch(LoggedBatch)
successBatch = session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
casMap := make(map[string]interface{})
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
Expand All @@ -470,22 +470,22 @@ func TestCAS(t *testing.T) {
t.Fatal("insert should have been applied")
}

failBatch := session.NewBatch(LoggedBatch)
failBatch := session.Batch(LoggedBatch)
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

insertBatch := session.NewBatch(LoggedBatch)
insertBatch := session.Batch(LoggedBatch)
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
if err := session.ExecuteBatch(insertBatch); err != nil {
t.Fatal("insert:", err)
}

failBatch = session.NewBatch(LoggedBatch)
failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
Expand Down Expand Up @@ -610,7 +610,7 @@ func TestBatch(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 100; i++ {
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -642,9 +642,9 @@ func TestUnpreparedBatch(t *testing.T) {

var batch *Batch
if session.cfg.ProtoVersion == 2 {
batch = session.NewBatch(CounterBatch)
batch = session.Batch(CounterBatch)
} else {
batch = session.NewBatch(UnloggedBatch)
batch = session.Batch(UnloggedBatch)
}

for i := 0; i < 100; i++ {
Expand Down Expand Up @@ -683,7 +683,7 @@ func TestBatchLimit(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 65537; i++ {
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -737,7 +737,7 @@ func TestTooManyQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -769,7 +769,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -1342,7 +1342,7 @@ func TestBatchQueryInfo(t *testing.T) {
return values, nil
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)

if err := session.ExecuteBatch(batch); err != nil {
Expand Down Expand Up @@ -1470,7 +1470,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
}

stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query(stmt, "bar")
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
Expand Down Expand Up @@ -1854,7 +1854,7 @@ func TestBatchStats(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)

Expand Down Expand Up @@ -1897,7 +1897,7 @@ func TestBatchObserve(t *testing.T) {

var observedBatch *observation

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
if observedBatch != nil {
t.Fatal("batch observe called more than once")
Expand Down Expand Up @@ -3236,7 +3236,7 @@ func TestUnsetColBatch(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
Expand Down
14 changes: 12 additions & 2 deletions example_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"fmt"
"log"

gocql "github.com/gocql/gocql"
"github.com/gocql/gocql"
)

// Example_batch demonstrates how to execute a batch of statements.
Expand All @@ -49,7 +49,7 @@ func Example_batch() {

ctx := context.Background()

b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
Args: []interface{}{1, 2, "1.2"},
Expand All @@ -60,11 +60,19 @@ func Example_batch() {
Args: []interface{}{1, 3, "1.3"},
Idempotent: true,
})

err = session.ExecuteBatch(b)
if err != nil {
log.Fatal(err)
}

err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
Exec()
if err != nil {
log.Fatal(err)
}

scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
for scanner.Next() {
var pk, ck int32
Expand All @@ -77,4 +85,6 @@ func Example_batch() {
}
// 1 2 1.2
// 1 3 1.3
// 1 4 1.4
// 1 5 1.5
}
4 changes: 2 additions & 2 deletions example_lwt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"fmt"
"log"

gocql "github.com/gocql/gocql"
"github.com/gocql/gocql"
)

// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
Expand Down Expand Up @@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() {
}

executeBatch := func(ck2Version int) {
b := session.NewBatch(gocql.LoggedBatch)
b := session.Batch(gocql.LoggedBatch)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
Args: []interface{}{"b", "pk1", "ck1", 1},
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) {
iter.Close()

// Batch Message
b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.CustomPayload = customPayload
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
if err := session.ExecuteBatch(b); err != nil {
Expand Down
17 changes: 16 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
return conn.executeBatch(ctx, b)
}

// Exec executes a batch operation and returns nil if successful
// otherwise an error is returned describing the failure.
func (b *Batch) Exec() error {
iter := b.session.executeBatch(b)
return iter.Close()
}

func (s *Session) executeBatch(batch *Batch) *Iter {
// fail fast
if s.Closed() {
Expand Down Expand Up @@ -1760,7 +1767,14 @@ func NewBatch(typ BatchType) *Batch {
}

// NewBatch creates a new batch operation using defaults defined in the cluster
//
// Deprecated: use session.Batch instead
func (s *Session) NewBatch(typ BatchType) *Batch {
return s.Batch(typ)
}

// Batch creates a new batch operation using defaults defined in the cluster
func (s *Session) Batch(typ BatchType) *Batch {
s.mu.RLock()
batch := &Batch{
Type: typ,
Expand Down Expand Up @@ -1860,8 +1874,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch
}

// Query adds the query to the batch operation
func (b *Batch) Query(stmt string, args ...interface{}) {
func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
return b
}

// Bind adds the query to the batch operation and correlates it with a binding callback
Expand Down
6 changes: 3 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) {
t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err)
}

testBatch := s.NewBatch(LoggedBatch)
testBatch := s.Batch(LoggedBatch)
testBatch.Query("test")
err := s.ExecuteBatch(testBatch)

Expand Down Expand Up @@ -219,15 +219,15 @@ func TestBatchBasicAPI(t *testing.T) {
s.pool = cfg.PoolConfig.buildPool(s)

// Test UnloggedBatch
b := s.NewBatch(UnloggedBatch)
b := s.Batch(UnloggedBatch)
if b.Type != UnloggedBatch {
t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type)
} else if b.rt != cfg.RetryPolicy {
t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt)
}

// Test LoggedBatch
b = s.NewBatch(LoggedBatch)
b = s.Batch(LoggedBatch)
if b.Type != LoggedBatch {
t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type)
}
Expand Down

0 comments on commit 2c5b2be

Please sign in to comment.