Skip to content

Commit

Permalink
test all types in scalar UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Sep 12, 2024
1 parent 49f7b7a commit 1833f7d
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 22 deletions.
5 changes: 3 additions & 2 deletions scalar_udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func scalar_udf_callback(info C.duckdb_function_info, input C.duckdb_data_chunk,
}

//export scalar_udf_delete_callback
func scalar_udf_delete_callback(data unsafe.Pointer) {
h := cgo.Handle(data)
func scalar_udf_delete_callback(extraInfo unsafe.Pointer) {
h := cgo.Handle(extraInfo)
h.Delete()
}

Expand Down Expand Up @@ -158,6 +158,7 @@ func RegisterScalarUDF(c *sql.Conn, name string, function ScalarFunction) error

// Register the function.
state := C.duckdb_register_scalar_function(con.duckdbCon, scalarFunction)
// TODO: we crash here if DuckDBError (e.g., register same twice)
C.duckdb_destroy_scalar_function(&scalarFunction)
if state == C.DuckDBError {
return getError(errAPI, errScalarUDFCreate)
Expand Down
47 changes: 45 additions & 2 deletions scalar_udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -47,14 +48,56 @@ func TestSimpleScalarUDF(t *testing.T) {
require.NoError(t, db.Close())
}

func TestAllTypesInScalarUDF(t *testing.T) {
type allTypesScalarUDF struct{}

var currentType TypeInfo

func (udf *allTypesScalarUDF) Config() (ScalarFunctionConfig, error) {
config := ScalarFunctionConfig{
InputTypeInfos: []TypeInfo{currentType},
ResultTypeInfo: currentType,
}
return config, nil
}

// TODO: test other primitive data types
func (udf *allTypesScalarUDF) ExecuteRow(args []driver.Value) (any, error) {
return args[0], nil
}

func TestAllTypesScalarUDF(t *testing.T) {
typeInfos := getTypeInfos(t)
for _, info := range typeInfos {
currentType = info.TypeInfo

db, err := sql.Open("duckdb", "")
require.NoError(t, err)

c, err := db.Conn(context.Background())
require.NoError(t, err)

_, err = c.ExecContext(context.Background(), `CREATE TYPE greeting AS ENUM ('hello', 'world')`)
require.NoError(t, err)

var udf allTypesScalarUDF
err = RegisterScalarUDF(c, "my_identity", &udf)
require.NoError(t, err)

var msg string
row := db.QueryRow(fmt.Sprintf(`SELECT my_identity(%s)::VARCHAR AS msg`, info.input))
require.NoError(t, row.Scan(&msg))
if info.TypeInfo.t != TYPE_UUID {
require.Equal(t, info.output, msg, fmt.Sprintf(`output does not match expected output, input: %s`, info.input))
} else {
require.NotEqual(t, "", msg, "uuid empty")
}

require.NoError(t, db.Close())
}
}

func TestScalarUDFErrors(t *testing.T) {
// TODO: trigger all possible errors and move to errors_test.go
// TODO: especially test trying to register same name twice
}

func TestScalarUDFNested(t *testing.T) {
Expand Down
130 changes: 116 additions & 14 deletions type_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,43 @@ import (
"github.com/stretchr/testify/require"
)

func TestTypeInterface(t *testing.T) {
type testTypeValues struct {
input string
output string
}

type testTypeInfo struct {
TypeInfo
testTypeValues
}

var testPrimitiveSQLValues = map[Type]testTypeValues{
TYPE_BOOLEAN: {input: `true::BOOLEAN`, output: `true`},
TYPE_TINYINT: {input: `42::TINYINT`, output: `42`},
TYPE_SMALLINT: {input: `42::SMALLINT`, output: `42`},
TYPE_INTEGER: {input: `42::INTEGER`, output: `42`},
TYPE_BIGINT: {input: `42::BIGINT`, output: `42`},
TYPE_UTINYINT: {input: `43::UTINYINT`, output: `43`},
TYPE_USMALLINT: {input: `43::USMALLINT`, output: `43`},
TYPE_UINTEGER: {input: `43::UINTEGER`, output: `43`},
TYPE_UBIGINT: {input: `43::UBIGINT`, output: `43`},
TYPE_FLOAT: {input: `1.7::FLOAT`, output: `1.7`},
TYPE_DOUBLE: {input: `1.7::DOUBLE`, output: `1.7`},
TYPE_TIMESTAMP: {input: `TIMESTAMP '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456`},
TYPE_DATE: {input: `DATE '1992-09-20 11:30:00.123456789'`, output: `1992-09-20`},
TYPE_TIME: {input: `TIME '1992-09-20 11:30:00.123456789'`, output: `11:30:00.123456`},
TYPE_INTERVAL: {input: `INTERVAL 1 YEAR`, output: `1 year`},
TYPE_HUGEINT: {input: `44::HUGEINT`, output: `44`},
TYPE_VARCHAR: {input: `'hello world'::VARCHAR`, output: `hello world`},
TYPE_BLOB: {input: `'\xAA'::BLOB`, output: `\xAA`},
TYPE_TIMESTAMP_S: {input: `TIMESTAMP_S '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00`},
TYPE_TIMESTAMP_MS: {input: `TIMESTAMP_MS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123`},
TYPE_TIMESTAMP_NS: {input: `TIMESTAMP_NS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456789`},
TYPE_UUID: {input: `uuid()`, output: ``},
TYPE_TIMESTAMP_TZ: {input: `TIMESTAMPTZ '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456+00`},
}

func getTypeInfos(t *testing.T) []testTypeInfo {
var primitiveTypes []Type
for k := range typeToStringMap {
_, inMap := unsupportedTypeToStringMap[k]
Expand All @@ -21,41 +57,107 @@ func TestTypeInterface(t *testing.T) {
}

// Create each primitive type information.
var typeInfos []TypeInfo
var typeInfos []testTypeInfo
for _, primitive := range primitiveTypes {
typeInfo, err := PrimitiveTypeInfo(primitive)
require.NoError(t, err)
typeInfos = append(typeInfos, typeInfo)
info := testTypeInfo{
TypeInfo: typeInfo,
testTypeValues: testPrimitiveSQLValues[typeInfo.t],
}
typeInfos = append(typeInfos, info)
}

// Create nested types.
decimalInfo := DecimalTypeInfo(3, 2)
decimalInfo := testTypeInfo{
TypeInfo: DecimalTypeInfo(3, 2),
testTypeValues: testTypeValues{
input: `4::DECIMAL(3, 2)`,
output: `4.00`,
},
}

names := []string{"hello", "world"}
enumInfo, err := EnumTypeInfo(names)
info, err := EnumTypeInfo(names)
enumInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `'hello'::greeting`,
output: `hello`,
},
}
require.NoError(t, err)

listInfo, err := ListTypeInfo(decimalInfo)
info, err = ListTypeInfo(decimalInfo.TypeInfo)
listInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `[4::DECIMAL(3, 2)]`,
output: `[4.00]`,
},
}
require.NoError(t, err)
nestedListInfo, err := ListTypeInfo(listInfo)

info, err = ListTypeInfo(listInfo.TypeInfo)
nestedListInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `[[4::DECIMAL(3, 2)]]`,
output: `[[4.00]]`,
},
}
require.NoError(t, err)

childTypeInfos := []TypeInfo{enumInfo, nestedListInfo}
structTypeInfo, err := StructTypeInfo(childTypeInfos, names)
childTypeInfos := []TypeInfo{enumInfo.TypeInfo, nestedListInfo.TypeInfo}
info, err = StructTypeInfo(childTypeInfos, names)
structTypeInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `{'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]}`,
output: `{'hello': hello, 'world': [[4.00]]}`,
},
}
require.NoError(t, err)

nestedChildTypeInfos := []TypeInfo{structTypeInfo, listInfo}
nestedStructTypeInfo, err := StructTypeInfo(nestedChildTypeInfos, names)
nestedChildTypeInfos := []TypeInfo{structTypeInfo.TypeInfo, listInfo.TypeInfo}
info, err = StructTypeInfo(nestedChildTypeInfos, names)
nestedStructTypeInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `{
'hello': {'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]},
'world': [4::DECIMAL(3, 2)]
}`,
output: `{'hello': {'hello': hello, 'world': [[4.00]]}, 'world': [4.00]}`,
},
}
require.NoError(t, err)

mapTypeInfo, err := MapTypeInfo(nestedStructTypeInfo, nestedListInfo)
info, err = MapTypeInfo(decimalInfo.TypeInfo, nestedStructTypeInfo.TypeInfo)
mapTypeInfo := testTypeInfo{
TypeInfo: info,
testTypeValues: testTypeValues{
input: `MAP {
4::DECIMAL(3, 2) : {
'hello': {'hello': 'hello'::greeting, 'world': [[4::DECIMAL(3, 2)]]},
'world': [4::DECIMAL(3, 2)]
}
}`,
output: `{4.00={'hello': {'hello': hello, 'world': [[4.00]]}, 'world': [4.00]}}`,
},
}
require.NoError(t, err)

typeInfos = append(typeInfos, decimalInfo, enumInfo, listInfo, nestedListInfo, structTypeInfo, nestedStructTypeInfo, mapTypeInfo)
return typeInfos
}

func TestTypeInterface(t *testing.T) {
typeInfos := getTypeInfos(t)

// Use each type as a child and to create the respective logical type.
// Use each type as a child.
for _, info := range typeInfos {
_, err = ListTypeInfo(info)
_, err := ListTypeInfo(info.TypeInfo)
require.NoError(t, err)
}
}
8 changes: 5 additions & 3 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ func convertNumericType[srcT numericType, destT numericType](val srcT) destT {
return destT(val)
}

type UUID [16]byte
const UUIDLength = 16

type UUID [UUIDLength]byte

func (u *UUID) Scan(v any) error {
if n := copy(u[:], v.([]byte)); n != 16 {
if n := copy(u[:], v.([]byte)); n != UUIDLength {
return fmt.Errorf("invalid UUID length: %d", n)
}
return nil
Expand All @@ -34,7 +36,7 @@ func (u *UUID) Scan(v any) error {
// The value is computed as: upper * 2^64 + lower

func hugeIntToUUID(hi C.duckdb_hugeint) []byte {
var uuid [16]byte
var uuid [UUIDLength]byte
// We need to flip the sign bit of the signed hugeint to transform it to UUID bytes
binary.BigEndian.PutUint64(uuid[:8], uint64(hi.upper)^1<<63)
binary.BigEndian.PutUint64(uuid[8:], uint64(hi.lower))
Expand Down
22 changes: 21 additions & 1 deletion vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (vec *vector) tryCast(val any) (any, error) {
case TYPE_MAP:
return tryPrimitiveCast[Map](val, reflect.TypeOf(Map{}).String())
case TYPE_UUID:
return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String())
return vec.tryCastUUID(val)
}
return nil, unsupportedTypeError(unknownTypeErrMsg)
}
Expand Down Expand Up @@ -228,6 +228,26 @@ func (vec *vector) tryCastStruct(val any) (map[string]any, error) {
return m, nil
}

func (vec *vector) tryCastUUID(val any) (UUID, error) {
uuid, ok := val.(UUID)
if ok {
return uuid, nil
}

bytes, ok := val.([]byte)
if ok {
if len(bytes) == UUIDLength {
for i := 0; i < UUIDLength; i++ {
uuid[i] = bytes[i]
}
return uuid, nil
}
}

goType := reflect.TypeOf(val)
return uuid, castError(goType.String(), reflect.TypeOf(UUID{}).String())
}

func (vec *vector) init(logicalType C.duckdb_logical_type, colIdx int) error {
t := Type(C.duckdb_get_type_id(logicalType))

Expand Down

0 comments on commit 1833f7d

Please sign in to comment.