diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 9afb0ebb88..ddaee5ff6e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -20,6 +20,8 @@ use std::{collections::HashMap, sync::Arc}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use datafusion::physical_plan::windows::BoundedWindowAggExec; +use datafusion::physical_plan::InputOrderMode; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, @@ -46,20 +48,23 @@ use datafusion::{ }, prelude::SessionContext, }; -use datafusion::physical_plan::InputOrderMode; -use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, JoinType as DFJoinType, ScalarValue, }; use datafusion_expr::expr::find_df_window_func; -use datafusion_expr::{ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits}; +use datafusion_expr::{ + ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, +}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use itertools::Itertools; use jni::objects::GlobalRef; use num::{BigInt, ToPrimitive}; +use crate::execution::spark_operator::lower_window_frame_bound::LowerFrameBoundStruct; +use crate::execution::spark_operator::upper_window_frame_bound::UpperFrameBoundStruct; +use crate::execution::spark_operator::WindowFrameType; use crate::{ errors::ExpressionError, execution::{ @@ -100,9 +105,6 @@ use crate::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }, }; -use crate::execution::spark_operator::lower_window_frame_bound::LowerFrameBoundStruct; -use crate::execution::spark_operator::upper_window_frame_bound::UpperFrameBoundStruct; -use crate::execution::spark_operator::WindowFrameType; use super::expressions::{abs::CometAbsFunc, EvalMode}; @@ -1384,7 +1386,7 @@ impl PhysicalPlanner { Some(ExprStruct::ScalarFunc(f)) => { window_func_name = f.func.clone(); window_func_args = f.args.clone(); - }, + } other => { return Err(ExecutionError::GeneralError(format!( "{other:?} not supported for window function" @@ -1396,8 +1398,9 @@ impl PhysicalPlanner { window_func_name = result.0; window_func_args = result.1; } else { - // Handle the case where neither func nor agg_func is set. - return Err(ExecutionError::GeneralError("Both func and agg_func are not set".to_string())); + return Err(ExecutionError::GeneralError( + "Both func and agg_func are not set".to_string(), + )); } let window_func = match find_df_window_func(&window_func_name) { @@ -1484,15 +1487,16 @@ impl PhysicalPlanner { &input_schema, false, // TODO: Ignore nulls ) - .map_err(|e| ExecutionError::DataFusionError(e.to_string())) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } - fn process_agg_func( - agg_func: &AggExpr - ) -> Result<(String, Vec), ExecutionError> { + fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec), ExecutionError> { fn optional_expr_to_vec(expr_option: &Option) -> Vec { - expr_option.as_ref().cloned().map_or_else(Vec::new, |e| vec![e]) + expr_option + .as_ref() + .cloned() + .map_or_else(Vec::new, |e| vec![e]) } fn int_to_stats_type(value: i32) -> Option { @@ -1507,13 +1511,13 @@ impl PhysicalPlanner { Some(AggExprStruct::Count(expr)) => { let args = &expr.children; Ok(("count".to_string(), args.to_vec())) - }, + } Some(AggExprStruct::Min(expr)) => { Ok(("min".to_string(), optional_expr_to_vec(&expr.child))) - }, + } Some(AggExprStruct::Max(expr)) => { Ok(("max".to_string(), optional_expr_to_vec(&expr.child))) - }, + } other => { return Err(ExecutionError::GeneralError(format!( "{other:?} not supported for window function" diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 45109d2e4c..29a9a1ff34 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._