Skip to content

Commit

Permalink
support function sets
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 18, 2024
1 parent 97f620c commit 33e4627
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 35 deletions.
2 changes: 2 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ var (
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate)
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: cannot add to set", errScalarUDFCreateSet)

errSetSQLNULLValue = errors.New("cannot write to a NULL column")

Expand Down
112 changes: 77 additions & 35 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,52 +137,59 @@ func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duck
return nil
}

// RegisterScalarUDF registers a scalar UDF.
// This function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
func createScalarFunction(name string, f ScalarFunction) (C.duckdb_scalar_function, error) {
if name == "" {
return getError(errAPI, errScalarUDFNoName)
return nil, errScalarUDFNoName
}
if f == nil {
return getError(errAPI, errScalarUDFIsNil)
return nil, errScalarUDFIsNil
}
scalarFunction := C.duckdb_create_scalar_function()

// c.Raw exposes the underlying driver connection.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
functionName := C.CString(name)
defer C.duckdb_free(unsafe.Pointer(functionName))
// Set the name.
functionName := C.CString(name)
C.duckdb_scalar_function_set_name(scalarFunction, functionName)
C.duckdb_free(unsafe.Pointer(functionName))

extraInfoHandle := cgo.NewHandle(f)
// Configure the scalar function.
config := f.Config()
if err := registerInputParameters(config, scalarFunction); err != nil {
return nil, err
}
if err := registerResultParameters(config, scalarFunction); err != nil {
return nil, err
}
if config.SpecialNullHandling() {
C.duckdb_scalar_function_set_special_handling(scalarFunction)
}
if config.Volatile() {
C.duckdb_scalar_function_set_volatile(scalarFunction)
}

scalarFunction := C.duckdb_create_scalar_function()
C.duckdb_scalar_function_set_name(scalarFunction, functionName)
// Set the function callback.
C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback))

// Configure the scalar function.
config := f.Config()
if err := registerInputParameters(config, scalarFunction); err != nil {
return getError(errAPI, err)
}
if err := registerResultParameters(config, scalarFunction); err != nil {
return getError(errAPI, err)
}
if config.SpecialNullHandling() {
C.duckdb_scalar_function_set_special_handling(scalarFunction)
}
if config.Volatile() {
C.duckdb_scalar_function_set_volatile(scalarFunction)
}
// Set data available during execution.
extraInfoHandle := cgo.NewHandle(f)
C.duckdb_scalar_function_set_extra_info(
scalarFunction,
unsafe.Pointer(&extraInfoHandle),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))

// Set the function callback.
C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback))
return scalarFunction, nil
}

// Set data available during execution.
C.duckdb_scalar_function_set_extra_info(
scalarFunction,
unsafe.Pointer(&extraInfoHandle),
C.duckdb_delete_callback_t(C.scalar_udf_delete_callback))
// RegisterScalarUDF registers a scalar UDF.
// This function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
scalarFunction, err := createScalarFunction(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function.
// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunction)
C.duckdb_destroy_scalar_function(&scalarFunction)
if state == C.DuckDBError {
Expand All @@ -192,3 +199,38 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
})
return err
}

func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunction) error {
functionName := C.CString(name)
set := C.duckdb_create_scalar_function_set(functionName)
C.duckdb_free(unsafe.Pointer(functionName))

// Create each function and add it to the set.
for i, f := range functions {
scalarFunction, err := createScalarFunction(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&scalarFunction)
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, err)
}

state := C.duckdb_add_scalar_function_to_set(set, scalarFunction)
C.duckdb_destroy_scalar_function(&scalarFunction)
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
}
}

// Register the function set on the underlying driver connection exposed by c.Raw.
err := c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function_set(con.duckdbCon, set)
C.duckdb_destroy_scalar_function_set(&set)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreateSet)
}
return nil
})
return err
}
28 changes: 28 additions & 0 deletions scalar_udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,34 @@ func (*variadicScalarUDF) ExecuteRow(args []driver.Value) (any, error) {
return sum, nil
}

func TestScalarUDFSet(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)

c, err := db.Conn(context.Background())
require.NoError(t, err)

currentInfo, err = NewTypeInfo(TYPE_INTEGER)
require.NoError(t, err)

var udf1 *simpleScalarUDF
var udf2 *allTypesScalarUDF
err = RegisterScalarUDFSet(c, "my_addition", udf1, udf2)
require.NoError(t, err)

var sum int
row := db.QueryRow(`SELECT my_addition(10, 42) AS sum`)
require.NoError(t, row.Scan(&sum))
require.Equal(t, 52, sum)

row = db.QueryRow(`SELECT my_addition(42) AS sum`)
require.NoError(t, row.Scan(&sum))
require.Equal(t, 42, sum)

require.NoError(t, c.Close())
require.NoError(t, db.Close())
}

func TestVariadicScalarUDF(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)
Expand Down

0 comments on commit 33e4627

Please sign in to comment.