Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initial support for Window function #599

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 195 additions & 1 deletion core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
use std::{collections::HashMap, sync::Arc};

use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::physical_plan::windows::BoundedWindowAggExec;
use datafusion::physical_plan::InputOrderMode;
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
Expand Down Expand Up @@ -50,12 +52,17 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter},
JoinType as DFJoinType, ScalarValue,
};
use datafusion_expr::ScalarUDF;
use datafusion_expr::expr::find_df_window_func;
use datafusion_expr::{ScalarUDF, WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_physical_expr::window::WindowExpr;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};

use crate::execution::spark_operator::lower_window_frame_bound::LowerFrameBoundStruct;
use crate::execution::spark_operator::upper_window_frame_bound::UpperFrameBoundStruct;
use crate::execution::spark_operator::WindowFrameType;
use crate::{
errors::ExpressionError,
execution::{
Expand Down Expand Up @@ -980,6 +987,47 @@ impl PhysicalPlanner {

Ok((scans, hash_join))
}
OpStruct::Window(wnd) => {
let (scans, child) = self.create_plan(&children[0], inputs)?;
let input_schema = child.schema();
let sort_exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = wnd
.order_by_list
.iter()
.map(|expr| self.create_sort_expr(expr, input_schema.clone()))
.collect();

let partition_exprs: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = wnd
.partition_by_list
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect();

let sort_exprs = &sort_exprs?;
let partition_exprs = &partition_exprs?;

let window_expr: Result<Vec<Arc<dyn WindowExpr>>, ExecutionError> = wnd
.window_expr
.iter()
.map(|expr| {
self.create_window_expr(
expr,
input_schema.clone(),
partition_exprs,
sort_exprs,
)
})
.collect();

Ok((
scans,
Arc::new(BoundedWindowAggExec::try_new(
window_expr?,
child,
partition_exprs.to_vec(),
InputOrderMode::Sorted,
)?),
))
}
}
}

Expand Down Expand Up @@ -1322,6 +1370,152 @@ impl PhysicalPlanner {
}
}

/// Create a DataFusion windows physical expression from Spark physical expression
fn create_window_expr<'a>(
&'a self,
spark_expr: &'a crate::execution::spark_operator::WindowExpr,
input_schema: SchemaRef,
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let (mut window_func_name, mut window_func_args) = (String::new(), Vec::new());
if let Some(func) = &spark_expr.built_in_window_function {
match &func.expr_struct {
Some(ExprStruct::ScalarFunc(f)) => {
window_func_name.clone_from(&f.func);
window_func_args.clone_from(&f.args);
}
other => {
return Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
)))
}
};
} else if let Some(agg_func) = &spark_expr.agg_func {
let result = Self::process_agg_func(agg_func)?;
window_func_name = result.0;
window_func_args = result.1;
} else {
return Err(ExecutionError::GeneralError(
"Both func and agg_func are not set".to_string(),
));
}

let window_func = match find_df_window_func(&window_func_name) {
Some(f) => f,
_ => {
return Err(ExecutionError::GeneralError(format!(
"{window_func_name} not supported for window function"
)))
}
};

let window_args = window_func_args
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, ExecutionError>>()?;

let spark_window_frame = match spark_expr
.spec
.as_ref()
.and_then(|inner| inner.frame_specification.as_ref())
{
Some(frame) => frame,
_ => {
return Err(ExecutionError::DeserializeError(
"Cannot deserialize window frame".to_string(),
))
}
};

let units = match spark_window_frame.frame_type() {
WindowFrameType::Rows => WindowFrameUnits::Rows,
WindowFrameType::Range => WindowFrameUnits::Range,
};

let lower_bound: WindowFrameBound = match spark_window_frame
.lower_bound
.as_ref()
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::UInt64(None))
}
LowerFrameBoundStruct::Preceding(offset) => {
let offset_value = offset.offset.unsigned_abs() as u64;
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
};

let upper_bound: WindowFrameBound = match spark_window_frame
.upper_bound
.as_ref()
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Following(ScalarValue::UInt64(None))
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
}
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);

datafusion::physical_plan::windows::create_window_expr(
&window_func,
window_func_name,
&window_args,
partition_by,
sort_exprs,
window_frame.into(),
&input_schema,
false, // TODO: Ignore nulls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll create a follow up PR for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}

fn process_agg_func(agg_func: &AggExpr) -> Result<(String, Vec<Expr>), ExecutionError> {
fn optional_expr_to_vec(expr_option: &Option<Expr>) -> Vec<Expr> {
expr_option
.as_ref()
.cloned()
.map_or_else(Vec::new, |e| vec![e])
}

fn int_to_stats_type(value: i32) -> Option<StatsType> {
match value {
0 => Some(StatsType::Sample),
1 => Some(StatsType::Population),
_ => None,
}
}

match &agg_func.expr_struct {
Some(AggExprStruct::Count(expr)) => {
let args = &expr.children;
Ok(("count".to_string(), args.to_vec()))
}
Some(AggExprStruct::Min(expr)) => {
Ok(("min".to_string(), optional_expr_to_vec(&expr.child)))
}
Some(AggExprStruct::Max(expr)) => {
Ok(("max".to_string(), optional_expr_to_vec(&expr.child)))
}
other => Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
))),
}
}

/// Create a DataFusion physical partitioning from Spark physical partitioning
fn create_partitioning(
&self,
Expand Down
59 changes: 59 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ message Operator {
Expand expand = 107;
SortMergeJoin sort_merge_join = 108;
HashJoin hash_join = 109;
Window window = 110;
}
}

Expand Down Expand Up @@ -120,3 +121,61 @@ enum BuildSide {
BuildLeft = 0;
BuildRight = 1;
}

message WindowExpr {
spark.spark_expression.Expr built_in_window_function = 1;
spark.spark_expression.AggExpr agg_func = 2;
WindowSpecDefinition spec = 3;
}

enum WindowFrameType {
Rows = 0;
Range = 1;
}

message WindowFrame {
WindowFrameType frame_type = 1;
LowerWindowFrameBound lower_bound = 2;
UpperWindowFrameBound upper_bound = 3;
}

message LowerWindowFrameBound {
oneof lower_frame_bound_struct {
UnboundedPreceding unboundedPreceding = 1;
Preceding preceding = 2;
CurrentRow currentRow = 3;
}
}

message UpperWindowFrameBound {
oneof upper_frame_bound_struct {
UnboundedFollowing unboundedFollowing = 1;
Following following = 2;
CurrentRow currentRow = 3;
}
}

message Preceding {
int32 offset = 1;
}

message Following {
int32 offset = 1;
}

message UnboundedPreceding {}
message UnboundedFollowing {}
message CurrentRow {}

message WindowSpecDefinition {
repeated spark.spark_expression.Expr partitionSpec = 1;
repeated spark.spark_expression.Expr orderSpec = 2;
WindowFrame frameSpecification = 3;
}

message Window {
repeated WindowExpr window_expr = 1;
repeated spark.spark_expression.Expr order_by_list = 2;
repeated spark.spark_expression.Expr partition_by_list = 3;
Operator child = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -541,6 +542,17 @@ class CometSparkSessionExtensions
withInfo(s, Seq(info1, info2).flatten.mkString(","))
s

case w: WindowExec =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add isCometOperatorEnabled() to enable/disable the feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I saw some of the isCometOperatorEnabled are checked in CometSparkSessionExtensions, but some of them are in QueryPlanSerde. I followed HashAggregateExec and added the check in QueryPlanSerde

val newOp = transform1(w)
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
case None =>
w
}

case u: UnionExec
if isCometOperatorEnabled(conf, "union") &&
u.children.forall(isCometNative) =>
Expand Down
Loading
Loading