Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Count(Distinct) and similar aggregation functions #42

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,18 @@ class CometSparkSessionExtensions
newOp match {
case Some(nativeOp) =>
val modes = aggExprs.map(_.mode).distinct
assert(modes.length == 1)
// The aggExprs could be empty. For example, if the aggregate functions only have
// distinct aggregate functions or only have group by, the aggExprs is empty and
// modes is empty too. If aggExprs is not empty, we need to verify all the aggregates
// have the same mode.
assert(modes.length == 1 || modes.length == 0)
CometHashAggregateExec(
nativeOp,
op,
groupingExprs,
aggExprs,
child.output,
modes.head,
if (modes.nonEmpty) Some(modes.head) else None,
child)
case None =>
op
Expand Down
123 changes: 80 additions & 43 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometHashAggregateExec, CometPlan, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -1632,60 +1632,97 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
_,
groupingExpressions,
aggregateExpressions,
_,
aggregateAttributes,
_,
resultExpressions,
child) if isCometOperatorEnabled(op.conf, "aggregate") =>
val modes = aggregateExpressions.map(_.mode).distinct

if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
return None
}

val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ => return None
}

val output = mode match {
case CometAggregateMode.Partial => child.output
case CometAggregateMode.Final =>
// Assuming `Final` always follows `Partial` aggregation, this find the first
// `Partial` aggregation and get the input attributes from it.
child.collectFirst { case CometHashAggregateExec(_, _, _, _, input, Partial, _) =>
input
} match {
case Some(input) => input
case _ => return None
}
case _ => return None
}

val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))

if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
// In some of the cases, the aggregateExpressions could be empty.
// For example, if the aggregate functions only have group by or if the aggregate
// functions only have distinct aggregate functions:
//
// SELECT COUNT(distinct col2), col1 FROM test group by col1
// +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
// +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
// +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
// If the aggregateExpressions is empty, we only want to build groupingExpressions,
// and skip processing of aggregateExpressions.
if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++
aggregateExpressions.map(_.resultAttribute)
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
emitWarning(s"Unsupported result expressions found in: ${resultExpressions}")
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
None
val modes = aggregateExpressions.map(_.mode).distinct

if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
return None
}

val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ => return None
}

val output = mode match {
case CometAggregateMode.Partial => child.output
case CometAggregateMode.Final =>
// Assuming `Final` always follows `Partial` aggregation, this find the first
// `Partial` aggregation and get the input attributes from it.
// During finding partial aggregation, we must ensure all traversed op are
// native operators. If not, we should fallback to Spark.
var seenNonNativeOp = false
var partialAggInput: Option[Seq[Attribute]] = None
child.transformDown {
case op if !op.isInstanceOf[CometPlan] =>
seenNonNativeOp = true
op
case op @ CometHashAggregateExec(_, _, _, _, input, Some(Partial), _) =>
if (!seenNonNativeOp && partialAggInput.isEmpty) {
partialAggInput = Some(input)
}
op
}

if (partialAggInput.isDefined) {
partialAggInput.get
} else {
return None
}
case _ => return None
}

val aggExprs = aggregateExpressions.map(aggExprToProto(_, output))
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
emitWarning(s"Unsupported result expressions found in: ${resultExpressions}")
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
None
}
}

case op if isCometSink(op) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ case class CometHashAggregateExec(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
input: Seq[Attribute],
mode: AggregateMode,
mode: Option[AggregateMode],
child: SparkPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"tbl",
dictionaryEnabled) {
checkSparkAnswer(
"SELECT _2, SUM(_1), MIN(_1), MAX(_1), COUNT(_1), AVG(_1) FROM tbl GROUP BY _2")
"SELECT _2, SUM(_1), SUM(DISTINCT _1), MIN(_1), MAX(_1), COUNT(_1), COUNT(DISTINCT _1), AVG(_1) FROM tbl GROUP BY _2")
}
}
}
Expand All @@ -402,6 +402,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g1, _g2, SUM(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, COUNT(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, SUM(DISTINCT _3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, COUNT(DISTINCT _3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, MIN(_3), MAX(_3) FROM tbl GROUP BY _g1, _g2")
checkSparkAnswer("SELECT _g1, _g2, AVG(_3) FROM tbl GROUP BY _g1, _g2")
}
Expand Down Expand Up @@ -432,8 +434,12 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
makeParquetFile(path, numValues, numGroups, dictionaryEnabled)
withParquetTable(path.toUri.toString, "tbl") {
checkSparkAnswer("SELECT _g3, _g4, SUM(_3), SUM(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, SUM(DISTINCT _3), SUM(DISTINCT _4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, COUNT(_3), COUNT(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, COUNT(DISTINCT _3), COUNT(DISTINCT _4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer(
"SELECT _g3, _g4, MIN(_3), MAX(_3), MIN(_4), MAX(_4) FROM tbl GROUP BY _g3, _g4")
checkSparkAnswer("SELECT _g3, _g4, AVG(_3), AVG(_4) FROM tbl GROUP BY _g3, _g4")
Expand Down Expand Up @@ -461,7 +467,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
(1 to 4).foreach { col =>
(1 to 14).foreach { gCol =>
checkSparkAnswer(s"SELECT _g$gCol, SUM(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, SUM(DISTINCT _$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, COUNT(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, COUNT(DISTINCT _$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(
s"SELECT _g$gCol, MIN(_$col), MAX(_$col) FROM tbl GROUP BY _g$gCol")
checkSparkAnswer(s"SELECT _g$gCol, AVG(_$col) FROM tbl GROUP BY _g$gCol")
Expand Down Expand Up @@ -701,6 +711,61 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("distinct") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq(true, false).foreach { bosonColumnShuffleEnabled =>
withSQLConf(
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> bosonColumnShuffleEnabled.toString) {
Seq(true, false).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "test"
withTable(table) {
sql(s"create table $table(col1 int, col2 int, col3 int) using parquet")
sql(
s"insert into $table values(1, 1, 1), (1, 1, 1), (1, 3, 1), (1, 4, 2), (5, 3, 2)")

var expectedNumOfBosonAggregates = 2

checkSparkAnswerAndNumOfAggregates(
s"SELECT DISTINCT(col2) FROM $table",
expectedNumOfBosonAggregates)

expectedNumOfBosonAggregates = 4

checkSparkAnswerAndNumOfAggregates(
s"SELECT COUNT(distinct col2) FROM $table",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT COUNT(distinct col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT SUM(distinct col2) FROM $table",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
s"SELECT SUM(distinct col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)

checkSparkAnswerAndNumOfAggregates(
"SELECT COUNT(distinct col2), SUM(distinct col2), col1, COUNT(distinct col2)," +
s" SUM(distinct col2) FROM $table group by col1",
expectedNumOfBosonAggregates)

expectedNumOfBosonAggregates = 1
checkSparkAnswerAndNumOfAggregates(
"SELECT COUNT(col2), MIN(col2), COUNT(DISTINCT col2), SUM(col2)," +
s" SUM(DISTINCT col2), COUNT(DISTINCT col2), col1 FROM $table group by col1",
expectedNumOfBosonAggregates)
}
}
}
}
}
}
}

protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = {
val df = sql(query)
checkSparkAnswer(df)
Expand Down