Skip to content

Commit

Permalink
feat: Add Comet windows function support
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Mar 12, 2024
1 parent c2f2d9b commit 068c958
Show file tree
Hide file tree
Showing 5 changed files with 437 additions and 12 deletions.
160 changes: 154 additions & 6 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion::{
arrow::{compute::SortOptions, datatypes::SchemaRef},
common::DataFusionError,
logical_expr::{BuiltinScalarFunction, Operator as DataFusionOperator},
logical_expr::{
expr::find_df_window_func, BuiltinScalarFunction, Operator as DataFusionOperator,
WindowFrame, WindowFrameBound, WindowFrameUnits,
},
physical_expr::{
expressions::{BinaryExpr, Column, IsNotNullExpr, Literal as DataFusionLiteral},
functions::create_physical_expr,
Expand All @@ -35,7 +38,8 @@ use datafusion::{
limit::LocalLimitExec,
projection::ProjectionExec,
sorts::sort::SortExec,
ExecutionPlan, Partitioning,
windows::BoundedWindowAggExec,
ExecutionPlan, InputOrderMode, Partitioning, WindowExpr,
},
};
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -74,12 +78,14 @@ use crate::{
},
operators::{CopyExec, ExecutionError, ScanExec},
serde::to_arrow_datatype,
spark_expression,
spark_expression::{
agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr, Expr,
ScalarFunc,
self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr,
Expr, ScalarFunc,
},
spark_operator::{
lower_window_frame_bound::LowerFrameBoundStruct, operator::OpStruct,
upper_window_frame_bound::UpperFrameBoundStruct, Operator, WindowFrameType,
},
spark_operator::{operator::OpStruct, Operator},
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
},
};
Expand Down Expand Up @@ -797,6 +803,50 @@ impl PhysicalPlanner {
)?),
))
}
OpStruct::Window(wnd) => {
dbg!(&inputs);
//dbg!(&children);
let (scans, child) = self.create_plan(&children[0], inputs)?;
let input_schema = child.schema();
//dbg!(&input_schema);
let sort_exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = wnd
.order_by_list
.iter()
.map(|expr| self.create_sort_expr(expr, input_schema.clone()))
.collect();

let partition_exprs: Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError> = wnd
.partition_by_list
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect();

let sort_exprs = &sort_exprs?;
let partition_exprs = &partition_exprs?;

let window_expr: Result<Vec<Arc<dyn WindowExpr>>, ExecutionError> = wnd
.window_expr
.iter()
.map(|expr| {
self.create_window_expr(
expr,
input_schema.clone(),
partition_exprs,
sort_exprs,
)
})
.collect();

Ok((
scans,
Arc::new(BoundedWindowAggExec::try_new(
window_expr?,
child,
partition_exprs.to_vec(),
InputOrderMode::Sorted,
)?),
))
}
OpStruct::Expand(expand) => {
assert!(children.len() == 1);
let (scans, child) = self.create_plan(&children[0], inputs)?;
Expand Down Expand Up @@ -934,6 +984,104 @@ impl PhysicalPlanner {
}
}

/// Create a DataFusion windows physical expression from Spark physical expression
fn create_window_expr<'a>(
&'a self,
spark_expr: &'a crate::execution::spark_operator::WindowExpr,
input_schema: SchemaRef,
partition_by: &[Arc<dyn PhysicalExpr>],
sort_exprs: &[PhysicalSortExpr],
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
let (window_func_name, window_func_args) =
match &spark_expr.func.as_ref().unwrap().expr_struct.as_ref() {
Some(ExprStruct::ScalarFunc(f)) => (f.func.clone(), f.args.clone()),
other => {
return Err(ExecutionError::GeneralError(format!(
"{other:?} not supported for window function"
)))
}
};

let window_func = match find_df_window_func(&window_func_name) {
Some(f) => f,
_ => {
return Err(ExecutionError::GeneralError(format!(
"{window_func_name} not supported for window function"
)))
}
};

let window_args = window_func_args
.iter()
.map(|expr| self.create_expr(expr, input_schema.clone()))
.collect::<Result<Vec<_>, ExecutionError>>()?;

let spark_window_frame = match spark_expr
.spec
.as_ref()
.and_then(|inner| inner.frame_specification.as_ref())
{
Some(frame) => frame,
_ => {
return Err(ExecutionError::DeserializeError(
"Cannot deserialize window frame".to_string(),
))
}
};

let units = match spark_window_frame.frame_type() {
WindowFrameType::Rows => WindowFrameUnits::Rows,
WindowFrameType::Range => WindowFrameUnits::Range,
};

let lower_bound: WindowFrameBound = match spark_window_frame
.lower_bound
.as_ref()
.and_then(|inner| inner.lower_frame_bound_struct.as_ref())
{
Some(l) => match l {
LowerFrameBoundStruct::UnboundedPreceding(_) => {
WindowFrameBound::Preceding(ScalarValue::Null)
}
LowerFrameBoundStruct::Preceding(offset) => {
WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset)))
}
LowerFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Preceding(ScalarValue::Null),
};

let upper_bound: WindowFrameBound = match spark_window_frame
.upper_bound
.as_ref()
.and_then(|inner| inner.upper_frame_bound_struct.as_ref())
{
Some(u) => match u {
UpperFrameBoundStruct::UnboundedFollowing(_) => {
WindowFrameBound::Preceding(ScalarValue::Null)
}
UpperFrameBoundStruct::Following(offset) => {
WindowFrameBound::Preceding(ScalarValue::Int32(Some(offset.offset)))
}
UpperFrameBoundStruct::CurrentRow(_) => WindowFrameBound::CurrentRow,
},
None => WindowFrameBound::Following(ScalarValue::Null),
};

let window_frame = WindowFrame::new_bounds(units, lower_bound, upper_bound);

datafusion::physical_plan::windows::create_window_expr(
&window_func,
window_func_name,
&window_args,
partition_by,
sort_exprs,
window_frame.into(),
&input_schema,
)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))
}

/// Create a DataFusion physical partitioning from Spark physical partitioning
fn create_partitioning(
&self,
Expand Down
58 changes: 58 additions & 0 deletions core/src/execution/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message Operator {
Limit limit = 105;
ShuffleWriter shuffle_writer = 106;
Expand expand = 107;
Window window = 108;
}
}

Expand Down Expand Up @@ -87,3 +88,60 @@ message Expand {
repeated spark.spark_expression.Expr project_list = 1;
int32 num_expr_per_project = 3;
}

message WindowExpr {
spark.spark_expression.Expr func = 1;
WindowSpecDefinition spec = 2;
}

enum WindowFrameType {
Rows = 0;
Range = 1;
}

message WindowFrame {
WindowFrameType frame_type = 1;
LowerWindowFrameBound lower_bound = 2;
UpperWindowFrameBound upper_bound = 3;
}

message LowerWindowFrameBound {
oneof lower_frame_bound_struct {
UnboundedPreceding unboundedPreceding = 1;
Preceding preceding = 2;
CurrentRow currentRow = 3;
}
}

message UpperWindowFrameBound {
oneof upper_frame_bound_struct {
UnboundedFollowing unboundedFollowing = 1;
Following following = 2;
CurrentRow currentRow = 3;
}
}

message Preceding {
int32 offset = 1;
}

message Following {
int32 offset = 1;
}

message UnboundedPreceding {}
message UnboundedFollowing {}
message CurrentRow {}

message WindowSpecDefinition {
repeated spark.spark_expression.Expr partitionSpec = 1;
repeated spark.spark_expression.Expr orderSpec = 2;
WindowFrame frameSpecification = 3;
}

message Window {
repeated WindowExpr window_expr = 1;
repeated spark.spark_expression.Expr order_by_list = 2;
repeated spark.spark_expression.Expr partition_by_list = 3;
Operator child = 4;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.comet

import java.nio.ByteOrder

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
Expand All @@ -40,13 +39,13 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, isCometBroadCastEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde
import org.apache.comet.shims.ShimCometSparkSessionExtensions
import org.apache.spark.sql.execution.window.WindowExec

class CometSparkSessionExtensions
extends (SparkSessionExtensions => Unit)
Expand Down Expand Up @@ -357,6 +356,16 @@ class CometSparkSessionExtensions
s
}

case w: WindowExec =>
QueryPlanSerde.operator2Proto(w) match {
case Some(nativeOp) =>
val bosonOp =
CometWindowExec(w, w.windowExpression, w.partitionSpec, w.orderSpec, w.child)
CometSinkPlaceHolder(nativeOp, w, bosonOp)
case None =>
w
}

case u: UnionExec
if isCometOperatorEnabled(conf, "union") &&
u.children.forall(isCometNative) =>
Expand Down
Loading

0 comments on commit 068c958

Please sign in to comment.