Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table function UDFs #201

Merged
merged 20 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ examples:
go run examples/simple/main.go
go run examples/appender/main.go
go run examples/scalar_udf/main.go
go run examples/table_udf/main.go
go run examples/table_udf_parallel/main.go

.PHONY: test
test:
Expand Down
7 changes: 7 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const (
)

var (
errInternal = errors.New("internal error: please file a bug report at go-duckdb")
errAPI = errors.New("API error")
errVectorSize = errors.New("data chunks cannot exceed duckdb's internal vector size")

Expand Down Expand Up @@ -105,6 +106,12 @@ var (
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: could not add the function to the set", errScalarUDFCreateSet)

errTableUDFCreate = errors.New("could not create table UDF")
errTableUDFNoName = fmt.Errorf("%w: missing name", errTableUDFCreate)
errTableUDFMissingBindArgs = fmt.Errorf("%w: missing bind arguments", errTableUDFCreate)
errTableUDFArgumentIsNil = fmt.Errorf("%w: argument is nil", errTableUDFCreate)
errTableUDFColumnTypeIsNil = fmt.Errorf("%w: column type is nil", errTableUDFCreate)

errProfilingInfoEmpty = errors.New("no profiling information available for this connection")

// Errors not covered in tests.
Expand Down
100 changes: 100 additions & 0 deletions examples/table_udf/main.go
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"context"
"database/sql"
"fmt"

"github.com/marcboeker/go-duckdb"
)

type incrementTableUDF struct {
tableSize int64
currentRow int64
}

func bindTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.RowTableSource, error) {
return &incrementTableUDF{
currentRow: 0,
tableSize: args[0].(int64),
}, nil
}

func (udf *incrementTableUDF) ColumnInfos() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{{Name: "result", T: t}}
}

func (udf *incrementTableUDF) Init() {}

func (udf *incrementTableUDF) FillRow(row duckdb.Row) (bool, error) {
if udf.currentRow+1 > udf.tableSize {
return false, nil
}
udf.currentRow++
err := duckdb.SetRowValue(row, 0, udf.currentRow)
return true, err
}

func (udf *incrementTableUDF) Cardinality() *duckdb.CardinalityInfo {
return &duckdb.CardinalityInfo{
Cardinality: uint(udf.tableSize),
Exact: true,
}
}

func main() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
udf := duckdb.RowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindTableUDF,
}

err = duckdb.RegisterTableUDF(conn, "increment", udf)
check(err)

rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(100)`)
check(err)

// Get the column names.
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
args := make([]interface{}, len(values))
for i := range values {
args[i] = &values[i]
}

rowSum := int64(0)
for rows.Next() {
err = rows.Scan(args...)
check(err)
for _, value := range values {
// Keep in mind that the value can be nil for NULL values.
// This never happens in this example, so we don't check for it.
rowSum += value.(int64)
}
}
fmt.Printf("row sum: %d", rowSum)

check(rows.Close())
check(conn.Close())
check(db.Close())
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
156 changes: 156 additions & 0 deletions examples/table_udf_parallel/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package main

import (
"context"
"database/sql"
"fmt"
"sync"

"github.com/marcboeker/go-duckdb"
)

type (
parallelIncrementTableUDF struct {
lock *sync.Mutex
claimed int64
n int64
}

localTableState struct {
start int64
end int64
}
)

func bindParallelTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.ParallelRowTableSource, error) {
return &parallelIncrementTableUDF{
lock: &sync.Mutex{},
claimed: 0,
n: args[0].(int64),
}, nil
}

func (udf *parallelIncrementTableUDF) ColumnInfos() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{{Name: "result", T: t}}
}

func (udf *parallelIncrementTableUDF) Init() duckdb.ParallelTableSourceInfo {
return duckdb.ParallelTableSourceInfo{
MaxThreads: 8,
}
}

func (udf *parallelIncrementTableUDF) NewLocalState() any {
return &localTableState{
start: 0,
end: -1,
}
}

func (udf *parallelIncrementTableUDF) FillRow(localState any, row duckdb.Row) (bool, error) {
state := localState.(*localTableState)

if state.start >= state.end {
// Claim a new work unit.
udf.lock.Lock()
remaining := udf.n - udf.claimed

if remaining <= 0 {
// No more work.
udf.lock.Unlock()
return false, nil
} else if remaining >= 2024 {
remaining = 2024
}

state.start = udf.claimed
udf.claimed += remaining
state.end = udf.claimed
udf.lock.Unlock()
}

state.start++
err := duckdb.SetRowValue(row, 0, state.start)
return true, err
}

func (udf *parallelIncrementTableUDF) GetValue(r, c int) any {
return int64(r + 1)
}

func (udf *parallelIncrementTableUDF) GetTypes() []any {
return []any{0}
}

func (udf *parallelIncrementTableUDF) Cardinality() *duckdb.CardinalityInfo {
return nil
}

func (udf *parallelIncrementTableUDF) GetFunction() duckdb.ParallelRowTableFunction {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)

return duckdb.ParallelRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindParallelTableUDF,
}
}

func main() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
udf := duckdb.ParallelRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindParallelTableUDF,
}

err = duckdb.RegisterTableUDF(conn, "increment", udf)
check(err)

rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(10000)`)
check(err)

// Get the column names.
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
args := make([]interface{}, len(values))
for i := range values {
args[i] = &values[i]
}

rowSum := int64(0)
for rows.Next() {
err = rows.Scan(args...)
check(err)
for _, value := range values {
// Keep in mind that the value could be nil for NULL values.
// This never happens in our case, so we don't check for it.
rowSum += value.(int64)
}
}
fmt.Printf("row sum: %d", rowSum)

check(rows.Close())
check(conn.Close())
check(db.Close())
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
2 changes: 1 addition & 1 deletion replacement_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.ccha
tFunc, params, err := scanner(C.GoString(table_name))
if err != nil {
errStr := C.CString(err.Error())
defer C.duckdb_free(unsafe.Pointer(errStr))
C.duckdb_replacement_scan_set_error(info, errStr)
C.duckdb_free(unsafe.Pointer(errStr))
return
}

Expand Down
37 changes: 37 additions & 0 deletions row.go
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package duckdb

/*
#include <duckdb.h>
*/
import "C"

// Row represents one row in duckdb. It references the internal vectors.
type Row struct {
chunk *DataChunk
r C.idx_t
projection []int
}

// IsProjected returns whether the column is projected.
func (r Row) IsProjected(colIdx int) bool {
return r.projection[colIdx] != -1
}

// SetRowValue sets the value at colIdx to val.
// Returns an error on failure, and nil for non-projected columns.
func SetRowValue[T any](row Row, colIdx int, val T) error {
projectedIdx := row.projection[colIdx]
if projectedIdx < 0 || projectedIdx >= len(row.chunk.columns) {
return nil
}
vec := row.chunk.columns[projectedIdx]
return setVectorVal(&vec, row.r, val)
}

// SetRowValue sets the value at colIdx to val. Returns an error on failure.
func (r Row) SetRowValue(colIdx int, val any) error {
if !r.IsProjected(colIdx) {
return nil
}
return r.chunk.SetValue(colIdx, int(r.r), val)
}
18 changes: 5 additions & 13 deletions scalar_udf.go → scalarUDF.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package duckdb
#include <duckdb.h>

void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
void scalar_udf_delete_callback(void *);
void udf_delete_callback(void *);

typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
*/
Expand Down Expand Up @@ -89,8 +89,8 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
// functions contains all ScalarFunc functions of the scalar function set.
func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
cName := C.CString(name)
defer C.duckdb_free(unsafe.Pointer(cName))
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
Expand Down Expand Up @@ -125,10 +125,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err
//export scalar_udf_callback
func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(function_info)

// extraInfo is a void* pointer to our pinned ScalarFunc f.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
function := h.Value().(pinnedValue[ScalarFunc]).value
function := getPinned[ScalarFunc](extraInfo)

// Initialize the input chunk.
var inputChunk DataChunk
Expand Down Expand Up @@ -191,11 +188,6 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da
}
}

//export scalar_udf_delete_callback
func scalar_udf_delete_callback(info unsafe.Pointer) {
udf_delete_callback(info)
}

func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo != nil {
Expand Down Expand Up @@ -252,8 +244,8 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro

// Set the name.
cName := C.CString(name)
defer C.duckdb_free(unsafe.Pointer(cName))
C.duckdb_scalar_function_set_name(function, cName)
C.duckdb_free(unsafe.Pointer(cName))

// Configure the scalar function.
config := f.Config()
Expand Down Expand Up @@ -285,7 +277,7 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro
C.duckdb_scalar_function_set_extra_info(
function,
unsafe.Pointer(&h),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))
C.duckdb_delete_callback_t(C.udf_delete_callback))

return function, nil
}
File renamed without changes.
Loading
Loading