From aec120526b5a3bd519c25ad2fb283dca1a605b11 Mon Sep 17 00:00:00 2001 From: Jaap Aarts Date: Sat, 4 May 2024 14:07:49 +0200 Subject: [PATCH] Add support for returning structs and lists from tableUDF --- appender_vector.go | 237 ++++++++++++++++++++++++++++++--------------- udf.go | 36 ++++++- udf_test.go | 144 +++++++++++++++++++++++---- 3 files changed, 318 insertions(+), 99 deletions(-) diff --git a/appender_vector.go b/appender_vector.go index 5f2b960e..0e235017 100644 --- a/appender_vector.go +++ b/appender_vector.go @@ -7,7 +7,6 @@ package duckdb import "C" import ( - "fmt" "reflect" "strconv" "time" @@ -105,57 +104,6 @@ func tryCastInteger[S any, R numericType](val S) (R, error) { } -/* - func tryCast[T, R any](val T) (R, error) { - var x R - switch any(x).(type) { - case uint8: - r, err := tryCastInteger[T, R](val) - return R(r), err - case int8: - return convertNumericType[T, int8] - case uint16: - return convertNumericType[T, uint16] - case int16: - return convertNumericType[T, int16] - case uint32: - return convertNumericType[T, uint32] - case int32: - return convertNumericType[T, int32] - case uint64: - return convertNumericType[T, uint64] - case int64: - return convertNumericType[T, int64] - case uint: - return convertNumericType[T, uint] - case int: - return convertNumericType[T, int] - case C.DUCKDB_TYPE_BIGINT: - return tryNumericCast[int64](val, reflect.Int64.String()) - case C.DUCKDB_TYPE_FLOAT: - return tryNumericCast[float32](val, reflect.Float32.String()) - case C.DUCKDB_TYPE_DOUBLE: - return tryNumericCast[float64](val, reflect.Float64.String()) - case C.DUCKDB_TYPE_BOOLEAN: - return tryPrimitiveCast[bool](val, reflect.Bool.String()) - case C.DUCKDB_TYPE_VARCHAR: - return tryPrimitiveCast[string](val, reflect.String.String()) - case C.DUCKDB_TYPE_BLOB: - return tryPrimitiveCast[[]byte](val, reflect.TypeOf([]byte{}).String()) - case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, - C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: - return tryPrimitiveCast[time.Time](val, reflect.TypeOf(time.Time{}).String()) - case C.DUCKDB_TYPE_UUID: - return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String()) - case C.DUCKDB_TYPE_LIST: - return vec.tryCastList(val) - case C.DUCKDB_TYPE_STRUCT: - return vec.tryCastStruct(val) - } - - return nil, getError(errDriver, nil) - } -*/ func (*vector) canNil(val reflect.Value) bool { switch val.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, @@ -362,7 +310,6 @@ func (vec *vector) setCString(rowIdx C.idx_t, val any) { } else if vec.duckdbType == C.DUCKDB_TYPE_BLOB { str = string(val.([]byte)[:]) } - // This setter also writes BLOBs. cStr := C.CString(str) C.duckdb_vector_assign_string_element_len(vec.duckdbVector, rowIdx, cStr, C.idx_t(len(str))) @@ -513,6 +460,12 @@ func (vec *vector) initStruct(logicalType C.duckdb_logical_type) error { return nil } +func _setPrimitive[T any](vec *vector, rowIdx C.idx_t, val T) { + ptr := C.duckdb_vector_get_data(vec.duckdbVector) + xs := (*[1 << 31]T)(ptr) + xs[rowIdx] = val +} + func _setVectorNumeric[S any, T numericType](vec *vector, rowIdx C.idx_t, val S) error { var fv T switch v := any(val).(type) { @@ -547,12 +500,9 @@ func _setVectorNumeric[S any, T numericType](vec *vector, rowIdx C.idx_t, val S) fv = 0 } default: - return fmt.Errorf("wrong input type") + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(fv).String()) } - - ptr := C.duckdb_vector_get_data(vec.duckdbVector) - xs := (*[1 << 31]T)(ptr) - xs[rowIdx] = fv + _setPrimitive(vec, rowIdx, fv) return nil } @@ -584,14 +534,149 @@ func _setVectorBool[S any](vec *vector, rowIdx C.idx_t, val S) error { case float64: fv = v == 0 case bool: - fv = v + fv = v default: - return fmt.Errorf("wrong input type") + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(fv).String()) } + _setPrimitive(vec, rowIdx, fv) + return nil +} - ptr := C.duckdb_vector_get_data(vec.duckdbVector) - xs := (*[1 << 31]bool)(ptr) - xs[rowIdx] = fv +func _setVectorString[S any](vec *vector, rowIdx C.idx_t, val S) error { + var cStr *C.char + var length int + switch v := any(val).(type) { + case string: + cStr = C.CString(v) + defer C.free(unsafe.Pointer(cStr)) + length = len(v) + case []byte: + cStr = (*C.char)(C.CBytes(v)) + defer C.free(unsafe.Pointer(cStr)) + length = len(v) + default: + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(cStr).String()) + } + + C.duckdb_vector_assign_string_element_len(vec.duckdbVector, rowIdx, (*C.char)(cStr), C.idx_t(length)) + return nil +} + +func _setVectorTS[S any](vec *vector, rowIdx C.idx_t, val S) error { + var t time.Time + switch v := any(val).(type) { + case time.Time: + t = v + default: + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(t).String()) + } + var ticks int64 + switch vec.duckdbType { + case C.DUCKDB_TYPE_TIMESTAMP: + ticks = t.UTC().UnixMicro() + case C.DUCKDB_TYPE_TIMESTAMP_S: + ticks = t.UTC().Unix() + case C.DUCKDB_TYPE_TIMESTAMP_MS: + ticks = t.UTC().UnixMilli() + case C.DUCKDB_TYPE_TIMESTAMP_NS: + ticks = t.UTC().UnixNano() + case C.DUCKDB_TYPE_TIMESTAMP_TZ: + ticks = t.UTC().UnixMicro() + } + var ts C.duckdb_timestamp + ts.micros = C.int64_t(ticks) + _setPrimitive(vec, rowIdx, ts) + return nil +} + +func _setVectorUUID[S any](vec *vector, rowIdx C.idx_t, val S) error { + var uuid UUID + switch v := any(val).(type) { + case UUID: + uuid = v + default: + return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) + } + hi := uuidToHugeInt(uuid) + _setPrimitive(vec, rowIdx, hi) + return nil +} + +func _setVectorList[S any](vec *vector, rowIdx C.idx_t, val S) error { + var list []any + switch v := any(val).(type) { + case []any: + list = v + default: + // Insert the values into the child vector. + rv := reflect.ValueOf(val) + list = make([]any, rv.Len()) + childVector := vec.childVectors[0] + + for i := 0; i < rv.Len(); i++ { + idx := rv.Index(i) + if vec.canNil(idx) && idx.IsNil() { + list[i] = nil + continue + } + + var err error + list[i], err = childVector.tryCast(idx.Interface()) + if err != nil { + return err + } + } + } + childVectorSize := C.duckdb_list_vector_get_size(vec.duckdbVector) + + // Set the offset and length of the list vector using the current size of the child vector. + listEntry := C.duckdb_list_entry{ + offset: C.idx_t(childVectorSize), + length: C.idx_t(len(list)), + } + _setPrimitive(vec, rowIdx, listEntry) + + newLength := C.idx_t(len(list)) + childVectorSize + C.duckdb_list_vector_set_size(vec.duckdbVector, newLength) + C.duckdb_list_vector_reserve(vec.duckdbVector, newLength) + + // Insert the values into the child vector. + childVector := vec.childVectors[0] + for i, e := range list { + offset := C.idx_t(i) + childVectorSize + childVector.fn(&childVector, offset, e) + } + return nil +} + +func _setVectorStruct[S any](vec *vector, rowIdx C.idx_t, val S) error { + //TODO: cast to map if possible + var m map[string]any + switch v := any(val).(type) { + case map[string]any: + m = v + default: + // Catch mismatching types. + goType := reflect.TypeOf(val) + if reflect.TypeOf(val).Kind() != reflect.Struct { + return castError(goType.String(), reflect.Struct.String()) + } + + m = make(map[string]any) + rv := reflect.ValueOf(val) + structType := rv.Type() + + for i := 0; i < structType.NumField(); i++ { + fieldName := structType.Field(i).Name + m[fieldName] = rv.Field(i).Interface() + } + } + + for i := 0; i < len(vec.childVectors); i++ { + childVector := vec.childVectors[i] + childName := vec.childNames[i] + childVector.fn(&childVector, rowIdx, m[childName]) + } return nil } @@ -619,19 +704,19 @@ func setVectorVal[S any](vec *vector, rowIdx C.idx_t, val S) error { return _setVectorNumeric[S, float64](vec, rowIdx, val) case C.DUCKDB_TYPE_BOOLEAN: return _setVectorBool[S](vec, rowIdx, val) - /* case C.DUCKDB_TYPE_VARCHAR: - return tryPrimitiveCast[string](val, reflect.String.String()) - case C.DUCKDB_TYPE_BLOB: - return tryPrimitiveCast[[]byte](val, reflect.TypeOf([]byte{}).String()) - case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, - C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: - return tryPrimitiveCast[time.Time](val, reflect.TypeOf(time.Time{}).String()) - case C.DUCKDB_TYPE_UUID: - return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String()) - case C.DUCKDB_TYPE_LIST: - return vec.tryCastList(val) - case C.DUCKDB_TYPE_STRUCT: - return vec.tryCastStruct(val)*/ + case C.DUCKDB_TYPE_VARCHAR: + return _setVectorString[S](vec, rowIdx, val) + case C.DUCKDB_TYPE_BLOB: + return _setVectorString[S](vec, rowIdx, val) + case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, + C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: + return _setVectorTS[S](vec, rowIdx, val) + case C.DUCKDB_TYPE_UUID: + return _setVectorUUID[S](vec, rowIdx, val) + case C.DUCKDB_TYPE_LIST: + return _setVectorList[S](vec, rowIdx, val) + case C.DUCKDB_TYPE_STRUCT: + return _setVectorStruct[S](vec, rowIdx, val) } // TODO: error return nil diff --git a/udf.go b/udf.go index 659d1463..b83637d8 100644 --- a/udf.go +++ b/udf.go @@ -44,7 +44,6 @@ func SetRowValue[T any](row Row, c int, val T) { setVectorVal[T](&vec, C.ulong(row.r), val) } - func (row Row) SetRowValue(c int, val any) { vec := row.vectors[c] @@ -56,7 +55,7 @@ func (row Row) SetRowValue(c int, val any) { defer C.free(unsafe.Pointer(errstr)) C.duckdb_function_set_error(row.info, errstr) } - + vec.fn(&vec, C.ulong(row.r), v) } @@ -130,7 +129,7 @@ func udf_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) { break } } - // since row.r points to one past the last value, it is also the size + // since row.r points to one past the last value, it is also the size C.duckdb_data_chunk_set_size(output, C.ulong(row.r)) } @@ -182,6 +181,37 @@ func getDuckdbTypeFromValue(v any) (C.duckdb_logical_type, error) { case string: return C.duckdb_create_logical_type(C.DUCKDB_TYPE_VARCHAR), nil default: + rv := reflect.ValueOf(v) + rt := reflect.TypeOf(v) + + switch rt.Kind() { + case reflect.Struct: + var nfields int + for i := rt.NumField()-1; i >= 0; i-- { + if rv.Field(i).CanInterface(){ + nfields ++ + } + } + + types := (*[1 << 31]C.duckdb_logical_type)(C.malloc(C.ulong(uintptr(nfields) * unsafe.Sizeof(C.duckdb_logical_type(nil))))) + names := (*[1 << 31]*C.char)(C.malloc(C.ulong(uintptr(nfields) * unsafe.Sizeof((*C.char)(nil))))) + defer C.free(unsafe.Pointer(types)) + defer C.free(unsafe.Pointer(names)) + for i := 0; i < nfields; i++ { + if !rv.Field(i).CanInterface(){ + continue + } + var err error + (*types)[i], err = getDuckdbTypeFromValue(rv.Field(i).Interface()) + if err != nil { + return C.duckdb_logical_type(nil), err + } + (*names)[i] = C.CString(rt.Field(i).Name) + } + ctypes := (*C.duckdb_logical_type)(unsafe.Pointer(types)) + cnames := (**C.char)(unsafe.Pointer(names)) + return C.duckdb_create_struct_type(ctypes, cnames, C.ulong(nfields)), nil + } return C.duckdb_logical_type(nil), unsupportedTypeError(reflect.TypeOf(v).String()) } } diff --git a/udf_test.go b/udf_test.go index de33212c..33f60d2e 100644 --- a/udf_test.go +++ b/udf_test.go @@ -8,30 +8,49 @@ import ( "testing" ) -type wrongValueError struct { - rowIdx int - colIdx int - colName string - expected any - got any -} +type ( + wrongValueError struct { + rowIdx int + colIdx int + colName string + expected any + got any + } -type tableUDF struct { - n int64 - count int64 -} + testUDF interface { + TableFunction + GetValue(r, c int) any + } -func (wve wrongValueError) Error() string { - return fmt.Sprintf("Wrong value at row %d, column %d(%s): Expected %v of type %[4]T, found %v of type %[5]T", wve.rowIdx, wve.colIdx, wve.colName, wve.expected, wve.got) -} + incTableUDF struct { + n int64 + count int64 + } + + structTableUDF struct { + n int64 + count int64 + } -func (d *tableUDF) GetArguments() []interface{} { + structTableUDFT struct { + I int64 + } +) + +var ( + tudfs = []testUDF{ + &incTableUDF{}, + &structTableUDF{}, + } +) + +func (d *incTableUDF) GetArguments() []interface{} { return []interface{}{ int64(0), } } -func (d *tableUDF) BindArguments(args ...interface{}) []ColumnName { +func (d *incTableUDF) BindArguments(args ...interface{}) []ColumnName { d.count = 0 d.n = args[0].(int64) return []ColumnName{ @@ -39,7 +58,7 @@ func (d *tableUDF) BindArguments(args ...interface{}) []ColumnName { } } -func (d *tableUDF) FillRow(row Row) bool { +func (d *incTableUDF) FillRow(row Row) bool { if d.count > d.n { return false } @@ -48,10 +67,39 @@ func (d *tableUDF) FillRow(row Row) bool { return true } -func (d *tableUDF) GetValue(r, c int) any { +func (d *incTableUDF) GetValue(r, c int) any { return int64(r + 1) } +func (d *structTableUDF) GetArguments() []interface{} { + return []interface{}{ + int64(0), + } +} + +func (d *structTableUDF) BindArguments(args ...interface{}) []ColumnName { + d.count = 0 + d.n = args[0].(int64) + return []ColumnName{ + {Name: "result", V: structTableUDFT{I: 0}}, + } +} + +func (d *structTableUDF) FillRow(row Row) bool { + if d.count > d.n { + return false + } + d.count++ + SetRowValue[structTableUDFT](row, 0, structTableUDFT{I: d.count}) + return true +} + +func (d *structTableUDF) GetValue(r, c int) any { + return map[string]any{ + "I": int64(r + 1), + } +} + func BenchmarkTableUDF(b *testing.B) { b.StopTimer() var err error @@ -61,7 +109,7 @@ func BenchmarkTableUDF(b *testing.B) { } defer db.Close() conn, _ := db.Conn(context.Background()) - var fun tableUDF + var fun incTableUDF RegisterTableUDF(conn, &fun) b.StartTimer() for n := 0; n < b.N; n++ { @@ -81,7 +129,7 @@ func TestTableUDF(t *testing.T) { } defer db.Close() conn, _ := db.Conn(context.Background()) - var fun tableUDF + var fun incTableUDF RegisterTableUDF(conn, &fun) rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048)") if err != nil { @@ -124,3 +172,59 @@ func TestTableUDF(t *testing.T) { r++ } } + +func TestTableUDF2(t *testing.T) { + var err error + db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") + if err != nil { + t.Fatal(err) + } + defer db.Close() + conn, _ := db.Conn(context.Background()) + var fun structTableUDF + RegisterTableUDF(conn, &fun) + rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048)") + if err != nil { + t.Fatal(err) + } + + //TODO: check column names + columns, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + + values := make([]interface{}, len(columns)) + scanArgs := make([]interface{}, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + // Fetch rows + var r int + for rows.Next() { + err = rows.Scan(scanArgs...) + if err != nil { + panic(err.Error()) + } + for i, value := range values { + expected := fun.GetValue(r, i) + if !reflect.DeepEqual(expected, value) { + err := wrongValueError{ + rowIdx: r, + colIdx: i, + colName: columns[i], + expected: expected, + got: value, + } + t.Log(err) + t.Fail() + } + } + r++ + } +} + +func (wve wrongValueError) Error() string { + return fmt.Sprintf("Wrong value at row %d, column %d(%s): Expected %v of type %[4]T, found %v of type %[5]T", wve.rowIdx, wve.colIdx, wve.colName, wve.expected, wve.got) +}