Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table function UDFs #201

Merged
merged 20 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ examples:
go run examples/simple/main.go
go run examples/appender/main.go
go run examples/scalar_udf/main.go
go run examples/table_udf_basic/main.go
go run examples/table_udf_parallel/main.go

.PHONY: test
test:
Expand Down
9 changes: 7 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ var (
errParseDSN = errors.New("could not parse DSN for database")
errOpen = errors.New("could not open database")
errSetConfig = errors.New("could not set invalid or local option for global database config")
errMalformedType = errors.New("Used a malformed TypeInfo to indicate a type")

errUnsupportedMapKeyType = errors.New("MAP key type not supported")

Expand Down Expand Up @@ -111,8 +112,12 @@ var (
errSetSQLNULLValue = errors.New("cannot write to a NULL column")

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
errTableUDFCreate = errors.New("could not create table UDF")
errTableUDFMissingBindags = errors.New("could not create table UDF, missing bind arguments")
errTableUDFNoName = errors.New("could not create table UDF, name cannot be empty")
errTableUDFNillFunction = errors.New("could not create table UDF, no function provided")
)

type ErrorType int
Expand Down
108 changes: 108 additions & 0 deletions examples/table_udf_basic/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package main

import (
"context"
"database/sql"
"fmt"
"log"
"reflect"

"github.com/marcboeker/go-duckdb"
)

var db *sql.DB

type incrementTableUDF struct {
tableSize int64
currentRow int64
}

// BindTableUDF creates
func BindTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.RowTableSource, error) {
return &incrementTableUDF{
currentRow: 0,
tableSize: args[0].(int64),
}, nil
}

func (d *incrementTableUDF) Columns() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{
{Name: "result", T: t},
}
}

func (d *incrementTableUDF) Init() {}

func (d *incrementTableUDF) FillRow(row duckdb.Row) (bool, error) {
if d.currentRow+1 > d.tableSize {
return false, nil
}
d.currentRow++
duckdb.SetRowValue(row, 0, d.currentRow)
return true, nil
}

func (d *incrementTableUDF) Cardinality() *duckdb.CardinalityInfo {
return &duckdb.CardinalityInfo{
Cardinality: uint(d.tableSize),
Exact: true,
}
}

func main() {
var err error
db, err = sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)
if err != nil {
log.Fatal(err)
}
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved
defer db.Close()
conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
fun := duckdb.RowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: BindTableUDF,
}

duckdb.RegisterTableUDF(conn, "inc", fun)

rows, err := db.QueryContext(context.Background(), "SELECT * FROM inc(100)")
check(err)
defer rows.Close()

// Get column names
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}

// Fetch rows
for rows.Next() {
err = rows.Scan(scanArgs...)
check(err)
for i, value := range values {
// Keep in mind that the value could be nil for NULL values.
// This never happens in our case, so we don't check for it.
fmt.Printf("%s: %v\n", columns[i], value)
fmt.Printf("Type: %s\n", reflect.TypeOf(value))
}
fmt.Println("-----------------------------------")
}
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
160 changes: 160 additions & 0 deletions examples/table_udf_parallel/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package main

import (
"context"
"database/sql"
"fmt"
"log"
"reflect"
"sync"

"github.com/marcboeker/go-duckdb"
)

var db *sql.DB

type (
parallelIncTableUDF struct {
lock *sync.Mutex
claimed int64
n int64
}

parallelIncTableLocal struct {
start int64
end int64
}
)

func BindParallelIncTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.ThreadedRowTableSource, error) {
return &parallelIncTableUDF{
lock: &sync.Mutex{},
claimed: 0,
n: args[0].(int64),
}, nil
}

func (d *parallelIncTableUDF) Columns() []duckdb.ColumnInfo {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return []duckdb.ColumnInfo{
{Name: "result", T: t},
}
}

func (d *parallelIncTableUDF) Init() duckdb.ThreadedTableSourceInitData {
return duckdb.ThreadedTableSourceInitData{
MaxThreads: 8,
}
}

func (d *parallelIncTableUDF) NewLocalState() any {
return &parallelIncTableLocal{
start: 0,
end: -1,
}
}

func (d *parallelIncTableUDF) FillRow(localState any, row duckdb.Row) (bool, error) {
state := localState.(*parallelIncTableLocal)
if state.start >= state.end {
// claim a new "work" unit
d.lock.Lock()
remaining := d.n - d.claimed
if remaining <= 0 {
// no more work to be done :(
d.lock.Unlock()
return false, nil
} else if remaining >= 2024 {
remaining = 2024
}
state.start = d.claimed
d.claimed += remaining
state.end = d.claimed
d.lock.Unlock()
}
state.start++
err := duckdb.SetRowValue(row, 0, state.start)
return true, err
}

func (d *parallelIncTableUDF) GetValue(r, c int) any {
return int64(r + 1)
}

func (d *parallelIncTableUDF) GetTypes() []any {
return []any{
int(0),
}
}

func (d *parallelIncTableUDF) Cardinality() *duckdb.CardinalityInfo {
return nil
}

func (d *parallelIncTableUDF) GetFunction() duckdb.ThreadedRowTableFunction {
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
return duckdb.ThreadedRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: BindParallelIncTableUDF,
}
}

func main() {
var err error
db, err = sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)
if err != nil {
log.Fatal(err)
}
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved
defer db.Close()
conn, err := db.Conn(context.Background())
check(err)

t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT)
check(err)
fun := duckdb.ThreadedRowTableFunction{
Config: duckdb.TableFunctionConfig{
Arguments: []duckdb.TypeInfo{t},
},
BindArguments: BindParallelIncTableUDF,
}

duckdb.RegisterTableUDF(conn, "inc", fun)

rows, err := db.QueryContext(context.Background(), "SELECT * FROM inc(2048)")
check(err)
defer rows.Close()

// Get column names
columns, err := rows.Columns()
check(err)

values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}

// Fetch rows
for rows.Next() {
err = rows.Scan(scanArgs...)
check(err)
for i, value := range values {
// Keep in mind that the value could be nil for NULL values.
// This never happens in our case, so we don't check for it.
fmt.Printf("%s: %v\n", columns[i], value)
fmt.Printf("Type: %s\n", reflect.TypeOf(value))
}
fmt.Println("-----------------------------------")
}
}

func check(err interface{}) {
if err != nil {
panic(err)
}
}
45 changes: 45 additions & 0 deletions row.go
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package duckdb

/*
#include <duckdb.h>
*/
import "C"

type (
// Row represents one row in duckdb. It references the vectors underneeth.
Row struct {
chunk *DataChunk
r C.idx_t
projection []int
}
)
JAicewizard marked this conversation as resolved.
Show resolved Hide resolved

// Returns whether or now the column is projected
func (r Row) IsProjected(colIdx int) bool {
return r.projection[colIdx] != -1
}

// SetRowValue sets the value at column c to value val.
// Returns an error when the setting the value failled.
// If the row is not projected, nil will be returned, no matter the type.
func SetRowValue[T any](row Row, colIdx int, val T) error {
projectedRowIdx := row.projection[colIdx]
if projectedRowIdx < 0 || projectedRowIdx >= len(row.chunk.columns) {
// we want to allow setting to columns that are not projected,
// it should just be a nop.
return nil
}
vec := row.chunk.columns[projectedRowIdx]
return setVectorVal(&vec, row.r, val)
}

// SetRowValue sets the column c to value val, if possible. If this operation
// fails an error is returned.
func (row Row) SetRowValue(colIdx int, val any) error {
if !row.IsProjected(colIdx) {
// we want to allow setting to columns that are not projected,
// it should just be a nop.
return nil
}
return row.chunk.SetValue(colIdx, int(row.r), val)
}
14 changes: 3 additions & 11 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package duckdb
#include <duckdb.h>

void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
void scalar_udf_delete_callback(void *);
void udf_delete_callback(void *);

typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
*/
Expand Down Expand Up @@ -125,10 +125,7 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err
//export scalar_udf_callback
func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(function_info)

// extraInfo is a void* pointer to our pinned ScalarFunc f.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
function := h.Value().(pinnedValue[ScalarFunc]).value
function := getPinned[ScalarFunc](extraInfo)

// Initialize the input chunk.
var inputChunk DataChunk
Expand Down Expand Up @@ -191,11 +188,6 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da
}
}

//export scalar_udf_delete_callback
func scalar_udf_delete_callback(info unsafe.Pointer) {
udf_delete_callback(info)
}

func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo != nil {
Expand Down Expand Up @@ -285,7 +277,7 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro
C.duckdb_scalar_function_set_extra_info(
function,
unsafe.Pointer(&h),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))
C.duckdb_delete_callback_t(C.udf_delete_callback))

return function, nil
}
1 change: 1 addition & 0 deletions type.go
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ var typeToStringMap = map[Type]string{
TYPE_ANY: "ANY",
TYPE_VARINT: "VARINT",
}

Loading