diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 33cf636a2e..e4767b4e1e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -23,7 +23,10 @@ use arrow_schema::{DataType, Field, Schema, TimeUnit}; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, - logical_expr::{BuiltinScalarFunction, Operator as DataFusionOperator}, + logical_expr::{ + expr::find_df_window_func, BuiltinScalarFunction, Operator as DataFusionOperator, + WindowFrame, WindowFrameBound, WindowFrameUnits, + }, physical_expr::{ expressions::{BinaryExpr, Column, IsNotNullExpr, Literal as DataFusionLiteral}, functions::create_physical_expr, @@ -35,7 +38,8 @@ use datafusion::{ limit::LocalLimitExec, projection::ProjectionExec, sorts::sort::SortExec, - ExecutionPlan, Partitioning, + windows::BoundedWindowAggExec, + ExecutionPlan, InputOrderMode, Partitioning, WindowExpr, }, }; use datafusion_common::ScalarValue; @@ -74,12 +78,14 @@ use crate::{ }, operators::{CopyExec, ExecutionError, ScanExec}, serde::to_arrow_datatype, - spark_expression, spark_expression::{ - agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr, - ScalarFunc, + self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, + Expr, ScalarFunc, + }, + spark_operator::{ + lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct, + upper_window_frame_bound::UpperFrameBoundStruct, Operator, WindowFrameType, }, - spark_operator::{operator::OpStruct, Operator}, spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; @@ -797,6 +803,50 @@ impl PhysicalPlanner { )?), )) } + OpStruct::Window(wnd) => { + dbg!(&inputs); + //dbg!(&children); + let (scans, child) = self.create_plan(&children[0], inputs)?; + let input_schema = child.schema(); + //dbg!(&input_schema); + let sort_exprs: Result, ExecutionError> = wnd + .order_by_list + .iter() + .map(|expr| self.create_sort_expr(expr, input_schema.clone())) + .collect(); + + let partition_exprs: Result>, ExecutionError> = wnd + .partition_by_list + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect(); + + let sort_exprs = &sort_exprs?; + let partition_exprs = &partition_exprs?; + + let window_expr: Result>, ExecutionError> = wnd + .window_expr + .iter() + .map(|expr| { + self.create_window_expr( + expr, + input_schema.clone(), + partition_exprs, + sort_exprs, + ) + }) + .collect(); + + Ok(( + scans, + Arc::new(BoundedWindowAggExec::try_new( + window_expr?, + child, + partition_exprs.to_vec(), + InputOrderMode::Sorted, + )?), + )) + } OpStruct::Expand(expand) => { assert!(children.len() == 1); let (scans, child) = self.create_plan(&children[0], inputs)?; @@ -934,6 +984,104 @@ impl PhysicalPlanner { } } + /// Create a DataFusion windows physical expression from Spark physical expression + fn create_window_expr<'a>( + &'a self, + spark_expr: &'a crate::execution::spark_operator::WindowExpr, + input_schema: SchemaRef, + partition_by: &[Arc], + sort_exprs: &[PhysicalSortExpr], + ) -> Result, ExecutionError> { + let (window_func_name, window_func_args) = + match &spark_expr.func.as_ref().unwrap().expr_struct.as_ref() { + Some(ExprStruct::ScalarFunc(f)) => (f.func.clone(), f.args.clone()), + other => { + return Err(ExecutionError::GeneralError(format!( + "{other:?} not supported for window function" + ))) + } + }; + + let window_func = match find_df_window_func(&window_func_name) { + Some(f) => f, + _ => { + return Err(ExecutionError::GeneralError(format!( + "{window_func_name} not supported for window function" + ))) + } + }; + + let window_args = window_func_args + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect::, ExecutionError>>()?; + + let spark_window_frame = match spark_expr + .spec + .as_ref() + .and_then(|inner| inner.frame_specification.as_ref()) + { + Some(frame) => frame, + _ => { + return Err(ExecutionError::DeserializeError( + "Cannot deserialize window frame".to_string(), + )) + } + }; + + let units = match spark_window_frame.frame_type() { + WindowFrameType::Rows => WindowFrameUnits::Rows, + WindowFrameType::Range => WindowFrameUnits::Range, + }; + + let lower_bound: WindowFrameBound = match spark_window_frame + .lower_bound + .as_ref() + .and_then(|inner| inner.lower_frame_bound_struct.as_ref()) + { + Some(l) => match l { + LowerFrameBoundStruct::UnboundedPreceding(_) => { + WindowFrameBound::Preceding(ScalarValue::Null) + } + LowerFrameBoundStruct::Preceding(offset) => { + WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset))) + } + LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Preceding(ScalarValue::Null), + }; + + let upper_bound: WindowFrameBound = match spark_window_frame + .upper_bound + .as_ref() + .and_then(|inner| inner.upper_frame_bound_struct.as_ref()) + { + Some(u) => match u { + UpperFrameBoundStruct::UnboundedFollowing(_) => { + WindowFrameBound::Preceding(ScalarValue::Null) + } + UpperFrameBoundStruct::Following(offset) => { + WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset))) + } + UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Following(ScalarValue::Null), + }; + + let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound); + + datafusion::physical_plan::windows::create_window_expr( + &window_func, + window_func_name, + &window_args, + partition_by, + sort_exprs, + window_frame.into(), + &input_schema, + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } + /// Create a DataFusion physical partitioning from Spark physical partitioning fn create_partitioning( &self, diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index 5b07cb30b1..85ff1e366b 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -40,6 +40,7 @@ message Operator { Limit limit = 105; ShuffleWriter shuffle_writer = 106; Expand expand = 107; + Window window = 108; } } @@ -87,3 +88,60 @@ message Expand { repeated spark.spark_expression.Expr project_list = 1; int32 num_expr_per_project = 3; } + +message WindowExpr { + spark.spark_expression.Expr func = 1; + WindowSpecDefinition spec = 2; +} + +enum WindowFrameType { + Rows = 0; + Range = 1; +} + +message WindowFrame { + WindowFrameType frame_type = 1; + LowerWindowFrameBound lower_bound = 2; + UpperWindowFrameBound upper_bound = 3; +} + +message LowerWindowFrameBound { + oneof lower_frame_bound_struct { + UnboundedPreceding unboundedPreceding = 1; + Preceding preceding = 2; + CurrentRow currentRow = 3; + } +} + +message UpperWindowFrameBound { + oneof upper_frame_bound_struct { + UnboundedFollowing unboundedFollowing = 1; + Following following = 2; + CurrentRow currentRow = 3; + } +} + +message Preceding { + int32 offset = 1; +} + +message Following { + int32 offset = 1; +} + +message UnboundedPreceding {} +message UnboundedFollowing {} +message CurrentRow {} + +message WindowSpecDefinition { + repeated spark.spark_expression.Expr partitionSpec = 1; + repeated spark.spark_expression.Expr orderSpec = 2; + WindowFrame frameSpecification = 3; +} + +message Window { + repeated WindowExpr window_expr = 1; + repeated spark.spark_expression.Expr order_by_list = 2; + repeated spark.spark_expression.Expr partition_by_list = 3; + Operator child = 4; +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87c2265fcb..de18b9d3c6 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -20,7 +20,6 @@ package org.apache.comet import java.nio.ByteOrder - import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit @@ -40,13 +39,13 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - import org.apache.comet.CometConf._ import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde import org.apache.comet.shims.ShimCometSparkSessionExtensions +import org.apache.spark.sql.execution.window.WindowExec class CometSparkSessionExtensions extends (SparkSessionExtensions => Unit) @@ -357,6 +356,16 @@ class CometSparkSessionExtensions s } + case w: WindowExec => + QueryPlanSerde.operator2Proto(w) match { + case Some(nativeOp) => + val bosonOp = + CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child) + CometSinkPlaceHolder(nativeOp, w, bosonOp) + case None => + w + } + case u: UnionExec if isCometOperatorEnabled(conf, "union") && u.children.forall(isCometNative) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8b4eaa6092..036c726cc0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -20,7 +20,6 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum} @@ -36,12 +35,12 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleEx import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isCometScan, isSpark32, isSpark34Plus} -import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} +import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc, DataType => ProtoDataType} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.{Operator, AggregateMode => CometAggregateMode} import org.apache.comet.shims.ShimQueryPlanSerde +import org.apache.spark.sql.execution.window.WindowExec /** * An utility object for query plan and expression serialization. @@ -186,6 +185,91 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } + def windowExprToProto( + windowExpr: WindowExpression, + inputs: Seq[Attribute]): Option[OperatorOuterClass.WindowExpr] = { + val func = exprToProto(windowExpr.windowFunction, inputs).getOrElse(return None) + + val f = windowExpr.windowSpec.frameSpecification + + val (frameType, lowerBound, upperBound) = f match { + case SpecifiedWindowFrame(frameType, lBound, uBound) => + val frameProto = frameType match { + case RowFrame => OperatorOuterClass.WindowFrameType.Rows + case RangeFrame => OperatorOuterClass.WindowFrameType.Range + } + + val lBoundProto = lBound match { + case UnboundedPreceding => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setPreceding( + OperatorOuterClass.Preceding + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + val uBoundProto = uBound match { + case UnboundedFollowing => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setFollowing( + OperatorOuterClass.Following + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + (frameProto, lBoundProto, uBoundProto) + case _ => + ( + OperatorOuterClass.WindowFrameType.Rows, + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build(), + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build()) + } + + val frame = OperatorOuterClass.WindowFrame + .newBuilder() + .setFrameType(frameType) + .setLowerBound(lowerBound) + .setUpperBound(upperBound) + .build() + + val spec = + OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build() + + Some(OperatorOuterClass.WindowExpr.newBuilder().setFunc(func).setSpec(spec).build()) + } + def aggExprToProto( aggExpr: AggregateExpression, inputs: Seq[Attribute], @@ -1505,6 +1589,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // the fix: https://github.com/apache/arrow-datafusion/pull/9459 castToProto(None, a.dataType, childExpr) + case r @ Rank(_) => + val exprChildren = r.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("rank", exprChildren: _*) + + case r @ RowNumber() => + val exprChildren = r.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("row_number", exprChildren: _*) + + case l @ Lag(_, _, _, _) => + val exprChildren = l.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("lag", exprChildren: _*) + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for // char types. Use rpad to achieve the behavior. // See https://github.com/apache/spark/pull/38151 @@ -1908,6 +2004,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _: ShuffleExchangeExec => true case _: TakeOrderedAndProjectExec => true case _: BroadcastExchangeExec => true + case _: WindowExec => true case _ => false } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala new file mode 100644 index 0000000000..ea455f1903 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -0,0 +1,113 @@ +package org.apache.spark.sql.comet + +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, windowExprToProto} + +import scala.collection.JavaConverters.asJavaIterableConverter +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, SortOrder, WindowExpression} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.CometWindowExec.getNativePlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Comet physical plan node for Spark `WindowsExec`. + * + * It is used to execute a `WindowsExec` physical operator by using Comet native engine. It is not + * like other physical plan nodes which are wrapped by `CometExec`, because it contains two native + * executions separated by a Comet shuffle exchange. + */ +case class CometWindowExec( + override val originalPlan: SparkPlan, + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + override def nodeName: String = "CometWindowExec" + + override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + override def supportsColumnar: Boolean = true + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + + childRDD.mapPartitionsInternal { iter => + CometExec.getCometIterator( + Seq(iter), + getNativePlan(output, windowExpression, partitionSpec, orderSpec, child).get) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + +} + +object CometWindowExec { + def getNativePlan( + outputAttributes: Seq[Attribute], + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan): Option[Operator] = { + + val orderSpecs = orderSpec.map(exprToProto(_, child.output)) + val partitionSpecs = partitionSpec.map(exprToProto(_, child.output)) + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + val windowExprs = windowExpression.map(w => + windowExprToProto(w.asInstanceOf[Alias].child.asInstanceOf[WindowExpression], child.output)) + + val windowBuilder = OperatorOuterClass.Window + .newBuilder() + + if (windowExprs.forall(_.isDefined)) { + windowBuilder + .addAllWindowExpr(windowExprs.map(_.get).asJava) + + if (orderSpecs.forall(_.isDefined)) { + windowBuilder.addAllOrderByList(orderSpecs.map(_.get).asJava) + } + + if (partitionSpecs.forall(_.isDefined)) { + windowBuilder.addAllPartitionByList(partitionSpecs.map(_.get).asJava) + } + + scanBuilder.addAllFields(scanTypes.asJava) + + val opBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + + Some(opBuilder.setWindow(windowBuilder).build()) + } else None + } +} \ No newline at end of file