Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support bitwise aggregate functions #197

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ use datafusion::{
physical_expr::{
execution_props::ExecutionProps,
expressions::{
in_list, BinaryExpr, CaseExpr, CastExpr, Column, Count, FirstValue, InListExpr,
IsNotNullExpr, IsNullExpr, LastValue, Literal as DataFusionLiteral, Max, Min,
NegativeExpr, NotExpr, Sum, UnKnownColumn,
in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count,
FirstValue, InListExpr, IsNotNullExpr, IsNullExpr, LastValue,
Literal as DataFusionLiteral, Max, Min, NegativeExpr, NotExpr, Sum, UnKnownColumn,
},
functions::create_physical_expr,
AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr,
Expand Down Expand Up @@ -940,6 +940,21 @@ impl PhysicalPlanner {
vec![],
)))
}
AggExprStruct::BitAndAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(BitAnd::new(child, "bit_and", datatype)))
}
AggExprStruct::BitOrAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(BitOr::new(child, "bit_or", datatype)))
}
AggExprStruct::BitXorAgg(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
Ok(Arc::new(BitXor::new(child, "bit_xor", datatype)))
}
}
}

Expand Down
18 changes: 18 additions & 0 deletions core/src/execution/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ message AggExpr {
Avg avg = 6;
First first = 7;
Last last = 8;
BitAndAgg bitAndAgg = 9;
BitOrAgg bitOrAgg = 10;
BitXorAgg bitXorAgg = 11;
}
}

Expand Down Expand Up @@ -130,6 +133,21 @@ message Last {
bool ignore_nulls = 3;
}

message BitAndAgg {
Expr child = 1;
DataType datatype = 2;
}

message BitOrAgg {
Expr child = 1;
DataType datatype = 2;
}

message BitXorAgg {
Expr child = 1;
DataType datatype = 2;
}

message Literal {
oneof value {
bool bool_val = 1;
Expand Down
60 changes: 59 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Final, First, Last, Max, Min, Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Final, First, Last, Max, Min, Partial, Sum}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
Expand Down Expand Up @@ -186,6 +186,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

private def bitwiseAggTypeSupported(dt: DataType): Boolean = {
dt match {
case _: IntegerType | LongType | ShortType | ByteType => true
case _ => false
}
}
Comment on lines +189 to +194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the limitation due to DataFusion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Spark bit aggregate function supports more types than DataFusion, can you add a comment here and open a ticket at DataFusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark bitwise aggregates only supports INTEGRAL type. If I try other types, Spark throws Exception:

Cannot resolve "bit_and(col1)" due to data type mismatch: Parameter 1 requires the "INTEGRAL" type, however "col1" has the type ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see.


def aggExprToProto(
aggExpr: AggregateExpression,
inputs: Seq[Attribute],
Expand Down Expand Up @@ -326,6 +333,57 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
} else {
None
}
case bitAnd @ BitAndAgg(child) if bitwiseAggTypeSupported(bitAnd.dataType) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(bitAnd.dataType)

if (childExpr.isDefined && dataType.isDefined) {
val bitAndBuilder = ExprOuterClass.BitAndAgg.newBuilder()
bitAndBuilder.setChild(childExpr.get)
bitAndBuilder.setDatatype(dataType.get)

Some(
ExprOuterClass.AggExpr
.newBuilder()
.setBitAndAgg(bitAndBuilder)
.build())
} else {
None
}
case bitOr @ BitOrAgg(child) if bitwiseAggTypeSupported(bitOr.dataType) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(bitOr.dataType)

if (childExpr.isDefined && dataType.isDefined) {
val bitOrBuilder = ExprOuterClass.BitOrAgg.newBuilder()
bitOrBuilder.setChild(childExpr.get)
bitOrBuilder.setDatatype(dataType.get)

Some(
ExprOuterClass.AggExpr
.newBuilder()
.setBitOrAgg(bitOrBuilder)
.build())
} else {
None
}
case bitXor @ BitXorAgg(child) if bitwiseAggTypeSupported(bitXor.dataType) =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(bitXor.dataType)

if (childExpr.isDefined && dataType.isDefined) {
val bitXorBuilder = ExprOuterClass.BitXorAgg.newBuilder()
bitXorBuilder.setChild(childExpr.get)
bitXorBuilder.setDatatype(dataType.get)

Some(
ExprOuterClass.AggExpr
.newBuilder()
.setBitXorAgg(bitXorBuilder)
.build())
} else {
None
}

case fn =>
emitWarning(s"unsupported Spark aggregate function: $fn")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,56 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("bitwise aggregate") {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet")
sql(
s"insert into $table values(4, 1, 1, 3), (4, 1, 1, 3), (3, 3, 1, 4)," +
" (2, 4, 2, 5), (1, 3, 2, 6), (null, 1, 1, 7)")
val expectedNumOfCometAggregates = 2
checkSparkAnswerAndNumOfAggregates(
"SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
" BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
" BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
" BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4) FROM test",
expectedNumOfCometAggregates)

// Make sure the combination of BITWISE aggregates and other aggregates work OK
checkSparkAnswerAndNumOfAggregates(
"SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
" BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
" BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
" BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), MIN(col1), COUNT(col1) FROM test",
expectedNumOfCometAggregates)

checkSparkAnswerAndNumOfAggregates(
"SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
" BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
" BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
" BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4), col3 FROM test GROUP BY col3",
expectedNumOfCometAggregates)

// Make sure the combination of BITWISE aggregates and other aggregates work OK
// with group by
checkSparkAnswerAndNumOfAggregates(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add the comments why this test is needed, it looks like the same as one before that, just MIN/COUNT added

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments were added. Thanks

"SELECT BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," +
" BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," +
" BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," +
" BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4)," +
" MIN(col1), COUNT(col1), col3 FROM test GROUP BY col3",
expectedNumOfCometAggregates)
}
}
}
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down
Loading