diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 4bb63e501..1682295f7 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -19,11 +19,16 @@ package org.apache.comet.vector +import java.io.OutputStream + +import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data} import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector._ +import org.apache.arrow.vector.dictionary.DictionaryProvider +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.SparkException import org.apache.spark.sql.vectorized.ColumnarBatch @@ -32,6 +37,71 @@ class NativeUtil { private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider private val importer = new ArrowImporter(allocator) + /** + * Serializes a list of `ColumnarBatch` into an output stream. + * + * @param batches + * the output batches, each batch is a list of Arrow vectors wrapped in `CometVector` + * @param out + * the output stream + */ + def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = { + var schemaRoot: Option[VectorSchemaRoot] = None + var writer: Option[ArrowStreamWriter] = None + var rowCount = 0 + + batches.foreach { batch => + val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch) + val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava)) + val provider = batchProviderOpt.getOrElse(dictionaryProvider) + + if (writer.isEmpty) { + writer = Some(new ArrowStreamWriter(root, provider, out)) + writer.get.start() + } + writer.get.writeBatch() + + root.clear() + schemaRoot = Some(root) + + rowCount += batch.numRows() + } + + writer.map(_.end()) + schemaRoot.map(_.close()) + + rowCount + } + + def getBatchFieldVectors( + batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { + var provider: Option[DictionaryProvider] = None + val fieldVectors = (0 until batch.numCols()).map { index => + batch.column(index) match { + case a: CometVector => + val valueVector = a.getValueVector + if (valueVector.getField.getDictionary != null) { + if (provider.isEmpty) { + provider = Some(a.getDictionaryProvider) + } else { + if (provider.get != a.getDictionaryProvider) { + throw new SparkException( + "Comet execution only takes Arrow Arrays with the same dictionary provider") + } + } + } + + getFieldVector(valueVector) + + case c => + throw new SparkException( + "Comet execution only takes Arrow Arrays, but got " + + s"${c.getClass}") + } + } + (fieldVectors, provider) + } + /** * Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the * native execution. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 4d8011e08..12f086c76 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -19,25 +19,31 @@ package org.apache.spark.sql.comet -import java.io.ByteArrayOutputStream +import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable.ArrayBuffer -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} import com.google.common.base.Objects import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException} import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.vector.NativeUtil /** * A Comet physical operator @@ -61,6 +67,48 @@ abstract class CometExec extends CometPlan { override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering override def outputPartitioning: Partitioning = originalPlan.outputPartitioning + + /** + * Executes this Comet operator and serialized output ColumnarBatch into bytes. + */ + private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = { + executeColumnar().mapPartitionsInternal { iter => + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + val out = new DataOutputStream(codec.compressedOutputStream(cbbos)) + + val count = new NativeUtil().serializeBatches(iter, out) + + out.flush() + out.close() + Iterator((count, cbbos.toChunkedByteBuffer)) + } + } + + /** + * Decodes the byte arrays back to ColumnarBatches and put them into buffer. + */ + private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = { + if (bytes.size == 0) { + return Iterator.empty + } + + val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + val cbbis = bytes.toInputStream() + val ins = new DataInputStream(codec.compressedInputStream(cbbis)) + + new ArrowReaderIterator(Channels.newChannel(ins)) + } + + /** + * Executes the Comet operator and returns the result as an iterator of ColumnarBatch. + */ + def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = { + val countsAndBytes = getByteArrayRdd().collect() + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2)) + (total, rows) + } } object CometExec { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1334f2f77..e1f864249 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -19,6 +19,8 @@ package org.apache.comet.exec +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.scalactic.source.Position @@ -53,6 +55,38 @@ class CometExecSuite extends CometTestBase { } } + test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") { + withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") { + val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3") + + val nativeProject = find(df.queryExecution.executedPlan) { + case _: CometProjectExec => true + case _ => false + }.get.asInstanceOf[CometProjectExec] + + val (rows, batches) = nativeProject.executeColumnarCollectIterator() + assert(rows == 46) + + val column1 = mutable.ArrayBuffer.empty[Int] + val column2 = mutable.ArrayBuffer.empty[Int] + + batches.foreach(batch => { + batch.rowIterator().asScala.foreach { row => + assert(row.numFields == 2) + column1 += row.getInt(0) + column2 += row.getInt(1) + } + }) + + assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray) + assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray) + } + } + } + test("scalar subquery") { val dataTypes = Seq(