From a131c4406755f7bfba02865782ac11bfdbfcdb5c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 4 Mar 2024 17:03:56 -0800 Subject: [PATCH] fix: Final aggregation should not bind to the input of partial aggregation (#155) This patch adds the check of the index of bound reference. The aggregate expressions of final aggregation are not bound to the input of partial aggregation anymore but sent to native side as unbound expressions. --- core/src/execution/datafusion/planner.rs | 13 +- core/src/execution/proto/expr.proto | 6 + .../apache/comet/serde/QueryPlanSerde.scala | 2257 +++++++++-------- .../comet/exec/CometAggregateSuite.scala | 19 + 4 files changed, 1187 insertions(+), 1108 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index f4a0cec79..33cf636a2 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -43,7 +43,7 @@ use datafusion_physical_expr::{ execution_props::ExecutionProps, expressions::{ CaseExpr, CastExpr, Count, FirstValue, InListExpr, IsNullExpr, LastValue, Max, Min, - NegativeExpr, NotExpr, Sum, + NegativeExpr, NotExpr, Sum, UnKnownColumn, }, AggregateExpr, ScalarFunctionExpr, }; @@ -201,9 +201,16 @@ impl PhysicalPlanner { } ExprStruct::Bound(bound) => { let idx = bound.index as usize; - let column_name = format!("col_{}", idx); - Ok(Arc::new(Column::new(&column_name, idx))) + if idx >= input_schema.fields().len() { + return Err(ExecutionError::GeneralError(format!( + "Column index {} is out of bound. Schema: {}", + idx, input_schema + ))); + } + let field = input_schema.field(idx); + Ok(Arc::new(Column::new(field.name().as_str(), idx))) } + ExprStruct::Unbound(unbound) => Ok(Arc::new(UnKnownColumn::new(unbound.name.as_str()))), ExprStruct::IsNotNull(is_notnull) => { let child = self.create_expr(is_notnull.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(IsNotNullExpr::new(child))) diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index a80335c50..8aa81b767 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -75,6 +75,7 @@ message Expr { BitwiseNot bitwiseNot = 48; Abs abs = 49; Subquery subquery = 50; + UnboundReference unbound = 51; } } @@ -254,6 +255,11 @@ message BoundReference { DataType datatype = 2; } +message UnboundReference { + string name = 1; + DataType datatype = 2; +} + message SortOrder { Expr child = 1; SortDirection direction = 2; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b154fb62f..8dc2b1596 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -186,10 +186,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - def aggExprToProto(aggExpr: AggregateExpression, inputs: Seq[Attribute]): Option[AggExpr] = { + def aggExprToProto( + aggExpr: AggregateExpression, + inputs: Seq[Attribute], + binding: Boolean): Option[AggExpr] = { aggExpr.aggregateFunction match { case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -207,7 +210,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case s @ Average(child, _) if avgDataTypeSupported(s.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(s.dataType) val sumDataType = if (child.dataType.isInstanceOf[DecimalType]) { @@ -239,7 +242,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case Count(children) => - val exprChildren = children.map(exprToProto(_, inputs)) + val exprChildren = children.map(exprToProto(_, inputs, binding)) if (exprChildren.forall(_.isDefined)) { val countBuilder = ExprOuterClass.Count.newBuilder() @@ -254,7 +257,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case min @ Min(child) if minMaxDataTypeSupported(min.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(min.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -271,7 +274,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { None } case max @ Max(child) if minMaxDataTypeSupported(max.dataType) => - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(max.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -289,7 +292,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case first @ First(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(first.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -307,7 +310,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } case last @ Last(child, ignoreNulls) if !ignoreNulls => // DataFusion doesn't support ignoreNulls true - val childExpr = exprToProto(child, inputs) + val childExpr = exprToProto(child, inputs, binding) val dataType = serializeDataType(last.dataType) if (childExpr.isDefined && dataType.isDefined) { @@ -330,1243 +333,1277 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - def exprToProto(expr: Expression, input: Seq[Attribute]): Option[Expr] = { - val conf = SQLConf.get - val newExpr = - DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) - exprToProtoInternal(newExpr, input) - } - - def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { - SQLConf.get - expr match { - case a @ Alias(_, _) => - exprToProtoInternal(a.child, inputs) - - case cast @ Cast(_: Literal, dataType, _, _) => - // This can happen after promoting decimal precisions - val value = cast.eval() - exprToProtoInternal(Literal(value, dataType), inputs) - - case Cast(child, dt, timeZoneId, _) => - val childExpr = exprToProtoInternal(child, inputs) - val dataType = serializeDataType(dt) + /** + * Convert a Spark expression to protobuf. + * + * @param expr + * The input expression + * @param inputs + * The input attributes + * @param binding + * Whether to bind the expression to the input attributes + * @return + * The protobuf representation of the expression, or None if the expression is not supported + */ + def exprToProto( + expr: Expression, + input: Seq[Attribute], + binding: Boolean = true): Option[Expr] = { + + def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { + SQLConf.get + expr match { + case a @ Alias(_, _) => + exprToProtoInternal(a.child, inputs) + + case cast @ Cast(_: Literal, dataType, _, _) => + // This can happen after promoting decimal precisions + val value = cast.eval() + exprToProtoInternal(Literal(value, dataType), inputs) + + case Cast(child, dt, timeZoneId, _) => + val childExpr = exprToProtoInternal(child, inputs) + val dataType = serializeDataType(dt) - if (childExpr.isDefined && dataType.isDefined) { - val castBuilder = ExprOuterClass.Cast.newBuilder() - castBuilder.setChild(childExpr.get) - castBuilder.setDatatype(dataType.get) + if (childExpr.isDefined && dataType.isDefined) { + val castBuilder = ExprOuterClass.Cast.newBuilder() + castBuilder.setChild(childExpr.get) + castBuilder.setDatatype(dataType.get) - val timeZone = timeZoneId.getOrElse("UTC") - castBuilder.setTimezone(timeZone) + val timeZone = timeZoneId.getOrElse("UTC") + castBuilder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCast(castBuilder) - .build()) - } else { - None - } - - case add @ Add(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val addBuilder = ExprOuterClass.Add.newBuilder() - addBuilder.setLeft(leftExpr.get) - addBuilder.setRight(rightExpr.get) - addBuilder.setFailOnError(getFailOnError(add)) - serializeDataType(add.dataType).foreach { t => - addBuilder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setCast(castBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setAdd(addBuilder) - .build()) - } else { - None - } + case add @ Add(left, right, _) if supportedDataType(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val addBuilder = ExprOuterClass.Add.newBuilder() + addBuilder.setLeft(leftExpr.get) + addBuilder.setRight(rightExpr.get) + addBuilder.setFailOnError(getFailOnError(add)) + serializeDataType(add.dataType).foreach { t => + addBuilder.setReturnType(t) + } - case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Subtract.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(sub)) - serializeDataType(sub.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setAdd(addBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubtract(builder) - .build()) - } else { - None - } + case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Subtract.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(sub)) + serializeDataType(sub.dataType).foreach { t => + builder.setReturnType(t) + } - case mul @ Multiply(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Multiply.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(mul)) - serializeDataType(mul.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubtract(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setMultiply(builder) - .build()) - } else { - None - } + case mul @ Multiply(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Multiply.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(mul)) + serializeDataType(mul.dataType).foreach { t => + builder.setReturnType(t) + } - case div @ Divide(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - // Datafusion now throws an exception for dividing by zero - // See https://github.com/apache/arrow-datafusion/pull/6792 - // For now, use NullIf to swap zeros with nulls. - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Divide.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(div)) - serializeDataType(div.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setMultiply(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setDivide(builder) - .build()) - } else { - None - } + case div @ Divide(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + // Datafusion now throws an exception for dividing by zero + // See https://github.com/apache/arrow-datafusion/pull/6792 + // For now, use NullIf to swap zeros with nulls. + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Divide.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(div)) + serializeDataType(div.dataType).foreach { t => + builder.setReturnType(t) + } - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Remainder.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(getFailOnError(rem)) - serializeDataType(rem.dataType).foreach { t => - builder.setReturnType(t) + Some( + ExprOuterClass.Expr + .newBuilder() + .setDivide(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setRemainder(builder) - .build()) - } else { - None - } - - case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Equal.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) - } else { - None - } - - case Not(EqualTo(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.NotEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeq(builder) - .build()) - } else { - None - } - - case EqualNullSafe(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.EqualNullSafe.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setEqNullSafe(builder) - .build()) - } else { - None - } - - case Not(EqualNullSafe(left, right)) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setNeqNullSafe(builder) - .build()) - } else { - None - } - - case GreaterThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.GreaterThan.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setGt(builder) - .build()) - } else { - None - } - - case GreaterThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.GreaterThanEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setGtEq(builder) - .build()) - } else { - None - } - - case LessThan(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.LessThan.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setLt(builder) - .build()) - } else { - None - } - - case LessThanOrEqual(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.LessThanEqual.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setLtEq(builder) - .build()) - } else { - None - } - - case Literal(value, dataType) if supportedDataType(dataType) => - val exprBuilder = ExprOuterClass.Literal.newBuilder() + case rem @ Remainder(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Remainder.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(getFailOnError(rem)) + serializeDataType(rem.dataType).foreach { t => + builder.setReturnType(t) + } - if (value == null) { - exprBuilder.setIsNull(true) - } else { - exprBuilder.setIsNull(false) - dataType match { - case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) - case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) - case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) - case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) - case _: StringType => - exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: DecimalType => - // Pass decimal literal as bytes. - val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue - exprBuilder.setDecimalVal( - com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) - case _: BinaryType => - val byteStr = - com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) - exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case dt => - logWarning(s"Unexpected date type '$dt' for literal value '$value'") + Some( + ExprOuterClass.Expr + .newBuilder() + .setRemainder(builder) + .build()) + } else { + None } - } - - val dt = serializeDataType(dataType) - - if (dt.isDefined) { - exprBuilder.setDatatype(dt.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setLiteral(exprBuilder) - .build()) - } else { - None - } - - case Substring(str, Literal(pos, _), Literal(len, _)) => - val strExpr = exprToProtoInternal(str, inputs) - if (strExpr.isDefined) { - val builder = ExprOuterClass.Substring.newBuilder() - builder.setChild(strExpr.get) - builder.setStart(pos.asInstanceOf[Int]) - builder.setLen(len.asInstanceOf[Int]) + case EqualTo(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubstring(builder) - .build()) - } else { - None - } - - case Like(left, right, _) => - // TODO escapeChar - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Equal.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Like.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setLike(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setEq(builder) + .build()) + } else { + None + } - // TODO waiting for arrow-rs update -// case RLike(left, right) => -// val leftExpr = exprToProtoInternal(left, inputs) -// val rightExpr = exprToProtoInternal(right, inputs) -// -// if (leftExpr.isDefined && rightExpr.isDefined) { -// val builder = ExprOuterClass.RLike.newBuilder() -// builder.setLeft(leftExpr.get) -// builder.setRight(rightExpr.get) -// -// Some( -// ExprOuterClass.Expr -// .newBuilder() -// .setRlike(builder) -// .build()) -// } else { -// None -// } - - case StartsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.StartsWith.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case Not(EqualTo(left, right)) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setStartsWith(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.NotEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case EndsWith(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeq(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.EndsWith.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case EqualNullSafe(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEndsWith(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.EqualNullSafe.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case Contains(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setEqNullSafe(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Contains.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case Not(EqualNullSafe(left, right)) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setContains(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.NotEqualNullSafe.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case StringSpace(child) => - val childExpr = exprToProtoInternal(child, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNeqNullSafe(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.StringSpace.newBuilder() - builder.setChild(childExpr.get) + case GreaterThan(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setStringSpace(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.GreaterThan.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case Hour(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setGt(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Hour.newBuilder() - builder.setChild(childExpr.get) + case GreaterThanOrEqual(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.GreaterThanEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setHour(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setGtEq(builder) + .build()) + } else { + None + } - case Minute(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + case LessThan(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Minute.newBuilder() - builder.setChild(childExpr.get) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.LessThan.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + Some( + ExprOuterClass.Expr + .newBuilder() + .setLt(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setMinute(builder) - .build()) - } else { - None - } + case LessThanOrEqual(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case TruncDate(child, format) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.LessThanEqual.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncDate.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setLtEq(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncDate(builder) - .build()) - } else { - None - } + case Literal(value, dataType) if supportedDataType(dataType) => + val exprBuilder = ExprOuterClass.Literal.newBuilder() - case TruncTimestamp(format, child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + if (value == null) { + exprBuilder.setIsNull(true) + } else { + exprBuilder.setIsNull(false) + dataType match { + case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) + case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) + case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) + case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) + case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) + case _: StringType => + exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) + case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: DecimalType => + // Pass decimal literal as bytes. + val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue + exprBuilder.setDecimalVal( + com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) + case _: BinaryType => + val byteStr = + com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) + exprBuilder.setBytesVal(byteStr) + case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case dt => + logWarning(s"Unexpected date type '$dt' for literal value '$value'") + } + } - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncTimestamp.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + val dt = serializeDataType(dataType) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (dt.isDefined) { + exprBuilder.setDatatype(dt.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncTimestamp(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setLiteral(exprBuilder) + .build()) + } else { + None + } - case Second(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + case Substring(str, Literal(pos, _), Literal(len, _)) => + val strExpr = exprToProtoInternal(str, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Second.newBuilder() - builder.setChild(childExpr.get) + if (strExpr.isDefined) { + val builder = ExprOuterClass.Substring.newBuilder() + builder.setChild(strExpr.get) + builder.setStart(pos.asInstanceOf[Int]) + builder.setLen(len.asInstanceOf[Int]) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubstring(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSecond(builder) - .build()) - } else { - None - } + case Like(left, right, _) => + // TODO escapeChar + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Like.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setLike(builder) + .build()) + } else { + None + } - case Year(child) => - val periodType = exprToProtoInternal(Literal("year"), inputs) - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("datepart", Seq(periodType, childExpr): _*) - .map(e => { - Expr - .newBuilder() - .setCast( - ExprOuterClass.Cast - .newBuilder() - .setChild(e) - .setDatatype(serializeDataType(IntegerType).get) - .build()) - .build() - }) + // TODO waiting for arrow-rs update + // case RLike(left, right) => + // val leftExpr = exprToProtoInternal(left, inputs) + // val rightExpr = exprToProtoInternal(right, inputs) + // + // if (leftExpr.isDefined && rightExpr.isDefined) { + // val builder = ExprOuterClass.RLike.newBuilder() + // builder.setLeft(leftExpr.get) + // builder.setRight(rightExpr.get) + // + // Some( + // ExprOuterClass.Expr + // .newBuilder() + // .setRlike(builder) + // .build()) + // } else { + // None + // } + + case StartsWith(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.StartsWith.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setStartsWith(builder) + .build()) + } else { + None + } - case IsNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + case EndsWith(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - if (childExpr.isDefined) { - val castBuilder = ExprOuterClass.IsNull.newBuilder() - castBuilder.setChild(childExpr.get) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.EndsWith.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNull(castBuilder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setEndsWith(builder) + .build()) + } else { + None + } - case IsNotNull(child) => - val childExpr = exprToProtoInternal(child, inputs) + case Contains(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - if (childExpr.isDefined) { - val castBuilder = ExprOuterClass.IsNotNull.newBuilder() - castBuilder.setChild(childExpr.get) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Contains.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIsNotNull(castBuilder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setContains(builder) + .build()) + } else { + None + } - case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs) + case StringSpace(child) => + val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() - sortOrderBuilder.setChild(childExpr.get) + if (childExpr.isDefined) { + val builder = ExprOuterClass.StringSpace.newBuilder() + builder.setChild(childExpr.get) - direction match { - case Ascending => sortOrderBuilder.setDirectionValue(0) - case Descending => sortOrderBuilder.setDirectionValue(1) + Some( + ExprOuterClass.Expr + .newBuilder() + .setStringSpace(builder) + .build()) + } else { + None } - nullOrdering match { - case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) - case NullsLast => sortOrderBuilder.setNullOrderingValue(1) + case Hour(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) + + if (childExpr.isDefined) { + val builder = ExprOuterClass.Hour.newBuilder() + builder.setChild(childExpr.get) + + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setHour(builder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSortOrder(sortOrderBuilder) - .build()) - } else { - None - } + case Minute(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) - case And(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Minute.newBuilder() + builder.setChild(childExpr.get) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.And.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setAnd(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setMinute(builder) + .build()) + } else { + None + } - case Or(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + case TruncDate(child, format) => + val childExpr = exprToProtoInternal(child, inputs) + val formatExpr = exprToProtoInternal(format, inputs) - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Or.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncDate.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setOr(builder) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncDate(builder) + .build()) + } else { + None + } - case UnaryExpression(child) if expr.prettyName == "promote_precision" => - // `UnaryExpression` includes `PromotePrecision` for Spark 3.2 & 3.3 - // `PromotePrecision` is just a wrapper, don't need to serialize it. - exprToProtoInternal(child, inputs) + case TruncTimestamp(format, child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) + val formatExpr = exprToProtoInternal(format, inputs) - case CheckOverflow(child, dt, nullOnOverflow) => - val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncTimestamp.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - if (childExpr.isDefined) { - val builder = ExprOuterClass.CheckOverflow.newBuilder() - builder.setChild(childExpr.get) - builder.setFailOnError(!nullOnOverflow) + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - // `dataType` must be decimal type - val dataType = serializeDataType(dt) - builder.setDatatype(dataType.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncTimestamp(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCheckOverflow(builder) - .build()) - } else { - None - } + case Second(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs) - case attr: AttributeReference => - val dataType = serializeDataType(attr.dataType) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Second.newBuilder() + builder.setChild(childExpr.get) - if (dataType.isDefined) { - val boundRef = BindReferences - .bindReference(attr, inputs, allowFailures = false) - .asInstanceOf[BoundReference] - val boundExpr = ExprOuterClass.BoundReference - .newBuilder() - .setIndex(boundRef.ordinal) - .setDatatype(dataType.get) - .build() + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBound(boundExpr) - .build()) - } else { - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setSecond(builder) + .build()) + } else { + None + } - case Abs(child, _) => - exprToProtoInternal(child, inputs).map(childExpr => { - val abs = - ExprOuterClass.Abs - .newBuilder() - .setChild(childExpr) - .build() - Expr.newBuilder().setAbs(abs).build() - }) - - case Acos(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("acos", childExpr) - - case Asin(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("asin", childExpr) - - case Atan(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("atan", childExpr) - - case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - scalarExprToProto("atan2", leftExpr, rightExpr) - - case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + case Year(child) => + val periodType = exprToProtoInternal(Literal("year"), inputs) + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("datepart", Seq(periodType, childExpr): _*) + .map(e => { + Expr + .newBuilder() + .setCast( + ExprOuterClass.Cast + .newBuilder() + .setChild(e) + .setDatatype(serializeDataType(IntegerType).get) + .build()) + .build() + }) + + case IsNull(child) => + val childExpr = exprToProtoInternal(child, inputs) + + if (childExpr.isDefined) { + val castBuilder = ExprOuterClass.IsNull.newBuilder() + castBuilder.setChild(childExpr.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNull(castBuilder) + .build()) + } else { None - case _ => - scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) - } + } - case Cos(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("cos", childExpr) + case IsNotNull(child) => + val childExpr = exprToProtoInternal(child, inputs) - case Exp(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("exp", childExpr) + if (childExpr.isDefined) { + val castBuilder = ExprOuterClass.IsNotNull.newBuilder() + castBuilder.setChild(childExpr.get) - case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + Some( + ExprOuterClass.Expr + .newBuilder() + .setIsNotNull(castBuilder) + .build()) + } else { None - case _ => - scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) - } + } - case Log(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("ln", childExpr) - - case Log10(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("log10", childExpr) - - case Log2(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("log2", childExpr) - - case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - scalarExprToProto("pow", leftExpr, rightExpr) - - // round function for Spark 3.2 does not allow negative round target scale. In addition, - // it has different result precision/scale for decimals. Supporting only 3.3 and above. - case r: Round if !isSpark32 => - // _scale s a constant, copied from Spark's RoundBase because it is a protected val - val scaleV: Any = r.scale.eval(EmptyRow) - val _scale: Int = scaleV.asInstanceOf[Int] - - lazy val childExpr = exprToProtoInternal(r.child, inputs) - r.child.dataType match { - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - None - case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) - case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. - None - case _ => - // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) - } + case SortOrder(child, direction, nullOrdering, _) => + val childExpr = exprToProtoInternal(child, inputs) - case Signum(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("signum", childExpr) - - case Sin(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("sin", childExpr) - - case Sqrt(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("sqrt", childExpr) - - case Tan(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("tan", childExpr) - - case Ascii(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("ascii", childExpr) - - case BitLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("bit_length", childExpr) - - case If(predicate, trueValue, falseValue) => - val predicateExpr = exprToProtoInternal(predicate, inputs) - val trueExpr = exprToProtoInternal(trueValue, inputs) - val falseExpr = exprToProtoInternal(falseValue, inputs) - if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { - val builder = ExprOuterClass.IfExpr.newBuilder() - builder.setIfExpr(predicateExpr.get) - builder.setTrueExpr(trueExpr.get) - builder.setFalseExpr(falseExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIf(builder) - .build()) - } else { - None - } + if (childExpr.isDefined) { + val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() + sortOrderBuilder.setChild(childExpr.get) - case CaseWhen(branches, elseValue) => - val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs)) - val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs)) - assert(whenSeq.length == thenSeq.length) - if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { - val builder = ExprOuterClass.CaseWhen.newBuilder() - builder.addAllWhen(whenSeq.map(_.get).asJava) - builder.addAllThen(thenSeq.map(_.get).asJava) - if (elseValue.isDefined) { - val elseValueExpr = exprToProtoInternal(elseValue.get, inputs) - if (elseValueExpr.isDefined) { - builder.setElseExpr(elseValueExpr.get) - } else { - return None + direction match { + case Ascending => sortOrderBuilder.setDirectionValue(0) + case Descending => sortOrderBuilder.setDirectionValue(1) + } + + nullOrdering match { + case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) + case NullsLast => sortOrderBuilder.setNullOrderingValue(1) } + + Some( + ExprOuterClass.Expr + .newBuilder() + .setSortOrder(sortOrderBuilder) + .build()) + } else { + None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCaseWhen(builder) - .build()) - } else { - None - } - case ConcatWs(children) => - val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs)) - scalarExprToProto("concat_ws", exprs: _*) + case And(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Chr(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProto("chr", childExpr) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.And.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setAnd(builder) + .build()) + } else { + None + } - case InitCap(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("initcap", childExpr) + case Or(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - case Length(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("length", childExpr) + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.Or.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case Lower(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("lower", childExpr) + Some( + ExprOuterClass.Expr + .newBuilder() + .setOr(builder) + .build()) + } else { + None + } - case Md5(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("md5", childExpr) + case UnaryExpression(child) if expr.prettyName == "promote_precision" => + // `UnaryExpression` includes `PromotePrecision` for Spark 3.2 & 3.3 + // `PromotePrecision` is just a wrapper, don't need to serialize it. + exprToProtoInternal(child, inputs) - case OctetLength(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("octet_length", childExpr) + case CheckOverflow(child, dt, nullOnOverflow) => + val childExpr = exprToProtoInternal(child, inputs) - case Reverse(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("reverse", childExpr) + if (childExpr.isDefined) { + val builder = ExprOuterClass.CheckOverflow.newBuilder() + builder.setChild(childExpr.get) + builder.setFailOnError(!nullOnOverflow) - case StringInstr(str, substr) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs) - scalarExprToProto("strpos", leftExpr, rightExpr) + // `dataType` must be decimal type + val dataType = serializeDataType(dt) + builder.setDatatype(dataType.get) - case StringRepeat(str, times) => - val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) - val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs) - scalarExprToProto("repeat", leftExpr, rightExpr) + Some( + ExprOuterClass.Expr + .newBuilder() + .setCheckOverflow(builder) + .build()) + } else { + None + } - case StringReplace(src, search, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) - scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) + case attr: AttributeReference => + val dataType = serializeDataType(attr.dataType) + + if (dataType.isDefined) { + if (binding) { + val boundRef = BindReferences + .bindReference(attr, inputs, allowFailures = false) + .asInstanceOf[BoundReference] + val boundExpr = ExprOuterClass.BoundReference + .newBuilder() + .setIndex(boundRef.ordinal) + .setDatatype(dataType.get) + .build() + + Some( + ExprOuterClass.Expr + .newBuilder() + .setBound(boundExpr) + .build()) + } else { + val unboundRef = ExprOuterClass.UnboundReference + .newBuilder() + .setName(attr.name) + .setDatatype(dataType.get) + .build() + + Some( + ExprOuterClass.Expr + .newBuilder() + .setUnbound(unboundRef) + .build()) + } + } else { + None + } - case StringTranslate(src, matching, replace) => - val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) - val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs) - val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) - scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) + case Abs(child, _) => + exprToProtoInternal(child, inputs).map(childExpr => { + val abs = + ExprOuterClass.Abs + .newBuilder() + .setChild(childExpr) + .build() + Expr.newBuilder().setAbs(abs).build() + }) - case StringTrim(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "trim") + case Acos(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("acos", childExpr) + + case Asin(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("asin", childExpr) + + case Atan(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("atan", childExpr) + + case Atan2(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + scalarExprToProto("atan2", leftExpr, rightExpr) + + case e @ Ceil(child) => + val childExpr = exprToProtoInternal(child, inputs) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ => + scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) + } - case StringTrimLeft(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "ltrim") + case Cos(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("cos", childExpr) + + case Exp(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("exp", childExpr) + + case e @ Floor(child) => + val childExpr = exprToProtoInternal(child, inputs) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ => + scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) + } - case StringTrimRight(srcStr, trimStr) => - trim(srcStr, trimStr, inputs, "rtrim") + case Log(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("ln", childExpr) + + case Log10(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("log10", childExpr) + + case Log2(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("log2", childExpr) + + case Pow(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + scalarExprToProto("pow", leftExpr, rightExpr) + + // round function for Spark 3.2 does not allow negative round target scale. In addition, + // it has different result precision/scale for decimals. Supporting only 3.3 and above. + case r: Round if !isSpark32 => + // _scale s a constant, copied from Spark's RoundBase because it is a protected val + val scaleV: Any = r.scale.eval(EmptyRow) + val _scale: Int = scaleV.asInstanceOf[Int] + + lazy val childExpr = exprToProtoInternal(r.child, inputs) + r.child.dataType match { + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + None + case _ if scaleV == null => + exprToProtoInternal(Literal(null), inputs) + case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => + childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + case _: FloatType | DoubleType => + // We cannot properly match with the Spark behavior for floating-point numbers. + // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a + // double to string internally in order to create its own internal representation. + // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated + // rounding algorithm. E.g. -5.81855622136895E8 is actually + // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of + // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a + // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be + // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that + // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can + // be rounded up to 6.13171162472835E18 that still represents the same double number. + // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. + // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead + // of 6.1317116247283999E18. + None + case _ => + // `scale` must be Int64 type in DataFusion + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + } - case StringTrimBoth(srcStr, trimStr, _) => - trim(srcStr, trimStr, inputs, "btrim") + case Signum(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("signum", childExpr) + + case Sin(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("sin", childExpr) + + case Sqrt(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("sqrt", childExpr) + + case Tan(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("tan", childExpr) + + case Ascii(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("ascii", childExpr) + + case BitLength(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("bit_length", childExpr) + + case If(predicate, trueValue, falseValue) => + val predicateExpr = exprToProtoInternal(predicate, inputs) + val trueExpr = exprToProtoInternal(trueValue, inputs) + val falseExpr = exprToProtoInternal(falseValue, inputs) + if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { + val builder = ExprOuterClass.IfExpr.newBuilder() + builder.setIfExpr(predicateExpr.get) + builder.setTrueExpr(trueExpr.get) + builder.setFalseExpr(falseExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIf(builder) + .build()) + } else { + None + } - case Upper(child) => - val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) - scalarExprToProto("upper", childExpr) + case CaseWhen(branches, elseValue) => + val whenSeq = branches.map(elements => exprToProtoInternal(elements._1, inputs)) + val thenSeq = branches.map(elements => exprToProtoInternal(elements._2, inputs)) + assert(whenSeq.length == thenSeq.length) + if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { + val builder = ExprOuterClass.CaseWhen.newBuilder() + builder.addAllWhen(whenSeq.map(_.get).asJava) + builder.addAllThen(thenSeq.map(_.get).asJava) + if (elseValue.isDefined) { + val elseValueExpr = exprToProtoInternal(elseValue.get, inputs) + if (elseValueExpr.isDefined) { + builder.setElseExpr(elseValueExpr.get) + } else { + return None + } + } + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(builder) + .build()) + } else { + None + } - case BitwiseAnd(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + case ConcatWs(children) => + val exprs = children.map(e => exprToProtoInternal(Cast(e, StringType), inputs)) + scalarExprToProto("concat_ws", exprs: _*) + + case Chr(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProto("chr", childExpr) + + case InitCap(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("initcap", childExpr) + + case Length(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("length", childExpr) + + case Lower(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("lower", childExpr) + + case Md5(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("md5", childExpr) + + case OctetLength(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("octet_length", childExpr) + + case Reverse(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("reverse", childExpr) + + case StringInstr(str, substr) => + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) + val rightExpr = exprToProtoInternal(Cast(substr, StringType), inputs) + scalarExprToProto("strpos", leftExpr, rightExpr) + + case StringRepeat(str, times) => + val leftExpr = exprToProtoInternal(Cast(str, StringType), inputs) + val rightExpr = exprToProtoInternal(Cast(times, LongType), inputs) + scalarExprToProto("repeat", leftExpr, rightExpr) + + case StringReplace(src, search, replace) => + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) + val searchExpr = exprToProtoInternal(Cast(search, StringType), inputs) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) + + case StringTranslate(src, matching, replace) => + val srcExpr = exprToProtoInternal(Cast(src, StringType), inputs) + val matchingExpr = exprToProtoInternal(Cast(matching, StringType), inputs) + val replaceExpr = exprToProtoInternal(Cast(replace, StringType), inputs) + scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) + + case StringTrim(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "trim") + + case StringTrimLeft(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "ltrim") + + case StringTrimRight(srcStr, trimStr) => + trim(srcStr, trimStr, inputs, "rtrim") + + case StringTrimBoth(srcStr, trimStr, _) => + trim(srcStr, trimStr, inputs, "btrim") + + case Upper(child) => + val childExpr = exprToProtoInternal(Cast(child, StringType), inputs) + scalarExprToProto("upper", childExpr) + + case BitwiseAnd(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) + + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseAnd.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseAnd(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseAnd.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case BitwiseNot(child) => + val childExpr = exprToProtoInternal(child, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseAnd(builder) - .build()) - } else { - None - } + if (childExpr.isDefined) { + val builder = ExprOuterClass.BitwiseNot.newBuilder() + builder.setChild(childExpr.get) - case BitwiseNot(child) => - val childExpr = exprToProtoInternal(child, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseNot(builder) + .build()) + } else { + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.BitwiseNot.newBuilder() - builder.setChild(childExpr.get) + case BitwiseOr(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseNot(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseOr.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case BitwiseOr(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseOr(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseOr.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case BitwiseXor(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = exprToProtoInternal(right, inputs) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseOr(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseXor.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case BitwiseXor(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseXor(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseXor.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case ShiftRight(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift right expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseXor(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseShiftRight.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case ShiftRight(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = if (left.dataType == LongType) { - // DataFusion bitwise shift right expression requires - // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs) - } else { - exprToProtoInternal(right, inputs) - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftRight(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseShiftRight.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case ShiftLeft(left, right) => + val leftExpr = exprToProtoInternal(left, inputs) + val rightExpr = if (left.dataType == LongType) { + // DataFusion bitwise shift left expression requires + // same data type between left and right side + exprToProtoInternal(Cast(right, LongType), inputs) + } else { + exprToProtoInternal(right, inputs) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftRight(builder) - .build()) - } else { - None - } + if (leftExpr.isDefined && rightExpr.isDefined) { + val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) - case ShiftLeft(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = if (left.dataType == LongType) { - // DataFusion bitwise shift left expression requires - // same data type between left and right side - exprToProtoInternal(Cast(right, LongType), inputs) - } else { - exprToProtoInternal(right, inputs) - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setBitwiseShiftLeft(builder) + .build()) + } else { + None + } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) + case In(value, list) => + in(value, list, inputs, false) + + case InSet(value, hset) => + val valueDataType = value.dataType + val list = hset.map { setVal => + Literal(setVal, valueDataType) + }.toSeq + // Change `InSet` to `In` expression + // We do Spark `InSet` optimization in native (DataFusion) side. + in(value, list, inputs, false) + + case Not(In(value, list)) => + in(value, list, inputs, true) + + case Not(child) => + val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Not.newBuilder() + builder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNot(builder) + .build()) + } else { + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setBitwiseShiftLeft(builder) - .build()) - } else { - None - } + case UnaryMinus(child, _) => + val childExpr = exprToProtoInternal(child, inputs) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Negative.newBuilder() + builder.setChild(childExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setNegative(builder) + .build()) + } else { + None + } - case In(value, list) => - in(value, list, inputs, false) - - case InSet(value, hset) => - val valueDataType = value.dataType - val list = hset.map { setVal => - Literal(setVal, valueDataType) - }.toSeq - // Change `InSet` to `In` expression - // We do Spark `InSet` optimization in native (DataFusion) side. - in(value, list, inputs, false) - - case Not(In(value, list)) => - in(value, list, inputs, true) - - case Not(child) => - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Not.newBuilder() - builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setNot(builder) - .build()) - } else { - None - } + case a @ Coalesce(_) => + val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) + scalarExprToProto("coalesce", exprChildren: _*) + + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for + // char types. Use rpad to achieve the behavior. + // See https://github.com/apache/spark/pull/38151 + case StaticInvoke( + clz: Class[_], + _: StringType, + "readSidePadding", + arguments, + _, + true, + false, + true) if clz == classOf[CharVarcharCodegenUtils] && arguments.size == 2 => + val argsExpr = Seq( + exprToProtoInternal(Cast(arguments(0), StringType), inputs), + exprToProtoInternal(arguments(1), inputs)) + + if (argsExpr.forall(_.isDefined)) { + val builder = ExprOuterClass.ScalarFunc.newBuilder() + builder.setFunc("rpad") + argsExpr.foreach(arg => builder.addArgs(arg.get)) + + Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) + } else { + None + } - case UnaryMinus(child, _) => - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Negative.newBuilder() - builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr + case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => + val dataType = serializeDataType(expr.dataType) + if (dataType.isEmpty) { + return None + } + exprToProtoInternal(expr, inputs).map { child => + val builder = ExprOuterClass.NormalizeNaNAndZero .newBuilder() - .setNegative(builder) - .build()) - } else { - None - } + .setChild(child) + .setDatatype(dataType.get) + ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() + } - case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) - scalarExprToProto("coalesce", exprChildren: _*) - - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for char - // types. Use rpad to achieve the behavior. See https://github.com/apache/spark/pull/38151 - case StaticInvoke( - clz: Class[_], - _: StringType, - "readSidePadding", - arguments, - _, - true, - false, - true) if clz == classOf[CharVarcharCodegenUtils] && arguments.size == 2 => - val argsExpr = Seq( - exprToProtoInternal(Cast(arguments(0), StringType), inputs), - exprToProtoInternal(arguments(1), inputs)) - - if (argsExpr.forall(_.isDefined)) { - val builder = ExprOuterClass.ScalarFunc.newBuilder() - builder.setFunc("rpad") - argsExpr.foreach(arg => builder.addArgs(arg.get)) - - Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) - } else { - None - } + case s @ execution.ScalarSubquery(_, _) => + val dataType = serializeDataType(s.dataType) + if (dataType.isEmpty) { + return None + } - case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => - val dataType = serializeDataType(expr.dataType) - if (dataType.isEmpty) { - return None - } - exprToProtoInternal(expr, inputs).map { child => - val builder = ExprOuterClass.NormalizeNaNAndZero + val builder = ExprOuterClass.Subquery .newBuilder() - .setChild(child) + .setId(s.exprId.id) .setDatatype(dataType.get) - ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() - } - - case s @ execution.ScalarSubquery(_, _) => - val dataType = serializeDataType(s.dataType) - if (dataType.isEmpty) { - return None - } + Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) - val builder = ExprOuterClass.Subquery - .newBuilder() - .setId(s.exprId.id) - .setDatatype(dataType.get) - Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) + case UnscaledValue(child) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) - case UnscaledValue(child) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) + case MakeDecimal(child, precision, scale, true) => + val childExpr = exprToProtoInternal(child, inputs) + scalarExprToProtoWithReturnType( + "make_decimal", + DecimalType(precision, scale), + childExpr) - case MakeDecimal(child, precision, scale, true) => - val childExpr = exprToProtoInternal(child, inputs) - scalarExprToProtoWithReturnType("make_decimal", DecimalType(precision, scale), childExpr) - - case e => - emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") - None + case e => + emitWarning(s"unsupported Spark expression: '$e' of class '${e.getClass.getName}") + None + } } - } - private def trim( - srcStr: Expression, - trimStr: Option[Expression], - inputs: Seq[Attribute], - trimType: String): Option[Expr] = { - val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs) - if (trimStr.isDefined) { - val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs) - scalarExprToProto(trimType, srcExpr, trimExpr) - } else { - scalarExprToProto(trimType, srcExpr) + def trim( + srcStr: Expression, + trimStr: Option[Expression], + inputs: Seq[Attribute], + trimType: String): Option[Expr] = { + val srcExpr = exprToProtoInternal(Cast(srcStr, StringType), inputs) + if (trimStr.isDefined) { + val trimExpr = exprToProtoInternal(Cast(trimStr.get, StringType), inputs) + scalarExprToProto(trimType, srcExpr, trimExpr) + } else { + scalarExprToProto(trimType, srcExpr) + } } - } - private def in( - value: Expression, - list: Seq[Expression], - inputs: Seq[Attribute], - negate: Boolean): Option[Expr] = { - val valueExpr = exprToProtoInternal(value, inputs) - val listExprs = list.map(exprToProtoInternal(_, inputs)) - if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { - val builder = ExprOuterClass.In.newBuilder() - builder.setInValue(valueExpr.get) - builder.addAllLists(listExprs.map(_.get).asJava) - builder.setNegated(negate) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIn(builder) - .build()) - } else { - None + def in( + value: Expression, + list: Seq[Expression], + inputs: Seq[Attribute], + negate: Boolean): Option[Expr] = { + val valueExpr = exprToProtoInternal(value, inputs) + val listExprs = list.map(exprToProtoInternal(_, inputs)) + if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.In.newBuilder() + builder.setInValue(valueExpr.get) + builder.addAllLists(listExprs.map(_.get).asJava) + builder.setNegated(negate) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()) + } else { + None + } } + + val conf = SQLConf.get + val newExpr = + DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) + exprToProtoInternal(newExpr, input) } def scalarExprToProtoWithReturnType( @@ -1781,7 +1818,17 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case _ => return None } - val aggExprs = aggregateExpressions.map(aggExprToProto(_, output)) + val binding = if (mode == CometAggregateMode.Final) { + // In final mode, the aggregate expressions are bound to the output of the + // child and partial aggregate expressions buffer attributes produced by partial + // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, + // we don't have to do this because we don't use the merging expression. + false + } else { + true + } + + val aggExprs = aggregateExpressions.map(aggExprToProto(_, output, binding)) if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) && aggExprs.forall(_.isDefined)) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index c2faa6544..3b4fb1c99 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -38,6 +38,25 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus * Test suite dedicated to Comet native aggregate operator */ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ + + test("Final aggregation should not bind to the input of partial aggregation") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 10000, 10, dictionaryEnabled) + withParquetTable(path.toUri.toString, "tbl") { + val df = sql("SELECT * FROM tbl").groupBy("_g1").agg(sum($"_3" + $"_g3")) + checkSparkAnswer(df) + } + } + } + } + } test("Ensure traversed operators during finding first partial aggregation are all native") { withTable("lineitem", "part") {