From c4ad931bd19ebd93773960b38edd215ee9a64ef8 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:45:13 +0100 Subject: [PATCH] add PrepareContext support --- connection.go | 145 +++++++++++++++++++++++----------------------- duckdb.go | 2 +- duckdb_test.go | 2 +- errors.go | 24 +++++--- errors_test.go | 6 +- statement_test.go | 50 +++++++++++++++- 6 files changed, 143 insertions(+), 86 deletions(-) diff --git a/connection.go b/connection.go index f5f9cb3f..c2d6340e 100644 --- a/connection.go +++ b/connection.go @@ -28,39 +28,6 @@ func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { return driver.ErrSkip } -func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) { - if c.closed { - return nil, getError(errClosedCon, nil) - } - - stmts, count, errExtract := c.extractStmts(query) - defer C.duckdb_destroy_extracted(&stmts) - if errExtract != nil { - return nil, errExtract - } - - for i := C.idx_t(0); i < count-1; i++ { - prepared, err := c.prepareExtractedStmt(stmts, i) - if err != nil { - return nil, err - } - - // Execute the statement without any arguments and ignore the result. - if _, err = prepared.ExecContext(ctx, nil); err != nil { - return nil, err - } - if err = prepared.Close(); err != nil { - return nil, err - } - } - - prepared, err := c.prepareExtractedStmt(stmts, count-1) - if err != nil { - return nil, err - } - return prepared, nil -} - func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { prepared, err := c.prepareStmts(ctx, query) if err != nil { @@ -70,10 +37,15 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name res, err := prepared.ExecContext(ctx, args) errClose := prepared.Close() if err != nil { - err = errors.Join(err, errClose) + if errClose != nil { + return nil, errors.Join(err, errClose) + } return nil, err } - return res, errClose + if errClose != nil { + return nil, errClose + } + return res, nil } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { @@ -85,43 +57,59 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam r, err := prepared.QueryContext(ctx, args) if err != nil { errClose := prepared.Close() - err = errors.Join(err, errClose) + if errClose != nil { + return nil, errors.Join(err, errClose) + } return nil, err } // We must close the prepared statement after closing the rows r. prepared.closeOnRowsClose = true - return r, err + return r, nil +} + +func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return c.prepareStmts(ctx, query) } func (c *conn) Prepare(query string) (driver.Stmt, error) { if c.closed { - return nil, getError(errClosedCon, nil) + return nil, errors.Join(errPrepare, errClosedCon) + } + + stmts, count, err := c.extractStmts(query) + if err != nil { + return nil, err } - return c.prepareStmt(query) + defer C.duckdb_destroy_extracted(&stmts) + + if count != 1 { + return nil, errors.Join(errPrepare, errMissingPrepareContext) + } + return c.prepareExtractedStmt(stmts, 0) } -// Deprecated: Use BeginTx instead. +// Begin is deprecated: Use BeginTx instead. func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if c.tx { - panic("database/sql/driver: misuse of duckdb driver: multiple Tx") + return nil, errors.Join(errBeginTx, errMultipleTx) } if opts.ReadOnly { - return nil, errors.New("read-only transactions are not supported") + return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported) } switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: default: - return nil, errors.New("isolation levels other than default are not supported") + return nil, errors.Join(errBeginTx, errIsolationLevelNotSupported) } - if _, err := c.ExecContext(ctx, "BEGIN TRANSACTION", nil); err != nil { + if _, err := c.ExecContext(ctx, `BEGIN TRANSACTION`, nil); err != nil { return nil, err } @@ -134,51 +122,66 @@ func (c *conn) Close() error { return errClosedCon } c.closed = true - C.duckdb_disconnect(&c.duckdbCon) - return nil } -func (c *conn) prepareStmt(cmd string) (*stmt, error) { - cmdStr := C.CString(cmd) - defer C.duckdb_free(unsafe.Pointer(cmdStr)) - - var s C.duckdb_prepared_statement - if state := C.duckdb_prepare(c.duckdbCon, cmdStr, &s); state == C.DuckDBError { - dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s))) - C.duckdb_destroy_prepare(&s) - return nil, dbErr - } - - return &stmt{c: c, stmt: &s}, nil -} - func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) { cQuery := C.CString(query) defer C.duckdb_free(unsafe.Pointer(cQuery)) var stmts C.duckdb_extracted_statements - stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts) - if stmtsCount == 0 { - err := C.GoString(C.duckdb_extract_statements_error(stmts)) + count := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts) + + if count == 0 { + errMsg := C.GoString(C.duckdb_extract_statements_error(stmts)) C.duckdb_destroy_extracted(&stmts) - if err != "" { - return nil, 0, getDuckDBError(err) + if errMsg != "" { + return nil, 0, getDuckDBError(errMsg) } - return nil, 0, errors.New("no statements found") + return nil, 0, errEmptyQuery } - return stmts, stmtsCount, nil + return stmts, count, nil } -func (c *conn) prepareExtractedStmt(extractedStmts C.duckdb_extracted_statements, index C.idx_t) (*stmt, error) { +func (c *conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx_t) (*stmt, error) { var s C.duckdb_prepared_statement - if state := C.duckdb_prepare_extracted_statement(c.duckdbCon, extractedStmts, index, &s); state == C.DuckDBError { - dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s))) + state := C.duckdb_prepare_extracted_statement(c.duckdbCon, stmts, i, &s) + + if state == C.DuckDBError { + err := getDuckDBError(C.GoString(C.duckdb_prepare_error(s))) C.duckdb_destroy_prepare(&s) - return nil, dbErr + return nil, err } return &stmt{c: c, stmt: &s}, nil } + +func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) { + if c.closed { + return nil, errClosedCon + } + + stmts, count, errExtract := c.extractStmts(query) + if errExtract != nil { + return nil, errExtract + } + defer C.duckdb_destroy_extracted(&stmts) + + for i := C.idx_t(0); i < count-1; i++ { + prepared, err := c.prepareExtractedStmt(stmts, i) + if err != nil { + return nil, err + } + + // Execute the statement without any arguments and ignore the result. + if _, err = prepared.ExecContext(ctx, nil); err != nil { + return nil, err + } + if err = prepared.Close(); err != nil { + return nil, err + } + } + return c.prepareExtractedStmt(stmts, count-1) +} diff --git a/duckdb.go b/duckdb.go index 050fdc62..fdf67082 100644 --- a/duckdb.go +++ b/duckdb.go @@ -63,7 +63,7 @@ func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error defer C.duckdb_free(unsafe.Pointer(outError)) if state := C.duckdb_open_ext(connStr, &db, config, &outError); state == C.DuckDBError { - return nil, getError(errOpen, duckdbError(outError)) + return nil, getError(errConnect, duckdbError(outError)) } return &Connector{ diff --git a/duckdb_test.go b/duckdb_test.go index e0315cb4..663d9926 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -516,7 +516,7 @@ func TestMultipleStatements(t *testing.T) { // test empty query _, err := db.Exec("") require.Error(t, err) - require.Contains(t, err.Error(), "no statements found") + require.Contains(t, err.Error(), errEmptyQuery.Error()) // test invalid query _, err = db.Exec("abc;") diff --git a/errors.go b/errors.go index 6cdc2ef5..75f06725 100644 --- a/errors.go +++ b/errors.go @@ -77,12 +77,22 @@ var ( errAPI = errors.New("API error") errVectorSize = errors.New("data chunks cannot exceed duckdb's internal vector size") - errParseDSN = errors.New("could not parse DSN for database") - errOpen = errors.New("could not open database") - errSetConfig = errors.New("could not set invalid or local option for global database config") + errConnect = errors.New("could not connect to database") + errParseDSN = errors.New("could not parse DSN for database") + errSetConfig = errors.New("could not set invalid or local option for global database config") + errCreateConfig = errors.New("could not create config for database") + errInvalidCon = errors.New("not a DuckDB driver connection") errClosedCon = errors.New("closed connection") + errPrepare = errors.New("could not prepare query") + errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext") + errEmptyQuery = errors.New("empty query") + errBeginTx = errors.New("could not begin transaction") + errMultipleTx = errors.New("multiple transactions") + errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported") + errIsolationLevelNotSupported = errors.New("isolation level not supported: go-duckdb only supports the default isolation level") + errAppenderCreation = errors.New("could not create appender") errAppenderClose = errors.New("could not close appender") errAppenderDoubleClose = fmt.Errorf("%w: already closed", errAppenderClose) @@ -113,10 +123,6 @@ var ( errTableUDFColumnTypeIsNil = fmt.Errorf("%w: column type is nil", errTableUDFCreate) errProfilingInfoEmpty = errors.New("no profiling information available for this connection") - - // Errors not covered in tests. - errConnect = errors.New("could not connect to database") - errCreateConfig = errors.New("could not create config for database") ) type ErrorType int @@ -231,12 +237,14 @@ func (e *Error) Is(err error) bool { func getDuckDBError(errMsg string) error { errType := ErrorTypeInvalid - // find the end of the prefix (" Error: ") + + // Find the end of the prefix (" Error: "). if idx := strings.Index(errMsg, ": "); idx != -1 { if typ, ok := errorPrefixMap[errMsg[:idx]]; ok { errType = typ } } + return &Error{ Type: errType, Msg: errMsg, diff --git a/errors_test.go b/errors_test.go index c7df352d..a7af1d4b 100644 --- a/errors_test.go +++ b/errors_test.go @@ -24,15 +24,15 @@ func testError(t *testing.T, actual error, contains ...string) { testErrorInternal(t, actual, contains) } -func TestErrOpen(t *testing.T) { +func TestErrConnect(t *testing.T) { t.Run(errParseDSN.Error(), func(t *testing.T) { _, err := sql.Open("duckdb", ":mem ory:") testError(t, err, errParseDSN.Error()) }) - t.Run(errOpen.Error(), func(t *testing.T) { + t.Run(errConnect.Error(), func(t *testing.T) { _, err := sql.Open("duckdb", "?readonly") - testError(t, err, errOpen.Error(), duckdbErrMsg) + testError(t, err, errConnect.Error(), duckdbErrMsg) }) t.Run(errSetConfig.Error(), func(t *testing.T) { diff --git a/statement_test.go b/statement_test.go index 09db669e..17e5c59f 100644 --- a/statement_test.go +++ b/statement_test.go @@ -66,8 +66,7 @@ func TestPrepareQueryNamed(t *testing.T) { var foo, bar, foo2 int var baz string row := prepared.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2)) - err = row.Scan(&foo, &bar, &baz, &foo2) - require.NoError(t, err) + require.NoError(t, row.Scan(&foo, &bar, &baz, &foo2)) require.Equal(t, 1, foo) require.Equal(t, 2, bar) require.Equal(t, "x", baz) @@ -112,3 +111,50 @@ func TestPrepareWithError(t *testing.T) { } require.NoError(t, db.Close()) } + +func TestPreparePivot(t *testing.T) { + db := openDB(t) + ctx := context.Background() + createTable(db, t, `CREATE OR REPLACE TABLE cities(country VARCHAR, name VARCHAR, year INT, population INT)`) + _, err := db.ExecContext(ctx, `INSERT INTO cities VALUES ('NL', 'Netherlands', '2020', '42')`) + require.NoError(t, err) + + prepared, err := db.Prepare(`PIVOT cities ON year USING SUM(population)`) + require.NoError(t, err) + + var country, name string + var population int + row := prepared.QueryRow() + require.NoError(t, row.Scan(&country, &name, &population)) + require.Equal(t, "NL", country) + require.Equal(t, "Netherlands", name) + require.Equal(t, 42, population) + require.NoError(t, prepared.Close()) + + prepared, err = db.PrepareContext(ctx, `PIVOT cities ON year USING SUM(population)`) + require.NoError(t, err) + + row = prepared.QueryRow() + require.NoError(t, row.Scan(&country, &name, &population)) + require.Equal(t, "NL", country) + require.Equal(t, "Netherlands", name) + require.Equal(t, 42, population) + require.NoError(t, prepared.Close()) + + // Prepare on a connection. + c, err := db.Conn(ctx) + require.NoError(t, err) + + prepared, err = c.PrepareContext(ctx, `PIVOT cities ON year USING SUM(population)`) + require.NoError(t, err) + + row = prepared.QueryRow() + require.NoError(t, row.Scan(&country, &name, &population)) + require.Equal(t, "NL", country) + require.Equal(t, "Netherlands", name) + require.Equal(t, 42, population) + require.NoError(t, prepared.Close()) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +}