diff --git a/errors.go b/errors.go index b41e2dc2..651e9295 100644 --- a/errors.go +++ b/errors.go @@ -105,6 +105,7 @@ var ( errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate) errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate) errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate) + errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate) // Errors not covered in tests. errConnect = errors.New("could not connect to database") diff --git a/scalar_udf.go b/scalar_udf.go index 9ccd6b6a..e5832001 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -23,6 +23,8 @@ type ScalarFunctionConfig interface { InputTypeInfos() []TypeInfo ResultTypeInfo() TypeInfo VariadicTypeInfo() TypeInfo + Volatile() bool + SpecialNullHandling() bool } type ScalarFunction interface { @@ -126,6 +128,9 @@ func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duck if config.ResultTypeInfo() == nil { return errScalarUDFResultTypeIsNil } + if config.ResultTypeInfo().InternalType() == TYPE_ANY { + return errScalarUDFResultTypeIsANY + } logicalType := config.ResultTypeInfo().logicalType() C.duckdb_scalar_function_set_return_type(scalarFunction, logicalType) C.duckdb_destroy_logical_type(&logicalType) @@ -153,17 +158,20 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error { scalarFunction := C.duckdb_create_scalar_function() C.duckdb_scalar_function_set_name(scalarFunction, functionName) - // Get the configuration. + // Configure the scalar function. config := f.Config() - - // Register the input parameters. if err := registerInputParameters(config, scalarFunction); err != nil { return getError(errAPI, err) } - // Register the result parameters. if err := registerResultParameters(config, scalarFunction); err != nil { return getError(errAPI, err) } + if config.SpecialNullHandling() { + C.duckdb_scalar_function_set_special_handling(scalarFunction) + } + if config.Volatile() { + C.duckdb_scalar_function_set_volatile(scalarFunction) + } // Set the function callback. C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback)) diff --git a/scalar_udf_test.go b/scalar_udf_test.go index a54d9bde..eebb1c71 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -29,6 +29,14 @@ func (*simpleScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*simpleScalarUDFConfig) Volatile() bool { + return false +} + +func (*simpleScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*simpleScalarUDF) Config() ScalarFunctionConfig { return &simpleScalarUDFConfig{} } @@ -88,6 +96,14 @@ func (*allTypesScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*allTypesScalarUDFConfig) Volatile() bool { + return false +} + +func (*allTypesScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*allTypesScalarUDF) Config() ScalarFunctionConfig { return &allTypesScalarUDFConfig{} } @@ -144,6 +160,14 @@ func (*variadicScalarUDFConfig) VariadicTypeInfo() TypeInfo { return currentInfo } +func (*variadicScalarUDFConfig) Volatile() bool { + return true +} + +func (*variadicScalarUDFConfig) SpecialNullHandling() bool { + return true +} + func (*variadicScalarUDF) Config() ScalarFunctionConfig { return &variadicScalarUDFConfig{} } @@ -198,6 +222,90 @@ func TestVariadicScalarUDF(t *testing.T) { require.NoError(t, db.Close()) } +type anyScalarUDF struct{} + +type anyScalarUDFConfig struct{} + +func (*anyScalarUDFConfig) InputTypeInfos() []TypeInfo { + return nil +} + +func (*anyScalarUDFConfig) ResultTypeInfo() TypeInfo { + return currentInfo +} + +func (*anyScalarUDFConfig) VariadicTypeInfo() TypeInfo { + info, err := NewTypeInfo(TYPE_ANY) + if err != nil { + panic(err) + } + return info +} + +func (*anyScalarUDFConfig) Volatile() bool { + return true +} + +func (*anyScalarUDFConfig) SpecialNullHandling() bool { + return true +} + +func (*anyScalarUDF) Config() ScalarFunctionConfig { + return &anyScalarUDFConfig{} +} + +func (*anyScalarUDF) ExecuteRow(args []driver.Value) (any, error) { + count := int32(0) + for _, val := range args { + if val == nil { + count++ + } + } + return count, nil +} + +func TestANYScalarUDF(t *testing.T) { + // TODO: once DuckDB has SQLNull type. + return + + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *anyScalarUDF + err = RegisterScalarUDF(c, "my_null_count", udf) + require.NoError(t, err) + + var count int + row := db.QueryRow(`SELECT my_null_count(10, 42, 2, 2, 2) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + row = db.QueryRow(`SELECT my_null_count(10, NULL, NULL) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 2, count) + + row = db.QueryRow(`SELECT my_null_count(10) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + row = db.QueryRow(`SELECT my_null_count(NULL) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 1, count) + + row = db.QueryRow(`SELECT my_null_count() AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + type errNilInputScalarUDF struct{} type errNilInputScalarUDFConfig struct{} @@ -214,6 +322,14 @@ func (*errNilInputScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*errNilInputScalarUDFConfig) Volatile() bool { + return false +} + +func (*errNilInputScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*errNilInputScalarUDF) Config() ScalarFunctionConfig { return &errNilInputScalarUDFConfig{} } @@ -238,6 +354,14 @@ func (*errEmptyInputScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*errEmptyInputScalarUDFConfig) Volatile() bool { + return false +} + +func (*errEmptyInputScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*errEmptyInputScalarUDF) Config() ScalarFunctionConfig { return &errEmptyInputScalarUDFConfig{} } @@ -262,6 +386,14 @@ func (*errInputIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*errInputIsNilScalarUDFConfig) Volatile() bool { + return false +} + +func (*errInputIsNilScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*errInputIsNilScalarUDF) Config() ScalarFunctionConfig { return &errInputIsNilScalarUDFConfig{} } @@ -286,6 +418,14 @@ func (*errResultIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo { return nil } +func (*errResultIsNilScalarUDFConfig) Volatile() bool { + return false +} + +func (*errResultIsNilScalarUDFConfig) SpecialNullHandling() bool { + return false +} + func (*errResultIsNilScalarUDF) Config() ScalarFunctionConfig { return &errResultIsNilScalarUDFConfig{} } @@ -294,6 +434,42 @@ func (*errResultIsNilScalarUDF) ExecuteRow([]driver.Value) (any, error) { return nil, nil } +type errResultIsANYScalarUDF struct{} + +type errResultIsANYScalarUDFConfig struct{} + +func (*errResultIsANYScalarUDFConfig) InputTypeInfos() []TypeInfo { + return []TypeInfo{currentInfo} +} + +func (*errResultIsANYScalarUDFConfig) ResultTypeInfo() TypeInfo { + info, err := NewTypeInfo(TYPE_ANY) + if err != nil { + panic(err) + } + return info +} + +func (*errResultIsANYScalarUDFConfig) VariadicTypeInfo() TypeInfo { + return nil +} + +func (*errResultIsANYScalarUDFConfig) Volatile() bool { + return false +} + +func (*errResultIsANYScalarUDFConfig) SpecialNullHandling() bool { + return false +} + +func (*errResultIsANYScalarUDF) Config() ScalarFunctionConfig { + return &errResultIsANYScalarUDFConfig{} +} + +func (*errResultIsANYScalarUDF) ExecuteRow([]driver.Value) (any, error) { + return nil, nil +} + type errExecScalarUDF struct{} func (*errExecScalarUDF) Config() ScalarFunctionConfig { @@ -336,10 +512,16 @@ func TestScalarUDFErrors(t *testing.T) { err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputIsNilUDF) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFInputTypeIsNil.Error()) + // Invalid result parameters. + var errResultIsNil *errResultIsNilScalarUDF err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultIsNil) testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsNil.Error()) + var errResultIsANY *errResultIsANYScalarUDF + err = RegisterScalarUDF(c, "err_result_type_is_any", errResultIsANY) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsANY.Error()) + // Error during execution. var errExecUDF *errExecScalarUDF err = RegisterScalarUDF(c, "err_exec", errExecUDF)