Skip to content

Commit

Permalink
adding ANY
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 17, 2024
1 parent ab9b9fc commit cd7dc56
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 4 deletions.
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ var (
errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate)
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY", errScalarUDFCreate)

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
Expand Down
16 changes: 12 additions & 4 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ type ScalarFunctionConfig interface {
InputTypeInfos() []TypeInfo
ResultTypeInfo() TypeInfo
VariadicTypeInfo() TypeInfo
Volatile() bool
SpecialNullHandling() bool
}

type ScalarFunction interface {
Expand Down Expand Up @@ -126,6 +128,9 @@ func registerResultParameters(config ScalarFunctionConfig, scalarFunction C.duck
if config.ResultTypeInfo() == nil {
return errScalarUDFResultTypeIsNil
}
if config.ResultTypeInfo().InternalType() == TYPE_ANY {
return errScalarUDFResultTypeIsANY
}
logicalType := config.ResultTypeInfo().logicalType()
C.duckdb_scalar_function_set_return_type(scalarFunction, logicalType)
C.duckdb_destroy_logical_type(&logicalType)
Expand Down Expand Up @@ -153,17 +158,20 @@ func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunction) error {
scalarFunction := C.duckdb_create_scalar_function()
C.duckdb_scalar_function_set_name(scalarFunction, functionName)

// Get the configuration.
// Configure the scalar function.
config := f.Config()

// Register the input parameters.
if err := registerInputParameters(config, scalarFunction); err != nil {
return getError(errAPI, err)
}
// Register the result parameters.
if err := registerResultParameters(config, scalarFunction); err != nil {
return getError(errAPI, err)
}
if config.SpecialNullHandling() {
C.duckdb_scalar_function_set_special_handling(scalarFunction)
}
if config.Volatile() {
C.duckdb_scalar_function_set_volatile(scalarFunction)
}

// Set the function callback.
C.duckdb_scalar_function_set_function(scalarFunction, C.scalar_udf_callback_t(C.scalar_udf_callback))
Expand Down
182 changes: 182 additions & 0 deletions scalar_udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func (*simpleScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*simpleScalarUDFConfig) Volatile() bool {
return false
}

func (*simpleScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*simpleScalarUDF) Config() ScalarFunctionConfig {
return &simpleScalarUDFConfig{}
}
Expand Down Expand Up @@ -88,6 +96,14 @@ func (*allTypesScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*allTypesScalarUDFConfig) Volatile() bool {
return false
}

func (*allTypesScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*allTypesScalarUDF) Config() ScalarFunctionConfig {
return &allTypesScalarUDFConfig{}
}
Expand Down Expand Up @@ -144,6 +160,14 @@ func (*variadicScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return currentInfo
}

func (*variadicScalarUDFConfig) Volatile() bool {
return true
}

func (*variadicScalarUDFConfig) SpecialNullHandling() bool {
return true
}

func (*variadicScalarUDF) Config() ScalarFunctionConfig {
return &variadicScalarUDFConfig{}
}
Expand Down Expand Up @@ -198,6 +222,90 @@ func TestVariadicScalarUDF(t *testing.T) {
require.NoError(t, db.Close())
}

type anyScalarUDF struct{}

type anyScalarUDFConfig struct{}

func (*anyScalarUDFConfig) InputTypeInfos() []TypeInfo {
return nil
}

func (*anyScalarUDFConfig) ResultTypeInfo() TypeInfo {
return currentInfo
}

func (*anyScalarUDFConfig) VariadicTypeInfo() TypeInfo {
info, err := NewTypeInfo(TYPE_ANY)
if err != nil {
panic(err)
}
return info
}

func (*anyScalarUDFConfig) Volatile() bool {
return true
}

func (*anyScalarUDFConfig) SpecialNullHandling() bool {
return true
}

func (*anyScalarUDF) Config() ScalarFunctionConfig {
return &anyScalarUDFConfig{}
}

func (*anyScalarUDF) ExecuteRow(args []driver.Value) (any, error) {
count := int32(0)
for _, val := range args {
if val == nil {
count++
}
}
return count, nil
}

func TestANYScalarUDF(t *testing.T) {
// TODO: once DuckDB has SQLNull type.
return

db, err := sql.Open("duckdb", "")
require.NoError(t, err)

c, err := db.Conn(context.Background())
require.NoError(t, err)

currentInfo, err = NewTypeInfo(TYPE_INTEGER)
require.NoError(t, err)

var udf *anyScalarUDF
err = RegisterScalarUDF(c, "my_null_count", udf)
require.NoError(t, err)

var count int
row := db.QueryRow(`SELECT my_null_count(10, 42, 2, 2, 2) AS msg`)
require.NoError(t, row.Scan(&count))
require.Equal(t, 0, count)

row = db.QueryRow(`SELECT my_null_count(10, NULL, NULL) AS msg`)
require.NoError(t, row.Scan(&count))
require.Equal(t, 2, count)

row = db.QueryRow(`SELECT my_null_count(10) AS msg`)
require.NoError(t, row.Scan(&count))
require.Equal(t, 0, count)

row = db.QueryRow(`SELECT my_null_count(NULL) AS msg`)
require.NoError(t, row.Scan(&count))
require.Equal(t, 1, count)

row = db.QueryRow(`SELECT my_null_count() AS msg`)
require.NoError(t, row.Scan(&count))
require.Equal(t, 0, count)

require.NoError(t, c.Close())
require.NoError(t, db.Close())
}

type errNilInputScalarUDF struct{}

type errNilInputScalarUDFConfig struct{}
Expand All @@ -214,6 +322,14 @@ func (*errNilInputScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*errNilInputScalarUDFConfig) Volatile() bool {
return false
}

func (*errNilInputScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*errNilInputScalarUDF) Config() ScalarFunctionConfig {
return &errNilInputScalarUDFConfig{}
}
Expand All @@ -238,6 +354,14 @@ func (*errEmptyInputScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*errEmptyInputScalarUDFConfig) Volatile() bool {
return false
}

func (*errEmptyInputScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*errEmptyInputScalarUDF) Config() ScalarFunctionConfig {
return &errEmptyInputScalarUDFConfig{}
}
Expand All @@ -262,6 +386,14 @@ func (*errInputIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*errInputIsNilScalarUDFConfig) Volatile() bool {
return false
}

func (*errInputIsNilScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*errInputIsNilScalarUDF) Config() ScalarFunctionConfig {
return &errInputIsNilScalarUDFConfig{}
}
Expand All @@ -286,6 +418,14 @@ func (*errResultIsNilScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*errResultIsNilScalarUDFConfig) Volatile() bool {
return false
}

func (*errResultIsNilScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*errResultIsNilScalarUDF) Config() ScalarFunctionConfig {
return &errResultIsNilScalarUDFConfig{}
}
Expand All @@ -294,6 +434,42 @@ func (*errResultIsNilScalarUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
}

type errResultIsANYScalarUDF struct{}

type errResultIsANYScalarUDFConfig struct{}

func (*errResultIsANYScalarUDFConfig) InputTypeInfos() []TypeInfo {
return []TypeInfo{currentInfo}
}

func (*errResultIsANYScalarUDFConfig) ResultTypeInfo() TypeInfo {
info, err := NewTypeInfo(TYPE_ANY)
if err != nil {
panic(err)
}
return info
}

func (*errResultIsANYScalarUDFConfig) VariadicTypeInfo() TypeInfo {
return nil
}

func (*errResultIsANYScalarUDFConfig) Volatile() bool {
return false
}

func (*errResultIsANYScalarUDFConfig) SpecialNullHandling() bool {
return false
}

func (*errResultIsANYScalarUDF) Config() ScalarFunctionConfig {
return &errResultIsANYScalarUDFConfig{}
}

func (*errResultIsANYScalarUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
}

type errExecScalarUDF struct{}

func (*errExecScalarUDF) Config() ScalarFunctionConfig {
Expand Down Expand Up @@ -336,10 +512,16 @@ func TestScalarUDFErrors(t *testing.T) {
err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputIsNilUDF)
testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFInputTypeIsNil.Error())

// Invalid result parameters.

var errResultIsNil *errResultIsNilScalarUDF
err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultIsNil)
testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsNil.Error())

var errResultIsANY *errResultIsANYScalarUDF
err = RegisterScalarUDF(c, "err_result_type_is_any", errResultIsANY)
testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsANY.Error())

// Error during execution.
var errExecUDF *errExecScalarUDF
err = RegisterScalarUDF(c, "err_exec", errExecUDF)
Expand Down

0 comments on commit cd7dc56

Please sign in to comment.