Skip to content

Commit

Permalink
feat: Support Count(Distinct) and similar aggregation functions (#42)
Browse files Browse the repository at this point in the history
Co-authored-by: Huaxin Gao <[email protected]>
  • Loading branch information
huaxingao and Huaxin Gao authored Feb 20, 2024
1 parent f7b88e9 commit 2820327
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,18 @@ class CometSparkSessionExtensions
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
assert(modes.length == 1)
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the aggregates
// have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
modes.head,
if (modes.nonEmpty) Some(modes.head) else None,
child)
case None =>
op
Expand Down
123 changes: 80 additions & 43 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -1653,60 +1653,97 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
_,
groupingExpressions,
aggregateExpressions,
_,
aggregateAttributes,
_,
resultExpressions,
child) if isCometOperatorEnabled(op.conf, "aggregate") =>
val modes = aggregateExpressions.map(_.mode).distinct

if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
return None
}

val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ => return None
}

val output = mode match {
case CometAggregateMode.Partial => child.output
case CometAggregateMode.Final =>
// Assuming `Final` always follows `Partial` aggregation, this find the first
// `Partial` aggregation and get the input attributes from it.
child.collectFirst { case CometHashAggregateExec(_, _, _, _, input, Partial, _) =>
input
} match {
case Some(input) => input
case _ => return None
}
case _ => return None
}

val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))

if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
// In some of the cases, the aggregateExpressions could be empty.
// For example, if the aggregate functions only have group by or if the aggregate
// functions only have distinct aggregate functions:
//
// SELECT COUNT(distinct col2), col1 FROM test group by col1
// +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
// +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
// +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
// If the aggregateExpressions is empty, we only want to build groupingExpressions,
// and skip processing of aggregateExpressions.
if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateExpressions.map(_.resultAttribute)
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
emitWarning(s"Unsupported result expressions found in: ${resultExpressions}")
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
None
val modes = aggregateExpressions.map(_.mode).distinct

if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
return None
}

val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ => return None
}

val output = mode match {
case CometAggregateMode.Partial => child.output
case CometAggregateMode.Final =>
// Assuming `Final` always follows `Partial` aggregation, this find the first
// `Partial` aggregation and get the input attributes from it.
// During finding partial aggregation, we must ensure all traversed op are
// native operators. If not, we should fallback to Spark.
var seenNonNativeOp = false
var partialAggInput: Option[Seq[Attribute]] = None
child.transformDown {
case op if !op.isInstanceOf[CometPlan] =>
seenNonNativeOp = true
op
case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _) =>
if (!seenNonNativeOp && partialAggInput.isEmpty) {
partialAggInput = Some(input)
}
op
}

if (partialAggInput.isDefined) {
partialAggInput.get
} else {
return None
}
case _ => return None
}

val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
emitWarning(s"Unsupported result expressions found in: ${resultExpressions}")
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
None
}
}

case op if isCometSink(op) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ case class CometHashAggregateExec(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
input: Seq[Attribute],
mode: AggregateMode,
mode: Option[AggregateMode],
child: SparkPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"tbl",
dictionaryEnabled) {
checkSparkAnswer(
"SELECT _2, SUM(_1), MIN(_1), MAX(_1), COUNT(_1), AVG(_1) FROM tbl GROUP BY _2")
"SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2")
}
}
}
Expand All @@ -423,6 +423,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, COUNT(DISTINCT _3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, MIN(_3), MAX(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, AVG(_3) FROM tbl GROUP BY _g1, _g2")
}
Expand Down Expand Up @@ -453,8 +455,12 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, COUNT(_3), COUNT(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, COUNT(DISTINCT _3), COUNT(DISTINCT _4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, MIN(_3), MAX(_3), MIN(_4), MAX(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer("SELECT _g3, _g4, AVG(_3), AVG(_4) FROM tbl GROUP BY _g3, _g4")
Expand Down Expand Up @@ -482,7 +488,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
(1 to 4).foreach { col =>
(1 to 14).foreach { gCol =>
checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, COUNT(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, COUNT(DISTINCT _$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, MIN(_$col), MAX(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, AVG(_$col) FROM tbl GROUP BY _g$gCol")
Expand Down Expand Up @@ -722,6 +732,61 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("distinct") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { bosonColumnShuffleEnabled =>
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> bosonColumnShuffleEnabled.toString) {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col1 int, col2 int, col3 int) using parquet")
sql(
s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1), (1, 4, 2), (5, 3, 2)")

var expectedNumOfBosonAggregates = 2

checkSparkAnswerAndNumOfAggregates(
s"SELECT DISTINCT(col2) FROM $table",
expectedNumOfBosonAggregates)

expectedNumOfBosonAggregates = 4

checkSparkAnswerAndNumOfAggregates(
s"SELECT COUNT(distinct col2) FROM $table",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT COUNT(distinct col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT SUM(distinct col2) FROM $table",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT SUM(distinct col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
"SELECT COUNT(distinct col2), SUM(distinct col2), col1, COUNT(distinct col2)," +
s" SUM(distinct col2) FROM $table group by col1",
expectedNumOfBosonAggregates)

expectedNumOfBosonAggregates = 1
checkSparkAnswerAndNumOfAggregates(
"SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), SUM(col2)," +
s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)
}
}
}
}
}
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down

0 comments on commit 2820327

Please sign in to comment.