diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 763ccff7f..eb731f9d0 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -33,6 +33,7 @@ class NativeUtil { import Utils._ private val allocator = new RootAllocator(Long.MaxValue) + .newChildAllocator(this.getClass.getSimpleName, 0, Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider private val importer = new ArrowImporter(allocator) diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala index 61d800bfb..4a08f0521 100644 --- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala +++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * A reader that consumes Arrow data from an input channel, and produces Comet batches. */ -case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable { +case class StreamReader(channel: ReadableByteChannel, source: String) extends AutoCloseable { private var allocator = new RootAllocator(Long.MaxValue) + .newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0, Long.MaxValue) private val channelReader = new MessageChannelReader(new ReadChannel(channel), allocator) private var arrowReader = new ArrowStreamReader(channelReader, allocator) private var root = arrowReader.getVectorSchemaRoot diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index 304c3ce77..3c0fa153d 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.vector._ -class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { +class ArrowReaderIterator(channel: ReadableByteChannel, source: String) + extends Iterator[ColumnarBatch] { - private val reader = StreamReader(channel) + private val reader = StreamReader(channel, source) private var batch = nextBatch() private var currentBatch: ColumnarBatch = null diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index b461b53f5..90e0bb14e 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C]( // Closes previous read iterator. currentReadIterator.close() } - currentReadIterator = new ArrowReaderIterator(channel) + currentReadIterator = new ArrowReaderIterator(channel, this.getClass.getSimpleName) currentReadIterator.map((0, _)) // use 0 as key since it's not used } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index a8322be7d..06c5898f7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -284,7 +284,8 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - partition.value.value.toIterator.flatMap(CometExec.decodeBatches) + partition.value.value.toIterator + .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a8579757d..39ffef140 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -86,7 +86,8 @@ abstract class CometExec extends CometPlan { val countsAndBytes = CometExec.getByteArrayRdd(this).collect() val total = countsAndBytes.map(_._1).sum val rows = countsAndBytes.iterator - .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2)) + .flatMap(countAndBytes => + CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName)) (total, rows) } } @@ -126,7 +127,7 @@ object CometExec { /** * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. */ - def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { if (bytes.size == 0) { return Iterator.empty } @@ -135,7 +136,7 @@ object CometExec { val cbbis = bytes.toInputStream() val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - new ArrowReaderIterator(Channels.newChannel(ins)) + new ArrowReaderIterator(Channels.newChannel(ins), source) } }