From 43c7d13306d94f10d0e3b3717067c78debb8c7b8 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Mon, 16 Dec 2024 19:17:14 +0800 Subject: [PATCH] Expose `Bind` and `(Query|Exec)Bound` on `Stmt` for advanced usage --- errors.go | 3 +++ statement.go | 66 ++++++++++++++++++++++++++++++++++++++++++++--- statement_test.go | 54 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 3 deletions(-) diff --git a/errors.go b/errors.go index 2e03bc7d..57dcfd1e 100644 --- a/errors.go +++ b/errors.go @@ -106,6 +106,9 @@ var ( errPrepare = errors.New("could not prepare query") errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext") errEmptyQuery = errors.New("empty query") + errCouldNotBind = errors.New("could not bind parameter") + errActiveRows = errors.New("ExecContext or QueryContext with active Rows") + errNotBound = errors.New("parameters have not been bound") errBeginTx = errors.New("could not begin transaction") errMultipleTx = errors.New("multiple transactions") errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported") diff --git a/statement.go b/statement.go index b17f71f8..641c85ad 100644 --- a/statement.go +++ b/statement.go @@ -53,6 +53,7 @@ type Stmt struct { c *Conn stmt *C.duckdb_prepared_statement closeOnRowsClose bool + bound bool closed bool rows bool } @@ -131,6 +132,18 @@ func (s *Stmt) StatementType() (StmtType, error) { return StmtType(C.duckdb_prepared_statement_type(*s.stmt)), nil } +// Bind binds the parameters to the statement. +// WARNING: This is a low-level API and should be used with caution. +func (s *Stmt) Bind(args []driver.NamedValue) error { + if s.closed { + return errors.Join(errCouldNotBind, errClosedStmt) + } + if s.stmt == nil { + return errors.Join(errCouldNotBind, errUninitializedStmt) + } + return s.bind(args) +} + func (s *Stmt) bind(args []driver.NamedValue) error { if s.NumInput() > len(args) { return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput()) @@ -258,6 +271,7 @@ func (s *Stmt) bind(args []driver.NamedValue) error { } } + s.bound = true return nil } @@ -279,6 +293,30 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv return &result{ra}, nil } +// ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE. +// It can only be used after Bind has been called. +// WARNING: This is a low-level API and should be used with caution. +func (s *Stmt) ExecBound(ctx context.Context) (driver.Result, error) { + if s.closed { + return nil, errClosedCon + } + if s.rows { + return nil, errActiveRows + } + if !s.bound { + return nil, errNotBound + } + + res, err := s.executeBound(ctx) + if err != nil { + return nil, err + } + defer C.duckdb_destroy_result(res) + + ra := int64(C.duckdb_value_int64(res, 0, 0)) + return &result{ra}, nil +} + // Deprecated: Use QueryContext instead. func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { return s.QueryContext(context.Background(), argsToNamedArgs(args)) @@ -295,6 +333,28 @@ func (s *Stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (dri return newRowsWithStmt(*res, s), nil } +// QueryBound executes a bound query that may return rows, such as a SELECT. +// It can only be used after Bind has been called. +// WARNING: This is a low-level API and should be used with caution. +func (s *Stmt) QueryBound(ctx context.Context) (driver.Rows, error) { + if s.closed { + return nil, errClosedCon + } + if s.rows { + return nil, errActiveRows + } + if !s.bound { + return nil, errNotBound + } + + res, err := s.executeBound(ctx) + if err != nil { + return nil, err + } + s.rows = true + return newRowsWithStmt(*res, s), nil +} + // This method executes the query in steps and checks if context is cancelled before executing each step. // It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) { @@ -304,11 +364,13 @@ func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb if s.rows { panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext with active Rows") } - if err := s.bind(args); err != nil { return nil, err } + return s.executeBound(ctx) +} +func (s *Stmt) executeBound(ctx context.Context) (*C.duckdb_result, error) { var pendingRes C.duckdb_pending_result if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); state == C.DuckDBError { dbErr := getDuckDBError(C.GoString(C.duckdb_pending_error(pendingRes))) @@ -360,5 +422,3 @@ func argsToNamedArgs(values []driver.Value) []driver.NamedValue { } return args } - -var errCouldNotBind = errors.New("could not bind parameter") diff --git a/statement_test.go b/statement_test.go index 6f910fd6..bb5b0c60 100644 --- a/statement_test.go +++ b/statement_test.go @@ -3,6 +3,7 @@ package duckdb import ( "context" "database/sql" + "database/sql/driver" "errors" "testing" @@ -56,6 +57,18 @@ func TestPrepareQuery(t *testing.T) { require.ErrorContains(t, err, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) + rows, err := stmt.QueryBound(context.Background()) + require.Nil(t, rows) + require.ErrorIs(t, err, errNotBound) + + err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}}) + require.NoError(t, err) + + rows, err = stmt.QueryBound(context.Background()) + require.NoError(t, err) + require.NotNil(t, rows) + require.NoError(t, rows.Close()) + require.NoError(t, stmt.Close()) stmtType, err = stmt.StatementType() @@ -66,6 +79,10 @@ func TestPrepareQuery(t *testing.T) { require.ErrorIs(t, err, errClosedStmt) require.Equal(t, TYPE_INVALID, paramType) + err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}}) + require.ErrorIs(t, err, errCouldNotBind) + require.ErrorIs(t, err, errClosedStmt) + return nil }) require.NoError(t, err) @@ -146,6 +163,17 @@ func TestPrepareQueryPositional(t *testing.T) { require.ErrorContains(t, err, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) + result, err := stmt.ExecBound(context.Background()) + require.Nil(t, result) + require.ErrorIs(t, err, errNotBound) + + err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}}) + require.NoError(t, err) + + result, err = stmt.ExecBound(context.Background()) + require.NoError(t, err) + require.NotNil(t, result) + require.NoError(t, stmt.Close()) stmtType, err = stmt.StatementType() @@ -160,6 +188,10 @@ func TestPrepareQueryPositional(t *testing.T) { require.ErrorIs(t, err, errClosedStmt) require.Equal(t, TYPE_INVALID, paramType) + err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}}) + require.ErrorIs(t, err, errCouldNotBind) + require.ErrorIs(t, err, errClosedStmt) + return nil }) require.NoError(t, err) @@ -245,6 +277,17 @@ func TestPrepareQueryNamed(t *testing.T) { require.ErrorContains(t, err, paramIndexErrMsg) require.Equal(t, TYPE_INVALID, paramType) + result, err := stmt.ExecBound(context.Background()) + require.Nil(t, result) + require.ErrorIs(t, err, errNotBound) + + err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}}) + require.NoError(t, err) + + result, err = stmt.ExecBound(context.Background()) + require.NoError(t, err) + require.NotNil(t, result) + require.NoError(t, stmt.Close()) stmtType, err = stmt.StatementType() @@ -259,6 +302,10 @@ func TestPrepareQueryNamed(t *testing.T) { require.ErrorIs(t, err, errClosedStmt) require.Equal(t, TYPE_INVALID, paramType) + err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}}) + require.ErrorIs(t, err, errCouldNotBind) + require.ErrorIs(t, err, errClosedStmt) + return nil }) require.NoError(t, err) @@ -280,6 +327,13 @@ func TestUninitializedStmt(t *testing.T) { paramName, err := stmt.ParamName(1) require.ErrorIs(t, err, errUninitializedStmt) require.Equal(t, "", paramName) + + err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}}) + require.ErrorIs(t, err, errCouldNotBind) + require.ErrorIs(t, err, errUninitializedStmt) + + _, err = stmt.ExecBound(context.Background()) + require.ErrorIs(t, err, errNotBound) } func TestPrepareWithError(t *testing.T) {