From 4ae69e67e26b67d2df15d34800a7d08ccfc7b534 Mon Sep 17 00:00:00 2001 From: taniabogatsch <44262898+taniabogatsch@users.noreply.github.com> Date: Thu, 25 Jan 2024 13:50:16 +0100 Subject: [PATCH] more refactoring and leak fixes --- duckdb.go | 95 ++++++++++++++++++++++++++++---------------------- duckdb_test.go | 11 ++++++ 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/duckdb.go b/duckdb.go index f4f04e39..84769c39 100644 --- a/duckdb.go +++ b/duckdb.go @@ -26,8 +26,8 @@ func init() { type Driver struct{} -func (d Driver) Open(dataSourceName string) (driver.Conn, error) { - connector, err := d.OpenConnector(dataSourceName) +func (d Driver) Open(dsn string) (driver.Conn, error) { + connector, err := d.OpenConnector(dsn) if err != nil { return nil, err } @@ -35,47 +35,47 @@ func (d Driver) Open(dataSourceName string) (driver.Conn, error) { } func (Driver) OpenConnector(dsn string) (driver.Connector, error) { - return NewConnector(dsn, func(execerContext driver.ExecerContext) error { return nil }) + return NewConnector(dsn, func(execerContext driver.ExecerContext) error { + return nil + }) } -// NewConnector opens a new Connector for the DuckDB database. -// It's user's responsibility to close the returned Connector in case it's not passed to the sql.OpenDB function. -// sql.DB will close the Connector when sql.DB.Close() is called. +// NewConnector opens a new Connector for a DuckDB database. +// The user must close the returned Connector, if it is not passed to the sql.OpenDB function. +// Otherwise, sql.DB closes the Connector when calling sql.DB.Close(). func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (*Connector, error) { + var db C.duckdb_database parsedDSN, err := url.Parse(dsn) if err != nil { - return nil, fmt.Errorf("%w: %s", errParseConfig, err.Error()) + return nil, fmt.Errorf("%w: %s", errParseDSN, err.Error()) } config, err := prepareConfig(parsedDSN) if err != nil { return nil, err } + defer C.duckdb_destroy_config(&config) - connectionString := C.CString(extractConnectionString(dsn)) - defer C.free(unsafe.Pointer(connectionString)) - - var errMsg *C.char - defer C.duckdb_free(unsafe.Pointer(errMsg)) + connStr := C.CString(extractConnectionString(dsn)) + defer C.free(unsafe.Pointer(connStr)) - if state := C.duckdb_open_ext(connectionString, &db, config, &errMsg); state == C.DuckDBError { - C.duckdb_destroy_config(&config) + var errOpenMsg *C.char + defer C.duckdb_free(unsafe.Pointer(errOpenMsg)) - return nil, fmt.Errorf("%w: %s", errOpen, C.GoString(errMsg)) + if state := C.duckdb_open_ext(connStr, &db, config, &errOpenMsg); state == C.DuckDBError { + return nil, fmt.Errorf("%w: %s", errOpen, C.GoString(errOpenMsg)) } return &Connector{ - db: &db, + db: db, connInitFn: connInitFn, - config: config, }, nil } type Connector struct { - db *C.duckdb_database - config C.duckdb_config + db C.duckdb_database connInitFn func(execer driver.ExecerContext) error } @@ -84,8 +84,9 @@ func (c *Connector) Driver() driver.Driver { } func (c *Connector) Connect(context.Context) (driver.Conn, error) { + var con C.duckdb_connection - if state := C.duckdb_connect(*c.db, &con); state == C.DuckDBError { + if state := C.duckdb_connect(c.db, &con); state == C.DuckDBError { return nil, errOpen } @@ -102,50 +103,60 @@ func (c *Connector) Connect(context.Context) (driver.Conn, error) { } func (c *Connector) Close() error { - C.duckdb_close(c.db) + C.duckdb_close(&c.db) c.db = nil - - C.duckdb_destroy_config(&c.config) - c.config = nil - return nil } -func extractConnectionString(dataSourceName string) string { - var queryIndex = strings.Index(dataSourceName, "?") +func extractConnectionString(dsn string) string { + var queryIndex = strings.Index(dsn, "?") if queryIndex < 0 { - queryIndex = len(dataSourceName) + queryIndex = len(dsn) } - return dataSourceName[0:queryIndex] + return dsn[0:queryIndex] } func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) { + var config C.duckdb_config if state := C.duckdb_create_config(&config); state == C.DuckDBError { + C.duckdb_destroy_config(&config) return nil, errCreateConfig } - if state := C.duckdb_set_config(config, C.CString("duckdb_api"), C.CString("go")); state == C.DuckDBError { - return nil, fmt.Errorf("%w: failed to set duckdb_api", errSetConfig) + + if err := setConfig(config, "duckdb_api", "go"); err != nil { + return nil, err } - if len(parsedDSN.RawQuery) > 0 { - for k, v := range parsedDSN.Query() { - if len(v) > 0 { - if err := setConfig(config, k, v[0]); err != nil { - C.duckdb_destroy_config(&config) + // early-out + if len(parsedDSN.RawQuery) == 0 { + return config, nil + } - return nil, err - } - } + for k, v := range parsedDSN.Query() { + if len(v) == 0 { + continue + } + if err := setConfig(config, k, v[0]); err != nil { + return nil, err } } return config, nil } -func setConfig(config C.duckdb_config, name, option string) error { - if state := C.duckdb_set_config(config, C.CString(name), C.CString(option)); state == C.DuckDBError { +func setConfig(config C.duckdb_config, name string, option string) error { + + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + + cOption := C.CString(option) + defer C.free(unsafe.Pointer(cOption)) + + state := C.duckdb_set_config(config, cName, cOption) + if state == C.DuckDBError { + C.duckdb_destroy_config(&config) return fmt.Errorf("%w: affected config option %s=%s", errSetConfig, name, option) } @@ -154,7 +165,7 @@ func setConfig(config C.duckdb_config, name, option string) error { var ( errOpen = errors.New("could not open database") - errParseConfig = errors.New("could not parse config for database") + errParseDSN = errors.New("could not parse DSN for database") errCreateConfig = errors.New("could not create config for database") errSetConfig = errors.New("could not set config for database") ) diff --git a/duckdb_test.go b/duckdb_test.go index 9172b2c4..2fd2bd67 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -62,6 +62,17 @@ func TestOpen(t *testing.T) { }) } +func TestConnector_Close(t *testing.T) { + t.Parallel() + + connector, err := NewConnector("", nil) + require.NoError(t, err) + + // check that multiple close calls don't cause panics or errors + require.NoError(t, connector.Close()) + require.NoError(t, connector.Close()) +} + func TestConnPool(t *testing.T) { db := openDB(t) db.SetMaxOpenConns(2) // set connection pool size greater than 1