From 1fff61ec16c45c65b4f9bcc685e0fb19b3e3d692 Mon Sep 17 00:00:00 2001 From: Anshul Khandelwal <12948312+k-anshul@users.noreply.github.com> Date: Mon, 23 Dec 2024 22:25:31 +0530 Subject: [PATCH] fix : Ensure rows, conn and db are closed on all error paths (#6321) * ensure rows, conn and db are closed on all error paths * revert otelsql change --- runtime/drivers/snowflake/warehouse.go | 50 +++++++++++++------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/runtime/drivers/snowflake/warehouse.go b/runtime/drivers/snowflake/warehouse.go index 4f1bbfc20aa..52286ea8bd5 100644 --- a/runtime/drivers/snowflake/warehouse.go +++ b/runtime/drivers/snowflake/warehouse.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/XSAM/otelsql" "github.com/apache/arrow/go/v15/arrow" "github.com/apache/arrow/go/v15/arrow/memory" "github.com/apache/arrow/go/v15/parquet" @@ -32,7 +31,7 @@ const rowGroupBufferSize = int64(datasize.MB) * 512 // Fetches query result in arrow batches. // As an alternative (or in case of memory issues) consider utilizing Snowflake "COPY INTO " feature, // see https://docs.snowflake.com/en/sql-reference/sql/copy-into-location -func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (drivers.FileIterator, error) { +func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (iter drivers.FileIterator, resErr error) { srcProps, err := parseSourceProperties(props) if err != nil { return nil, err @@ -52,29 +51,41 @@ func (c *connection) QueryAsFiles(ctx context.Context, props map[string]any) (dr parallelFetchLimit = c.configProperties.ParallelFetchLimit } - db, err := otelsql.Open("snowflake", dsn) + db, err := sql.Open("snowflake", dsn) if err != nil { return nil, err } + defer func() { + if resErr != nil { + db.Close() + } + }() ctx = sf.WithArrowAllocator(sf.WithArrowBatches(ctx), memory.DefaultAllocator) conn, err := db.Conn(ctx) if err != nil { - db.Close() return nil, err } + defer func() { + if resErr != nil { + conn.Close() + } + }() var rows sqld.Rows - err = rawConn(conn, func(x sqld.Conn) error { + err = conn.Raw(func(x interface{}) error { rows, err = x.(sqld.QueryerContext).QueryContext(ctx, srcProps.SQL, nil) return err }) if err != nil { - conn.Close() - db.Close() return nil, err } + defer func() { + if resErr != nil { + rows.Close() + } + }() batches, err := rows.(sf.SnowflakeRows).GetArrowBatches() if err != nil { @@ -119,6 +130,11 @@ type fileIterator struct { // Close implements drivers.FileIterator. func (f *fileIterator) Close() error { + if f.rows != nil { + f.rows.Close() + f.conn.Close() + f.db.Close() + } return os.RemoveAll(f.tempDir) } @@ -134,6 +150,8 @@ func (f *fileIterator) Next() ([]string, error) { f.rows.Close() f.conn.Close() f.db.Close() + // mark rows as nil to prevent double close + f.rows = nil }() f.logger.Debug("downloading results in parquet file", observability.ZapCtx(f.ctx)) @@ -278,21 +296,3 @@ func parseSourceProperties(props map[string]any) (*sourceProperties, error) { } return conf, err } - -// rawConn is similar to *sql.Conn.Raw, but additionally unwraps otelsql (which we use for instrumentation). -func rawConn(conn *sql.Conn, f func(sqld.Conn) error) error { - return conn.Raw(func(raw any) error { - // For details, see: https://github.com/XSAM/otelsql/issues/98 - if c, ok := raw.(interface{ Raw() sqld.Conn }); ok { - raw = c.Raw() - } - - // This is currently guaranteed, but adding check to be safe - driverConn, ok := raw.(sqld.Conn) - if !ok { - return fmt.Errorf("internal: did not obtain a driver.Conn") - } - - return f(driverConn) - }) -}