Skip to content

Commit

Permalink
Merge pull request #201 from JAicewizard/main
Browse files Browse the repository at this point in the history
Table function UDFs
  • Loading branch information
taniabogatsch authored Oct 16, 2024
2 parents ba0eeb8 + b646fbd commit 7c41608
Show file tree
Hide file tree
Showing 13 changed files with 1,622 additions and 15 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ examples:
go run examples/simple/main.go
go run examples/appender/main.go
go run examples/scalar_udf/main.go
go run examples/table_udf/main.go
go run examples/table_udf_parallel/main.go

.PHONY: test
test:
Expand Down
7 changes: 7 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ const (
)

var (
errInternal = errors.New("internal error: please file a bug report at go-duckdb")
errAPI = errors.New("API error")
errVectorSize = errors.New("data chunks cannot exceed duckdb's internal vector size")

Expand Down Expand Up @@ -105,6 +106,12 @@ var (
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: could not add the function to the set", errScalarUDFCreateSet)

errTableUDFCreate = errors.New("could not create table UDF")
errTableUDFNoName = fmt.Errorf("%w: missing name", errTableUDFCreate)
errTableUDFMissingBindArgs = fmt.Errorf("%w: missing bind arguments", errTableUDFCreate)
errTableUDFArgumentIsNil = fmt.Errorf("%w: argument is nil", errTableUDFCreate)
errTableUDFColumnTypeIsNil = fmt.Errorf("%w: column type is nil", errTableUDFCreate)

errProfilingInfoEmpty = errors.New("no profiling information available for this connection")

// Errors not covered in tests.
Expand Down
100 changes: 100 additions & 0 deletions examples/table_udf/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"context"
"database/sql"
"fmt"

"github.com/marcboeker/go-duckdb"
)

type incrementTableUDF struct {
tableSize int64
currentRow int64
}

func bindTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.RowTableSource, error) {
return &incrementTableUDF{
currentRow: 0,
tableSize: args[0].(int64),
}, nil
}

func (udf *incrementTableUDF) ColumnInfos() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{{Name: "result", T: t}}
}

func (udf *incrementTableUDF) Init() {}

func (udf *incrementTableUDF) FillRow(row duckdb.Row) (bool, error) {
if udf.currentRow+1 > udf.tableSize {
return false, nil
}
udf.currentRow++
err := duckdb.SetRowValue(row, 0, udf.currentRow)
return true, err
}

func (udf *incrementTableUDF) Cardinality() *duckdb.CardinalityInfo {
return &duckdb.CardinalityInfo{
Cardinality: uint(udf.tableSize),
Exact: true,
}
}

func main() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
udf := duckdb.RowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindTableUDF,
}

err = duckdb.RegisterTableUDF(conn, "increment", udf)
check(err)

rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(100)`)
check(err)

// Get the column names.
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
args := make([]interface{}, len(values))
for i := range values {
args[i] = &values[i]
}

rowSum := int64(0)
for rows.Next() {
err = rows.Scan(args...)
check(err)
for _, value := range values {
// Keep in mind that the value can be nil for NULL values.
// This never happens in this example, so we don't check for it.
rowSum += value.(int64)
}
}
fmt.Printf("row sum: %d", rowSum)

check(rows.Close())
check(conn.Close())
check(db.Close())
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
156 changes: 156 additions & 0 deletions examples/table_udf_parallel/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package main

import (
"context"
"database/sql"
"fmt"
"sync"

"github.com/marcboeker/go-duckdb"
)

type (
parallelIncrementTableUDF struct {
lock *sync.Mutex
claimed int64
n int64
}

localTableState struct {
start int64
end int64
}
)

func bindParallelTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.ParallelRowTableSource, error) {
return &parallelIncrementTableUDF{
lock: &sync.Mutex{},
claimed: 0,
n: args[0].(int64),
}, nil
}

func (udf *parallelIncrementTableUDF) ColumnInfos() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{{Name: "result", T: t}}
}

func (udf *parallelIncrementTableUDF) Init() duckdb.ParallelTableSourceInfo {
return duckdb.ParallelTableSourceInfo{
MaxThreads: 8,
}
}

func (udf *parallelIncrementTableUDF) NewLocalState() any {
return &localTableState{
start: 0,
end: -1,
}
}

func (udf *parallelIncrementTableUDF) FillRow(localState any, row duckdb.Row) (bool, error) {
state := localState.(*localTableState)

if state.start >= state.end {
// Claim a new work unit.
udf.lock.Lock()
remaining := udf.n - udf.claimed

if remaining <= 0 {
// No more work.
udf.lock.Unlock()
return false, nil
} else if remaining >= 2024 {
remaining = 2024
}

state.start = udf.claimed
udf.claimed += remaining
state.end = udf.claimed
udf.lock.Unlock()
}

state.start++
err := duckdb.SetRowValue(row, 0, state.start)
return true, err
}

func (udf *parallelIncrementTableUDF) GetValue(r, c int) any {
return int64(r + 1)
}

func (udf *parallelIncrementTableUDF) GetTypes() []any {
return []any{0}
}

func (udf *parallelIncrementTableUDF) Cardinality() *duckdb.CardinalityInfo {
return nil
}

func (udf *parallelIncrementTableUDF) GetFunction() duckdb.ParallelRowTableFunction {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)

return duckdb.ParallelRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindParallelTableUDF,
}
}

func main() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
udf := duckdb.ParallelRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: bindParallelTableUDF,
}

err = duckdb.RegisterTableUDF(conn, "increment", udf)
check(err)

rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(10000)`)
check(err)

// Get the column names.
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
args := make([]interface{}, len(values))
for i := range values {
args[i] = &values[i]
}

rowSum := int64(0)
for rows.Next() {
err = rows.Scan(args...)
check(err)
for _, value := range values {
// Keep in mind that the value could be nil for NULL values.
// This never happens in our case, so we don't check for it.
rowSum += value.(int64)
}
}
fmt.Printf("row sum: %d", rowSum)

check(rows.Close())
check(conn.Close())
check(db.Close())
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
2 changes: 1 addition & 1 deletion replacement_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.ccha
tFunc, params, err := scanner(C.GoString(table_name))
if err != nil {
errStr := C.CString(err.Error())
defer C.duckdb_free(unsafe.Pointer(errStr))
C.duckdb_replacement_scan_set_error(info, errStr)
C.duckdb_free(unsafe.Pointer(errStr))
return
}

Expand Down
37 changes: 37 additions & 0 deletions row.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package duckdb

/*
#include <duckdb.h>
*/
import "C"

// Row represents one row in duckdb. It references the internal vectors.
type Row struct {
chunk *DataChunk
r C.idx_t
projection []int
}

// IsProjected returns whether the column is projected.
func (r Row) IsProjected(colIdx int) bool {
return r.projection[colIdx] != -1
}

// SetRowValue sets the value at colIdx to val.
// Returns an error on failure, and nil for non-projected columns.
func SetRowValue[T any](row Row, colIdx int, val T) error {
projectedIdx := row.projection[colIdx]
if projectedIdx < 0 || projectedIdx >= len(row.chunk.columns) {
return nil
}
vec := row.chunk.columns[projectedIdx]
return setVectorVal(&vec, row.r, val)
}

// SetRowValue sets the value at colIdx to val. Returns an error on failure.
func (r Row) SetRowValue(colIdx int, val any) error {
if !r.IsProjected(colIdx) {
return nil
}
return r.chunk.SetValue(colIdx, int(r.r), val)
}
18 changes: 5 additions & 13 deletions scalar_udf.go → scalarUDF.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package duckdb
#include <duckdb.h>
void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
void scalar_udf_delete_callback(void *);
void udf_delete_callback(void *);
typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
*/
Expand Down Expand Up @@ -89,8 +89,8 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
// functions contains all ScalarFunc functions of the scalar function set.
func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
cName := C.CString(name)
defer C.duckdb_free(unsafe.Pointer(cName))
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
Expand Down Expand Up @@ -125,10 +125,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err
//export scalar_udf_callback
func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(function_info)

// extraInfo is a void* pointer to our pinned ScalarFunc f.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
function := h.Value().(pinnedValue[ScalarFunc]).value
function := getPinned[ScalarFunc](extraInfo)

// Initialize the input chunk.
var inputChunk DataChunk
Expand Down Expand Up @@ -191,11 +188,6 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da
}
}

//export scalar_udf_delete_callback
func scalar_udf_delete_callback(info unsafe.Pointer) {
udf_delete_callback(info)
}

func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo != nil {
Expand Down Expand Up @@ -252,8 +244,8 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro

// Set the name.
cName := C.CString(name)
defer C.duckdb_free(unsafe.Pointer(cName))
C.duckdb_scalar_function_set_name(function, cName)
C.duckdb_free(unsafe.Pointer(cName))

// Configure the scalar function.
config := f.Config()
Expand Down Expand Up @@ -285,7 +277,7 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro
C.duckdb_scalar_function_set_extra_info(
function,
unsafe.Pointer(&h),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))
C.duckdb_delete_callback_t(C.udf_delete_callback))

return function, nil
}
File renamed without changes.
Loading

0 comments on commit 7c41608

Please sign in to comment.