diff --git a/errors.go b/errors.go index 813cf3ec..5aefe309 100644 --- a/errors.go +++ b/errors.go @@ -98,15 +98,15 @@ var ( errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH) errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width") - errScalarUDFCreate = errors.New("could not create scalar UDF") - errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate) - errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate) - errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate) - errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate) - errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate) + errScalarUDFCreate = errors.New("could not create scalar UDF") + errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate) + errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate) + errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate) + //errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate) + //errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate) errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate) errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate) - errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate) + errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY, which is not supported", errScalarUDFCreate) errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set") errScalarUDFAddToSet = fmt.Errorf("%w: cannot add to set", errScalarUDFCreateSet) diff --git a/examples/scalar_udf/main.go b/examples/scalar_udf/main.go new file mode 100644 index 00000000..17b287c9 --- /dev/null +++ b/examples/scalar_udf/main.go @@ -0,0 +1,5 @@ +package main + +// TODO: overload the my_len function with LIST(ANY) and VARCHAR + +// TODO: mix of variadic and configurable parameters diff --git a/scalar_udf.go b/scalar_udf.go index da2dfb0c..0dff3215 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -19,26 +19,111 @@ import ( "unsafe" ) -type rowFn func(args []driver.Value) (any, error) - +// ScalarFuncConfig contains the fields to configure a user-defined scalar function. type ScalarFuncConfig struct { + // InputTypeInfos contains Type information for each input parameter of the scalar function. InputTypeInfos []TypeInfo + // ResultTypeInfo holds the Type information of the scalar function's result type. ResultTypeInfo TypeInfo - VariadicTypeInfo *TypeInfo - Volatile bool + // VariadicTypeInfo configures the number of input parameters. + // If this field is nil, then the input parameters match InputTypeInfos. + // Otherwise, the scalar function's input parameters are set to variadic, allowing any number of input parameters. + // The Type of the first len(InputTypeInfos) parameters is configured by InputTypeInfos, and all + // remaining parameters must match the variadic Type. To configure different variadic parameter Type's, + // you must set the VariadicTypeInfo's Type to TYPE_ANY. + VariadicTypeInfo TypeInfo + // Volatile sets the stability of the scalar function to volatile, if true. + // Volatile scalar functions might create a different result per row. + // E.g., RANDOM() is a volatile scalar function. + Volatile bool + // SpecialNullHandling disables the default NULL handling of scalar functions, if true. + // The default NULL handling is NULL in, NULL out. I.e., if any input parameter is NULL, then the result is NULL. SpecialNullHandling bool } +// ScalarFuncExecutor contains the callback function to execute a user-defined scalar function. +// Currently, its only field is a row-based executor. type ScalarFuncExecutor struct { - RowExecutor rowFn + // RowExecutor accepts a row-based execution function. + // args contains the input values, and it returns the row execution result, or error. + RowExecutor func(args []driver.Value) (any, error) } +// ScalarFunc is the user-defined scalar function interface. +// Any scalar function must implement a Config function, and an Executor function. type ScalarFunc interface { + // Config returns ScalarFuncConfig to configure the scalar function. Config() ScalarFuncConfig + // Executor returns ScalarFuncExecutor to execute the scalar function. Executor() ScalarFuncExecutor } +// RegisterScalarUDF registers a user-defined scalar function. +// c is the SQL connection on which to register the scalar function. +// name is the function name, and f is the scalar function's interface. +// RegisterScalarUDF takes ownership of f, so you must pass it as a pointer. +func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { + function, err := createScalarFunc(name, f) + if err != nil { + return getError(errAPI, err) + } + + // Register the function on the underlying driver connection exposed by c.Raw. + err = c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_scalar_function(con.duckdbCon, function) + C.duckdb_destroy_scalar_function(&function) + if state == C.DuckDBError { + return getError(errAPI, errScalarUDFCreate) + } + return nil + }) + return err +} + +// RegisterScalarUDFSet registers a set of user-defined scalar functions with the same name. +// This allows overloading of scalar functions. +// E.g., it allows overloading the function my_length() with different implementations +// like my_length(LIST(ANY)) and my_length(VARCHAR). +// c is the SQL connection on which to register the scalar function set. +// name is the function name of each function in the set. +// functions contains all functions of the scalar function set. +func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error { + cName := C.CString(name) + set := C.duckdb_create_scalar_function_set(cName) + C.duckdb_free(unsafe.Pointer(cName)) + + // Create each function and add it to the set. + for i, f := range functions { + function, err := createScalarFunc(name, f) + if err != nil { + C.duckdb_destroy_scalar_function(&function) + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, err) + } + + state := C.duckdb_add_scalar_function_to_set(set, function) + C.duckdb_destroy_scalar_function(&function) + if state == C.DuckDBError { + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i)) + } + } + + // Register the function set on the underlying driver connection exposed by c.Raw. + err := c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_scalar_function_set(con.duckdbCon, set) + C.duckdb_destroy_scalar_function_set(&set) + if state == C.DuckDBError { + return getError(errAPI, errScalarUDFCreateSet) + } + return nil + }) + return err +} + func setFuncError(function_info C.duckdb_function_info, msg string) { err := C.CString(msg) C.duckdb_scalar_function_set_error(function_info, err) @@ -69,32 +154,50 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da // Execute the user-defined scalar function for each row. executor := function.Executor() + nullInNullOut := !function.Config().SpecialNullHandling values := make([]driver.Value, len(inputChunk.columns)) columnCount := len(values) + rowCount := inputChunk.GetSize() + // Set the values for each row by invoking the callback function. var err error - for rowIdx := 0; rowIdx < inputChunk.GetSize(); rowIdx++ { - // Set the values for each row. + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + nullRow := false + + // Get each column value. for colIdx := 0; colIdx < columnCount; colIdx++ { if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { setFuncError(function_info, getError(errAPI, err).Error()) return } + + // NULL handling. + if nullInNullOut && values[colIdx] == nil { + if err = outputChunk.SetValue(0, rowIdx, nil); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + nullRow = true + break + } + } + if nullRow { + continue } - // Execute the function and write the result to the output vector. + // Execute the function. var val any if val, err = executor.RowExecutor(values); err != nil { - break + setFuncError(function_info, getError(errAPI, err).Error()) + return } + + // Write the result to the output chunk. if err = outputChunk.SetValue(0, rowIdx, val); err != nil { - break + setFuncError(function_info, getError(errAPI, err).Error()) + return } } - - if err != nil { - setFuncError(function_info, getError(errAPI, err).Error()) - } } //export scalar_udf_delete_callback @@ -106,18 +209,17 @@ func scalar_udf_delete_callback(extraInfo unsafe.Pointer) { func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error { // Set variadic input parameters. if config.VariadicTypeInfo != nil { - t := (*config.VariadicTypeInfo).logicalType() + t := config.VariadicTypeInfo.logicalType() C.duckdb_scalar_function_set_varargs(f, t) C.duckdb_destroy_logical_type(&t) - return nil } // Set normal input parameters. if config.InputTypeInfos == nil { - return errScalarUDFNilInputTypes + return nil } if len(config.InputTypeInfos) == 0 { - return errScalarUDFEmptyInputTypes + return nil } for i, info := range config.InputTypeInfos { @@ -189,59 +291,3 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro return function, nil } - -// RegisterScalarUDF registers a scalar user-defined function. -// The function takes ownership of f, so you must pass it as a pointer. -func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { - function, err := createScalarFunc(name, f) - if err != nil { - return getError(errAPI, err) - } - - // Register the function on the underlying driver connection exposed by c.Raw. - err = c.Raw(func(driverConn any) error { - con := driverConn.(*conn) - state := C.duckdb_register_scalar_function(con.duckdbCon, function) - C.duckdb_destroy_scalar_function(&function) - if state == C.DuckDBError { - return getError(errAPI, errScalarUDFCreate) - } - return nil - }) - return err -} - -func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error { - cName := C.CString(name) - set := C.duckdb_create_scalar_function_set(cName) - C.duckdb_free(unsafe.Pointer(cName)) - - // Create each function and add it to the set. - for i, f := range functions { - function, err := createScalarFunc(name, f) - if err != nil { - C.duckdb_destroy_scalar_function(&function) - C.duckdb_destroy_scalar_function_set(&set) - return getError(errAPI, err) - } - - state := C.duckdb_add_scalar_function_to_set(set, function) - C.duckdb_destroy_scalar_function(&function) - if state == C.DuckDBError { - C.duckdb_destroy_scalar_function_set(&set) - return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i)) - } - } - - // Register the function set on the underlying driver connection exposed by c.Raw. - err := c.Raw(func(driverConn any) error { - con := driverConn.(*conn) - state := C.duckdb_register_scalar_function_set(con.duckdbCon, set) - C.duckdb_destroy_scalar_function_set(&set) - if state == C.DuckDBError { - return getError(errAPI, errScalarUDFCreateSet) - } - return nil - }) - return err -} diff --git a/scalar_udf_test.go b/scalar_udf_test.go index afc8e0e9..bb061397 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -50,18 +50,81 @@ func TestSimpleScalarUDF(t *testing.T) { err = RegisterScalarUDF(c, "my_sum", udf) require.NoError(t, err) - var msg *int - row := db.QueryRow(`SELECT my_sum(10, 42) AS msg`) - require.NoError(t, row.Scan(&msg)) - require.Equal(t, 52, *msg) + var sum *int + row := db.QueryRow(`SELECT my_sum(10, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 52, *sum) - row = db.QueryRow(`SELECT my_sum(NULL, 42) AS msg`) - require.NoError(t, row.Scan(&msg)) - require.Equal(t, (*int)(nil), msg) + row = db.QueryRow(`SELECT my_sum(NULL, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) - row = db.QueryRow(`SELECT my_sum(42, NULL) AS msg`) - require.NoError(t, row.Scan(&msg)) - require.Equal(t, (*int)(nil), msg) + row = db.QueryRow(`SELECT my_sum(42, NULL) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +type constantSUDF struct{} +type otherConstantSUDF struct{} + +func (*constantSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{ + ResultTypeInfo: currentInfo, + } +} + +func (*otherConstantSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{ + InputTypeInfos: []TypeInfo{}, + ResultTypeInfo: currentInfo, + } +} + +func constantOne([]driver.Value) (any, error) { + return int32(1), nil +} + +func (*constantSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantOne, + } +} + +func (*otherConstantSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantOne, + } +} + +func TestConstantScalarUDF(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *constantSUDF + err = RegisterScalarUDF(c, "constant_one", udf) + require.NoError(t, err) + + var otherUDF *otherConstantSUDF + err = RegisterScalarUDF(c, "other_constant_one", otherUDF) + require.NoError(t, err) + + var one int + row := db.QueryRow(`SELECT constant_one() AS one`) + require.NoError(t, row.Scan(&one)) + require.Equal(t, 1, one) + + row = db.QueryRow(`SELECT other_constant_one() AS one`) + require.NoError(t, row.Scan(&one)) + require.Equal(t, 1, one) require.NoError(t, c.Close()) require.NoError(t, db.Close()) @@ -151,7 +214,7 @@ type variadicSUDF struct{} func (*variadicSUDF) Config() ScalarFuncConfig { return ScalarFuncConfig{ ResultTypeInfo: currentInfo, - VariadicTypeInfo: ¤tInfo, + VariadicTypeInfo: currentInfo, Volatile: true, SpecialNullHandling: true, } @@ -223,7 +286,7 @@ func (*anyTypeSUDF) Config() ScalarFuncConfig { return ScalarFuncConfig{ ResultTypeInfo: currentInfo, - VariadicTypeInfo: &info, + VariadicTypeInfo: info, SpecialNullHandling: true, } } @@ -296,39 +359,6 @@ func (*errExecutorSUDF) Executor() ScalarFuncExecutor { } } -type errInputSUDF struct{} - -func (*errInputSUDF) Config() ScalarFuncConfig { - return ScalarFuncConfig{ - ResultTypeInfo: currentInfo, - } -} - -func constantNil([]driver.Value) (any, error) { - return nil, nil -} - -func (*errInputSUDF) Executor() ScalarFuncExecutor { - return ScalarFuncExecutor{ - RowExecutor: constantNil, - } -} - -type errEmptyInputSUDF struct{} - -func (*errEmptyInputSUDF) Config() ScalarFuncConfig { - return ScalarFuncConfig{ - InputTypeInfos: []TypeInfo{}, - ResultTypeInfo: currentInfo, - } -} - -func (*errEmptyInputSUDF) Executor() ScalarFuncExecutor { - return ScalarFuncExecutor{ - RowExecutor: constantNil, - } -} - type errInputNilSUDF struct{} func (*errInputNilSUDF) Config() ScalarFuncConfig { @@ -340,7 +370,7 @@ func (*errInputNilSUDF) Config() ScalarFuncConfig { func (*errInputNilSUDF) Executor() ScalarFuncExecutor { return ScalarFuncExecutor{ - RowExecutor: constantNil, + RowExecutor: constantOne, } } @@ -355,7 +385,7 @@ func (*errResultNilSUDF) Config() ScalarFuncConfig { func (*errResultNilSUDF) Executor() ScalarFuncExecutor { return ScalarFuncExecutor{ - RowExecutor: constantNil, + RowExecutor: constantOne, } } @@ -375,7 +405,7 @@ func (*errResultAnySUDF) Config() ScalarFuncConfig { func (*errResultAnySUDF) Executor() ScalarFuncExecutor { return ScalarFuncExecutor{ - RowExecutor: constantNil, + RowExecutor: constantOne, } } @@ -418,26 +448,15 @@ func TestScalarUDFErrors(t *testing.T) { err = RegisterScalarUDF(c, "err_executor_is_nil", errExecutorUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoExecutor.Error()) - // Invalid input parameters. - - var errInputUDF *errInputSUDF - err = RegisterScalarUDF(c, "err_input", errInputUDF) - testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNilInputTypes.Error()) - - var errEmptyInputUDF *errEmptyInputSUDF - err = RegisterScalarUDF(c, "err_empty_input", errEmptyInputUDF) - testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFEmptyInputTypes.Error()) - + // Invalid input parameter. var errInputNilUDF *errInputNilSUDF err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputNilUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFInputTypeIsNil.Error()) // Invalid result parameters. - var errResultNil *errResultNilSUDF err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultNil) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsNil.Error()) - var errResultAny *errResultAnySUDF err = RegisterScalarUDF(c, "err_result_type_is_any", errResultAny) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsANY.Error())