Skip to content

Commit

Permalink
fix simple example
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 12, 2024
1 parent f1d079a commit 49f7b7a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 40 deletions.
16 changes: 8 additions & 8 deletions data_chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,20 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable
func (chunk *DataChunk) initFromDuckVector(duckdbVector C.duckdb_vector, writable bool) error {
columnCount := 1
chunk.columns = make([]vector, columnCount)
chunk.columns[0].duckdbVector = duckdbVector
chunk.columns[0].getChildVectors(duckdbVector, writable)

// Initialize the callback function to read and write values.
// Initialize the callback functions to read and write values.
logicalType := C.duckdb_vector_get_column_type(duckdbVector)
err := chunk.columns[0].init(logicalType, 0)
C.duckdb_destroy_logical_type(&logicalType)
return err
if err != nil {
return err
}

// Initialize the vector and its child vectors.
chunk.columns[0].initVectors(duckdbVector, writable)
return nil
}

func (chunk *DataChunk) close() {
C.duckdb_destroy_data_chunk(&chunk.data)
}

// TODO: GetMetaData, see table UDF PR.
// TODO: Add all templated functions.
// TODO: Projection pushdown.
31 changes: 18 additions & 13 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,46 @@ type ScalarFunctionConfig struct {
type ScalarFunction interface {
Config() (ScalarFunctionConfig, error)
ExecuteRow(args []driver.Value) (any, error)
SetError(err error)
}

func setFunctionError(info C.duckdb_function_info, msg string) {
err := C.CString(msg)
C.duckdb_scalar_function_set_error(info, err)
C.duckdb_free(unsafe.Pointer(err))
}

//export scalar_udf_callback
func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk, output C.duckdb_vector) {
// info is a void* pointer to our ScalarFunction.
h := *(*cgo.Handle)(unsafe.Pointer(info))
// If we hardcoded h = 1 here, it no longer segfaults.
scalarFunction := h.Value().(ScalarFunction)
extraInfo := C.duckdb_scalar_function_get_extra_info(info)

var err error
// extraInfo is a void* pointer to our ScalarFunction.
h := *(*cgo.Handle)(unsafe.Pointer(extraInfo))
scalarFunction := h.Value().(ScalarFunction)

// Initialize the input chunk.
var inputChunk DataChunk
if err = inputChunk.initFromDuckDataChunk(input, false); err != nil {
scalarFunction.SetError(getError(errAPI, err))
if err := inputChunk.initFromDuckDataChunk(input, false); err != nil {
setFunctionError(info, getError(errAPI, err).Error())
return
}
// Initialize the output chunk.
var outputChunk DataChunk
if err = outputChunk.initFromDuckVector(output, true); err != nil {
scalarFunction.SetError(getError(errAPI, err))
if err := outputChunk.initFromDuckVector(output, true); err != nil {
setFunctionError(info, getError(errAPI, err).Error())
return
}

// Execute the user-defined scalar function for each row.
args := make([]driver.Value, len(inputChunk.columns))
rowCount := inputChunk.GetSize()
columnCount := len(args)
for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
var err error

for rowIdx := 0; rowIdx < rowCount; rowIdx++ {
// Set the input arguments for each column of a row.
for colIdx := 0; colIdx < columnCount; colIdx++ {
if args[colIdx], err = inputChunk.GetValue(colIdx, rowIdx); err != nil {
scalarFunction.SetError(getError(errAPI, err))
setFunctionError(info, getError(errAPI, err).Error())
return
}
}
Expand All @@ -77,7 +82,7 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk,
}

if err != nil {
scalarFunction.SetError(getError(errAPI, err))
setFunctionError(info, getError(errAPI, err).Error())
}
}

Expand Down
32 changes: 13 additions & 19 deletions scalar_udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,39 @@ import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"testing"

"github.com/stretchr/testify/require"
)

type scalarUDF struct {
err error
}
type simpleScalarUDF struct{}

func (udf *scalarUDF) Config() (ScalarFunctionConfig, error) {
func (udf *simpleScalarUDF) Config() (ScalarFunctionConfig, error) {
var config ScalarFunctionConfig

intInfo, err := PrimitiveTypeInfo(TYPE_INTEGER)
intTypeInp, err := PrimitiveTypeInfo(TYPE_INTEGER)
if err != nil {
return config, err
}

config.InputTypeInfos = []TypeInfo{intInfo, intInfo}
config.ResultTypeInfo = intInfo
config.InputTypeInfos = []TypeInfo{intTypeInp, intTypeInp}
config.ResultTypeInfo = intTypeInp
return config, nil
}

func (udf *scalarUDF) ExecuteRow(args []driver.Value) (any, error) {
if len(args) != 2 {
return nil, errors.New("error executing row: expected two input values")
}
func (udf *simpleScalarUDF) ExecuteRow(args []driver.Value) (any, error) {
val := args[0].(int32) + args[1].(int32)
return val, nil
}

func (udf *scalarUDF) SetError(err error) {
udf.err = err
}

func TestScalarUDFPrimitive(t *testing.T) {
func TestSimpleScalarUDF(t *testing.T) {
db, err := sql.Open("duckdb", "")
require.NoError(t, err)

c, err := db.Conn(context.Background())
require.NoError(t, err)

var udf scalarUDF
var udf simpleScalarUDF
err = RegisterScalarUDF(c, "my_sum", &udf)
require.NoError(t, err)

Expand All @@ -55,10 +45,14 @@ func TestScalarUDFPrimitive(t *testing.T) {
require.NoError(t, row.Scan(&msg))
require.Equal(t, 52, msg)
require.NoError(t, db.Close())
}

func TestAllTypesInScalarUDF(t *testing.T) {

// TODO: test other primitive data types
}

// TODO: test other primitive data types

func TestScalarUDFErrors(t *testing.T) {
// TODO: trigger all possible errors and move to errors_test.go
}
Expand Down

0 comments on commit 49f7b7a

Please sign in to comment.