From 6cb63eda81ba9c7579c915083aa431757e7580dc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 1 Apr 2024 00:00:01 -0700 Subject: [PATCH] Fix --- .../org/apache/comet/vector/NativeUtil.scala | 40 +------------------ .../apache/comet/vector/StreamReader.scala | 1 + .../comet/CometSparkSessionExtensions.scala | 2 +- .../comet/CometBroadcastExchangeExec.scala | 4 +- .../apache/spark/sql/comet/operators.scala | 37 +++++++++++++++-- .../spark/sql/CometTPCDSQuerySuite.scala | 6 +-- 6 files changed, 42 insertions(+), 48 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index cc726e3e8..2ec3b7973 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -19,10 +19,6 @@ package org.apache.comet.vector -import java.io.OutputStream -import java.nio.channels.Channels - -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} @@ -32,46 +28,12 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider import org.apache.spark.SparkException import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometArrowStreamWriter - class NativeUtil { private val allocator = new RootAllocator(Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider private val importer = new ArrowImporter(allocator) - /** - * Serializes a list of `ColumnarBatch` into an output stream. - * - * @param batches - * the output batches, each batch is a list of Arrow vectors wrapped in `CometVector` - * @param out - * the output stream - */ - def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = { - var writer: Option[CometArrowStreamWriter] = None - var rowCount = 0 - - batches.foreach { batch => - val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch) - val root = new VectorSchemaRoot(fieldVectors.asJava) - val provider = batchProviderOpt.getOrElse(dictionaryProvider) - - if (writer.isEmpty) { - writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out))) - writer.get.start() - writer.get.writeBatch() - } else { - writer.get.writeMoreBatch(root) - } - - root.clear() - rowCount += batch.numRows() - } - - writer.map(_.end()) - - rowCount - } + def getDictionaryProvider: DictionaryProvider = dictionaryProvider def getBatchFieldVectors( batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala index 9c4f99602..da72383e8 100644 --- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala +++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala @@ -51,6 +51,7 @@ case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable { // Native shuffle always uses decimal128. CometVector.getVector(vector, true, arrowReader).asInstanceOf[ColumnVector] }.toArray + val batch = new ColumnarBatch(columns) batch.setNumRows(root.getRowCount) batch diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 0d8aed743..77951943f 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -460,7 +460,7 @@ class CometSparkSessionExtensions case other => other } if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) { - val newPlan = transform(plan.withNewChildren(newChildren)) + val newPlan = apply(plan.withNewChildren(newChildren)) if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) { newPlan } else { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 24f9f3279..4797907b9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -95,6 +95,8 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan) @transient private lazy val maxBroadcastRows = 512000000 + private lazy val childRDD = child.asInstanceOf[CometExec].executeColumnar() + @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( @@ -191,7 +193,7 @@ case class CometBroadcastExchangeExec(originalPlan: SparkPlan, child: SparkPlan) override protected def doExecuteColumnar(): RDD[ColumnarBatch] = { val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]() - new CometBatchRDD(sparkContext, broadcasted.value.length, broadcasted) + new CometBatchRDD(sparkContext, childRDD.getNumPartitions, broadcasted) } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 84734a175..a75a8e145 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.comet -import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream, OutputStream} import java.nio.ByteBuffer import java.nio.channels.Channels +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.RDD @@ -77,18 +80,44 @@ abstract class CometExec extends CometPlan { */ def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { executeColumnar().mapPartitionsInternal { iter => + serializeBatches(iter) + } + } + + /** + * Serializes a list of `ColumnarBatch` into an output stream. + * + * @param batches + * the output batches, each batch is a list of Arrow vectors wrapped in `CometVector` + * @param out + * the output stream + */ + def serializeBatches(batches: Iterator[ColumnarBatch]): Iterator[(Long, ChunkedByteBuffer)] = { + val nativeUtil = new NativeUtil() + + batches.map { batch => val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) val out = new DataOutputStream(codec.compressedOutputStream(cbbos)) - val count = new NativeUtil().serializeBatches(iter, out) + val (fieldVectors, batchProviderOpt) = nativeUtil.getBatchFieldVectors(batch) + val root = new VectorSchemaRoot(fieldVectors.asJava) + val provider = batchProviderOpt.getOrElse(nativeUtil.getDictionaryProvider) + + val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out)) + writer.start() + writer.writeBatch() + + root.clear() + writer.end() out.flush() out.close() + if (out.size() > 0) { - Iterator((count, cbbos.toChunkedByteBuffer)) + (batch.numRows(), cbbos.toChunkedByteBuffer) } else { - Iterator((count, new ChunkedByteBuffer(Array.empty[ByteBuffer]))) + (batch.numRows(), new ChunkedByteBuffer(Array.empty[ByteBuffer])) } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index 265235ffe..a3b73dfa0 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -138,9 +138,9 @@ class CometTPCDSQuerySuite "q99") // TODO: enable the 3 queries after fixing the issues #1358. - override val tpcdsQueries: Seq[String] = - tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains) - + override val tpcdsQueries: Seq[String] = Seq("q4") + // tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains) + // Seq("q1", "q2", "q3", "q4") } with TPCDSQueryTestSuite { override def sparkConf: SparkConf = {