diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index f9caee9d9..1858f765c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -140,14 +140,14 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = conf( - "spark.comet.columnar.shuffle.enabled") - .doc( - "Force Comet to only use columnar shuffle for CometScan and Spark regular operators. " + - "If this is enabled, Comet native shuffle will not be enabled but only Arrow shuffle. " + - "By default, this config is false.") - .booleanConf - .createWithDefault(false) + val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.columnar.shuffle.enabled") + .doc( + "Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. " + + "If this is enabled, Comet prefers columnar shuffle than native shuffle. " + + "By default, this config is true.") + .booleanConf + .createWithDefault(true) val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.shuffle.enforceMode.enabled") diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index eb731f9d0..595c0a427 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -66,7 +66,7 @@ class NativeUtil { val arrowArray = ArrowArray.allocateNew(allocator) Data.exportVector( allocator, - getFieldVector(valueVector), + getFieldVector(valueVector, "export"), provider, arrowArray, arrowSchema) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 7d920e1be..2300e109a 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -242,7 +242,7 @@ object Utils { } } - getFieldVector(valueVector) + getFieldVector(valueVector, "serialize") case c => throw new SparkException( @@ -253,14 +253,15 @@ object Utils { (fieldVectors, provider) } - def getFieldVector(valueVector: ValueVector): FieldVector = { + def getFieldVector(valueVector: ValueVector, reason: String): FieldVector = { valueVector match { case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | _: FixedSizeBinaryVector | _: TimeStampMicroVector) => v.asInstanceOf[FieldVector] - case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}") + case _ => + throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") } } } diff --git a/core/Cargo.lock b/core/Cargo.lock index 3fb7b5f62..52f105591 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "android-tzdata" @@ -90,9 +90,9 @@ checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] name = "arc-swap" @@ -327,13 +327,13 @@ checksum = "0c24e9d990669fbd16806bff449e4ac644fd9b1fca014760087732fe4102f131" [[package]] name = "async-trait" -version = "0.1.79" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -438,9 +438,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" @@ -468,9 +468,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" dependencies = [ "jobserver", "libc", @@ -490,14 +490,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -576,9 +576,9 @@ checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] name = "combine" -version = "4.6.6" +version = "4.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" dependencies = [ "bytes", "memchr", @@ -625,7 +625,7 @@ dependencies = [ "parquet-format", "paste", "pprof", - "prost 0.12.3", + "prost 0.12.4", "prost-build", "rand", "regex", @@ -643,12 +643,12 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.0" +version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "unicode-width", ] @@ -923,8 +923,8 @@ dependencies = [ "datafusion-common", "paste", "sqlparser", - "strum 0.26.2", - "strum_macros 0.26.2", + "strum", + "strum_macros", ] [[package]] @@ -1044,7 +1044,7 @@ dependencies = [ "datafusion-expr", "log", "sqlparser", - "strum 0.26.2", + "strum", ] [[package]] @@ -1092,9 +1092,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" [[package]] name = "equivalent" @@ -1227,7 +1227,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -1272,9 +1272,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "libc", @@ -1526,9 +1526,9 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jobserver" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2" dependencies = [ "libc", ] @@ -1794,9 +1794,9 @@ dependencies = [ [[package]] name = "num" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41" dependencies = [ "num-bigint", "num-complex", @@ -2142,9 +2142,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "a56dea16b0a29e94408b9aa5e2940a4eedbd128a1ba20e8f7ae60fd3d465af0e" dependencies = [ "unicode-ident", ] @@ -2161,12 +2161,12 @@ dependencies = [ [[package]] name = "prost" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a" +checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922" dependencies = [ "bytes", - "prost-derive 0.12.3", + "prost-derive 0.12.4", ] [[package]] @@ -2204,15 +2204,15 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" +checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2236,9 +2236,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -2370,9 +2370,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" [[package]] name = "ryu" @@ -2434,14 +2434,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] name = "serde_json" -version = "1.0.115" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -2545,7 +2545,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2566,32 +2566,13 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" dependencies = [ - "strum_macros 0.26.2", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.57", + "strum_macros", ] [[package]] @@ -2604,7 +2585,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2649,9 +2630,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.57" +version = "2.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a6ae1e52eb25aab8f3fb9fca13be982a373b8f1157ca14b897a825ba4a2d35" +checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a" dependencies = [ "proc-macro2", "quote", @@ -2687,7 +2668,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2790,7 +2771,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2823,7 +2804,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2971,7 +2952,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -2993,7 +2974,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3063,7 +3044,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -3081,7 +3062,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -3116,17 +3097,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.4", - "windows_aarch64_msvc 0.52.4", - "windows_i686_gnu 0.52.4", - "windows_i686_msvc 0.52.4", - "windows_x86_64_gnu 0.52.4", - "windows_x86_64_gnullvm 0.52.4", - "windows_x86_64_msvc 0.52.4", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -3143,9 +3125,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -3161,9 +3143,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -3179,9 +3161,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.4" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -3197,9 +3185,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -3215,9 +3203,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -3233,9 +3221,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -3251,9 +3239,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "zerocopy" @@ -3272,7 +3260,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 72174790b..d1fb252cc 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1039,6 +1039,23 @@ impl PhysicalPlanner { .collect(); let full_schema = Arc::new(Schema::new(all_fields)); + // Because we cast dictionary array to array in scan operator, + // we need to change dictionary type to data type for join filter expression. + let fields: Vec<_> = full_schema + .fields() + .iter() + .map(|f| match f.data_type() { + DataType::Dictionary(_, val_type) => Arc::new(Field::new( + f.name(), + val_type.as_ref().clone(), + f.is_nullable(), + )), + _ => f.clone(), + }) + .collect(); + + let full_schema = Arc::new(Schema::new(fields)); + let physical_expr = self.create_expr(expr, full_schema)?; let (left_field_indices, right_field_indices) = expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?; @@ -1057,6 +1074,14 @@ impl PhysicalPlanner { .into_iter() .map(|i| right.schema().field(i).clone()), ) + // Because we cast dictionary array to array in scan operator, + // we need to change dictionary type to data type for join filter expression. + .map(|f| match f.data_type() { + DataType::Dictionary(_, val_type) => { + Field::new(f.name(), val_type.as_ref().clone(), f.is_nullable()) + } + _ => f.clone(), + }) .collect_vec(); let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); diff --git a/dev/diffs/3.4.2.diff b/dev/diffs/3.4.2.diff index 4154a705d..19bf6dd41 100644 --- a/dev/diffs/3.4.2.diff +++ b/dev/diffs/3.4.2.diff @@ -210,6 +210,51 @@ index 0efe0877e9b..423d3b3d76d 100644 -- -- SELECT_HAVING -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +index cf40e944c09..bdd5be4f462 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants + import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.columnar._ +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -516,7 +516,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils + */ + private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + assert( +- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected) ++ collect(df.queryExecution.executedPlan) { ++ case _: ShuffleExchangeLike => 1 }.size == expected) + } + + test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +index ea5e47ede55..814b92d090f 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +@@ -27,7 +27,7 @@ import org.apache.spark.SparkException + import org.apache.spark.sql.execution.WholeStageCodegenExec + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = collect(aggPlan) { +- case shuffle: ShuffleExchangeExec => shuffle ++ case shuffle: ShuffleExchangeLike => shuffle + } + assert(exchangePlans.length == 1) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56e9520fdab..917932336df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -226,9 +271,54 @@ index 56e9520fdab..917932336df 100644 spark.range(100).write.saveAsTable(s"$dbName.$table2Name") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala -index 9ddb4abe98b..1bebe99f1cc 100644 +index 9ddb4abe98b..1b9269acef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +@@ -43,7 +43,7 @@ import org.apache.spark.sql.connector.FakeV2Provider + import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.aggregate.HashAggregateExec +-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} ++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.expressions.{Aggregator, Window} + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -1981,7 +1981,7 @@ class DataFrameSuite extends QueryTest + fail("Should not have back to back Aggregates") + } + atFirstAgg = true +- case e: ShuffleExchangeExec => atFirstAgg = false ++ case e: ShuffleExchangeLike => atFirstAgg = false + case _ => + } + } +@@ -2291,7 +2291,7 @@ class DataFrameSuite extends QueryTest + checkAnswer(join, df) + assert( + collect(join.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => true }.size === 1) ++ case _: ShuffleExchangeLike => true }.size === 1) + assert( + collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) + val broadcasted = broadcast(join) +@@ -2299,7 +2299,7 @@ class DataFrameSuite extends QueryTest + checkAnswer(join2, df) + assert( + collect(join2.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => true }.size == 1) ++ case _: ShuffleExchangeLike => true }.size == 1) + assert( + collect(join2.queryExecution.executedPlan) { + case e: BroadcastExchangeExec => true }.size === 1) +@@ -2862,7 +2862,7 @@ class DataFrameSuite extends QueryTest + + // Assert that no extra shuffle introduced by cogroup. + val exchanges = collect(df3.queryExecution.executedPlan) { +- case h: ShuffleExchangeExec => h ++ case h: ShuffleExchangeLike => h + } + assert(exchanges.size == 2) + } @@ -3311,7 +3311,8 @@ class DataFrameSuite extends QueryTest assert(df2.isLocal) } @@ -239,8 +329,30 @@ index 9ddb4abe98b..1bebe99f1cc 100644 withTable("tbl") { sql( """ +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +index 7dec558f8df..840dda15033 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} + import org.apache.spark.sql.catalyst.util.sideBySide + import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} ++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.streaming.MemoryStream + import org.apache.spark.sql.expressions.UserDefinedFunction + import org.apache.spark.sql.functions._ +@@ -2254,7 +2254,7 @@ class DatasetSuite extends QueryTest + + // Assert that no extra shuffle introduced by cogroup. + val exchanges = collect(df3.queryExecution.executedPlan) { +- case h: ShuffleExchangeExec => h ++ case h: ShuffleExchangeLike => h + } + assert(exchanges.size == 2) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala -index f33432ddb6f..6160c8d241a 100644 +index f33432ddb6f..060f874ea72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen @@ -261,7 +373,17 @@ index f33432ddb6f..6160c8d241a 100644 case _ => Nil } } -@@ -1238,7 +1242,8 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1187,7 +1191,8 @@ abstract class DynamicPartitionPruningSuiteBase + } + } + +- test("Make sure dynamic pruning works on uncorrelated queries") { ++ test("Make sure dynamic pruning works on uncorrelated queries", ++ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql( + """ +@@ -1238,7 +1243,8 @@ abstract class DynamicPartitionPruningSuiteBase } } @@ -271,7 +393,7 @@ index f33432ddb6f..6160c8d241a 100644 Given("dynamic pruning filter on the build side") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( -@@ -1485,7 +1490,7 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1485,7 +1491,7 @@ abstract class DynamicPartitionPruningSuiteBase } test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " + @@ -280,7 +402,7 @@ index f33432ddb6f..6160c8d241a 100644 withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { Seq( "f.store_id = 1" -> false, -@@ -1729,6 +1734,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat +@@ -1729,6 +1735,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) @@ -290,7 +412,7 @@ index f33432ddb6f..6160c8d241a 100644 } assert(scanOption.isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -index a6b295578d6..a5cb616945a 100644 +index a6b295578d6..91acca4306f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite @@ -303,19 +425,38 @@ index a6b295578d6..a5cb616945a 100644 withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt +@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite + } + } + +-class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { ++// Ignored when Comet is enabled. Comet changes expected query plans. ++class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite ++ with IgnoreCometSuite { + import testImplicits._ + + test("SPARK-35884: Explain Formatted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala -index 2796b1cf154..94591f83c84 100644 +index 2796b1cf154..be7078b38f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal} import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter -+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec} ++import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec, CometSortMergeJoinExec} import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition -@@ -875,6 +876,7 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -815,6 +816,7 @@ class FileBasedDataSourceSuite extends QueryTest + assert(bJoinExec.isEmpty) + val smJoinExec = collect(joinedDF.queryExecution.executedPlan) { + case smJoin: SortMergeJoinExec => smJoin ++ case smJoin: CometSortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } +@@ -875,6 +877,7 @@ class FileBasedDataSourceSuite extends QueryTest val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f @@ -323,7 +464,7 @@ index 2796b1cf154..94591f83c84 100644 } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) -@@ -916,6 +918,7 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -916,6 +919,7 @@ class FileBasedDataSourceSuite extends QueryTest val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f @@ -331,7 +472,7 @@ index 2796b1cf154..94591f83c84 100644 } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) -@@ -1100,6 +1103,8 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -1100,6 +1104,8 @@ class FileBasedDataSourceSuite extends QueryTest val filters = df.queryExecution.executedPlan.collect { case f: FileSourceScanLike => f.dataFilters case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters @@ -388,8 +529,36 @@ index 00000000000..4b31bea33de + } + } +} +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +index 1792b4c32eb..1616e6f39bd 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide + import org.apache.spark.sql.catalyst.plans.PlanTest + import org.apache.spark.sql.catalyst.plans.logical._ + import org.apache.spark.sql.catalyst.rules.RuleExecutor ++import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.joins._ + import org.apache.spark.sql.internal.SQLConf +@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP + val executedPlan = df.queryExecution.executedPlan + val shuffleHashJoins = collect(executedPlan) { + case s: ShuffledHashJoinExec => s ++ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec] + } + assert(shuffleHashJoins.size == 1) + assert(shuffleHashJoins.head.buildSide == buildSide) +@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP + val executedPlan = df.queryExecution.executedPlan + val shuffleMergeJoins = collect(executedPlan) { + case s: SortMergeJoinExec => s ++ case c: CometSortMergeJoinExec => c + } + assert(shuffleMergeJoins.size == 1) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala -index 5125708be32..a1f1ae90796 100644 +index 5125708be32..210ab4f3ce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier @@ -400,7 +569,87 @@ index 5125708be32..a1f1ae90796 100644 import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} -@@ -1369,9 +1370,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -739,7 +740,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + } + } + +- test("test SortMergeJoin (with spill)") { ++ test("test SortMergeJoin (with spill)", ++ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { +@@ -1114,9 +1116,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) + .groupBy($"k1").count() + .queryExecution.executedPlan +- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // No extra shuffle before aggregate +- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true }.size === 2) + }) + } + +@@ -1133,10 +1137,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) + .queryExecution + .executedPlan +- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) + assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join +- assert(collect(plan) { case _: SortExec => true }.size === 3) ++ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3) + }) + + // Test shuffled hash join +@@ -1146,10 +1151,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) + .queryExecution + .executedPlan +- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) +- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join +- assert(collect(plan) { case _: SortExec => true }.size === 3) ++ assert(collect(plan) { ++ case _: SortExec | _: CometSortExec => true }.size === 3) + }) + } + +@@ -1240,12 +1248,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + inputDFs.foreach { case (df1, df2, joinExprs) => + val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full") + assert(collect(smjDF.queryExecution.executedPlan) { +- case _: SortMergeJoinExec => true }.size === 1) ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1) + val smjResult = smjDF.collect() + + val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full") + assert(collect(shjDF.queryExecution.executedPlan) { +- case _: ShuffledHashJoinExec => true }.size === 1) ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // Same result between shuffled hash join and sort merge join + checkAnswer(shjDF, smjResult) + } +@@ -1340,7 +1348,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan + assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) + // Have shuffle before aggregation +- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true }.size === 1) + } + + def getJoinQuery(selectExpr: String, joinType: String): String = { +@@ -1369,9 +1378,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -415,7 +664,7 @@ index 5125708be32..a1f1ae90796 100644 } // Test output ordering is not preserved -@@ -1380,9 +1384,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1380,9 +1392,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0" val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -430,6 +679,16 @@ index 5125708be32..a1f1ae90796 100644 } // Test singe partition +@@ -1392,7 +1407,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2 + |""".stripMargin) + val plan = fullJoinDF.queryExecution.executedPlan +- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true}.size == 1) + checkAnswer(fullJoinDF, Row(100)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index b5b34922694..a72403780c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -443,20 +702,38 @@ index b5b34922694..a72403780c4 100644 protected val baseResourcePath = { // use the same way as `SQLQueryTestSuite` to get the resource path +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +index 525d97e4998..8a3e7457618 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + +- test("external sorting updates peak execution memory") { ++ test("external sorting updates peak execution memory", ++ IgnoreComet("TODO: native CometSort does not update peak execution memory")) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala -index 3cfda19134a..278bb1060c4 100644 +index 3cfda19134a..7590b808def 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala -@@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer +@@ -21,10 +21,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union} +import org.apache.spark.sql.comet.CometScanExec -+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.datasources.FileScanRDD -@@ -1543,6 +1545,12 @@ class SubquerySuite extends QueryTest +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} + import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.test.SharedSparkSession +@@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest fs.inputRDDs().forall( _.asInstanceOf[FileScanRDD].filePartitions.forall( _.files.forall(_.urlEncodedPath.contains("p=0")))) @@ -469,14 +746,78 @@ index 3cfda19134a..278bb1060c4 100644 case _ => false }) } -@@ -2109,6 +2117,7 @@ class SubquerySuite extends QueryTest +@@ -2108,7 +2115,7 @@ class SubquerySuite extends QueryTest + df.collect() val exchanges = collect(df.queryExecution.executedPlan) { - case s: ShuffleExchangeExec => s -+ case s: CometShuffleExchangeExec => s +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s } assert(exchanges.size === 1) } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +index 02990a7a40d..bddf5e1ccc2 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._ + + import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} + import org.apache.spark.sql.catalyst.InternalRow ++import org.apache.spark.sql.comet.CometSortExec + import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} + import org.apache.spark.sql.connector.catalog.TableCapability._ + import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} +@@ -33,7 +34,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, + import org.apache.spark.sql.execution.SortExec + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} +-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} ++import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ +@@ -268,13 +269,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + val groupByColJ = df.groupBy($"j").agg(sum($"i")) + checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(collectFirst(groupByColJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined) + + val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*")) + checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined) + } + } +@@ -334,10 +335,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + + val (shuffleExpected, sortExpected) = groupByExpects + assert(collectFirst(groupBy.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined === shuffleExpected) + assert(collectFirst(groupBy.queryExecution.executedPlan) { + case e: SortExec => e ++ case c: CometSortExec => c + }.isDefined === sortExpected) + } + +@@ -352,10 +354,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + + val (shuffleExpected, sortExpected) = windowFuncExpects + assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined === shuffleExpected) + assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { + case e: SortExec => e ++ case c: CometSortExec => c + }.isDefined === sortExpected) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index cfc8b2cc845..c6fcfd7bd08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -502,6 +843,44 @@ index cfc8b2cc845..c6fcfd7bd08 100644 } } finally { spark.listenerManager.unregister(listener) +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +index cf76f6ca32c..f454128af06 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row} + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} + import org.apache.spark.sql.catalyst.plans.physical ++import org.apache.spark.sql.comet.CometSortMergeJoinExec + import org.apache.spark.sql.connector.catalog.Identifier + import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog + import org.apache.spark.sql.connector.catalog.functions._ +@@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._ + import org.apache.spark.sql.execution.SparkPlan + import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.joins.SortMergeJoinExec + import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.internal.SQLConf._ +@@ -279,13 +280,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { + Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) + } + +- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { ++ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { + // here we skip collecting shuffle operators that are not associated with SMJ + collect(plan) { + case s: SortMergeJoinExec => s ++ case c: CometSortMergeJoinExec => c.originalPlan + }.flatMap(smj => + collect(smj) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + }) + } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index c0ec8a58bd5..4e8bc6ed3c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -547,6 +926,45 @@ index 418ca3430bb..eb8267192f8 100644 Seq("json", "orc", "parquet").foreach { format => withTempPath { path => val dir = path.getCanonicalPath +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +index 743ec41dbe7..9f30d6c8e04 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv + case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child) + case p: ProjectExec => isScanPlanTree(p.child) + case f: FilterExec => isScanPlanTree(f.child) ++ // Comet produces scan plan tree like: ++ // ColumnarToRow ++ // +- ReusedExchange ++ case _: ReusedExchangeExec => false + case _: LeafExecNode => true + case _ => false + } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +index 4b3d3a4b805..56e1e0e6f16 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +@@ -18,7 +18,7 @@ + package org.apache.spark.sql.execution + + import org.apache.spark.rdd.RDD +-import org.apache.spark.sql.{execution, DataFrame, Row} ++import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row} + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions._ + import org.apache.spark.sql.catalyst.plans._ +@@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.test.SharedSparkSession + import org.apache.spark.sql.types._ + +-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { ++// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner ++// comes out with too many difference. Simply ignoring this suite for now. ++class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite { + import testImplicits._ + + setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala index 9e9d717db3b..91a4f9a38d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala @@ -571,11 +989,108 @@ index 9e9d717db3b..91a4f9a38d5 100644 assert(actual == expected) } } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +index 30ce940b032..0d3f6c6c934 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution + + import org.apache.spark.sql.{DataFrame, QueryTest} + import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} ++import org.apache.spark.sql.comet.CometSortExec + import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} + import org.apache.spark.sql.execution.joins.ShuffledJoin + import org.apache.spark.sql.internal.SQLConf +@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase + + private def checkNumSorts(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan +- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) ++ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count) + } + + private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala +index 47679ed7865..9ffbaecb98e 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala +@@ -18,6 +18,7 @@ + package org.apache.spark.sql.execution + + import org.apache.spark.sql.{DataFrame, QueryTest} ++import org.apache.spark.sql.comet.CometHashAggregateExec + import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} + import org.apache.spark.sql.internal.SQLConf +@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase + private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { +- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s ++ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s + }.length == hashAggCount) + assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala -index ac710c32296..37746bd470d 100644 +index ac710c32296..e163c1a6a76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala -@@ -616,7 +616,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession +@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution + + import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} + import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} ++import org.apache.spark.sql.comet.CometSortMergeJoinExec + import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} + import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +@@ -224,6 +225,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true ++ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), +@@ -258,6 +260,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: CometSortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5), +@@ -280,8 +283,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi") + .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi") + assert(twoJoinsDF.queryExecution.executedPlan.collect { +- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) | +- WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: SortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3))) + } +@@ -302,8 +304,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti") + .join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti") + assert(twoJoinsDF.queryExecution.executedPlan.collect { +- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) | +- WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: SortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9))) + } +@@ -436,7 +437,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val plan = df.queryExecution.executedPlan + assert(plan.exists(p => + p.isInstanceOf[WholeStageCodegenExec] && +- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec])) ++ p.asInstanceOf[WholeStageCodegenExec].collect { ++ case _: SortExec => true ++ }.nonEmpty)) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } + +@@ -616,7 +619,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -587,18 +1102,31 @@ index ac710c32296..37746bd470d 100644 val df = spark.read.parquet(path).selectExpr(projection: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala -index 593bd7bb4ba..be1b82d0030 100644 +index 593bd7bb4ba..7ad55e3ab20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala -@@ -29,6 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe - import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} +@@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._ + + import org.apache.spark.SparkException + import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +-import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} ++import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.comet._ ++import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec -@@ -116,6 +117,9 @@ class AdaptiveQueryExecSuite +@@ -104,6 +106,7 @@ class AdaptiveQueryExecSuite + private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { + collect(plan) { + case j: BroadcastHashJoinExec => j ++ case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec] + } + } + +@@ -116,30 +119,38 @@ class AdaptiveQueryExecSuite private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { collect(plan) { case j: SortMergeJoinExec => j @@ -608,6 +1136,331 @@ index 593bd7bb4ba..be1b82d0030 100644 } } + private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { + collect(plan) { + case j: ShuffledHashJoinExec => j ++ case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec] + } + } + + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { + collect(plan) { + case j: BaseJoinExec => j ++ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] ++ case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] + } + } + + private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { + collect(plan) { + case s: SortExec => s ++ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] + } + } + + private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { + collect(plan) { + case agg: BaseAggregateExec => agg ++ case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec] + } + } + +@@ -176,6 +187,7 @@ class AdaptiveQueryExecSuite + val parts = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + } ++ + assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) + } + +@@ -184,7 +196,7 @@ class AdaptiveQueryExecSuite + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == numPartition) +@@ -200,7 +212,8 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics ++ // checkNumLocalShuffleReads(adaptivePlan) + } + } + +@@ -227,7 +240,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Reuse the parallelism of coalesced shuffle in local shuffle read") { ++ test("Reuse the parallelism of coalesced shuffle in local shuffle read", ++ IgnoreComet("Comet shuffle changes shuffle partition size")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", +@@ -259,7 +273,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Reuse the default parallelism in local shuffle read") { ++ test("Reuse the default parallelism in local shuffle read", ++ IgnoreComet("Comet shuffle changes shuffle partition size")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", +@@ -273,7 +288,8 @@ class AdaptiveQueryExecSuite + val localReads = collect(adaptivePlan) { + case read: AQEShuffleReadExec if read.isLocalRead => read + } +- assert(localReads.length == 2) ++ // Comet shuffle changes shuffle metrics ++ assert(localReads.length == 1) + val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] + val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 +@@ -322,7 +338,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Scalar subquery") { ++ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -337,7 +353,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Scalar subquery in later stages") { ++ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -353,7 +369,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins") { ++ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -398,7 +414,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins with aggregate") { ++ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -443,7 +459,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins with aggregate 2") { ++ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { +@@ -508,7 +524,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Exchange reuse with subqueries") { ++ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -539,7 +555,9 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics, ++ // so we can't check the number of local shuffle reads. ++ // checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) +@@ -560,7 +578,9 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics, ++ // so we can't check the number of local shuffle reads. ++ // checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.isEmpty) +@@ -569,7 +589,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Broadcast exchange reuse across subqueries") { ++ test("Broadcast exchange reuse across subqueries", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", +@@ -664,7 +685,8 @@ class AdaptiveQueryExecSuite + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + // There is still a SMJ, and its two shuffles can't apply local read. +- checkNumLocalShuffleReads(adaptivePlan, 2) ++ // Comet shuffle changes shuffle metrics ++ // checkNumLocalShuffleReads(adaptivePlan, 2) + } + } + +@@ -786,7 +808,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-29544: adaptive skew join with different join types") { ++ test("SPARK-29544: adaptive skew join with different join types", ++ IgnoreComet("Comet shuffle has different partition metrics")) { + Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => + def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { + findTopLevelSortMergeJoin(plan) +@@ -1004,7 +1027,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("metrics of the shuffle read") { ++ test("metrics of the shuffle read", ++ IgnoreComet("Comet shuffle changes the metrics")) { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT key FROM testData GROUP BY key") +@@ -1599,7 +1623,7 @@ class AdaptiveQueryExecSuite + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") + assert(collect(adaptivePlan) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + }.length == 1) + } + } +@@ -1679,7 +1703,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-33551: Do not use AQE shuffle read for repartition") { ++ test("SPARK-33551: Do not use AQE shuffle read for repartition", ++ IgnoreComet("Comet shuffle changes partition size")) { + def hasRepartitionShuffle(plan: SparkPlan): Boolean = { + find(plan) { + case s: ShuffleExchangeLike => +@@ -1864,6 +1889,9 @@ class AdaptiveQueryExecSuite + def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { + assert(collect(ds.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s ++ case c: CometShuffleExchangeExec ++ if c.originalPlan.shuffleOrigin == origin && ++ c.originalPlan.numPartitions == 2 => c + }.size == 1) + ds.collect() + val plan = ds.queryExecution.executedPlan +@@ -1872,6 +1900,9 @@ class AdaptiveQueryExecSuite + }.isEmpty) + assert(collect(plan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s ++ case c: CometShuffleExchangeExec ++ if c.originalPlan.shuffleOrigin == origin && ++ c.originalPlan.numPartitions == 2 => c + }.size == 1) + checkAnswer(ds, testData) + } +@@ -2028,7 +2059,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35264: Support AQE side shuffled hash join formula") { ++ test("SPARK-35264: Support AQE side shuffled hash join formula", ++ IgnoreComet("Comet shuffle changes the partition size")) { + withTempView("t1", "t2") { + def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { + Seq("100", "100000").foreach { size => +@@ -2114,7 +2146,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { ++ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", +@@ -2213,7 +2246,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2 GROUP BY key1") + val shuffles1 = collect(adaptive1) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles1.size == 3) + // shuffles1.head is the top-level shuffle under the Aggregate operator +@@ -2226,7 +2259,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2") + val shuffles2 = collect(adaptive2) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + if (hasRequiredDistribution) { + assert(shuffles2.size == 3) +@@ -2260,7 +2293,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35794: Allow custom plugin for cost evaluator") { ++ test("SPARK-35794: Allow custom plugin for cost evaluator", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + CostEvaluator.instantiate( + classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + intercept[IllegalArgumentException] { +@@ -2404,6 +2438,7 @@ class AdaptiveQueryExecSuite + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + assert(adaptive.collect { + case sort: SortExec => sort ++ case sort: CometSortExec => sort + }.size == 1) + val read = collect(adaptive) { + case read: AQEShuffleReadExec => read +@@ -2421,7 +2456,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-37357: Add small partition factor for rebalance partitions") { ++ test("SPARK-37357: Add small partition factor for rebalance partitions", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", +@@ -2533,7 +2569,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value2 = value3") + val shuffles1 = collect(adaptive1) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles1.size == 4) + val smj1 = findTopLevelSortMergeJoin(adaptive1) +@@ -2544,7 +2580,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value1 = value3") + val shuffles2 = collect(adaptive2) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles2.size == 4) + val smj2 = findTopLevelSortMergeJoin(adaptive2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index bd9c79e5b96..ab7584e768e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -628,6 +1481,29 @@ index bd9c79e5b96..ab7584e768e 100644 } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +index ce43edb79c1..c414b19eda7 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +@@ -17,7 +17,7 @@ + + package org.apache.spark.sql.execution.datasources + +-import org.apache.spark.sql.{QueryTest, Row} ++import org.apache.spark.sql.{IgnoreComet, QueryTest, Row} + import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder} + import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} + import org.apache.spark.sql.execution.{QueryExecution, SortExec} +@@ -305,7 +305,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write + } + } + +- test("v1 write with AQE changing SMJ to BHJ") { ++ test("v1 write with AQE changing SMJ to BHJ", ++ IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) { + withPlannedWrite { enabled => + withTable("t") { + sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 1d2e467c94c..3ea82cd1a3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -917,10 +1793,22 @@ index 3a0bd35cb70..b28f06a757f 100644 val workDirPath = workDir.getAbsolutePath val input = spark.range(5).toDF("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala -index 26e61c6b58d..cde10983c68 100644 +index 26e61c6b58d..cb09d7e116a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala -@@ -737,7 +737,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils +@@ -45,8 +45,10 @@ import org.apache.spark.sql.util.QueryExecutionListener + import org.apache.spark.util.{AccumulatorContext, JsonProtocol} + + // Disable AQE because metric info is different with AQE on/off ++// This test suite runs tests against the metrics of physical operators. ++// Disabling it for Comet because the metrics are different with Comet enabled. + class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils +- with DisableAdaptiveExecutionSuite { ++ with DisableAdaptiveExecutionSuite with IgnoreCometSuite { + import testImplicits._ + + /** +@@ -737,7 +739,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } @@ -1002,21 +1890,24 @@ index d083cac48ff..3c11bcde807 100644 import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala -index 266bb343526..b33bb677f0d 100644 +index 266bb343526..a426d8396be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala -@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec +@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} +import org.apache.spark.sql.comet._ -+import org.apache.spark.sql.comet.execution.shuffle._ +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.datasources.BucketingUtils - import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -@@ -101,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.joins.SortMergeJoinExec + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -101,12 +102,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti } } @@ -1039,7 +1930,7 @@ index 266bb343526..b33bb677f0d 100644 // To verify if the bucket pruning works, this function checks two conditions: // 1) Check if the pruned buckets (before filtering) are empty. // 2) Verify the final result is the same as the expected one -@@ -155,7 +165,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -155,7 +164,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition) .queryExecution.executedPlan val fileScan = getFileScan(planWithoutBucketedScan) @@ -1049,7 +1940,7 @@ index 266bb343526..b33bb677f0d 100644 val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType val rowsWithInvalidBuckets = fileScan.execute().filter(row => { -@@ -451,28 +462,46 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -451,28 +461,44 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) { val executedPlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan @@ -1082,13 +1973,11 @@ index 266bb343526..b33bb677f0d 100644 // check existence of shuffle assert( - joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft, -+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeExec] || -+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleLeft, ++ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight, -+ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeExec] || -+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleRight, ++ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -1104,7 +1993,7 @@ index 266bb343526..b33bb677f0d 100644 s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") // check the output partitioning -@@ -835,11 +864,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -835,11 +861,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") val scanDF = spark.table("bucketed_table").select("j") @@ -1118,14 +2007,13 @@ index 266bb343526..b33bb677f0d 100644 checkAnswer(aggDF, df1.groupBy("j").agg(max("k"))) } } -@@ -1026,15 +1055,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -1026,15 +1052,23 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti expectedNumShuffles: Int, expectedCoalescedNumBuckets: Option[Int]): Unit = { val plan = sql(query).queryExecution.executedPlan - val shuffles = plan.collect { case s: ShuffleExchangeExec => s } + val shuffles = plan.collect { -+ case s: ShuffleExchangeExec => s -+ case s: CometShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } assert(shuffles.length == expectedNumShuffles) @@ -1303,6 +2191,120 @@ index 2a2a83d35e1..e3b7b290b3e 100644 val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS() val initialState: KeyValueGroupedDataset[String, RunningCount] = initialStateDS.groupByKey(_._1).mapValues(_._2) +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +index ef5b8a769fe..84fe1bfabc9 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +@@ -37,6 +37,7 @@ import org.apache.spark.sql._ + import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} + import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} + import org.apache.spark.sql.catalyst.util.DateTimeUtils ++import org.apache.spark.sql.comet.CometLocalLimitExec + import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} + import org.apache.spark.sql.execution.command.ExplainCommand + import org.apache.spark.sql.execution.streaming._ +@@ -1103,11 +1104,12 @@ class StreamSuite extends StreamTest { + val localLimits = execPlan.collect { + case l: LocalLimitExec => l + case l: StreamingLocalLimitExec => l ++ case l: CometLocalLimitExec => l + } + + require( + localLimits.size == 1, +- s"Cant verify local limit optimization with this plan:\n$execPlan") ++ s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan") + + if (expectStreamingLimit) { + assert( +@@ -1115,7 +1117,8 @@ class StreamSuite extends StreamTest { + s"Local limit was not StreamingLocalLimitExec:\n$execPlan") + } else { + assert( +- localLimits.head.isInstanceOf[LocalLimitExec], ++ localLimits.head.isInstanceOf[LocalLimitExec] || ++ localLimits.head.isInstanceOf[CometLocalLimitExec], + s"Local limit was not LocalLimitExec:\n$execPlan") + } + } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala +index b4c4ec7acbf..20579284856 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala +@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils + import org.scalatest.Assertions + + import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution ++import org.apache.spark.sql.comet.CometHashAggregateExec + import org.apache.spark.sql.execution.aggregate.BaseAggregateExec + import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec} + import org.apache.spark.sql.functions.count +@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest + // verify aggregations in between, except partial aggregation + val allAggregateExecs = query.lastExecution.executedPlan.collect { + case a: BaseAggregateExec => a ++ case c: CometHashAggregateExec => c.originalPlan + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { +@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest + // verify aggregations in between, except partial aggregation + val allAggregateExecs = executedPlan.collect { + case a: BaseAggregateExec => a ++ case c: CometHashAggregateExec => c.originalPlan + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +index 4d92e270539..33f1c2eb75e 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation + import org.apache.spark.sql.{DataFrame, Row, SparkSession} + import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} + import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} + import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId} + import org.apache.spark.sql.functions._ +@@ -619,14 +619,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { + + val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) + +- assert(query.lastExecution.executedPlan.collect { +- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, +- ShuffleExchangeExec(opA: HashPartitioning, _, _), +- ShuffleExchangeExec(opB: HashPartitioning, _, _)) +- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") +- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") +- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j +- }.size == 1) ++ val join = query.lastExecution.executedPlan.collect { ++ case j: StreamingSymmetricHashJoinExec => j ++ }.head ++ val opA = join.left.collect { ++ case s: ShuffleExchangeLike ++ if s.outputPartitioning.isInstanceOf[HashPartitioning] && ++ partitionExpressionsColumns( ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning] ++ }.head ++ val opB = join.right.collect { ++ case s: ShuffleExchangeLike ++ if s.outputPartitioning.isInstanceOf[HashPartitioning] && ++ partitionExpressionsColumns( ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning] ++ }.head ++ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions) + }) + } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index abe606ad9c1..2d930b64cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -1327,7 +2329,7 @@ index abe606ad9c1..2d930b64cca 100644 val tblTargetName = "tbl_target" val tblSourceQualified = s"default.$tblSourceName" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala -index dd55fcfe42c..b4776c50e49 100644 +index dd55fcfe42c..293e9dc2986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -1364,20 +2366,20 @@ index dd55fcfe42c..b4776c50e49 100644 + } + + /** -+ * Whether Spark should only apply Comet scan optimization. This is only effective when ++ * Whether to enable ansi mode This is only effective when + * [[isCometEnabled]] returns true. + */ -+ protected def isCometScanOnly: Boolean = { -+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY") ++ protected def enableCometAnsiMode: Boolean = { ++ val v = System.getenv("ENABLE_COMET_ANSI_MODE") + v != null && v.toBoolean + } + + /** -+ * Whether to enable ansi mode This is only effective when ++ * Whether Spark should only apply Comet scan optimization. This is only effective when + * [[isCometEnabled]] returns true. + */ -+ protected def enableCometAnsiMode: Boolean = { -+ val v = System.getenv("ENABLE_COMET_ANSI_MODE") ++ protected def isCometScanOnly: Boolean = { ++ val v = System.getenv("ENABLE_COMET_SCAN_ONLY") + v != null && v.toBoolean + } + @@ -1394,7 +2396,7 @@ index dd55fcfe42c..b4776c50e49 100644 spark.internalCreateDataFrame(withoutFilters.execute(), schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala -index ed2e309fa07..f64cc283903 100644 +index ed2e309fa07..e071fc44960 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -74,6 +74,28 @@ trait SharedSparkSessionBase @@ -1414,6 +2416,7 @@ index ed2e309fa07..f64cc283903 100644 + .set("spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + .set("spark.comet.exec.shuffle.enabled", "true") ++ .set("spark.comet.memoryOverhead", "10g") + } + + if (enableCometAnsiMode) { @@ -1421,11 +2424,23 @@ index ed2e309fa07..f64cc283903 100644 + .set("spark.sql.ansi.enabled", "true") + .set("spark.comet.ansi.enabled", "true") + } -+ + } conf.set( StaticSQLConf.WAREHOUSE_PATH, conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) +diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +index 1510e8957f9..7618419d8ff 100644 +--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +@@ -43,7 +43,7 @@ class SqlResourceWithActualMetricsSuite + import testImplicits._ + + // Exclude nodes which may not have the metrics +- val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject") ++ val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar") + + implicit val formats = new DefaultFormats { + override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala index 52abd248f3a..7a199931a08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala @@ -1463,10 +2478,10 @@ index 1966e1e64fd..cde97a0aafe 100644 spark.sql( """ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala -index 07361cfdce9..1763168a808 100644 +index 07361cfdce9..25b0dc3ef7e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala -@@ -55,25 +55,54 @@ object TestHive +@@ -55,25 +55,53 @@ object TestHive new SparkContext( System.getProperty("spark.sql.test.master", "local[1]"), "TestSQLContext", @@ -1530,9 +2545,8 @@ index 07361cfdce9..1763168a808 100644 + .set("spark.sql.ansi.enabled", "true") + .set("spark.comet.ansi.enabled", "true") + } - + } -+ + + conf + } + )) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 02ecbd693..22a7a0982 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -29,7 +29,7 @@ Comet provides the following configuration settings. | spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config is false. | false | | spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. By default, this config is 100. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 | | spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. By default, this config is 3. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 | -| spark.comet.columnar.shuffle.enabled | Force Comet to only use columnar shuffle for CometScan and Spark regular operators. If this is enabled, Comet native shuffle will not be enabled but only Arrow shuffle. By default, this config is false. | false | +| spark.comet.columnar.shuffle.enabled | Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. If this is enabled, Comet prefers columnar shuffle than native shuffle. By default, this config is true. | true | | spark.comet.columnar.shuffle.memory.factor | Fraction of Comet memory to be allocated per executor process for Comet shuffle. Comet memory size is specified by `spark.comet.memoryOverhead` or calculated by `spark.comet.memory.overhead.factor` * `spark.executor.memory`. By default, this config is 1.0. | 1.0 | | spark.comet.debug.enabled | Whether to enable debug mode for Comet. By default, this config is false. When enabled, Comet will do additional checks for debugging purpose. For example, validating array when importing arrays from JVM at native side. Note that these checks may be expensive in performance and should only be enabled for debugging purpose. | false | | spark.comet.enabled | Whether to enable Comet extension for Spark. When this is turned on, Spark will use Comet to read Parquet data source. Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. By default, this config is the value of the env var `ENABLE_COMET` if set, or true otherwise. | true | diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e77adc9bb..63b23ba1e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2479,7 +2479,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(join, "SortMergeJoin is not enabled") None - case op if isCometSink(op) => + case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType)) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala index 6e5b44c8f..996526e55 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin trait AliasAwareOutputExpression extends SQLConfHelper { // `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only. // Use a default value for now. - protected val aliasCandidateLimit = 100 + protected val aliasCandidateLimit: Int = + conf.getConfString("spark.sql.optimizer.expressionProjectionCandidateLimit", "100").toInt protected def outputExpressions: Seq[NamedExpression] /** diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index cdbd7194d..7a07f8629 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -152,8 +152,9 @@ class CometTPCDSQuerySuite conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "20g") conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") - conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(MEMORY_OFFHEAP_SIZE.key, "20g") conf }