Skip to content

Commit

Permalink
Merge pull request #222 from taniabogatsch/scalar
Browse files Browse the repository at this point in the history
[Feature] Scalar UDF support
  • Loading branch information
taniabogatsch authored Sep 23, 2024
2 parents a9134d6 + 649f3c9 commit aa90038
Show file tree
Hide file tree
Showing 13 changed files with 1,149 additions and 33 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ install:
examples:
go run examples/simple/main.go
go run examples/appender/main.go
go run examples/scalar_udf/main.go

.PHONY: test
test:
Expand Down
24 changes: 24 additions & 0 deletions data_chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ func (chunk *DataChunk) GetValue(colIdx int, rowIdx int) (any, error) {
return nil, getError(errAPI, columnCountError(colIdx, len(chunk.columns)))
}
column := &chunk.columns[colIdx]
if column.isSQLNull {
return nil, nil
}
return column.getFn(column, C.idx_t(rowIdx)), nil
}

Expand All @@ -57,7 +60,11 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error {
if colIdx >= len(chunk.columns) {
return getError(errAPI, columnCountError(colIdx, len(chunk.columns)))
}

column := &chunk.columns[colIdx]
if column.isSQLNull {
return getError(errAPI, errSetSQLNULLValue)
}

// Set the value.
return column.setFn(column, C.idx_t(rowIdx), val)
Expand Down Expand Up @@ -127,6 +134,23 @@ func (chunk *DataChunk) initFromDuckDataChunk(data C.duckdb_data_chunk, writable
return err
}

func (chunk *DataChunk) initFromDuckVector(duckdbVector C.duckdb_vector, writable bool) error {
columnCount := 1
chunk.columns = make([]vector, columnCount)

// Initialize the callback functions to read and write values.
logicalType := C.duckdb_vector_get_column_type(duckdbVector)
err := chunk.columns[0].init(logicalType, 0)
C.duckdb_destroy_logical_type(&logicalType)
if err != nil {
return err
}

// Initialize the vector and its child vectors.
chunk.columns[0].initVectors(duckdbVector, writable)
return nil
}

func (chunk *DataChunk) close() {
C.duckdb_destroy_data_chunk(&chunk.data)
}
29 changes: 22 additions & 7 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,34 @@ var (

errUnsupportedMapKeyType = errors.New("MAP key type not supported")

errAppenderInvalidCon = errors.New("could not create appender: not a DuckDB driver connection")
errAppenderClosedCon = errors.New("could not create appender: appender creation on a closed connection")
errAppenderCreation = errors.New("could not create appender")
errAppenderDoubleClose = errors.New("could not close appender: already closed")
errAppenderCreation = errors.New("could not create appender")
errAppenderInvalidCon = fmt.Errorf("%w: not a DuckDB driver connection", errAppenderCreation)
errAppenderClosedCon = fmt.Errorf("%w: appender creation on a closed connection", errAppenderCreation)

errAppenderClose = errors.New("could not close appender")
errAppenderDoubleClose = fmt.Errorf("%w: already closed", errAppenderClose)

errAppenderAppendRow = errors.New("could not append row")
errAppenderAppendAfterClose = errors.New("could not append row: appender already closed")
errAppenderClose = errors.New("could not close appender")
errAppenderFlush = errors.New("could not flush appender")
errAppenderAppendAfterClose = fmt.Errorf("%w: appender already closed", errAppenderAppendRow)

errAppenderFlush = errors.New("could not flush appender")

errEmptyName = errors.New("empty name")
errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH)
errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width")

errScalarUDFCreate = errors.New("could not create scalar UDF")
errScalarUDFNoName = fmt.Errorf("%w: missing name", errScalarUDFCreate)
errScalarUDFIsNil = fmt.Errorf("%w: function is nil", errScalarUDFCreate)
errScalarUDFNoExecutor = fmt.Errorf("%w: executor is nil", errScalarUDFCreate)
errScalarUDFInputTypeIsNil = fmt.Errorf("%w: input type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsNil = fmt.Errorf("%w: result type is nil", errScalarUDFCreate)
errScalarUDFResultTypeIsANY = fmt.Errorf("%w: result type is ANY, which is not supported", errScalarUDFCreate)
errScalarUDFCreateSet = fmt.Errorf("could not create scalar UDF set")
errScalarUDFAddToSet = fmt.Errorf("%w: could not add the function to the set", errScalarUDFCreateSet)

errSetSQLNULLValue = errors.New("cannot write to a NULL column")

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
Expand Down
166 changes: 166 additions & 0 deletions examples/scalar_udf/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package main

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"

"github.com/marcboeker/go-duckdb"
)

// Overload my_length with two user-defined scalar functions.
// varcharLen takes a VARCHAR as its input parameter.
// listLen takes a LIST(ANY) as its input parameter.

type (
varcharLen struct{}
listLen struct{}
)

func varcharLenFn(values []driver.Value) (any, error) {
str := values[0].(string)
return int32(len(str)), nil
}

func (*varcharLen) Config() duckdb.ScalarFuncConfig {
inputTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR)
check(err)
resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER)
check(err)

return duckdb.ScalarFuncConfig{
InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo},
ResultTypeInfo: resultTypeInfo,
}
}

func (*varcharLen) Executor() duckdb.ScalarFuncExecutor {
return duckdb.ScalarFuncExecutor{RowExecutor: varcharLenFn}
}

func listLenFn(values []driver.Value) (any, error) {
list := values[0].([]any)
return int32(len(list)), nil
}

func (*listLen) Config() duckdb.ScalarFuncConfig {
anyTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_ANY)
check(err)
inputTypeInfo, err := duckdb.NewListInfo(anyTypeInfo)
check(err)
resultTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER)
check(err)

return duckdb.ScalarFuncConfig{
InputTypeInfos: []duckdb.TypeInfo{inputTypeInfo},
ResultTypeInfo: resultTypeInfo,
}
}

func (*listLen) Executor() duckdb.ScalarFuncExecutor {
return duckdb.ScalarFuncExecutor{RowExecutor: listLenFn}
}

func myLengthScalarUDFSet() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

c, err := db.Conn(context.Background())
check(err)

var varcharUDF *varcharLen
var listUDF *listLen
err = duckdb.RegisterScalarUDFSet(c, "my_length", varcharUDF, listUDF)
check(err)

var length int32
row := db.QueryRow(`SELECT my_length('hello world') AS sum`)
check(row.Scan(&length))
if length != 11 {
panic(errors.New("incorrect length"))
}

row = db.QueryRow(`SELECT my_length([1, 2, NULL, 4, NULL]) AS sum`)
check(row.Scan(&length))
if length != 5 {
panic(errors.New("incorrect length"))
}

check(c.Close())
check(db.Close())
}

// appendSum takes a VARCHAR prefix, a VARCHAR suffix, and a variadic number of integer values.
// It computes the sum of the integer values. Then, it emits a VARCHAR by concatenating prefix || sum || suffix.

type wrapSum struct{}

func wrapSumFn(values []driver.Value) (any, error) {
sum := int32(0)
for i := 2; i < len(values); i++ {
sum += values[i].(int32)
}
strSum := fmt.Sprintf("%d", sum)
prefix := values[0].(string)
suffix := values[1].(string)
return prefix + strSum + suffix, nil
}

func (*wrapSum) Config() duckdb.ScalarFuncConfig {
varcharTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_VARCHAR)
check(err)
intTypeInfo, err := duckdb.NewTypeInfo(duckdb.TYPE_INTEGER)
check(err)

return duckdb.ScalarFuncConfig{
InputTypeInfos: []duckdb.TypeInfo{varcharTypeInfo, varcharTypeInfo},
ResultTypeInfo: varcharTypeInfo,
VariadicTypeInfo: intTypeInfo,
}
}

func (*wrapSum) Executor() duckdb.ScalarFuncExecutor {
return duckdb.ScalarFuncExecutor{RowExecutor: wrapSumFn}
}

func wrapSumScalarUDF() {
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
check(err)

c, err := db.Conn(context.Background())
check(err)

var wrapSumUDF *wrapSum
err = duckdb.RegisterScalarUDF(c, "wrap_sum", wrapSumUDF)
check(err)

var res string
row := db.QueryRow(`SELECT wrap_sum('hello', ' world', 1, 2, 3, 4) AS sum`)
check(row.Scan(&res))
if res != "hello10 world" {
panic(errors.New("incorrect result"))
}

row = db.QueryRow(`SELECT wrap_sum('hello', ' world') AS sum`)
check(row.Scan(&res))
if res != "hello0 world" {
panic(errors.New("incorrect result"))
}

check(c.Close())
check(db.Close())
}

func main() {
myLengthScalarUDFSet()
wrapSumScalarUDF()
}

func check(args ...interface{}) {
err := args[len(args)-1]
if err != nil {
panic(err)
}
}
Loading

0 comments on commit aa90038

Please sign in to comment.