Skip to content

Commit

Permalink
fix : Ensure rows, conn and db are closed on all error paths (#6321)
Browse files Browse the repository at this point in the history
* ensure rows, conn and db are closed on all error paths

* revert otelsql change
  • Loading branch information
k-anshul committed Jan 6, 2025
1 parent 5b3de29 commit 1fff61e
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions runtime/drivers/snowflake/warehouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"sync"
"time"

"github.com/XSAM/otelsql"
"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/memory"
"github.com/apache/arrow/go/v15/parquet"
Expand All @@ -32,7 +31,7 @@ const rowGroupBufferSize = int64(datasize.MB) * 512
// Fetches query result in arrow batches.
// As an alternative (or in case of memory issues) consider utilizing Snowflake "COPY INTO <location>" feature,
// see https://docs.snowflake.com/en/sql-reference/sql/copy-into-location
func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (drivers.FileIterator, error) {
func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (iter drivers.FileIterator, resErr error) {
srcProps, err := parseSourceProperties(props)
if err != nil {
return nil, err
Expand All @@ -52,29 +51,41 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (dr
parallelFetchLimit = c.configProperties.ParallelFetchLimit
}

db, err := otelsql.Open("snowflake", dsn)
db, err := sql.Open("snowflake", dsn)
if err != nil {
return nil, err
}
defer func() {
if resErr != nil {
db.Close()
}
}()

ctx = sf.WithArrowAllocator(sf.WithArrowBatches(ctx), memory.DefaultAllocator)

conn, err := db.Conn(ctx)
if err != nil {
db.Close()
return nil, err
}
defer func() {
if resErr != nil {
conn.Close()
}
}()

var rows sqld.Rows
err = rawConn(conn, func(x sqld.Conn) error {
err = conn.Raw(func(x interface{}) error {
rows, err = x.(sqld.QueryerContext).QueryContext(ctx, srcProps.SQL, nil)
return err
})
if err != nil {
conn.Close()
db.Close()
return nil, err
}
defer func() {
if resErr != nil {
rows.Close()
}
}()

batches, err := rows.(sf.SnowflakeRows).GetArrowBatches()
if err != nil {
Expand Down Expand Up @@ -119,6 +130,11 @@ type fileIterator struct {

// Close implements drivers.FileIterator.
func (f *fileIterator) Close() error {
if f.rows != nil {
f.rows.Close()
f.conn.Close()
f.db.Close()
}
return os.RemoveAll(f.tempDir)
}

Expand All @@ -134,6 +150,8 @@ func (f *fileIterator) Next() ([]string, error) {
f.rows.Close()
f.conn.Close()
f.db.Close()
// mark rows as nil to prevent double close
f.rows = nil
}()

f.logger.Debug("downloading results in parquet file", observability.ZapCtx(f.ctx))
Expand Down Expand Up @@ -278,21 +296,3 @@ func parseSourceProperties(props map[string]any) (*sourceProperties, error) {
}
return conf, err
}

// rawConn is similar to *sql.Conn.Raw, but additionally unwraps otelsql (which we use for instrumentation).
func rawConn(conn *sql.Conn, f func(sqld.Conn) error) error {
return conn.Raw(func(raw any) error {
// For details, see: https://github.com/XSAM/otelsql/issues/98
if c, ok := raw.(interface{ Raw() sqld.Conn }); ok {
raw = c.Raw()
}

// This is currently guaranteed, but adding check to be safe
driverConn, ok := raw.(sqld.Conn)
if !ok {
return fmt.Errorf("internal: did not obtain a driver.Conn")
}

return f(driverConn)
})
}

0 comments on commit 1fff61e

Please sign in to comment.