Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add executeColumnarCollectIterator to CometExec to collect Comet operator result #71

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@

package org.apache.comet.vector

import java.io.OutputStream

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand All @@ -32,6 +37,71 @@ class NativeUtil {
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var schemaRoot: Option[VectorSchemaRoot] = None
var writer: Option[ArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava))
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new ArrowStreamWriter(root, provider, out))
writer.get.start()
}
writer.get.writeBatch()

root.clear()
schemaRoot = Some(root)

rowCount += batch.numRows()
}

writer.map(_.end())
schemaRoot.map(_.close())

rowCount
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
} else {
if (provider.get != a.getDictionaryProvider) {
throw new SparkException(
"Comet execution only takes Arrow Arrays with the same dictionary provider")
}
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
Expand Down
52 changes: 50 additions & 2 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@

package org.apache.spark.sql.comet

import java.io.ByteArrayOutputStream
import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.TaskContext
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.vector.NativeUtil

/**
* A Comet physical operator
Expand All @@ -61,6 +67,48 @@ abstract class CometExec extends CometPlan {
override def outputOrdering: Seq[SortOrder] = originalPlan.outputOrdering

override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
* Executes this Comet operator and serialized output ColumnarBatch into bytes.
*/
private def getByteArrayRdd(): RDD[(Long, ChunkedByteBuffer)] = {
executeColumnar().mapPartitionsInternal { iter =>
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

val count = new NativeUtil().serializeBatches(iter, out)

out.flush()
out.close()
Iterator((count, cbbos.toChunkedByteBuffer))
}
}

/**
* Decodes the byte arrays back to ColumnarBatches and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins))
}

/**
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
*/
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
val countsAndBytes = getByteArrayRdd().collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeBatches(countAndBytes._2))
(total, rows)
}
}

object CometExec {
Expand Down
34 changes: 34 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.exec

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.scalactic.source.Position
Expand Down Expand Up @@ -53,6 +55,38 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") {
withSQLConf(
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true") {
withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")

val nativeProject = find(df.queryExecution.executedPlan) {
case _: CometProjectExec => true
case _ => false
}.get.asInstanceOf[CometProjectExec]

val (rows, batches) = nativeProject.executeColumnarCollectIterator()
assert(rows == 46)

val column1 = mutable.ArrayBuffer.empty[Int]
val column2 = mutable.ArrayBuffer.empty[Int]

batches.foreach(batch => {
batch.rowIterator().asScala.foreach { row =>
assert(row.numFields == 2)
column1 += row.getInt(0)
column2 += row.getInt(1)
}
})

assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
}
}
}

test("scalar subquery") {
val dataTypes =
Seq(
Expand Down