Skip to content

Commit

Permalink
review feedback and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 19, 2024
1 parent 354897a commit 7fd79ca
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 141 deletions.
14 changes: 7 additions & 7 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ var (
errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH)
errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width")

errScalarUDFCreate = errors.New("could not create scalar UDF")
errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate)
errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate)
errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate)
errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate)
errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate)
errScalarUDFCreate = errors.New("could not create scalar UDF")
errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate)
errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate)
errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate)
//errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate)
//errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate)
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY, which is not supported", errScalarUDFCreate)
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: cannot add to set", errScalarUDFCreateSet)

Expand Down
5 changes: 5 additions & 0 deletions examples/scalar_udf/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package main

// TODO: overload the my_len function with LIST(ANY) and VARCHAR

// TODO: mix of variadic and configurable parameters
194 changes: 120 additions & 74 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,111 @@ import (
"unsafe"
)

type rowFn func(args []driver.Value) (any, error)

// ScalarFuncConfig contains the fields to configure a user-defined scalar function.
type ScalarFuncConfig struct {
// InputTypeInfos contains Type information for each input parameter of the scalar function.
InputTypeInfos []TypeInfo
// ResultTypeInfo holds the Type information of the scalar function's result type.
ResultTypeInfo TypeInfo

VariadicTypeInfo *TypeInfo
Volatile bool
// VariadicTypeInfo configures the number of input parameters.
// If this field is nil, then the input parameters match InputTypeInfos.
// Otherwise, the scalar function's input parameters are set to variadic, allowing any number of input parameters.
// The Type of the first len(InputTypeInfos) parameters is configured by InputTypeInfos, and all
// remaining parameters must match the variadic Type. To configure different variadic parameter Type's,
// you must set the VariadicTypeInfo's Type to TYPE_ANY.
VariadicTypeInfo TypeInfo
// Volatile sets the stability of the scalar function to volatile, if true.
// Volatile scalar functions might create a different result per row.
// E.g., RANDOM() is a volatile scalar function.
Volatile bool
// SpecialNullHandling disables the default NULL handling of scalar functions, if true.
// The default NULL handling is NULL in, NULL out. I.e., if any input parameter is NULL, then the result is NULL.
SpecialNullHandling bool
}

// ScalarFuncExecutor contains the callback function to execute a user-defined scalar function.
// Currently, its only field is a row-based executor.
type ScalarFuncExecutor struct {
RowExecutor rowFn
// RowExecutor accepts a row-based execution function.
// args contains the input values, and it returns the row execution result, or error.
RowExecutor func(args []driver.Value) (any, error)
}

// ScalarFunc is the user-defined scalar function interface.
// Any scalar function must implement a Config function, and an Executor function.
type ScalarFunc interface {
// Config returns ScalarFuncConfig to configure the scalar function.
Config() ScalarFuncConfig
// Executor returns ScalarFuncExecutor to execute the scalar function.
Executor() ScalarFuncExecutor
}

// RegisterScalarUDF registers a user-defined scalar function.
// c is the SQL connection on which to register the scalar function.
// name is the function name, and f is the scalar function's interface.
// RegisterScalarUDF takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
function, err := createScalarFunc(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
return nil
})
return err
}

// RegisterScalarUDFSet registers a set of user-defined scalar functions with the same name.
// This allows overloading of scalar functions.
// E.g., it allows overloading the function my_length() with different implementations
// like my_length(LIST(ANY)) and my_length(VARCHAR).
// c is the SQL connection on which to register the scalar function set.
// name is the function name of each function in the set.
// functions contains all functions of the scalar function set.
func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
cName := C.CString(name)
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
function, err := createScalarFunc(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&function)
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, err)
}

state := C.duckdb_add_scalar_function_to_set(set, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
}
}

// Register the function set on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreateSet)
}
return nil
})
return err
}

func setFuncError(function_info C.duckdb_function_info, msg string) {
err := C.CString(msg)
C.duckdb_scalar_function_set_error(function_info, err)
Expand Down Expand Up @@ -69,32 +154,50 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da

// Execute the user-defined scalar function for each row.
executor := function.Executor()
nullInNullOut := !function.Config().SpecialNullHandling
values := make([]driver.Value, len(inputChunk.columns))
columnCount := len(values)
rowCount := inputChunk.GetSize()

// Set the values for each row by invoking the callback function.
var err error
for rowIdx := 0; rowIdx < inputChunk.GetSize(); rowIdx++ {
// Set the values for each row.
for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
nullRow := false

// Get each column value.
for colIdx := 0; colIdx < columnCount; colIdx++ {
if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// NULL handling.
if nullInNullOut && values[colIdx] == nil {
if err = outputChunk.SetValue(0, rowIdx, nil); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}
nullRow = true
break
}
}
if nullRow {
continue
}

// Execute the function and write the result to the output vector.
// Execute the function.
var val any
if val, err = executor.RowExecutor(values); err != nil {
break
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// Write the result to the output chunk.
if err = outputChunk.SetValue(0, rowIdx, val); err != nil {
break
setFuncError(function_info, getError(errAPI, err).Error())
return
}
}

if err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
}
}

//export scalar_udf_delete_callback
Expand All @@ -106,18 +209,17 @@ func scalar_udf_delete_callback(extraInfo unsafe.Pointer) {
func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo != nil {
t := (*config.VariadicTypeInfo).logicalType()
t := config.VariadicTypeInfo.logicalType()
C.duckdb_scalar_function_set_varargs(f, t)
C.duckdb_destroy_logical_type(&t)
return nil
}

// Set normal input parameters.
if config.InputTypeInfos == nil {
return errScalarUDFNilInputTypes
return nil
}
if len(config.InputTypeInfos) == 0 {
return errScalarUDFEmptyInputTypes
return nil
}

for i, info := range config.InputTypeInfos {
Expand Down Expand Up @@ -189,59 +291,3 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro

return function, nil
}

// RegisterScalarUDF registers a scalar user-defined function.
// The function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
function, err := createScalarFunc(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
return nil
})
return err
}

func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
cName := C.CString(name)
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
function, err := createScalarFunc(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&function)
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, err)
}

state := C.duckdb_add_scalar_function_to_set(set, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
}
}

// Register the function set on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreateSet)
}
return nil
})
return err
}
Loading

0 comments on commit 7fd79ca

Please sign in to comment.