From 1ca51413c300c803a0cd8e4615335bb5f63c6a1b Mon Sep 17 00:00:00 2001 From: Marc Boeker Date: Mon, 15 Jan 2024 17:24:33 +0100 Subject: [PATCH] Refactor to use driver.Conn instead of any. --- README.md | 162 ++++++++++++++--------------------------------- appender.go | 4 +- appender_test.go | 74 ++++++++-------------- arrow.go | 101 +++++++++++++++++++++++++---- arrow.h | 43 ------------- arrow_test.go | 119 ++++++++++++++++------------------ duckdb.go | 36 +++-------- duckdb_test.go | 20 +++++- statement.go | 66 ++++--------------- 9 files changed, 260 insertions(+), 365 deletions(-) delete mode 100644 arrow.h diff --git a/README.md b/README.md index f677da21..ab733d06 100644 --- a/README.md +++ b/README.md @@ -70,77 +70,33 @@ Please refer to the [database/sql](https://godoc.org/database/sql) GoDoc for fur If you want to use the [DuckDB Appender API](https://duckdb.org/docs/data/appender.html), you can obtain a new Appender by supplying a DuckDB connection to `NewAppenderFromConn()`. ```go -package main - -import ( - "context" - "database/sql" - - "github.com/marcboeker/go-duckdb" -) - -func main() { - connector, err := duckdb.OpenConnector("test.db", nil) - if err != nil { - panic(err) - } - driverConn, err := connector.Connect(context.Background()) - if err != nil { - panic(err) - } - defer driverConn.Close() - - // Retrieve appender from connection (note that you have to create the table 'test' beforehand). - appender, err := duckdb.NewAppenderFromConn(driverConn, "", "test") - if err != nil { - panic(err) - } - defer appender.Close() - - err = appender.AppendRow(1, "a") - if err != nil { - panic(err) - } - - // Optional, if you want to access the appended rows immediately. - err = appender.Flush() - if err != nil { - panic(err) - } - - // Alternatively, you can use Raw method of sql.Conn to obtain the driver connection. - db, err := sql.Open("duckdb", "") - if err != nil { - panic(err) - } - defer db.Close() - - conn, err := db.Conn(context.Background()) - if err != nil { - panic(err) - } - defer conn.Close() - - err = conn.Raw(func(driverConn any) error { - // Notice usage of driverConn - appender, err := duckdb.NewAppenderFromConn(driverConn, "", "test") - if err != nil { - panic(err) - } - defer appender.Close() - - err = appender.AppendRow(...) - if err != nil { - panic(err) - } - - return nil - }) - if err != nil { - panic(err) - } +connector, err := duckdb.NewConnector("test.db", nil) +if err != nil { + ... +} +conn, err := connector.Connect(context.Background()) +if err != nil { + ... +} +defer conn.Close() + +// Retrieve appender from connection (note that you have to create the table 'test' beforehand). +appender, err := NewAppenderFromConn(conn, "", "test") +if err != nil { + ... } +defer appender.Close() +err = appender.AppendRow(...) +if err != nil { + ... +} + +// Optional, if you want to access the appended rows immediately. +err = appender.Flush() +if err != nil { + ... +} ``` ## DuckDB Apache Arrow Interface @@ -148,53 +104,31 @@ func main() { If you want to use the [DuckDB Arrow Interface](https://duckdb.org/docs/api/c/api#arrow-interface), you can obtain a new Arrow by supplying a DuckDB connection to `NewArrowFromConn()`. ```go -package main - -import ( - "context" - "database/sql" - - "github.com/marcboeker/go-duckdb" -) - -func main() { - db, err := sql.Open("duckdb", "") - if err != nil { - panic(err) - } - defer db.Close() - - conn, err := db.Conn(context.Background()) - if err != nil { - panic(err) - } - defer conn.Close() - - // Use Raw method of sql.Conn to obtain the driver connection. - err = conn.Raw(func(driverConn any) error { - // Create Arrow interface from the connection. - ar, err := duckdb.NewArrowFromConn(driverConn) - if err != nil { - panic(err) - } - - rdr, err := ar.Query(context.Background(), "SELECT * FROM generate_series(1, 10)") - if err != nil { - panic(err) - } - defer rdr.Release() - - for rdr.Next() { - // Process records. - } - - return nil - }) - if err != nil { - panic(err) - } +connector, err := duckdb.NewConnector("", nil) +if err != nil { + ... +} +conn, err := connector.Connect(context.Background()) +if err != nil { + ... +} +defer conn.Close() + +// Retrieve Arrow from connection. +ar, err := duckdb.NewArrowFromConn(conn) +if err != nil { + ... } +rdr, err := ar.QueryContext(context.Background(), "SELECT * FROM generate_series(1, 10)") +if err != nil { + ... +} +defer rdr.Release() + +for rdr.Next() { + // Process records. +} ``` ## Linking DuckDB diff --git a/appender.go b/appender.go index 48114985..babd2fac 100644 --- a/appender.go +++ b/appender.go @@ -30,7 +30,7 @@ type Appender struct { } // NewAppenderFromConn returns a new Appender from a DuckDB driver connection. -func NewAppenderFromConn(driverConn any, schema, table string) (*Appender, error) { +func NewAppenderFromConn(driverConn driver.Conn, schema string, table string) (*Appender, error) { dbConn, ok := driverConn.(*conn) if !ok { return nil, fmt.Errorf("not a duckdb driver connection") @@ -40,7 +40,7 @@ func NewAppenderFromConn(driverConn any, schema, table string) (*Appender, error panic("database/sql/driver: misuse of duckdb driver: Appender after Close") } - var schemastr *C.char + var schemastr *(C.char) if schema != "" { schemastr = C.CString(schema) defer C.free(unsafe.Pointer(schemastr)) diff --git a/appender_test.go b/appender_test.go index 70530cbc..ac7a06c8 100644 --- a/appender_test.go +++ b/appender_test.go @@ -57,27 +57,13 @@ func randString(n int) string { } func TestAppender(t *testing.T) { - connector, err := OpenConnector("", nil) + c, err := NewConnector("", nil) require.NoError(t, err) - defer connector.Close() - db := sql.OpenDB(connector) + db := sql.OpenDB(c) createAppenderTable(db, t) defer db.Close() - // Test that appender can be opened from the connector directly - driverConn, err := connector.Connect(context.Background()) - require.NoError(t, err) - - appender, err := NewAppenderFromConn(driverConn, "", "test") - require.NoError(t, err) - - err = appender.Close() - require.NoError(t, err) - - err = driverConn.Close() - require.NoError(t, err) - type dataRow struct { ID int UInt8 uint8 @@ -117,45 +103,39 @@ func TestAppender(t *testing.T) { Bool: randBool(), } } - var rows []dataRow + rows := []dataRow{} for i := 0; i < numAppenderTestRows; i++ { rows = append(rows, randRow(i)) } - conn, err := db.Conn(context.Background()) + conn, err := c.Connect(context.Background()) require.NoError(t, err) defer conn.Close() - err = conn.Raw(func(driverConn any) error { - appender, err := NewAppenderFromConn(driverConn, "", "test") - require.NoError(t, err) - defer appender.Close() - - for _, row := range rows { - err := appender.AppendRow( - row.ID, - row.UInt8, - row.Int8, - row.UInt16, - row.Int16, - row.UInt32, - row.Int32, - row.UInt64, - row.Int64, - row.Timestamp, - row.Float, - row.Double, - row.String, - row.Bool, - ) - require.NoError(t, err) - } - - err = appender.Flush() + appender, err := NewAppenderFromConn(conn, "", "test") + require.NoError(t, err) + defer appender.Close() + + for _, row := range rows { + err := appender.AppendRow( + row.ID, + row.UInt8, + row.Int8, + row.UInt16, + row.Int16, + row.UInt32, + row.Int32, + row.UInt64, + row.Int64, + row.Timestamp, + row.Float, + row.Double, + row.String, + row.Bool, + ) require.NoError(t, err) - - return nil - }) + } + err = appender.Flush() require.NoError(t, err) res, err := db.QueryContext( diff --git a/arrow.go b/arrow.go index 8ecc892a..b1c3f3bc 100644 --- a/arrow.go +++ b/arrow.go @@ -2,12 +2,55 @@ package duckdb /* #include -#include +#include + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE */ import "C" import ( "context" + "database/sql/driver" "errors" "fmt" "unsafe" @@ -24,7 +67,7 @@ type Arrow struct { } // NewArrowFromConn returns a new Arrow from a DuckDB driver connection. -func NewArrowFromConn(driverConn any) (*Arrow, error) { +func NewArrowFromConn(driverConn driver.Conn) (*Arrow, error) { dbConn, ok := driverConn.(*conn) if !ok { return nil, fmt.Errorf("not a duckdb driver connection") @@ -37,9 +80,14 @@ func NewArrowFromConn(driverConn any) (*Arrow, error) { return &Arrow{c: dbConn}, nil } +// Deprecated: Use QueryContext instead. +func (a *Arrow) Query(query string, args ...any) (array.RecordReader, error) { + return a.QueryContext(context.Background(), query, args) +} + // Query prepares statements, executes them, returns Apache Arrow array.RecordReader as a result of the last // executed statement. Arguments are bound to the last statement. -func (a *Arrow) Query(ctx context.Context, query string, args ...any) (array.RecordReader, error) { +func (a *Arrow) QueryContext(ctx context.Context, query string, args ...any) (array.RecordReader, error) { if a.c.closed { panic("database/sql/driver: misuse of duckdb driver: Arrow.Query after Close") } @@ -71,14 +119,13 @@ func (a *Arrow) Query(ctx context.Context, query string, args ...any) (array.Rec } defer stmt.Close() - res, err := stmt.executeArrow(args...) + res, err := a.execute(stmt, a.anyArgsToNamedArgs(args)) if err != nil { return nil, err } - defer C.duckdb_destroy_arrow(res) - sc, err := queryArrowSchema(res) + sc, err := a.queryArrowSchema(res) if err != nil { return nil, err } @@ -92,8 +139,7 @@ func (a *Arrow) Query(ctx context.Context, query string, args ...any) (array.Rec rowCount := uint64(C.duckdb_arrow_row_count(*res)) - var retrievedRows uint64 - + var retrievedRows uint64 = 0 for retrievedRows < rowCount { select { case <-ctx.Done(): @@ -101,13 +147,12 @@ func (a *Arrow) Query(ctx context.Context, query string, args ...any) (array.Rec default: } - rec, err := queryArrowArray(res, sc) + rec, err := a.queryArrowArray(res, sc) if err != nil { return nil, err } recs = append(recs, rec) - retrievedRows += uint64(rec.NumRows()) } @@ -115,7 +160,7 @@ func (a *Arrow) Query(ctx context.Context, query string, args ...any) (array.Rec } // queryArrowSchema fetches the internal arrow schema from the arrow result. -func queryArrowSchema(res *C.duckdb_arrow) (*arrow.Schema, error) { +func (a *Arrow) queryArrowSchema(res *C.duckdb_arrow) (*arrow.Schema, error) { cdSchema := (*cdata.CArrowSchema)(unsafe.Pointer(C.calloc(1, C.sizeof_struct_ArrowSchema))) defer func() { cdata.ReleaseCArrowSchema(cdSchema) @@ -141,7 +186,7 @@ func queryArrowSchema(res *C.duckdb_arrow) (*arrow.Schema, error) { // // This function can be called multiple time to get next chunks, // which will free the previous out_array. -func queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Record, error) { +func (a *Arrow) queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Record, error) { cdArr := (*cdata.CArrowArray)(unsafe.Pointer(C.calloc(1, C.sizeof_struct_ArrowArray))) defer func() { cdata.ReleaseCArrowArray(cdArr) @@ -162,3 +207,35 @@ func queryArrowArray(res *C.duckdb_arrow, sc *arrow.Schema) (arrow.Record, error return rec, nil } + +func (a *Arrow) execute(s *stmt, args []driver.NamedValue) (*C.duckdb_arrow, error) { + if s.closed { + panic("database/sql/driver: misuse of duckdb driver: executeArrow after Close") + } + + if err := s.start(args); err != nil { + return nil, err + } + + var res C.duckdb_arrow + if state := C.duckdb_execute_prepared_arrow(*s.stmt, &res); state == C.DuckDBError { + dbErr := C.GoString(C.duckdb_query_arrow_error(res)) + C.duckdb_destroy_arrow(&res) + return nil, fmt.Errorf("duckdb_execute_prepared_arrow: %v", dbErr) + } + + return &res, nil +} + +func (a *Arrow) anyArgsToNamedArgs(args []any) []driver.NamedValue { + if len(args) == 0 { + return nil + } + + values := make([]driver.Value, len(args)) + for i, arg := range args { + values[i] = arg + } + + return argsToNamedArgs(values) +} diff --git a/arrow.h b/arrow.h deleted file mode 100644 index 60fe1f40..00000000 --- a/arrow.h +++ /dev/null @@ -1,43 +0,0 @@ -#include - -#ifndef ARROW_C_DATA_INTERFACE -#define ARROW_C_DATA_INTERFACE - -#define ARROW_FLAG_DICTIONARY_ORDERED 1 -#define ARROW_FLAG_NULLABLE 2 -#define ARROW_FLAG_MAP_KEYS_SORTED 4 - -struct ArrowSchema { - // Array type description - const char* format; - const char* name; - const char* metadata; - int64_t flags; - int64_t n_children; - struct ArrowSchema** children; - struct ArrowSchema* dictionary; - - // Release callback - void (*release)(struct ArrowSchema*); - // Opaque producer-specific data - void* private_data; -}; - -struct ArrowArray { - // Array data description - int64_t length; - int64_t null_count; - int64_t offset; - int64_t n_buffers; - int64_t n_children; - const void** buffers; - struct ArrowArray** children; - struct ArrowArray* dictionary; - - // Release callback - void (*release)(struct ArrowArray*); - // Opaque producer-specific data - void* private_data; -}; - -#endif // ARROW_C_DATA_INTERFACE diff --git a/arrow_test.go b/arrow_test.go index 4a43be16..d1ebb4f5 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -2,6 +2,7 @@ package duckdb import ( "context" + "database/sql/driver" "testing" "github.com/stretchr/testify/require" @@ -9,111 +10,101 @@ import ( func TestArrow(t *testing.T) { t.Parallel() - db := openDB(t) - db.SetMaxOpenConns(2) // set connection pool size greater than 1 defer db.Close() - t.Run("select_series", func(t *testing.T) { - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - defer conn.Close() - - err = conn.Raw(func(driverConn any) error { - ar, err := NewArrowFromConn(driverConn) - require.NoError(t, err) + createTable(db, t) - rdr, err := ar.Query(context.Background(), "SELECT * FROM generate_series(1, 10)") - require.NoError(t, err, "should query arrow") - defer rdr.Release() + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() - for rdr.Next() { - rec := rdr.Record() - require.Equal(t, int64(10), rec.NumRows()) - bs, err := rec.MarshalJSON() - require.NoError(t, err) - - t.Log(string(bs)) - } + t.Run("select series", func(t *testing.T) { + c, err := NewConnector("", nil) + require.NoError(t, err) - require.NoError(t, rdr.Err()) + conn, err := c.Connect(context.Background()) + require.NoError(t, err) + defer conn.Close() - return nil - }) + ar, err := NewArrowFromConn(conn) require.NoError(t, err) - }) - t.Run("select_long_series", func(t *testing.T) { - conn, err := db.Conn(context.Background()) + rdr, err := ar.QueryContext(context.Background(), "SELECT * FROM generate_series(1, 10)") require.NoError(t, err) - defer conn.Close() + defer rdr.Release() - err = conn.Raw(func(driverConn any) error { - ar, err := NewArrowFromConn(driverConn) + for rdr.Next() { + rec := rdr.Record() + require.Equal(t, int64(10), rec.NumRows()) require.NoError(t, err) + } - rdr, err := ar.Query(context.Background(), "SELECT * FROM generate_series(1, 10000)") - require.NoError(t, err, "should query arrow") - defer rdr.Release() + require.NoError(t, rdr.Err()) + }) - var totalRows int64 - for rdr.Next() { - rec := rdr.Record() - totalRows += rec.NumRows() - } + t.Run("select long series", func(t *testing.T) { + c, err := NewConnector("", nil) + require.NoError(t, err) - require.Equal(t, int64(10000), totalRows) + conn, err := c.Connect(context.Background()) + require.NoError(t, err) + defer conn.Close() - require.NoError(t, rdr.Err()) + ar, err := NewArrowFromConn(conn) + require.NoError(t, err) - return nil - }) + rdr, err := ar.QueryContext(context.Background(), "SELECT * FROM generate_series(1, 10000)") require.NoError(t, err) - }) + defer rdr.Release() - createTable(db, t) + var totalRows int64 + for rdr.Next() { + rec := rdr.Record() + totalRows += rec.NumRows() + } - t.Run("query_table_and_filter_results", func(t *testing.T) { - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - defer conn.Close() + require.Equal(t, int64(10000), totalRows) + require.NoError(t, rdr.Err()) + }) + + t.Run("query table and filter results", func(t *testing.T) { err = conn.Raw(func(driverConn any) error { - ar, err := NewArrowFromConn(driverConn) + conn, ok := driverConn.(driver.Conn) + require.True(t, ok) + + ar, err := NewArrowFromConn(conn) require.NoError(t, err) - rdr, err := ar.Query(context.Background(), "SELECT bar, baz FROM foo WHERE baz > ?", 12344) - require.NoError(t, err, "should query arrow") - defer rdr.Release() + reader, err := ar.QueryContext(context.Background(), "SELECT bar, baz FROM foo WHERE baz > ?", 12344) + require.NoError(t, err) + defer reader.Release() - for rdr.Next() { - rec := rdr.Record() + for reader.Next() { + rec := reader.Record() require.Equal(t, int64(1), rec.NumRows()) bs, err := rec.MarshalJSON() require.NoError(t, err) t.Log(string(bs)) } - - require.NoError(t, rdr.Err()) - + require.NoError(t, reader.Err()) return nil }) require.NoError(t, err) }) t.Run("query error", func(t *testing.T) { - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - defer conn.Close() + err := conn.Raw(func(driverConn any) error { + conn, ok := driverConn.(driver.Conn) + require.True(t, ok) - err = conn.Raw(func(driverConn any) error { - ar, err := NewArrowFromConn(driverConn) + ar, err := NewArrowFromConn(conn) require.NoError(t, err) - _, err = ar.Query(context.Background(), "select bar") + _, err = ar.QueryContext(context.Background(), "SELECT bar") require.Error(t, err) - return nil }) require.NoError(t, err) diff --git a/duckdb.go b/duckdb.go index 06da1b00..f9b2514f 100644 --- a/duckdb.go +++ b/duckdb.go @@ -15,7 +15,6 @@ import ( "database/sql/driver" "errors" "fmt" - "io" "net/url" "strings" "unsafe" @@ -36,20 +35,15 @@ func (d Driver) Open(dataSourceName string) (driver.Conn, error) { } func (Driver) OpenConnector(dataSourceName string) (driver.Connector, error) { - return openConnector(dataSourceName, func(execerContext driver.ExecerContext) error { return nil }) + return createConnector(dataSourceName, func(execerContext driver.ExecerContext) error { return nil }) } -type ConnectorCloser interface { - driver.Connector - io.Closer +// NewConnector creates a new Connector for the DuckDB database. +func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) { + return createConnector(dsn, connInitFn) } -// OpenConnector opens a new connector for the DuckDB database. -func OpenConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (ConnectorCloser, error) { - return openConnector(dsn, connInitFn) -} - -func openConnector(dataSourceName string, connInitFn func(execer driver.ExecerContext) error) (*connector, error) { +func createConnector(dataSourceName string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) { var db C.duckdb_database parsedDSN, err := url.Parse(dataSourceName) @@ -99,7 +93,6 @@ func (c *connector) Connect(context.Context) (driver.Conn, error) { if state := C.duckdb_connect(*c.db, &con); state == C.DuckDBError { return nil, errOpen } - conn := &conn{con: &con} // Call the connection init function if defined @@ -133,8 +126,9 @@ func prepareConfig(options map[string][]string) (C.duckdb_config, error) { for k, v := range options { if len(v) > 0 { - if err := setConfig(config, k, v[0]); err != nil { - return nil, err + state := C.duckdb_set_config(config, C.CString(k), C.CString(v[0])) + if state == C.DuckDBError { + return nil, fmt.Errorf("%w: affected config option %s=%s", errPrepareConfig, k, v[0]) } } } @@ -142,20 +136,6 @@ func prepareConfig(options map[string][]string) (C.duckdb_config, error) { return config, nil } -func setConfig(config C.duckdb_config, name, option string) error { - cName := C.CString(name) - defer C.free(unsafe.Pointer(cName)) - - cOption := C.CString(option) - defer C.free(unsafe.Pointer(cOption)) - - if state := C.duckdb_set_config(config, cName, cOption); state == C.DuckDBError { - return fmt.Errorf("%w: affected config option %s=%s", errPrepareConfig, name, option) - } - - return nil -} - var ( errOpen = errors.New("could not open database") errParseConfig = errors.New("could not parse config for database") diff --git a/duckdb_test.go b/duckdb_test.go index 8fc828a8..d9303973 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -99,7 +99,7 @@ func TestConnPool(t *testing.T) { } func TestConnInit(t *testing.T) { - connector, err := OpenConnector("", func(execer driver.ExecerContext) error { + connector, err := NewConnector("", func(execer driver.ExecerContext) error { bootQueries := []string{ "INSTALL 'json'", "LOAD 'json'", @@ -114,8 +114,6 @@ func TestConnInit(t *testing.T) { return nil }) require.NoError(t, err) - defer connector.Close() - db := sql.OpenDB(connector) db.SetMaxOpenConns(2) // set connection pool size greater than 1 defer db.Close() @@ -1102,6 +1100,22 @@ func TestParquetExtension(t *testing.T) { require.NoError(t, err) } +func TestQueryTimeout(t *testing.T) { + db := openDB(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*250) + defer cancel() + + now := time.Now() + _, err := db.ExecContext(ctx, "CREATE TABLE test AS SELECT * FROM range(10000000) t1, range(1000000) t2;") + require.ErrorIs(t, err, context.DeadlineExceeded) + + // a very defensive time check, but should be good enough + // the query takes much longer than 10 seconds + require.Less(t, time.Since(now), 10*time.Second) +} + func openDB(t *testing.T) *sql.DB { db, err := sql.Open("duckdb", "") require.NoError(t, err) diff --git a/statement.go b/statement.go index 0e035043..69df0317 100644 --- a/statement.go +++ b/statement.go @@ -202,73 +202,35 @@ func (s *stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb } defer C.duckdb_destroy_pending(&pendingRes) - for { + done := make(chan bool) + defer close(done) + + go func() { select { - // if context is cancelled or deadline exceeded, don't execute further case <-ctx.Done(): - return nil, ctx.Err() - default: - // continue - } - state := C.duckdb_pending_execute_task(pendingRes) - if state == C.DUCKDB_PENDING_ERROR { - dbErr := C.GoString(C.duckdb_pending_error(pendingRes)) - return nil, errors.New(dbErr) + // also need to interrupt to cancel the query + C.duckdb_interrupt(*s.c.con) + return + case <-done: + return } - if C.duckdb_pending_execution_is_finished(state) { - break - } - } + }() var res C.duckdb_result if state := C.duckdb_execute_pending(pendingRes, &res); state == C.DuckDBError { + if ctx.Err() != nil { + return nil, ctx.Err() + } + dbErr := C.GoString(C.duckdb_result_error(&res)) C.duckdb_destroy_result(&res) return nil, errors.New(dbErr) } - return &res, nil -} - -// executeArrow executes the prepared statement with the given bound parameters, and returns an arrow query result. -// If the query fails to execute, returns error from DuckDB by calling duckdb_query_arrow_error. -// Note that after running queryArrow, C.duckdb_destroy_arrow must be called on the result object if there is no error. -func (s *stmt) executeArrow(args ...any) (*C.duckdb_arrow, error) { - if s.closed { - panic("database/sql/driver: misuse of duckdb driver: executeArrow after Close") - } - - if err := s.start(anyArgsToNamedArgs(args)); err != nil { - return nil, err - } - - var res C.duckdb_arrow - if state := C.duckdb_execute_prepared_arrow(*s.stmt, &res); state == C.DuckDBError { - dbErr := C.GoString(C.duckdb_query_arrow_error(res)) - C.duckdb_destroy_arrow(&res) - return nil, fmt.Errorf("duckdb_execute_prepared_arrow: %v", dbErr) - } return &res, nil } -func anyArgsToNamedArgs(args []any) []driver.NamedValue { - if len(args) == 0 { - return nil - } - - values := make([]driver.Value, len(args)) - for i, arg := range args { - values[i] = arg - } - - return argsToNamedArgs(values) -} - func argsToNamedArgs(values []driver.Value) []driver.NamedValue { - if len(values) == 0 { - return nil - } - args := make([]driver.NamedValue, len(values)) for n, param := range values { args[n].Value = param