Skip to content

Commit

Permalink
some refactoring and more error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Oct 30, 2024
1 parent 156a1b4 commit c455ab2
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 73 deletions.
82 changes: 36 additions & 46 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,87 +28,77 @@ func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
return driver.ErrSkip
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) {
if c.closed {
panic("database/sql/driver: misuse of duckdb driver: ExecContext after Close")
return nil, getError(errClosedCon, nil)
}

stmts, size, err := c.extractStmts(query)
if err != nil {
return nil, err
}
stmts, count, errExtract := c.extractStmts(query)
defer C.duckdb_destroy_extracted(&stmts)
if errExtract != nil {
return nil, errExtract
}

// execute all statements without args, except the last one
for i := C.idx_t(0); i < size-1; i++ {
stmt, err := c.prepareExtractedStmt(stmts, i)
for i := C.idx_t(0); i < count-1; i++ {
prepared, err := c.prepareExtractedStmt(stmts, i)
if err != nil {
return nil, err
}
// send nil args to execute statement and ignore result
_, err = stmt.ExecContext(ctx, nil)
stmt.Close()
if err != nil {

// Execute the statement without any arguments and ignore the result.
if _, err = prepared.ExecContext(ctx, nil); err != nil {
return nil, err
}
if err = prepared.Close(); err != nil {
return nil, err
}
}

// prepare and execute last statement with args and return result
stmt, err := c.prepareExtractedStmt(stmts, size-1)
prepared, err := c.prepareExtractedStmt(stmts, count-1)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args)
return prepared, nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if c.closed {
panic("database/sql/driver: misuse of duckdb driver: QueryContext after Close")
}

stmts, size, err := c.extractStmts(query)
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
prepared, err := c.prepareStmts(ctx, query)
if err != nil {
return nil, err
}
defer C.duckdb_destroy_extracted(&stmts)

// execute all statements without args, except the last one
for i := C.idx_t(0); i < size-1; i++ {
stmt, err := c.prepareExtractedStmt(stmts, i)
if err != nil {
return nil, err
}
// send nil args to execute statement and ignore result (using ExecContext since we're ignoring the result anyway)
_, err = stmt.ExecContext(ctx, nil)
stmt.Close()
if err != nil {
return nil, err
}
res, err := prepared.ExecContext(ctx, args)
errClose := prepared.Close()
if err != nil {
err = errors.Join(err, errClose)
return nil, err
}
return res, errClose
}

// prepare and execute last statement with args and return result
stmt, err := c.prepareExtractedStmt(stmts, size-1)
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
prepared, err := c.prepareStmts(ctx, query)
if err != nil {
return nil, err
}

rows, err := stmt.QueryContext(ctx, args)
r, err := prepared.QueryContext(ctx, args)
if err != nil {
stmt.Close()
errClose := prepared.Close()
err = errors.Join(err, errClose)
return nil, err
}

// we can't close the statement before the query result rows are closed
stmt.closeOnRowsClose = true
return rows, err
// We must close the prepared statement after closing the rows r.
prepared.closeOnRowsClose = true
return r, err
}

func (c *conn) Prepare(cmd string) (driver.Stmt, error) {
func (c *conn) Prepare(query string) (driver.Stmt, error) {
if c.closed {
panic("database/sql/driver: misuse of duckdb driver: Prepare after Close")
return nil, getError(errClosedCon, nil)
}
return c.prepareStmt(cmd)
return c.prepareStmt(query)
}

// Deprecated: Use BeginTx instead.
Expand Down
2 changes: 1 addition & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestErrNestedMap(t *testing.T) {
db := openDB(t)

var m Map
err := db.QueryRow("SELECT MAP([MAP([1], [1]), MAP([2], [2])], ['a', 'e'])").Scan(&m)
err := db.QueryRow(`SELECT MAP([MAP([1], [1]), MAP([2], [2])], ['a', 'e'])`).Scan(&m)
testError(t, err, errUnsupportedMapKeyType.Error())
require.NoError(t, db.Close())
}
Expand Down
72 changes: 46 additions & 26 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,92 +3,112 @@ package duckdb
import (
"context"
"database/sql"
"errors"
"testing"

"github.com/stretchr/testify/require"
)

func TestPrepareQueryAutoIncrement(t *testing.T) {
func TestPrepareQuery(t *testing.T) {
db := openDB(t)
defer db.Close()
createFooTable(db, t)

stmt, err := db.Prepare("SELECT * FROM foo WHERE baz=?")
prepared, err := db.Prepare(`SELECT * FROM foo WHERE baz = ?`)
require.NoError(t, err)
res, err := prepared.Query(0)
require.NoError(t, err)

require.NoError(t, res.Close())
require.NoError(t, prepared.Close())

// Prepare on a connection.
c, err := db.Conn(context.Background())
require.NoError(t, err)
defer stmt.Close()

rows, err := stmt.Query(0)
prepared, err = c.PrepareContext(context.Background(), `SELECT * FROM foo WHERE baz = ?`)
require.NoError(t, err)
res, err = prepared.Query(0)
require.NoError(t, err)
defer rows.Close()

require.NoError(t, res.Close())
require.NoError(t, prepared.Close())
require.NoError(t, c.Close())
require.NoError(t, db.Close())
}

func TestPrepareQueryPositional(t *testing.T) {
db := openDB(t)
defer db.Close()
createFooTable(db, t)

stmt, err := db.Prepare("SELECT $1, $2 as foo WHERE foo=$2")
prepared, err := db.Prepare(`SELECT $1, $2 AS foo WHERE foo = $2`)
require.NoError(t, err)
defer stmt.Close()

var foo, bar int
row := stmt.QueryRow(1, 2)
row := prepared.QueryRow(1, 2)
require.NoError(t, err)

err = row.Scan(&foo, &bar)
require.NoError(t, err)
require.Equal(t, 1, foo)
require.Equal(t, 2, bar)

require.NoError(t, prepared.Close())
require.NoError(t, db.Close())
}

func TestPrepareQueryNamed(t *testing.T) {
db := openDB(t)
defer db.Close()
createFooTable(db, t)

stmt, err := db.PrepareContext(context.Background(), "SELECT $foo, $bar, $baz, $foo")
prepared, err := db.PrepareContext(context.Background(), `SELECT $foo, $bar, $baz, $foo`)
require.NoError(t, err)
defer stmt.Close()

var foo, bar, foo2 int
var baz string
err = stmt.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2)).Scan(&foo, &bar, &baz, &foo2)
row := prepared.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2))
err = row.Scan(&foo, &bar, &baz, &foo2)
require.NoError(t, err)
if foo != 1 || bar != 2 || baz != "x" || foo2 != 1 {
require.Fail(t, "bad values: %d %d %s %d", foo, bar, baz, foo2)
}
require.Equal(t, 1, foo)
require.Equal(t, 2, bar)
require.Equal(t, "x", baz)
require.Equal(t, 1, foo2)

require.NoError(t, prepared.Close())
require.NoError(t, db.Close())
}

func TestPrepareWithError(t *testing.T) {
db := openDB(t)
defer db.Close()
createFooTable(db, t)

testCases := []struct {
tpl string
sql string
err string
}{
{
tpl: "SELECT * FROM tbl WHERE baz=?",
err: "Table with name tbl does not exist",
sql: `SELECT * FROM tbl WHERE baz = ?`,
err: `Table with name tbl does not exist`,
},
{
tpl: "SELECT * FROM foo WHERE col=?",
sql: `SELECT * FROM foo WHERE col = ?`,
err: `Referenced column "col" not found in FROM clause`,
},
{
tpl: "SELECT * FROM foo col=?",
sql: `SELECT * FROM foo col = ?`,
err: `syntax error at or near "="`,
},
}
for _, tc := range testCases {
stmt, err := db.Prepare(tc.tpl)
prepared, err := db.Prepare(tc.sql)
if err != nil {
if _, ok := err.(*Error); !ok {
var dbErr *Error
if !errors.As(err, &dbErr) {
require.Fail(t, "error type is not (*duckdb.Error)")
}
require.ErrorContains(t, err, tc.err)
continue
}
defer stmt.Close()
require.NoError(t, prepared.Close())
}
require.NoError(t, db.Close())
}

0 comments on commit c455ab2

Please sign in to comment.