Skip to content

Commit

Permalink
feat: Support multiple input sources for CometNativeExec (#87)
Browse files Browse the repository at this point in the history
CometNativeExec currently limits the number of input source. That blocks the operators with multiple input sources like join operator. This patch generalizes the input source handling to remove the limitation.
  • Loading branch information
viirya authored Feb 22, 2024
1 parent 0cca52e commit 0ed183a
Show file tree
Hide file tree
Showing 11 changed files with 279 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometProjectExec(nativeOp, op, op.projectList, op.output, op.child)
CometProjectExec(nativeOp, op, op.projectList, op.output, op.child, None)
case None =>
op
}
Expand All @@ -246,7 +246,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometFilterExec(nativeOp, op, op.condition, op.child)
CometFilterExec(nativeOp, op, op.condition, op.child, None)
case None =>
op
}
Expand All @@ -255,7 +255,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometSortExec(nativeOp, op, op.sortOrder, op.child)
CometSortExec(nativeOp, op, op.sortOrder, op.child, None)
case None =>
op
}
Expand All @@ -264,7 +264,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometLocalLimitExec(nativeOp, op, op.limit, op.child)
CometLocalLimitExec(nativeOp, op, op.limit, op.child, None)
case None =>
op
}
Expand All @@ -273,7 +273,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometGlobalLimitExec(nativeOp, op, op.limit, op.child)
CometGlobalLimitExec(nativeOp, op, op.limit, op.child, None)
case None =>
op
}
Expand All @@ -282,7 +282,7 @@ class CometSparkSessionExtensions
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometExpandExec(nativeOp, op, op.projections, op.child)
CometExpandExec(nativeOp, op, op.projections, op.child, None)
case None =>
op
}
Expand All @@ -304,7 +304,8 @@ class CometSparkSessionExtensions
aggExprs,
child.output,
if (modes.nonEmpty) Some(modes.head) else None,
child)
child,
None)
case None =>
op
}
Expand Down Expand Up @@ -425,10 +426,11 @@ class CometSparkSessionExtensions
newPlan.transformDown {
case op: CometNativeExec =>
if (firstNativeOp) {
op.convertBlock()
firstNativeOp = false
op.convertBlock()
} else {
op
}
op
case op =>
firstNativeOp = true
op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case op if !op.isInstanceOf[CometPlan] =>
seenNonNativeOp = true
op
case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _) =>
case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _, _) =>
if (!seenNonNativeOp && partialAggInput.isEmpty) {
partialAggInput = Some(input)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.{RDD, RDDOperationScope, ZippedPartitionsBaseRDD, ZippedPartitionsPartition}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Similar to Spark `ZippedPartitionsRDD[1-4]` classes, this class is used to zip partitions of
* the multiple RDDs into a single RDD. Spark `ZippedPartitionsRDD[1-4]` classes only support at
* most 4 RDDs. This class is used to support more than 4 RDDs. This ZipPartitionsRDD is used to
* zip the input sources of the Comet physical plan. So it only zips partitions of ColumnarBatch.
*/
private[spark] class ZippedPartitionsRDD(
sc: SparkContext,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
var zipRdds: Seq[RDD[ColumnarBatch]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) {

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
val iterators =
zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context))
f(iterators)
}

override def clearDependencies(): Unit = {
super.clearDependencies()
zipRdds = null
f = null
}
}

object ZippedPartitionsRDD {
def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
withScope(sc) {
new ZippedPartitionsRDD(sc, f, rdds)
}

private[spark] def withScope[U](sc: SparkContext)(body: => U): U =
RDDOperationScope.withScope[U](sc)(body)
}
Loading

0 comments on commit 0ed183a

Please sign in to comment.