Skip to content

Commit

Permalink
refactor Connector to expose Close method
Browse files Browse the repository at this point in the history
  • Loading branch information
levakin committed Jan 23, 2024
1 parent badd00c commit fa320f4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ connector, err := duckdb.NewConnector("test.db", nil)
if err != nil {
...
}
defer connector.Close()

conn, err := connector.Connect(context.Background())
if err != nil {
...
Expand Down Expand Up @@ -108,6 +110,8 @@ connector, err := duckdb.NewConnector("", nil)
if err != nil {
...
}
defer connector.Close()

conn, err := connector.Connect(context.Background())
if err != nil {
...
Expand Down
2 changes: 2 additions & 0 deletions arrow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestArrow(t *testing.T) {
t.Run("select series", func(t *testing.T) {
c, err := NewConnector("", nil)
require.NoError(t, err)
defer c.Close()

conn, err := c.Connect(context.Background())
require.NoError(t, err)
Expand All @@ -46,6 +47,7 @@ func TestArrow(t *testing.T) {
t.Run("select long series", func(t *testing.T) {
c, err := NewConnector("", nil)
require.NoError(t, err)
defer c.Close()

conn, err := c.Connect(context.Background())
require.NoError(t, err)
Expand Down
53 changes: 30 additions & 23 deletions duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@ func (d Driver) Open(dataSourceName string) (driver.Conn, error) {
return connector.Connect(context.Background())
}

func (Driver) OpenConnector(dataSourceName string) (driver.Connector, error) {
return createConnector(dataSourceName, func(execerContext driver.ExecerContext) error { return nil })
func (Driver) OpenConnector(dsn string) (driver.Connector, error) {
return NewConnector(dsn, func(execerContext driver.ExecerContext) error { return nil })
}

// NewConnector creates a new Connector for the DuckDB database.
func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) {
return createConnector(dsn, connInitFn)
}

func createConnector(dataSourceName string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) {
// NewConnector opens a new Connector for the DuckDB database.
// It's user's responsibility to close the returned Connector in case it's not passed to the sql.OpenDB function.
// sql.DB will close the Connector when sql.DB.Close() is called.
func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (*Connector, error) {
var db C.duckdb_database

parsedDSN, err := url.Parse(dataSourceName)
parsedDSN, err := url.Parse(dsn)
if err != nil {
return nil, fmt.Errorf("%w: %s", errParseConfig, err.Error())
}
Expand All @@ -57,7 +55,7 @@ func createConnector(dataSourceName string, connInitFn func(execer driver.Execer
}
defer C.duckdb_destroy_config(&config)

connectionString := C.CString(extractConnectionString(dataSourceName))
connectionString := C.CString(extractConnectionString(dsn))
defer C.free(unsafe.Pointer(connectionString))

var errMsg *C.char
Expand All @@ -67,23 +65,24 @@ func createConnector(dataSourceName string, connInitFn func(execer driver.Execer
return nil, fmt.Errorf("%w: %s", errOpen, C.GoString(errMsg))
}

return &connector{db: &db, connInitFn: connInitFn}, nil
return &Connector{db: &db, connInitFn: connInitFn}, nil
}

type connector struct {
type Connector struct {
db *C.duckdb_database
connInitFn func(execer driver.ExecerContext) error
}

func (c *connector) Driver() driver.Driver {
func (c *Connector) Driver() driver.Driver {
return Driver{}
}

func (c *connector) Connect(context.Context) (driver.Conn, error) {
func (c *Connector) Connect(context.Context) (driver.Conn, error) {
var con C.duckdb_connection
if state := C.duckdb_connect(*c.db, &con); state == C.DuckDBError {
return nil, errOpen
}

conn := &conn{con: &con}

// Call the connection init function if defined
Expand All @@ -92,10 +91,11 @@ func (c *connector) Connect(context.Context) (driver.Conn, error) {
return nil, err
}
}

return conn, nil
}

func (c *connector) Close() error {
func (c *Connector) Close() error {
C.duckdb_close(c.db)
c.db = nil
return nil
Expand All @@ -115,15 +115,14 @@ func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) {
return nil, errCreateConfig
}
if state := C.duckdb_set_config(config, C.CString("duckdb_api"), C.CString("go")); state == C.DuckDBError {
return nil, fmt.Errorf("%w: failed to set duckdb_api", errPrepareConfig)
return nil, fmt.Errorf("%w: failed to set duckdb_api", errSetConfig)
}

if len(parsedDSN.RawQuery) > 0 {
for k, v := range parsedDSN.Query() {
if len(v) > 0 {
state := C.duckdb_set_config(config, C.CString(k), C.CString(v[0]))
if state == C.DuckDBError {
return nil, fmt.Errorf("%w: affected config option %s=%s", errPrepareConfig, k, v[0])
if err := setConfig(config, k, v[0]); err != nil {
return nil, err
}
}
}
Expand All @@ -132,9 +131,17 @@ func prepareConfig(parsedDSN *url.URL) (C.duckdb_config, error) {
return config, nil
}

func setConfig(config C.duckdb_config, name, option string) error {
if state := C.duckdb_set_config(config, C.CString(name), C.CString(option)); state == C.DuckDBError {
return fmt.Errorf("%w: affected config option %s=%s", errSetConfig, name, option)
}

return nil
}

var (
errOpen = errors.New("could not open database")
errParseConfig = errors.New("could not parse config for database")
errCreateConfig = errors.New("could not create config for database")
errPrepareConfig = errors.New("could not set config for database")
errOpen = errors.New("could not open database")
errParseConfig = errors.New("could not parse config for database")
errCreateConfig = errors.New("could not create config for database")
errSetConfig = errors.New("could not set config for database")
)
2 changes: 1 addition & 1 deletion duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestOpen(t *testing.T) {
t.Run("with invalid config", func(t *testing.T) {
_, err := sql.Open("duckdb", "?threads=NaN")

if !errors.Is(err, errPrepareConfig) {
if !errors.Is(err, errSetConfig) {
t.Fatal("invalid config should not be accepted")
}
})
Expand Down

0 comments on commit fa320f4

Please sign in to comment.