diff --git a/errors.go b/errors.go index cb2843b6..813cf3ec 100644 --- a/errors.go +++ b/errors.go @@ -101,6 +101,7 @@ var ( errScalarUDFCreate = errors.New("could not create scalar UDF") errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate) errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate) + errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate) errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate) errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate) errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate) diff --git a/scalar_udf.go b/scalar_udf.go index 06c64905..da2dfb0c 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -19,6 +19,8 @@ import ( "unsafe" ) +type rowFn func(args []driver.Value) (any, error) + type ScalarFuncConfig struct { InputTypeInfos []TypeInfo ResultTypeInfo TypeInfo @@ -28,9 +30,13 @@ type ScalarFuncConfig struct { SpecialNullHandling bool } +type ScalarFuncExecutor struct { + RowExecutor rowFn +} + type ScalarFunc interface { Config() ScalarFuncConfig - ExecuteRow(args []driver.Value) (any, error) + Executor() ScalarFuncExecutor } func setFuncError(function_info C.duckdb_function_info, msg string) { @@ -62,12 +68,12 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da } // Execute the user-defined scalar function for each row. + executor := function.Executor() values := make([]driver.Value, len(inputChunk.columns)) - rowCount := inputChunk.GetSize() columnCount := len(values) - var err error - for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + var err error + for rowIdx := 0; rowIdx < inputChunk.GetSize(); rowIdx++ { // Set the values for each row. for colIdx := 0; colIdx < columnCount; colIdx++ { if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { @@ -78,7 +84,7 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da // Execute the function and write the result to the output vector. var val any - if val, err = function.ExecuteRow(values); err != nil { + if val, err = executor.RowExecutor(values); err != nil { break } if err = outputChunk.SetValue(0, rowIdx, val); err != nil { @@ -145,6 +151,10 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro if f == nil { return nil, errScalarUDFIsNil } + if f.Executor().RowExecutor == nil { + return nil, errScalarUDFNoExecutor + } + function := C.duckdb_create_scalar_function() // Set the name. @@ -183,7 +193,7 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro // RegisterScalarUDF registers a scalar user-defined function. // The function takes ownership of f, so you must pass it as a pointer. func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { - scalarFunc, err := createScalarFunc(name, f) + function, err := createScalarFunc(name, f) if err != nil { return getError(errAPI, err) } @@ -191,8 +201,8 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { // Register the function on the underlying driver connection exposed by c.Raw. err = c.Raw(func(driverConn any) error { con := driverConn.(*conn) - state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunc) - C.duckdb_destroy_scalar_function(&scalarFunc) + state := C.duckdb_register_scalar_function(con.duckdbCon, function) + C.duckdb_destroy_scalar_function(&function) if state == C.DuckDBError { return getError(errAPI, errScalarUDFCreate) } @@ -208,15 +218,15 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err // Create each function and add it to the set. for i, f := range functions { - scalarFunction, err := createScalarFunc(name, f) + function, err := createScalarFunc(name, f) if err != nil { - C.duckdb_destroy_scalar_function(&scalarFunction) + C.duckdb_destroy_scalar_function(&function) C.duckdb_destroy_scalar_function_set(&set) return getError(errAPI, err) } - state := C.duckdb_add_scalar_function_to_set(set, scalarFunction) - C.duckdb_destroy_scalar_function(&scalarFunction) + state := C.duckdb_add_scalar_function_to_set(set, function) + C.duckdb_destroy_scalar_function(&function) if state == C.DuckDBError { C.duckdb_destroy_scalar_function_set(&set) return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i)) diff --git a/scalar_udf_test.go b/scalar_udf_test.go index 37c824da..afc8e0e9 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -22,7 +22,7 @@ func (*simpleSUDF) Config() ScalarFuncConfig { } } -func (*simpleSUDF) ExecuteRow(args []driver.Value) (any, error) { +func simpleSum(args []driver.Value) (any, error) { if args[0] == nil || args[1] == nil { return nil, nil } @@ -30,6 +30,12 @@ func (*simpleSUDF) ExecuteRow(args []driver.Value) (any, error) { return val, nil } +func (*simpleSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: simpleSum, + } +} + func TestSimpleScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err) @@ -70,10 +76,16 @@ func (*typesSUDF) Config() ScalarFuncConfig { } } -func (*typesSUDF) ExecuteRow(args []driver.Value) (any, error) { +func identity(args []driver.Value) (any, error) { return args[0], nil } +func (*typesSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: identity, + } +} + func TestAllTypesScalarUDF(t *testing.T) { typeInfos := getTypeInfos(t, false) for _, info := range typeInfos { @@ -145,7 +157,7 @@ func (*variadicSUDF) Config() ScalarFuncConfig { } } -func (*variadicSUDF) ExecuteRow(args []driver.Value) (any, error) { +func variadicSum(args []driver.Value) (any, error) { sum := int32(0) for _, val := range args { if val == nil { @@ -156,6 +168,12 @@ func (*variadicSUDF) ExecuteRow(args []driver.Value) (any, error) { return sum, nil } +func (*variadicSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: variadicSum, + } +} + func TestVariadicScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err) @@ -210,7 +228,7 @@ func (*anyTypeSUDF) Config() ScalarFuncConfig { } } -func (*anyTypeSUDF) ExecuteRow(args []driver.Value) (any, error) { +func nilCount(args []driver.Value) (any, error) { count := int32(0) for _, val := range args { if val == nil { @@ -220,6 +238,12 @@ func (*anyTypeSUDF) ExecuteRow(args []driver.Value) (any, error) { return count, nil } +func (*anyTypeSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: nilCount, + } +} + func TestANYScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err) @@ -259,6 +283,19 @@ func TestANYScalarUDF(t *testing.T) { require.NoError(t, db.Close()) } +type errExecutorSUDF struct{} + +func (*errExecutorSUDF) Config() ScalarFuncConfig { + scalarUDF := simpleSUDF{} + return scalarUDF.Config() +} + +func (*errExecutorSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: nil, + } +} + type errInputSUDF struct{} func (*errInputSUDF) Config() ScalarFuncConfig { @@ -267,10 +304,16 @@ func (*errInputSUDF) Config() ScalarFuncConfig { } } -func (*errInputSUDF) ExecuteRow([]driver.Value) (any, error) { +func constantNil([]driver.Value) (any, error) { return nil, nil } +func (*errInputSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantNil, + } +} + type errEmptyInputSUDF struct{} func (*errEmptyInputSUDF) Config() ScalarFuncConfig { @@ -280,8 +323,10 @@ func (*errEmptyInputSUDF) Config() ScalarFuncConfig { } } -func (*errEmptyInputSUDF) ExecuteRow([]driver.Value) (any, error) { - return nil, nil +func (*errEmptyInputSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantNil, + } } type errInputNilSUDF struct{} @@ -293,8 +338,10 @@ func (*errInputNilSUDF) Config() ScalarFuncConfig { } } -func (*errInputNilSUDF) ExecuteRow([]driver.Value) (any, error) { - return nil, nil +func (*errInputNilSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantNil, + } } type errResultNilSUDF struct{} @@ -306,8 +353,10 @@ func (*errResultNilSUDF) Config() ScalarFuncConfig { } } -func (*errResultNilSUDF) ExecuteRow([]driver.Value) (any, error) { - return nil, nil +func (*errResultNilSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantNil, + } } type errResultAnySUDF struct{} @@ -324,8 +373,10 @@ func (*errResultAnySUDF) Config() ScalarFuncConfig { } } -func (*errResultAnySUDF) ExecuteRow([]driver.Value) (any, error) { - return nil, nil +func (*errResultAnySUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantNil, + } } type errExecSUDF struct{} @@ -335,10 +386,16 @@ func (*errExecSUDF) Config() ScalarFuncConfig { return scalarUDF.Config() } -func (*errExecSUDF) ExecuteRow([]driver.Value) (any, error) { +func constantError([]driver.Value) (any, error) { return nil, errors.New("test invalid execution") } +func (*errExecSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{ + RowExecutor: constantError, + } +} + func TestScalarUDFErrors(t *testing.T) { t.Parallel() @@ -356,6 +413,11 @@ func TestScalarUDFErrors(t *testing.T) { err = RegisterScalarUDF(c, "", emptyNameUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoName.Error()) + // Invalid executor. + var errExecutorUDF *errExecutorSUDF + err = RegisterScalarUDF(c, "err_executor_is_nil", errExecutorUDF) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoExecutor.Error()) + // Invalid input parameters. var errInputUDF *errInputSUDF