Skip to content

Commit

Permalink
feat: Add executeColumnarCollectIterator to CometExec to collect Come…
Browse files Browse the repository at this point in the history
…t operator result (#71)
  • Loading branch information
viirya authored Feb 21, 2024
1 parent 1acb56f commit 15b139e
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
70 changes: 70 additions & 0 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@

package org.apache.comet.vector

import java.io.OutputStream

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -32,6 +37,71 @@ class NativeUtil {
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var schemaRoot: Option[VectorSchemaRoot] = None
var writer: Option[ArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava))
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new ArrowStreamWriter(root, provider, out))
writer.get.start()
}
writer.get.writeBatch()

root.clear()
schemaRoot = Some(root)

rowCount += batch.numRows()
}

writer.map(_.end())
schemaRoot.map(_.close())

rowCount
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
} else {
if (provider.get != a.getDictionaryProvider) {
throw new SparkException(
"Comet execution only takes Arrow Arrays with the same dictionary provider")
}
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
Expand Down
52 changes: 50 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@

package org.apache.spark.sql.comet

import java.io.ByteArrayOutputStream
import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.vector.NativeUtil

/**
* A Comet physical operator
Expand All @@ -61,6 +67,48 @@ abstract class CometExec extends CometPlan {
override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering

override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
* Executes this Comet operator and serialized output ColumnarBatch into bytes.
*/
private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

val count = new NativeUtil().serializeBatches(iter, out)

out.flush()
out.close()
Iterator((count, cbbos.toChunkedByteBuffer))
}
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}

/**
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
*/
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
val countsAndBytes = getByteArrayRdd().collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2))
(total, rows)
}
}

object CometExec {
Expand Down
34 changes: 34 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.exec

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.scalactic.source.Position
Expand Down Expand Up @@ -53,6 +55,38 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") {
withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")

val nativeProject = find(df.queryExecution.executedPlan) {
case _: CometProjectExec => true
case _ => false
}.get.asInstanceOf[CometProjectExec]

val (rows, batches) = nativeProject.executeColumnarCollectIterator()
assert(rows == 46)

val column1 = mutable.ArrayBuffer.empty[Int]
val column2 = mutable.ArrayBuffer.empty[Int]

batches.foreach(batch => {
batch.rowIterator().asScala.foreach { row =>
assert(row.numFields == 2)
column1 += row.getInt(0)
column2 += row.getInt(1)
}
})

assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
}
}
}

test("scalar subquery") {
val dataTypes =
Seq(
Expand Down

0 comments on commit 15b139e

Please sign in to comment.