diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index afdebce328..a969c3baad 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -20,6 +20,8 @@ use std::{collections::HashMap, sync::Arc}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use datafusion::physical_plan::windows::BoundedWindowAggExec; +use datafusion::physical_plan::InputOrderMode; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, @@ -50,12 +52,17 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, }; -use datafusion_expr::ScalarUDF; +use datafusion_expr::expr::find_df_window_func; +use datafusion_expr::{ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits}; +use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; +use crate::execution::spark_operator::lower_window_frame_bound::LowerFrameBoundStruct; +use crate::execution::spark_operator::upper_window_frame_bound::UpperFrameBoundStruct; +use crate::execution::spark_operator::WindowFrameType; use crate::{ errors::ExpressionError, execution::{ @@ -980,6 +987,47 @@ impl PhysicalPlanner { Ok((scans, hash_join)) } + OpStruct::Window(wnd) => { + let (scans, child) = self.create_plan(&children[0], inputs)?; + let input_schema = child.schema(); + let sort_exprs: Result, ExecutionError> = wnd + .order_by_list + .iter() + .map(|expr| self.create_sort_expr(expr, input_schema.clone())) + .collect(); + + let partition_exprs: Result>, ExecutionError> = wnd + .partition_by_list + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect(); + + let sort_exprs = &sort_exprs?; + let partition_exprs = &partition_exprs?; + + let window_expr: Result>, ExecutionError> = wnd + .window_expr + .iter() + .map(|expr| { + self.create_window_expr( + expr, + input_schema.clone(), + partition_exprs, + sort_exprs, + ) + }) + .collect(); + + Ok(( + scans, + Arc::new(BoundedWindowAggExec::try_new( + window_expr?, + child, + partition_exprs.to_vec(), + InputOrderMode::Sorted, + )?), + )) + } } } @@ -1322,6 +1370,152 @@ impl PhysicalPlanner { } } + /// Create a DataFusion windows physical expression from Spark physical expression + fn create_window_expr<'a>( + &'a self, + spark_expr: &'a crate::execution::spark_operator::WindowExpr, + input_schema: SchemaRef, + partition_by: &[Arc], + sort_exprs: &[PhysicalSortExpr], + ) -> Result, ExecutionError> { + let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new()); + if let Some(func) = &spark_expr.built_in_window_function { + match &func.expr_struct { + Some(ExprStruct::ScalarFunc(f)) => { + window_func_name.clone_from(&f.func); + window_func_args.clone_from(&f.args); + } + other => { + return Err(ExecutionError::GeneralError(format!( + "{other:?} not supported for window function" + ))) + } + }; + } else if let Some(agg_func) = &spark_expr.agg_func { + let result = Self::process_agg_func(agg_func)?; + window_func_name = result.0; + window_func_args = result.1; + } else { + return Err(ExecutionError::GeneralError( + "Both func and agg_func are not set".to_string(), + )); + } + + let window_func = match find_df_window_func(&window_func_name) { + Some(f) => f, + _ => { + return Err(ExecutionError::GeneralError(format!( + "{window_func_name} not supported for window function" + ))) + } + }; + + let window_args = window_func_args + .iter() + .map(|expr| self.create_expr(expr, input_schema.clone())) + .collect::, ExecutionError>>()?; + + let spark_window_frame = match spark_expr + .spec + .as_ref() + .and_then(|inner| inner.frame_specification.as_ref()) + { + Some(frame) => frame, + _ => { + return Err(ExecutionError::DeserializeError( + "Cannot deserialize window frame".to_string(), + )) + } + }; + + let units = match spark_window_frame.frame_type() { + WindowFrameType::Rows => WindowFrameUnits::Rows, + WindowFrameType::Range => WindowFrameUnits::Range, + }; + + let lower_bound: WindowFrameBound = match spark_window_frame + .lower_bound + .as_ref() + .and_then(|inner| inner.lower_frame_bound_struct.as_ref()) + { + Some(l) => match l { + LowerFrameBoundStruct::UnboundedPreceding(_) => { + WindowFrameBound::Preceding(ScalarValue::UInt64(None)) + } + LowerFrameBoundStruct::Preceding(offset) => { + let offset_value = offset.offset.unsigned_abs() as u64; + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value))) + } + LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + }; + + let upper_bound: WindowFrameBound = match spark_window_frame + .upper_bound + .as_ref() + .and_then(|inner| inner.upper_frame_bound_struct.as_ref()) + { + Some(u) => match u { + UpperFrameBoundStruct::UnboundedFollowing(_) => { + WindowFrameBound::Following(ScalarValue::UInt64(None)) + } + UpperFrameBoundStruct::Following(offset) => { + WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64))) + } + UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow, + }, + None => WindowFrameBound::Following(ScalarValue::UInt64(None)), + }; + + let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound); + + datafusion::physical_plan::windows::create_window_expr( + &window_func, + window_func_name, + &window_args, + partition_by, + sort_exprs, + window_frame.into(), + &input_schema, + false, // TODO: Ignore nulls + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + } + + fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec), ExecutionError> { + fn optional_expr_to_vec(expr_option: &Option) -> Vec { + expr_option + .as_ref() + .cloned() + .map_or_else(Vec::new, |e| vec![e]) + } + + fn int_to_stats_type(value: i32) -> Option { + match value { + 0 => Some(StatsType::Sample), + 1 => Some(StatsType::Population), + _ => None, + } + } + + match &agg_func.expr_struct { + Some(AggExprStruct::Count(expr)) => { + let args = &expr.children; + Ok(("count".to_string(), args.to_vec())) + } + Some(AggExprStruct::Min(expr)) => { + Ok(("min".to_string(), optional_expr_to_vec(&expr.child))) + } + Some(AggExprStruct::Max(expr)) => { + Ok(("max".to_string(), optional_expr_to_vec(&expr.child))) + } + other => Err(ExecutionError::GeneralError(format!( + "{other:?} not supported for window function" + ))), + } + } + /// Create a DataFusion physical partitioning from Spark physical partitioning fn create_partitioning( &self, diff --git a/core/src/execution/proto/operator.proto b/core/src/execution/proto/operator.proto index de25f94dae..335d425966 100644 --- a/core/src/execution/proto/operator.proto +++ b/core/src/execution/proto/operator.proto @@ -42,6 +42,7 @@ message Operator { Expand expand = 107; SortMergeJoin sort_merge_join = 108; HashJoin hash_join = 109; + Window window = 110; } } @@ -120,3 +121,61 @@ enum BuildSide { BuildLeft = 0; BuildRight = 1; } + +message WindowExpr { + spark.spark_expression.Expr built_in_window_function = 1; + spark.spark_expression.AggExpr agg_func = 2; + WindowSpecDefinition spec = 3; +} + +enum WindowFrameType { + Rows = 0; + Range = 1; +} + +message WindowFrame { + WindowFrameType frame_type = 1; + LowerWindowFrameBound lower_bound = 2; + UpperWindowFrameBound upper_bound = 3; +} + +message LowerWindowFrameBound { + oneof lower_frame_bound_struct { + UnboundedPreceding unboundedPreceding = 1; + Preceding preceding = 2; + CurrentRow currentRow = 3; + } +} + +message UpperWindowFrameBound { + oneof upper_frame_bound_struct { + UnboundedFollowing unboundedFollowing = 1; + Following following = 2; + CurrentRow currentRow = 3; + } +} + +message Preceding { + int32 offset = 1; +} + +message Following { + int32 offset = 1; +} + +message UnboundedPreceding {} +message UnboundedFollowing {} +message CurrentRow {} + +message WindowSpecDefinition { + repeated spark.spark_expression.Expr partitionSpec = 1; + repeated spark.spark_expression.Expr orderSpec = 2; + WindowFrame frameSpecification = 3; +} + +message Window { + repeated WindowExpr window_expr = 1; + repeated spark.spark_expression.Expr order_by_list = 2; + repeated spark.spark_expression.Expr partition_by_list = 3; + Operator child = 4; +} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 37ca55e27e..ffb7de83b5 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -541,6 +542,17 @@ class CometSparkSessionExtensions withInfo(s, Seq(info1, info2).flatten.mkString(",")) s + case w: WindowExec => + val newOp = transform1(w) + newOp match { + case Some(nativeOp) => + val cometOp = + CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child) + CometSinkPlaceHolder(nativeOp, w, cometOp) + case None => + w + } + case u: UnionExec if isCometOperatorEnabled(conf, "union") && u.children.forall(isCometNative) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6838e0237b..3465315d18 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -37,6 +37,7 @@ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, Shuffle import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -199,6 +200,129 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } + def windowExprToProto( + windowExpr: WindowExpression, + output: Seq[Attribute]): Option[OperatorOuterClass.WindowExpr] = { + + val aggregateExpressions: Array[AggregateExpression] = windowExpr.flatMap { expr => + expr match { + case agg: AggregateExpression => + agg.aggregateFunction match { + case _: Min | _: Max | _: Count => + Some(agg) + case _ => + withInfo(windowExpr, "Unsupported aggregate", expr) + None + } + case _ => + None + } + }.toArray + + val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) { + val modes = aggregateExpressions.map(_.mode).distinct + assert(modes.size == 1 && modes.head == Complete) + (aggExprToProto(aggregateExpressions.head, output, true), None) + } else { + (None, exprToProto(windowExpr.windowFunction, output)) + } + + val f = windowExpr.windowSpec.frameSpecification + + val (frameType, lowerBound, upperBound) = f match { + case SpecifiedWindowFrame(frameType, lBound, uBound) => + val frameProto = frameType match { + case RowFrame => OperatorOuterClass.WindowFrameType.Rows + case RangeFrame => OperatorOuterClass.WindowFrameType.Range + } + + val lBoundProto = lBound match { + case UnboundedPreceding => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setPreceding( + OperatorOuterClass.Preceding + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + val uBoundProto = uBound match { + case UnboundedFollowing => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build() + case CurrentRow => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build()) + .build() + case e => + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setFollowing( + OperatorOuterClass.Following + .newBuilder() + .setOffset(e.eval().asInstanceOf[Int]) + .build()) + .build() + } + + (frameProto, lBoundProto, uBoundProto) + case _ => + ( + OperatorOuterClass.WindowFrameType.Rows, + OperatorOuterClass.LowerWindowFrameBound + .newBuilder() + .setUnboundedPreceding(OperatorOuterClass.UnboundedPreceding.newBuilder().build()) + .build(), + OperatorOuterClass.UpperWindowFrameBound + .newBuilder() + .setUnboundedFollowing(OperatorOuterClass.UnboundedFollowing.newBuilder().build()) + .build()) + } + + val frame = OperatorOuterClass.WindowFrame + .newBuilder() + .setFrameType(frameType) + .setLowerBound(lowerBound) + .setUpperBound(upperBound) + .build() + + val spec = + OperatorOuterClass.WindowSpecDefinition.newBuilder().setFrameSpecification(frame).build() + + if (builtinFunc.isDefined) { + Some( + OperatorOuterClass.WindowExpr + .newBuilder() + .setBuiltInWindowFunction(builtinFunc.get) + .setSpec(spec) + .build()) + } else if (aggExpr.isDefined) { + Some( + OperatorOuterClass.WindowExpr + .newBuilder() + .setAggFunc(aggExpr.get) + .setSpec(spec) + .build()) + } else { + None + } + } + def aggExprToProto( aggExpr: AggregateExpression, inputs: Seq[Attribute], @@ -2352,6 +2476,41 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } + case WindowExec(windowExpression, partitionSpec, orderSpec, child) + if isCometOperatorEnabled(op.conf, "window") => + val output = child.output + + val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr => + expr match { + case alias: Alias => + alias.child match { + case winExpr: WindowExpression => + Some(winExpr) + case _ => + None + } + case _ => + None + } + }.toArray + + val windowExprProto = winExprs.map(windowExprToProto(_, output)) + + val partitionExprs = partitionSpec.map(exprToProto(_, child.output)) + + val sortOrders = orderSpec.map(exprToProto(_, child.output)) + + if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined) + && sortOrders.forall(_.isDefined)) { + val windowBuilder = OperatorOuterClass.Window.newBuilder() + windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava) + windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava) + windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava) + Some(result.setWindow(windowBuilder).build()) + } else { + None + } + case HashAggregateExec( _, _, @@ -2652,6 +2811,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: TakeOrderedAndProjectExec => true case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true case _: BroadcastExchangeExec => true + case _: WindowExec => true case _ => false } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala new file mode 100644 index 0000000000..9a1232f0cd --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters.asJavaIterableConverter + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, SortOrder, WindowExpression} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.CometWindowExec.getNativePlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, windowExprToProto} + +/** + * Comet physical plan node for Spark `WindowsExec`. + * + * It is used to execute a `WindowsExec` physical operator by using Comet native engine. It is not + * like other physical plan nodes which are wrapped by `CometExec`, because it contains two native + * executions separated by a Comet shuffle exchange. + */ +case class CometWindowExec( + override val originalPlan: SparkPlan, + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan) + extends CometExec + with UnaryExecNode { + + override def nodeName: String = "CometWindowExec" + + override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics: Map[String, SQLMetric] = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "shuffleReadElapsedCompute" -> + SQLMetrics.createNanoTimingMetric(sparkContext, "shuffle read elapsed compute at native"), + "numPartitions" -> SQLMetrics.createMetric( + sparkContext, + "number of partitions")) ++ readMetrics ++ writeMetrics + + override def supportsColumnar: Boolean = true + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val childRDD = child.executeColumnar() + + childRDD.mapPartitionsInternal { iter => + CometExec.getCometIterator( + Seq(iter), + getNativePlan(output, windowExpression, partitionSpec, orderSpec, child).get) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def outputPartitioning: Partitioning = child.outputPartitioning + + protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + +} + +object CometWindowExec { + def getNativePlan( + outputAttributes: Seq[Attribute], + windowExpression: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + child: SparkPlan): Option[Operator] = { + + val orderSpecs = orderSpec.map(exprToProto(_, child.output)) + val partitionSpecs = partitionSpec.map(exprToProto(_, child.output)) + val scanBuilder = OperatorOuterClass.Scan.newBuilder() + val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() + + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + + val windowExprs = windowExpression.map(w => + windowExprToProto( + w.asInstanceOf[Alias].child.asInstanceOf[WindowExpression], + outputAttributes)) + + val windowBuilder = OperatorOuterClass.Window + .newBuilder() + + if (windowExprs.forall(_.isDefined)) { + windowBuilder + .addAllWindowExpr(windowExprs.map(_.get).asJava) + + if (orderSpecs.forall(_.isDefined)) { + windowBuilder.addAllOrderByList(orderSpecs.map(_.get).asJava) + } + + if (partitionSpecs.forall(_.isDefined)) { + windowBuilder.addAllPartitionByList(partitionSpecs.map(_.get).asJava) + } + + scanBuilder.addAllFields(scanTypes.asJava) + + val opBuilder = OperatorOuterClass.Operator + .newBuilder() + .addChildren(scanOpBuilder.setScan(scanBuilder)) + + Some(opBuilder.setWindow(windowBuilder).build()) + } else None + } +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0dc8510c33..6c821965dd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1432,6 +1432,33 @@ class CometExecSuite extends CometTestBase { } }) } + + test("Windows support") { + Seq("true", "false").foreach(aqeEnabled => + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { + withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls + val aggregateFunctions = + List("COUNT(_1)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates + + aggregateFunctions.foreach { function => + val queries = Seq( + s"SELECT $function OVER() FROM t1", + s"SELECT $function OVER(order by _2) FROM t1", + s"SELECT $function OVER(order by _2 desc) FROM t1", + s"SELECT $function OVER(partition by _2 order by _2) FROM t1", + s"SELECT $function OVER(rows between 1 preceding and 1 following) FROM t1", + s"SELECT $function OVER(order by _2 rows between 1 preceding and current row) FROM t1", + s"SELECT $function OVER(order by _2 rows between current row and 1 following) FROM t1") + + queries.foreach { query => + checkSparkAnswerAndOperator(query) + } + } + } + }) + } } case class BucketedTableTestSpec(