From a6088fcd9fe908f04e456623af00c63e6cd326f3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Apr 2024 14:55:05 -0700 Subject: [PATCH] chore: Add allocation source to StreamReader (#332) * chore: Add allocation source to StreamReader * Use simple name --- .../main/scala/org/apache/comet/vector/NativeUtil.scala | 1 + .../main/scala/org/apache/comet/vector/StreamReader.scala | 3 ++- .../sql/comet/execution/shuffle/ArrowReaderIterator.scala | 5 +++-- .../execution/shuffle/CometBlockStoreShuffleReader.scala | 2 +- .../spark/sql/comet/CometBroadcastExchangeExec.scala | 3 ++- .../main/scala/org/apache/spark/sql/comet/operators.scala | 7 ++++--- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 763ccff7fd..eb731f9d0c 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -33,6 +33,7 @@ class NativeUtil { import Utils._ private val allocator = new RootAllocator(Long.MaxValue) + .newChildAllocator(this.getClass.getSimpleName, 0, Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider private val importer = new ArrowImporter(allocator) diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala index 61d800bfb5..4a08f05213 100644 --- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala +++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala @@ -30,8 +30,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * A reader that consumes Arrow data from an input channel, and produces Comet batches. */ -case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable { +case class StreamReader(channel: ReadableByteChannel, source: String) extends AutoCloseable { private var allocator = new RootAllocator(Long.MaxValue) + .newChildAllocator(s"${this.getClass.getSimpleName}/$source", 0, Long.MaxValue) private val channelReader = new MessageChannelReader(new ReadChannel(channel), allocator) private var arrowReader = new ArrowStreamReader(channelReader, allocator) private var root = arrowReader.getVectorSchemaRoot diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index 304c3ce779..3c0fa153d8 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -25,9 +25,10 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.vector._ -class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { +class ArrowReaderIterator(channel: ReadableByteChannel, source: String) + extends Iterator[ColumnarBatch] { - private val reader = StreamReader(channel) + private val reader = StreamReader(channel, source) private var batch = nextBatch() private var currentBatch: ColumnarBatch = null diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index b461b53f54..90e0bb14ed 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -108,7 +108,7 @@ class CometBlockStoreShuffleReader[K, C]( // Closes previous read iterator. currentReadIterator.close() } - currentReadIterator = new ArrowReaderIterator(channel) + currentReadIterator = new ArrowReaderIterator(channel, this.getClass.getSimpleName) currentReadIterator.map((0, _)) // use 0 as key since it's not used } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index a8322be7d9..06c5898f79 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -284,7 +284,8 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - partition.value.value.toIterator.flatMap(CometExec.decodeBatches) + partition.value.value.toIterator + .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName)) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a8579757d5..39ffef140f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -86,7 +86,8 @@ abstract class CometExec extends CometPlan { val countsAndBytes = CometExec.getByteArrayRdd(this).collect() val total = countsAndBytes.map(_._1).sum val rows = countsAndBytes.iterator - .flatMap(countAndBytes => CometExec.decodeBatches(countAndBytes._2)) + .flatMap(countAndBytes => + CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName)) (total, rows) } } @@ -126,7 +127,7 @@ object CometExec { /** * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. */ - def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { if (bytes.size == 0) { return Iterator.empty } @@ -135,7 +136,7 @@ object CometExec { val cbbis = bytes.toInputStream() val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - new ArrowReaderIterator(Channels.newChannel(ins)) + new ArrowReaderIterator(Channels.newChannel(ins), source) } }