Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Scalar UDF support #222

Merged
merged 45 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
f5aaa94
use duckdb's feature branch
taniabogatsch May 24, 2024
6da3dbd
Re-build static libraries
taniabogatsch May 24, 2024
274e029
trigger tests
taniabogatsch May 24, 2024
0687448
trigger tests
taniabogatsch May 24, 2024
032239d
initial commit towards a data chunk api
taniabogatsch May 27, 2024
a231352
towards exposing an amazing DataChunk
taniabogatsch May 28, 2024
23522f6
Merge branch 'data-chunks' into scalar
taniabogatsch May 31, 2024
4bf7c62
initial scalar UDF support
taniabogatsch May 31, 2024
cc2a104
more initialisation and primitive getter
taniabogatsch May 31, 2024
8781142
Merge branch 'data-chunks' into scalar
taniabogatsch May 31, 2024
b63142c
remove capi example
taniabogatsch May 31, 2024
dc9ec28
Merge branch 'main' into scalar
taniabogatsch Sep 9, 2024
6c48a4a
merge fixes
taniabogatsch Sep 9, 2024
f2ee037
Re-build static libraries
taniabogatsch Sep 9, 2024
6924aae
Merge branch 'type-interface' into scalar
taniabogatsch Sep 11, 2024
8b51aba
update code to current code base
taniabogatsch Sep 11, 2024
905bc5e
changes related to type info
taniabogatsch Sep 11, 2024
f1d079a
trying to get the handle to work
taniabogatsch Sep 11, 2024
49f7b7a
fix simple example
taniabogatsch Sep 12, 2024
1833f7d
test all types in scalar UDFs
taniabogatsch Sep 12, 2024
f6a953f
add remaining tests
taniabogatsch Sep 12, 2024
6d4a794
nit
taniabogatsch Sep 13, 2024
58b5dde
Merge branch 'type-interface' into scalar
taniabogatsch Sep 13, 2024
4e58e16
Merge branch 'main' into scalar
taniabogatsch Sep 17, 2024
c1d0cf7
resolve merge and changes
taniabogatsch Sep 17, 2024
50887d7
remove error from config and add more tests
taniabogatsch Sep 17, 2024
0f3619b
variadic support
taniabogatsch Sep 17, 2024
ab9b9fc
variadic tests
taniabogatsch Sep 17, 2024
cd7dc56
adding ANY
taniabogatsch Sep 17, 2024
31ef67c
add workaround for SQLNULL
taniabogatsch Sep 17, 2024
97f620c
fix SQL NULL in nested types
taniabogatsch Sep 17, 2024
33e4627
support function sets
taniabogatsch Sep 18, 2024
bfef4ee
tidying things up
taniabogatsch Sep 18, 2024
9f8546d
started to add feedback
taniabogatsch Sep 18, 2024
354897a
make executor extensible
taniabogatsch Sep 19, 2024
7fd79ca
review feedback and documentation
taniabogatsch Sep 19, 2024
493f151
tidy tests, add pinner, other nits
taniabogatsch Sep 19, 2024
11cb4ac
nits
taniabogatsch Sep 19, 2024
172a22b
add tests
taniabogatsch Sep 19, 2024
147fdfa
Merge branch 'main' into scalar
taniabogatsch Sep 23, 2024
5e287cd
formatter
taniabogatsch Sep 23, 2024
5b389f0
uuid cast
taniabogatsch Sep 23, 2024
7906697
create udf utility file
taniabogatsch Sep 23, 2024
847098b
Merge branch 'main' into scalar
taniabogatsch Sep 23, 2024
649f3c9
nit
taniabogatsch Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions data_chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (chunk *DataChunk) GetValue(colIdx int, rowIdx int) (any, error) {
return nil, getError(errAPI, columnCountError(colIdx, len(chunk.columns)))
}
column := &chunk.columns[colIdx]
if column.isSQLNull {
return nil, nil
}
return column.getFn(column, C.idx_t(rowIdx)), nil
}

Expand All @@ -58,6 +61,10 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error {
}
column := &chunk.columns[colIdx]

if column.isSQLNull {
return getError(errAPI, errSetSQLNULLValue)
}

// Ensure that the types match before attempting to set anything.
// This is done to prevent failures 'halfway through' writing column values,
// potentially corrupting data in that column.
Expand Down Expand Up @@ -126,6 +133,23 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable
return err
}

func (chunk *DataChunk) initFromDuckVector(duckdbVector C.duckdb_vector, writable bool) error {
columnCount := 1
chunk.columns = make([]vector, columnCount)

// Initialize the callback functions to read and write values.
logicalType := C.duckdb_vector_get_column_type(duckdbVector)
err := chunk.columns[0].init(logicalType, 0)
C.duckdb_destroy_logical_type(&logicalType)
if err != nil {
return err
}

// Initialize the vector and its child vectors.
chunk.columns[0].initVectors(duckdbVector, writable)
return nil
}

func (chunk *DataChunk) close() {
C.duckdb_destroy_data_chunk(&chunk.data)
}
31 changes: 24 additions & 7 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,36 @@ var (

errUnsupportedMapKeyType = errors.New("MAP key type not supported")

errAppenderInvalidCon = errors.New("could not create appender: not a DuckDB driver connection")
errAppenderClosedCon = errors.New("could not create appender: appender creation on a closed connection")
errAppenderCreation = errors.New("could not create appender")
errAppenderDoubleClose = errors.New("could not close appender: already closed")
errAppenderCreation = errors.New("could not create appender")
errAppenderInvalidCon = fmt.Errorf("%w: not a DuckDB driver connection", errAppenderCreation)
errAppenderClosedCon = fmt.Errorf("%w: appender creation on a closed connection", errAppenderCreation)

errAppenderClose = errors.New("could not close appender")
errAppenderDoubleClose = fmt.Errorf("%w: already closed", errAppenderClose)

errAppenderAppendRow = errors.New("could not append row")
errAppenderAppendAfterClose = errors.New("could not append row: appender already closed")
errAppenderClose = errors.New("could not close appender")
errAppenderFlush = errors.New("could not flush appender")
errAppenderAppendAfterClose = fmt.Errorf("%w: appender already closed", errAppenderAppendRow)

errAppenderFlush = errors.New("could not flush appender")

errEmptyName = errors.New("empty name")
errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH)
errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width")

errScalarUDFCreate = errors.New("could not create scalar UDF")
errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate)
errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate)
errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate)
errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate)
errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate)
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate)
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: cannot add to set", errScalarUDFCreateSet)

errSetSQLNULLValue = errors.New("cannot write to a NULL column")

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
Expand Down
247 changes: 247 additions & 0 deletions scalar_udf.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
package duckdb

// Related issues: https://golang.org/issue/19835, https://golang.org/issue/19837.

/*
#include <duckdb.h>

void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
void scalar_udf_delete_callback(void *);

typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector);
*/
import "C"

import (
"database/sql"
"database/sql/driver"
"runtime/cgo"
"unsafe"
)

type rowFn func(args []driver.Value) (any, error)

type ScalarFuncConfig struct {
InputTypeInfos []TypeInfo
ResultTypeInfo TypeInfo

VariadicTypeInfo *TypeInfo
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
Volatile bool
SpecialNullHandling bool
}

type ScalarFuncExecutor struct {
RowExecutor rowFn
}

type ScalarFunc interface {
Config() ScalarFuncConfig
Executor() ScalarFuncExecutor
}

func setFuncError(function_info C.duckdb_function_info, msg string) {
err := C.CString(msg)
C.duckdb_scalar_function_set_error(function_info, err)
C.duckdb_free(unsafe.Pointer(err))
}

//export scalar_udf_callback
func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(function_info)

// extraInfo is a void* pointer to our ScalarFunc.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
function := h.Value().(ScalarFunc)

// Initialize the input chunk.
var inputChunk DataChunk
if err := inputChunk.initFromDuckDataChunk(input, false); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// Initialize the output chunk.
var outputChunk DataChunk
if err := outputChunk.initFromDuckVector(output, true); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// Execute the user-defined scalar function for each row.
executor := function.Executor()
values := make([]driver.Value, len(inputChunk.columns))
columnCount := len(values)

var err error
for rowIdx := 0; rowIdx < inputChunk.GetSize(); rowIdx++ {
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
// Set the values for each row.
for colIdx := 0; colIdx < columnCount; colIdx++ {
if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}
}

// Execute the function and write the result to the output vector.
var val any
if val, err = executor.RowExecutor(values); err != nil {
break
}
if err = outputChunk.SetValue(0, rowIdx, val); err != nil {
break
}
}

if err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
}
}

//export scalar_udf_delete_callback
func scalar_udf_delete_callback(extraInfo unsafe.Pointer) {
h := (*cgo.Handle)(extraInfo)
h.Delete()
}

taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo != nil {
t := (*config.VariadicTypeInfo).logicalType()
C.duckdb_scalar_function_set_varargs(f, t)
C.duckdb_destroy_logical_type(&t)
return nil
}
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved

// Set normal input parameters.
if config.InputTypeInfos == nil {
return errScalarUDFNilInputTypes
}
if len(config.InputTypeInfos) == 0 {
return errScalarUDFEmptyInputTypes
}

for i, info := range config.InputTypeInfos {
if info == nil {
return addIndexToError(errScalarUDFInputTypeIsNil, i)
}
t := info.logicalType()
C.duckdb_scalar_function_add_parameter(f, t)
C.duckdb_destroy_logical_type(&t)
}
return nil
}

func registerResultParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
if config.ResultTypeInfo == nil {
return errScalarUDFResultTypeIsNil
}
if config.ResultTypeInfo.InternalType() == TYPE_ANY {
return errScalarUDFResultTypeIsANY
}
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
t := config.ResultTypeInfo.logicalType()
C.duckdb_scalar_function_set_return_type(f, t)
C.duckdb_destroy_logical_type(&t)
return nil
}

func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, error) {
if name == "" {
return nil, errScalarUDFNoName
}
if f == nil {
return nil, errScalarUDFIsNil
}
if f.Executor().RowExecutor == nil {
return nil, errScalarUDFNoExecutor
}

function := C.duckdb_create_scalar_function()

// Set the name.
cName := C.CString(name)
C.duckdb_scalar_function_set_name(function, cName)
C.duckdb_free(unsafe.Pointer(cName))

// Configure the scalar function.
config := f.Config()
if err := registerInputParams(config, function); err != nil {
return nil, err
}
if err := registerResultParams(config, function); err != nil {
return nil, err
}
if config.SpecialNullHandling {
C.duckdb_scalar_function_set_special_handling(function)
}
if config.Volatile {
C.duckdb_scalar_function_set_volatile(function)
}

// Set the function callback.
C.duckdb_scalar_function_set_function(function, C.scalar_udf_callback_t(C.scalar_udf_callback))

// Set data available during execution.
h := cgo.NewHandle(f)
C.duckdb_scalar_function_set_extra_info(
function,
unsafe.Pointer(&h),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))

return function, nil
}

// RegisterScalarUDF registers a scalar user-defined function.
// The function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
function, err := createScalarFunc(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
return nil
})
return err
}

func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
taniabogatsch marked this conversation as resolved.
Show resolved Hide resolved
cName := C.CString(name)
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
function, err := createScalarFunc(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&function)
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, err)
}

state := C.duckdb_add_scalar_function_to_set(set, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
}
}

// Register the function set on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreateSet)
}
return nil
})
return err
}
Loading