Skip to content

Commit

Permalink
Merge pull request #300 from taniabogatsch/main
Browse files Browse the repository at this point in the history
PrepareContext support
  • Loading branch information
taniabogatsch authored Oct 31, 2024
2 parents ad693b3 + c4ad931 commit 26097d1
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 86 deletions.
145 changes: 74 additions & 71 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,6 @@ func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
return driver.ErrSkip
}

func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) {
if c.closed {
return nil, getError(errClosedCon, nil)
}

stmts, count, errExtract := c.extractStmts(query)
defer C.duckdb_destroy_extracted(&stmts)
if errExtract != nil {
return nil, errExtract
}

for i := C.idx_t(0); i < count-1; i++ {
prepared, err := c.prepareExtractedStmt(stmts, i)
if err != nil {
return nil, err
}

// Execute the statement without any arguments and ignore the result.
if _, err = prepared.ExecContext(ctx, nil); err != nil {
return nil, err
}
if err = prepared.Close(); err != nil {
return nil, err
}
}

prepared, err := c.prepareExtractedStmt(stmts, count-1)
if err != nil {
return nil, err
}
return prepared, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
prepared, err := c.prepareStmts(ctx, query)
if err != nil {
Expand All @@ -70,10 +37,15 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
res, err := prepared.ExecContext(ctx, args)
errClose := prepared.Close()
if err != nil {
err = errors.Join(err, errClose)
if errClose != nil {
return nil, errors.Join(err, errClose)
}
return nil, err
}
return res, errClose
if errClose != nil {
return nil, errClose
}
return res, nil
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
Expand All @@ -85,43 +57,59 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
r, err := prepared.QueryContext(ctx, args)
if err != nil {
errClose := prepared.Close()
err = errors.Join(err, errClose)
if errClose != nil {
return nil, errors.Join(err, errClose)
}
return nil, err
}

// We must close the prepared statement after closing the rows r.
prepared.closeOnRowsClose = true
return r, err
return r, nil
}

func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return c.prepareStmts(ctx, query)
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
if c.closed {
return nil, getError(errClosedCon, nil)
return nil, errors.Join(errPrepare, errClosedCon)
}

stmts, count, err := c.extractStmts(query)
if err != nil {
return nil, err
}
return c.prepareStmt(query)
defer C.duckdb_destroy_extracted(&stmts)

if count != 1 {
return nil, errors.Join(errPrepare, errMissingPrepareContext)
}
return c.prepareExtractedStmt(stmts, 0)
}

// Deprecated: Use BeginTx instead.
// Begin is deprecated: Use BeginTx instead.
func (c *conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.tx {
panic("database/sql/driver: misuse of duckdb driver: multiple Tx")
return nil, errors.Join(errBeginTx, errMultipleTx)
}

if opts.ReadOnly {
return nil, errors.New("read-only transactions are not supported")
return nil, errors.Join(errBeginTx, errReadOnlyTxNotSupported)
}

switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
default:
return nil, errors.New("isolation levels other than default are not supported")
return nil, errors.Join(errBeginTx, errIsolationLevelNotSupported)
}

if _, err := c.ExecContext(ctx, "BEGIN TRANSACTION", nil); err != nil {
if _, err := c.ExecContext(ctx, `BEGIN TRANSACTION`, nil); err != nil {
return nil, err
}

Expand All @@ -134,51 +122,66 @@ func (c *conn) Close() error {
return errClosedCon
}
c.closed = true

C.duckdb_disconnect(&c.duckdbCon)

return nil
}

func (c *conn) prepareStmt(cmd string) (*stmt, error) {
cmdStr := C.CString(cmd)
defer C.duckdb_free(unsafe.Pointer(cmdStr))

var s C.duckdb_prepared_statement
if state := C.duckdb_prepare(c.duckdbCon, cmdStr, &s); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s)))
C.duckdb_destroy_prepare(&s)
return nil, dbErr
}

return &stmt{c: c, stmt: &s}, nil
}

func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) {
cQuery := C.CString(query)
defer C.duckdb_free(unsafe.Pointer(cQuery))

var stmts C.duckdb_extracted_statements
stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts)
if stmtsCount == 0 {
err := C.GoString(C.duckdb_extract_statements_error(stmts))
count := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts)

if count == 0 {
errMsg := C.GoString(C.duckdb_extract_statements_error(stmts))
C.duckdb_destroy_extracted(&stmts)
if err != "" {
return nil, 0, getDuckDBError(err)
if errMsg != "" {
return nil, 0, getDuckDBError(errMsg)
}
return nil, 0, errors.New("no statements found")
return nil, 0, errEmptyQuery
}

return stmts, stmtsCount, nil
return stmts, count, nil
}

func (c *conn) prepareExtractedStmt(extractedStmts C.duckdb_extracted_statements, index C.idx_t) (*stmt, error) {
func (c *conn) prepareExtractedStmt(stmts C.duckdb_extracted_statements, i C.idx_t) (*stmt, error) {
var s C.duckdb_prepared_statement
if state := C.duckdb_prepare_extracted_statement(c.duckdbCon, extractedStmts, index, &s); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s)))
state := C.duckdb_prepare_extracted_statement(c.duckdbCon, stmts, i, &s)

if state == C.DuckDBError {
err := getDuckDBError(C.GoString(C.duckdb_prepare_error(s)))
C.duckdb_destroy_prepare(&s)
return nil, dbErr
return nil, err
}

return &stmt{c: c, stmt: &s}, nil
}

func (c *conn) prepareStmts(ctx context.Context, query string) (*stmt, error) {
if c.closed {
return nil, errClosedCon
}

stmts, count, errExtract := c.extractStmts(query)
if errExtract != nil {
return nil, errExtract
}
defer C.duckdb_destroy_extracted(&stmts)

for i := C.idx_t(0); i < count-1; i++ {
prepared, err := c.prepareExtractedStmt(stmts, i)
if err != nil {
return nil, err
}

// Execute the statement without any arguments and ignore the result.
if _, err = prepared.ExecContext(ctx, nil); err != nil {
return nil, err
}
if err = prepared.Close(); err != nil {
return nil, err
}
}
return c.prepareExtractedStmt(stmts, count-1)
}
2 changes: 1 addition & 1 deletion duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error
defer C.duckdb_free(unsafe.Pointer(outError))

if state := C.duckdb_open_ext(connStr, &db, config, &outError); state == C.DuckDBError {
return nil, getError(errOpen, duckdbError(outError))
return nil, getError(errConnect, duckdbError(outError))
}

return &Connector{
Expand Down
2 changes: 1 addition & 1 deletion duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func TestMultipleStatements(t *testing.T) {
// test empty query
_, err := db.Exec("")
require.Error(t, err)
require.Contains(t, err.Error(), "no statements found")
require.Contains(t, err.Error(), errEmptyQuery.Error())

// test invalid query
_, err = db.Exec("abc;")
Expand Down
24 changes: 16 additions & 8 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,22 @@ var (
errAPI = errors.New("API error")
errVectorSize = errors.New("data chunks cannot exceed duckdb's internal vector size")

errParseDSN = errors.New("could not parse DSN for database")
errOpen = errors.New("could not open database")
errSetConfig = errors.New("could not set invalid or local option for global database config")
errConnect = errors.New("could not connect to database")
errParseDSN = errors.New("could not parse DSN for database")
errSetConfig = errors.New("could not set invalid or local option for global database config")
errCreateConfig = errors.New("could not create config for database")

errInvalidCon = errors.New("not a DuckDB driver connection")
errClosedCon = errors.New("closed connection")

errPrepare = errors.New("could not prepare query")
errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext")
errEmptyQuery = errors.New("empty query")
errBeginTx = errors.New("could not begin transaction")
errMultipleTx = errors.New("multiple transactions")
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")
errIsolationLevelNotSupported = errors.New("isolation level not supported: go-duckdb only supports the default isolation level")

errAppenderCreation = errors.New("could not create appender")
errAppenderClose = errors.New("could not close appender")
errAppenderDoubleClose = fmt.Errorf("%w: already closed", errAppenderClose)
Expand Down Expand Up @@ -113,10 +123,6 @@ var (
errTableUDFColumnTypeIsNil = fmt.Errorf("%w: column type is nil", errTableUDFCreate)

errProfilingInfoEmpty = errors.New("no profiling information available for this connection")

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
)

type ErrorType int
Expand Down Expand Up @@ -231,12 +237,14 @@ func (e *Error) Is(err error) bool {

func getDuckDBError(errMsg string) error {
errType := ErrorTypeInvalid
// find the end of the prefix ("<error-type> Error: ")

// Find the end of the prefix ("<error-type> Error: ").
if idx := strings.Index(errMsg, ": "); idx != -1 {
if typ, ok := errorPrefixMap[errMsg[:idx]]; ok {
errType = typ
}
}

return &Error{
Type: errType,
Msg: errMsg,
Expand Down
6 changes: 3 additions & 3 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ func testError(t *testing.T, actual error, contains ...string) {
testErrorInternal(t, actual, contains)
}

func TestErrOpen(t *testing.T) {
func TestErrConnect(t *testing.T) {
t.Run(errParseDSN.Error(), func(t *testing.T) {
_, err := sql.Open("duckdb", ":mem ory:")
testError(t, err, errParseDSN.Error())
})

t.Run(errOpen.Error(), func(t *testing.T) {
t.Run(errConnect.Error(), func(t *testing.T) {
_, err := sql.Open("duckdb", "?readonly")
testError(t, err, errOpen.Error(), duckdbErrMsg)
testError(t, err, errConnect.Error(), duckdbErrMsg)
})

t.Run(errSetConfig.Error(), func(t *testing.T) {
Expand Down
50 changes: 48 additions & 2 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ func TestPrepareQueryNamed(t *testing.T) {
var foo, bar, foo2 int
var baz string
row := prepared.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2))
err = row.Scan(&foo, &bar, &baz, &foo2)
require.NoError(t, err)
require.NoError(t, row.Scan(&foo, &bar, &baz, &foo2))
require.Equal(t, 1, foo)
require.Equal(t, 2, bar)
require.Equal(t, "x", baz)
Expand Down Expand Up @@ -112,3 +111,50 @@ func TestPrepareWithError(t *testing.T) {
}
require.NoError(t, db.Close())
}

func TestPreparePivot(t *testing.T) {
db := openDB(t)
ctx := context.Background()
createTable(db, t, `CREATE OR REPLACE TABLE cities(country VARCHAR, name VARCHAR, year INT, population INT)`)
_, err := db.ExecContext(ctx, `INSERT INTO cities VALUES ('NL', 'Netherlands', '2020', '42')`)
require.NoError(t, err)

prepared, err := db.Prepare(`PIVOT cities ON year USING SUM(population)`)
require.NoError(t, err)

var country, name string
var population int
row := prepared.QueryRow()
require.NoError(t, row.Scan(&country, &name, &population))
require.Equal(t, "NL", country)
require.Equal(t, "Netherlands", name)
require.Equal(t, 42, population)
require.NoError(t, prepared.Close())

prepared, err = db.PrepareContext(ctx, `PIVOT cities ON year USING SUM(population)`)
require.NoError(t, err)

row = prepared.QueryRow()
require.NoError(t, row.Scan(&country, &name, &population))
require.Equal(t, "NL", country)
require.Equal(t, "Netherlands", name)
require.Equal(t, 42, population)
require.NoError(t, prepared.Close())

// Prepare on a connection.
c, err := db.Conn(ctx)
require.NoError(t, err)

prepared, err = c.PrepareContext(ctx, `PIVOT cities ON year USING SUM(population)`)
require.NoError(t, err)

row = prepared.QueryRow()
require.NoError(t, row.Scan(&country, &name, &population))
require.Equal(t, "NL", country)
require.Equal(t, "Netherlands", name)
require.Equal(t, 42, population)
require.NoError(t, prepared.Close())

require.NoError(t, c.Close())
require.NoError(t, db.Close())
}

0 comments on commit 26097d1

Please sign in to comment.