Skip to content

Commit

Permalink
make executor extensible
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 19, 2024
1 parent 9f8546d commit 354897a
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 26 deletions.
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ var (
errScalarUDFCreate = errors.New("could not create scalar UDF")
errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate)
errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate)
errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate)
errScalarUDFNilInputTypes = fmt.Errorf("%w: input types are nil", errScalarUDFCreate)
errScalarUDFEmptyInputTypes = fmt.Errorf("%w: empty input types", errScalarUDFCreate)
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
Expand Down
34 changes: 22 additions & 12 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"unsafe"
)

type rowFn func(args []driver.Value) (any, error)

type ScalarFuncConfig struct {
InputTypeInfos []TypeInfo
ResultTypeInfo TypeInfo
Expand All @@ -28,9 +30,13 @@ type ScalarFuncConfig struct {
SpecialNullHandling bool
}

type ScalarFuncExecutor struct {
RowExecutor rowFn
}

type ScalarFunc interface {
Config() ScalarFuncConfig
ExecuteRow(args []driver.Value) (any, error)
Executor() ScalarFuncExecutor
}

func setFuncError(function_info C.duckdb_function_info, msg string) {
Expand Down Expand Up @@ -62,12 +68,12 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da
}

// Execute the user-defined scalar function for each row.
executor := function.Executor()
values := make([]driver.Value, len(inputChunk.columns))
rowCount := inputChunk.GetSize()
columnCount := len(values)
var err error

for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
var err error
for rowIdx := 0; rowIdx < inputChunk.GetSize(); rowIdx++ {
// Set the values for each row.
for colIdx := 0; colIdx < columnCount; colIdx++ {
if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
Expand All @@ -78,7 +84,7 @@ func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_da

// Execute the function and write the result to the output vector.
var val any
if val, err = function.ExecuteRow(values); err != nil {
if val, err = executor.RowExecutor(values); err != nil {
break
}
if err = outputChunk.SetValue(0, rowIdx, val); err != nil {
Expand Down Expand Up @@ -145,6 +151,10 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro
if f == nil {
return nil, errScalarUDFIsNil
}
if f.Executor().RowExecutor == nil {
return nil, errScalarUDFNoExecutor
}

function := C.duckdb_create_scalar_function()

// Set the name.
Expand Down Expand Up @@ -183,16 +193,16 @@ func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, erro
// RegisterScalarUDF registers a scalar user-defined function.
// The function takes ownership of f, so you must pass it as a pointer.
func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error {
scalarFunc, err := createScalarFunc(name, f)
function, err := createScalarFunc(name, f)
if err != nil {
return getError(errAPI, err)
}

// Register the function on the underlying driver connection exposed by c.Raw.
err = c.Raw(func(driverConn any) error {
con := driverConn.(*conn)
state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunc)
C.duckdb_destroy_scalar_function(&scalarFunc)
state := C.duckdb_register_scalar_function(con.duckdbCon, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
}
Expand All @@ -208,15 +218,15 @@ func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) err

// Create each function and add it to the set.
for i, f := range functions {
scalarFunction, err := createScalarFunc(name, f)
function, err := createScalarFunc(name, f)
if err != nil {
C.duckdb_destroy_scalar_function(&scalarFunction)
C.duckdb_destroy_scalar_function(&function)
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, err)
}

state := C.duckdb_add_scalar_function_to_set(set, scalarFunction)
C.duckdb_destroy_scalar_function(&scalarFunction)
state := C.duckdb_add_scalar_function_to_set(set, function)
C.duckdb_destroy_scalar_function(&function)
if state == C.DuckDBError {
C.duckdb_destroy_scalar_function_set(&set)
return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i))
Expand Down
90 changes: 76 additions & 14 deletions scalar_udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,20 @@ func (*simpleSUDF) Config() ScalarFuncConfig {
}
}

func (*simpleSUDF) ExecuteRow(args []driver.Value) (any, error) {
func simpleSum(args []driver.Value) (any, error) {
if args[0] == nil || args[1] == nil {
return nil, nil
}
val := args[0].(int32) + args[1].(int32)
return val, nil
}

func (*simpleSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: simpleSum,
}
}

func TestSimpleScalarUDF(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)
Expand Down Expand Up @@ -70,10 +76,16 @@ func (*typesSUDF) Config() ScalarFuncConfig {
}
}

func (*typesSUDF) ExecuteRow(args []driver.Value) (any, error) {
func identity(args []driver.Value) (any, error) {
return args[0], nil
}

func (*typesSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: identity,
}
}

func TestAllTypesScalarUDF(t *testing.T) {
typeInfos := getTypeInfos(t, false)
for _, info := range typeInfos {
Expand Down Expand Up @@ -145,7 +157,7 @@ func (*variadicSUDF) Config() ScalarFuncConfig {
}
}

func (*variadicSUDF) ExecuteRow(args []driver.Value) (any, error) {
func variadicSum(args []driver.Value) (any, error) {
sum := int32(0)
for _, val := range args {
if val == nil {
Expand All @@ -156,6 +168,12 @@ func (*variadicSUDF) ExecuteRow(args []driver.Value) (any, error) {
return sum, nil
}

func (*variadicSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: variadicSum,
}
}

func TestVariadicScalarUDF(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)
Expand Down Expand Up @@ -210,7 +228,7 @@ func (*anyTypeSUDF) Config() ScalarFuncConfig {
}
}

func (*anyTypeSUDF) ExecuteRow(args []driver.Value) (any, error) {
func nilCount(args []driver.Value) (any, error) {
count := int32(0)
for _, val := range args {
if val == nil {
Expand All @@ -220,6 +238,12 @@ func (*anyTypeSUDF) ExecuteRow(args []driver.Value) (any, error) {
return count, nil
}

func (*anyTypeSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: nilCount,
}
}

func TestANYScalarUDF(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)
Expand Down Expand Up @@ -259,6 +283,19 @@ func TestANYScalarUDF(t *testing.T) {
require.NoError(t, db.Close())
}

type errExecutorSUDF struct{}

func (*errExecutorSUDF) Config() ScalarFuncConfig {
scalarUDF := simpleSUDF{}
return scalarUDF.Config()
}

func (*errExecutorSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: nil,
}
}

type errInputSUDF struct{}

func (*errInputSUDF) Config() ScalarFuncConfig {
Expand All @@ -267,10 +304,16 @@ func (*errInputSUDF) Config() ScalarFuncConfig {
}
}

func (*errInputSUDF) ExecuteRow([]driver.Value) (any, error) {
func constantNil([]driver.Value) (any, error) {
return nil, nil
}

func (*errInputSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantNil,
}
}

type errEmptyInputSUDF struct{}

func (*errEmptyInputSUDF) Config() ScalarFuncConfig {
Expand All @@ -280,8 +323,10 @@ func (*errEmptyInputSUDF) Config() ScalarFuncConfig {
}
}

func (*errEmptyInputSUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
func (*errEmptyInputSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantNil,
}
}

type errInputNilSUDF struct{}
Expand All @@ -293,8 +338,10 @@ func (*errInputNilSUDF) Config() ScalarFuncConfig {
}
}

func (*errInputNilSUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
func (*errInputNilSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantNil,
}
}

type errResultNilSUDF struct{}
Expand All @@ -306,8 +353,10 @@ func (*errResultNilSUDF) Config() ScalarFuncConfig {
}
}

func (*errResultNilSUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
func (*errResultNilSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantNil,
}
}

type errResultAnySUDF struct{}
Expand All @@ -324,8 +373,10 @@ func (*errResultAnySUDF) Config() ScalarFuncConfig {
}
}

func (*errResultAnySUDF) ExecuteRow([]driver.Value) (any, error) {
return nil, nil
func (*errResultAnySUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantNil,
}
}

type errExecSUDF struct{}
Expand All @@ -335,10 +386,16 @@ func (*errExecSUDF) Config() ScalarFuncConfig {
return scalarUDF.Config()
}

func (*errExecSUDF) ExecuteRow([]driver.Value) (any, error) {
func constantError([]driver.Value) (any, error) {
return nil, errors.New("test invalid execution")
}

func (*errExecSUDF) Executor() ScalarFuncExecutor {
return ScalarFuncExecutor{
RowExecutor: constantError,
}
}

func TestScalarUDFErrors(t *testing.T) {
t.Parallel()

Expand All @@ -356,6 +413,11 @@ func TestScalarUDFErrors(t *testing.T) {
err = RegisterScalarUDF(c, "", emptyNameUDF)
testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoName.Error())

// Invalid executor.
var errExecutorUDF *errExecutorSUDF
err = RegisterScalarUDF(c, "err_executor_is_nil", errExecutorUDF)
testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoExecutor.Error())

// Invalid input parameters.

var errInputUDF *errInputSUDF
Expand Down

0 comments on commit 354897a

Please sign in to comment.