diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index ccf218cf6..6bc519ab9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.comet +import java.io.DataInputStream +import java.nio.channels.Channels import java.util.UUID import java.util.concurrent.{Future, TimeoutException, TimeUnit} @@ -26,13 +28,15 @@ import scala.concurrent.{ExecutionContext, Promise} import scala.concurrent.duration.NANOSECONDS import scala.util.control.NonFatal -import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext} +import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext} import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec +import org.apache.spark.io.CompressionCodec import org.apache.spark.launcher.SparkLauncher import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} @@ -299,7 +303,23 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] partition.value.value.toIterator - .flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName)) + .flatMap(decodeBatches(_, this.getClass.getSimpleName)) + } + + /** + * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. + */ + private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + // use Spark's compression codec (LZ4 by default) and not Comet's compression + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + // batches are in Arrow IPC format + new ArrowReaderIterator(Channels.newChannel(ins), source) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 77188312e..c70f7464e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql.comet -import java.io.{ByteArrayOutputStream, DataInputStream} -import java.nio.channels.Channels +import java.io.ByteArrayOutputStream import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.io.CompressionCodec +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} @@ -34,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} -import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} @@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan { // outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec. override def outputPartitioning: Partitioning = originalPlan.outputPartitioning - /** - * Executes the Comet operator and returns the result as an iterator of ColumnarBatch. - */ - def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = { - val countsAndBytes = CometExec.getByteArrayRdd(this).collect() - val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator - .flatMap(countAndBytes => - CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName)) - (total, rows) - } - protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { sparkPlan.children.foreach(setSubqueries(planId, _)) @@ -161,21 +147,6 @@ object CometExec { Utils.serializeBatches(iter) } } - - /** - * Decodes the byte arrays back to ColumnarBatchs and put them into buffer. - */ - def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = { - if (bytes.size == 0) { - return Iterator.empty - } - - val codec = CompressionCodec.createCodec(SparkEnv.get.conf) - val cbbis = bytes.toInputStream() - val ins = new DataInputStream(codec.compressedInputStream(cbbis)) - - new ArrowReaderIterator(Channels.newChannel(ins), source) - } } /** diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 102769537..90c3221e5 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -22,8 +22,6 @@ package org.apache.comet.exec import java.sql.Date import java.time.{Duration, Period} -import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.util.Random import org.scalactic.source.Position @@ -462,37 +460,6 @@ class CometExecSuite extends CometTestBase { } } - test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") { - assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") - withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") { - withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") { - val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3") - - val nativeProject = find(df.queryExecution.executedPlan) { - case _: CometProjectExec => true - case _ => false - }.get.asInstanceOf[CometProjectExec] - - val (rows, batches) = nativeProject.executeColumnarCollectIterator() - assert(rows == 46) - - val column1 = mutable.ArrayBuffer.empty[Int] - val column2 = mutable.ArrayBuffer.empty[Int] - - batches.foreach(batch => { - batch.rowIterator().asScala.foreach { row => - assert(row.numFields == 2) - column1 += row.getInt(0) - column2 += row.getInt(1) - } - }) - - assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray) - assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray) - } - } - } - test("scalar subquery") { val dataTypes = Seq(