Skip to content

Commit

Permalink
chore: Add allocation source to StreamReader
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 26, 2024
1 parent 21717eb commit d05b76e
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
/**
* A reader that consumes Arrow data from an input channel, and produces Comet batches.
*/
case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
case class StreamReader(channel: ReadableByteChannel, source: String) extends AutoCloseable {
private var allocator = new RootAllocator(Long.MaxValue)
.newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0, Long.MaxValue)
private val channelReader = new MessageChannelReader(new ReadChannel(channel), allocator)
private var arrowReader = new ArrowStreamReader(channelReader, allocator)
private var root = arrowReader.getVectorSchemaRoot
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.vector._

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {
class ArrowReaderIterator(channel: ReadableByteChannel, source: String)
extends Iterator[ColumnarBatch] {

private val reader = StreamReader(channel)
private val reader = StreamReader(channel, source)
private var batch = nextBatch()
private var currentBatch: ColumnarBatch = null

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C](
// Closes previous read iterator.
currentReadIterator.close()
}
currentReadIterator = new ArrowReaderIterator(channel)
currentReadIterator = new ArrowReaderIterator(channel, "shuffle reader")
currentReadIterator.map((0, _)) // use 0 as key since it's not used
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class CometBatchRDD(

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
partition.value.value.toIterator.flatMap(CometExec.decodeBatches(_, "broadcast"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ abstract class CometExec extends CometPlan {
val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator
.flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2))
.flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2, "collect"))
(total, rows)
}
}
Expand Down Expand Up @@ -118,7 +118,7 @@ object CometExec {
/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
*/
def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}
Expand All @@ -127,7 +127,7 @@ object CometExec {
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
new ArrowReaderIterator(Channels.newChannel(ins), source)
}
}

Expand Down

0 comments on commit d05b76e

Please sign in to comment.