diff --git a/appender.go b/appender.go index fa611903..31c264a6 100644 --- a/appender.go +++ b/appender.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -43,11 +42,11 @@ func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appende var cSchema *C.char if schema != "" { cSchema = C.CString(schema) - defer C.free(unsafe.Pointer(cSchema)) + defer C.duckdb_free(unsafe.Pointer(cSchema)) } cTable := C.CString(table) - defer C.free(unsafe.Pointer(cTable)) + defer C.duckdb_free(unsafe.Pointer(cTable)) var duckdbAppender C.duckdb_appender state := C.duckdb_appender_create(con.duckdbCon, cSchema, cTable, &duckdbAppender) @@ -74,10 +73,10 @@ func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appende a.types[i] = C.duckdb_appender_column_type(duckdbAppender, C.idx_t(i)) // Ensure that we only create an appender for supported column types. - duckdbType := C.duckdb_get_type_id(a.types[i]) - name, found := unsupportedTypeMap[duckdbType] + t := Type(C.duckdb_get_type_id(a.types[i])) + name, found := unsupportedTypeToStringMap[t] if found { - err := columnError(unsupportedTypeError(name), i+1) + err := addIndexToError(unsupportedTypeError(name), i+1) destroyTypeSlice(a.ptr, a.types) C.duckdb_appender_destroy(&duckdbAppender) return nil, getError(errAppenderCreation, err) @@ -229,5 +228,5 @@ func destroyTypeSlice(ptr unsafe.Pointer, slice []C.duckdb_logical_type) { for _, t := range slice { C.duckdb_destroy_logical_type(&t) } - C.free(ptr) + C.duckdb_free(ptr) } diff --git a/connection.go b/connection.go index a5e250f1..0f5eea3f 100644 --- a/connection.go +++ b/connection.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -152,11 +151,11 @@ func (c *conn) Close() error { } func (c *conn) prepareStmt(cmd string) (*stmt, error) { - cmdstr := C.CString(cmd) - defer C.free(unsafe.Pointer(cmdstr)) + cmdStr := C.CString(cmd) + defer C.duckdb_free(unsafe.Pointer(cmdStr)) var s C.duckdb_prepared_statement - if state := C.duckdb_prepare(c.duckdbCon, cmdstr, &s); state == C.DuckDBError { + if state := C.duckdb_prepare(c.duckdbCon, cmdStr, &s); state == C.DuckDBError { dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s))) C.duckdb_destroy_prepare(&s) return nil, dbErr @@ -166,11 +165,11 @@ func (c *conn) prepareStmt(cmd string) (*stmt, error) { } func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) { - cquery := C.CString(query) - defer C.free(unsafe.Pointer(cquery)) + cQuery := C.CString(query) + defer C.duckdb_free(unsafe.Pointer(cQuery)) var stmts C.duckdb_extracted_statements - stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cquery, &stmts) + stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts) if stmtsCount == 0 { err := C.GoString(C.duckdb_extract_statements_error(stmts)) C.duckdb_destroy_extracted(&stmts) diff --git a/data_chunk.go b/data_chunk.go index 03d3cdc3..e7d8082f 100644 --- a/data_chunk.go +++ b/data_chunk.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -66,7 +65,7 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error { // FIXME: Maybe we can make columnar insertions unsafe, i.e., we always assume a correct type. v, err := column.tryCast(val) if err != nil { - return columnError(err, colIdx) + return addIndexToError(err, colIdx) } // Set the value. diff --git a/duckdb.go b/duckdb.go index 6ca46ce8..050fdc62 100644 --- a/duckdb.go +++ b/duckdb.go @@ -5,7 +5,6 @@ package duckdb /* -#include #include */ import "C" @@ -58,7 +57,7 @@ func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error defer C.duckdb_destroy_config(&config) connStr := C.CString(getConnString(dsn)) - defer C.free(unsafe.Pointer(connStr)) + defer C.duckdb_free(unsafe.Pointer(connStr)) var outError *C.char defer C.duckdb_free(unsafe.Pointer(outError)) @@ -143,10 +142,10 @@ func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) { func setConfigOption(config C.duckdb_config, name string, option string) error { cName := C.CString(name) - defer C.free(unsafe.Pointer(cName)) + defer C.duckdb_free(unsafe.Pointer(cName)) cOption := C.CString(option) - defer C.free(unsafe.Pointer(cOption)) + defer C.duckdb_free(unsafe.Pointer(cOption)) state := C.duckdb_set_config(config, cName, cOption) if state == C.DuckDBError { diff --git a/duckdb_test.go b/duckdb_test.go index 5f677fb6..44d72662 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -256,7 +255,7 @@ func TestQuery(t *testing.T) { for rows.Next() { var i int require.NoError(t, rows.Scan(&i)) - assert.Equal(t, expected, i) + require.Equal(t, expected, i) expected++ } }) @@ -281,10 +280,8 @@ func TestQuery(t *testing.T) { func TestJSON(t *testing.T) { t.Parallel() db := openDB(t) - defer db.Close() loadJSONExt(t, db) - var data string t.Run("select empty JSON", func(t *testing.T) { @@ -311,18 +308,20 @@ func TestJSON(t *testing.T) { require.Equal(t, len(items), 2) require.Equal(t, items, []string{"foo", "bar"}) }) + + require.NoError(t, db.Close()) } func TestEmpty(t *testing.T) { t.Parallel() db := openDB(t) - defer db.Close() rows, err := db.Query(`SELECT 1 WHERE 1 = 0`) require.NoError(t, err) defer rows.Close() require.False(t, rows.Next()) require.NoError(t, rows.Err()) + require.NoError(t, db.Close()) } func TestTypeNamesAndScanTypes(t *testing.T) { @@ -519,12 +518,12 @@ func TestTypeNamesAndScanTypes(t *testing.T) { require.Equal(t, rows.Next(), false) }) } + require.NoError(t, db.Close()) } // Running multiple statements in a single query. All statements except the last one are executed and if no error then last statement is executed with args and result returned. func TestMultipleStatements(t *testing.T) { db := openDB(t) - defer db.Close() // test empty query _, err := db.Exec("") @@ -617,13 +616,12 @@ func TestMultipleStatements(t *testing.T) { err = rows.Close() require.NoError(t, err) - err = conn.Close() - require.NoError(t, err) + require.NoError(t, conn.Close()) + require.NoError(t, db.Close()) } func TestParquetExtension(t *testing.T) { db := openDB(t) - defer db.Close() _, err := db.Exec("CREATE TABLE users (id int, name varchar, age int);") require.NoError(t, err) @@ -647,11 +645,11 @@ func TestParquetExtension(t *testing.T) { err = os.Remove("./users.parquet") require.NoError(t, err) + require.NoError(t, db.Close()) } func TestQueryTimeout(t *testing.T) { db := openDB(t) - defer db.Close() ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*250) defer cancel() @@ -663,6 +661,7 @@ func TestQueryTimeout(t *testing.T) { // a very defensive time check, but should be good enough // the query takes much longer than 10 seconds require.Less(t, time.Since(now), 10*time.Second) + require.NoError(t, db.Close()) } func openDB(t *testing.T) *sql.DB { diff --git a/errors.go b/errors.go index 126b18fa..ea4f9721 100644 --- a/errors.go +++ b/errors.go @@ -26,10 +26,6 @@ func structFieldError(actual string, expected string) error { return fmt.Errorf("%s: expected %s, got %s", structFieldErrMsg, expected, actual) } -func columnError(err error, colIdx int) error { - return fmt.Errorf("%w: %s: %d", err, columnErrMsg, colIdx) -} - func columnCountError(actual int, expected int) error { return fmt.Errorf("%s: expected %d, got %d", columnCountErrMsg, expected, actual) } @@ -45,15 +41,35 @@ func invalidatedAppenderError(err error) error { return fmt.Errorf("%w: %s", err, invalidatedAppenderMsg) } +func tryOtherFuncError(hint string) error { + return fmt.Errorf("%s: %s", tryOtherFuncErrMsg, hint) +} + +func addIndexToError(err error, idx int) error { + return fmt.Errorf("%w: %s: %d", err, indexErrMsg, idx) +} + +func interfaceIsNilError(interfaceName string) error { + return fmt.Errorf("%s: %s", interfaceIsNilErrMsg, interfaceName) +} + +func duplicateNameError(name string) error { + return fmt.Errorf("%s: %s", duplicateNameErrMsg, name) +} + const ( driverErrMsg = "database/sql/driver" duckdbErrMsg = "duckdb error" castErrMsg = "cast error" structFieldErrMsg = "invalid STRUCT field" - columnErrMsg = "column index" columnCountErrMsg = "invalid column count" unsupportedTypeErrMsg = "unsupported data type" invalidatedAppenderMsg = "appended data has been invalidated due to corrupt row" + tryOtherFuncErrMsg = "please try this function instead" + indexErrMsg = "index" + unknownTypeErrMsg = "unknown type" + interfaceIsNilErrMsg = "interface is nil" + duplicateNameErrMsg = "duplicate name" ) var ( @@ -75,6 +91,10 @@ var ( errAppenderClose = errors.New("could not close appender") errAppenderFlush = errors.New("could not flush appender") + errEmptyName = errors.New("empty name") + errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH) + errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width") + errProfilingInfoEmpty = errors.New("no profiling information available for this connection") // Errors not covered in tests. diff --git a/errors_test.go b/errors_test.go index 430e6503..a4b2d88b 100644 --- a/errors_test.go +++ b/errors_test.go @@ -297,6 +297,8 @@ func TestErrAppendNestedList(t *testing.T) { } func TestErrAPISetValue(t *testing.T) { + t.Parallel() + var chunk DataChunk err := chunk.SetValue(1, 42, "hello") testError(t, err, errAPI.Error(), columnCountErrMsg) @@ -341,9 +343,8 @@ func TestErrProfiling(t *testing.T) { func TestDuckDBErrors(t *testing.T) { db := openDB(t) - defer db.Close() - createTable(db, t, `CREATE TABLE duckdberror_test(bar VARCHAR UNIQUE, baz INT32, u_1 UNION("string" VARCHAR))`) - _, err := db.Exec("INSERT INTO duckdberror_test(bar, baz) VALUES('bar', 0)") + createTable(db, t, `CREATE TABLE duckdb_error_test(bar VARCHAR UNIQUE, baz INT32, u_1 UNION("string" VARCHAR))`) + _, err := db.Exec(`INSERT INTO duckdb_error_test(bar, baz) VALUES ('bar', 0)`) require.NoError(t, err) testCases := []struct { @@ -351,62 +352,64 @@ func TestDuckDBErrors(t *testing.T) { errTyp ErrorType }{ { - tpl: "SELECT * FROM not_exist WHERE baz=0", + tpl: `SELECT * FROM not_exist WHERE baz=0`, errTyp: ErrorTypeCatalog, }, { - tpl: "SELECT * FROM duckdberror_test WHERE col=?", + tpl: `SELECT * FROM duckdb_error_test WHERE col=?`, errTyp: ErrorTypeBinder, }, { - tpl: "SELEC * FROM duckdberror_test baz=0", + tpl: `SELEC * FROM duckdb_error_test baz=0`, errTyp: ErrorTypeParser, }, { - tpl: "INSERT INTO duckdberror_test(bar, baz) VALUES('bar', 1)", + tpl: `INSERT INTO duckdb_error_test(bar, baz) VALUES ('bar', 1)`, errTyp: ErrorTypeConstraint, }, { - tpl: "INSERT INTO duckdberror_test(bar, baz) VALUES('foo', 18446744073709551615)", + tpl: `INSERT INTO duckdb_error_test(bar, baz) VALUES ('foo', 18446744073709551615)`, errTyp: ErrorTypeConversion, }, { - tpl: "INSTALL not_exist", + tpl: `INSTALL not_exist`, errTyp: ErrorTypeHTTP, }, { - tpl: "LOAD not_exist", + tpl: `LOAD not_exist`, errTyp: ErrorTypeIO, }, { - tpl: "SELECT array_length(array_value(array_value(1, 2, 2), array_value(3, 4, 3)), 3)", + tpl: `SELECT array_length(array_value(array_value(1, 2, 2), array_value(3, 4, 3)), 3)`, errTyp: ErrorTypeOutOfRange, }, { - tpl: "SELECT '010110'::BIT & '11000'::BIT", + tpl: `SELECT '010110'::BIT & '11000'::BIT`, errTyp: ErrorTypeInvalidInput, }, { - tpl: "SET external_threads=-1", + tpl: `SET external_threads=-1`, errTyp: ErrorTypeSyntax, }, { - tpl: "CREATE UNIQUE INDEX idx ON duckdberror_test(u_1)", + tpl: `CREATE UNIQUE INDEX idx ON duckdb_error_test(u_1)`, errTyp: ErrorTypeInvalidType, }, } for _, tc := range testCases { - _, err := db.Exec(tc.tpl) - de, ok := err.(*Error) + _, err = db.Exec(tc.tpl) + var de *Error + ok := errors.As(err, &de) if !ok { require.Fail(t, "error type is not (*duckdb.Error)", "tql: %s\ngot: %#v", tc.tpl, err) } require.Equal(t, de.Type, tc.errTyp, "tpl: %s\nactual error msg: %s", tc.tpl, de.Msg) } + + require.NoError(t, db.Close()) } -func TestGetDuckDBError(t *testing.T) { - // only for the corner cases +func TestDuckDBErrorsCornerCases(t *testing.T) { testCases := []*Error{ { Msg: "", @@ -420,7 +423,7 @@ func TestGetDuckDBError(t *testing.T) { Msg: "Error: xxx", Type: ErrorTypeUnknownType, }, - // next 3 cases for the prefix testing + // Prefix testing. { Msg: "Invalid Error: xxx", Type: ErrorTypeInvalid, @@ -436,7 +439,8 @@ func TestGetDuckDBError(t *testing.T) { } for _, tc := range testCases { - err := getDuckDBError(tc.Msg).(*Error) + var err *Error + errors.As(getDuckDBError(tc.Msg), &err) require.Equal(t, tc, err) } } diff --git a/replacement_scan.go b/replacement_scan.go index cd960237..6a1e55d1 100644 --- a/replacement_scan.go +++ b/replacement_scan.go @@ -1,7 +1,6 @@ package duckdb /* - #include #include void replacement_scan_cb(duckdb_replacement_scan_info info, const char *table_name, void *data); typedef const char cchar_t; @@ -33,15 +32,15 @@ func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.ccha scanner := h.Value().(ReplacementScanCallback) tFunc, params, err := scanner(C.GoString(table_name)) if err != nil { - errstr := C.CString(err.Error()) - C.duckdb_replacement_scan_set_error(info, errstr) - C.free(unsafe.Pointer(errstr)) + errStr := C.CString(err.Error()) + C.duckdb_replacement_scan_set_error(info, errStr) + C.duckdb_free(unsafe.Pointer(errStr)) return } fNameStr := C.CString(tFunc) C.duckdb_replacement_scan_set_function_name(info, fNameStr) - defer C.free(unsafe.Pointer(fNameStr)) + defer C.duckdb_free(unsafe.Pointer(fNameStr)) for _, v := range params { switch x := v.(type) { @@ -49,16 +48,16 @@ func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.ccha str := C.CString(x) val := C.duckdb_create_varchar(str) C.duckdb_replacement_scan_add_parameter(info, val) - C.free(unsafe.Pointer(str)) + C.duckdb_free(unsafe.Pointer(str)) C.duckdb_destroy_value(&val) case int64: val := C.duckdb_create_int64(C.int64_t(x)) C.duckdb_replacement_scan_add_parameter(info, val) C.duckdb_destroy_value(&val) default: - errstr := C.CString("invalid type") - C.duckdb_replacement_scan_set_error(info, errstr) - C.free(unsafe.Pointer(errstr)) + errStr := C.CString("invalid type") + C.duckdb_replacement_scan_set_error(info, errStr) + C.duckdb_free(unsafe.Pointer(errStr)) return } } diff --git a/rows.go b/rows.go index a2b3d1a4..876fab81 100644 --- a/rows.go +++ b/rows.go @@ -81,93 +81,70 @@ func (r *rows) Next(dst []driver.Value) error { return nil } -// Implements driver.RowsColumnTypeScanType +// ColumnTypeScanType implements driver.RowsColumnTypeScanType. func (r *rows) ColumnTypeScanType(index int) reflect.Type { - colType := C.duckdb_column_type(&r.res, C.idx_t(index)) - switch colType { - case C.DUCKDB_TYPE_INVALID: + t := Type(C.duckdb_column_type(&r.res, C.idx_t(index))) + switch t { + case TYPE_INVALID: return nil - case C.DUCKDB_TYPE_BOOLEAN: + case TYPE_BOOLEAN: return reflect.TypeOf(true) - case C.DUCKDB_TYPE_TINYINT: + case TYPE_TINYINT: return reflect.TypeOf(int8(0)) - case C.DUCKDB_TYPE_SMALLINT: + case TYPE_SMALLINT: return reflect.TypeOf(int16(0)) - case C.DUCKDB_TYPE_INTEGER: + case TYPE_INTEGER: return reflect.TypeOf(int32(0)) - case C.DUCKDB_TYPE_BIGINT: + case TYPE_BIGINT: return reflect.TypeOf(int64(0)) - case C.DUCKDB_TYPE_UTINYINT: + case TYPE_UTINYINT: return reflect.TypeOf(uint8(0)) - case C.DUCKDB_TYPE_USMALLINT: + case TYPE_USMALLINT: return reflect.TypeOf(uint16(0)) - case C.DUCKDB_TYPE_UINTEGER: + case TYPE_UINTEGER: return reflect.TypeOf(uint32(0)) - case C.DUCKDB_TYPE_UBIGINT: + case TYPE_UBIGINT: return reflect.TypeOf(uint64(0)) - case C.DUCKDB_TYPE_FLOAT: + case TYPE_FLOAT: return reflect.TypeOf(float32(0)) - case C.DUCKDB_TYPE_DOUBLE: + case TYPE_DOUBLE: return reflect.TypeOf(float64(0)) - case C.DUCKDB_TYPE_TIMESTAMP: - return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_DATE: - return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_TIME: + case TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, TYPE_DATE, TYPE_TIME, TYPE_TIMESTAMP_TZ: return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_INTERVAL: + case TYPE_INTERVAL: return reflect.TypeOf(Interval{}) - case C.DUCKDB_TYPE_HUGEINT: + case TYPE_HUGEINT: return reflect.TypeOf(big.NewInt(0)) - case C.DUCKDB_TYPE_VARCHAR: + case TYPE_VARCHAR, TYPE_ENUM: return reflect.TypeOf("") - case C.DUCKDB_TYPE_ENUM: - return reflect.TypeOf("") - case C.DUCKDB_TYPE_BLOB: + case TYPE_BLOB: return reflect.TypeOf([]byte{}) - case C.DUCKDB_TYPE_DECIMAL: + case TYPE_DECIMAL: return reflect.TypeOf(Decimal{}) - case C.DUCKDB_TYPE_TIMESTAMP_S: - return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_TIMESTAMP_MS: - return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_TIMESTAMP_NS: - return reflect.TypeOf(time.Time{}) - case C.DUCKDB_TYPE_LIST: + case TYPE_LIST: return reflect.TypeOf([]any{}) - case C.DUCKDB_TYPE_STRUCT: + case TYPE_STRUCT: return reflect.TypeOf(map[string]any{}) - case C.DUCKDB_TYPE_MAP: + case TYPE_MAP: return reflect.TypeOf(Map{}) - case C.DUCKDB_TYPE_UUID: + case TYPE_UUID: return reflect.TypeOf([]byte{}) - case C.DUCKDB_TYPE_TIMESTAMP_TZ: - return reflect.TypeOf(time.Time{}) default: return nil } } -// Implements driver.RowsColumnTypeScanType +// ColumnTypeDatabaseTypeName implements driver.RowsColumnTypeScanType. func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - // Only allocate logical type if necessary - colType := C.duckdb_column_type(&r.res, C.idx_t(index)) - switch colType { - case C.DUCKDB_TYPE_DECIMAL: - fallthrough - case C.DUCKDB_TYPE_ENUM: - fallthrough - case C.DUCKDB_TYPE_LIST: - fallthrough - case C.DUCKDB_TYPE_STRUCT: - fallthrough - case C.DUCKDB_TYPE_MAP: - logColType := C.duckdb_column_logical_type(&r.res, C.idx_t(index)) - defer C.duckdb_destroy_logical_type(&logColType) - return logicalTypeName(logColType) + t := Type(C.duckdb_column_type(&r.res, C.idx_t(index))) + switch t { + case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP: + // Only allocate the logical type if necessary. + logicalType := C.duckdb_column_logical_type(&r.res, C.idx_t(index)) + defer C.duckdb_destroy_logical_type(&logicalType) + return logicalTypeName(logicalType) default: - // Handle as primitive type - return duckdbTypeMap[colType] + return typeToStringMap[t] } } @@ -186,62 +163,61 @@ func (r *rows) Close() error { return err } -func logicalTypeName(lt C.duckdb_logical_type) string { - t := C.duckdb_get_type_id(lt) +func logicalTypeName(logicalType C.duckdb_logical_type) string { + t := Type(C.duckdb_get_type_id(logicalType)) switch t { - case C.DUCKDB_TYPE_DECIMAL: - width := C.duckdb_decimal_width(lt) - scale := C.duckdb_decimal_scale(lt) + case TYPE_DECIMAL: + width := C.duckdb_decimal_width(logicalType) + scale := C.duckdb_decimal_scale(logicalType) return fmt.Sprintf("DECIMAL(%d,%d)", width, scale) - case C.DUCKDB_TYPE_ENUM: - // C API does not currently expose enum name + case TYPE_ENUM: + // The C API does not expose ENUM names. return "ENUM" - case C.DUCKDB_TYPE_LIST: - clt := C.duckdb_list_type_child_type(lt) - defer C.duckdb_destroy_logical_type(&clt) - return logicalTypeName(clt) + "[]" - case C.DUCKDB_TYPE_STRUCT: - return logicalTypeNameStruct(lt) - case C.DUCKDB_TYPE_MAP: - return logicalTypeNameMap(lt) + case TYPE_LIST: + childType := C.duckdb_list_type_child_type(logicalType) + defer C.duckdb_destroy_logical_type(&childType) + return logicalTypeName(childType) + "[]" + case TYPE_STRUCT: + return logicalTypeNameStruct(logicalType) + case TYPE_MAP: + return logicalTypeNameMap(logicalType) default: - return duckdbTypeMap[t] + return typeToStringMap[t] } } -func logicalTypeNameStruct(lt C.duckdb_logical_type) string { - count := int(C.duckdb_struct_type_child_count(lt)) +func logicalTypeNameStruct(logicalType C.duckdb_logical_type) string { + count := int(C.duckdb_struct_type_child_count(logicalType)) name := "STRUCT(" + for i := 0; i < count; i++ { - ptrToChildName := C.duckdb_struct_type_child_name(lt, C.idx_t(i)) + ptrToChildName := C.duckdb_struct_type_child_name(logicalType, C.idx_t(i)) childName := C.GoString(ptrToChildName) - childLogicalType := C.duckdb_struct_type_child_type(lt, C.idx_t(i)) + childType := C.duckdb_struct_type_child_type(logicalType, C.idx_t(i)) - // Add comma if not at end of list - name += escapeStructFieldName(childName) + " " + logicalTypeName(childLogicalType) + // Add comma if not at the end of the list. + name += escapeStructFieldName(childName) + " " + logicalTypeName(childType) if i != count-1 { name += ", " } C.duckdb_free(unsafe.Pointer(ptrToChildName)) - C.duckdb_destroy_logical_type(&childLogicalType) + C.duckdb_destroy_logical_type(&childType) } return name + ")" } -func logicalTypeNameMap(lt C.duckdb_logical_type) string { - // Key logical type - klt := C.duckdb_map_type_key_type(lt) - defer C.duckdb_destroy_logical_type(&klt) +func logicalTypeNameMap(logicalType C.duckdb_logical_type) string { + keyType := C.duckdb_map_type_key_type(logicalType) + defer C.duckdb_destroy_logical_type(&keyType) - // Value logical type - vlt := C.duckdb_map_type_value_type(lt) - defer C.duckdb_destroy_logical_type(&vlt) + valueType := C.duckdb_map_type_value_type(logicalType) + defer C.duckdb_destroy_logical_type(&valueType) - return fmt.Sprintf("MAP(%s, %s)", logicalTypeName(klt), logicalTypeName(vlt)) + return fmt.Sprintf("MAP(%s, %s)", logicalTypeName(keyType), logicalTypeName(valueType)) } -// DuckDB escapes struct field names by doubling double quotes, then wrapping in double quotes. func escapeStructFieldName(s string) string { + // DuckDB escapes STRUCT field names by doubling double quotes, then wrapping in double quotes. return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` } diff --git a/statement.go b/statement.go index 0de5636a..dce7d7ed 100644 --- a/statement.go +++ b/statement.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -135,18 +134,18 @@ func (s *stmt) bind(args []driver.NamedValue) error { case string: val := C.CString(v) if rv := C.duckdb_bind_varchar(*s.stmt, C.idx_t(i+1), val); rv == C.DuckDBError { - C.free(unsafe.Pointer(val)) + C.duckdb_free(unsafe.Pointer(val)) return errCouldNotBind } - C.free(unsafe.Pointer(val)) + C.duckdb_free(unsafe.Pointer(val)) case []byte: val := C.CBytes(v) l := len(v) if rv := C.duckdb_bind_blob(*s.stmt, C.idx_t(i+1), val, C.uint64_t(l)); rv == C.DuckDBError { - C.free(unsafe.Pointer(val)) + C.duckdb_free(unsafe.Pointer(val)) return errCouldNotBind } - C.free(unsafe.Pointer(val)) + C.duckdb_free(unsafe.Pointer(val)) case time.Time: val := C.duckdb_timestamp{ micros: C.int64_t(v.UTC().UnixMicro()), diff --git a/type.go b/type.go new file mode 100644 index 00000000..89acdae6 --- /dev/null +++ b/type.go @@ -0,0 +1,99 @@ +package duckdb + +/* +#include +*/ +import "C" + +// Type wraps the corresponding DuckDB type enum. +type Type C.duckdb_type + +const ( + TYPE_INVALID Type = C.DUCKDB_TYPE_INVALID + TYPE_BOOLEAN Type = C.DUCKDB_TYPE_BOOLEAN + TYPE_TINYINT Type = C.DUCKDB_TYPE_TINYINT + TYPE_SMALLINT Type = C.DUCKDB_TYPE_SMALLINT + TYPE_INTEGER Type = C.DUCKDB_TYPE_INTEGER + TYPE_BIGINT Type = C.DUCKDB_TYPE_BIGINT + TYPE_UTINYINT Type = C.DUCKDB_TYPE_UTINYINT + TYPE_USMALLINT Type = C.DUCKDB_TYPE_USMALLINT + TYPE_UINTEGER Type = C.DUCKDB_TYPE_UINTEGER + TYPE_UBIGINT Type = C.DUCKDB_TYPE_UBIGINT + TYPE_FLOAT Type = C.DUCKDB_TYPE_FLOAT + TYPE_DOUBLE Type = C.DUCKDB_TYPE_DOUBLE + TYPE_TIMESTAMP Type = C.DUCKDB_TYPE_TIMESTAMP + TYPE_DATE Type = C.DUCKDB_TYPE_DATE + TYPE_TIME Type = C.DUCKDB_TYPE_TIME + TYPE_INTERVAL Type = C.DUCKDB_TYPE_INTERVAL + TYPE_HUGEINT Type = C.DUCKDB_TYPE_HUGEINT + TYPE_UHUGEINT Type = C.DUCKDB_TYPE_UHUGEINT + TYPE_VARCHAR Type = C.DUCKDB_TYPE_VARCHAR + TYPE_BLOB Type = C.DUCKDB_TYPE_BLOB + TYPE_DECIMAL Type = C.DUCKDB_TYPE_DECIMAL + TYPE_TIMESTAMP_S Type = C.DUCKDB_TYPE_TIMESTAMP_S + TYPE_TIMESTAMP_MS Type = C.DUCKDB_TYPE_TIMESTAMP_MS + TYPE_TIMESTAMP_NS Type = C.DUCKDB_TYPE_TIMESTAMP_NS + TYPE_ENUM Type = C.DUCKDB_TYPE_ENUM + TYPE_LIST Type = C.DUCKDB_TYPE_LIST + TYPE_STRUCT Type = C.DUCKDB_TYPE_STRUCT + TYPE_MAP Type = C.DUCKDB_TYPE_MAP + TYPE_ARRAY Type = C.DUCKDB_TYPE_ARRAY + TYPE_UUID Type = C.DUCKDB_TYPE_UUID + TYPE_UNION Type = C.DUCKDB_TYPE_UNION + TYPE_BIT Type = C.DUCKDB_TYPE_BIT + TYPE_TIME_TZ Type = C.DUCKDB_TYPE_TIME_TZ + TYPE_TIMESTAMP_TZ Type = C.DUCKDB_TYPE_TIMESTAMP_TZ + TYPE_ANY Type = C.DUCKDB_TYPE_ANY + TYPE_VARINT Type = C.DUCKDB_TYPE_VARINT +) + +// FIXME: Implement support for these types. +var unsupportedTypeToStringMap = map[Type]string{ + TYPE_INVALID: "INVALID", + TYPE_UHUGEINT: "UHUGEINT", + TYPE_ARRAY: "ARRAY", + TYPE_UNION: "UNION", + TYPE_BIT: "BIT", + TYPE_TIME_TZ: "TIME_TZ", + TYPE_ANY: "ANY", + TYPE_VARINT: "VARINT", +} + +var typeToStringMap = map[Type]string{ + TYPE_INVALID: "INVALID", + TYPE_BOOLEAN: "BOOLEAN", + TYPE_TINYINT: "TINYINT", + TYPE_SMALLINT: "SMALLINT", + TYPE_INTEGER: "INTEGER", + TYPE_BIGINT: "BIGINT", + TYPE_UTINYINT: "UTINYINT", + TYPE_USMALLINT: "USMALLINT", + TYPE_UINTEGER: "UINTEGER", + TYPE_UBIGINT: "UBIGINT", + TYPE_FLOAT: "FLOAT", + TYPE_DOUBLE: "DOUBLE", + TYPE_TIMESTAMP: "TIMESTAMP", + TYPE_DATE: "DATE", + TYPE_TIME: "TIME", + TYPE_INTERVAL: "INTERVAL", + TYPE_HUGEINT: "HUGEINT", + TYPE_UHUGEINT: "UHUGEINT", + TYPE_VARCHAR: "VARCHAR", + TYPE_BLOB: "BLOB", + TYPE_DECIMAL: "DECIMAL", + TYPE_TIMESTAMP_S: "TIMESTAMP_S", + TYPE_TIMESTAMP_MS: "TIMESTAMP_MS", + TYPE_TIMESTAMP_NS: "TIMESTAMP_NS", + TYPE_ENUM: "ENUM", + TYPE_LIST: "LIST", + TYPE_STRUCT: "STRUCT", + TYPE_MAP: "MAP", + TYPE_ARRAY: "ARRAY", + TYPE_UUID: "UUID", + TYPE_UNION: "UNION", + TYPE_BIT: "BIT", + TYPE_TIME_TZ: "TIMETZ", + TYPE_TIMESTAMP_TZ: "TIMESTAMPTZ", + TYPE_ANY: "ANY", + TYPE_VARINT: "VARINT", +} diff --git a/type_info.go b/type_info.go new file mode 100644 index 00000000..1aff30ea --- /dev/null +++ b/type_info.go @@ -0,0 +1,310 @@ +package duckdb + +/* +#include +*/ +import "C" + +import ( + "reflect" + "runtime" + "unsafe" +) + +type structEntry struct { + TypeInfo + name string +} + +// StructEntry is an interface to provide STRUCT entry information. +type StructEntry interface { + // Info returns a STRUCT entry's type information. + Info() TypeInfo + // Name returns a STRUCT entry's name. + Name() string +} + +// NewStructEntry returns a STRUCT entry. +// info contains information about the entry's type, and name holds the entry's name. +func NewStructEntry(info TypeInfo, name string) (StructEntry, error) { + if name == "" { + return nil, getError(errAPI, errEmptyName) + } + + return &structEntry{ + TypeInfo: info, + name: name, + }, nil +} + +// Info returns a STRUCT entry's type information. +func (entry *structEntry) Info() TypeInfo { + return entry.TypeInfo +} + +// Name returns a STRUCT entry's name. +func (entry *structEntry) Name() string { + return entry.name +} + +type baseTypeInfo struct { + Type + structEntries []StructEntry + decimalWidth uint8 + decimalScale uint8 +} + +type vectorTypeInfo struct { + baseTypeInfo + dict map[string]uint32 +} + +type typeInfo struct { + baseTypeInfo + childTypes []TypeInfo + enumNames []string +} + +// TypeInfo is an interface for a DuckDB type. +type TypeInfo interface { + logicalType() C.duckdb_logical_type +} + +// NewTypeInfo returns type information for DuckDB's primitive types. +// It returns the TypeInfo, if the Type parameter is a valid primitive type. +// Else, it returns nil, and an error. +// Valid types are: +// TYPE_[BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, UTINYINT, USMALLINT, UINTEGER, +// UBIGINT, FLOAT, DOUBLE, TIMESTAMP, DATE, TIME, INTERVAL, HUGEINT, VARCHAR, BLOB, +// TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UUID, TIMESTAMP_TZ, ANY]. +func NewTypeInfo(t Type) (TypeInfo, error) { + name, inMap := unsupportedTypeToStringMap[t] + if inMap && t != TYPE_ANY { + return nil, getError(errAPI, unsupportedTypeError(name)) + } + + switch t { + case TYPE_DECIMAL: + return nil, getError(errAPI, tryOtherFuncError(funcName(NewDecimalInfo))) + case TYPE_ENUM: + return nil, getError(errAPI, tryOtherFuncError(funcName(NewEnumInfo))) + case TYPE_LIST: + return nil, getError(errAPI, tryOtherFuncError(funcName(NewListInfo))) + case TYPE_STRUCT: + return nil, getError(errAPI, tryOtherFuncError(funcName(NewStructInfo))) + case TYPE_MAP: + return nil, getError(errAPI, tryOtherFuncError(funcName(NewMapInfo))) + } + + return &typeInfo{ + baseTypeInfo: baseTypeInfo{Type: t}, + }, nil +} + +// NewDecimalInfo returns DECIMAL type information. +// Its input parameters are the width and scale of the DECIMAL type. +func NewDecimalInfo(width uint8, scale uint8) (TypeInfo, error) { + if width < 1 || width > MAX_DECIMAL_WIDTH { + return nil, getError(errAPI, errInvalidDecimalWidth) + } + if scale > width { + return nil, getError(errAPI, errInvalidDecimalScale) + } + + return &typeInfo{ + baseTypeInfo: baseTypeInfo{ + Type: TYPE_DECIMAL, + decimalWidth: width, + decimalScale: scale, + }, + }, nil +} + +// NewEnumInfo returns ENUM type information. +// Its input parameters are the dictionary values. +func NewEnumInfo(first string, others ...string) (TypeInfo, error) { + // Check for duplicate names. + m := map[string]bool{} + m[first] = true + for _, name := range others { + _, inMap := m[name] + if inMap { + return nil, getError(errAPI, duplicateNameError(name)) + } + m[name] = true + } + + info := &typeInfo{ + baseTypeInfo: baseTypeInfo{ + Type: TYPE_ENUM, + }, + enumNames: make([]string, 0), + } + + info.enumNames = append(info.enumNames, first) + info.enumNames = append(info.enumNames, others...) + return info, nil +} + +// NewListInfo returns LIST type information. +// childInfo contains the type information of the LIST's elements. +func NewListInfo(childInfo TypeInfo) (TypeInfo, error) { + if childInfo == nil { + return nil, getError(errAPI, interfaceIsNilError("childInfo")) + } + + info := &typeInfo{ + baseTypeInfo: baseTypeInfo{Type: TYPE_LIST}, + childTypes: make([]TypeInfo, 1), + } + info.childTypes[0] = childInfo + return info, nil +} + +// NewStructInfo returns STRUCT type information. +// Its input parameters are the STRUCT entries. +func NewStructInfo(firstEntry StructEntry, others ...StructEntry) (TypeInfo, error) { + if firstEntry == nil { + return nil, getError(errAPI, interfaceIsNilError("firstEntry")) + } + if firstEntry.Info() == nil { + return nil, getError(errAPI, interfaceIsNilError("firstEntry.Info()")) + } + for i, entry := range others { + if entry == nil { + return nil, getError(errAPI, addIndexToError(interfaceIsNilError("entry"), i)) + } + if entry.Info() == nil { + return nil, getError(errAPI, addIndexToError(interfaceIsNilError("entry.Info()"), i)) + } + } + + // Check for duplicate names. + m := map[string]bool{} + m[firstEntry.Name()] = true + for _, entry := range others { + name := entry.Name() + _, inMap := m[name] + if inMap { + return nil, getError(errAPI, duplicateNameError(name)) + } + m[name] = true + } + + info := &typeInfo{ + baseTypeInfo: baseTypeInfo{ + Type: TYPE_STRUCT, + structEntries: make([]StructEntry, 0), + }, + } + info.structEntries = append(info.structEntries, firstEntry) + info.structEntries = append(info.structEntries, others...) + return info, nil +} + +// NewMapInfo returns MAP type information. +// keyInfo contains the type information of the MAP keys. +// valueInfo contains the type information of the MAP values. +func NewMapInfo(keyInfo TypeInfo, valueInfo TypeInfo) (TypeInfo, error) { + if keyInfo == nil { + return nil, getError(errAPI, interfaceIsNilError("keyInfo")) + } + if valueInfo == nil { + return nil, getError(errAPI, interfaceIsNilError("valueInfo")) + } + + info := &typeInfo{ + baseTypeInfo: baseTypeInfo{Type: TYPE_MAP}, + childTypes: make([]TypeInfo, 2), + } + info.childTypes[0] = keyInfo + info.childTypes[1] = valueInfo + return info, nil +} + +func (info *typeInfo) logicalType() C.duckdb_logical_type { + switch info.Type { + case TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_UTINYINT, TYPE_USMALLINT, + TYPE_UINTEGER, TYPE_UBIGINT, TYPE_FLOAT, TYPE_DOUBLE, TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, + TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ, TYPE_DATE, TYPE_TIME, TYPE_INTERVAL, TYPE_HUGEINT, TYPE_VARCHAR, + TYPE_BLOB, TYPE_UUID, TYPE_ANY: + return C.duckdb_create_logical_type(C.duckdb_type(info.Type)) + + case TYPE_DECIMAL: + return C.duckdb_create_decimal_type(C.uint8_t(info.decimalWidth), C.uint8_t(info.decimalScale)) + case TYPE_ENUM: + return info.logicalEnumType() + case TYPE_LIST: + return info.logicalListType() + case TYPE_STRUCT: + return info.logicalStructType() + case TYPE_MAP: + return info.logicalMapType() + } + return nil +} + +func (info *typeInfo) logicalEnumType() C.duckdb_logical_type { + count := len(info.enumNames) + size := C.size_t(unsafe.Sizeof((*C.char)(nil))) + names := (*[1 << 31]*C.char)(C.malloc(C.size_t(count) * size)) + + for i, name := range info.enumNames { + (*names)[i] = C.CString(name) + } + cNames := (**C.char)(unsafe.Pointer(names)) + logicalType := C.duckdb_create_enum_type(cNames, C.idx_t(count)) + + for i := 0; i < count; i++ { + C.duckdb_free(unsafe.Pointer((*names)[i])) + } + C.duckdb_free(unsafe.Pointer(names)) + return logicalType +} + +func (info *typeInfo) logicalListType() C.duckdb_logical_type { + child := info.childTypes[0].logicalType() + logicalType := C.duckdb_create_list_type(child) + C.duckdb_destroy_logical_type(&child) + return logicalType +} + +func (info *typeInfo) logicalStructType() C.duckdb_logical_type { + count := len(info.structEntries) + size := C.size_t(unsafe.Sizeof(C.duckdb_logical_type(nil))) + types := (*[1 << 31]C.duckdb_logical_type)(C.malloc(C.size_t(count) * size)) + + size = C.size_t(unsafe.Sizeof((*C.char)(nil))) + names := (*[1 << 31]*C.char)(C.malloc(C.size_t(count) * size)) + + for i, entry := range info.structEntries { + (*types)[i] = entry.Info().logicalType() + (*names)[i] = C.CString(entry.Name()) + } + + cTypes := (*C.duckdb_logical_type)(unsafe.Pointer(types)) + cNames := (**C.char)(unsafe.Pointer(names)) + logicalType := C.duckdb_create_struct_type(cTypes, cNames, C.idx_t(count)) + + for i := 0; i < count; i++ { + C.duckdb_destroy_logical_type(&types[i]) + C.duckdb_free(unsafe.Pointer((*names)[i])) + } + C.duckdb_free(unsafe.Pointer(types)) + C.duckdb_free(unsafe.Pointer(names)) + return logicalType +} + +func (info *typeInfo) logicalMapType() C.duckdb_logical_type { + key := info.childTypes[0].logicalType() + value := info.childTypes[1].logicalType() + logicalType := C.duckdb_create_map_type(key, value) + + C.duckdb_destroy_logical_type(&key) + C.duckdb_destroy_logical_type(&value) + return logicalType +} + +func funcName(i interface{}) string { + return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() +} diff --git a/type_info_test.go b/type_info_test.go new file mode 100644 index 00000000..f8f5a0b6 --- /dev/null +++ b/type_info_test.go @@ -0,0 +1,139 @@ +package duckdb + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTypeInfo(t *testing.T) { + var primitiveTypes []Type + for k := range typeToStringMap { + _, inMap := unsupportedTypeToStringMap[k] + if inMap && k != TYPE_ANY { + continue + } + switch k { + case TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP: + continue + } + primitiveTypes = append(primitiveTypes, k) + } + + // Create each primitive type information. + var typeInfos []TypeInfo + for _, primitive := range primitiveTypes { + info, err := NewTypeInfo(primitive) + require.NoError(t, err) + typeInfos = append(typeInfos, info) + } + + // Create nested types. + decimalInfo, err := NewDecimalInfo(3, 2) + require.NoError(t, err) + enumInfo, err := NewEnumInfo("hello", "world", "!") + require.NoError(t, err) + listInfo, err := NewListInfo(decimalInfo) + require.NoError(t, err) + nestedListInfo, err := NewListInfo(listInfo) + require.NoError(t, err) + + firstEntry, err := NewStructEntry(enumInfo, "hello") + require.NoError(t, err) + secondEntry, err := NewStructEntry(nestedListInfo, "world") + require.NoError(t, err) + structInfo, err := NewStructInfo(firstEntry, secondEntry) + require.NoError(t, err) + + firstEntry, err = NewStructEntry(structInfo, "hello") + require.NoError(t, err) + secondEntry, err = NewStructEntry(listInfo, "world") + require.NoError(t, err) + nestedStructInfo, err := NewStructInfo(firstEntry, secondEntry) + require.NoError(t, err) + + mapInfo, err := NewMapInfo(nestedStructInfo, nestedListInfo) + require.NoError(t, err) + + typeInfos = append(typeInfos, decimalInfo, enumInfo, listInfo, nestedListInfo, structInfo, nestedStructInfo, mapInfo) + + // Use each type as a child. + for _, info := range typeInfos { + _, err = NewListInfo(info) + require.NoError(t, err) + } +} + +func TestErrTypeInfo(t *testing.T) { + t.Parallel() + + var incorrectTypes []Type + incorrectTypes = append(incorrectTypes, TYPE_DECIMAL, TYPE_ENUM, TYPE_LIST, TYPE_STRUCT, TYPE_MAP) + + for _, incorrect := range incorrectTypes { + _, err := NewTypeInfo(incorrect) + testError(t, err, errAPI.Error(), tryOtherFuncErrMsg) + } + + var unsupportedTypes []Type + for k := range unsupportedTypeToStringMap { + if k != TYPE_ANY { + unsupportedTypes = append(unsupportedTypes, k) + } + } + + for _, unsupported := range unsupportedTypes { + _, err := NewTypeInfo(unsupported) + testError(t, err, errAPI.Error(), unsupportedTypeErrMsg) + } + + // Invalid DECIMAL. + _, err := NewDecimalInfo(0, 0) + testError(t, err, errAPI.Error(), errInvalidDecimalWidth.Error()) + _, err = NewDecimalInfo(42, 20) + testError(t, err, errAPI.Error(), errInvalidDecimalWidth.Error()) + _, err = NewDecimalInfo(5, 6) + testError(t, err, errAPI.Error(), errInvalidDecimalScale.Error()) + + // Invalid ENUM. + _, err = NewEnumInfo("hello", "hello") + testError(t, err, errAPI.Error(), duplicateNameErrMsg) + _, err = NewEnumInfo("hello", "world", "hello") + testError(t, err, errAPI.Error(), duplicateNameErrMsg) + + validInfo, err := NewTypeInfo(TYPE_FLOAT) + require.NoError(t, err) + + // Invalid STRUCT entry. + _, err = NewStructEntry(validInfo, "") + testError(t, err, errAPI.Error(), errEmptyName.Error()) + + validStructEntry, err := NewStructEntry(validInfo, "hello") + require.NoError(t, err) + otherValidStructEntry, err := NewStructEntry(validInfo, "you") + require.NoError(t, err) + nilStructEntry, err := NewStructEntry(nil, "hello") + require.NoError(t, err) + + // Invalid interfaces. + _, err = NewListInfo(nil) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + + _, err = NewStructInfo(nil) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + _, err = NewStructInfo(validStructEntry, nil) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + _, err = NewStructInfo(nilStructEntry, validStructEntry) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + _, err = NewStructInfo(validStructEntry, nilStructEntry) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + _, err = NewStructInfo(validStructEntry, validStructEntry) + testError(t, err, errAPI.Error(), duplicateNameErrMsg) + _, err = NewStructInfo(validStructEntry, otherValidStructEntry, validStructEntry) + testError(t, err, errAPI.Error(), duplicateNameErrMsg) + + _, err = NewMapInfo(nil, validInfo) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) + _, err = NewMapInfo(validInfo, nil) + testError(t, err, errAPI.Error(), interfaceIsNilErrMsg) +} diff --git a/types.go b/types.go index b698d350..bc21c29d 100644 --- a/types.go +++ b/types.go @@ -13,53 +13,6 @@ import ( "github.com/mitchellh/mapstructure" ) -// FIXME: Implement support for these types. -var unsupportedTypeMap = map[C.duckdb_type]string{ - C.DUCKDB_TYPE_INVALID: "INVALID", - C.DUCKDB_TYPE_UHUGEINT: "UHUGEINT", - C.DUCKDB_TYPE_ARRAY: "ARRAY", - C.DUCKDB_TYPE_UNION: "UNION", - C.DUCKDB_TYPE_BIT: "BIT", - C.DUCKDB_TYPE_TIME_TZ: "TIME_TZ", -} - -var duckdbTypeMap = map[C.duckdb_type]string{ - C.DUCKDB_TYPE_INVALID: "INVALID", - C.DUCKDB_TYPE_BOOLEAN: "BOOLEAN", - C.DUCKDB_TYPE_TINYINT: "TINYINT", - C.DUCKDB_TYPE_SMALLINT: "SMALLINT", - C.DUCKDB_TYPE_INTEGER: "INTEGER", - C.DUCKDB_TYPE_BIGINT: "BIGINT", - C.DUCKDB_TYPE_UTINYINT: "UTINYINT", - C.DUCKDB_TYPE_USMALLINT: "USMALLINT", - C.DUCKDB_TYPE_UINTEGER: "UINTEGER", - C.DUCKDB_TYPE_UBIGINT: "UBIGINT", - C.DUCKDB_TYPE_FLOAT: "FLOAT", - C.DUCKDB_TYPE_DOUBLE: "DOUBLE", - C.DUCKDB_TYPE_TIMESTAMP: "TIMESTAMP", - C.DUCKDB_TYPE_DATE: "DATE", - C.DUCKDB_TYPE_TIME: "TIME", - C.DUCKDB_TYPE_INTERVAL: "INTERVAL", - C.DUCKDB_TYPE_HUGEINT: "HUGEINT", - C.DUCKDB_TYPE_UHUGEINT: "UHUGEINT", - C.DUCKDB_TYPE_VARCHAR: "VARCHAR", - C.DUCKDB_TYPE_BLOB: "BLOB", - C.DUCKDB_TYPE_DECIMAL: "DECIMAL", - C.DUCKDB_TYPE_TIMESTAMP_S: "TIMESTAMP_S", - C.DUCKDB_TYPE_TIMESTAMP_MS: "TIMESTAMP_MS", - C.DUCKDB_TYPE_TIMESTAMP_NS: "TIMESTAMP_NS", - C.DUCKDB_TYPE_ENUM: "ENUM", - C.DUCKDB_TYPE_LIST: "LIST", - C.DUCKDB_TYPE_STRUCT: "STRUCT", - C.DUCKDB_TYPE_MAP: "MAP", - C.DUCKDB_TYPE_ARRAY: "ARRAY", - C.DUCKDB_TYPE_UUID: "UUID", - C.DUCKDB_TYPE_UNION: "UNION", - C.DUCKDB_TYPE_BIT: "BIT", - C.DUCKDB_TYPE_TIME_TZ: "TIMETZ", - C.DUCKDB_TYPE_TIMESTAMP_TZ: "TIMESTAMPTZ", -} - type numericType interface { int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64 } @@ -162,6 +115,8 @@ func (s *Composite[T]) Scan(v any) error { return mapstructure.Decode(v, &s.t) } +const MAX_DECIMAL_WIDTH = 38 + type Decimal struct { Width uint8 Scale uint8 diff --git a/vector.go b/vector.go index b3499332..c01a2f93 100644 --- a/vector.go +++ b/vector.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -26,19 +25,11 @@ type vector struct { getFn fnGetVectorValue // A callback function to write to this vector. setFn fnSetVectorValue - // The data type of the vector. - duckdbType C.duckdb_type // The child vectors of nested data types. childVectors []vector - // The child names of STRUCT vectors. - childNames []string - // The dictionary for ENUM types. - dict map[string]uint32 - // The width of DECIMAL types. - width uint8 - // The scale of DECIMAL types. - scale uint8 + // The vector's type information. + vectorTypeInfo } func (vec *vector) tryCast(val any) (any, error) { @@ -46,68 +37,59 @@ func (vec *vector) tryCast(val any) (any, error) { return val, nil } - switch vec.duckdbType { - case C.DUCKDB_TYPE_INVALID: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - case C.DUCKDB_TYPE_BOOLEAN: + name, inMap := unsupportedTypeToStringMap[vec.Type] + if inMap { + return nil, unsupportedTypeError(name) + } + + switch vec.Type { + case TYPE_BOOLEAN: return tryPrimitiveCast[bool](val, reflect.Bool.String()) - case C.DUCKDB_TYPE_TINYINT: + case TYPE_TINYINT: return tryNumericCast[int8](val, reflect.Int8.String()) - case C.DUCKDB_TYPE_SMALLINT: + case TYPE_SMALLINT: return tryNumericCast[int16](val, reflect.Int16.String()) - case C.DUCKDB_TYPE_INTEGER: + case TYPE_INTEGER: return tryNumericCast[int32](val, reflect.Int32.String()) - case C.DUCKDB_TYPE_BIGINT: + case TYPE_BIGINT: return tryNumericCast[int64](val, reflect.Int64.String()) - case C.DUCKDB_TYPE_UTINYINT: + case TYPE_UTINYINT: return tryNumericCast[uint8](val, reflect.Uint8.String()) - case C.DUCKDB_TYPE_USMALLINT: + case TYPE_USMALLINT: return tryNumericCast[uint16](val, reflect.Uint16.String()) - case C.DUCKDB_TYPE_UINTEGER: + case TYPE_UINTEGER: return tryNumericCast[uint32](val, reflect.Uint32.String()) - case C.DUCKDB_TYPE_UBIGINT: + case TYPE_UBIGINT: return tryNumericCast[uint64](val, reflect.Uint64.String()) - case C.DUCKDB_TYPE_FLOAT: + case TYPE_FLOAT: return tryNumericCast[float32](val, reflect.Float32.String()) - case C.DUCKDB_TYPE_DOUBLE: + case TYPE_DOUBLE: return tryNumericCast[float64](val, reflect.Float64.String()) - case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, - C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ, C.DUCKDB_TYPE_DATE, C.DUCKDB_TYPE_TIME: + case TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ, + TYPE_DATE, TYPE_TIME: return tryPrimitiveCast[time.Time](val, reflect.TypeOf(time.Time{}).String()) - case C.DUCKDB_TYPE_INTERVAL: + case TYPE_INTERVAL: return tryPrimitiveCast[Interval](val, reflect.TypeOf(Interval{}).String()) - case C.DUCKDB_TYPE_HUGEINT: - // Note that this expects *big.Int. + case TYPE_HUGEINT: return tryPrimitiveCast[*big.Int](val, reflect.TypeOf(big.Int{}).String()) - case C.DUCKDB_TYPE_UHUGEINT: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - case C.DUCKDB_TYPE_VARCHAR: + case TYPE_VARCHAR: return tryPrimitiveCast[string](val, reflect.String.String()) - case C.DUCKDB_TYPE_BLOB: + case TYPE_BLOB: return tryPrimitiveCast[[]byte](val, reflect.TypeOf([]byte{}).String()) - case C.DUCKDB_TYPE_DECIMAL: + case TYPE_DECIMAL: return vec.tryCastDecimal(val) - case C.DUCKDB_TYPE_ENUM: + case TYPE_ENUM: return vec.tryCastEnum(val) - case C.DUCKDB_TYPE_LIST: + case TYPE_LIST: return vec.tryCastList(val) - case C.DUCKDB_TYPE_STRUCT: + case TYPE_STRUCT: return vec.tryCastStruct(val) - case C.DUCKDB_TYPE_MAP: + case TYPE_MAP: return tryPrimitiveCast[Map](val, reflect.TypeOf(Map{}).String()) - case C.DUCKDB_TYPE_ARRAY: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - case C.DUCKDB_TYPE_UUID: + case TYPE_UUID: return tryPrimitiveCast[UUID](val, reflect.TypeOf(UUID{}).String()) - case C.DUCKDB_TYPE_UNION: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - case C.DUCKDB_TYPE_BIT: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - case C.DUCKDB_TYPE_TIME_TZ: - return nil, unsupportedTypeError(duckdbTypeMap[vec.duckdbType]) - default: - return nil, unsupportedTypeError("unknown type") } + return nil, unsupportedTypeError(unknownTypeErrMsg) } func (*vector) canNil(val reflect.Value) bool { @@ -153,8 +135,8 @@ func (vec *vector) tryCastDecimal(val any) (Decimal, error) { return v, castError(goType.String(), reflect.TypeOf(Decimal{}).String()) } - if v.Width != vec.width || v.Scale != vec.scale { - d := Decimal{Width: vec.width, Scale: vec.scale} + if v.Width != vec.decimalWidth || v.Scale != vec.decimalScale { + d := Decimal{Width: vec.decimalWidth, Scale: vec.decimalScale} return v, castError(d.toString(), v.toString()) } return v, nil @@ -222,23 +204,24 @@ func (vec *vector) tryCastStruct(val any) (map[string]any, error) { } // Catch mismatching field count. - if len(m) != len(vec.childNames) { - return nil, structFieldError(strconv.Itoa(len(m)), strconv.Itoa(len(vec.childNames))) + count := len(vec.structEntries) + if len(m) != count { + return nil, structFieldError(strconv.Itoa(len(m)), strconv.Itoa(count)) } // Cast child entries and return the map. - for i := 0; i < len(vec.childVectors); i++ { + for i := 0; i < count; i++ { childVector := vec.childVectors[i] - childName := vec.childNames[i] - v, ok := m[childName] + name := vec.structEntries[i].Name() + v, ok := m[name] // Catch mismatching field names. if !ok { - return nil, structFieldError("missing field", childName) + return nil, structFieldError("missing field", name) } var err error - m[childName], err = childVector.tryCast(v) + m[name], err = childVector.tryCast(v) if err != nil { return nil, err } @@ -247,70 +230,62 @@ func (vec *vector) tryCastStruct(val any) (map[string]any, error) { } func (vec *vector) init(logicalType C.duckdb_logical_type, colIdx int) error { - duckdbType := C.duckdb_get_type_id(logicalType) - - switch duckdbType { - case C.DUCKDB_TYPE_INVALID: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) - case C.DUCKDB_TYPE_BOOLEAN: - initPrimitive[bool](vec, C.DUCKDB_TYPE_BOOLEAN) - case C.DUCKDB_TYPE_TINYINT: - initPrimitive[int8](vec, C.DUCKDB_TYPE_TINYINT) - case C.DUCKDB_TYPE_SMALLINT: - initPrimitive[int16](vec, C.DUCKDB_TYPE_SMALLINT) - case C.DUCKDB_TYPE_INTEGER: - initPrimitive[int32](vec, C.DUCKDB_TYPE_INTEGER) - case C.DUCKDB_TYPE_BIGINT: - initPrimitive[int64](vec, C.DUCKDB_TYPE_BIGINT) - case C.DUCKDB_TYPE_UTINYINT: - initPrimitive[uint8](vec, C.DUCKDB_TYPE_UTINYINT) - case C.DUCKDB_TYPE_USMALLINT: - initPrimitive[uint16](vec, C.DUCKDB_TYPE_USMALLINT) - case C.DUCKDB_TYPE_UINTEGER: - initPrimitive[uint32](vec, C.DUCKDB_TYPE_UINTEGER) - case C.DUCKDB_TYPE_UBIGINT: - initPrimitive[uint64](vec, C.DUCKDB_TYPE_UBIGINT) - case C.DUCKDB_TYPE_FLOAT: - initPrimitive[float32](vec, C.DUCKDB_TYPE_FLOAT) - case C.DUCKDB_TYPE_DOUBLE: - initPrimitive[float64](vec, C.DUCKDB_TYPE_DOUBLE) - case C.DUCKDB_TYPE_TIMESTAMP, C.DUCKDB_TYPE_TIMESTAMP_S, C.DUCKDB_TYPE_TIMESTAMP_MS, - C.DUCKDB_TYPE_TIMESTAMP_NS, C.DUCKDB_TYPE_TIMESTAMP_TZ: - vec.initTS(duckdbType) - case C.DUCKDB_TYPE_DATE: + t := Type(C.duckdb_get_type_id(logicalType)) + + name, inMap := unsupportedTypeToStringMap[t] + if inMap { + return addIndexToError(unsupportedTypeError(name), colIdx) + } + + switch t { + case TYPE_BOOLEAN: + initPrimitive[bool](vec, t) + case TYPE_TINYINT: + initPrimitive[int8](vec, t) + case TYPE_SMALLINT: + initPrimitive[int16](vec, t) + case TYPE_INTEGER: + initPrimitive[int32](vec, t) + case TYPE_BIGINT: + initPrimitive[int64](vec, t) + case TYPE_UTINYINT: + initPrimitive[uint8](vec, t) + case TYPE_USMALLINT: + initPrimitive[uint16](vec, t) + case TYPE_UINTEGER: + initPrimitive[uint32](vec, t) + case TYPE_UBIGINT: + initPrimitive[uint64](vec, t) + case TYPE_FLOAT: + initPrimitive[float32](vec, t) + case TYPE_DOUBLE: + initPrimitive[float64](vec, t) + case TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ: + vec.initTS(t) + case TYPE_DATE: vec.initDate() - case C.DUCKDB_TYPE_TIME: + case TYPE_TIME: vec.initTime() - case C.DUCKDB_TYPE_INTERVAL: + case TYPE_INTERVAL: vec.initInterval() - case C.DUCKDB_TYPE_HUGEINT: + case TYPE_HUGEINT: vec.initHugeint() - case C.DUCKDB_TYPE_UHUGEINT: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) - case C.DUCKDB_TYPE_VARCHAR, C.DUCKDB_TYPE_BLOB: - vec.initCString(duckdbType) - case C.DUCKDB_TYPE_DECIMAL: + case TYPE_VARCHAR, TYPE_BLOB: + vec.initCString(t) + case TYPE_DECIMAL: return vec.initDecimal(logicalType, colIdx) - case C.DUCKDB_TYPE_ENUM: + case TYPE_ENUM: return vec.initEnum(logicalType, colIdx) - case C.DUCKDB_TYPE_LIST: + case TYPE_LIST: return vec.initList(logicalType, colIdx) - case C.DUCKDB_TYPE_STRUCT: + case TYPE_STRUCT: return vec.initStruct(logicalType, colIdx) - case C.DUCKDB_TYPE_MAP: + case TYPE_MAP: return vec.initMap(logicalType, colIdx) - case C.DUCKDB_TYPE_ARRAY: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) - case C.DUCKDB_TYPE_UUID: + case TYPE_UUID: vec.initUUID() - case C.DUCKDB_TYPE_UNION: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) - case C.DUCKDB_TYPE_BIT: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) - case C.DUCKDB_TYPE_TIME_TZ: - return columnError(unsupportedTypeError(duckdbTypeMap[duckdbType]), colIdx) default: - return columnError(unsupportedTypeError("unknown type"), colIdx) + return addIndexToError(unsupportedTypeError(unknownTypeErrMsg), colIdx) } return nil } @@ -326,13 +301,13 @@ func (vec *vector) initVectors(v C.duckdb_vector, writable bool) { } func (vec *vector) getChildVectors(v C.duckdb_vector, writable bool) { - switch vec.duckdbType { + switch vec.Type { - case C.DUCKDB_TYPE_LIST, C.DUCKDB_TYPE_MAP: + case TYPE_LIST, TYPE_MAP: child := C.duckdb_list_vector_get_child(v) vec.childVectors[0].initVectors(child, writable) - case C.DUCKDB_TYPE_STRUCT: + case TYPE_STRUCT: for i := 0; i < len(vec.childVectors); i++ { child := C.duckdb_struct_vector_get_child(v, C.idx_t(i)) vec.childVectors[i].initVectors(child, writable) @@ -340,7 +315,7 @@ func (vec *vector) getChildVectors(v C.duckdb_vector, writable bool) { } } -func initPrimitive[T any](vec *vector, duckdbType C.duckdb_type) { +func initPrimitive[T any](vec *vector, t Type) { vec.getFn = func(vec *vector, rowIdx C.idx_t) any { if vec.getNull(rowIdx) { return nil @@ -354,24 +329,24 @@ func initPrimitive[T any](vec *vector, duckdbType C.duckdb_type) { } setPrimitive(vec, rowIdx, val.(T)) } - vec.duckdbType = duckdbType + vec.Type = t } -func (vec *vector) initTS(duckdbType C.duckdb_type) { +func (vec *vector) initTS(t Type) { vec.getFn = func(vec *vector, rowIdx C.idx_t) any { if vec.getNull(rowIdx) { return nil } - return vec.getTS(duckdbType, rowIdx) + return vec.getTS(t, rowIdx) } vec.setFn = func(vec *vector, rowIdx C.idx_t, val any) { if val == nil { vec.setNull(rowIdx) return } - vec.setTS(duckdbType, rowIdx, val) + vec.setTS(t, rowIdx, val) } - vec.duckdbType = duckdbType + vec.Type = t } func (vec *vector) initDate() { @@ -388,7 +363,7 @@ func (vec *vector) initDate() { } vec.setDate(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_DATE + vec.Type = TYPE_DATE } func (vec *vector) initTime() { @@ -405,7 +380,7 @@ func (vec *vector) initTime() { } vec.setTime(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_TIME + vec.Type = TYPE_TIME } func (vec *vector) initInterval() { @@ -422,7 +397,7 @@ func (vec *vector) initInterval() { } vec.setInterval(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_INTERVAL + vec.Type = TYPE_INTERVAL } func (vec *vector) initHugeint() { @@ -439,10 +414,10 @@ func (vec *vector) initHugeint() { } vec.setHugeint(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_HUGEINT + vec.Type = TYPE_HUGEINT } -func (vec *vector) initCString(duckdbType C.duckdb_type) { +func (vec *vector) initCString(t Type) { vec.getFn = func(vec *vector, rowIdx C.idx_t) any { if vec.getNull(rowIdx) { return nil @@ -456,34 +431,34 @@ func (vec *vector) initCString(duckdbType C.duckdb_type) { } vec.setCString(rowIdx, val) } - vec.duckdbType = duckdbType + vec.Type = t } func (vec *vector) initDecimal(logicalType C.duckdb_logical_type, colIdx int) error { - vec.width = uint8(C.duckdb_decimal_width(logicalType)) - vec.scale = uint8(C.duckdb_decimal_scale(logicalType)) + vec.decimalWidth = uint8(C.duckdb_decimal_width(logicalType)) + vec.decimalScale = uint8(C.duckdb_decimal_scale(logicalType)) - internalType := C.duckdb_decimal_internal_type(logicalType) - switch internalType { - case C.DUCKDB_TYPE_SMALLINT, C.DUCKDB_TYPE_INTEGER, C.DUCKDB_TYPE_BIGINT, C.DUCKDB_TYPE_HUGEINT: + t := Type(C.duckdb_decimal_internal_type(logicalType)) + switch t { + case TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_HUGEINT: vec.getFn = func(vec *vector, rowIdx C.idx_t) any { if vec.getNull(rowIdx) { return nil } - return vec.getDecimal(internalType, rowIdx) + return vec.getDecimal(t, rowIdx) } vec.setFn = func(vec *vector, rowIdx C.idx_t, val any) { if val == nil { vec.setNull(rowIdx) return } - vec.setDecimal(internalType, rowIdx, val) + vec.setDecimal(t, rowIdx, val) } default: - return columnError(unsupportedTypeError(duckdbTypeMap[internalType]), colIdx) + return addIndexToError(unsupportedTypeError(typeToStringMap[t]), colIdx) } - vec.duckdbType = C.DUCKDB_TYPE_DECIMAL + vec.Type = TYPE_DECIMAL return nil } @@ -498,27 +473,27 @@ func (vec *vector) initEnum(logicalType C.duckdb_logical_type, colIdx int) error C.duckdb_free(unsafe.Pointer(cStr)) } - internalType := C.duckdb_enum_internal_type(logicalType) - switch internalType { - case C.DUCKDB_TYPE_UTINYINT, C.DUCKDB_TYPE_USMALLINT, C.DUCKDB_TYPE_UINTEGER, C.DUCKDB_TYPE_UBIGINT: + t := Type(C.duckdb_enum_internal_type(logicalType)) + switch t { + case TYPE_UTINYINT, TYPE_USMALLINT, TYPE_UINTEGER, TYPE_UBIGINT: vec.getFn = func(vec *vector, rowIdx C.idx_t) any { if vec.getNull(rowIdx) { return nil } - return vec.getEnum(internalType, rowIdx) + return vec.getEnum(t, rowIdx) } vec.setFn = func(vec *vector, rowIdx C.idx_t, val any) { if val == nil { vec.setNull(rowIdx) return } - vec.setEnum(internalType, rowIdx, val) + vec.setEnum(t, rowIdx, val) } default: - return columnError(unsupportedTypeError(duckdbTypeMap[internalType]), colIdx) + return addIndexToError(unsupportedTypeError(typeToStringMap[t]), colIdx) } - vec.duckdbType = C.DUCKDB_TYPE_ENUM + vec.Type = TYPE_ENUM return nil } @@ -547,21 +522,25 @@ func (vec *vector) initList(logicalType C.duckdb_logical_type, colIdx int) error } vec.setList(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_LIST + vec.Type = TYPE_LIST return nil } func (vec *vector) initStruct(logicalType C.duckdb_logical_type, colIdx int) error { childCount := int(C.duckdb_struct_type_child_count(logicalType)) - var childNames []string + var structEntries []StructEntry for i := 0; i < childCount; i++ { - childName := C.duckdb_struct_type_child_name(logicalType, C.idx_t(i)) - childNames = append(childNames, C.GoString(childName)) - C.free(unsafe.Pointer(childName)) + name := C.duckdb_struct_type_child_name(logicalType, C.idx_t(i)) + entry, err := NewStructEntry(nil, C.GoString(name)) + structEntries = append(structEntries, entry) + C.duckdb_free(unsafe.Pointer(name)) + if err != nil { + return err + } } vec.childVectors = make([]vector, childCount) - vec.childNames = childNames + vec.structEntries = structEntries // Recurse into the children. for i := 0; i < childCount; i++ { @@ -587,7 +566,7 @@ func (vec *vector) initStruct(logicalType C.duckdb_logical_type, colIdx int) err } vec.setStruct(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_STRUCT + vec.Type = TYPE_STRUCT return nil } @@ -610,10 +589,10 @@ func (vec *vector) initMap(logicalType C.duckdb_logical_type, colIdx int) error keyType := C.duckdb_map_type_key_type(logicalType) defer C.duckdb_destroy_logical_type(&keyType) - duckdbKeyType := C.duckdb_get_type_id(keyType) - switch duckdbKeyType { - case C.DUCKDB_TYPE_LIST, C.DUCKDB_TYPE_STRUCT, C.DUCKDB_TYPE_MAP, C.DUCKDB_TYPE_ARRAY: - return columnError(errUnsupportedMapKeyType, colIdx) + t := Type(C.duckdb_get_type_id(keyType)) + switch t { + case TYPE_LIST, TYPE_STRUCT, TYPE_MAP, TYPE_ARRAY: + return addIndexToError(errUnsupportedMapKeyType, colIdx) } vec.getFn = func(vec *vector, rowIdx C.idx_t) any { @@ -629,7 +608,7 @@ func (vec *vector) initMap(logicalType C.duckdb_logical_type, colIdx int) error } vec.setMap(rowIdx, val) } - vec.duckdbType = C.DUCKDB_TYPE_MAP + vec.Type = TYPE_MAP return nil } @@ -648,5 +627,5 @@ func (vec *vector) initUUID() { } setPrimitive(vec, rowIdx, uuidToHugeInt(val.(UUID))) } - vec.duckdbType = C.DUCKDB_TYPE_UUID + vec.Type = TYPE_UUID } diff --git a/vector_getters.go b/vector_getters.go index 2a753f53..29757ab0 100644 --- a/vector_getters.go +++ b/vector_getters.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -33,20 +32,20 @@ func getPrimitive[T any](vec *vector, rowIdx C.idx_t) T { return xs[rowIdx] } -func (vec *vector) getTS(duckdbType C.duckdb_type, rowIdx C.idx_t) time.Time { +func (vec *vector) getTS(t Type, rowIdx C.idx_t) time.Time { val := getPrimitive[C.duckdb_timestamp](vec, rowIdx) micros := val.micros - switch duckdbType { - case C.DUCKDB_TYPE_TIMESTAMP: + switch t { + case TYPE_TIMESTAMP: return time.UnixMicro(int64(micros)).UTC() - case C.DUCKDB_TYPE_TIMESTAMP_S: + case TYPE_TIMESTAMP_S: return time.Unix(int64(micros), 0).UTC() - case C.DUCKDB_TYPE_TIMESTAMP_MS: + case TYPE_TIMESTAMP_MS: return time.UnixMilli(int64(micros)).UTC() - case C.DUCKDB_TYPE_TIMESTAMP_NS: + case TYPE_TIMESTAMP_NS: return time.Unix(0, int64(micros)).UTC() - case C.DUCKDB_TYPE_TIMESTAMP_TZ: + case TYPE_TIMESTAMP_TZ: return time.UnixMicro(int64(micros)).UTC() } @@ -92,25 +91,25 @@ func (vec *vector) getCString(rowIdx C.idx_t) any { blob = C.GoBytes(unsafe.Pointer(cStr.ptr), C.int(cStr.length)) } - if vec.duckdbType == C.DUCKDB_TYPE_VARCHAR { + if vec.Type == TYPE_VARCHAR { return string(blob) } return blob } -func (vec *vector) getDecimal(internalType C.duckdb_type, rowIdx C.idx_t) Decimal { +func (vec *vector) getDecimal(t Type, rowIdx C.idx_t) Decimal { var val *big.Int - switch internalType { - case C.DUCKDB_TYPE_SMALLINT: + switch t { + case TYPE_SMALLINT: v := getPrimitive[int16](vec, rowIdx) val = big.NewInt(int64(v)) - case C.DUCKDB_TYPE_INTEGER: + case TYPE_INTEGER: v := getPrimitive[int32](vec, rowIdx) val = big.NewInt(int64(v)) - case C.DUCKDB_TYPE_BIGINT: + case TYPE_BIGINT: v := getPrimitive[int64](vec, rowIdx) val = big.NewInt(v) - case C.DUCKDB_TYPE_HUGEINT: + case TYPE_HUGEINT: v := getPrimitive[C.duckdb_hugeint](vec, rowIdx) val = hugeIntToNative(C.duckdb_hugeint{ lower: v.lower, @@ -118,19 +117,19 @@ func (vec *vector) getDecimal(internalType C.duckdb_type, rowIdx C.idx_t) Decima }) } - return Decimal{Width: vec.width, Scale: vec.scale, Value: val} + return Decimal{Width: vec.decimalWidth, Scale: vec.decimalScale, Value: val} } -func (vec *vector) getEnum(internalType C.duckdb_type, rowIdx C.idx_t) string { +func (vec *vector) getEnum(t Type, rowIdx C.idx_t) string { var idx uint64 - switch internalType { - case C.DUCKDB_TYPE_UTINYINT: + switch t { + case TYPE_UTINYINT: idx = uint64(getPrimitive[uint8](vec, rowIdx)) - case C.DUCKDB_TYPE_USMALLINT: + case TYPE_USMALLINT: idx = uint64(getPrimitive[uint16](vec, rowIdx)) - case C.DUCKDB_TYPE_UINTEGER: + case TYPE_UINTEGER: idx = uint64(getPrimitive[uint32](vec, rowIdx)) - case C.DUCKDB_TYPE_UBIGINT: + case TYPE_UBIGINT: idx = getPrimitive[uint64](vec, rowIdx) } @@ -145,11 +144,11 @@ func (vec *vector) getEnum(internalType C.duckdb_type, rowIdx C.idx_t) string { func (vec *vector) getList(rowIdx C.idx_t) []any { entry := getPrimitive[duckdb_list_entry_t](vec, rowIdx) slice := make([]any, 0, entry.length) - childVector := &vec.childVectors[0] + child := &vec.childVectors[0] // Fill the slice with all child values. for i := C.idx_t(0); i < entry.length; i++ { - val := childVector.getFn(childVector, i+entry.offset) + val := child.getFn(child, i+entry.offset) slice = append(slice, val) } return slice @@ -158,9 +157,9 @@ func (vec *vector) getList(rowIdx C.idx_t) []any { func (vec *vector) getStruct(rowIdx C.idx_t) map[string]any { m := map[string]any{} for i := 0; i < len(vec.childVectors); i++ { - childVector := &vec.childVectors[i] - val := childVector.getFn(childVector, rowIdx) - m[vec.childNames[i]] = val + child := &vec.childVectors[i] + val := child.getFn(child, rowIdx) + m[vec.structEntries[i].Name()] = val } return m } diff --git a/vector_setters.go b/vector_setters.go index cd090a32..5a9ba099 100644 --- a/vector_setters.go +++ b/vector_setters.go @@ -1,7 +1,6 @@ package duckdb /* -#include #include */ import "C" @@ -21,7 +20,7 @@ type fnSetVectorValue func(vec *vector, rowIdx C.idx_t, val any) func (vec *vector) setNull(rowIdx C.idx_t) { C.duckdb_validity_set_row_invalid(vec.mask, rowIdx) - if vec.duckdbType == C.DUCKDB_TYPE_STRUCT { + if vec.Type == TYPE_STRUCT { for i := 0; i < len(vec.childVectors); i++ { vec.childVectors[i].setNull(rowIdx) } @@ -33,19 +32,19 @@ func setPrimitive[T any](vec *vector, rowIdx C.idx_t, v T) { xs[rowIdx] = v } -func (vec *vector) setTS(duckdbType C.duckdb_type, rowIdx C.idx_t, val any) { +func (vec *vector) setTS(t Type, rowIdx C.idx_t, val any) { v := val.(time.Time) var ticks int64 - switch duckdbType { - case C.DUCKDB_TYPE_TIMESTAMP: + switch t { + case TYPE_TIMESTAMP: ticks = v.UTC().UnixMicro() - case C.DUCKDB_TYPE_TIMESTAMP_S: + case TYPE_TIMESTAMP_S: ticks = v.UTC().Unix() - case C.DUCKDB_TYPE_TIMESTAMP_MS: + case TYPE_TIMESTAMP_MS: ticks = v.UTC().UnixMilli() - case C.DUCKDB_TYPE_TIMESTAMP_NS: + case TYPE_TIMESTAMP_NS: ticks = v.UTC().UnixNano() - case C.DUCKDB_TYPE_TIMESTAMP_TZ: + case TYPE_TIMESTAMP_TZ: ticks = v.UTC().UnixMicro() } @@ -90,45 +89,45 @@ func (vec *vector) setHugeint(rowIdx C.idx_t, val any) { func (vec *vector) setCString(rowIdx C.idx_t, val any) { var str string - if vec.duckdbType == C.DUCKDB_TYPE_VARCHAR { + if vec.Type == TYPE_VARCHAR { str = val.(string) - } else if vec.duckdbType == C.DUCKDB_TYPE_BLOB { + } else if vec.Type == TYPE_BLOB { str = string(val.([]byte)[:]) } // This setter also writes BLOBs. cStr := C.CString(str) C.duckdb_vector_assign_string_element_len(vec.duckdbVector, rowIdx, cStr, C.idx_t(len(str))) - C.free(unsafe.Pointer(cStr)) + C.duckdb_free(unsafe.Pointer(cStr)) } -func (vec *vector) setDecimal(internalType C.duckdb_type, rowIdx C.idx_t, val any) { +func (vec *vector) setDecimal(t Type, rowIdx C.idx_t, val any) { v := val.(Decimal) - switch internalType { - case C.DUCKDB_TYPE_SMALLINT: + switch t { + case TYPE_SMALLINT: setPrimitive(vec, rowIdx, int16(v.Value.Int64())) - case C.DUCKDB_TYPE_INTEGER: + case TYPE_INTEGER: setPrimitive(vec, rowIdx, int32(v.Value.Int64())) - case C.DUCKDB_TYPE_BIGINT: + case TYPE_BIGINT: setPrimitive(vec, rowIdx, v.Value.Int64()) - case C.DUCKDB_TYPE_HUGEINT: + case TYPE_HUGEINT: value, _ := hugeIntFromNative(v.Value) setPrimitive(vec, rowIdx, value) } } -func (vec *vector) setEnum(internalType C.duckdb_type, rowIdx C.idx_t, val any) { +func (vec *vector) setEnum(t Type, rowIdx C.idx_t, val any) { v := vec.dict[val.(string)] - switch internalType { - case C.DUCKDB_TYPE_UTINYINT: + switch t { + case TYPE_UTINYINT: setPrimitive(vec, rowIdx, uint8(v)) - case C.DUCKDB_TYPE_USMALLINT: + case TYPE_USMALLINT: setPrimitive(vec, rowIdx, uint16(v)) - case C.DUCKDB_TYPE_UINTEGER: + case TYPE_UINTEGER: setPrimitive(vec, rowIdx, v) - case C.DUCKDB_TYPE_UBIGINT: + case TYPE_UBIGINT: setPrimitive(vec, rowIdx, uint64(v)) } } @@ -159,9 +158,9 @@ func (vec *vector) setList(rowIdx C.idx_t, val any) { func (vec *vector) setStruct(rowIdx C.idx_t, val any) { m := val.(map[string]any) for i := 0; i < len(vec.childVectors); i++ { - childVector := &vec.childVectors[i] - childName := vec.childNames[i] - childVector.setFn(childVector, rowIdx, m[childName]) + child := &vec.childVectors[i] + name := vec.structEntries[i].Name() + child.setFn(child, rowIdx, m[name]) } }