diff --git a/Makefile b/Makefile index b6c4e07c..4467333f 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ install: examples: go run examples/simple/main.go go run examples/appender/main.go + go run examples/scalar_udf/main.go .PHONY: test test: diff --git a/data_chunk.go b/data_chunk.go index 65fcd058..b46d8a2c 100644 --- a/data_chunk.go +++ b/data_chunk.go @@ -47,6 +47,9 @@ func (chunk *DataChunk) GetValue(colIdx int, rowIdx int) (any, error) { return nil, getError(errAPI, columnCountError(colIdx, len(chunk.columns))) } column := &chunk.columns[colIdx] + if column.isSQLNull { + return nil, nil + } return column.getFn(column, C.idx_t(rowIdx)), nil } @@ -57,7 +60,11 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error { if colIdx >= len(chunk.columns) { return getError(errAPI, columnCountError(colIdx, len(chunk.columns))) } + column := &chunk.columns[colIdx] + if column.isSQLNull { + return getError(errAPI, errSetSQLNULLValue) + } // Set the value. return column.setFn(column, C.idx_t(rowIdx), val) @@ -127,6 +134,23 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable return err } +func (chunk *DataChunk) initFromDuckVector(duckdbVector C.duckdb_vector, writable bool) error { + columnCount := 1 + chunk.columns = make([]vector, columnCount) + + // Initialize the callback functions to read and write values. + logicalType := C.duckdb_vector_get_column_type(duckdbVector) + err := chunk.columns[0].init(logicalType, 0) + C.duckdb_destroy_logical_type(&logicalType) + if err != nil { + return err + } + + // Initialize the vector and its child vectors. + chunk.columns[0].initVectors(duckdbVector, writable) + return nil +} + func (chunk *DataChunk) close() { C.duckdb_destroy_data_chunk(&chunk.data) } diff --git a/errors.go b/errors.go index 1de2e763..02b88a2c 100644 --- a/errors.go +++ b/errors.go @@ -82,19 +82,34 @@ var ( errUnsupportedMapKeyType = errors.New("MAP key type not supported") - errAppenderInvalidCon = errors.New("could not create appender: not a DuckDB driver connection") - errAppenderClosedCon = errors.New("could not create appender: appender creation on a closed connection") - errAppenderCreation = errors.New("could not create appender") - errAppenderDoubleClose = errors.New("could not close appender: already closed") + errAppenderCreation = errors.New("could not create appender") + errAppenderInvalidCon = fmt.Errorf("%w: not a DuckDB driver connection", errAppenderCreation) + errAppenderClosedCon = fmt.Errorf("%w: appender creation on a closed connection", errAppenderCreation) + + errAppenderClose = errors.New("could not close appender") + errAppenderDoubleClose = fmt.Errorf("%w: already closed", errAppenderClose) + errAppenderAppendRow = errors.New("could not append row") - errAppenderAppendAfterClose = errors.New("could not append row: appender already closed") - errAppenderClose = errors.New("could not close appender") - errAppenderFlush = errors.New("could not flush appender") + errAppenderAppendAfterClose = fmt.Errorf("%w: appender already closed", errAppenderAppendRow) + + errAppenderFlush = errors.New("could not flush appender") errEmptyName = errors.New("empty name") errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH) errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width") + errScalarUDFCreate = errors.New("could not create scalar UDF") + errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate) + errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate) + errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate) + errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate) + errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate) + errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY, which is not supported", errScalarUDFCreate) + errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set") + errScalarUDFAddToSet = fmt.Errorf("%w: could not add the function to the set", errScalarUDFCreateSet) + + errSetSQLNULLValue = errors.New("cannot write to a NULL column") + // Errors not covered in tests. errConnect = errors.New("could not connect to database") errCreateConfig = errors.New("could not create config for database") diff --git a/examples/scalar_udf/main.go b/examples/scalar_udf/main.go new file mode 100644 index 00000000..0afe79e7 --- /dev/null +++ b/examples/scalar_udf/main.go @@ -0,0 +1,166 @@ +package main + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + + "github.com/marcboeker/go-duckdb" +) + +// Overload my_length with two user-defined scalar functions. +// varcharLen takes a VARCHAR as its input parameter. +// listLen takes a LIST(ANY) as its input parameter. + +type ( + varcharLen struct{} + listLen struct{} +) + +func varcharLenFn(values []driver.Value) (any, error) { + str := values[0].(string) + return int32(len(str)), nil +} + +func (*varcharLen) Config() duckdb.ScalarFuncConfig { + inputTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR) + check(err) + resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER) + check(err) + + return duckdb.ScalarFuncConfig{ + InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo}, + ResultTypeInfo: resultTypeInfo, + } +} + +func (*varcharLen) Executor() duckdb.ScalarFuncExecutor { + return duckdb.ScalarFuncExecutor{RowExecutor: varcharLenFn} +} + +func listLenFn(values []driver.Value) (any, error) { + list := values[0].([]any) + return int32(len(list)), nil +} + +func (*listLen) Config() duckdb.ScalarFuncConfig { + anyTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_ANY) + check(err) + inputTypeInfo, err := duckdb.NewListInfo(anyTypeInfo) + check(err) + resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER) + check(err) + + return duckdb.ScalarFuncConfig{ + InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo}, + ResultTypeInfo: resultTypeInfo, + } +} + +func (*listLen) Executor() duckdb.ScalarFuncExecutor { + return duckdb.ScalarFuncExecutor{RowExecutor: listLenFn} +} + +func myLengthScalarUDFSet() { + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + check(err) + + c, err := db.Conn(context.Background()) + check(err) + + var varcharUDF *varcharLen + var listUDF *listLen + err = duckdb.RegisterScalarUDFSet(c, "my_length", varcharUDF, listUDF) + check(err) + + var length int32 + row := db.QueryRow(`SELECT my_length('hello world') AS sum`) + check(row.Scan(&length)) + if length != 11 { + panic(errors.New("incorrect length")) + } + + row = db.QueryRow(`SELECT my_length([1, 2, NULL, 4, NULL]) AS sum`) + check(row.Scan(&length)) + if length != 5 { + panic(errors.New("incorrect length")) + } + + check(c.Close()) + check(db.Close()) +} + +// appendSum takes a VARCHAR prefix, a VARCHAR suffix, and a variadic number of integer values. +// It computes the sum of the integer values. Then, it emits a VARCHAR by concatenating prefix || sum || suffix. + +type wrapSum struct{} + +func wrapSumFn(values []driver.Value) (any, error) { + sum := int32(0) + for i := 2; i < len(values); i++ { + sum += values[i].(int32) + } + strSum := fmt.Sprintf("%d", sum) + prefix := values[0].(string) + suffix := values[1].(string) + return prefix + strSum + suffix, nil +} + +func (*wrapSum) Config() duckdb.ScalarFuncConfig { + varcharTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR) + check(err) + intTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER) + check(err) + + return duckdb.ScalarFuncConfig{ + InputTypeInfos: []duckdb.TypeInfo{varcharTypeInfo, varcharTypeInfo}, + ResultTypeInfo: varcharTypeInfo, + VariadicTypeInfo: intTypeInfo, + } +} + +func (*wrapSum) Executor() duckdb.ScalarFuncExecutor { + return duckdb.ScalarFuncExecutor{RowExecutor: wrapSumFn} +} + +func wrapSumScalarUDF() { + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + check(err) + + c, err := db.Conn(context.Background()) + check(err) + + var wrapSumUDF *wrapSum + err = duckdb.RegisterScalarUDF(c, "wrap_sum", wrapSumUDF) + check(err) + + var res string + row := db.QueryRow(`SELECT wrap_sum('hello', ' world', 1, 2, 3, 4) AS sum`) + check(row.Scan(&res)) + if res != "hello10 world" { + panic(errors.New("incorrect result")) + } + + row = db.QueryRow(`SELECT wrap_sum('hello', ' world') AS sum`) + check(row.Scan(&res)) + if res != "hello0 world" { + panic(errors.New("incorrect result")) + } + + check(c.Close()) + check(db.Close()) +} + +func main() { + myLengthScalarUDFSet() + wrapSumScalarUDF() +} + +func check(args ...interface{}) { + err := args[len(args)-1] + if err != nil { + panic(err) + } +} diff --git a/scalar_udf.go b/scalar_udf.go new file mode 100644 index 00000000..2f055368 --- /dev/null +++ b/scalar_udf.go @@ -0,0 +1,291 @@ +package duckdb + +/* +#include + +void scalar_udf_callback(duckdb_function_info, duckdb_data_chunk, duckdb_vector); +void scalar_udf_delete_callback(void *); + +typedef void (*scalar_udf_callback_t)(duckdb_function_info, duckdb_data_chunk, duckdb_vector); +*/ +import "C" + +import ( + "database/sql" + "database/sql/driver" + "runtime" + "runtime/cgo" + "unsafe" +) + +// ScalarFuncConfig contains the fields to configure a user-defined scalar function. +type ScalarFuncConfig struct { + // InputTypeInfos contains Type information for each input parameter of the scalar function. + InputTypeInfos []TypeInfo + // ResultTypeInfo holds the Type information of the scalar function's result type. + ResultTypeInfo TypeInfo + + // VariadicTypeInfo configures the number of input parameters. + // If this field is nil, then the input parameters match InputTypeInfos. + // Otherwise, the scalar function's input parameters are set to variadic, allowing any number of input parameters. + // The Type of the first len(InputTypeInfos) parameters is configured by InputTypeInfos, and all + // remaining parameters must match the variadic Type. To configure different variadic parameter types, + // you must set the VariadicTypeInfo's Type to TYPE_ANY. + VariadicTypeInfo TypeInfo + // Volatile sets the stability of the scalar function to volatile, if true. + // Volatile scalar functions might create a different result per row. + // E.g., random() is a volatile scalar function. + Volatile bool + // SpecialNullHandling disables the default NULL handling of scalar functions, if true. + // The default NULL handling is: NULL in, NULL out. I.e., if any input parameter is NULL, then the result is NULL. + SpecialNullHandling bool +} + +// ScalarFuncExecutor contains the callback function to execute a user-defined scalar function. +// Currently, its only field is a row-based executor. +type ScalarFuncExecutor struct { + // RowExecutor accepts a row-based execution function. + // []driver.Value contains the row values, and it returns the row execution result, or error. + RowExecutor func(values []driver.Value) (any, error) +} + +// ScalarFunc is the user-defined scalar function interface. +// Any scalar function must implement a Config function, and an Executor function. +type ScalarFunc interface { + // Config returns ScalarFuncConfig to configure the scalar function. + Config() ScalarFuncConfig + // Executor returns ScalarFuncExecutor to execute the scalar function. + Executor() ScalarFuncExecutor +} + +// RegisterScalarUDF registers a user-defined scalar function. +// *sql.Conn is the SQL connection on which to register the scalar function. +// name is the function name, and f is the scalar function's interface ScalarFunc. +// RegisterScalarUDF takes ownership of f, so you must pass it as a pointer. +func RegisterScalarUDF(c *sql.Conn, name string, f ScalarFunc) error { + function, err := createScalarFunc(name, f) + if err != nil { + return getError(errAPI, err) + } + + // Register the function on the underlying driver connection exposed by c.Raw. + err = c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_scalar_function(con.duckdbCon, function) + C.duckdb_destroy_scalar_function(&function) + if state == C.DuckDBError { + return getError(errAPI, errScalarUDFCreate) + } + return nil + }) + return err +} + +// RegisterScalarUDFSet registers a set of user-defined scalar functions with the same name. +// This enables overloading of scalar functions. +// E.g., the function my_length() can have implementations like my_length(LIST(ANY)) and my_length(VARCHAR). +// *sql.Conn is the SQL connection on which to register the scalar function set. +// name is the function name of each function in the set. +// functions contains all ScalarFunc functions of the scalar function set. +func RegisterScalarUDFSet(c *sql.Conn, name string, functions ...ScalarFunc) error { + cName := C.CString(name) + set := C.duckdb_create_scalar_function_set(cName) + C.duckdb_free(unsafe.Pointer(cName)) + + // Create each function and add it to the set. + for i, f := range functions { + function, err := createScalarFunc(name, f) + if err != nil { + C.duckdb_destroy_scalar_function(&function) + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, err) + } + + state := C.duckdb_add_scalar_function_to_set(set, function) + C.duckdb_destroy_scalar_function(&function) + if state == C.DuckDBError { + C.duckdb_destroy_scalar_function_set(&set) + return getError(errAPI, addIndexToError(errScalarUDFAddToSet, i)) + } + } + + // Register the function set on the underlying driver connection exposed by c.Raw. + err := c.Raw(func(driverConn any) error { + con := driverConn.(*conn) + state := C.duckdb_register_scalar_function_set(con.duckdbCon, set) + C.duckdb_destroy_scalar_function_set(&set) + if state == C.DuckDBError { + return getError(errAPI, errScalarUDFCreateSet) + } + return nil + }) + return err +} + +//export scalar_udf_callback +func scalar_udf_callback(function_info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) { + extraInfo := C.duckdb_scalar_function_get_extra_info(function_info) + + // extraInfo is a void* pointer to our pinned ScalarFunc f. + h := *(*cgo.Handle)(unsafe.Pointer(extraInfo)) + function := h.Value().(pinnedValue[ScalarFunc]).value + + // Initialize the input chunk. + var inputChunk DataChunk + if err := inputChunk.initFromDuckDataChunk(input, false); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + + // Initialize the output chunk. + var outputChunk DataChunk + if err := outputChunk.initFromDuckVector(output, true); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + + executor := function.Executor() + nullInNullOut := !function.Config().SpecialNullHandling + values := make([]driver.Value, len(inputChunk.columns)) + columnCount := len(values) + rowCount := inputChunk.GetSize() + + // Execute the user-defined scalar function for each row. + var err error + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + nullRow := false + + // Get each column value. + for colIdx := 0; colIdx < columnCount; colIdx++ { + if values[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + + // NULL handling. + if nullInNullOut && values[colIdx] == nil { + if err = outputChunk.SetValue(0, rowIdx, nil); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + nullRow = true + break + } + } + if nullRow { + continue + } + + // Execute the function. + var val any + if val, err = executor.RowExecutor(values); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + + // Write the result to the output chunk. + if err = outputChunk.SetValue(0, rowIdx, val); err != nil { + setFuncError(function_info, getError(errAPI, err).Error()) + return + } + } +} + +//export scalar_udf_delete_callback +func scalar_udf_delete_callback(info unsafe.Pointer) { + udf_delete_callback(info) +} + +func registerInputParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error { + // Set variadic input parameters. + if config.VariadicTypeInfo != nil { + t := config.VariadicTypeInfo.logicalType() + C.duckdb_scalar_function_set_varargs(f, t) + C.duckdb_destroy_logical_type(&t) + } + + // Early-out, if the function does not take any (non-variadic) parameters. + if config.InputTypeInfos == nil { + return nil + } + if len(config.InputTypeInfos) == 0 { + return nil + } + + // Set non-variadic input parameters. + for i, info := range config.InputTypeInfos { + if info == nil { + return addIndexToError(errScalarUDFInputTypeIsNil, i) + } + t := info.logicalType() + C.duckdb_scalar_function_add_parameter(f, t) + C.duckdb_destroy_logical_type(&t) + } + return nil +} + +func registerResultParams(config ScalarFuncConfig, f C.duckdb_scalar_function) error { + if config.ResultTypeInfo == nil { + return errScalarUDFResultTypeIsNil + } + if config.ResultTypeInfo.InternalType() == TYPE_ANY { + return errScalarUDFResultTypeIsANY + } + t := config.ResultTypeInfo.logicalType() + C.duckdb_scalar_function_set_return_type(f, t) + C.duckdb_destroy_logical_type(&t) + return nil +} + +func createScalarFunc(name string, f ScalarFunc) (C.duckdb_scalar_function, error) { + if name == "" { + return nil, errScalarUDFNoName + } + if f == nil { + return nil, errScalarUDFIsNil + } + if f.Executor().RowExecutor == nil { + return nil, errScalarUDFNoExecutor + } + + function := C.duckdb_create_scalar_function() + + // Set the name. + cName := C.CString(name) + C.duckdb_scalar_function_set_name(function, cName) + C.duckdb_free(unsafe.Pointer(cName)) + + // Configure the scalar function. + config := f.Config() + if err := registerInputParams(config, function); err != nil { + return nil, err + } + if err := registerResultParams(config, function); err != nil { + return nil, err + } + if config.SpecialNullHandling { + C.duckdb_scalar_function_set_special_handling(function) + } + if config.Volatile { + C.duckdb_scalar_function_set_volatile(function) + } + + // Set the function callback. + C.duckdb_scalar_function_set_function(function, C.scalar_udf_callback_t(C.scalar_udf_callback)) + + // Pin the ScalarFunc f. + value := pinnedValue[ScalarFunc]{ + pinner: &runtime.Pinner{}, + value: f, + } + h := cgo.NewHandle(value) + value.pinner.Pin(&h) + + // Set the execution data, which is the ScalarFunc f. + C.duckdb_scalar_function_set_extra_info( + function, + unsafe.Pointer(&h), + C.duckdb_delete_callback_t(C.scalar_udf_delete_callback)) + + return function, nil +} diff --git a/scalar_udf_test.go b/scalar_udf_test.go new file mode 100644 index 00000000..feb939fb --- /dev/null +++ b/scalar_udf_test.go @@ -0,0 +1,436 @@ +package duckdb + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +var currentInfo TypeInfo + +type ( + simpleSUDF struct{} + constantSUDF struct{} + otherConstantSUDF struct{} + typesSUDF struct{} + variadicSUDF struct{} + anyTypeSUDF struct{} + errExecutorSUDF struct{} + errInputNilSUDF struct{} + errResultNilSUDF struct{} + errResultAnySUDF struct{} + errExecSUDF struct{} +) + +func simpleSum(values []driver.Value) (any, error) { + if values[0] == nil || values[1] == nil { + return nil, nil + } + val := values[0].(int32) + values[1].(int32) + return val, nil +} + +func constantOne([]driver.Value) (any, error) { + return int32(1), nil +} + +func identity(values []driver.Value) (any, error) { + return values[0], nil +} + +func variadicSum(values []driver.Value) (any, error) { + sum := int32(0) + for _, val := range values { + if val == nil { + return nil, nil + } + sum += val.(int32) + } + return sum, nil +} + +func nilCount(values []driver.Value) (any, error) { + count := int32(0) + for _, val := range values { + if val == nil { + count++ + } + } + return count, nil +} + +func constantError([]driver.Value) (any, error) { + return nil, errors.New("test invalid execution") +} + +func (*simpleSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{[]TypeInfo{currentInfo, currentInfo}, currentInfo, nil, false, false} +} + +func (*simpleSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{simpleSum} +} + +func (*constantSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{ResultTypeInfo: currentInfo} +} + +func (*constantSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantOne} +} + +func (*otherConstantSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{[]TypeInfo{}, currentInfo, nil, false, false} +} + +func (*otherConstantSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantOne} +} + +func (*typesSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{[]TypeInfo{currentInfo}, currentInfo, nil, false, false} +} + +func (*typesSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{identity} +} + +func (*variadicSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{nil, currentInfo, currentInfo, true, true} +} + +func (*variadicSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{variadicSum} +} + +func (*anyTypeSUDF) Config() ScalarFuncConfig { + info, err := NewTypeInfo(TYPE_ANY) + if err != nil { + panic(err) + } + + return ScalarFuncConfig{nil, currentInfo, info, false, true} +} + +func (*anyTypeSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{nilCount} +} + +func (*errExecutorSUDF) Config() ScalarFuncConfig { + scalarUDF := simpleSUDF{} + return scalarUDF.Config() +} + +func (*errExecutorSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{nil} +} + +func (*errInputNilSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{[]TypeInfo{nil}, currentInfo, nil, false, false} +} + +func (*errInputNilSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantOne} +} + +func (*errResultNilSUDF) Config() ScalarFuncConfig { + return ScalarFuncConfig{[]TypeInfo{currentInfo}, nil, nil, false, false} +} + +func (*errResultNilSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantOne} +} + +func (*errResultAnySUDF) Config() ScalarFuncConfig { + info, err := NewTypeInfo(TYPE_ANY) + if err != nil { + panic(err) + } + + return ScalarFuncConfig{[]TypeInfo{currentInfo}, info, nil, false, false} +} + +func (*errResultAnySUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantOne} +} + +func (*errExecSUDF) Config() ScalarFuncConfig { + scalarUDF := simpleSUDF{} + return scalarUDF.Config() +} + +func (*errExecSUDF) Executor() ScalarFuncExecutor { + return ScalarFuncExecutor{constantError} +} + +func TestSimpleScalarUDF(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *simpleSUDF + err = RegisterScalarUDF(c, "my_sum", udf) + require.NoError(t, err) + + var sum *int + row := db.QueryRow(`SELECT my_sum(10, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 52, *sum) + + row = db.QueryRow(`SELECT my_sum(NULL, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) + + row = db.QueryRow(`SELECT my_sum(42, NULL) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +func TestConstantScalarUDF(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *constantSUDF + err = RegisterScalarUDF(c, "constant_one", udf) + require.NoError(t, err) + + var otherUDF *otherConstantSUDF + err = RegisterScalarUDF(c, "other_constant_one", otherUDF) + require.NoError(t, err) + + var one int + row := db.QueryRow(`SELECT constant_one() AS one`) + require.NoError(t, row.Scan(&one)) + require.Equal(t, 1, one) + + row = db.QueryRow(`SELECT other_constant_one() AS one`) + require.NoError(t, row.Scan(&one)) + require.Equal(t, 1, one) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +func TestAllTypesScalarUDF(t *testing.T) { + typeInfos := getTypeInfos(t, false) + for _, info := range typeInfos { + currentInfo = info.TypeInfo + + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + _, err = c.ExecContext(context.Background(), `CREATE TYPE greeting AS ENUM ('hello', 'world')`) + require.NoError(t, err) + + var udf *typesSUDF + err = RegisterScalarUDF(c, "my_identity", udf) + require.NoError(t, err) + + var res string + row := db.QueryRow(fmt.Sprintf(`SELECT my_identity(%s)::VARCHAR AS res`, info.input)) + require.NoError(t, row.Scan(&res)) + if info.TypeInfo.InternalType() != TYPE_UUID { + require.Equal(t, info.output, res, fmt.Sprintf(`output does not match expected output, input: %s`, info.input)) + } else { + require.NotEqual(t, "", res, "uuid empty") + } + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) + } +} + +func TestScalarUDFSet(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf1 *simpleSUDF + var udf2 *typesSUDF + err = RegisterScalarUDFSet(c, "my_addition", udf1, udf2) + require.NoError(t, err) + + var sum int + row := db.QueryRow(`SELECT my_addition(10, 42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 52, sum) + + row = db.QueryRow(`SELECT my_addition(42) AS sum`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 42, sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +func TestVariadicScalarUDF(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *variadicSUDF + err = RegisterScalarUDF(c, "my_variadic_sum", udf) + require.NoError(t, err) + + var sum *int + row := db.QueryRow(`SELECT my_variadic_sum(10, NULL, NULL) AS msg`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) + + row = db.QueryRow(`SELECT my_variadic_sum(10, 42, 2, 2, 2) AS msg`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 58, *sum) + + row = db.QueryRow(`SELECT my_variadic_sum(10) AS msg`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 10, *sum) + + row = db.QueryRow(`SELECT my_variadic_sum(NULL) AS msg`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, (*int)(nil), sum) + + row = db.QueryRow(`SELECT my_variadic_sum() AS msg`) + require.NoError(t, row.Scan(&sum)) + require.Equal(t, 0, *sum) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +func TestANYScalarUDF(t *testing.T) { + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + var udf *anyTypeSUDF + err = RegisterScalarUDF(c, "my_null_count", udf) + require.NoError(t, err) + + var count int + row := db.QueryRow(`SELECT my_null_count(10, 'hello', 2, [2], 2) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + row = db.QueryRow(`SELECT my_null_count(10, NULL, NULL, [NULL], {'hello': NULL}) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 2, count) + + row = db.QueryRow(`SELECT my_null_count(10, True) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + row = db.QueryRow(`SELECT my_null_count(NULL) AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 1, count) + + row = db.QueryRow(`SELECT my_null_count() AS msg`) + require.NoError(t, row.Scan(&count)) + require.Equal(t, 0, count) + + require.NoError(t, c.Close()) + require.NoError(t, db.Close()) +} + +func TestScalarUDFErrors(t *testing.T) { + t.Parallel() + + db, err := sql.Open("duckdb", "") + require.NoError(t, err) + + c, err := db.Conn(context.Background()) + require.NoError(t, err) + + currentInfo, err = NewTypeInfo(TYPE_INTEGER) + require.NoError(t, err) + + // Empty name. + var emptyNameUDF *simpleSUDF + err = RegisterScalarUDF(c, "", emptyNameUDF) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoName.Error()) + + // Invalid executor. + var errExecutorUDF *errExecutorSUDF + err = RegisterScalarUDF(c, "err_executor_is_nil", errExecutorUDF) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFNoExecutor.Error()) + + // Invalid input parameter. + var errInputNilUDF *errInputNilSUDF + err = RegisterScalarUDF(c, "err_input_type_is_nil", errInputNilUDF) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFInputTypeIsNil.Error()) + + // Invalid result parameters. + var errResultNil *errResultNilSUDF + err = RegisterScalarUDF(c, "err_result_type_is_nil", errResultNil) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsNil.Error()) + var errResultAny *errResultAnySUDF + err = RegisterScalarUDF(c, "err_result_type_is_any", errResultAny) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error(), errScalarUDFResultTypeIsANY.Error()) + + // Error during execution. + var errExecUDF *errExecSUDF + err = RegisterScalarUDF(c, "err_exec", errExecUDF) + require.NoError(t, err) + row := db.QueryRow(`SELECT err_exec(10, 10) AS res`) + testError(t, row.Err(), errAPI.Error()) + + // Register the same scalar function a second time. + // Since RegisterScalarUDF takes ownership of udf, we are now passing nil. + var udf *simpleSUDF + err = RegisterScalarUDF(c, "my_sum", udf) + require.NoError(t, err) + err = RegisterScalarUDF(c, "my_sum", udf) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error()) + + // Register a scalar function whose name already exists. + var errDuplicateUDF *simpleSUDF + err = RegisterScalarUDF(c, "my_sum", errDuplicateUDF) + testError(t, err, errAPI.Error(), errScalarUDFCreate.Error()) + + // Register a scalar function that is nil. + err = RegisterScalarUDF(c, "my_sum", nil) + testError(t, err, errAPI.Error(), errScalarUDFIsNil.Error()) + require.NoError(t, c.Close()) + + // Test registering the scalar function on a closed connection. + var errClosedConUDF *simpleSUDF + err = RegisterScalarUDF(c, "closed_con", errClosedConUDF) + require.ErrorContains(t, err, sql.ErrConnDone.Error()) + require.NoError(t, db.Close()) +} diff --git a/type_info.go b/type_info.go index 8259e810..bc2269bc 100644 --- a/type_info.go +++ b/type_info.go @@ -69,9 +69,15 @@ type typeInfo struct { // TypeInfo is an interface for a DuckDB type. type TypeInfo interface { + // InternalType returns the Type. + InternalType() Type logicalType() C.duckdb_logical_type } +func (info *typeInfo) InternalType() Type { + return info.Type +} + // NewTypeInfo returns type information for DuckDB's primitive types. // It returns the TypeInfo, if the Type parameter is a valid primitive type. // Else, it returns nil, and an error. diff --git a/type_info_test.go b/type_info_test.go index f8f5a0b6..c6682f7c 100644 --- a/type_info_test.go +++ b/type_info_test.go @@ -6,13 +6,52 @@ import ( "github.com/stretchr/testify/require" ) -func TestTypeInfo(t *testing.T) { +type testTypeValues struct { + input string + output string +} + +type testTypeInfo struct { + TypeInfo + testTypeValues +} + +var testPrimitiveSQLValues = map[Type]testTypeValues{ + TYPE_BOOLEAN: {input: `true::BOOLEAN`, output: `true`}, + TYPE_TINYINT: {input: `42::TINYINT`, output: `42`}, + TYPE_SMALLINT: {input: `42::SMALLINT`, output: `42`}, + TYPE_INTEGER: {input: `42::INTEGER`, output: `42`}, + TYPE_BIGINT: {input: `42::BIGINT`, output: `42`}, + TYPE_UTINYINT: {input: `43::UTINYINT`, output: `43`}, + TYPE_USMALLINT: {input: `43::USMALLINT`, output: `43`}, + TYPE_UINTEGER: {input: `43::UINTEGER`, output: `43`}, + TYPE_UBIGINT: {input: `43::UBIGINT`, output: `43`}, + TYPE_FLOAT: {input: `1.7::FLOAT`, output: `1.7`}, + TYPE_DOUBLE: {input: `1.7::DOUBLE`, output: `1.7`}, + TYPE_TIMESTAMP: {input: `TIMESTAMP '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456`}, + TYPE_DATE: {input: `DATE '1992-09-20 11:30:00.123456789'`, output: `1992-09-20`}, + TYPE_TIME: {input: `TIME '1992-09-20 11:30:00.123456789'`, output: `11:30:00.123456`}, + TYPE_INTERVAL: {input: `INTERVAL 1 YEAR`, output: `1 year`}, + TYPE_HUGEINT: {input: `44::HUGEINT`, output: `44`}, + TYPE_VARCHAR: {input: `'hello world'::VARCHAR`, output: `hello world`}, + TYPE_BLOB: {input: `'\xAA'::BLOB`, output: `\xAA`}, + TYPE_TIMESTAMP_S: {input: `TIMESTAMP_S '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00`}, + TYPE_TIMESTAMP_MS: {input: `TIMESTAMP_MS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123`}, + TYPE_TIMESTAMP_NS: {input: `TIMESTAMP_NS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456789`}, + TYPE_UUID: {input: `uuid()`, output: ``}, + TYPE_TIMESTAMP_TZ: {input: `TIMESTAMPTZ '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456+00`}, +} + +func getTypeInfos(t *testing.T, useAny bool) []testTypeInfo { var primitiveTypes []Type for k := range typeToStringMap { _, inMap := unsupportedTypeToStringMap[k] if inMap && k != TYPE_ANY { continue } + if k == TYPE_ANY && !useAny { + continue + } switch k { case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP: continue @@ -21,45 +60,115 @@ func TestTypeInfo(t *testing.T) { } // Create each primitive type information. - var typeInfos []TypeInfo + var testTypeInfos []testTypeInfo for _, primitive := range primitiveTypes { info, err := NewTypeInfo(primitive) require.NoError(t, err) - typeInfos = append(typeInfos, info) + testInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testPrimitiveSQLValues[primitive], + } + testTypeInfos = append(testTypeInfos, testInfo) } - // Create nested types. - decimalInfo, err := NewDecimalInfo(3, 2) + // Create nested type information. + + info, err := NewDecimalInfo(3, 2) require.NoError(t, err) - enumInfo, err := NewEnumInfo("hello", "world", "!") + decimalTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `4::DECIMAL(3, 2)`, + output: `4.00`, + }, + } + + info, err = NewEnumInfo("hello", "world", "!") require.NoError(t, err) - listInfo, err := NewListInfo(decimalInfo) + enumTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `'hello'::greeting`, + output: `hello`, + }, + } + + info, err = NewListInfo(decimalTypeInfo) require.NoError(t, err) - nestedListInfo, err := NewListInfo(listInfo) + listTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `[4::DECIMAL(3, 2)]`, + output: `[4.00]`, + }, + } + + info, err = NewListInfo(listTypeInfo) require.NoError(t, err) + nestedListTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `[[4::DECIMAL(3, 2)]]`, + output: `[[4.00]]`, + }, + } - firstEntry, err := NewStructEntry(enumInfo, "hello") + firstEntry, err := NewStructEntry(enumTypeInfo, "hello") require.NoError(t, err) - secondEntry, err := NewStructEntry(nestedListInfo, "world") + secondEntry, err := NewStructEntry(nestedListTypeInfo, "world") require.NoError(t, err) - structInfo, err := NewStructInfo(firstEntry, secondEntry) + info, err = NewStructInfo(firstEntry, secondEntry) require.NoError(t, err) + structTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `{'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]}`, + output: `{'hello': hello, 'world': [[4.00]]}`, + }, + } - firstEntry, err = NewStructEntry(structInfo, "hello") + firstEntry, err = NewStructEntry(structTypeInfo, "hello") require.NoError(t, err) - secondEntry, err = NewStructEntry(listInfo, "world") + secondEntry, err = NewStructEntry(listTypeInfo, "world") require.NoError(t, err) - nestedStructInfo, err := NewStructInfo(firstEntry, secondEntry) + info, err = NewStructInfo(firstEntry, secondEntry) require.NoError(t, err) + nestedStructTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `{ + 'hello': {'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]}, + 'world': [4::DECIMAL(3, 2)] + }`, + output: `{'hello': {'hello': hello, 'world': [[4.00]]}, 'world': [4.00]}`, + }, + } - mapInfo, err := NewMapInfo(nestedStructInfo, nestedListInfo) + info, err = NewMapInfo(decimalTypeInfo, nestedStructTypeInfo) require.NoError(t, err) + mapTypeInfo := testTypeInfo{ + TypeInfo: info, + testTypeValues: testTypeValues{ + input: `MAP { + 4::DECIMAL(3, 2) : { + 'hello': {'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]}, + 'world': [4::DECIMAL(3, 2)] + } + }`, + output: `{4.00={'hello': {'hello': hello, 'world': [[4.00]]}, 'world': [4.00]}}`, + }, + } + + testTypeInfos = append(testTypeInfos, decimalTypeInfo, enumTypeInfo, listTypeInfo, nestedListTypeInfo, structTypeInfo, nestedStructTypeInfo, mapTypeInfo) + return testTypeInfos +} - typeInfos = append(typeInfos, decimalInfo, enumInfo, listInfo, nestedListInfo, structInfo, nestedStructInfo, mapInfo) +func TestTypeInterface(t *testing.T) { + testTypeInfos := getTypeInfos(t, true) // Use each type as a child. - for _, info := range typeInfos { - _, err = NewListInfo(info) + for _, info := range testTypeInfos { + _, err := NewListInfo(info.TypeInfo) require.NoError(t, err) } } diff --git a/types.go b/types.go index bc21c29d..95617da5 100644 --- a/types.go +++ b/types.go @@ -21,10 +21,12 @@ func convertNumericType[srcT numericType, destT numericType](val srcT) destT { return destT(val) } -type UUID [16]byte +const UUIDLength = 16 + +type UUID [UUIDLength]byte func (u *UUID) Scan(v any) error { - if n := copy(u[:], v.([]byte)); n != 16 { + if n := copy(u[:], v.([]byte)); n != UUIDLength { return fmt.Errorf("invalid UUID length: %d", n) } return nil @@ -34,7 +36,7 @@ func (u *UUID) Scan(v any) error { // The value is computed as: upper * 2^64 + lower func hugeIntToUUID(hi C.duckdb_hugeint) []byte { - var uuid [16]byte + var uuid [UUIDLength]byte // We need to flip the sign bit of the signed hugeint to transform it to UUID bytes binary.BigEndian.PutUint64(uuid[:8], uint64(hi.upper)^1<<63) binary.BigEndian.PutUint64(uuid[8:], uint64(hi.lower)) diff --git a/udf_utils.go b/udf_utils.go new file mode 100644 index 00000000..91ef9541 --- /dev/null +++ b/udf_utils.go @@ -0,0 +1,37 @@ +package duckdb + +/* +#include +*/ +import "C" + +import ( + "runtime" + "runtime/cgo" + "unsafe" +) + +type pinnedValue[T any] struct { + pinner *runtime.Pinner + value T +} + +type unpinner interface { + unpin() +} + +func (v pinnedValue[T]) unpin() { + v.pinner.Unpin() +} + +func setFuncError(function_info C.duckdb_function_info, msg string) { + err := C.CString(msg) + C.duckdb_scalar_function_set_error(function_info, err) + C.duckdb_free(unsafe.Pointer(err)) +} + +func udf_delete_callback(info unsafe.Pointer) { + h := (*cgo.Handle)(info) + h.Value().(unpinner).unpin() + h.Delete() +} diff --git a/vector.go b/vector.go index d4366400..0eea79cf 100644 --- a/vector.go +++ b/vector.go @@ -25,6 +25,10 @@ type vector struct { // The child vectors of nested data types. childVectors []vector + // FIXME: This is a workaround until the C API exposes SQLNULL. + // FIXME: Then, SQLNULL becomes another Type value (C.DUCKDB_TYPE_SQLNULL). + isSQLNull bool + // The vector's type information. vectorTypeInfo } @@ -34,13 +38,19 @@ func (*vector) canNil(val reflect.Value) bool { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: return true + default: + return false } - return false } func (vec *vector) init(logicalType C.duckdb_logical_type, colIdx int) error { t := Type(C.duckdb_get_type_id(logicalType)) + if t == TYPE_INVALID { + vec.isSQLNull = true + return nil + } + name, inMap := unsupportedTypeToStringMap[t] if inMap { return addIndexToError(unsupportedTypeError(name), colIdx) @@ -113,6 +123,10 @@ func (vec *vector) resetChildData() { } func (vec *vector) initVectors(v C.duckdb_vector, writable bool) { + if vec.isSQLNull { + return + } + vec.duckdbVector = v vec.ptr = C.duckdb_vector_get_data(v) if writable { diff --git a/vector_getters.go b/vector_getters.go index 3f82f3a7..0bb447fd 100644 --- a/vector_getters.go +++ b/vector_getters.go @@ -148,8 +148,12 @@ func (vec *vector) getList(rowIdx C.idx_t) []any { // Fill the slice with all child values. for i := C.idx_t(0); i < entry.length; i++ { - val := child.getFn(child, i+entry.offset) - slice = append(slice, val) + if child.isSQLNull { + slice = append(slice, nil) + } else { + val := child.getFn(child, i+entry.offset) + slice = append(slice, val) + } } return slice } @@ -158,8 +162,12 @@ func (vec *vector) getStruct(rowIdx C.idx_t) map[string]any { m := map[string]any{} for i := 0; i < len(vec.childVectors); i++ { child := &vec.childVectors[i] - val := child.getFn(child, rowIdx) - m[vec.structEntries[i].Name()] = val + if child.isSQLNull { + m[vec.structEntries[i].Name()] = nil + } else { + val := child.getFn(child, rowIdx) + m[vec.structEntries[i].Name()] = val + } } return m } diff --git a/vector_setters.go b/vector_setters.go index 77ec88ba..271679c7 100644 --- a/vector_setters.go +++ b/vector_setters.go @@ -394,6 +394,13 @@ func setUUID[S any](vec *vector, rowIdx C.idx_t, val S) error { switch v := any(val).(type) { case UUID: uuid = v + case []uint8: + if len(v) != UUIDLength { + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) + } + for i := 0; i < UUIDLength; i++ { + uuid[i] = v[i] + } default: return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) }