Skip to content

Commit

Permalink
more refactoring and leak fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
taniabogatsch committed Jan 25, 2024
1 parent 9dfac4e commit 4ae69e6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 42 deletions.
95 changes: 53 additions & 42 deletions duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,56 +26,56 @@ func init() {

type Driver struct{}

func (d Driver) Open(dataSourceName string) (driver.Conn, error) {
connector, err := d.OpenConnector(dataSourceName)
func (d Driver) Open(dsn string) (driver.Conn, error) {
connector, err := d.OpenConnector(dsn)
if err != nil {
return nil, err
}
return connector.Connect(context.Background())
}

func (Driver) OpenConnector(dsn string) (driver.Connector, error) {
return NewConnector(dsn, func(execerContext driver.ExecerContext) error { return nil })
return NewConnector(dsn, func(execerContext driver.ExecerContext) error {
return nil
})
}

// NewConnector opens a new Connector for the DuckDB database.
// It's user's responsibility to close the returned Connector in case it's not passed to the sql.OpenDB function.
// sql.DB will close the Connector when sql.DB.Close() is called.
// NewConnector opens a new Connector for a DuckDB database.
// The user must close the returned Connector, if it is not passed to the sql.OpenDB function.
// Otherwise, sql.DB closes the Connector when calling sql.DB.Close().
func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (*Connector, error) {

var db C.duckdb_database

parsedDSN, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("%w: %s", errParseConfig, err.Error())
return nil, fmt.Errorf("%w: %s", errParseDSN, err.Error())
}

config, err := prepareConfig(parsedDSN)
if err != nil {
return nil, err
}
defer C.duckdb_destroy_config(&config)

connectionString := C.CString(extractConnectionString(dsn))
defer C.free(unsafe.Pointer(connectionString))

var errMsg *C.char
defer C.duckdb_free(unsafe.Pointer(errMsg))
connStr := C.CString(extractConnectionString(dsn))
defer C.free(unsafe.Pointer(connStr))

if state := C.duckdb_open_ext(connectionString, &db, config, &errMsg); state == C.DuckDBError {
C.duckdb_destroy_config(&config)
var errOpenMsg *C.char
defer C.duckdb_free(unsafe.Pointer(errOpenMsg))

return nil, fmt.Errorf("%w: %s", errOpen, C.GoString(errMsg))
if state := C.duckdb_open_ext(connStr, &db, config, &errOpenMsg); state == C.DuckDBError {
return nil, fmt.Errorf("%w: %s", errOpen, C.GoString(errOpenMsg))
}

return &Connector{
db: &db,
db: db,
connInitFn: connInitFn,
config: config,
}, nil
}

type Connector struct {
db *C.duckdb_database
config C.duckdb_config
db C.duckdb_database
connInitFn func(execer driver.ExecerContext) error
}

Expand All @@ -84,8 +84,9 @@ func (c *Connector) Driver() driver.Driver {
}

func (c *Connector) Connect(context.Context) (driver.Conn, error) {

var con C.duckdb_connection
if state := C.duckdb_connect(*c.db, &con); state == C.DuckDBError {
if state := C.duckdb_connect(c.db, &con); state == C.DuckDBError {
return nil, errOpen
}

Expand All @@ -102,50 +103,60 @@ func (c *Connector) Connect(context.Context) (driver.Conn, error) {
}

func (c *Connector) Close() error {
C.duckdb_close(c.db)
C.duckdb_close(&c.db)
c.db = nil

C.duckdb_destroy_config(&c.config)
c.config = nil

return nil
}

func extractConnectionString(dataSourceName string) string {
var queryIndex = strings.Index(dataSourceName, "?")
func extractConnectionString(dsn string) string {
var queryIndex = strings.Index(dsn, "?")
if queryIndex < 0 {
queryIndex = len(dataSourceName)
queryIndex = len(dsn)
}

return dataSourceName[0:queryIndex]
return dsn[0:queryIndex]
}

func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) {

var config C.duckdb_config
if state := C.duckdb_create_config(&config); state == C.DuckDBError {
C.duckdb_destroy_config(&config)
return nil, errCreateConfig
}
if state := C.duckdb_set_config(config, C.CString("duckdb_api"), C.CString("go")); state == C.DuckDBError {
return nil, fmt.Errorf("%w: failed to set duckdb_api", errSetConfig)

if err := setConfig(config, "duckdb_api", "go"); err != nil {
return nil, err
}

if len(parsedDSN.RawQuery) > 0 {
for k, v := range parsedDSN.Query() {
if len(v) > 0 {
if err := setConfig(config, k, v[0]); err != nil {
C.duckdb_destroy_config(&config)
// early-out
if len(parsedDSN.RawQuery) == 0 {
return config, nil
}

return nil, err
}
}
for k, v := range parsedDSN.Query() {
if len(v) == 0 {
continue
}
if err := setConfig(config, k, v[0]); err != nil {
return nil, err
}
}

return config, nil
}

func setConfig(config C.duckdb_config, name, option string) error {
if state := C.duckdb_set_config(config, C.CString(name), C.CString(option)); state == C.DuckDBError {
func setConfig(config C.duckdb_config, name string, option string) error {

cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))

cOption := C.CString(option)
defer C.free(unsafe.Pointer(cOption))

state := C.duckdb_set_config(config, cName, cOption)
if state == C.DuckDBError {
C.duckdb_destroy_config(&config)
return fmt.Errorf("%w: affected config option %s=%s", errSetConfig, name, option)
}

Expand All @@ -154,7 +165,7 @@ func setConfig(config C.duckdb_config, name, option string) error {

var (
errOpen = errors.New("could not open database")
errParseConfig = errors.New("could not parse config for database")
errParseDSN = errors.New("could not parse DSN for database")
errCreateConfig = errors.New("could not create config for database")
errSetConfig = errors.New("could not set config for database")
)
11 changes: 11 additions & 0 deletions duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ func TestOpen(t *testing.T) {
})
}

func TestConnector_Close(t *testing.T) {
t.Parallel()

connector, err := NewConnector("", nil)
require.NoError(t, err)

// check that multiple close calls don't cause panics or errors
require.NoError(t, connector.Close())
require.NoError(t, connector.Close())
}

func TestConnPool(t *testing.T) {
db := openDB(t)
db.SetMaxOpenConns(2) // set connection pool size greater than 1
Expand Down

0 comments on commit 4ae69e6

Please sign in to comment.