From 49f7b7a5be6f2d3c7119fc3f61be001f14e99be2 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:26:56 +0200 Subject: [PATCH] fix simple example --- data_chunk.go | 16 ++++++++-------- scalar_udf.go | 31 ++++++++++++++++++------------- scalar_udf_test.go | 32 +++++++++++++------------------- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/data_chunk.go b/data_chunk.go index d934c81c..d8176611 100644 --- a/data_chunk.go +++ b/data_chunk.go @@ -129,20 +129,20 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable func (chunk *DataChunk) initFromDuckVector(duckdbVector C.duckdb_vector, writable bool) error { columnCount := 1 chunk.columns = make([]vector, columnCount) - chunk.columns[0].duckdbVector = duckdbVector - chunk.columns[0].getChildVectors(duckdbVector, writable) - // Initialize the callback function to read and write values. + // Initialize the callback functions to read and write values. logicalType := C.duckdb_vector_get_column_type(duckdbVector) err := chunk.columns[0].init(logicalType, 0) C.duckdb_destroy_logical_type(&logicalType) - return err + if err != nil { + return err + } + + // Initialize the vector and its child vectors. + chunk.columns[0].initVectors(duckdbVector, writable) + return nil } func (chunk *DataChunk) close() { C.duckdb_destroy_data_chunk(&chunk.data) } - -// TODO: GetMetaData, see table UDF PR. -// TODO: Add all templated functions. -// TODO: Projection pushdown. diff --git a/scalar_udf.go b/scalar_udf.go index a205b42c..2fab18b5 100644 --- a/scalar_udf.go +++ b/scalar_udf.go @@ -27,28 +27,32 @@ type ScalarFunctionConfig struct { type ScalarFunction interface { Config() (ScalarFunctionConfig, error) ExecuteRow(args []driver.Value) (any, error) - SetError(err error) +} + +func setFunctionError(info C.duckdb_function_info, msg string) { + err := C.CString(msg) + C.duckdb_scalar_function_set_error(info, err) + C.duckdb_free(unsafe.Pointer(err)) } //export scalar_udf_callback func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) { - // info is a void* pointer to our ScalarFunction. - h := *(*cgo.Handle)(unsafe.Pointer(info)) - // If we hardcoded h = 1 here, it no longer segfaults. - scalarFunction := h.Value().(ScalarFunction) + extraInfo := C.duckdb_scalar_function_get_extra_info(info) - var err error + // extraInfo is a void* pointer to our ScalarFunction. + h := *(*cgo.Handle)(unsafe.Pointer(extraInfo)) + scalarFunction := h.Value().(ScalarFunction) // Initialize the input chunk. var inputChunk DataChunk - if err = inputChunk.initFromDuckDataChunk(input, false); err != nil { - scalarFunction.SetError(getError(errAPI, err)) + if err := inputChunk.initFromDuckDataChunk(input, false); err != nil { + setFunctionError(info, getError(errAPI, err).Error()) return } // Initialize the output chunk. var outputChunk DataChunk - if err = outputChunk.initFromDuckVector(output, true); err != nil { - scalarFunction.SetError(getError(errAPI, err)) + if err := outputChunk.initFromDuckVector(output, true); err != nil { + setFunctionError(info, getError(errAPI, err).Error()) return } @@ -56,12 +60,13 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, args := make([]driver.Value, len(inputChunk.columns)) rowCount := inputChunk.GetSize() columnCount := len(args) - for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + var err error + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { // Set the input arguments for each column of a row. for colIdx := 0; colIdx < columnCount; colIdx++ { if args[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil { - scalarFunction.SetError(getError(errAPI, err)) + setFunctionError(info, getError(errAPI, err).Error()) return } } @@ -77,7 +82,7 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, } if err != nil { - scalarFunction.SetError(getError(errAPI, err)) + setFunctionError(info, getError(errAPI, err).Error()) } } diff --git a/scalar_udf_test.go b/scalar_udf_test.go index fcef6438..32e83e76 100644 --- a/scalar_udf_test.go +++ b/scalar_udf_test.go @@ -4,49 +4,39 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "testing" "github.com/stretchr/testify/require" ) -type scalarUDF struct { - err error -} +type simpleScalarUDF struct{} -func (udf *scalarUDF) Config() (ScalarFunctionConfig, error) { +func (udf *simpleScalarUDF) Config() (ScalarFunctionConfig, error) { var config ScalarFunctionConfig - intInfo, err := PrimitiveTypeInfo(TYPE_INTEGER) + intTypeInp, err := PrimitiveTypeInfo(TYPE_INTEGER) if err != nil { return config, err } - config.InputTypeInfos = []TypeInfo{intInfo, intInfo} - config.ResultTypeInfo = intInfo + config.InputTypeInfos = []TypeInfo{intTypeInp, intTypeInp} + config.ResultTypeInfo = intTypeInp return config, nil } -func (udf *scalarUDF) ExecuteRow(args []driver.Value) (any, error) { - if len(args) != 2 { - return nil, errors.New("error executing row: expected two input values") - } +func (udf *simpleScalarUDF) ExecuteRow(args []driver.Value) (any, error) { val := args[0].(int32) + args[1].(int32) return val, nil } -func (udf *scalarUDF) SetError(err error) { - udf.err = err -} - -func TestScalarUDFPrimitive(t *testing.T) { +func TestSimpleScalarUDF(t *testing.T) { db, err := sql.Open("duckdb", "") require.NoError(t, err) c, err := db.Conn(context.Background()) require.NoError(t, err) - var udf scalarUDF + var udf simpleScalarUDF err = RegisterScalarUDF(c, "my_sum", &udf) require.NoError(t, err) @@ -55,10 +45,14 @@ func TestScalarUDFPrimitive(t *testing.T) { require.NoError(t, row.Scan(&msg)) require.Equal(t, 52, msg) require.NoError(t, db.Close()) +} + +func TestAllTypesInScalarUDF(t *testing.T) { - // TODO: test other primitive data types } +// TODO: test other primitive data types + func TestScalarUDFErrors(t *testing.T) { // TODO: trigger all possible errors and move to errors_test.go }