From 375f11ced81b94f6a65ee941bb64b7a102fb72cb Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 3 Mar 2024 00:53:35 -0800 Subject: [PATCH] fix: Final aggregation should not bind to the input of partial aggregation --- core/src/execution/datafusion/planner.rs | 13 +- core/src/execution/proto/expr.proto | 6 + .../apache/comet/serde/QueryPlanSerde.scala | 375 ++++++++++-------- .../CometTakeOrderedAndProjectExec.scala | 4 +- .../comet/exec/CometAggregateSuite.scala | 19 + 5 files changed, 248 insertions(+), 169 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index f4a0cec791..33cf636a2e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -43,7 +43,7 @@ use datafusion_physical_expr::{ execution_props::ExecutionProps, expressions::{ CaseExpr, CastExpr, Count, FirstValue, InListExpr, IsNullExpr, LastValue, Max, Min, - NegativeExpr, NotExpr, Sum, + NegativeExpr, NotExpr, Sum, UnKnownColumn, }, AggregateExpr, ScalarFunctionExpr, }; @@ -201,9 +201,16 @@ impl PhysicalPlanner { } ExprStruct::Bound(bound) => { let idx = bound.index as usize; - let column_name = format!("col_{}", idx); - Ok(Arc::new(Column::new(&column_name, idx))) + if idx >= input_schema.fields().len() { + return Err(ExecutionError::GeneralError(format!( + "Column index {} is out of bound. Schema: {}", + idx, input_schema + ))); + } + let field = input_schema.field(idx); + Ok(Arc::new(Column::new(field.name().as_str(), idx))) } + ExprStruct::Unbound(unbound) => Ok(Arc::new(UnKnownColumn::new(unbound.name.as_str()))), ExprStruct::IsNotNull(is_notnull) => { let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(IsNotNullExpr::new(child))) diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index a80335c50a..8aa81b7672 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -75,6 +75,7 @@ message Expr { BitwiseNot bitwiseNot = 48; Abs abs = 49; Subquery subquery = 50; + UnboundReference unbound = 51; } } @@ -254,6 +255,11 @@ message BoundReference { DataType datatype = 2; } +message UnboundReference { + string name = 1; + DataType datatype = 2; +} + message SortOrder { Expr child = 1; SortDirection direction = 2; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 75a2ff9819..661b7f2a11 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -186,10 +186,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - def aggExprToProto(aggExpr: AggregateExpression, inputs: Seq[Attribute]): Option[AggExpr] = { + def aggExprToProto( + aggExpr: AggregateExpression, + inputs: Seq[Attribute], + binding: Boolean): Option[AggExpr] = { aggExpr.aggregateFunction match { case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -207,7 +210,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case s @ Average(child, _) if avgDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) val sumDataType = if (child.dataType.isInstanceOf[DecimalType]) { @@ -239,7 +242,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case Count(children) => - val exprChildren = children.map(exprToProto(_, inputs)) + val exprChildren = children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { val countBuilder = ExprOuterClass.Count.newBuilder() @@ -254,7 +257,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case min @ Min(child) if minMaxDataTypeSupported(min.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(min.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -271,7 +274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case max @ Max(child) if minMaxDataTypeSupported(max.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(max.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -289,7 +292,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case first @ First(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(first.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -307,7 +310,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case last @ Last(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(last.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -330,26 +333,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - def exprToProto(expr: Expression, input: Seq[Attribute]): Option[Expr] = { + def exprToProto( + expr: Expression, + input: Seq[Attribute], + binding: Boolean = true): Option[Expr] = { val conf = SQLConf.get val newExpr = DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) - exprToProtoInternal(newExpr, input) + exprToProtoInternal(newExpr, input, binding) } - def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { + /** + * Convert a Spark expression to protobuf. + * + * @param expr + * The input expression + * @param inputs + * The input attributes + * @param binding + * Whether to bind the expression to the input attributes + * @return + * The protobuf representation of the expression, or None if the expression is not supported + */ + def exprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { SQLConf.get expr match { case a @ Alias(_, _) => - exprToProtoInternal(a.child, inputs) + exprToProtoInternal(a.child, inputs, binding) case cast @ Cast(_: Literal, dataType, _, _) => // This can happen after promoting decimal precisions val value = cast.eval() - exprToProtoInternal(Literal(value, dataType), inputs) + exprToProtoInternal(Literal(value, dataType), inputs, binding) case Cast(child, dt, timeZoneId, _) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) val dataType = serializeDataType(dt) if (childExpr.isDefined && dataType.isDefined) { @@ -370,8 +391,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case add @ Add(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val addBuilder = ExprOuterClass.Add.newBuilder() @@ -392,8 +413,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Subtract.newBuilder() @@ -415,8 +436,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Multiply.newBuilder() @@ -438,11 +459,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case div @ Divide(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) // Datafusion now throws an exception for dividing by zero // See https://github.com/apache/arrow-datafusion/pull/6792 // For now, use NullIf to swap zeros with nulls. - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Divide.newBuilder() @@ -464,8 +485,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Remainder.newBuilder() @@ -486,8 +507,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Equal.newBuilder() @@ -504,8 +525,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Not(EqualTo(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.NotEqual.newBuilder() @@ -522,8 +543,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case EqualNullSafe(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.EqualNullSafe.newBuilder() @@ -540,8 +561,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Not(EqualNullSafe(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() @@ -558,8 +579,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case GreaterThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.GreaterThan.newBuilder() @@ -576,8 +597,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case GreaterThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.GreaterThanEqual.newBuilder() @@ -594,8 +615,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case LessThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.LessThan.newBuilder() @@ -612,8 +633,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case LessThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.LessThanEqual.newBuilder() @@ -677,7 +698,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Substring(str, Literal(pos, _), Literal(len, _)) => - val strExpr = exprToProtoInternal(str, inputs) + val strExpr = exprToProtoInternal(str, inputs, binding) if (strExpr.isDefined) { val builder = ExprOuterClass.Substring.newBuilder() @@ -696,8 +717,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case Like(left, right, _) => // TODO escapeChar - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Like.newBuilder() @@ -733,8 +754,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { // } case StartsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.StartsWith.newBuilder() @@ -751,8 +772,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case EndsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.EndsWith.newBuilder() @@ -769,8 +790,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Contains(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Contains.newBuilder() @@ -787,7 +808,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case StringSpace(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.StringSpace.newBuilder() @@ -803,7 +824,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Hour(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.Hour.newBuilder() @@ -822,7 +843,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Minute(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.Minute.newBuilder() @@ -841,8 +862,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case TruncDate(child, format) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) + val formatExpr = exprToProtoInternal(format, inputs, binding) if (childExpr.isDefined && formatExpr.isDefined) { val builder = ExprOuterClass.TruncDate.newBuilder() @@ -859,8 +880,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case TruncTimestamp(format, child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) + val formatExpr = exprToProtoInternal(format, inputs, binding) if (childExpr.isDefined && formatExpr.isDefined) { val builder = ExprOuterClass.TruncTimestamp.newBuilder() @@ -880,7 +901,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Second(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.Second.newBuilder() @@ -899,8 +920,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Year(child) => - val periodType = exprToProtoInternal(Literal("year"), inputs) - val childExpr = exprToProtoInternal(child, inputs) + val periodType = exprToProtoInternal(Literal("year"), inputs, binding) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("datepart", Seq(periodType, childExpr): _*) .map(e => { Expr @@ -915,7 +936,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { }) case IsNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val castBuilder = ExprOuterClass.IsNull.newBuilder() @@ -931,7 +952,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case IsNotNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val castBuilder = ExprOuterClass.IsNotNull.newBuilder() @@ -947,7 +968,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() @@ -973,8 +994,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case And(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.And.newBuilder() @@ -991,8 +1012,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Or(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Or.newBuilder() @@ -1011,10 +1032,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case UnaryExpression(child) if expr.prettyName == "promote_precision" => // `UnaryExpression` includes `PromotePrecision` for Spark 3.2 & 3.3 // `PromotePrecision` is just a wrapper, don't need to serialize it. - exprToProtoInternal(child, inputs) + exprToProtoInternal(child, inputs, binding) case CheckOverflow(child, dt, nullOnOverflow) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.CheckOverflow.newBuilder() @@ -1038,26 +1059,40 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(attr.dataType) if (dataType.isDefined) { - val boundRef = BindReferences - .bindReference(attr, inputs, allowFailures = false) - .asInstanceOf[BoundReference] - val boundExpr = ExprOuterClass.BoundReference - .newBuilder() - .setIndex(boundRef.ordinal) - .setDatatype(dataType.get) - .build() + if (binding) { + val boundRef = BindReferences + .bindReference(attr, inputs, allowFailures = false) + .asInstanceOf[BoundReference] + val boundExpr = ExprOuterClass.BoundReference + .newBuilder() + .setIndex(boundRef.ordinal) + .setDatatype(dataType.get) + .build() - Some( - ExprOuterClass.Expr + Some( + ExprOuterClass.Expr + .newBuilder() + .setBound(boundExpr) + .build()) + } else { + val unboundRef = ExprOuterClass.UnboundReference .newBuilder() - .setBound(boundExpr) - .build()) + .setName(attr.name) + .setDatatype(dataType.get) + .build() + + Some( + ExprOuterClass.Expr + .newBuilder() + .setUnbound(unboundRef) + .build()) + } } else { None } case Abs(child, _) => - exprToProtoInternal(child, inputs).map(childExpr => { + exprToProtoInternal(child, inputs, binding).map(childExpr => { val abs = ExprOuterClass.Abs .newBuilder() @@ -1067,24 +1102,24 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { }) case Acos(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("acos", childExpr) case Asin(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("asin", childExpr) case Atan(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("atan", childExpr) case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) scalarExprToProto("atan2", leftExpr, rightExpr) case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) child.dataType match { case t: DecimalType if t.scale == 0 => // zero scale is no-op childExpr @@ -1095,15 +1130,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Cos(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("cos", childExpr) case Exp(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("exp", childExpr) case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) child.dataType match { case t: DecimalType if t.scale == 0 => // zero scale is no-op childExpr @@ -1114,20 +1149,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case Log(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("ln", childExpr) case Log10(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("log10", childExpr) case Log2(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("log2", childExpr) case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) scalarExprToProto("pow", leftExpr, rightExpr) // round function for Spark 3.2 does not allow negative round target scale. In addition, @@ -1137,12 +1172,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val scaleV: Any = r.scale.eval(EmptyRow) val _scale: Int = scaleV.asInstanceOf[Int] - lazy val childExpr = exprToProtoInternal(r.child, inputs) + lazy val childExpr = exprToProtoInternal(r.child, inputs, binding) r.child.dataType match { case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 None case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) + exprToProtoInternal(Literal(null), inputs, binding) case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark case _: FloatType | DoubleType => @@ -1163,38 +1198,38 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None case _ => // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding) scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) } case Signum(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("signum", childExpr) case Sin(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("sin", childExpr) case Sqrt(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("sqrt", childExpr) case Tan(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("tan", childExpr) case Ascii(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("ascii", childExpr) case BitLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("bit_length", childExpr) case If(predicate, trueValue, falseValue) => - val predicateExpr = exprToProtoInternal(predicate, inputs) - val trueExpr = exprToProtoInternal(trueValue, inputs) - val falseExpr = exprToProtoInternal(falseValue, inputs) + val predicateExpr = exprToProtoInternal(predicate, inputs, binding) + val trueExpr = exprToProtoInternal(trueValue, inputs, binding) + val falseExpr = exprToProtoInternal(falseValue, inputs, binding) if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { val builder = ExprOuterClass.IfExpr.newBuilder() builder.setIfExpr(predicateExpr.get) @@ -1210,15 +1245,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case CaseWhen(branches, elseValue) => - val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs)) - val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs)) + val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs, binding)) + val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs, binding)) assert(whenSeq.length == thenSeq.length) if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { val builder = ExprOuterClass.CaseWhen.newBuilder() builder.addAllWhen(whenSeq.map(_.get).asJava) builder.addAllThen(thenSeq.map(_.get).asJava) if (elseValue.isDefined) { - val elseValueExpr = exprToProtoInternal(elseValue.get, inputs) + val elseValueExpr = exprToProtoInternal(elseValue.get, inputs, binding) if (elseValueExpr.isDefined) { builder.setElseExpr(elseValueExpr.get) } else { @@ -1235,78 +1270,78 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case ConcatWs(children) => - val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs)) + val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs, binding)) scalarExprToProto("concat_ws", exprs: _*) case Chr(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProto("chr", childExpr) case InitCap(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("initcap", childExpr) case Length(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("length", childExpr) case Lower(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("lower", childExpr) case Md5(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("md5", childExpr) case OctetLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("octet_length", childExpr) case Reverse(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("reverse", childExpr) case StringInstr(str, substr) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs) + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs, binding) + val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs, binding) scalarExprToProto("strpos", leftExpr, rightExpr) case StringRepeat(str, times) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs) + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs, binding) + val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs, binding) scalarExprToProto("repeat", leftExpr, rightExpr) case StringReplace(src, search, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs, binding) + val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs, binding) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs, binding) scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) case StringTranslate(src, matching, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs, binding) + val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs, binding) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs, binding) scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) case StringTrim(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "trim") + trim(srcStr, trimStr, inputs, "trim", binding) case StringTrimLeft(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "ltrim") + trim(srcStr, trimStr, inputs, "ltrim", binding) case StringTrimRight(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "rtrim") + trim(srcStr, trimStr, inputs, "rtrim", binding) case StringTrimBoth(srcStr, trimStr, _) => - trim(srcStr, trimStr, inputs, "btrim") + trim(srcStr, trimStr, inputs, "btrim", binding) case Upper(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs, binding) scalarExprToProto("upper", childExpr) case BitwiseAnd(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseAnd.newBuilder() @@ -1323,7 +1358,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case BitwiseNot(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.BitwiseNot.newBuilder() @@ -1339,8 +1374,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case BitwiseOr(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseOr.newBuilder() @@ -1357,8 +1392,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case BitwiseXor(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.BitwiseXor.newBuilder() @@ -1375,13 +1410,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case ShiftRight(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) val rightExpr = if (left.dataType == LongType) { // DataFusion bitwise shift right expression requires // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs) + exprToProtoInternal(Cast(right, LongType), inputs, binding) } else { - exprToProtoInternal(right, inputs) + exprToProtoInternal(right, inputs, binding) } if (leftExpr.isDefined && rightExpr.isDefined) { @@ -1399,13 +1434,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case ShiftLeft(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) + val leftExpr = exprToProtoInternal(left, inputs, binding) val rightExpr = if (left.dataType == LongType) { // DataFusion bitwise shift left expression requires // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs) + exprToProtoInternal(Cast(right, LongType), inputs, binding) } else { - exprToProtoInternal(right, inputs) + exprToProtoInternal(right, inputs, binding) } if (leftExpr.isDefined && rightExpr.isDefined) { @@ -1423,7 +1458,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case In(value, list) => - in(value, list, inputs, false) + in(value, list, inputs, false, binding) case InSet(value, hset) => val valueDataType = value.dataType @@ -1432,13 +1467,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { }.toSeq // Change `InSet` to `In` expression // We do Spark `InSet` optimization in native (DataFusion) side. - in(value, list, inputs, false) + in(value, list, inputs, false, binding) case Not(In(value, list)) => - in(value, list, inputs, true) + in(value, list, inputs, true, binding) case Not(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.Not.newBuilder() builder.setChild(childExpr.get) @@ -1452,7 +1487,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case UnaryMinus(child, _) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) if (childExpr.isDefined) { val builder = ExprOuterClass.Negative.newBuilder() builder.setChild(childExpr.get) @@ -1466,7 +1501,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) + val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding)) scalarExprToProto("coalesce", exprChildren: _*) // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for char @@ -1481,8 +1516,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { false, true) if arguments.size == 2 => val argsExpr = Seq( - exprToProtoInternal(Cast(arguments(0), StringType), inputs), - exprToProtoInternal(arguments(1), inputs)) + exprToProtoInternal(Cast(arguments(0), StringType), inputs, binding), + exprToProtoInternal(arguments(1), inputs, binding)) if (argsExpr.forall(_.isDefined)) { val builder = ExprOuterClass.ScalarFunc.newBuilder() @@ -1499,7 +1534,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (dataType.isEmpty) { return None } - exprToProtoInternal(expr, inputs).map { child => + exprToProtoInternal(expr, inputs, binding).map { child => val builder = ExprOuterClass.NormalizeNaNAndZero .newBuilder() .setChild(child) @@ -1520,11 +1555,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) case UnscaledValue(child) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) case MakeDecimal(child, precision, scale, true) => - val childExpr = exprToProtoInternal(child, inputs) + val childExpr = exprToProtoInternal(child, inputs, binding) scalarExprToProtoWithReturnType("make_decimal", DecimalType(precision, scale), childExpr) case e => @@ -1537,10 +1572,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { srcStr: Expression, trimStr: Option[Expression], inputs: Seq[Attribute], - trimType: String): Option[Expr] = { - val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs) + trimType: String, + binding: Boolean): Option[Expr] = { + val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs, binding) if (trimStr.isDefined) { - val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs) + val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs, binding) scalarExprToProto(trimType, srcExpr, trimExpr) } else { scalarExprToProto(trimType, srcExpr) @@ -1551,9 +1587,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { value: Expression, list: Seq[Expression], inputs: Seq[Attribute], - negate: Boolean): Option[Expr] = { - val valueExpr = exprToProtoInternal(value, inputs) - val listExprs = list.map(exprToProtoInternal(_, inputs)) + negate: Boolean, + binding: Boolean): Option[Expr] = { + val valueExpr = exprToProtoInternal(value, inputs, binding) + val listExprs = list.map(exprToProtoInternal(_, inputs, binding)) if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { val builder = ExprOuterClass.In.newBuilder() builder.setInValue(valueExpr.get) @@ -1781,7 +1818,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _ => return None } - val aggExprs = aggregateExpressions.map(aggExprToProto(_, output)) + val binding = if (mode == CometAggregateMode.Final) { + // In final mode, the aggregate expressions are bound to the output of the + // child and partial aggregate expressions buffer attributes produced by partial + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, + // we don't have to do this because we don't use the merging expression. + false + } else { + true + } + + val aggExprs = aggregateExpressions.map(aggExprToProto(_, output, binding)) if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && aggExprs.forall(_.isDefined)) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 26ec401ed6..5259adb876 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -123,8 +123,8 @@ case class CometTakeOrderedAndProjectExec( object CometTakeOrderedAndProjectExec extends ShimCometTakeOrderedAndProjectExec { // TODO: support offset for Spark 3.4 def isSupported(plan: TakeOrderedAndProjectExec): Boolean = { - val exprs = plan.projectList.map(exprToProto(_, plan.child.output)) - val sortOrders = plan.sortOrder.map(exprToProto(_, plan.child.output)) + val exprs = plan.projectList.map(exprToProto(_, plan.child.output, true)) + val sortOrders = plan.sortOrder.map(exprToProto(_, plan.child.output, true)) exprs.forall(_.isDefined) && sortOrders.forall(_.isDefined) && getOffset(plan).getOrElse( 0) == 0 } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index d64a3a3ae2..edd53b942c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -38,6 +38,25 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus * Test suite dedicated to Comet native aggregate operator */ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ + + test("Final aggregation should not bind to the input of partial aggregation") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 10000, 10, dictionaryEnabled) + withParquetTable(path.toUri.toString, "tbl") { + val df = sql("SELECT * FROM tbl").groupBy("_g1").agg(sum($"_3" + $"_g3")) + checkSparkAnswer(df) + } + } + } + } + } test("Ensure traversed operators during finding first partial aggregation are all native") { withTable("lineitem", "part") {