diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index e57565365a..fec6197d6e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} -import org.apache.spark.sql.{CometTestBase, Row} +import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions.col @@ -61,25 +61,6 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar import testImplicits._ - test("Native shuffle with dictionary of binary") { - Seq("true", "false").foreach { dictionaryEnabled => - withSQLConf( - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { - withParquetTable( - (0 until 1000).map(i => (i % 5, (i % 5).toString.getBytes())), - "tbl", - dictionaryEnabled.toBoolean) { - val shuffled = sql("SELECT * FROM tbl").repartition(2, $"_2") - - checkCometExchange(shuffled, 1, true) - checkSparkAnswer(shuffled) - } - } - } - } - test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => @@ -93,8 +74,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2", $"_3") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } } } @@ -113,8 +93,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } } } @@ -134,9 +113,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkSparkAnswer(df) // Array map key array element fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") { @@ -145,9 +123,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkSparkAnswer(df) // Array map value array element fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } withParquetTable((0 until 50).map(i => (Map((i, i.toString) -> i), i + 1)), "tbl") { @@ -156,9 +133,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkSparkAnswer(df) // Struct map key array element fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") { @@ -167,9 +143,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkSparkAnswer(df) // Struct map value array element fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } } } @@ -192,9 +167,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_2") - checkSparkAnswer(df) // Map array element fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } } } @@ -285,8 +259,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Byte key @@ -309,8 +282,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Short key @@ -333,8 +305,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Int key @@ -357,8 +328,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Long key @@ -381,8 +351,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Float key @@ -405,8 +374,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Double key @@ -429,8 +397,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Date key @@ -455,8 +422,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Timestamp key @@ -483,8 +449,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Decimal key @@ -511,8 +476,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // String key @@ -535,8 +499,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } // Binary key @@ -561,8 +524,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_13") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } } } @@ -593,8 +555,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2", $"_3", $"_4", $"_5") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } } } @@ -614,9 +575,8 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_1") - checkSparkAnswer(df) // Nested array fallback to Spark shuffle for now - checkCometExchange(df, 0, false) + checkShuffleAnswer(df, 0) } } } @@ -637,8 +597,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .repartition(numPartitions, $"_1", $"_2") .sortWithinPartitions($"_1") - checkSparkAnswer(df) - checkCometExchange(df, 1, false) + checkShuffleAnswer(df, 1) } } } @@ -663,7 +622,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar val df = sql( "select a, b, count(distinct h) from tbl_a, tbl_b " + "where c = e and b = '2222222' and a not like '2' group by a, b") - checkSparkAnswer(df) + checkShuffleAnswer(df, 4) } } } @@ -681,7 +640,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar "SELECT * FROM tbl") .count()) val shuffled = sql("SELECT * FROM tbl").repartition(numPartitions, $"_1") - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -696,7 +655,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar "SELECT * FROM tbl") .count()) val shuffled = sql("SELECT * FROM tbl").select($"_1").repartition(numPartitions, $"_1") - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -711,7 +670,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar "SELECT * FROM tbl") .count()) val shuffled = sql("SELECT * FROM tbl").select($"_1").repartition(numPartitions, $"_1") - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -750,7 +709,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar $"_20").foreach { col => readParquetFile(path.toString) { df => val shuffled = df.select(col).repartition(numPartitions, col) - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -766,7 +725,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar readParquetFile(dir.getCanonicalPath) { df => { val shuffled = df.repartition(numPartitions, $"dec") - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -782,7 +741,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar data.write.parquet(dir.getCanonicalPath) readParquetFile(dir.getCanonicalPath) { df => val shuffled = df.repartition(numPartitions, $"dec") - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } } @@ -847,24 +806,21 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar df.repartitionByRange(numPartitions, $"_2").limit(2).repartition(numPartitions, $"_1") // 3 exchanges are expected: 1) shuffle to repartition by range, 2) shuffle to global limit, 3) hash shuffle - checkCometExchange(shuffled1, 3, false) - checkSparkAnswer(shuffled1) + checkShuffleAnswer(shuffled1, 3) val shuffled2 = df .repartitionByRange(numPartitions, $"_2") .limit(2) .repartition(numPartitions, $"_1", $"_2") - checkCometExchange(shuffled2, 3, false) - checkSparkAnswer(shuffled2) + checkShuffleAnswer(shuffled2, 3) val shuffled3 = df .repartitionByRange(numPartitions, $"_2") .limit(2) .repartition(numPartitions, $"_2", $"_1") - checkCometExchange(shuffled3, 3, false) - checkSparkAnswer(shuffled3) + checkShuffleAnswer(shuffled3, 3) } } } @@ -906,8 +862,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .filter($"_1" > 1) // 2 Comet shuffle exchanges are expected - checkCometExchange(shuffled1, 2, false) - checkSparkAnswer(shuffled1) + checkShuffleAnswer(shuffled1, 2) val shuffled2 = df .repartitionByRange(10, $"_2") @@ -916,8 +871,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .filter($"_1" > 1) // 2 Comet shuffle exchanges are expected, if columnar shuffle is enabled - checkCometExchange(shuffled2, 2, false) - checkSparkAnswer(shuffled2) + checkShuffleAnswer(shuffled2, 2) } } @@ -927,8 +881,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar val shuffled = df.repartition(1) - checkCometExchange(shuffled, 1, false) - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) } } @@ -937,8 +890,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) val shuffled = df.repartition(201, $"_1") - checkCometExchange(shuffled, 1, false) - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) // Materialize the shuffled data shuffled.collect() @@ -957,6 +909,15 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar assert(metrics("shuffleWriteTime").value > 0) } } + + /** + * Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet + * exchange operators. + */ + private def checkShuffleAnswer(df: DataFrame, expectedNum: Int): Unit = { + checkCometExchange(df, expectedNum, false) + checkSparkAnswer(df) + } } class CometAsyncShuffleSuite extends CometColumnarShuffleSuite { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 9e2b17e4ed..906fc99570 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -23,7 +23,7 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.apache.hadoop.fs.Path -import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.{CometTestBase, DataFrame} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions.col @@ -46,8 +46,8 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper import testImplicits._ - // TODO: this test takes ~5mins to run, we should reduce the test time. - test("fix: Too many task completion listener of ArrowReaderIterator causes OOM") { + // TODO: this test takes a long to run, we should reduce the test time. + ignore("fix: Too many task completion listener of ArrowReaderIterator causes OOM") { withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "1") { withParquetTable((0 until 100000).map(i => (1, (i + 1).toLong)), "tbl") { assert( @@ -58,7 +58,7 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } test("native shuffle: different data type") { - Seq(false).foreach { execEnabled => + Seq(true, false).foreach { execEnabled => Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") @@ -75,12 +75,7 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper val shuffled = df .select($"_1") .repartition(10, col(c)) - checkCometExchange(shuffled, 1, true) - if (execEnabled) { - checkSparkAnswerAndOperator(shuffled) - } else { - checkSparkAnswer(shuffled) - } + checkShuffleAnswer(shuffled, 1, checkNativeOperators = execEnabled) } } } @@ -93,30 +88,34 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) val shuffled1 = df.repartition(10, $"_1") - - checkCometExchange(shuffled1, 1, true) - checkSparkAnswer(shuffled1) + checkShuffleAnswer(shuffled1, 1) val shuffled2 = df.repartition(10, $"_1", $"_2") - - checkCometExchange(shuffled2, 1, true) - checkSparkAnswer(shuffled2) + checkShuffleAnswer(shuffled2, 1) val shuffled3 = df.repartition(10, $"_2", $"_1") - - checkCometExchange(shuffled3, 1, true) - checkSparkAnswer(shuffled3) + checkShuffleAnswer(shuffled3, 1) } } - test("columnar shuffle: single partition") { + test("native shuffle: single partition") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) val shuffled = df.repartition(1) + checkShuffleAnswer(shuffled, 1) + } + } - checkCometExchange(shuffled, 1, true) - checkSparkAnswer(shuffled) + test("native shuffle with dictionary of binary") { + Seq("true", "false").foreach { dictionaryEnabled => + withParquetTable( + (0 until 1000).map(i => (i % 5, (i % 5).toString.getBytes())), + "tbl", + dictionaryEnabled.toBoolean) { + val shuffled = sql("SELECT * FROM tbl").repartition(2, $"_2") + checkShuffleAnswer(shuffled, 1) + } } } @@ -131,8 +130,7 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper .filter($"_1" > 1) // 2 Comet shuffle exchanges are expected - checkCometExchange(shuffled1, 2, true) - checkSparkAnswer(shuffled1) + checkShuffleAnswer(shuffled1, 2) val shuffled2 = df .repartitionByRange(10, $"_2") @@ -143,16 +141,14 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper // Because the first exchange from the bottom is range exchange which native shuffle // doesn't support. So Comet exec operators stop before the first exchange and thus // there is no Comet exchange. - checkCometExchange(shuffled2, 0, true) - checkSparkAnswer(shuffled2) + checkShuffleAnswer(shuffled2, 0) } } test("grouped aggregate: native shuffle") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT count(_2), sum(_2) FROM tbl GROUP BY _1") - checkCometExchange(df, 1, true) - checkSparkAnswerAndOperator(df) + checkShuffleAnswer(df, 1, checkNativeOperators = true) } } @@ -161,8 +157,7 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper val df = sql("SELECT * FROM tbl").sortWithinPartitions($"_1".desc) val shuffled = df.repartition(10, $"_1") - checkCometExchange(shuffled, 1, true) - checkSparkAnswer(shuffled) + checkShuffleAnswer(shuffled, 1) // Materialize the shuffled data shuffled.collect() @@ -193,14 +188,29 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } - test("fix: comet native shuffle with binary data") { + test("fix: Comet native shuffle with binary data") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") val shuffled = df.repartition(1, $"binary") + checkShuffleAnswer(shuffled, 1) + } + } - checkCometExchange(shuffled, 1, true) - checkSparkAnswer(shuffled) + /** + * Checks that `df` produces the same answer as Spark does, and has the `expectedNum` Comet + * exchange operators. When `checkNativeOperators` is true, this also checks that all operators + * used by `df` are Comet native operators. + */ + private def checkShuffleAnswer( + df: DataFrame, + expectedNum: Int, + checkNativeOperators: Boolean = false): Unit = { + checkCometExchange(df, expectedNum, true) + if (checkNativeOperators) { + checkSparkAnswerAndOperator(df) + } else { + checkSparkAnswer(df) } } }