Skip to content

Commit

Permalink
Merge pull request #272 from taniabogatsch/type-interface
Browse files Browse the repository at this point in the history
[Feature] Type interface
  • Loading branch information
taniabogatsch authored Sep 17, 2024
2 parents 61ee2b4 + 647811c commit 7accec1
Show file tree
Hide file tree
Showing 17 changed files with 888 additions and 415 deletions.
13 changes: 6 additions & 7 deletions appender.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package duckdb

/*
#include <stdlib.h>
#include <duckdb.h>
*/
import "C"
Expand Down Expand Up @@ -43,11 +42,11 @@ func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appende
var cSchema *C.char
if schema != "" {
cSchema = C.CString(schema)
defer C.free(unsafe.Pointer(cSchema))
defer C.duckdb_free(unsafe.Pointer(cSchema))
}

cTable := C.CString(table)
defer C.free(unsafe.Pointer(cTable))
defer C.duckdb_free(unsafe.Pointer(cTable))

var duckdbAppender C.duckdb_appender
state := C.duckdb_appender_create(con.duckdbCon, cSchema, cTable, &duckdbAppender)
Expand All @@ -74,10 +73,10 @@ func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appende
a.types[i] = C.duckdb_appender_column_type(duckdbAppender, C.idx_t(i))

// Ensure that we only create an appender for supported column types.
duckdbType := C.duckdb_get_type_id(a.types[i])
name, found := unsupportedTypeMap[duckdbType]
t := Type(C.duckdb_get_type_id(a.types[i]))
name, found := unsupportedTypeToStringMap[t]
if found {
err := columnError(unsupportedTypeError(name), i+1)
err := addIndexToError(unsupportedTypeError(name), i+1)
destroyTypeSlice(a.ptr, a.types)
C.duckdb_appender_destroy(&duckdbAppender)
return nil, getError(errAppenderCreation, err)
Expand Down Expand Up @@ -229,5 +228,5 @@ func destroyTypeSlice(ptr unsafe.Pointer, slice []C.duckdb_logical_type) {
for _, t := range slice {
C.duckdb_destroy_logical_type(&t)
}
C.free(ptr)
C.duckdb_free(ptr)
}
13 changes: 6 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package duckdb

/*
#include <stdlib.h>
#include <duckdb.h>
*/
import "C"
Expand Down Expand Up @@ -152,11 +151,11 @@ func (c *conn) Close() error {
}

func (c *conn) prepareStmt(cmd string) (*stmt, error) {
cmdstr := C.CString(cmd)
defer C.free(unsafe.Pointer(cmdstr))
cmdStr := C.CString(cmd)
defer C.duckdb_free(unsafe.Pointer(cmdStr))

var s C.duckdb_prepared_statement
if state := C.duckdb_prepare(c.duckdbCon, cmdstr, &s); state == C.DuckDBError {
if state := C.duckdb_prepare(c.duckdbCon, cmdStr, &s); state == C.DuckDBError {
dbErr := getDuckDBError(C.GoString(C.duckdb_prepare_error(s)))
C.duckdb_destroy_prepare(&s)
return nil, dbErr
Expand All @@ -166,11 +165,11 @@ func (c *conn) prepareStmt(cmd string) (*stmt, error) {
}

func (c *conn) extractStmts(query string) (C.duckdb_extracted_statements, C.idx_t, error) {
cquery := C.CString(query)
defer C.free(unsafe.Pointer(cquery))
cQuery := C.CString(query)
defer C.duckdb_free(unsafe.Pointer(cQuery))

var stmts C.duckdb_extracted_statements
stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cquery, &stmts)
stmtsCount := C.duckdb_extract_statements(c.duckdbCon, cQuery, &stmts)
if stmtsCount == 0 {
err := C.GoString(C.duckdb_extract_statements_error(stmts))
C.duckdb_destroy_extracted(&stmts)
Expand Down
3 changes: 1 addition & 2 deletions data_chunk.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package duckdb

/*
#include <stdlib.h>
#include <duckdb.h>
*/
import "C"
Expand Down Expand Up @@ -66,7 +65,7 @@ func (chunk *DataChunk) SetValue(colIdx int, rowIdx int, val any) error {
// FIXME: Maybe we can make columnar insertions unsafe, i.e., we always assume a correct type.
v, err := column.tryCast(val)
if err != nil {
return columnError(err, colIdx)
return addIndexToError(err, colIdx)
}

// Set the value.
Expand Down
7 changes: 3 additions & 4 deletions duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package duckdb

/*
#include <stdlib.h>
#include <duckdb.h>
*/
import "C"
Expand Down Expand Up @@ -58,7 +57,7 @@ func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error
defer C.duckdb_destroy_config(&config)

connStr := C.CString(getConnString(dsn))
defer C.free(unsafe.Pointer(connStr))
defer C.duckdb_free(unsafe.Pointer(connStr))

var outError *C.char
defer C.duckdb_free(unsafe.Pointer(outError))
Expand Down Expand Up @@ -143,10 +142,10 @@ func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) {

func setConfigOption(config C.duckdb_config, name string, option string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
defer C.duckdb_free(unsafe.Pointer(cName))

cOption := C.CString(option)
defer C.free(unsafe.Pointer(cOption))
defer C.duckdb_free(unsafe.Pointer(cOption))

state := C.duckdb_set_config(config, cName, cOption)
if state == C.DuckDBError {
Expand Down
19 changes: 9 additions & 10 deletions duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -256,7 +255,7 @@ func TestQuery(t *testing.T) {
for rows.Next() {
var i int
require.NoError(t, rows.Scan(&i))
assert.Equal(t, expected, i)
require.Equal(t, expected, i)
expected++
}
})
Expand All @@ -281,10 +280,8 @@ func TestQuery(t *testing.T) {
func TestJSON(t *testing.T) {
t.Parallel()
db := openDB(t)
defer db.Close()

loadJSONExt(t, db)

var data string

t.Run("select empty JSON", func(t *testing.T) {
Expand All @@ -311,18 +308,20 @@ func TestJSON(t *testing.T) {
require.Equal(t, len(items), 2)
require.Equal(t, items, []string{"foo", "bar"})
})

require.NoError(t, db.Close())
}

func TestEmpty(t *testing.T) {
t.Parallel()
db := openDB(t)
defer db.Close()

rows, err := db.Query(`SELECT 1 WHERE 1 = 0`)
require.NoError(t, err)
defer rows.Close()
require.False(t, rows.Next())
require.NoError(t, rows.Err())
require.NoError(t, db.Close())
}

func TestTypeNamesAndScanTypes(t *testing.T) {
Expand Down Expand Up @@ -519,12 +518,12 @@ func TestTypeNamesAndScanTypes(t *testing.T) {
require.Equal(t, rows.Next(), false)
})
}
require.NoError(t, db.Close())
}

// Running multiple statements in a single query. All statements except the last one are executed and if no error then last statement is executed with args and result returned.
func TestMultipleStatements(t *testing.T) {
db := openDB(t)
defer db.Close()

// test empty query
_, err := db.Exec("")
Expand Down Expand Up @@ -617,13 +616,12 @@ func TestMultipleStatements(t *testing.T) {
err = rows.Close()
require.NoError(t, err)

err = conn.Close()
require.NoError(t, err)
require.NoError(t, conn.Close())
require.NoError(t, db.Close())
}

func TestParquetExtension(t *testing.T) {
db := openDB(t)
defer db.Close()

_, err := db.Exec("CREATE TABLE users (id int, name varchar, age int);")
require.NoError(t, err)
Expand All @@ -647,11 +645,11 @@ func TestParquetExtension(t *testing.T) {

err = os.Remove("./users.parquet")
require.NoError(t, err)
require.NoError(t, db.Close())
}

func TestQueryTimeout(t *testing.T) {
db := openDB(t)
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*250)
defer cancel()
Expand All @@ -663,6 +661,7 @@ func TestQueryTimeout(t *testing.T) {
// a very defensive time check, but should be good enough
// the query takes much longer than 10 seconds
require.Less(t, time.Since(now), 10*time.Second)
require.NoError(t, db.Close())
}

func openDB(t *testing.T) *sql.DB {
Expand Down
30 changes: 25 additions & 5 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ func structFieldError(actual string, expected string) error {
return fmt.Errorf("%s: expected %s, got %s", structFieldErrMsg, expected, actual)
}

func columnError(err error, colIdx int) error {
return fmt.Errorf("%w: %s: %d", err, columnErrMsg, colIdx)
}

func columnCountError(actual int, expected int) error {
return fmt.Errorf("%s: expected %d, got %d", columnCountErrMsg, expected, actual)
}
Expand All @@ -45,15 +41,35 @@ func invalidatedAppenderError(err error) error {
return fmt.Errorf("%w: %s", err, invalidatedAppenderMsg)
}

func tryOtherFuncError(hint string) error {
return fmt.Errorf("%s: %s", tryOtherFuncErrMsg, hint)
}

func addIndexToError(err error, idx int) error {
return fmt.Errorf("%w: %s: %d", err, indexErrMsg, idx)
}

func interfaceIsNilError(interfaceName string) error {
return fmt.Errorf("%s: %s", interfaceIsNilErrMsg, interfaceName)
}

func duplicateNameError(name string) error {
return fmt.Errorf("%s: %s", duplicateNameErrMsg, name)
}

const (
driverErrMsg = "database/sql/driver"
duckdbErrMsg = "duckdb error"
castErrMsg = "cast error"
structFieldErrMsg = "invalid STRUCT field"
columnErrMsg = "column index"
columnCountErrMsg = "invalid column count"
unsupportedTypeErrMsg = "unsupported data type"
invalidatedAppenderMsg = "appended data has been invalidated due to corrupt row"
tryOtherFuncErrMsg = "please try this function instead"
indexErrMsg = "index"
unknownTypeErrMsg = "unknown type"
interfaceIsNilErrMsg = "interface is nil"
duplicateNameErrMsg = "duplicate name"
)

var (
Expand All @@ -75,6 +91,10 @@ var (
errAppenderClose = errors.New("could not close appender")
errAppenderFlush = errors.New("could not flush appender")

errEmptyName = errors.New("empty name")
errInvalidDecimalWidth = fmt.Errorf("the DECIMAL with must be between 1 and %d", MAX_DECIMAL_WIDTH)
errInvalidDecimalScale = errors.New("the DECIMAL scale must be less than or equal to the width")

// Errors not covered in tests.
errConnect = errors.New("could not connect to database")
errCreateConfig = errors.New("could not create config for database")
Expand Down
Loading

0 comments on commit 7accec1

Please sign in to comment.