diff --git a/scalar_udf.go b/scalar_udf.go index c1d55fb0..3d0778da 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -19,65 +19,69 @@ import ( "unsafe" ) -type ScalarFunctionConfig interface { +type ScalarFuncConfig interface { InputTypeInfos() []TypeInfo ResultTypeInfo() TypeInfo +} + +type ScalarFuncExtraInfo interface { VariadicTypeInfo() TypeInfo Volatile() bool SpecialNullHandling() bool } -type ScalarFunction interface { - Config() ScalarFunctionConfig +type ScalarFunc interface { + Config() ScalarFuncConfig + ExtraInfo() ScalarFuncExtraInfo ExecuteRow(args []driver.Value) (any, error) } -func setFunctionError(info C.duckdb_function_info, msg string) { +func setFuncError(function_info C.duckdb_function_info, msg string) { err := C.CString(msg) - C.duckdb_scalar_function_set_error(info, err) + C.duckdb_scalar_function_set_error(function_info, err) C.duckdb_free(unsafe.Pointer(err)) } //export scalar_udf_callback -func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) { - extraInfo := C.duckdb_scalar_function_get_extra_info(info) +func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) { + extraInfo := C.duckdb_scalar_function_get_extra_info(function_info) - // extraInfo is a void* pointer to our ScalarFunction. + // extraInfo is a void* pointer to our ScalarFunc. h := *(*cgo.Handle)(unsafe.Pointer(extraInfo)) - scalarFunction := h.Value().(ScalarFunction) + function := h.Value().(ScalarFunc) // Initialize the input chunk. var inputChunk DataChunk if err := inputChunk.initFromDuckDataChunk(input, false); err != nil { - setFunctionError(info, getError(errAPI, err).Error()) + setFuncError(function_info, getError(errAPI, err).Error()) return } // Initialize the output chunk. var outputChunk DataChunk if err := outputChunk.initFromDuckVector(output, true); err != nil { - setFunctionError(info, getError(errAPI, err).Error()) + setFuncError(function_info, getError(errAPI, err).Error()) return } // Execute the user-defined scalar function for each row. - args := make([]driver.Value, len(inputChunk.columns)) + values := make([]driver.Value, len(inputChunk.columns)) rowCount := inputChunk.GetSize() - columnCount := len(args) + columnCount := len(values) var err error for rowIdx := 0; rowIdx < rowCount; rowIdx++ { - // Set the input arguments for each column of a row. + // Set the values for each row. for colIdx := 0; colIdx < columnCount; colIdx++ { - if args[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { - setFunctionError(info, getError(errAPI, err).Error()) + if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) return } } // Execute the function and write the result to the output vector. var val any - if val, err = scalarFunction.ExecuteRow(args); err != nil { + if val, err = function.ExecuteRow(values); err != nil { break } if err = outputChunk.SetValue(0, rowIdx, val); err != nil { @@ -86,7 +90,7 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, } if err != nil { - setFunctionError(info, getError(errAPI, err).Error()) + setFuncError(function_info, getError(errAPI, err).Error()) } } @@ -96,16 +100,16 @@ func scalar_udf_delete_callback(extraInfo unsafe.Pointer) { h.Delete() } -func registerInputParameters(config ScalarFunctionConfig, scalarFunction C.duckdb_scalar_function) error { +func registerInputParams(config ScalarFuncConfig, extraInfo ScalarFuncExtraInfo, f C.duckdb_scalar_function) error { // Set variadic input parameters. - if config.VariadicTypeInfo() != nil { - logicalType := config.VariadicTypeInfo().logicalType() - C.duckdb_scalar_function_set_varargs(scalarFunction, logicalType) - C.duckdb_destroy_logical_type(&logicalType) + if extraInfo != nil && extraInfo.VariadicTypeInfo() != nil { + t := extraInfo.VariadicTypeInfo().logicalType() + C.duckdb_scalar_function_set_varargs(f, t) + C.duckdb_destroy_logical_type(&t) return nil } - // Set fixed input parameters. + // Set normal input parameters. if config.InputTypeInfos() == nil { return errScalarUDFNilInputTypes } @@ -113,76 +117,77 @@ func registerInputParameters(config ScalarFunctionConfig, scalarFunction C.duckd return errScalarUDFEmptyInputTypes } - for i, inputTypeInfo := range config.InputTypeInfos() { - if inputTypeInfo == nil { + for i, info := range config.InputTypeInfos() { + if info == nil { return addIndexToError(errScalarUDFInputTypeIsNil, i) } - logicalType := inputTypeInfo.logicalType() - C.duckdb_scalar_function_add_parameter(scalarFunction, logicalType) - C.duckdb_destroy_logical_type(&logicalType) + t := info.logicalType() + C.duckdb_scalar_function_add_parameter(f, t) + C.duckdb_destroy_logical_type(&t) } return nil } -func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duckdb_scalar_function) error { +func registerResultParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error { if config.ResultTypeInfo() == nil { return errScalarUDFResultTypeIsNil } if config.ResultTypeInfo().InternalType() == TYPE_ANY { return errScalarUDFResultTypeIsANY } - logicalType := config.ResultTypeInfo().logicalType() - C.duckdb_scalar_function_set_return_type(scalarFunction, logicalType) - C.duckdb_destroy_logical_type(&logicalType) + t := config.ResultTypeInfo().logicalType() + C.duckdb_scalar_function_set_return_type(f, t) + C.duckdb_destroy_logical_type(&t) return nil } -func createScalarFunction(name string, f ScalarFunction) (C.duckdb_scalar_function, error) { +func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, error) { if name == "" { return nil, errScalarUDFNoName } if f == nil { return nil, errScalarUDFIsNil } - scalarFunction := C.duckdb_create_scalar_function() + function := C.duckdb_create_scalar_function() // Set the name. - functionName := C.CString(name) - C.duckdb_scalar_function_set_name(scalarFunction, functionName) - C.duckdb_free(unsafe.Pointer(functionName)) + cName := C.CString(name) + C.duckdb_scalar_function_set_name(function, cName) + C.duckdb_free(unsafe.Pointer(cName)) // Configure the scalar function. config := f.Config() - if err := registerInputParameters(config, scalarFunction); err != nil { + extraInfo := f.ExtraInfo() + if err := registerInputParams(config, extraInfo, function); err != nil { return nil, err } - if err := registerResultParameters(config, scalarFunction); err != nil { + if err := registerResultParams(config, function); err != nil { return nil, err } - if config.SpecialNullHandling() { - C.duckdb_scalar_function_set_special_handling(scalarFunction) + if extraInfo != nil && extraInfo.SpecialNullHandling() { + C.duckdb_scalar_function_set_special_handling(function) } - if config.Volatile() { - C.duckdb_scalar_function_set_volatile(scalarFunction) + if extraInfo != nil && extraInfo.Volatile() { + C.duckdb_scalar_function_set_volatile(function) } // Set the function callback. - C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback)) + C.duckdb_scalar_function_set_function(function, C.scalar_udf_callback_t(C.scalar_udf_callback)) // Set data available during execution. - extraInfoHandle := cgo.NewHandle(f) + h := cgo.NewHandle(f) C.duckdb_scalar_function_set_extra_info( - scalarFunction, - unsafe.Pointer(&extraInfoHandle), + function, + unsafe.Pointer(&h), C.duckdb_delete_callback_t(C.scalar_udf_delete_callback)) - return scalarFunction, nil + return function, nil } -// RegisterScalarUDF registers a scalar UDF. -// This function takes ownership of f, so you must pass it as a pointer. -func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { - scalarFunction, err := createScalarFunction(name, f) +// RegisterScalarUDF registers a scalar user-defined function. +// The function takes ownership of f, so you must pass it as a pointer. +func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { + scalarFunc, err := createScalarFunc(name, f) if err != nil { return getError(errAPI, err) } @@ -190,8 +195,8 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { // Register the function on the underlying driver connection exposed by c.Raw. err = c.Raw(func(driverConn any) error { con := driverConn.(*conn) - state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunction) - C.duckdb_destroy_scalar_function(&scalarFunction) + state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunc) + C.duckdb_destroy_scalar_function(&scalarFunc) if state == C.DuckDBError { return getError(errAPI, errScalarUDFCreate) } @@ -200,14 +205,14 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { return err } -func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunction) error { - functionName := C.CString(name) - set := C.duckdb_create_scalar_function_set(functionName) - C.duckdb_free(unsafe.Pointer(functionName)) +func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error { + cName := C.CString(name) + set := C.duckdb_create_scalar_function_set(cName) + C.duckdb_free(unsafe.Pointer(cName)) // Create each function and add it to the set. for i, f := range functions { - scalarFunction, err := createScalarFunction(name, f) + scalarFunction, err := createScalarFunc(name, f) if err != nil { C.duckdb_destroy_scalar_function(&scalarFunction) C.duckdb_destroy_scalar_function_set(&set) diff --git a/scalar_udf_test.go b/scalar_udf_test.go index dc467c16..23c3f10a 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -13,35 +13,27 @@ import ( var currentInfo TypeInfo -type simpleScalarUDF struct{} +type simpleSUDF struct{} -type simpleScalarUDFConfig struct{} +type simpleSUDFConfig struct{} -func (*simpleScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*simpleSUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{currentInfo, currentInfo} } -func (*simpleScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*simpleSUDFConfig) ResultTypeInfo() TypeInfo { return currentInfo } -func (*simpleScalarUDFConfig) VariadicTypeInfo() TypeInfo { - return nil -} - -func (*simpleScalarUDFConfig) Volatile() bool { - return false -} - -func (*simpleScalarUDFConfig) SpecialNullHandling() bool { - return false +func (*simpleSUDF) Config() ScalarFuncConfig { + return &simpleSUDFConfig{} } -func (*simpleScalarUDF) Config() ScalarFunctionConfig { - return &simpleScalarUDFConfig{} +func (*simpleSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*simpleScalarUDF) ExecuteRow(args []driver.Value) (any, error) { +func (*simpleSUDF) ExecuteRow(args []driver.Value) (any, error) { if args[0] == nil || args[1] == nil { return nil, nil } @@ -59,7 +51,7 @@ func TestSimpleScalarUDF(t *testing.T) { currentInfo, err = NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) - var udf *simpleScalarUDF + var udf *simpleSUDF err = RegisterScalarUDF(c, "my_sum", udf) require.NoError(t, err) @@ -80,35 +72,27 @@ func TestSimpleScalarUDF(t *testing.T) { require.NoError(t, db.Close()) } -type allTypesScalarUDF struct{} +type typesSUDF struct{} -type allTypesScalarUDFConfig struct{} +type typesSUDFConfig struct{} -func (*allTypesScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*typesSUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{currentInfo} } -func (*allTypesScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*typesSUDFConfig) ResultTypeInfo() TypeInfo { return currentInfo } -func (*allTypesScalarUDFConfig) VariadicTypeInfo() TypeInfo { - return nil +func (*typesSUDF) Config() ScalarFuncConfig { + return &typesSUDFConfig{} } -func (*allTypesScalarUDFConfig) Volatile() bool { - return false -} - -func (*allTypesScalarUDFConfig) SpecialNullHandling() bool { - return false -} - -func (*allTypesScalarUDF) Config() ScalarFunctionConfig { - return &allTypesScalarUDFConfig{} +func (*typesSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*allTypesScalarUDF) ExecuteRow(args []driver.Value) (any, error) { +func (*typesSUDF) ExecuteRow(args []driver.Value) (any, error) { return args[0], nil } @@ -126,7 +110,7 @@ func TestAllTypesScalarUDF(t *testing.T) { _, err = c.ExecContext(context.Background(), `CREATE TYPE greeting AS ENUM ('hello', 'world')`) require.NoError(t, err) - var udf *allTypesScalarUDF + var udf *typesSUDF err = RegisterScalarUDF(c, "my_identity", udf) require.NoError(t, err) @@ -144,35 +128,69 @@ func TestAllTypesScalarUDF(t *testing.T) { } } -type variadicScalarUDF struct{} +func TestScalarUDFSet(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) -type variadicScalarUDFConfig struct{} + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf1 *simpleSUDF + var udf2 *typesSUDF + err = RegisterScalarUDFSet(c, "my_addition", udf1, udf2) + require.NoError(t, err) -func (*variadicScalarUDFConfig) InputTypeInfos() []TypeInfo { + var sum int + row := db.QueryRow(`SELECT my_addition(10, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 52, sum) + + row = db.QueryRow(`SELECT my_addition(42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 42, sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +type variadicSUDF struct{} + +type variadicSUDFConfig struct{} + +type variadicSUDFExtraInfo struct{} + +func (*variadicSUDFConfig) InputTypeInfos() []TypeInfo { return nil } -func (*variadicScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*variadicSUDFConfig) ResultTypeInfo() TypeInfo { return currentInfo } -func (*variadicScalarUDFConfig) VariadicTypeInfo() TypeInfo { +func (*variadicSUDFExtraInfo) VariadicTypeInfo() TypeInfo { return currentInfo } -func (*variadicScalarUDFConfig) Volatile() bool { +func (*variadicSUDFExtraInfo) Volatile() bool { return true } -func (*variadicScalarUDFConfig) SpecialNullHandling() bool { +func (*variadicSUDFExtraInfo) SpecialNullHandling() bool { return true } -func (*variadicScalarUDF) Config() ScalarFunctionConfig { - return &variadicScalarUDFConfig{} +func (*variadicSUDF) Config() ScalarFuncConfig { + return &variadicSUDFConfig{} } -func (*variadicScalarUDF) ExecuteRow(args []driver.Value) (any, error) { +func (*variadicSUDF) ExtraInfo() ScalarFuncExtraInfo { + return &variadicSUDFExtraInfo{} +} + +func (*variadicSUDF) ExecuteRow(args []driver.Value) (any, error) { sum := int32(0) for _, val := range args { if val == nil { @@ -183,34 +201,6 @@ func (*variadicScalarUDF) ExecuteRow(args []driver.Value) (any, error) { return sum, nil } -func TestScalarUDFSet(t *testing.T) { - db, err := sql.Open("duckdb", "") - require.NoError(t, err) - - c, err := db.Conn(context.Background()) - require.NoError(t, err) - - currentInfo, err = NewTypeInfo(TYPE_INTEGER) - require.NoError(t, err) - - var udf1 *simpleScalarUDF - var udf2 *allTypesScalarUDF - err = RegisterScalarUDFSet(c, "my_addition", udf1, udf2) - require.NoError(t, err) - - var sum int - row := db.QueryRow(`SELECT my_addition(10, 42) AS sum`) - require.NoError(t, row.Scan(&sum)) - require.Equal(t, 52, sum) - - row = db.QueryRow(`SELECT my_addition(42) AS sum`) - require.NoError(t, row.Scan(&sum)) - require.Equal(t, 42, sum) - - require.NoError(t, c.Close()) - require.NoError(t, db.Close()) -} - func TestVariadicScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err) @@ -221,7 +211,7 @@ func TestVariadicScalarUDF(t *testing.T) { currentInfo, err = NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) - var udf *variadicScalarUDF + var udf *variadicSUDF err = RegisterScalarUDF(c, "my_variadic_sum", udf) require.NoError(t, err) @@ -250,19 +240,11 @@ func TestVariadicScalarUDF(t *testing.T) { require.NoError(t, db.Close()) } -type anyScalarUDF struct{} - -type anyScalarUDFConfig struct{} - -func (*anyScalarUDFConfig) InputTypeInfos() []TypeInfo { - return nil -} +type anyTypeSUDF struct{} -func (*anyScalarUDFConfig) ResultTypeInfo() TypeInfo { - return currentInfo -} +type anyTypeSUDFExtraInfo struct{} -func (*anyScalarUDFConfig) VariadicTypeInfo() TypeInfo { +func (*anyTypeSUDFExtraInfo) VariadicTypeInfo() TypeInfo { info, err := NewTypeInfo(TYPE_ANY) if err != nil { panic(err) @@ -270,19 +252,23 @@ func (*anyScalarUDFConfig) VariadicTypeInfo() TypeInfo { return info } -func (*anyScalarUDFConfig) Volatile() bool { +func (*anyTypeSUDFExtraInfo) Volatile() bool { return true } -func (*anyScalarUDFConfig) SpecialNullHandling() bool { +func (*anyTypeSUDFExtraInfo) SpecialNullHandling() bool { return true } -func (*anyScalarUDF) Config() ScalarFunctionConfig { - return &anyScalarUDFConfig{} +func (*anyTypeSUDF) Config() ScalarFuncConfig { + return &variadicSUDFConfig{} +} + +func (*anyTypeSUDF) ExtraInfo() ScalarFuncExtraInfo { + return &anyTypeSUDFExtraInfo{} } -func (*anyScalarUDF) ExecuteRow(args []driver.Value) (any, error) { +func (*anyTypeSUDF) ExecuteRow(args []driver.Value) (any, error) { count := int32(0) for _, val := range args { if val == nil { @@ -302,7 +288,7 @@ func TestANYScalarUDF(t *testing.T) { currentInfo, err = NewTypeInfo(TYPE_INTEGER) require.NoError(t, err) - var udf *anyScalarUDF + var udf *anyTypeSUDF err = RegisterScalarUDF(c, "my_null_count", udf) require.NoError(t, err) @@ -331,143 +317,101 @@ func TestANYScalarUDF(t *testing.T) { require.NoError(t, db.Close()) } -type errNilInputScalarUDF struct{} - -type errNilInputScalarUDFConfig struct{} - -func (*errNilInputScalarUDFConfig) InputTypeInfos() []TypeInfo { - return nil -} +type errInputSUDF struct{} -func (*errNilInputScalarUDFConfig) ResultTypeInfo() TypeInfo { - return currentInfo +func (*errInputSUDF) Config() ScalarFuncConfig { + return &variadicSUDFConfig{} } -func (*errNilInputScalarUDFConfig) VariadicTypeInfo() TypeInfo { +func (*errInputSUDF) ExtraInfo() ScalarFuncExtraInfo { return nil } -func (*errNilInputScalarUDFConfig) Volatile() bool { - return false -} - -func (*errNilInputScalarUDFConfig) SpecialNullHandling() bool { - return false -} - -func (*errNilInputScalarUDF) Config() ScalarFunctionConfig { - return &errNilInputScalarUDFConfig{} -} - -func (*errNilInputScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errInputSUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } -type errEmptyInputScalarUDF struct{} +type errEmptyInputSUDF struct{} -type errEmptyInputScalarUDFConfig struct{} +type errEmptyInputSUDFConfig struct{} -func (*errEmptyInputScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*errEmptyInputSUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{} } -func (*errEmptyInputScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*errEmptyInputSUDFConfig) ResultTypeInfo() TypeInfo { return currentInfo } -func (*errEmptyInputScalarUDFConfig) VariadicTypeInfo() TypeInfo { - return nil -} - -func (*errEmptyInputScalarUDFConfig) Volatile() bool { - return false +func (*errEmptyInputSUDF) Config() ScalarFuncConfig { + return &errEmptyInputSUDFConfig{} } -func (*errEmptyInputScalarUDFConfig) SpecialNullHandling() bool { - return false -} - -func (*errEmptyInputScalarUDF) Config() ScalarFunctionConfig { - return &errEmptyInputScalarUDFConfig{} +func (*errEmptyInputSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*errEmptyInputScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errEmptyInputSUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } -type errInputIsNilScalarUDF struct{} +type errInputNilSUDF struct{} -type errInputIsNilScalarUDFConfig struct{} +type errInputNilSUDFConfig struct{} -func (*errInputIsNilScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*errInputNilSUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{nil} } -func (*errInputIsNilScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*errInputNilSUDFConfig) ResultTypeInfo() TypeInfo { return currentInfo } -func (*errInputIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo { - return nil -} - -func (*errInputIsNilScalarUDFConfig) Volatile() bool { - return false +func (*errInputNilSUDF) Config() ScalarFuncConfig { + return &errInputNilSUDFConfig{} } -func (*errInputIsNilScalarUDFConfig) SpecialNullHandling() bool { - return false -} - -func (*errInputIsNilScalarUDF) Config() ScalarFunctionConfig { - return &errInputIsNilScalarUDFConfig{} +func (*errInputNilSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*errInputIsNilScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errInputNilSUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } -type errResultIsNilScalarUDF struct{} +type errResultNilSUDF struct{} -type errResultIsNilScalarUDFConfig struct{} +type errResultNilSUDFConfig struct{} -func (*errResultIsNilScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*errResultNilSUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{currentInfo} } -func (*errResultIsNilScalarUDFConfig) ResultTypeInfo() TypeInfo { - return nil -} - -func (*errResultIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo { +func (*errResultNilSUDFConfig) ResultTypeInfo() TypeInfo { return nil } -func (*errResultIsNilScalarUDFConfig) Volatile() bool { - return false -} - -func (*errResultIsNilScalarUDFConfig) SpecialNullHandling() bool { - return false +func (*errResultNilSUDF) Config() ScalarFuncConfig { + return &errResultNilSUDFConfig{} } -func (*errResultIsNilScalarUDF) Config() ScalarFunctionConfig { - return &errResultIsNilScalarUDFConfig{} +func (*errResultNilSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*errResultIsNilScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errResultNilSUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } -type errResultIsANYScalarUDF struct{} +type errResultAnySUDF struct{} -type errResultIsANYScalarUDFConfig struct{} +type errResultAnySUDFConfig struct{} -func (*errResultIsANYScalarUDFConfig) InputTypeInfos() []TypeInfo { +func (*errResultAnySUDFConfig) InputTypeInfos() []TypeInfo { return []TypeInfo{currentInfo} } -func (*errResultIsANYScalarUDFConfig) ResultTypeInfo() TypeInfo { +func (*errResultAnySUDFConfig) ResultTypeInfo() TypeInfo { info, err := NewTypeInfo(TYPE_ANY) if err != nil { panic(err) @@ -475,34 +419,30 @@ func (*errResultIsANYScalarUDFConfig) ResultTypeInfo() TypeInfo { return info } -func (*errResultIsANYScalarUDFConfig) VariadicTypeInfo() TypeInfo { - return nil -} - -func (*errResultIsANYScalarUDFConfig) Volatile() bool { - return false -} - -func (*errResultIsANYScalarUDFConfig) SpecialNullHandling() bool { - return false +func (*errResultAnySUDF) Config() ScalarFuncConfig { + return &errResultAnySUDFConfig{} } -func (*errResultIsANYScalarUDF) Config() ScalarFunctionConfig { - return &errResultIsANYScalarUDFConfig{} +func (*errResultAnySUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil } -func (*errResultIsANYScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errResultAnySUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } -type errExecScalarUDF struct{} +type errExecSUDF struct{} -func (*errExecScalarUDF) Config() ScalarFunctionConfig { - scalarUDF := simpleScalarUDF{} +func (*errExecSUDF) Config() ScalarFuncConfig { + scalarUDF := simpleSUDF{} return scalarUDF.Config() } -func (*errExecScalarUDF) ExecuteRow([]driver.Value) (any, error) { +func (*errExecSUDF) ExtraInfo() ScalarFuncExtraInfo { + return nil +} + +func (*errExecSUDF) ExecuteRow([]driver.Value) (any, error) { return nil, errors.New("test invalid execution") } @@ -519,36 +459,36 @@ func TestScalarUDFErrors(t *testing.T) { require.NoError(t, err) // Empty name. - var emptyNameUDF *simpleScalarUDF + var emptyNameUDF *simpleSUDF err = RegisterScalarUDF(c, "", emptyNameUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoName.Error()) // Invalid input parameters. - var errNilInputUDF *errNilInputScalarUDF - err = RegisterScalarUDF(c, "err_nil_input", errNilInputUDF) + var errInputUDF *errInputSUDF + err = RegisterScalarUDF(c, "err_input", errInputUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNilInputTypes.Error()) - var errEmptyInputUDF *errEmptyInputScalarUDF + var errEmptyInputUDF *errEmptyInputSUDF err = RegisterScalarUDF(c, "err_empty_input", errEmptyInputUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFEmptyInputTypes.Error()) - var errInputIsNilUDF *errInputIsNilScalarUDF - err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputIsNilUDF) + var errInputNilUDF *errInputNilSUDF + err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputNilUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFInputTypeIsNil.Error()) // Invalid result parameters. - var errResultIsNil *errResultIsNilScalarUDF - err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultIsNil) + var errResultNil *errResultNilSUDF + err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultNil) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsNil.Error()) - var errResultIsANY *errResultIsANYScalarUDF - err = RegisterScalarUDF(c, "err_result_type_is_any", errResultIsANY) + var errResultAny *errResultAnySUDF + err = RegisterScalarUDF(c, "err_result_type_is_any", errResultAny) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsANY.Error()) // Error during execution. - var errExecUDF *errExecScalarUDF + var errExecUDF *errExecSUDF err = RegisterScalarUDF(c, "err_exec", errExecUDF) require.NoError(t, err) row := db.QueryRow(`SELECT err_exec(10, 10) AS msg`) @@ -556,15 +496,15 @@ func TestScalarUDFErrors(t *testing.T) { // Register the same scalar function a second time. // Since RegisterScalarUDF takes ownership of udf, we are now passing nil. - var udf *simpleScalarUDF + var udf *simpleSUDF err = RegisterScalarUDF(c, "my_sum", udf) require.NoError(t, err) err = RegisterScalarUDF(c, "my_sum", udf) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error()) // Register a scalar function whose name already exists. - var udfDuplicateName *simpleScalarUDF - err = RegisterScalarUDF(c, "my_sum", udfDuplicateName) + var errDuplicateUDF *simpleSUDF + err = RegisterScalarUDF(c, "my_sum", errDuplicateUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error()) // Register a scalar function that is nil. @@ -573,8 +513,8 @@ func TestScalarUDFErrors(t *testing.T) { require.NoError(t, c.Close()) // Test registering the scalar function on a closed connection. - var udfOnClosedCon *simpleScalarUDF - err = RegisterScalarUDF(c, "closed_con", udfOnClosedCon) + var errClosedConUDF *simpleSUDF + err = RegisterScalarUDF(c, "closed_con", errClosedConUDF) require.ErrorContains(t, err, sql.ErrConnDone.Error()) require.NoError(t, db.Close()) }