diff --git a/README.md b/README.md index ab733d06..3703e342 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,8 @@ connector, err := duckdb.NewConnector("test.db", nil) if err != nil { ... } +defer connector.Close() + conn, err := connector.Connect(context.Background()) if err != nil { ... @@ -108,6 +110,8 @@ connector, err := duckdb.NewConnector("", nil) if err != nil { ... } +defer connector.Close() + conn, err := connector.Connect(context.Background()) if err != nil { ... diff --git a/arrow_test.go b/arrow_test.go index 7e6ee4cd..b659d925 100644 --- a/arrow_test.go +++ b/arrow_test.go @@ -22,6 +22,7 @@ func TestArrow(t *testing.T) { t.Run("select series", func(t *testing.T) { c, err := NewConnector("", nil) require.NoError(t, err) + defer c.Close() conn, err := c.Connect(context.Background()) require.NoError(t, err) @@ -46,6 +47,7 @@ func TestArrow(t *testing.T) { t.Run("select long series", func(t *testing.T) { c, err := NewConnector("", nil) require.NoError(t, err) + defer c.Close() conn, err := c.Connect(context.Background()) require.NoError(t, err) diff --git a/duckdb.go b/duckdb.go index f9b2514f..3b4ec620 100644 --- a/duckdb.go +++ b/duckdb.go @@ -34,24 +34,22 @@ func (d Driver) Open(dataSourceName string) (driver.Conn, error) { return connector.Connect(context.Background()) } -func (Driver) OpenConnector(dataSourceName string) (driver.Connector, error) { - return createConnector(dataSourceName, func(execerContext driver.ExecerContext) error { return nil }) +func (Driver) OpenConnector(dsn string) (driver.Connector, error) { + return NewConnector(dsn, func(execerContext driver.ExecerContext) error { return nil }) } -// NewConnector creates a new Connector for the DuckDB database. -func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) { - return createConnector(dsn, connInitFn) -} - -func createConnector(dataSourceName string, connInitFn func(execer driver.ExecerContext) error) (driver.Connector, error) { +// NewConnector opens a new Connector for the DuckDB database. +// It's user's responsibility to close the returned Connector in case it's not passed to the sql.OpenDB function. +// sql.DB will close the Connector when sql.DB.Close() is called. +func NewConnector(dsn string, connInitFn func(execer driver.ExecerContext) error) (*Connector, error) { var db C.duckdb_database - parsedDSN, err := url.Parse(dataSourceName) + parsedDSN, err := url.Parse(dsn) if err != nil { return nil, fmt.Errorf("%w: %s", errParseConfig, err.Error()) } - connectionString := C.CString(extractConnectionString(dataSourceName)) + connectionString := C.CString(extractConnectionString(dsn)) defer C.free(unsafe.Pointer(connectionString)) // Check for config options. @@ -76,23 +74,24 @@ func createConnector(dataSourceName string, connInitFn func(execer driver.Execer } } - return &connector{db: &db, connInitFn: connInitFn}, nil + return &Connector{db: &db, connInitFn: connInitFn}, nil } -type connector struct { +type Connector struct { db *C.duckdb_database connInitFn func(execer driver.ExecerContext) error } -func (c *connector) Driver() driver.Driver { +func (c *Connector) Driver() driver.Driver { return Driver{} } -func (c *connector) Connect(context.Context) (driver.Conn, error) { +func (c *Connector) Connect(context.Context) (driver.Conn, error) { var con C.duckdb_connection if state := C.duckdb_connect(*c.db, &con); state == C.DuckDBError { return nil, errOpen } + conn := &conn{con: &con} // Call the connection init function if defined @@ -101,10 +100,11 @@ func (c *connector) Connect(context.Context) (driver.Conn, error) { return nil, err } } + return conn, nil } -func (c *connector) Close() error { +func (c *Connector) Close() error { C.duckdb_close(c.db) c.db = nil return nil @@ -126,9 +126,8 @@ func prepareConfig(options map[string][]string) (C.duckdb_config, error) { for k, v := range options { if len(v) > 0 { - state := C.duckdb_set_config(config, C.CString(k), C.CString(v[0])) - if state == C.DuckDBError { - return nil, fmt.Errorf("%w: affected config option %s=%s", errPrepareConfig, k, v[0]) + if err := setConfig(config, k, v[0]); err != nil { + return nil, err } } } @@ -136,9 +135,23 @@ func prepareConfig(options map[string][]string) (C.duckdb_config, error) { return config, nil } +func setConfig(config C.duckdb_config, name, option string) error { + cName := C.CString(name) + defer C.free(unsafe.Pointer(cName)) + + cOption := C.CString(option) + defer C.free(unsafe.Pointer(cOption)) + + if state := C.duckdb_set_config(config, cName, cOption); state == C.DuckDBError { + return fmt.Errorf("%w: affected config option %s=%s", errSetConfig, name, option) + } + + return nil +} + var ( - errOpen = errors.New("could not open database") - errParseConfig = errors.New("could not parse config for database") - errCreateConfig = errors.New("could not create config for database") - errPrepareConfig = errors.New("could not set config for database") + errOpen = errors.New("could not open database") + errParseConfig = errors.New("could not parse config for database") + errCreateConfig = errors.New("could not create config for database") + errSetConfig = errors.New("could not set config for database") ) diff --git a/duckdb_test.go b/duckdb_test.go index d9303973..9172b2c4 100644 --- a/duckdb_test.go +++ b/duckdb_test.go @@ -56,7 +56,7 @@ func TestOpen(t *testing.T) { t.Run("with invalid config", func(t *testing.T) { _, err := sql.Open("duckdb", "?threads=NaN") - if !errors.Is(err, errPrepareConfig) { + if !errors.Is(err, errSetConfig) { t.Fatal("invalid config should not be accepted") } })