diff --git a/Makefile b/Makefile index d9629f1c..49f2926a 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,8 @@ examples: go run examples/simple/main.go go run examples/appender/main.go go run examples/scalar_udf/main.go + go run examples/table_udf/main.go + go run examples/table_udf_parallel/main.go .PHONY: test test: diff --git a/errors.go b/errors.go index efad36cf..6cdc2ef5 100644 --- a/errors.go +++ b/errors.go @@ -73,6 +73,7 @@ const ( ) var ( + errInternal = errors.New("internal error: please file a bug report at go-duckdb") errAPI = errors.New("API error") errVectorSize = errors.New("data chunks cannot exceed duckdb's internal vector size") @@ -105,6 +106,12 @@ var ( errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set") errScalarUDFAddToSet = fmt.Errorf("%w: could not add the function to the set", errScalarUDFCreateSet) + errTableUDFCreate = errors.New("could not create table UDF") + errTableUDFNoName = fmt.Errorf("%w: missing name", errTableUDFCreate) + errTableUDFMissingBindArgs = fmt.Errorf("%w: missing bind arguments", errTableUDFCreate) + errTableUDFArgumentIsNil = fmt.Errorf("%w: argument is nil", errTableUDFCreate) + errTableUDFColumnTypeIsNil = fmt.Errorf("%w: column type is nil", errTableUDFCreate) + errProfilingInfoEmpty = errors.New("no profiling information available for this connection") // Errors not covered in tests. diff --git a/examples/table_udf/main.go b/examples/table_udf/main.go new file mode 100644 index 00000000..36a3f865 --- /dev/null +++ b/examples/table_udf/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + + "github.com/marcboeker/go-duckdb" +) + +type incrementTableUDF struct { + tableSize int64 + currentRow int64 +} + +func bindTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.RowTableSource, error) { + return &incrementTableUDF{ + currentRow: 0, + tableSize: args[0].(int64), + }, nil +} + +func (udf *incrementTableUDF) ColumnInfos() []duckdb.ColumnInfo { + t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) + check(err) + return []duckdb.ColumnInfo{{Name: "result", T: t}} +} + +func (udf *incrementTableUDF) Init() {} + +func (udf *incrementTableUDF) FillRow(row duckdb.Row) (bool, error) { + if udf.currentRow+1 > udf.tableSize { + return false, nil + } + udf.currentRow++ + err := duckdb.SetRowValue(row, 0, udf.currentRow) + return true, err +} + +func (udf *incrementTableUDF) Cardinality() *duckdb.CardinalityInfo { + return &duckdb.CardinalityInfo{ + Cardinality: uint(udf.tableSize), + Exact: true, + } +} + +func main() { + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + check(err) + + conn, err := db.Conn(context.Background()) + check(err) + + t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) + check(err) + udf := duckdb.RowTableFunction{ + Config: duckdb.TableFunctionConfig{ + Arguments: []duckdb.TypeInfo{t}, + }, + BindArguments: bindTableUDF, + } + + err = duckdb.RegisterTableUDF(conn, "increment", udf) + check(err) + + rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(100)`) + check(err) + + // Get the column names. + columns, err := rows.Columns() + check(err) + + values := make([]interface{}, len(columns)) + args := make([]interface{}, len(values)) + for i := range values { + args[i] = &values[i] + } + + rowSum := int64(0) + for rows.Next() { + err = rows.Scan(args...) + check(err) + for _, value := range values { + // Keep in mind that the value can be nil for NULL values. + // This never happens in this example, so we don't check for it. + rowSum += value.(int64) + } + } + fmt.Printf("row sum: %d", rowSum) + + check(rows.Close()) + check(conn.Close()) + check(db.Close()) +} + +func check(err interface{}) { + if err != nil { + panic(err) + } +} diff --git a/examples/table_udf_parallel/main.go b/examples/table_udf_parallel/main.go new file mode 100644 index 00000000..e99ec685 --- /dev/null +++ b/examples/table_udf_parallel/main.go @@ -0,0 +1,156 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "github.com/marcboeker/go-duckdb" +) + +type ( + parallelIncrementTableUDF struct { + lock *sync.Mutex + claimed int64 + n int64 + } + + localTableState struct { + start int64 + end int64 + } +) + +func bindParallelTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.ParallelRowTableSource, error) { + return ¶llelIncrementTableUDF{ + lock: &sync.Mutex{}, + claimed: 0, + n: args[0].(int64), + }, nil +} + +func (udf *parallelIncrementTableUDF) ColumnInfos() []duckdb.ColumnInfo { + t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) + check(err) + return []duckdb.ColumnInfo{{Name: "result", T: t}} +} + +func (udf *parallelIncrementTableUDF) Init() duckdb.ParallelTableSourceInfo { + return duckdb.ParallelTableSourceInfo{ + MaxThreads: 8, + } +} + +func (udf *parallelIncrementTableUDF) NewLocalState() any { + return &localTableState{ + start: 0, + end: -1, + } +} + +func (udf *parallelIncrementTableUDF) FillRow(localState any, row duckdb.Row) (bool, error) { + state := localState.(*localTableState) + + if state.start >= state.end { + // Claim a new work unit. + udf.lock.Lock() + remaining := udf.n - udf.claimed + + if remaining <= 0 { + // No more work. + udf.lock.Unlock() + return false, nil + } else if remaining >= 2024 { + remaining = 2024 + } + + state.start = udf.claimed + udf.claimed += remaining + state.end = udf.claimed + udf.lock.Unlock() + } + + state.start++ + err := duckdb.SetRowValue(row, 0, state.start) + return true, err +} + +func (udf *parallelIncrementTableUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *parallelIncrementTableUDF) GetTypes() []any { + return []any{0} +} + +func (udf *parallelIncrementTableUDF) Cardinality() *duckdb.CardinalityInfo { + return nil +} + +func (udf *parallelIncrementTableUDF) GetFunction() duckdb.ParallelRowTableFunction { + t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) + check(err) + + return duckdb.ParallelRowTableFunction{ + Config: duckdb.TableFunctionConfig{ + Arguments: []duckdb.TypeInfo{t}, + }, + BindArguments: bindParallelTableUDF, + } +} + +func main() { + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + check(err) + + conn, err := db.Conn(context.Background()) + check(err) + + t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) + check(err) + udf := duckdb.ParallelRowTableFunction{ + Config: duckdb.TableFunctionConfig{ + Arguments: []duckdb.TypeInfo{t}, + }, + BindArguments: bindParallelTableUDF, + } + + err = duckdb.RegisterTableUDF(conn, "increment", udf) + check(err) + + rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(10000)`) + check(err) + + // Get the column names. + columns, err := rows.Columns() + check(err) + + values := make([]interface{}, len(columns)) + args := make([]interface{}, len(values)) + for i := range values { + args[i] = &values[i] + } + + rowSum := int64(0) + for rows.Next() { + err = rows.Scan(args...) + check(err) + for _, value := range values { + // Keep in mind that the value could be nil for NULL values. + // This never happens in our case, so we don't check for it. + rowSum += value.(int64) + } + } + fmt.Printf("row sum: %d", rowSum) + + check(rows.Close()) + check(conn.Close()) + check(db.Close()) +} + +func check(err interface{}) { + if err != nil { + panic(err) + } +} diff --git a/replacement_scan.go b/replacement_scan.go index 6a1e55d1..1271fcb5 100644 --- a/replacement_scan.go +++ b/replacement_scan.go @@ -33,8 +33,8 @@ func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.ccha tFunc, params, err := scanner(C.GoString(table_name)) if err != nil { errStr := C.CString(err.Error()) + defer C.duckdb_free(unsafe.Pointer(errStr)) C.duckdb_replacement_scan_set_error(info, errStr) - C.duckdb_free(unsafe.Pointer(errStr)) return } diff --git a/row.go b/row.go new file mode 100644 index 00000000..0205889a --- /dev/null +++ b/row.go @@ -0,0 +1,37 @@ +package duckdb + +/* +#include +*/ +import "C" + +// Row represents one row in duckdb. It references the internal vectors. +type Row struct { + chunk *DataChunk + r C.idx_t + projection []int +} + +// IsProjected returns whether the column is projected. +func (r Row) IsProjected(colIdx int) bool { + return r.projection[colIdx] != -1 +} + +// SetRowValue sets the value at colIdx to val. +// Returns an error on failure, and nil for non-projected columns. +func SetRowValue[T any](row Row, colIdx int, val T) error { + projectedIdx := row.projection[colIdx] + if projectedIdx < 0 || projectedIdx >= len(row.chunk.columns) { + return nil + } + vec := row.chunk.columns[projectedIdx] + return setVectorVal(&vec, row.r, val) +} + +// SetRowValue sets the value at colIdx to val. Returns an error on failure. +func (r Row) SetRowValue(colIdx int, val any) error { + if !r.IsProjected(colIdx) { + return nil + } + return r.chunk.SetValue(colIdx, int(r.r), val) +} diff --git a/scalar_udf.go b/scalarUDF.go similarity index 95% rename from scalar_udf.go rename to scalarUDF.go index 2f055368..d1addfa7 100644 --- a/scalar_udf.go +++ b/scalarUDF.go @@ -4,7 +4,7 @@ package duckdb #include void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector); -void scalar_udf_delete_callback(void *); +void udf_delete_callback(void *); typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector); */ @@ -89,8 +89,8 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { // functions contains all ScalarFunc functions of the scalar function set. func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error { cName := C.CString(name) + defer C.duckdb_free(unsafe.Pointer(cName)) set := C.duckdb_create_scalar_function_set(cName) - C.duckdb_free(unsafe.Pointer(cName)) // Create each function and add it to the set. for i, f := range functions { @@ -125,10 +125,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err //export scalar_udf_callback func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) { extraInfo := C.duckdb_scalar_function_get_extra_info(function_info) - - // extraInfo is a void* pointer to our pinned ScalarFunc f. - h := *(*cgo.Handle)(unsafe.Pointer(extraInfo)) - function := h.Value().(pinnedValue[ScalarFunc]).value + function := getPinned[ScalarFunc](extraInfo) // Initialize the input chunk. var inputChunk DataChunk @@ -191,11 +188,6 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da } } -//export scalar_udf_delete_callback -func scalar_udf_delete_callback(info unsafe.Pointer) { - udf_delete_callback(info) -} - func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error { // Set variadic input parameters. if config.VariadicTypeInfo != nil { @@ -252,8 +244,8 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro // Set the name. cName := C.CString(name) + defer C.duckdb_free(unsafe.Pointer(cName)) C.duckdb_scalar_function_set_name(function, cName) - C.duckdb_free(unsafe.Pointer(cName)) // Configure the scalar function. config := f.Config() @@ -285,7 +277,7 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro C.duckdb_scalar_function_set_extra_info( function, unsafe.Pointer(&h), - C.duckdb_delete_callback_t(C.scalar_udf_delete_callback)) + C.duckdb_delete_callback_t(C.udf_delete_callback)) return function, nil } diff --git a/scalar_udf_test.go b/scalarUDF_test.go similarity index 100% rename from scalar_udf_test.go rename to scalarUDF_test.go diff --git a/tableUDF.go b/tableUDF.go new file mode 100644 index 00000000..9d6316df --- /dev/null +++ b/tableUDF.go @@ -0,0 +1,484 @@ +package duckdb + +/* +#include + +void table_udf_bind_row(duckdb_bind_info info); +void table_udf_bind_chunk(duckdb_bind_info info); +void table_udf_bind_parallel_row(duckdb_bind_info info); +void table_udf_bind_parallel_chunk(duckdb_bind_info info); + +void table_udf_init(duckdb_init_info info); +void table_udf_init_parallel(duckdb_init_info info); +void table_udf_local_init(duckdb_init_info info); + +// See https://golang.org/issue/19837 +void table_udf_row_callback(duckdb_function_info, duckdb_data_chunk); +void table_udf_chunk_callback(duckdb_function_info, duckdb_data_chunk); + +// See https://golang.org/issue/19835. +typedef void (*init)(duckdb_function_info); +typedef void (*bind)(duckdb_function_info); +typedef void (*callback)(duckdb_function_info, duckdb_data_chunk); + +void udf_delete_callback(void *); +*/ +import "C" + +import ( + "database/sql" + "runtime" + "runtime/cgo" + "unsafe" +) + +type ( + // ColumnInfo contains the metadata of a column. + ColumnInfo struct { + // The column Name. + Name string + // The column type T. + T TypeInfo + } + + // CardinalityInfo contains the cardinality of a (table) function. + // If it is impossible or difficult to determine the exact cardinality, an approximate cardinality may be used. + CardinalityInfo struct { + // The absolute Cardinality. + Cardinality uint + // IsExact indicates whether the cardinality is exact. + Exact bool + } + + // ParallelTableSourceInfo contains information for initializing a parallelism-aware table source. + ParallelTableSourceInfo struct { + // MaxThreads is the maximum number of threads on which to run the table source function. + // If set to 0, it uses DuckDB's default thread configuration. + MaxThreads int + } + + tableFunctionData struct { + fun any + projection []int + } + + tableSource interface { + // ColumnInfos returns column information for each column of the table function. + ColumnInfos() []ColumnInfo + // Cardinality returns the cardinality information of the table function. + // Optionally, if no cardinality exists, it may return nil. + Cardinality() *CardinalityInfo + } + + parallelTableSource interface { + tableSource + // Init the table source. + // Additionally, it returns information for the parallelism-aware table source. + Init() ParallelTableSourceInfo + // NewLocalState returns a thread-local execution state. + // It must return a pointer or a reference type for correct state updates. + // go-duckdb does not prevent non-reference values. + NewLocalState() any + } + + sequentialTableSource interface { + tableSource + // Init the table source. + Init() + } + + // A RowTableSource represents anything that produces rows in a non-vectorised way. + // The cardinality is requested before function initialization. + // After initializing the RowTableSource, go-duckdb requests the rows. + // It sequentially calls the FillRow method with a single thread. + RowTableSource interface { + sequentialTableSource + // FillRow takes a Row and fills it with values. + // It returns true, if there are more rows to fill. + FillRow(Row) (bool, error) + } + + // A ParallelRowTableSource represents anything that produces rows in a non-vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ParallelRowTableSource, go-duckdb requests the rows. + // It simultaneously calls the FillRow method with multiple threads. + // If ParallelTableSourceInfo.MaxThreads is greater than one, FillRow must use synchronisation + // primitives to avoid race conditions. + ParallelRowTableSource interface { + parallelTableSource + // FillRow takes a Row and fills it with values. + // It returns true, if there are more rows to fill. + FillRow(any, Row) (bool, error) + } + + // A ChunkTableSource represents anything that produces rows in a vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ChunkTableSource, go-duckdb requests the rows. + // It sequentially calls the FillChunk method with a single thread. + ChunkTableSource interface { + sequentialTableSource + // FillChunk takes a Chunk and fills it with values. + // It returns true, if there are more chunks to fill. + FillChunk(DataChunk) error + } + + // A ParallelChunkTableSource represents anything that produces rows in a vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ParallelChunkTableSource, go-duckdb requests the rows. + // It simultaneously calls the FillChunk method with multiple threads. + // If ParallelTableSourceInfo.MaxThreads is greater than one, FillChunk must use synchronization + // primitives to avoid race conditions. + ParallelChunkTableSource interface { + parallelTableSource + // FillChunk takes a Chunk and fills it with values. + // It returns true, if there are more chunks to fill. + FillChunk(any, DataChunk) error + } + + // TableFunctionConfig contains any information passed to DuckDB when registering the table function. + TableFunctionConfig struct { + // The Arguments of the table function. + Arguments []TypeInfo + // The NamedArguments of the table function. + NamedArguments map[string]TypeInfo + } + + // TableFunction implements different table function types: + // RowTableFunction, ParallelRowTableFunction, ChunkTableFunction, and ParallelChunkTableFunction. + TableFunction interface { + RowTableFunction | ParallelRowTableFunction | ChunkTableFunction | ParallelChunkTableFunction + } + + // A RowTableFunction is a type which can be bound to return a RowTableSource. + RowTableFunction = tableFunction[RowTableSource] + // A ParallelRowTableFunction is a type which can be bound to return a ParallelRowTableSource. + ParallelRowTableFunction = tableFunction[ParallelRowTableSource] + // A ChunkTableFunction is a type which can be bound to return a ChunkTableSource. + ChunkTableFunction = tableFunction[ChunkTableSource] + // A ParallelChunkTableFunction is a type which can be bound to return a ParallelChunkTableSource. + ParallelChunkTableFunction = tableFunction[ParallelChunkTableSource] + + tableFunction[T any] struct { + // Config returns the table function configuration, including the function arguments. + Config TableFunctionConfig + // BindArguments binds the arguments and returns a TableSource. + BindArguments func(named map[string]any, args ...any) (T, error) + } +) + +func (tfd *tableFunctionData) setColumnCount(info C.duckdb_init_info) { + count := C.duckdb_init_get_column_count(info) + for i := 0; i < int(count); i++ { + srcPos := C.duckdb_init_get_column_index(info, C.idx_t(i)) + tfd.projection[int(srcPos)] = i + } +} + +//export table_udf_bind_row +func table_udf_bind_row(info C.duckdb_bind_info) { + udfBindTyped[RowTableSource](info) +} + +//export table_udf_bind_chunk +func table_udf_bind_chunk(info C.duckdb_bind_info) { + udfBindTyped[ChunkTableSource](info) +} + +//export table_udf_bind_parallel_row +func table_udf_bind_parallel_row(info C.duckdb_bind_info) { + udfBindTyped[ParallelRowTableSource](info) +} + +//export table_udf_bind_parallel_chunk +func table_udf_bind_parallel_chunk(info C.duckdb_bind_info) { + udfBindTyped[ParallelChunkTableSource](info) +} + +func udfBindTyped[T tableSource](info C.duckdb_bind_info) { + f := getPinned[tableFunction[T]](C.duckdb_bind_get_extra_info(info)) + config := f.Config + + argCount := len(config.Arguments) + args := make([]any, argCount) + namedArgs := make(map[string]any) + + for i, t := range config.Arguments { + value := C.duckdb_bind_get_parameter(info, C.idx_t(i)) + var err error + args[i], err = getValue(t, value) + C.duckdb_destroy_value(&value) + + if err != nil { + setBindError(info, err.Error()) + return + } + } + + for name, t := range config.NamedArguments { + argName := C.CString(name) + value := C.duckdb_bind_get_named_parameter(info, argName) + C.duckdb_free(unsafe.Pointer(argName)) + + var err error + namedArgs[name], err = getValue(t, value) + C.duckdb_destroy_value(&value) + + if err != nil { + setBindError(info, err.Error()) + return + } + } + + instance, err := f.BindArguments(namedArgs, args...) + if err != nil { + setBindError(info, err.Error()) + return + } + + columnInfos := instance.ColumnInfos() + instanceData := tableFunctionData{ + fun: instance, + projection: make([]int, len(columnInfos)), + } + + for i, v := range columnInfos { + if v.T == nil { + setBindError(info, errTableUDFColumnTypeIsNil.Error()) + return + } + + logicalType := v.T.logicalType() + name := C.CString(v.Name) + C.duckdb_bind_add_result_column(info, name, logicalType) + C.duckdb_destroy_logical_type(&logicalType) + C.duckdb_free(unsafe.Pointer(name)) + + instanceData.projection[i] = -1 + } + + cardinality := instance.Cardinality() + if cardinality != nil { + C.duckdb_bind_set_cardinality(info, C.idx_t(cardinality.Cardinality), C.bool(cardinality.Exact)) + } + + pinnedInstanceData := pinnedValue[tableFunctionData]{ + pinner: &runtime.Pinner{}, + value: instanceData, + } + + h := cgo.NewHandle(pinnedInstanceData) + pinnedInstanceData.pinner.Pin(&h) + C.duckdb_bind_set_bind_data(info, unsafe.Pointer(&h), C.duckdb_delete_callback_t(C.udf_delete_callback)) +} + +//export table_udf_init +func table_udf_init(info C.duckdb_init_info) { + instance := getPinned[tableFunctionData](C.duckdb_init_get_bind_data(info)) + instance.setColumnCount(info) + instance.fun.(sequentialTableSource).Init() +} + +//export table_udf_init_parallel +func table_udf_init_parallel(info C.duckdb_init_info) { + instance := getPinned[tableFunctionData](C.duckdb_init_get_bind_data(info)) + instance.setColumnCount(info) + initData := instance.fun.(parallelTableSource).Init() + maxThreads := C.idx_t(initData.MaxThreads) + C.duckdb_init_set_max_threads(info, maxThreads) +} + +//export table_udf_local_init +func table_udf_local_init(info C.duckdb_init_info) { + instance := getPinned[tableFunctionData](C.duckdb_init_get_bind_data(info)) + localState := pinnedValue[any]{ + pinner: &runtime.Pinner{}, + value: instance.fun.(parallelTableSource).NewLocalState(), + } + h := cgo.NewHandle(localState) + localState.pinner.Pin(&h) + C.duckdb_init_set_init_data(info, unsafe.Pointer(&h), C.duckdb_delete_callback_t(C.udf_delete_callback)) +} + +//export table_udf_row_callback +func table_udf_row_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) { + instance := getPinned[tableFunctionData](C.duckdb_function_get_bind_data(info)) + + var chunk DataChunk + err := chunk.initFromDuckDataChunk(output, true) + if err != nil { + setFuncError(info, err.Error()) + return + } + + row := Row{ + chunk: &chunk, + projection: instance.projection, + } + maxSize := C.duckdb_vector_size() + + switch fun := instance.fun.(type) { + case RowTableSource: + // At the end of the loop row.r must be the index of the last row. + for row.r = 0; row.r < maxSize; row.r++ { + next, errRow := fun.FillRow(row) + if errRow != nil { + setFuncError(info, errRow.Error()) + break + } + if !next { + break + } + } + case ParallelRowTableSource: + // At the end of the loop row.r must be the index of the last row. + localState := getPinned[any](C.duckdb_function_get_local_init_data(info)) + for row.r = 0; row.r < maxSize; row.r++ { + next, errRow := fun.FillRow(localState, row) + if errRow != nil { + setFuncError(info, errRow.Error()) + break + } + if !next { + break + } + } + } + C.duckdb_data_chunk_set_size(output, row.r) +} + +//export table_udf_chunk_callback +func table_udf_chunk_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) { + instance := getPinned[tableFunctionData](C.duckdb_function_get_bind_data(info)) + + var chunk DataChunk + err := chunk.initFromDuckDataChunk(output, true) + if err != nil { + setFuncError(info, err.Error()) + return + } + + switch fun := instance.fun.(type) { + case ChunkTableSource: + err = fun.FillChunk(chunk) + case ParallelChunkTableSource: + localState := getPinned[*any](C.duckdb_function_get_local_init_data(info)) + err = fun.FillChunk(localState, chunk) + } + if err != nil { + setFuncError(info, err.Error()) + } +} + +// RegisterTableUDF registers a user-defined table function. +// Projection pushdown is enabled by default. +func RegisterTableUDF[TFT TableFunction](c *sql.Conn, name string, f TFT) error { + if name == "" { + return getError(errAPI, errTableUDFNoName) + } + function := C.duckdb_create_table_function() + + // Set the name. + cName := C.CString(name) + defer C.duckdb_free(unsafe.Pointer(cName)) + C.duckdb_table_function_set_name(function, cName) + + var config TableFunctionConfig + + // Pin the table function f. + value := pinnedValue[TFT]{ + pinner: &runtime.Pinner{}, + value: f, + } + h := cgo.NewHandle(value) + value.pinner.Pin(&h) + + // Set the execution data, which is the table function f. + C.duckdb_table_function_set_extra_info( + function, + unsafe.Pointer(&h), + C.duckdb_delete_callback_t(C.udf_delete_callback)) + C.duckdb_table_function_supports_projection_pushdown(function, C.bool(true)) + + // Set the config. + var x any = f + switch tableFunc := x.(type) { + case RowTableFunction: + C.duckdb_table_function_set_init(function, C.init(C.table_udf_init)) + C.duckdb_table_function_set_bind(function, C.bind(C.table_udf_bind_row)) + C.duckdb_table_function_set_function(function, C.callback(C.table_udf_row_callback)) + + config = tableFunc.Config + if tableFunc.BindArguments == nil { + return getError(errAPI, errTableUDFMissingBindArgs) + } + + case ChunkTableFunction: + C.duckdb_table_function_set_init(function, C.init(C.table_udf_init)) + C.duckdb_table_function_set_bind(function, C.bind(C.table_udf_bind_chunk)) + C.duckdb_table_function_set_function(function, C.callback(C.table_udf_chunk_callback)) + + config = tableFunc.Config + if tableFunc.BindArguments == nil { + return getError(errAPI, errTableUDFMissingBindArgs) + } + + case ParallelRowTableFunction: + C.duckdb_table_function_set_init(function, C.init(C.table_udf_init_parallel)) + C.duckdb_table_function_set_bind(function, C.bind(C.table_udf_bind_parallel_row)) + C.duckdb_table_function_set_function(function, C.callback(C.table_udf_row_callback)) + C.duckdb_table_function_set_local_init(function, C.init(C.table_udf_local_init)) + + config = tableFunc.Config + if tableFunc.BindArguments == nil { + return getError(errAPI, errTableUDFMissingBindArgs) + } + + case ParallelChunkTableFunction: + C.duckdb_table_function_set_init(function, C.init(C.table_udf_init_parallel)) + C.duckdb_table_function_set_bind(function, C.bind(C.table_udf_bind_parallel_chunk)) + C.duckdb_table_function_set_function(function, C.callback(C.table_udf_chunk_callback)) + C.duckdb_table_function_set_local_init(function, C.init(C.table_udf_local_init)) + + config = tableFunc.Config + if tableFunc.BindArguments == nil { + return getError(errAPI, errTableUDFMissingBindArgs) + } + + default: + return getError(errInternal, nil) + } + + // Set the arguments. + for _, t := range config.Arguments { + if t == nil { + return getError(errAPI, errTableUDFArgumentIsNil) + } + logicalType := t.logicalType() + C.duckdb_table_function_add_parameter(function, logicalType) + C.duckdb_destroy_logical_type(&logicalType) + } + + // Set the named arguments. + for arg, t := range config.NamedArguments { + if t == nil { + return getError(errAPI, errTableUDFArgumentIsNil) + } + logicalType := t.logicalType() + cArg := C.CString(arg) + C.duckdb_table_function_add_named_parameter(function, cArg, logicalType) + C.duckdb_destroy_logical_type(&logicalType) + C.duckdb_free(unsafe.Pointer(cArg)) + } + + // Register the function on the underlying driver connection exposed by c.Raw. + err := c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_table_function(con.duckdbCon, function) + C.duckdb_destroy_table_function(&function) + if state == C.DuckDBError { + return getError(errAPI, errTableUDFCreate) + } + return nil + }) + return err +} diff --git a/tableUDF_test.go b/tableUDF_test.go new file mode 100644 index 00000000..78d3904d --- /dev/null +++ b/tableUDF_test.go @@ -0,0 +1,733 @@ +package duckdb + +import ( + "context" + "database/sql" + "fmt" + "math/big" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type ( + testTableFunction[T TableFunction] interface { + GetFunction() T + GetValue(r, c int) any + GetTypes() []any + } + + tableUDFTest[T TableFunction] struct { + udf testTableFunction[T] + name string + query string + resultCount int + } + + // Row-based table UDF tests. + + incTableUDF struct { + n int64 + count int64 + } + + structTableUDF struct { + n int64 + count int64 + } + + otherStructTableUDF struct { + I int64 + } + + pushdownTableUDF struct { + n int64 + count int64 + } + + incTableNamedUDF struct { + n int64 + count int64 + } + + constTableUDF[T any] struct { + count int64 + value T + t Type + } + + // Parallel row-based table UDF tests. + + parallelIncTableUDF struct { + lock *sync.Mutex + claimed int64 + n int64 + } + + parallelIncTableLocalState struct { + start int64 + end int64 + } + + // Chunk-based table UDF tests. + + chunkIncTableUDF struct { + n int64 + count int64 + } +) + +var ( + rowTableUDFs = []tableUDFTest[RowTableFunction]{ + { + udf: &incTableUDF{}, + name: "incTableUDF_non_full_vector", + query: `SELECT * FROM %s(2047)`, + resultCount: 2047, + }, + { + udf: &incTableUDF{}, + name: "incTableUDF", + query: `SELECT * FROM %s(10000)`, + resultCount: 10000, + }, + { + udf: &structTableUDF{}, + name: "structTableUDF", + query: `SELECT * FROM %s(2048)`, + resultCount: 2048, + }, + { + udf: &pushdownTableUDF{}, + name: "pushdownTableUDF", + query: `SELECT result2 FROM %s(2048)`, + resultCount: 2048, + }, + { + udf: &incTableNamedUDF{}, + name: "incTableNamedUDF", + query: `SELECT * FROM %s(ARG=2048)`, + resultCount: 2048, + }, + { + udf: &constTableUDF[bool]{value: false, t: TYPE_BOOLEAN}, + name: "constTableUDF_bool", + query: `SELECT * FROM %s(false)`, + resultCount: 1, + }, + { + udf: &constTableUDF[int8]{value: -8, t: TYPE_TINYINT}, + name: "constTableUDF_int8", + query: `SELECT * FROM %s(CAST(-8 AS TINYINT))`, + resultCount: 1, + }, + { + udf: &constTableUDF[int16]{value: -16, t: TYPE_SMALLINT}, + name: "constTableUDF_int16", + query: `SELECT * FROM %s(CAST(-16 AS SMALLINT))`, + resultCount: 1, + }, + { + udf: &constTableUDF[int32]{value: -32, t: TYPE_INTEGER}, + name: "constTableUDF_int32", + query: `SELECT * FROM %s(-32)`, + resultCount: 1, + }, + { + udf: &constTableUDF[int64]{value: -64, t: TYPE_BIGINT}, + name: "constTableUDF_int64", + query: `SELECT * FROM %s(-64)`, + resultCount: 1, + }, + { + udf: &constTableUDF[uint8]{value: 8, t: TYPE_UTINYINT}, + name: "constTableUDF_uint8", + query: `SELECT * FROM %s(CAST(8 AS UTINYINT))`, + resultCount: 1, + }, + { + udf: &constTableUDF[uint16]{value: 16, t: TYPE_USMALLINT}, + name: "constTableUDF_uint16", + query: `SELECT * FROM %s(CAST(16 AS USMALLINT))`, + resultCount: 1, + }, + { + udf: &constTableUDF[uint32]{value: 32, t: TYPE_UINTEGER}, + name: "constTableUDF_uint32", + query: `SELECT * FROM %s(CAST(32 AS UINTEGER))`, + resultCount: 1, + }, + { + udf: &constTableUDF[uint64]{value: 64, t: TYPE_UBIGINT}, + name: "constTableUDF_uint64", + query: `SELECT * FROM %s(CAST(64 AS UBIGINT))`, + resultCount: 1, + }, + { + udf: &constTableUDF[float32]{value: 32, t: TYPE_FLOAT}, + name: "constTableUDF_float32", + query: `SELECT * FROM %s(32)`, + resultCount: 1, + }, + { + udf: &constTableUDF[float64]{value: 64, t: TYPE_DOUBLE}, + name: "constTableUDF_float64", + query: `SELECT * FROM %s(64)`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 12, 34, 59, 123456000, time.UTC), t: TYPE_TIMESTAMP}, + name: "constTableUDF_timestamp", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIMESTAMP))`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 0, 0, 0, 0, time.UTC), t: TYPE_DATE}, + name: "constTableUDF_date", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS DATE))`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(1970, 1, 1, 12, 34, 59, 123456000, time.UTC), t: TYPE_TIME}, + name: "constTableUDF_time", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIME))`, + resultCount: 1, + }, + { + udf: &constTableUDF[Interval]{value: Interval{Months: 16, Days: 10, Micros: 172800000000}, t: TYPE_INTERVAL}, + name: "constTableUDF_interval", + query: `SELECT * FROM %s('16 months 10 days 48:00:00'::INTERVAL)`, + resultCount: 1, + }, + { + udf: &constTableUDF[*big.Int]{value: big.NewInt(10000000000000000), t: TYPE_HUGEINT}, + name: "constTableUDF_bigint", + query: `SELECT * FROM %s(10000000000000000)`, + resultCount: 1, + }, + { + udf: &constTableUDF[string]{value: "my_lovely_string", t: TYPE_VARCHAR}, + name: "constTableUDF_string", + query: `SELECT * FROM %s('my_lovely_string')`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 12, 34, 59, 0, time.UTC), t: TYPE_TIMESTAMP_S}, + name: "constTableUDF_timestamp_s", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIMESTAMP_S))`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 12, 34, 59, 123000000, time.UTC), t: TYPE_TIMESTAMP_MS}, + name: "constTableUDF_timestamp_ms", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIMESTAMP_MS))`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 12, 34, 59, 123456000, time.UTC), t: TYPE_TIMESTAMP_NS}, + name: "constTableUDF_timestamp_ns", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIMESTAMP_NS))`, + resultCount: 1, + }, + { + udf: &constTableUDF[time.Time]{value: time.Date(2006, 7, 8, 12, 34, 59, 123456000, time.UTC), t: TYPE_TIMESTAMP_TZ}, + name: "constTableUDF_timestamp_tz", + query: `SELECT * FROM %s(CAST('2006-07-08 12:34:59.123456789' AS TIMESTAMPTZ))`, + resultCount: 1, + }, + } + parallelTableUDFs = []tableUDFTest[ParallelRowTableFunction]{ + { + udf: ¶llelIncTableUDF{}, + name: "parallelIncTableUDF", + query: `SELECT * FROM %s(2048) ORDER BY result`, + resultCount: 2048, + }, + } + chunkTableUDFs = []tableUDFTest[ChunkTableFunction]{ + { + udf: &chunkIncTableUDF{}, + name: "chunkIncTableUDF", + query: `SELECT * FROM %s(2048)`, + resultCount: 2048, + }, + } +) + +var ( + typeBigintTableUDF, _ = NewTypeInfo(TYPE_BIGINT) + typeStructTableUDF = makeStructTableUDF() +) + +func makeStructTableUDF() TypeInfo { + entry, _ := NewStructEntry(typeBigintTableUDF, "I") + info, _ := NewStructInfo(entry) + return info +} + +func (udf *incTableUDF) GetFunction() RowTableFunction { + return RowTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{typeBigintTableUDF}, + }, + BindArguments: bindIncTableUDF, + } +} + +func bindIncTableUDF(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return &incTableUDF{ + count: 0, + n: args[0].(int64), + }, nil +} + +func (udf *incTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}} +} + +func (udf *incTableUDF) Init() {} + +func (udf *incTableUDF) FillRow(row Row) (bool, error) { + if udf.count >= udf.n { + return false, nil + } + udf.count++ + err := SetRowValue(row, 0, udf.count) + return true, err +} + +func (udf *incTableUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *incTableUDF) GetTypes() []any { + return []any{0} +} + +func (udf *incTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func bindParallelIncTableUDF(namedArgs map[string]any, args ...interface{}) (ParallelRowTableSource, error) { + return ¶llelIncTableUDF{ + lock: &sync.Mutex{}, + claimed: 0, + n: args[0].(int64), + }, nil +} + +func (udf *parallelIncTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}} +} + +func (udf *parallelIncTableUDF) Init() ParallelTableSourceInfo { + return ParallelTableSourceInfo{MaxThreads: 8} +} + +func (udf *parallelIncTableUDF) NewLocalState() any { + return ¶llelIncTableLocalState{ + start: 0, + end: -1, + } +} + +func (udf *parallelIncTableUDF) FillRow(localState any, row Row) (bool, error) { + state := localState.(*parallelIncTableLocalState) + + if state.start >= state.end { + // Claim a new work unit. + udf.lock.Lock() + remaining := udf.n - udf.claimed + + if remaining <= 0 { + // No more work. + udf.lock.Unlock() + return false, nil + } else if remaining >= 2024 { + remaining = 2024 + } + + state.start = udf.claimed + udf.claimed += remaining + state.end = udf.claimed + udf.lock.Unlock() + } + + state.start++ + err := SetRowValue(row, 0, state.start) + return true, err +} + +func (udf *parallelIncTableUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *parallelIncTableUDF) GetTypes() []any { + return []any{0} +} + +func (udf *parallelIncTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *parallelIncTableUDF) GetFunction() ParallelRowTableFunction { + return ParallelRowTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{typeBigintTableUDF}, + }, + BindArguments: bindParallelIncTableUDF, + } +} + +func (udf *structTableUDF) GetFunction() RowTableFunction { + return RowTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{typeBigintTableUDF}, + }, + BindArguments: bindStructTableUDF, + } +} + +func bindStructTableUDF(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return &structTableUDF{ + count: 0, + n: args[0].(int64), + }, nil +} + +func (udf *structTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeStructTableUDF}} +} + +func (udf *structTableUDF) Init() {} + +func (udf *structTableUDF) FillRow(row Row) (bool, error) { + if udf.count >= udf.n { + return false, nil + } + udf.count++ + err := SetRowValue(row, 0, otherStructTableUDF{I: udf.count}) + return true, err +} + +func (udf *structTableUDF) GetTypes() []any { + return []any{0} +} + +func (udf *structTableUDF) GetValue(r, c int) any { + return map[string]any{"I": int64(r + 1)} +} + +func (udf *structTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *pushdownTableUDF) GetFunction() RowTableFunction { + return RowTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{typeBigintTableUDF}, + }, + BindArguments: bindPushdownTableUDF, + } +} + +func bindPushdownTableUDF(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return &pushdownTableUDF{ + count: 0, + n: args[0].(int64), + }, nil +} + +func (udf *pushdownTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{ + {Name: "result", T: typeBigintTableUDF}, + {Name: "result2", T: typeBigintTableUDF}, + } +} + +func (udf *pushdownTableUDF) Init() {} + +func (udf *pushdownTableUDF) FillRow(row Row) (bool, error) { + if udf.count >= udf.n { + return false, nil + } + + if row.IsProjected(0) { + err := fmt.Errorf("column 0 is projected while it should not be") + return false, err + } + + udf.count++ + if err := SetRowValue(row, 0, udf.count); err != nil { + return false, err + } + if err := SetRowValue(row, 1, udf.count); err != nil { + return false, err + } + return true, nil +} + +func (udf *pushdownTableUDF) GetName() string { + return "pushdownTableUDF" +} + +func (udf *pushdownTableUDF) GetTypes() []any { + return []any{int64(0)} +} + +func (udf *pushdownTableUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *pushdownTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *incTableNamedUDF) GetFunction() RowTableFunction { + return RowTableFunction{ + Config: TableFunctionConfig{ + NamedArguments: map[string]TypeInfo{"ARG": typeBigintTableUDF}, + }, + BindArguments: bindIncTableNamedUDF, + } +} + +func bindIncTableNamedUDF(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return &incTableNamedUDF{ + count: 0, + n: namedArgs["ARG"].(int64), + }, nil +} + +func (udf *incTableNamedUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}} +} + +func (udf *incTableNamedUDF) Init() {} + +func (udf *incTableNamedUDF) FillRow(row Row) (bool, error) { + if udf.count >= udf.n { + return false, nil + } + udf.count++ + err := SetRowValue(row, 0, udf.count) + return true, err +} + +func (udf *incTableNamedUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *incTableNamedUDF) GetTypes() []any { + return []any{0} +} + +func (udf *incTableNamedUDF) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *constTableUDF[T]) GetFunction() RowTableFunction { + info, _ := NewTypeInfo(udf.t) + return RowTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{info}, + }, + BindArguments: bindConstTableUDF(udf.value, udf.t), + } +} + +func bindConstTableUDF[T any](val T, t Type) func(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return func(namedArgs map[string]any, args ...interface{}) (RowTableSource, error) { + return &constTableUDF[T]{ + count: 0, + value: args[0].(T), + t: t, + }, nil + } +} + +func (udf *constTableUDF[T]) ColumnInfos() []ColumnInfo { + info, _ := NewTypeInfo(udf.t) + return []ColumnInfo{{Name: "result", T: info}} +} + +func (udf *constTableUDF[T]) Init() {} + +func (udf *constTableUDF[T]) FillRow(row Row) (bool, error) { + if udf.count >= 1 { + return false, nil + } + udf.count++ + err := SetRowValue(row, 0, udf.value) + return true, err +} + +func (udf *constTableUDF[T]) GetValue(r, c int) any { + return udf.value +} + +func (udf *constTableUDF[T]) GetTypes() []any { + return []any{udf.value} +} + +func (udf *constTableUDF[T]) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *chunkIncTableUDF) GetFunction() ChunkTableFunction { + return ChunkTableFunction{ + Config: TableFunctionConfig{ + Arguments: []TypeInfo{typeBigintTableUDF}, + }, + BindArguments: bindChunkIncTableUDF, + } +} + +func bindChunkIncTableUDF(namedArgs map[string]any, args ...interface{}) (ChunkTableSource, error) { + return &chunkIncTableUDF{ + count: 0, + n: args[0].(int64), + }, nil +} + +func (udf *chunkIncTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}} +} + +func (udf *chunkIncTableUDF) Init() {} + +func (udf *chunkIncTableUDF) FillChunk(chunk DataChunk) error { + size := 2048 + i := 0 + + for ; i < size; i++ { + if udf.count >= udf.n { + err := chunk.SetSize(i) + return err + } + udf.count++ + err := chunk.SetValue(0, i, udf.count) + if err != nil { + return err + } + } + + err := chunk.SetSize(i) + return err +} + +func (udf *chunkIncTableUDF) GetValue(r, c int) any { + return int64(r + 1) +} + +func (udf *chunkIncTableUDF) GetTypes() []any { + return []any{0} +} + +func (udf *chunkIncTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func TestTableUDF(t *testing.T) { + for _, udf := range rowTableUDFs { + t.Run(udf.name, func(t *testing.T) { + singleTableUDF(t, udf) + }) + } + + for _, udf := range parallelTableUDFs { + t.Run(udf.name, func(t *testing.T) { + singleTableUDF(t, udf) + }) + } + + for _, udf := range chunkTableUDFs { + t.Run(udf.name, func(t *testing.T) { + singleTableUDF(t, udf) + }) + } +} + +func singleTableUDF[T TableFunction](t *testing.T, fun tableUDFTest[T]) { + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + require.NoError(t, err) + + con, err := db.Conn(context.Background()) + require.NoError(t, err) + + err = RegisterTableUDF(con, fun.name, fun.udf.GetFunction()) + require.NoError(t, err) + + res, err := db.QueryContext(context.Background(), fmt.Sprintf(fun.query, fun.name)) + require.NoError(t, err) + + values := fun.udf.GetTypes() + args := make([]interface{}, len(values)) + for i := range values { + args[i] = &values[i] + } + + count := 0 + for r := 0; res.Next(); r++ { + require.NoError(t, res.Scan(args...)) + + for i, value := range values { + expected := fun.udf.GetValue(r, i) + require.Equal(t, expected, value, "incorrect value") + } + count++ + } + + require.Equal(t, count, fun.resultCount, "result count did not match the expected count") + require.NoError(t, res.Close()) + require.NoError(t, con.Close()) + require.NoError(t, db.Close()) +} + +func BenchmarkRowTableUDF(b *testing.B) { + b.StopTimer() + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + require.NoError(b, err) + + con, err := db.Conn(context.Background()) + require.NoError(b, err) + + var fun incTableUDF + err = RegisterTableUDF(con, "whoo", fun.GetFunction()) + require.NoError(b, err) + + b.StartTimer() + for n := 0; n < b.N; n++ { + res, errQuery := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048*64)") + require.NoError(b, errQuery) + require.NoError(b, res.Close()) + } + + require.NoError(b, con.Close()) + require.NoError(b, db.Close()) +} + +func BenchmarkChunkTableUDF(b *testing.B) { + b.StopTimer() + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + require.NoError(b, err) + + con, err := db.Conn(context.Background()) + require.NoError(b, err) + + var fun chunkIncTableUDF + err = RegisterTableUDF(con, "whoo", fun.GetFunction()) + require.NoError(b, err) + + b.StartTimer() + for n := 0; n < b.N; n++ { + res, errQuery := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048*64)") + require.NoError(b, errQuery) + require.NoError(b, res.Close()) + } + + require.NoError(b, con.Close()) + require.NoError(b, db.Close()) +} diff --git a/udf_utils.go b/udf_utils.go index 91ef9541..b1ec92b9 100644 --- a/udf_utils.go +++ b/udf_utils.go @@ -11,6 +11,8 @@ import ( "unsafe" ) +// Helpers for passing values to C and back. + type pinnedValue[T any] struct { pinner *runtime.Pinner value T @@ -24,12 +26,28 @@ func (v pinnedValue[T]) unpin() { v.pinner.Unpin() } +func getPinned[T any](handle unsafe.Pointer) T { + h := *(*cgo.Handle)(handle) + return h.Value().(pinnedValue[T]).value +} + +// Set error helpers. + +func setBindError(info C.duckdb_bind_info, msg string) { + err := C.CString(msg) + defer C.duckdb_free(unsafe.Pointer(err)) + C.duckdb_bind_set_error(info, err) +} + func setFuncError(function_info C.duckdb_function_info, msg string) { err := C.CString(msg) + defer C.duckdb_free(unsafe.Pointer(err)) C.duckdb_scalar_function_set_error(function_info, err) - C.duckdb_free(unsafe.Pointer(err)) } +// Data deletion handlers. + +//export udf_delete_callback func udf_delete_callback(info unsafe.Pointer) { h := (*cgo.Handle)(info) h.Value().(unpinner).unpin() diff --git a/value.go b/value.go new file mode 100644 index 00000000..e5271fe4 --- /dev/null +++ b/value.go @@ -0,0 +1,77 @@ +package duckdb + +/* +#include +*/ +import "C" + +import ( + "time" + "unsafe" +) + +func getValue(t TypeInfo, v C.duckdb_value) (any, error) { + switch t.(*typeInfo).Type { + case TYPE_BOOLEAN: + return bool(C.duckdb_get_bool(v)), nil + case TYPE_TINYINT: + return int8(C.duckdb_get_int8(v)), nil + case TYPE_SMALLINT: + return int16(C.duckdb_get_int16(v)), nil + case TYPE_INTEGER: + return int32(C.duckdb_get_int32(v)), nil + case TYPE_BIGINT: + return int64(C.duckdb_get_int64(v)), nil + case TYPE_UTINYINT: + return uint8(C.duckdb_get_uint8(v)), nil + case TYPE_USMALLINT: + return uint16(C.duckdb_get_uint16(v)), nil + case TYPE_UINTEGER: + return uint32(C.duckdb_get_uint32(v)), nil + case TYPE_UBIGINT: + return uint64(C.duckdb_get_uint64(v)), nil + case TYPE_FLOAT: + return float32(C.duckdb_get_float(v)), nil + case TYPE_DOUBLE: + return float64(C.duckdb_get_double(v)), nil + case TYPE_TIMESTAMP: + val := C.duckdb_get_timestamp(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + case TYPE_DATE: + primitiveDate := C.duckdb_get_date(v) + date := C.duckdb_from_date(primitiveDate) + return time.Date(int(date.year), time.Month(date.month), int(date.day), 0, 0, 0, 0, time.UTC), nil + case TYPE_TIME: + val := C.duckdb_get_time(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + case TYPE_INTERVAL: + interval := C.duckdb_get_interval(v) + return Interval{ + Days: int32(interval.days), + Months: int32(interval.months), + Micros: int64(interval.micros), + }, nil + case TYPE_HUGEINT: + hugeint := C.duckdb_get_hugeint(v) + return hugeIntToNative(hugeint), nil + case TYPE_VARCHAR: + str := C.duckdb_get_varchar(v) + ret := C.GoString(str) + C.duckdb_free(unsafe.Pointer(str)) + return ret, nil + case TYPE_TIMESTAMP_S: + val := C.duckdb_get_timestamp(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + case TYPE_TIMESTAMP_MS: + val := C.duckdb_get_timestamp(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + case TYPE_TIMESTAMP_NS: + val := C.duckdb_get_timestamp(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + case TYPE_TIMESTAMP_TZ: + val := C.duckdb_get_timestamp(v) + return time.UnixMicro(int64(val.micros)).UTC(), nil + default: + return nil, unsupportedTypeError(typeToStringMap[t.InternalType()]) + } +} diff --git a/vector_getters.go b/vector_getters.go index 3f82f3a7..031764f1 100644 --- a/vector_getters.go +++ b/vector_getters.go @@ -36,6 +36,7 @@ func (vec *vector) getTS(t Type, rowIdx C.idx_t) time.Time { val := getPrimitive[C.duckdb_timestamp](vec, rowIdx) micros := val.micros + // FIXME: Unify this code path with the value.go code path. switch t { case TYPE_TIMESTAMP: return time.UnixMicro(int64(micros)).UTC()