diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala index 61d800bfb5..4a08f05213 100644 --- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala +++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * A reader that consumes Arrow data from an input channel, and produces Comet batches. */ -case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable { +case class StreamReader(channel: ReadableByteChannel, source: String) extends AutoCloseable { private var allocator = new RootAllocator(Long.MaxValue) + .newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0, Long.MaxValue) private val channelReader = new MessageChannelReader(new ReadChannel(channel), allocator) private var arrowReader = new ArrowStreamReader(channelReader, allocator) private var root = arrowReader.getVectorSchemaRoot diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index 304c3ce779..3c0fa153d8 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.vector._ -class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { +class ArrowReaderIterator(channel: ReadableByteChannel, source: String) + extends Iterator[ColumnarBatch] { - private val reader = StreamReader(channel) + private val reader = StreamReader(channel, source) private var batch = nextBatch() private var currentBatch: ColumnarBatch = null diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index b461b53f54..2026dfc5c5 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C]( // Closes previous read iterator. currentReadIterator.close() } - currentReadIterator = new ArrowReaderIterator(channel) + currentReadIterator = new ArrowReaderIterator(channel, "shuffle reader") currentReadIterator.map((0, _)) // use 0 as key since it's not used } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index a8322be7d9..f0d235800c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -284,7 +284,7 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - partition.value.value.toIterator.flatMap(CometExec.decodeBatches) + partition.value.value.toIterator.flatMap(CometExec.decodeBatches(_, "broadcast")) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 1065367c27..03be652297 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -78,7 +78,7 @@ abstract class CometExec extends CometPlan { val countsAndBytes = CometExec.getByteArrayRdd(this).collect() val total = countsAndBytes.map(_._1).sum val rows = countsAndBytes.iterator - .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2)) + .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2, "collect")) (total, rows) } } @@ -118,7 +118,7 @@ object CometExec { /** * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. */ - def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { if (bytes.size == 0) { return Iterator.empty } @@ -127,7 +127,7 @@ object CometExec { val cbbis = bytes.toInputStream() val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - new ArrowReaderIterator(Channels.newChannel(ins)) + new ArrowReaderIterator(Channels.newChannel(ins), source) } }