diff --git a/appender_vector.go b/appender_vector.go index 7724a39f..5f2b960e 100644 --- a/appender_vector.go +++ b/appender_vector.go @@ -7,6 +7,7 @@ package duckdb import "C" import ( + "fmt" "reflect" "strconv" "time" @@ -76,6 +77,85 @@ func (vec *vector) tryCast(val any) (any, error) { return nil, getError(errDriver, nil) } +func tryCastInteger[S any, R numericType](val S) (R, error) { + switch v := any(val).(type) { + case uint8: + return convertNumericType[uint8, R](v), nil + case int8: + return convertNumericType[int8, R](v), nil + case uint16: + return convertNumericType[uint16, R](v), nil + case int16: + return convertNumericType[int16, R](v), nil + case uint32: + return convertNumericType[uint32, R](v), nil + case int32: + return convertNumericType[int32, R](v), nil + case uint64: + return convertNumericType[uint64, R](v), nil + case int64: + return convertNumericType[int64, R](v), nil + case uint: + return convertNumericType[uint, R](v), nil + case int: + return convertNumericType[int, R](v), nil + default: + return 0, nil + } + +} + +/* + func tryCast[T, R any](val T) (R, error) { + var x R + switch any(x).(type) { + case uint8: + r, err := tryCastInteger[T, R](val) + return R(r), err + case int8: + return convertNumericType[T, int8] + case uint16: + return convertNumericType[T, uint16] + case int16: + return convertNumericType[T, int16] + case uint32: + return convertNumericType[T, uint32] + case int32: + return convertNumericType[T, int32] + case uint64: + return convertNumericType[T, uint64] + case int64: + return convertNumericType[T, int64] + case uint: + return convertNumericType[T, uint] + case int: + return convertNumericType[T, int] + case C.DUCKDB_TYPE_BIGINT: + return tryNumericCast[int64](val, reflect.Int64.String()) + case C.DUCKDB_TYPE_FLOAT: + return tryNumericCast[float32](val, reflect.Float32.String()) + case C.DUCKDB_TYPE_DOUBLE: + return tryNumericCast[float64](val, reflect.Float64.String()) + case C.DUCKDB_TYPE_BOOLEAN: + return tryPrimitiveCast[bool](val, reflect.Bool.String()) + case C.DUCKDB_TYPE_VARCHAR: + return tryPrimitiveCast[string](val, reflect.String.String()) + case C.DUCKDB_TYPE_BLOB: + return tryPrimitiveCast[[]byte](val, reflect.TypeOf([]byte{}).String()) + case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, + C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: + return tryPrimitiveCast[time.Time](val, reflect.TypeOf(time.Time{}).String()) + case C.DUCKDB_TYPE_UUID: + return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String()) + case C.DUCKDB_TYPE_LIST: + return vec.tryCastList(val) + case C.DUCKDB_TYPE_STRUCT: + return vec.tryCastStruct(val) + } + + return nil, getError(errDriver, nil) + } +*/ func (*vector) canNil(val reflect.Value) bool { switch val.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, @@ -432,3 +512,127 @@ func (vec *vector) initStruct(logicalType C.duckdb_logical_type) error { return nil } + +func _setVectorNumeric[S any, T numericType](vec *vector, rowIdx C.idx_t, val S) error { + var fv T + switch v := any(val).(type) { + case uint8: + fv = T(v) + case int8: + fv = T(v) + case uint16: + fv = T(v) + case int16: + fv = T(v) + case uint32: + fv = T(v) + case int32: + fv = T(v) + case uint64: + fv = T(v) + case int64: + fv = T(v) + case uint: + fv = T(v) + case int: + fv = T(v) + case float32: + fv = T(v) + case float64: + fv = T(v) + case bool: + if v { + fv = 1 + } else { + fv = 0 + } + default: + return fmt.Errorf("wrong input type") + } + + ptr := C.duckdb_vector_get_data(vec.duckdbVector) + xs := (*[1 << 31]T)(ptr) + xs[rowIdx] = fv + return nil +} + +func _setVectorBool[S any](vec *vector, rowIdx C.idx_t, val S) error { + var fv bool + switch v := any(val).(type) { + case uint8: + fv = v == 0 + case int8: + fv = v == 0 + case uint16: + fv = v == 0 + case int16: + fv = v == 0 + case uint32: + fv = v == 0 + case int32: + fv = v == 0 + case uint64: + fv = v == 0 + case int64: + fv = v == 0 + case uint: + fv = v == 0 + case int: + fv = v == 0 + case float32: + fv = v == 0 + case float64: + fv = v == 0 + case bool: + fv = v + default: + return fmt.Errorf("wrong input type") + } + + ptr := C.duckdb_vector_get_data(vec.duckdbVector) + xs := (*[1 << 31]bool)(ptr) + xs[rowIdx] = fv + return nil +} + +func setVectorVal[S any](vec *vector, rowIdx C.idx_t, val S) error { + switch vec.duckdbType { + case C.DUCKDB_TYPE_UTINYINT: + return _setVectorNumeric[S, uint8](vec, rowIdx, val) + case C.DUCKDB_TYPE_TINYINT: + return _setVectorNumeric[S, int8](vec, rowIdx, val) + case C.DUCKDB_TYPE_USMALLINT: + return _setVectorNumeric[S, uint16](vec, rowIdx, val) + case C.DUCKDB_TYPE_SMALLINT: + return _setVectorNumeric[S, int16](vec, rowIdx, val) + case C.DUCKDB_TYPE_UINTEGER: + return _setVectorNumeric[S, uint32](vec, rowIdx, val) + case C.DUCKDB_TYPE_INTEGER: + return _setVectorNumeric[S, int32](vec, rowIdx, val) + case C.DUCKDB_TYPE_UBIGINT: + return _setVectorNumeric[S, uint64](vec, rowIdx, val) + case C.DUCKDB_TYPE_BIGINT: + return _setVectorNumeric[S, int64](vec, rowIdx, val) + case C.DUCKDB_TYPE_FLOAT: + return _setVectorNumeric[S, float32](vec, rowIdx, val) + case C.DUCKDB_TYPE_DOUBLE: + return _setVectorNumeric[S, float64](vec, rowIdx, val) + case C.DUCKDB_TYPE_BOOLEAN: + return _setVectorBool[S](vec, rowIdx, val) + /* case C.DUCKDB_TYPE_VARCHAR: + return tryPrimitiveCast[string](val, reflect.String.String()) + case C.DUCKDB_TYPE_BLOB: + return tryPrimitiveCast[[]byte](val, reflect.TypeOf([]byte{}).String()) + case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, + C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: + return tryPrimitiveCast[time.Time](val, reflect.TypeOf(time.Time{}).String()) + case C.DUCKDB_TYPE_UUID: + return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String()) + case C.DUCKDB_TYPE_LIST: + return vec.tryCastList(val) + case C.DUCKDB_TYPE_STRUCT: + return vec.tryCastStruct(val)*/ + } + // TODO: error + return nil +} diff --git a/examples/udf/udf.go b/examples/udf/udf.go index ae5d50f9..cfa10caf 100644 --- a/examples/udf/udf.go +++ b/examples/udf/udf.go @@ -36,18 +36,21 @@ func (d *tableUDF) BindArguments(args ...interface{}) []duckdb.ColumnName { d.count = 0 d.n = args[0].(int64) return []duckdb.ColumnName{ - {"result", int64(0)}, + {Name: "result", V: int64(0)}, } } -func (d *tableUDF) GetRow() []interface{} { +func (d *tableUDF) FillRow(row duckdb.Row) bool{ + fmt.Println(d.count, d.n) if d.count > d.n { - return nil + return false } d.count++ - return []interface{}{int64(d.count)} + duckdb.SetRowValue[int64](row, 0, d.count) + return true } + func main() { var err error db, err = sql.Open("duckdb", "?access_mode=READ_WRITE") @@ -66,7 +69,7 @@ func main() { check(setting.Scan(&am)) log.Printf("DB opened with access mode %s", am) - rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(0)") + rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(100)") check(err) defer rows.Close() diff --git a/udf.go b/udf.go index fa8e4569..659d1463 100644 --- a/udf.go +++ b/udf.go @@ -17,28 +17,50 @@ typedef void (*callback)(duckdb_function_info, duckdb_data_chunk); // https://g import "C" import ( - "unsafe" - "reflect" "database/sql" + "reflect" + "unsafe" ) type ( + Row struct { + vectors []vector + r int + info C.duckdb_function_info + } ColumnName struct { Name string - V any + V any } TableFunction interface { GetArguments() []any BindArguments(args ...interface{}) []ColumnName - GetRow() []any + FillRow(Row) bool } ) -//var tableBinds = []func(info C.duckdb_function_info, output C.duckdb_data_chunk)([]struct{name string, type duckdb_logical_type}){} +func SetRowValue[T any](row Row, c int, val T) { + vec := row.vectors[c] + setVectorVal[T](&vec, C.ulong(row.r), val) +} -var tableFuncs = []TableFunction{} +func (row Row) SetRowValue(c int, val any) { + vec := row.vectors[c] + + // Ensure the types match before adding to the vector + v, err := vec.tryCast(val) + if err != nil { + cerr := columnError(err, c+1) + errstr := C.CString(cerr.Error()) + defer C.free(unsafe.Pointer(errstr)) + C.duckdb_function_set_error(row.info, errstr) + } + + vec.fn(&vec, C.ulong(row.r), v) +} +var tableFuncs = []TableFunction{} //export udf_bind func udf_bind(info C.duckdb_bind_info) { @@ -81,16 +103,17 @@ func udf_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) { tfunc := tableFuncs[*(*int)(extra_info)] columnCount := int(C.duckdb_data_chunk_get_column_count(output)) - vectors := make([]vector, columnCount) + var row Row + row.vectors = make([]vector, columnCount) var err error for i := 0; i < columnCount; i++ { duckdbVector := C.duckdb_data_chunk_get_vector(output, C.ulong(i)) t := C.duckdb_vector_get_column_type(duckdbVector) - if err = vectors[i].init(t, i); err != nil{ + if err = row.vectors[i].init(t, i); err != nil { break } - vectors[i].duckdbVector = duckdbVector - vectors[i].getChildVectors(duckdbVector) + row.vectors[i].duckdbVector = duckdbVector + row.vectors[i].getChildVectors(duckdbVector) } if err != nil { errstr := C.CString(err.Error()) @@ -98,37 +121,26 @@ func udf_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) { C.duckdb_function_set_error(info, errstr) return } - + maxSize := int(C.duckdb_vector_size()) - for i := 0; i < maxSize; i++ { - nextResults := tfunc.GetRow() - if nextResults == nil { + // At the end of the loop row.r must be the index one past the last added row + for row.r = 0; row.r < maxSize; row.r++ { + nextResults := tfunc.FillRow(row) + if !nextResults { break } - for j, val := range nextResults { - vec := vectors[j] - - // Ensure the types match before adding to the vector - v, err := vec.tryCast(val) - if err != nil { - cerr := columnError(err, j+1) - errstr := C.CString(cerr.Error()) - defer C.free(unsafe.Pointer(errstr)) - C.duckdb_function_set_error(info, errstr) - } - vec.fn(&vec, C.ulong(i), v) - C.duckdb_data_chunk_set_size(output, C.ulong(i+1)) - } } + // since row.r points to one past the last value, it is also the size + C.duckdb_data_chunk_set_size(output, C.ulong(row.r)) } func RegisterTableUDF(c *sql.Conn, function TableFunction) error { - err := c.Raw(func(dconn any) error{ + err := c.Raw(func(dconn any) error { ddconn := dconn.(*conn) name := C.CString("whoo") defer C.free(unsafe.Pointer(name)) - extra_info := C.malloc(C.ulong(1*unsafe.Sizeof(int(0)))) + extra_info := C.malloc(C.ulong(1 * unsafe.Sizeof(int(0)))) *(*int)(extra_info) = len(tableFuncs) tableFuncs = append(tableFuncs, function) @@ -138,7 +150,7 @@ func RegisterTableUDF(c *sql.Conn, function TableFunction) error { C.duckdb_table_function_set_init(tableFunction, C.init(C.udf_init)) C.duckdb_table_function_set_function(tableFunction, C.callback(C.udf_callback)) C.duckdb_table_function_set_extra_info(tableFunction, extra_info, C.duckdb_delete_callback_t(C.free)) - + argumentvalues := function.GetArguments() for _, v := range argumentvalues { @@ -164,13 +176,13 @@ func getDuckdbType[T any]() (C.duckdb_logical_type, error) { } func getDuckdbTypeFromValue(v any) (C.duckdb_logical_type, error) { - switch v.(type){ + switch v.(type) { case int64: return C.duckdb_create_logical_type(C.DUCKDB_TYPE_BIGINT), nil case string: return C.duckdb_create_logical_type(C.DUCKDB_TYPE_VARCHAR), nil default: - return C.duckdb_logical_type(nil), unsupportedTypeError(reflect.TypeOf(v).String()) + return C.duckdb_logical_type(nil), unsupportedTypeError(reflect.TypeOf(v).String()) } } @@ -184,6 +196,6 @@ func getValue(t C.duckdb_type, v C.duckdb_value) (any, error) { C.duckdb_free(unsafe.Pointer(str)) return ret, nil default: - return nil, unsupportedTypeError(reflect.TypeOf(v).String()) + return nil, unsupportedTypeError(reflect.TypeOf(v).String()) } }