Skip to content

Commit

Permalink
refactor: Skipping slicing on shuffle arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Mar 11, 2024
1 parent 488c523 commit c7d6865
Showing 1 changed file with 5 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle

import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometConf
import org.apache.comet.vector.{NativeUtil, StreamReader}
import org.apache.comet.vector.StreamReader

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

private val nativeUtil = new NativeUtil

private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get)

private val reader = StreamReader(channel)
private var currentIdx = -1
private var batch = nextBatch()
private var previousBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null

override def hasNext: Boolean = {
Expand All @@ -57,47 +49,28 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
}

val nextBatch = batch.get
val batchRows = nextBatch.numRows()
val numRows = Math.min(batchRows - currentIdx, maxBatchSize)

// Release the previous sliced batch.
// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
// Close plain arrays in the previous sliced batch.
// The dictionary arrays will be closed when closing the entire batch.
currentBatch.close()
}

currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows)
currentIdx += numRows

if (currentIdx == batchRows) {
// We cannot close the batch here, because if there is dictionary array in the batch,
// the dictionary array will be closed immediately, and the returned sliced batch will
// be invalid.
previousBatch = batch.get

batch = None
currentIdx = -1
}

currentBatch = nextBatch
batch = None
currentBatch
}

private def nextBatch(): Option[ColumnarBatch] = {
if (previousBatch != null) {
previousBatch.close()
previousBatch = null
}
currentIdx = 0
reader.nextBatch()
}

def close(): Unit =
synchronized {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
reader.close()
}
Expand Down

0 comments on commit c7d6865

Please sign in to comment.