From c455ab27a0c064bdc903f4927534b88e9392ba53 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Wed, 30 Oct 2024 17:02:31 +0100 Subject: [PATCH] some refactoring and more error handling --- connection.go | 82 +++++++++++++++++++++-------------------------- errors_test.go | 2 +- statement_test.go | 72 ++++++++++++++++++++++++++--------------- 3 files changed, 83 insertions(+), 73 deletions(-) diff --git a/connection.go b/connection.go index 9b127ddc..f5f9cb3f 100644 --- a/connection.go +++ b/connection.go @@ -28,87 +28,77 @@ func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { return driver.ErrSkip } -func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) { if c.closed { - panic("database/sql/driver: misuse of duckdb driver: ExecContext after Close") + return nil, getError(errClosedCon, nil) } - stmts, size, err := c.extractStmts(query) - if err != nil { - return nil, err - } + stmts, count, errExtract := c.extractStmts(query) defer C.duckdb_destroy_extracted(&stmts) + if errExtract != nil { + return nil, errExtract + } - // execute all statements without args, except the last one - for i := C.idx_t(0); i < size-1; i++ { - stmt, err := c.prepareExtractedStmt(stmts, i) + for i := C.idx_t(0); i < count-1; i++ { + prepared, err := c.prepareExtractedStmt(stmts, i) if err != nil { return nil, err } - // send nil args to execute statement and ignore result - _, err = stmt.ExecContext(ctx, nil) - stmt.Close() - if err != nil { + + // Execute the statement without any arguments and ignore the result. + if _, err = prepared.ExecContext(ctx, nil); err != nil { + return nil, err + } + if err = prepared.Close(); err != nil { return nil, err } } - // prepare and execute last statement with args and return result - stmt, err := c.prepareExtractedStmt(stmts, size-1) + prepared, err := c.prepareExtractedStmt(stmts, count-1) if err != nil { return nil, err } - defer stmt.Close() - return stmt.ExecContext(ctx, args) + return prepared, nil } -func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - if c.closed { - panic("database/sql/driver: misuse of duckdb driver: QueryContext after Close") - } - - stmts, size, err := c.extractStmts(query) +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + prepared, err := c.prepareStmts(ctx, query) if err != nil { return nil, err } - defer C.duckdb_destroy_extracted(&stmts) - // execute all statements without args, except the last one - for i := C.idx_t(0); i < size-1; i++ { - stmt, err := c.prepareExtractedStmt(stmts, i) - if err != nil { - return nil, err - } - // send nil args to execute statement and ignore result (using ExecContext since we're ignoring the result anyway) - _, err = stmt.ExecContext(ctx, nil) - stmt.Close() - if err != nil { - return nil, err - } + res, err := prepared.ExecContext(ctx, args) + errClose := prepared.Close() + if err != nil { + err = errors.Join(err, errClose) + return nil, err } + return res, errClose +} - // prepare and execute last statement with args and return result - stmt, err := c.prepareExtractedStmt(stmts, size-1) +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + prepared, err := c.prepareStmts(ctx, query) if err != nil { return nil, err } - rows, err := stmt.QueryContext(ctx, args) + r, err := prepared.QueryContext(ctx, args) if err != nil { - stmt.Close() + errClose := prepared.Close() + err = errors.Join(err, errClose) return nil, err } - // we can't close the statement before the query result rows are closed - stmt.closeOnRowsClose = true - return rows, err + // We must close the prepared statement after closing the rows r. + prepared.closeOnRowsClose = true + return r, err } -func (c *conn) Prepare(cmd string) (driver.Stmt, error) { +func (c *conn) Prepare(query string) (driver.Stmt, error) { if c.closed { - panic("database/sql/driver: misuse of duckdb driver: Prepare after Close") + return nil, getError(errClosedCon, nil) } - return c.prepareStmt(cmd) + return c.prepareStmt(query) } // Deprecated: Use BeginTx instead. diff --git a/errors_test.go b/errors_test.go index 9d7ad5d7..c7df352d 100644 --- a/errors_test.go +++ b/errors_test.go @@ -51,7 +51,7 @@ func TestErrNestedMap(t *testing.T) { db := openDB(t) var m Map - err := db.QueryRow("SELECT MAP([MAP([1], [1]), MAP([2], [2])], ['a', 'e'])").Scan(&m) + err := db.QueryRow(`SELECT MAP([MAP([1], [1]), MAP([2], [2])], ['a', 'e'])`).Scan(&m) testError(t, err, errUnsupportedMapKeyType.Error()) require.NoError(t, db.Close()) } diff --git a/statement_test.go b/statement_test.go index 5ebca8a1..09db669e 100644 --- a/statement_test.go +++ b/statement_test.go @@ -3,92 +3,112 @@ package duckdb import ( "context" "database/sql" + "errors" "testing" "github.com/stretchr/testify/require" ) -func TestPrepareQueryAutoIncrement(t *testing.T) { +func TestPrepareQuery(t *testing.T) { db := openDB(t) - defer db.Close() createFooTable(db, t) - stmt, err := db.Prepare("SELECT * FROM foo WHERE baz=?") + prepared, err := db.Prepare(`SELECT * FROM foo WHERE baz = ?`) + require.NoError(t, err) + res, err := prepared.Query(0) + require.NoError(t, err) + + require.NoError(t, res.Close()) + require.NoError(t, prepared.Close()) + + // Prepare on a connection. + c, err := db.Conn(context.Background()) require.NoError(t, err) - defer stmt.Close() - rows, err := stmt.Query(0) + prepared, err = c.PrepareContext(context.Background(), `SELECT * FROM foo WHERE baz = ?`) + require.NoError(t, err) + res, err = prepared.Query(0) require.NoError(t, err) - defer rows.Close() + + require.NoError(t, res.Close()) + require.NoError(t, prepared.Close()) + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) } func TestPrepareQueryPositional(t *testing.T) { db := openDB(t) - defer db.Close() createFooTable(db, t) - stmt, err := db.Prepare("SELECT $1, $2 as foo WHERE foo=$2") + prepared, err := db.Prepare(`SELECT $1, $2 AS foo WHERE foo = $2`) require.NoError(t, err) - defer stmt.Close() var foo, bar int - row := stmt.QueryRow(1, 2) + row := prepared.QueryRow(1, 2) require.NoError(t, err) err = row.Scan(&foo, &bar) require.NoError(t, err) require.Equal(t, 1, foo) require.Equal(t, 2, bar) + + require.NoError(t, prepared.Close()) + require.NoError(t, db.Close()) } func TestPrepareQueryNamed(t *testing.T) { db := openDB(t) - defer db.Close() createFooTable(db, t) - stmt, err := db.PrepareContext(context.Background(), "SELECT $foo, $bar, $baz, $foo") + prepared, err := db.PrepareContext(context.Background(), `SELECT $foo, $bar, $baz, $foo`) require.NoError(t, err) - defer stmt.Close() + var foo, bar, foo2 int var baz string - err = stmt.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2)).Scan(&foo, &bar, &baz, &foo2) + row := prepared.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2)) + err = row.Scan(&foo, &bar, &baz, &foo2) require.NoError(t, err) - if foo != 1 || bar != 2 || baz != "x" || foo2 != 1 { - require.Fail(t, "bad values: %d %d %s %d", foo, bar, baz, foo2) - } + require.Equal(t, 1, foo) + require.Equal(t, 2, bar) + require.Equal(t, "x", baz) + require.Equal(t, 1, foo2) + + require.NoError(t, prepared.Close()) + require.NoError(t, db.Close()) } func TestPrepareWithError(t *testing.T) { db := openDB(t) - defer db.Close() createFooTable(db, t) testCases := []struct { - tpl string + sql string err string }{ { - tpl: "SELECT * FROM tbl WHERE baz=?", - err: "Table with name tbl does not exist", + sql: `SELECT * FROM tbl WHERE baz = ?`, + err: `Table with name tbl does not exist`, }, { - tpl: "SELECT * FROM foo WHERE col=?", + sql: `SELECT * FROM foo WHERE col = ?`, err: `Referenced column "col" not found in FROM clause`, }, { - tpl: "SELECT * FROM foo col=?", + sql: `SELECT * FROM foo col = ?`, err: `syntax error at or near "="`, }, } for _, tc := range testCases { - stmt, err := db.Prepare(tc.tpl) + prepared, err := db.Prepare(tc.sql) if err != nil { - if _, ok := err.(*Error); !ok { + var dbErr *Error + if !errors.As(err, &dbErr) { require.Fail(t, "error type is not (*duckdb.Error)") } require.ErrorContains(t, err, tc.err) continue } - defer stmt.Close() + require.NoError(t, prepared.Close()) } + require.NoError(t, db.Close()) }