Skip to content

Commit

Permalink
tidying things up
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 18, 2024
1 parent 33e4627 commit bfef4ee
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 263 deletions.
125 changes: 65 additions & 60 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,69 @@ import (
"unsafe"
)

type ScalarFunctionConfig interface {
type ScalarFuncConfig interface {
InputTypeInfos() []TypeInfo
ResultTypeInfo() TypeInfo
}

type ScalarFuncExtraInfo interface {
VariadicTypeInfo() TypeInfo
Volatile() bool
SpecialNullHandling() bool
}

type ScalarFunction interface {
Config() ScalarFunctionConfig
type ScalarFunc interface {
Config() ScalarFuncConfig
ExtraInfo() ScalarFuncExtraInfo
ExecuteRow(args []driver.Value) (any, error)
}

func setFunctionError(info C.duckdb_function_info, msg string) {
func setFuncError(function_info C.duckdb_function_info, msg string) {
err := C.CString(msg)
C.duckdb_scalar_function_set_error(info, err)
C.duckdb_scalar_function_set_error(function_info, err)
C.duckdb_free(unsafe.Pointer(err))
}

//export scalar_udf_callback
func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(info)
func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
extraInfo := C.duckdb_scalar_function_get_extra_info(function_info)

// extraInfo is a void* pointer to our ScalarFunction.
// extraInfo is a void* pointer to our ScalarFunc.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
scalarFunction := h.Value().(ScalarFunction)
function := h.Value().(ScalarFunc)

// Initialize the input chunk.
var inputChunk DataChunk
if err := inputChunk.initFromDuckDataChunk(input, false); err != nil {
setFunctionError(info, getError(errAPI, err).Error())
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// Initialize the output chunk.
var outputChunk DataChunk
if err := outputChunk.initFromDuckVector(output, true); err != nil {
setFunctionError(info, getError(errAPI, err).Error())
setFuncError(function_info, getError(errAPI, err).Error())
return
}

// Execute the user-defined scalar function for each row.
args := make([]driver.Value, len(inputChunk.columns))
values := make([]driver.Value, len(inputChunk.columns))
rowCount := inputChunk.GetSize()
columnCount := len(args)
columnCount := len(values)
var err error

for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
// Set the input arguments for each column of a row.
// Set the values for each row.
for colIdx := 0; colIdx < columnCount; colIdx++ {
if args[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
setFunctionError(info, getError(errAPI, err).Error())
if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
setFuncError(function_info, getError(errAPI, err).Error())
return
}
}

// Execute the function and write the result to the output vector.
var val any
if val, err = scalarFunction.ExecuteRow(args); err != nil {
if val, err = function.ExecuteRow(values); err != nil {
break
}
if err = outputChunk.SetValue(0, rowIdx, val); err != nil {
Expand All @@ -86,7 +90,7 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk,
}

if err != nil {
setFunctionError(info, getError(errAPI, err).Error())
setFuncError(function_info, getError(errAPI, err).Error())
}
}

Expand All @@ -96,102 +100,103 @@ func scalar_udf_delete_callback(extraInfo unsafe.Pointer) {
h.Delete()
}

func registerInputParameters(config ScalarFunctionConfig, scalarFunction C.duckdb_scalar_function) error {
func registerInputParams(config ScalarFuncConfig, extraInfo ScalarFuncExtraInfo, f C.duckdb_scalar_function) error {
// Set variadic input parameters.
if config.VariadicTypeInfo() != nil {
logicalType := config.VariadicTypeInfo().logicalType()
C.duckdb_scalar_function_set_varargs(scalarFunction, logicalType)
C.duckdb_destroy_logical_type(&logicalType)
if extraInfo != nil && extraInfo.VariadicTypeInfo() != nil {
t := extraInfo.VariadicTypeInfo().logicalType()
C.duckdb_scalar_function_set_varargs(f, t)
C.duckdb_destroy_logical_type(&t)
return nil
}

// Set fixed input parameters.
// Set normal input parameters.
if config.InputTypeInfos() == nil {
return errScalarUDFNilInputTypes
}
if len(config.InputTypeInfos()) == 0 {
return errScalarUDFEmptyInputTypes
}

for i, inputTypeInfo := range config.InputTypeInfos() {
if inputTypeInfo == nil {
for i, info := range config.InputTypeInfos() {
if info == nil {
return addIndexToError(errScalarUDFInputTypeIsNil, i)
}
logicalType := inputTypeInfo.logicalType()
C.duckdb_scalar_function_add_parameter(scalarFunction, logicalType)
C.duckdb_destroy_logical_type(&logicalType)
t := info.logicalType()
C.duckdb_scalar_function_add_parameter(f, t)
C.duckdb_destroy_logical_type(&t)
}
return nil
}

func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duckdb_scalar_function) error {
func registerResultParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error {
if config.ResultTypeInfo() == nil {
return errScalarUDFResultTypeIsNil
}
if config.ResultTypeInfo().InternalType() == TYPE_ANY {
return errScalarUDFResultTypeIsANY
}
logicalType := config.ResultTypeInfo().logicalType()
C.duckdb_scalar_function_set_return_type(scalarFunction, logicalType)
C.duckdb_destroy_logical_type(&logicalType)
t := config.ResultTypeInfo().logicalType()
C.duckdb_scalar_function_set_return_type(f, t)
C.duckdb_destroy_logical_type(&t)
return nil
}

func createScalarFunction(name string, f ScalarFunction) (C.duckdb_scalar_function, error) {
func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, error) {
if name == "" {
return nil, errScalarUDFNoName
}
if f == nil {
return nil, errScalarUDFIsNil
}
scalarFunction := C.duckdb_create_scalar_function()
function := C.duckdb_create_scalar_function()

// Set the name.
functionName := C.CString(name)
C.duckdb_scalar_function_set_name(scalarFunction, functionName)
C.duckdb_free(unsafe.Pointer(functionName))
cName := C.CString(name)
C.duckdb_scalar_function_set_name(function, cName)
C.duckdb_free(unsafe.Pointer(cName))

// Configure the scalar function.
config := f.Config()
if err := registerInputParameters(config, scalarFunction); err != nil {
extraInfo := f.ExtraInfo()
if err := registerInputParams(config, extraInfo, function); err != nil {
return nil, err
}
if err := registerResultParameters(config, scalarFunction); err != nil {
if err := registerResultParams(config, function); err != nil {
return nil, err
}
if config.SpecialNullHandling() {
C.duckdb_scalar_function_set_special_handling(scalarFunction)
if extraInfo != nil && extraInfo.SpecialNullHandling() {
C.duckdb_scalar_function_set_special_handling(function)
}
if config.Volatile() {
C.duckdb_scalar_function_set_volatile(scalarFunction)
if extraInfo != nil && extraInfo.Volatile() {
C.duckdb_scalar_function_set_volatile(function)
}

// Set the function callback.
C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback))
C.duckdb_scalar_function_set_function(function, C.scalar_udf_callback_t(C.scalar_udf_callback))

// Set data available during execution.
extraInfoHandle := cgo.NewHandle(f)
h := cgo.NewHandle(f)
C.duckdb_scalar_function_set_extra_info(
scalarFunction,
unsafe.Pointer(&extraInfoHandle),
function,
unsafe.Pointer(&h),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))

return scalarFunction, nil
return function, nil
}

// RegisterScalarUDF registers a scalar UDF.
// This function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
scalarFunction, err := createScalarFunction(name, f)
// RegisterScalarUDF registers a scalar user-defined function.
// The function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
scalarFunc, err := createScalarFunc(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunction)
C.duckdb_destroy_scalar_function(&scalarFunction)
state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunc)
C.duckdb_destroy_scalar_function(&scalarFunc)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
Expand All @@ -200,14 +205,14 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
return err
}

func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunction) error {
functionName := C.CString(name)
set := C.duckdb_create_scalar_function_set(functionName)
C.duckdb_free(unsafe.Pointer(functionName))
func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error {
cName := C.CString(name)
set := C.duckdb_create_scalar_function_set(cName)
C.duckdb_free(unsafe.Pointer(cName))

// Create each function and add it to the set.
for i, f := range functions {
scalarFunction, err := createScalarFunction(name, f)
scalarFunction, err := createScalarFunc(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&scalarFunction)
C.duckdb_destroy_scalar_function_set(&set)
Expand Down
Loading

0 comments on commit bfef4ee

Please sign in to comment.