Skip to content

Commit

Permalink
refactor: Skipping slicing on shuffle arrays in shuffle reader (#189)
Browse files Browse the repository at this point in the history
* refactor: Skipping slicing on shuffle arrays

* Add note for columnar shuffle batch size.
  • Loading branch information
viirya authored Mar 11, 2024
1 parent 7ba69d8 commit 4fec40e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 33 deletions.
4 changes: 3 additions & 1 deletion common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ object CometConf {
val COMET_COLUMNAR_SHUFFLE_BATCH_SIZE: ConfigEntry[Int] =
conf("spark.comet.columnar.shuffle.batch.size")
.internal()
.doc("Batch size when writing out sorted spill files on the native side.")
.doc("Batch size when writing out sorted spill files on the native side. Note that " +
"this should not be larger than batch size (i.e., `spark.comet.batchSize`). Otherwise " +
"it will produce larger batches than expected in the native operator after shuffle.")
.intConf
.createWithDefault(8192)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle

import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometConf
import org.apache.comet.vector.{NativeUtil, StreamReader}
import org.apache.comet.vector.StreamReader

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

private val nativeUtil = new NativeUtil

private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get)

private val reader = StreamReader(channel)
private var currentIdx = -1
private var batch = nextBatch()
private var previousBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null

override def hasNext: Boolean = {
Expand All @@ -57,47 +49,28 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
}

val nextBatch = batch.get
val batchRows = nextBatch.numRows()
val numRows = Math.min(batchRows - currentIdx, maxBatchSize)

// Release the previous sliced batch.
// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
// Close plain arrays in the previous sliced batch.
// The dictionary arrays will be closed when closing the entire batch.
currentBatch.close()
}

currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows)
currentIdx += numRows

if (currentIdx == batchRows) {
// We cannot close the batch here, because if there is dictionary array in the batch,
// the dictionary array will be closed immediately, and the returned sliced batch will
// be invalid.
previousBatch = batch.get

batch = None
currentIdx = -1
}

currentBatch = nextBatch
batch = None
currentBatch
}

private def nextBatch(): Option[ColumnarBatch] = {
if (previousBatch != null) {
previousBatch.close()
previousBatch = null
}
currentIdx = 0
reader.nextBatch()
}

def close(): Unit =
synchronized {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
reader.close()
}
Expand Down

0 comments on commit 4fec40e

Please sign in to comment.