diff --git a/errors.go b/errors.go index 8e122f1a..cb2843b6 100644 --- a/errors.go +++ b/errors.go @@ -106,6 +106,8 @@ var ( errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate) errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate) errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate) + errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set") + errScalarUDFAddToSet = fmt.Errorf("%w: cannot add to set", errScalarUDFCreateSet) errSetSQLNULLValue = errors.New("cannot write to a NULL column") diff --git a/scalar_udf.go b/scalar_udf.go index e5832001..c1d55fb0 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -137,52 +137,59 @@ func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duck return nil } -// RegisterScalarUDF registers a scalar UDF. -// This function takes ownership of f, so you must pass it as a pointer. -func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { +func createScalarFunction(name string, f ScalarFunction) (C.duckdb_scalar_function, error) { if name == "" { - return getError(errAPI, errScalarUDFNoName) + return nil, errScalarUDFNoName } if f == nil { - return getError(errAPI, errScalarUDFIsNil) + return nil, errScalarUDFIsNil } + scalarFunction := C.duckdb_create_scalar_function() - // c.Raw exposes the underlying driver connection. - err := c.Raw(func(driverConn any) error { - con := driverConn.(*conn) - functionName := C.CString(name) - defer C.duckdb_free(unsafe.Pointer(functionName)) + // Set the name. + functionName := C.CString(name) + C.duckdb_scalar_function_set_name(scalarFunction, functionName) + C.duckdb_free(unsafe.Pointer(functionName)) - extraInfoHandle := cgo.NewHandle(f) + // Configure the scalar function. + config := f.Config() + if err := registerInputParameters(config, scalarFunction); err != nil { + return nil, err + } + if err := registerResultParameters(config, scalarFunction); err != nil { + return nil, err + } + if config.SpecialNullHandling() { + C.duckdb_scalar_function_set_special_handling(scalarFunction) + } + if config.Volatile() { + C.duckdb_scalar_function_set_volatile(scalarFunction) + } - scalarFunction := C.duckdb_create_scalar_function() - C.duckdb_scalar_function_set_name(scalarFunction, functionName) + // Set the function callback. + C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback)) - // Configure the scalar function. - config := f.Config() - if err := registerInputParameters(config, scalarFunction); err != nil { - return getError(errAPI, err) - } - if err := registerResultParameters(config, scalarFunction); err != nil { - return getError(errAPI, err) - } - if config.SpecialNullHandling() { - C.duckdb_scalar_function_set_special_handling(scalarFunction) - } - if config.Volatile() { - C.duckdb_scalar_function_set_volatile(scalarFunction) - } + // Set data available during execution. + extraInfoHandle := cgo.NewHandle(f) + C.duckdb_scalar_function_set_extra_info( + scalarFunction, + unsafe.Pointer(&extraInfoHandle), + C.duckdb_delete_callback_t(C.scalar_udf_delete_callback)) - // Set the function callback. - C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback)) + return scalarFunction, nil +} - // Set data available during execution. - C.duckdb_scalar_function_set_extra_info( - scalarFunction, - unsafe.Pointer(&extraInfoHandle), - C.duckdb_delete_callback_t(C.scalar_udf_delete_callback)) +// RegisterScalarUDF registers a scalar UDF. +// This function takes ownership of f, so you must pass it as a pointer. +func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { + scalarFunction, err := createScalarFunction(name, f) + if err != nil { + return getError(errAPI, err) + } - // Register the function. + // Register the function on the underlying driver connection exposed by c.Raw. + err = c.Raw(func(driverConn any) error { + con := driverConn.(*conn) state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunction) C.duckdb_destroy_scalar_function(&scalarFunction) if state == C.DuckDBError { @@ -192,3 +199,38 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { }) return err } + +func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunction) error { + functionName := C.CString(name) + set := C.duckdb_create_scalar_function_set(functionName) + C.duckdb_free(unsafe.Pointer(functionName)) + + // Create each function and add it to the set. + for i, f := range functions { + scalarFunction, err := createScalarFunction(name, f) + if err != nil { + C.duckdb_destroy_scalar_function(&scalarFunction) + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, err) + } + + state := C.duckdb_add_scalar_function_to_set(set, scalarFunction) + C.duckdb_destroy_scalar_function(&scalarFunction) + if state == C.DuckDBError { + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i)) + } + } + + // Register the function set on the underlying driver connection exposed by c.Raw. + err := c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_scalar_function_set(con.duckdbCon, set) + C.duckdb_destroy_scalar_function_set(&set) + if state == C.DuckDBError { + return getError(errAPI, errScalarUDFCreateSet) + } + return nil + }) + return err +} diff --git a/scalar_udf_test.go b/scalar_udf_test.go index 0687352b..dc467c16 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -183,6 +183,34 @@ func (*variadicScalarUDF) ExecuteRow(args []driver.Value) (any, error) { return sum, nil } +func TestScalarUDFSet(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf1 *simpleScalarUDF + var udf2 *allTypesScalarUDF + err = RegisterScalarUDFSet(c, "my_addition", udf1, udf2) + require.NoError(t, err) + + var sum int + row := db.QueryRow(`SELECT my_addition(10, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 52, sum) + + row = db.QueryRow(`SELECT my_addition(42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 42, sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + func TestVariadicScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err)