From c7d68650084567da5e2b29b1378b312fc422827d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 10 Mar 2024 23:02:06 -0700 Subject: [PATCH] refactor: Skipping slicing on shuffle arrays --- .../shuffle/ArrowReaderIterator.scala | 37 +++---------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index c17c5bce9..e8dba93e7 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle import java.nio.channels.ReadableByteChannel -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf -import org.apache.comet.vector.{NativeUtil, StreamReader} +import org.apache.comet.vector.StreamReader class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { - private val nativeUtil = new NativeUtil - - private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get) - private val reader = StreamReader(channel) - private var currentIdx = -1 private var batch = nextBatch() - private var previousBatch: ColumnarBatch = null private var currentBatch: ColumnarBatch = null override def hasNext: Boolean = { @@ -57,40 +49,20 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna } val nextBatch = batch.get - val batchRows = nextBatch.numRows() - val numRows = Math.min(batchRows - currentIdx, maxBatchSize) - // Release the previous sliced batch. + // Release the previous batch. // If it is not released, when closing the reader, arrow library will complain about // memory leak. if (currentBatch != null) { - // Close plain arrays in the previous sliced batch. - // The dictionary arrays will be closed when closing the entire batch. currentBatch.close() } - currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows) - currentIdx += numRows - - if (currentIdx == batchRows) { - // We cannot close the batch here, because if there is dictionary array in the batch, - // the dictionary array will be closed immediately, and the returned sliced batch will - // be invalid. - previousBatch = batch.get - - batch = None - currentIdx = -1 - } - + currentBatch = nextBatch + batch = None currentBatch } private def nextBatch(): Option[ColumnarBatch] = { - if (previousBatch != null) { - previousBatch.close() - previousBatch = null - } - currentIdx = 0 reader.nextBatch() } @@ -98,6 +70,7 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna synchronized { if (currentBatch != null) { currentBatch.close() + currentBatch = null } reader.close() }