From 9eb8a887fc119a807a98f443f8bd6e895b201e25 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 12 Mar 2024 10:46:06 -0700 Subject: [PATCH] Support bitwise aggregate functions --- core/src/execution/datafusion/planner.rs | 17 +++++- core/src/execution/proto/expr.proto | 18 ++++++ .../apache/comet/serde/QueryPlanSerde.scala | 60 ++++++++++++++++++- .../comet/exec/CometAggregateSuite.scala | 34 +++++++++++ 4 files changed, 127 insertions(+), 2 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ef2787f83..ffc2c6fde 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -27,7 +27,7 @@ use datafusion::{ physical_expr::{ execution_props::ExecutionProps, expressions::{ - in_list, BinaryExpr, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr, + in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, Sum, UnKnownColumn, }, @@ -940,6 +940,21 @@ impl PhysicalPlanner { vec![], ))) } + AggExprStruct::BitAndAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(BitAnd::new(child, "bit_and", datatype))) + } + AggExprStruct::BitOrAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(BitOr::new(child, "bit_or", datatype))) + } + AggExprStruct::BitXorAgg(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(BitXor::new(child, "bit_xor", datatype))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 8aa81b767..e8d35d19a 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -88,6 +88,9 @@ message AggExpr { Avg avg = 6; First first = 7; Last last = 8; + BitAndAgg bitAndAgg = 9; + BitOrAgg bitOrAgg = 10; + BitXorAgg bitXorAgg = 11; } } @@ -130,6 +133,21 @@ message Last { bool ignore_nulls = 3; } +message BitAndAgg { + Expr child = 1; + DataType datatype = 2; +} + +message BitOrAgg { + Expr child = 1; + DataType datatype = 2; +} + +message BitXorAgg { + Expr child = 1; + DataType datatype = 2; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 08a499b06..463aa19a3 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, Partial, Sum} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} @@ -186,6 +186,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } + private def bitwiseAggTypeSupported(dt: DataType): Boolean = { + dt match { + case _: IntegerType => true + case _ => false + } + } + def aggExprToProto( aggExpr: AggregateExpression, inputs: Seq[Attribute], @@ -326,6 +333,57 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } + case bitAnd @ BitAndAgg(child) if bitwiseAggTypeSupported(bitAnd.dataType) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(bitAnd.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder() + bitAndBuilder.setChild(childExpr.get) + bitAndBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitAndAgg(bitAndBuilder) + .build()) + } else { + None + } + case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(bitOr.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder() + bitOrBuilder.setChild(childExpr.get) + bitOrBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitOrAgg(bitOrBuilder) + .build()) + } else { + None + } + case bitXor @ BitXorAgg(child) if bitwiseAggTypeSupported(bitXor.dataType) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(bitXor.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder() + bitXorBuilder.setChild(childExpr.get) + bitXorBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setBitXorAgg(bitXorBuilder) + .build()) + } else { + None + } case fn => emitWarning(s"unsupported Spark aggregate function: $fn") diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 8a68a925e..2bd5556a5 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -947,6 +947,40 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise aggregate") { + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(4, 1, 1), (4, 1, 1), (3, 3, 1)," + + " (2, 4, 2), (1, 3, 2), (null, 1, 1)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerAndNumOfAggregates( + "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1), MIN(col1), COUNT(col1) FROM test", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1), col3 FROM test GROUP BY col3", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + "SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1), MIN(col1), COUNT(col1), col3 FROM test GROUP BY col3", + expectedNumOfCometAggregates) + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df)