diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml index 71eb02a9e..1e347250f 100644 --- a/.github/workflows/pr_build.yml +++ b/.github/workflows/pr_build.yml @@ -47,13 +47,14 @@ jobs: java_version: [8, 11, 17] test-target: [rust, java] spark-version: ['3.4'] + scala-version: ['2.12', '2.13'] is_push_event: - ${{ github.event_name == 'push' }} exclude: # exclude java 11 for pull_request event - java_version: 11 is_push_event: false fail-fast: false - name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}-scala-${{matrix.scala-version}}/${{ matrix.test-target }} runs-on: ${{ matrix.os }} container: image: amd64/rust @@ -71,7 +72,7 @@ jobs: name: Java test steps uses: ./.github/actions/java-test with: - maven_opts: -Pspark-${{ matrix.spark-version }} + maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} # upload test reports only for java 17 upload-test-reports: ${{ matrix.java_version == '17' }} @@ -82,13 +83,16 @@ jobs: java_version: [8, 11, 17] test-target: [java] spark-version: ['3.2', '3.3'] + scala-version: ['2.12', '2.13'] exclude: - java_version: 17 spark-version: '3.2' - java_version: 11 spark-version: '3.2' + - spark-version: '3.2' + scala-version: '2.13' fail-fast: false - name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}-scala-${{matrix.scala-version}}/${{ matrix.test-target }} runs-on: ${{ matrix.os }} container: image: amd64/rust @@ -102,7 +106,7 @@ jobs: - name: Java test steps uses: ./.github/actions/java-test with: - maven_opts: -Pspark-${{ matrix.spark-version }} + maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} macos-test: strategy: @@ -111,9 +115,10 @@ jobs: java_version: [8, 11, 17] test-target: [rust, java] spark-version: ['3.4'] + scala-version: ['2.12', '2.13'] fail-fast: false if: github.event_name == 'push' - name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + name: ${{ matrix.os }}/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}-scala-${{matrix.scala-version}}/${{ matrix.test-target }} runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -129,7 +134,7 @@ jobs: name: Java test steps uses: ./.github/actions/java-test with: - maven_opts: -Pspark-${{ matrix.spark-version }} + maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} macos-aarch64-test: strategy: @@ -137,13 +142,14 @@ jobs: java_version: [8, 11, 17] test-target: [rust, java] spark-version: ['3.4'] + scala-version: ['2.12', '2.13'] is_push_event: - ${{ github.event_name == 'push' }} exclude: # exclude java 11 for pull_request event - java_version: 11 is_push_event: false fail-fast: false - name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}-scala-${{matrix.scala-version}}/${{ matrix.test-target }} runs-on: macos-14 steps: - uses: actions/checkout@v4 @@ -161,7 +167,7 @@ jobs: name: Java test steps uses: ./.github/actions/java-test with: - maven_opts: -Pspark-${{ matrix.spark-version }} + maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} macos-aarch64-test-with-old-spark: strategy: @@ -169,13 +175,16 @@ jobs: java_version: [8, 17] test-target: [java] spark-version: ['3.2', '3.3'] + scala-version: ['2.12', '2.13'] exclude: - java_version: 17 spark-version: '3.2' - java_version: 8 spark-version: '3.3' + - spark-version: '3.2' + scala-version: '2.13' fail-fast: false - name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}/${{ matrix.test-target }} + name: macos-14(Silicon)/java ${{ matrix.java_version }}-spark-${{matrix.spark-version}}-scala-${{matrix.scala-version}}/${{ matrix.test-target }} runs-on: macos-14 steps: - uses: actions/checkout@v4 @@ -190,5 +199,5 @@ jobs: name: Java test steps uses: ./.github/actions/java-test with: - maven_opts: -Pspark-${{ matrix.spark-version }} + maven_opts: -Pspark-${{ matrix.spark-version }},scala-${{ matrix.scala-version }} diff --git a/common/pom.xml b/common/pom.xml index 540101d71..cc1f44481 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -179,7 +179,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index f2a793b31..ba8301ded 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -140,14 +140,14 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = conf( - "spark.comet.columnar.shuffle.enabled") - .doc( - "Force Comet to only use columnar shuffle for CometScan and Spark regular operators. " + - "If this is enabled, Comet native shuffle will not be enabled but only Arrow shuffle. " + - "By default, this config is false.") - .booleanConf - .createWithDefault(false) + val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.columnar.shuffle.enabled") + .doc( + "Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. " + + "If this is enabled, Comet prefers columnar shuffle than native shuffle. " + + "By default, this config is true.") + .booleanConf + .createWithDefault(true) val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.shuffle.enforceMode.enabled") diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index eb731f9d0..595c0a427 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -66,7 +66,7 @@ class NativeUtil { val arrowArray = ArrowArray.allocateNew(allocator) Data.exportVector( allocator, - getFieldVector(valueVector), + getFieldVector(valueVector, "export"), provider, arrowArray, arrowSchema) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 7d920e1be..2300e109a 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -242,7 +242,7 @@ object Utils { } } - getFieldVector(valueVector) + getFieldVector(valueVector, "serialize") case c => throw new SparkException( @@ -253,14 +253,15 @@ object Utils { (fieldVectors, provider) } - def getFieldVector(valueVector: ValueVector): FieldVector = { + def getFieldVector(valueVector: ValueVector, reason: String): FieldVector = { valueVector match { case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | _: FixedSizeBinaryVector | _: TimeStampMicroVector) => v.asInstanceOf[FieldVector] - case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}") + case _ => + throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") } } } diff --git a/core/Cargo.lock b/core/Cargo.lock index 3fb7b5f62..52f105591 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -57,9 +57,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.16" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "android-tzdata" @@ -90,9 +90,9 @@ checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anyhow" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] name = "arc-swap" @@ -327,13 +327,13 @@ checksum = "0c24e9d990669fbd16806bff449e4ac644fd9b1fca014760087732fe4102f131" [[package]] name = "async-trait" -version = "0.1.79" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -438,9 +438,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.4" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" @@ -468,9 +468,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.90" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd6604a82acf3039f1144f54b8eb34e91ffba622051189e71b781822d5ee1f5" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" dependencies = [ "jobserver", "libc", @@ -490,14 +490,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.37" +version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -576,9 +576,9 @@ checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] name = "combine" -version = "4.6.6" +version = "4.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" dependencies = [ "bytes", "memchr", @@ -625,7 +625,7 @@ dependencies = [ "parquet-format", "paste", "pprof", - "prost 0.12.3", + "prost 0.12.4", "prost-build", "rand", "regex", @@ -643,12 +643,12 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.0" +version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" +checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.25.0", - "strum_macros 0.25.3", + "strum", + "strum_macros", "unicode-width", ] @@ -923,8 +923,8 @@ dependencies = [ "datafusion-common", "paste", "sqlparser", - "strum 0.26.2", - "strum_macros 0.26.2", + "strum", + "strum_macros", ] [[package]] @@ -1044,7 +1044,7 @@ dependencies = [ "datafusion-expr", "log", "sqlparser", - "strum 0.26.2", + "strum", ] [[package]] @@ -1092,9 +1092,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" [[package]] name = "equivalent" @@ -1227,7 +1227,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -1272,9 +1272,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", "libc", @@ -1526,9 +1526,9 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jobserver" -version = "0.1.28" +version = "0.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2" dependencies = [ "libc", ] @@ -1794,9 +1794,9 @@ dependencies = [ [[package]] name = "num" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41" dependencies = [ "num-bigint", "num-complex", @@ -2142,9 +2142,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" +checksum = "a56dea16b0a29e94408b9aa5e2940a4eedbd128a1ba20e8f7ae60fd3d465af0e" dependencies = [ "unicode-ident", ] @@ -2161,12 +2161,12 @@ dependencies = [ [[package]] name = "prost" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c289cda302b98a28d40c8b3b90498d6e526dd24ac2ecea73e4e491685b94a" +checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922" dependencies = [ "bytes", - "prost-derive 0.12.3", + "prost-derive 0.12.4", ] [[package]] @@ -2204,15 +2204,15 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" +checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2236,9 +2236,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -2370,9 +2370,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" [[package]] name = "ryu" @@ -2434,14 +2434,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] name = "serde_json" -version = "1.0.115" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", @@ -2545,7 +2545,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2566,32 +2566,13 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" -[[package]] -name = "strum" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" - [[package]] name = "strum" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" dependencies = [ - "strum_macros 0.26.2", -] - -[[package]] -name = "strum_macros" -version = "0.25.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "rustversion", - "syn 2.0.57", + "strum_macros", ] [[package]] @@ -2604,7 +2585,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2649,9 +2630,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.57" +version = "2.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11a6ae1e52eb25aab8f3fb9fca13be982a373b8f1157ca14b897a825ba4a2d35" +checksum = "4a6531ffc7b071655e4ce2e04bd464c4830bb585a61cabb96cf808f05172615a" dependencies = [ "proc-macro2", "quote", @@ -2687,7 +2668,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2790,7 +2771,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2823,7 +2804,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] @@ -2971,7 +2952,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", "wasm-bindgen-shared", ] @@ -2993,7 +2974,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3063,7 +3044,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -3081,7 +3062,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.4", + "windows-targets 0.52.5", ] [[package]] @@ -3116,17 +3097,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ - "windows_aarch64_gnullvm 0.52.4", - "windows_aarch64_msvc 0.52.4", - "windows_i686_gnu 0.52.4", - "windows_i686_msvc 0.52.4", - "windows_x86_64_gnu 0.52.4", - "windows_x86_64_gnullvm 0.52.4", - "windows_x86_64_msvc 0.52.4", + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", ] [[package]] @@ -3143,9 +3125,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" @@ -3161,9 +3143,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" @@ -3179,9 +3161,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.4" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" @@ -3197,9 +3185,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" @@ -3215,9 +3203,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" @@ -3233,9 +3221,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" @@ -3251,9 +3239,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.4" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" [[package]] name = "zerocopy" @@ -3272,7 +3260,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.57", + "syn 2.0.59", ] [[package]] diff --git a/core/Cargo.toml b/core/Cargo.toml index cbca7f629..ac565680a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -119,5 +119,9 @@ name = "row_columnar" harness = false [[bench]] -name = "cast" +name = "cast_from_string" +harness = false + +[[bench]] +name = "cast_numeric" harness = false diff --git a/core/benches/cast.rs b/core/benches/cast_from_string.rs similarity index 93% rename from core/benches/cast.rs rename to core/benches/cast_from_string.rs index 281fe82e2..5bfaebf34 100644 --- a/core/benches/cast.rs +++ b/core/benches/cast_from_string.rs @@ -23,19 +23,7 @@ use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); - let mut b = StringBuilder::new(); - for i in 0..1000 { - if i % 10 == 0 { - b.append_null(); - } else if i % 2 == 0 { - b.append_value(format!("{}", rand::random::())); - } else { - b.append_value(format!("{}", rand::random::())); - } - } - let array = b.finish(); - let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + let batch = create_utf8_batch(); let expr = Arc::new(Column::new("a", 0)); let timezone = "".to_string(); let cast_string_to_i8 = Cast::new( @@ -58,7 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { ); let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); - let mut group = c.benchmark_group("cast"); + let mut group = c.benchmark_group("cast_string_to_int"); group.bench_function("cast_string_to_i8", |b| { b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); }); @@ -73,6 +61,24 @@ fn criterion_benchmark(c: &mut Criterion) { }); } +// Create UTF8 batch with strings representing ints, floats, nulls +fn create_utf8_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + batch +} + fn config() -> Criterion { Criterion::default() } diff --git a/core/benches/cast_numeric.rs b/core/benches/cast_numeric.rs new file mode 100644 index 000000000..398be6946 --- /dev/null +++ b/core/benches/cast_numeric.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::Int32Builder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use comet::execution::datafusion::expressions::cast::{Cast, EvalMode}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_int32_batch(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_i32_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_i32_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast_int_to_int"); + group.bench_function("cast_i32_to_i8", |b| { + b.iter(|| cast_i32_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i16", |b| { + b.iter(|| cast_i32_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i64", |b| { + b.iter(|| cast_i32_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn create_int32_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let mut b = Int32Builder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + let array = b.finish(); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap(); + batch +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/core/src/errors.rs b/core/src/errors.rs index a06c613ad..04a1629d5 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -72,6 +72,13 @@ pub enum CometError { to_type: String, }, + #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] @@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError { class: "org/apache/spark/SparkException".to_string(), msg: self.to_string(), }, + CometError::NumericValueOutOfRange { .. } => Exception { + class: "org/apache/spark/SparkException".to_string(), + msg: self.to_string(), + }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), msg: s.to_string(), diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index a6e3adaca..35ab23a76 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -25,21 +25,24 @@ use std::{ use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, - datatypes::TimestampMicrosecondType, + datatypes::{ + ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, + TimestampMicrosecondType, + }, record_batch::RecordBatch, util::display::FormatOptions, }; use arrow_array::{ types::{Int16Type, Int32Type, Int64Type, Int8Type}, - Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, GenericStringArray, + Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }; use arrow_schema::{DataType, Schema}; use chrono::{TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; +use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; use crate::execution::datafusion::expressions::utils::{ @@ -214,11 +217,11 @@ macro_rules! cast_int_to_int_macro { Some(value) => { let res = <$to_native_type>::try_from(value); if res.is_err() { - Err(CometError::CastOverFlow { - value: value.to_string() + spark_int_literal_suffix, - from_type: $spark_from_data_type_name.to_string(), - to_type: $spark_to_data_type_name.to_string(), - }) + Err(cast_overflow( + &(value.to_string() + spark_int_literal_suffix), + $spark_from_data_type_name, + $spark_to_data_type_name, + )) } else { Ok::, CometError>(Some(res.unwrap())) } @@ -232,6 +235,240 @@ macro_rules! cast_int_to_int_macro { }}; } +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_float_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = value.is_nan() || value.abs() as i32 == std::i32::MAX; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + let i32_value = value as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let i32_value = value as i32; + Ok::, CometError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_float_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $max_dest_val:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = + value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + Ok(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, CometError>(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_decimal_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > std::i32::MAX.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + let i32_value = truncated as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let i32_value = (value / divisor) as i32; + Ok::, CometError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $max_dest_val:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > $max_dest_val.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + Ok(Some(truncated as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + Ok::, CometError>(Some( + truncated as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + impl Cast { pub fn new( child: Arc, @@ -332,6 +569,33 @@ impl Cast { (DataType::Float32, DataType::LargeUtf8) => { Self::spark_cast_float32_to_utf8::(&array, self.eval_mode)? } + (DataType::Float32, DataType::Decimal128(precision, scale)) => { + Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode)? + } + (DataType::Float64, DataType::Decimal128(precision, scale)) => { + Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)? + } + (DataType::Float32, DataType::Int8) + | (DataType::Float32, DataType::Int16) + | (DataType::Float32, DataType::Int32) + | (DataType::Float32, DataType::Int64) + | (DataType::Float64, DataType::Int8) + | (DataType::Float64, DataType::Int16) + | (DataType::Float64, DataType::Int32) + | (DataType::Float64, DataType::Int64) + | (DataType::Decimal128(_, _), DataType::Int8) + | (DataType::Decimal128(_, _), DataType::Int16) + | (DataType::Decimal128(_, _), DataType::Int32) + | (DataType::Decimal128(_, _), DataType::Int64) + if self.eval_mode != EvalMode::Try => + { + Self::spark_cast_nonintegral_numeric_to_integral( + &array, + self.eval_mode, + from_type, + to_type, + )? + } _ => { // when we have no Spark-specific casting we delegate to DataFusion cast_with_options(&array, to_type, &CAST_OPTIONS)? @@ -395,6 +659,83 @@ impl Cast { Ok(cast_array) } + fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> CometResult + where + ::Native: AsPrimitive, + { + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); + + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { + if eval_mode == EvalMode::Ansi { + return Err(CometError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(CometError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + } + } + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) + } + fn spark_cast_float64_to_utf8( from: &dyn Array, _eval_mode: EvalMode, @@ -478,6 +819,146 @@ impl Cast { Ok(Arc::new(output_array)) } + + fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + ) -> CometResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + std::i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + std::i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + std::i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + std::i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int32Array, + i32, + "INT", + std::i32::MAX, + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int64Array, + i64, + "BIGINT", + std::i64::MAX, + *precision, + *scale + ) + } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), + } + } } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte @@ -676,6 +1157,15 @@ fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { } } +#[inline] +fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> CometError { + CometError::CastOverFlow { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 2895937ca..8c5e1f391 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -52,6 +52,9 @@ use num::{ }; use unicode_segmentation::UnicodeSegmentation; +mod unhex; +use unhex::spark_unhex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -105,6 +108,10 @@ pub fn create_comet_physical_fun( "make_decimal" => { make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } + "unhex" => { + let func = Arc::new(spark_unhex); + make_comet_scalar_udf!("unhex", func, without data_type) + } "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) } @@ -123,11 +130,10 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type) } _ => { - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { - Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + Ok(ScalarFunctionDefinition::BuiltIn(fun)) } else { - Ok(ScalarFunctionDefinition::BuiltIn(fun?)) + Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) } } } diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs new file mode 100644 index 000000000..38d5c0478 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::OffsetSizeTrait; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; + +/// Helper function to convert a hex digit to a binary value. +fn unhex_digit(c: u8) -> Result { + match c { + b'0'..=b'9' => Ok(c - b'0'), + b'A'..=b'F' => Ok(10 + c - b'A'), + b'a'..=b'f' => Ok(10 + c - b'a'), + _ => Err(DataFusionError::Execution( + "Input to unhex_digit is not a valid hex digit".to_string(), + )), + } +} + +/// Convert a hex string to binary and store the result in `result`. Returns an error if the input +/// is not a valid hex string. +fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { + let bytes = hex_str.as_bytes(); + + let mut i = 0; + + if (bytes.len() & 0x01) != 0 { + let v = unhex_digit(bytes[0])?; + + result.push(v); + i += 1; + } + + while i < bytes.len() { + let first = unhex_digit(bytes[i])?; + let second = unhex_digit(bytes[i + 1])?; + result.push((first << 4) | second); + + i += 2; + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut encoded = Vec::new(); + let mut builder = arrow::array::BinaryBuilder::new(); + + for item in string_array.iter() { + if let Some(s) = item { + if unhex(s, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {s}"); + } else { + builder.append_null(); + } + encoded.clear(); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a Utf8 or LargeUtf8: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{BinaryBuilder, StringBuilder}; + use arrow_array::make_array; + use arrow_data::ArrayData; + use datafusion::logical_expr::ColumnarValue; + use datafusion_common::ScalarValue; + + use super::unhex; + + #[test] + fn test_spark_unhex_null() -> Result<(), Box> { + let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2); + let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2); + + let input = ColumnarValue::Array(Arc::new(make_array(input))); + let expected = ColumnarValue::Array(Arc::new(make_array(output))); + + let result = super::spark_unhex(&[input])?; + + match (result, expected) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_partial_error() -> Result<(), Box> { + let mut input = StringBuilder::new(); + + input.append_value("1CGG"); // 1C is ok, but GG is invalid + input.append_value("537061726B2053514C"); // followed by valid + + let input = ColumnarValue::Array(Arc::new(input.finish())); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = super::spark_unhex(&[input, fail_on_error])?; + + let mut expected = BinaryBuilder::new(); + expected.append_null(); + expected.append_value("Spark SQL".as_bytes()); + + match (result, ColumnarValue::Array(Arc::new(expected.finish()))) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_unhex_valid() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + unhex("1C", &mut result)?; + assert_eq!(result, vec![28]); + result.clear(); + + unhex("737472696E67", &mut result)?; + assert_eq!(result, "string".as_bytes()); + result.clear(); + + unhex("1", &mut result)?; + assert_eq!(result, vec![1]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_odd_length() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + unhex("0A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_unhex_empty() { + let mut result = Vec::new(); + + // Empty hex string + unhex("", &mut result).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_unhex_invalid() { + let mut result = Vec::new(); + + // Invalid hex strings + assert!(unhex("##", &mut result).is_err()); + assert!(unhex("G123", &mut result).is_err()); + assert!(unhex("hello", &mut result).is_err()); + assert!(unhex("\0", &mut result).is_err()); + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 6a050eb8b..59818857e 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1040,6 +1040,23 @@ impl PhysicalPlanner { .collect(); let full_schema = Arc::new(Schema::new(all_fields)); + // Because we cast dictionary array to array in scan operator, + // we need to change dictionary type to data type for join filter expression. + let fields: Vec<_> = full_schema + .fields() + .iter() + .map(|f| match f.data_type() { + DataType::Dictionary(_, val_type) => Arc::new(Field::new( + f.name(), + val_type.as_ref().clone(), + f.is_nullable(), + )), + _ => f.clone(), + }) + .collect(); + + let full_schema = Arc::new(Schema::new(fields)); + let physical_expr = self.create_expr(expr, full_schema)?; let (left_field_indices, right_field_indices) = expr_to_columns(&physical_expr, left_fields.len(), right_fields.len())?; @@ -1058,6 +1075,14 @@ impl PhysicalPlanner { .into_iter() .map(|i| right.schema().field(i).clone()), ) + // Because we cast dictionary array to array in scan operator, + // we need to change dictionary type to data type for join filter expression. + .map(|f| match f.data_type() { + DataType::Dictionary(_, val_type) => { + Field::new(f.name(), val_type.as_ref().clone(), f.is_nullable()) + } + _ => f.clone(), + }) .collect_vec(); let filter_schema = Schema::new_with_metadata(filter_fields, HashMap::new()); @@ -1326,6 +1351,7 @@ impl PhysicalPlanner { .iter() .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; + let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { Some(t) => t, None => { @@ -1333,17 +1359,18 @@ impl PhysicalPlanner { // scalar function // Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll // throw error. - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { + + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + fun.return_type(&input_expr_types)? + } else { self.session_ctx .udf(fun_name)? .inner() .return_type(&input_expr_types)? - } else { - fun?.return_type(&input_expr_types)? } } }; + let fun_expr = create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; diff --git a/dev/diffs/3.4.2.diff b/dev/diffs/3.4.2.diff index 4154a705d..19bf6dd41 100644 --- a/dev/diffs/3.4.2.diff +++ b/dev/diffs/3.4.2.diff @@ -210,6 +210,51 @@ index 0efe0877e9b..423d3b3d76d 100644 -- -- SELECT_HAVING -- https://github.com/postgres/postgres/blob/REL_12_BETA2/src/test/regress/sql/select_having.sql +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +index cf40e944c09..bdd5be4f462 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants + import org.apache.spark.sql.execution.{ColumnarToRowExec, ExecSubqueryExpression, RDDScanExec, SparkPlan} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.columnar._ +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -516,7 +516,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils + */ + private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + assert( +- collect(df.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.size == expected) ++ collect(df.queryExecution.executedPlan) { ++ case _: ShuffleExchangeLike => 1 }.size == expected) + } + + test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +index ea5e47ede55..814b92d090f 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +@@ -27,7 +27,7 @@ import org.apache.spark.SparkException + import org.apache.spark.sql.execution.WholeStageCodegenExec + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -755,7 +755,7 @@ class DataFrameAggregateSuite extends QueryTest + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = collect(aggPlan) { +- case shuffle: ShuffleExchangeExec => shuffle ++ case shuffle: ShuffleExchangeLike => shuffle + } + assert(exchangePlans.length == 1) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 56e9520fdab..917932336df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -226,9 +271,54 @@ index 56e9520fdab..917932336df 100644 spark.range(100).write.saveAsTable(s"$dbName.$table2Name") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala -index 9ddb4abe98b..1bebe99f1cc 100644 +index 9ddb4abe98b..1b9269acef1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +@@ -43,7 +43,7 @@ import org.apache.spark.sql.connector.FakeV2Provider + import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.aggregate.HashAggregateExec +-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} ++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.expressions.{Aggregator, Window} + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -1981,7 +1981,7 @@ class DataFrameSuite extends QueryTest + fail("Should not have back to back Aggregates") + } + atFirstAgg = true +- case e: ShuffleExchangeExec => atFirstAgg = false ++ case e: ShuffleExchangeLike => atFirstAgg = false + case _ => + } + } +@@ -2291,7 +2291,7 @@ class DataFrameSuite extends QueryTest + checkAnswer(join, df) + assert( + collect(join.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => true }.size === 1) ++ case _: ShuffleExchangeLike => true }.size === 1) + assert( + collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) + val broadcasted = broadcast(join) +@@ -2299,7 +2299,7 @@ class DataFrameSuite extends QueryTest + checkAnswer(join2, df) + assert( + collect(join2.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => true }.size == 1) ++ case _: ShuffleExchangeLike => true }.size == 1) + assert( + collect(join2.queryExecution.executedPlan) { + case e: BroadcastExchangeExec => true }.size === 1) +@@ -2862,7 +2862,7 @@ class DataFrameSuite extends QueryTest + + // Assert that no extra shuffle introduced by cogroup. + val exchanges = collect(df3.queryExecution.executedPlan) { +- case h: ShuffleExchangeExec => h ++ case h: ShuffleExchangeLike => h + } + assert(exchanges.size == 2) + } @@ -3311,7 +3311,8 @@ class DataFrameSuite extends QueryTest assert(df2.isLocal) } @@ -239,8 +329,30 @@ index 9ddb4abe98b..1bebe99f1cc 100644 withTable("tbl") { sql( """ +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +index 7dec558f8df..840dda15033 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} + import org.apache.spark.sql.catalyst.util.sideBySide + import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +-import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} ++import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.streaming.MemoryStream + import org.apache.spark.sql.expressions.UserDefinedFunction + import org.apache.spark.sql.functions._ +@@ -2254,7 +2254,7 @@ class DatasetSuite extends QueryTest + + // Assert that no extra shuffle introduced by cogroup. + val exchanges = collect(df3.queryExecution.executedPlan) { +- case h: ShuffleExchangeExec => h ++ case h: ShuffleExchangeLike => h + } + assert(exchanges.size == 2) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala -index f33432ddb6f..6160c8d241a 100644 +index f33432ddb6f..060f874ea72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen @@ -261,7 +373,17 @@ index f33432ddb6f..6160c8d241a 100644 case _ => Nil } } -@@ -1238,7 +1242,8 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1187,7 +1191,8 @@ abstract class DynamicPartitionPruningSuiteBase + } + } + +- test("Make sure dynamic pruning works on uncorrelated queries") { ++ test("Make sure dynamic pruning works on uncorrelated queries", ++ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql( + """ +@@ -1238,7 +1243,8 @@ abstract class DynamicPartitionPruningSuiteBase } } @@ -271,7 +393,7 @@ index f33432ddb6f..6160c8d241a 100644 Given("dynamic pruning filter on the build side") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( -@@ -1485,7 +1490,7 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1485,7 +1491,7 @@ abstract class DynamicPartitionPruningSuiteBase } test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " + @@ -280,7 +402,7 @@ index f33432ddb6f..6160c8d241a 100644 withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { Seq( "f.store_id = 1" -> false, -@@ -1729,6 +1734,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat +@@ -1729,6 +1735,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) @@ -290,7 +412,7 @@ index f33432ddb6f..6160c8d241a 100644 } assert(scanOption.isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala -index a6b295578d6..a5cb616945a 100644 +index a6b295578d6..91acca4306f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -463,7 +463,8 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite @@ -303,19 +425,38 @@ index a6b295578d6..a5cb616945a 100644 withTempDir { dir => Seq("parquet", "orc", "csv", "json").foreach { fmt => val basePath = dir.getCanonicalPath + "/" + fmt +@@ -541,7 +542,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite + } + } + +-class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { ++// Ignored when Comet is enabled. Comet changes expected query plans. ++class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite ++ with IgnoreCometSuite { + import testImplicits._ + + test("SPARK-35884: Explain Formatted") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala -index 2796b1cf154..94591f83c84 100644 +index 2796b1cf154..be7078b38f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal} import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter -+import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec} ++import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec, CometSortMergeJoinExec} import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition -@@ -875,6 +876,7 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -815,6 +816,7 @@ class FileBasedDataSourceSuite extends QueryTest + assert(bJoinExec.isEmpty) + val smJoinExec = collect(joinedDF.queryExecution.executedPlan) { + case smJoin: SortMergeJoinExec => smJoin ++ case smJoin: CometSortMergeJoinExec => smJoin + } + assert(smJoinExec.nonEmpty) + } +@@ -875,6 +877,7 @@ class FileBasedDataSourceSuite extends QueryTest val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f @@ -323,7 +464,7 @@ index 2796b1cf154..94591f83c84 100644 } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) -@@ -916,6 +918,7 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -916,6 +919,7 @@ class FileBasedDataSourceSuite extends QueryTest val fileScan = df.queryExecution.executedPlan collectFirst { case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f @@ -331,7 +472,7 @@ index 2796b1cf154..94591f83c84 100644 } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) -@@ -1100,6 +1103,8 @@ class FileBasedDataSourceSuite extends QueryTest +@@ -1100,6 +1104,8 @@ class FileBasedDataSourceSuite extends QueryTest val filters = df.queryExecution.executedPlan.collect { case f: FileSourceScanLike => f.dataFilters case b: BatchScanExec => b.scan.asInstanceOf[FileScan].dataFilters @@ -388,8 +529,36 @@ index 00000000000..4b31bea33de + } + } +} +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +index 1792b4c32eb..1616e6f39bd 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide + import org.apache.spark.sql.catalyst.plans.PlanTest + import org.apache.spark.sql.catalyst.plans.logical._ + import org.apache.spark.sql.catalyst.rules.RuleExecutor ++import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec} + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.joins._ + import org.apache.spark.sql.internal.SQLConf +@@ -362,6 +363,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP + val executedPlan = df.queryExecution.executedPlan + val shuffleHashJoins = collect(executedPlan) { + case s: ShuffledHashJoinExec => s ++ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[ShuffledHashJoinExec] + } + assert(shuffleHashJoins.size == 1) + assert(shuffleHashJoins.head.buildSide == buildSide) +@@ -371,6 +373,7 @@ class JoinHintSuite extends PlanTest with SharedSparkSession with AdaptiveSparkP + val executedPlan = df.queryExecution.executedPlan + val shuffleMergeJoins = collect(executedPlan) { + case s: SortMergeJoinExec => s ++ case c: CometSortMergeJoinExec => c + } + assert(shuffleMergeJoins.size == 1) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala -index 5125708be32..a1f1ae90796 100644 +index 5125708be32..210ab4f3ce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier @@ -400,7 +569,87 @@ index 5125708be32..a1f1ae90796 100644 import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} -@@ -1369,9 +1370,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -739,7 +740,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + } + } + +- test("test SortMergeJoin (with spill)") { ++ test("test SortMergeJoin (with spill)", ++ IgnoreComet("TODO: Comet SMJ doesn't support spill yet")) { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", + SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { +@@ -1114,9 +1116,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) + .groupBy($"k1").count() + .queryExecution.executedPlan +- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // No extra shuffle before aggregate +- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true }.size === 2) + }) + } + +@@ -1133,10 +1137,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) + .queryExecution + .executedPlan +- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) + assert(collect(plan) { case _: BroadcastHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join +- assert(collect(plan) { case _: SortExec => true }.size === 3) ++ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 3) + }) + + // Test shuffled hash join +@@ -1146,10 +1151,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) + .queryExecution + .executedPlan +- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 2) +- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 2) ++ assert(collect(plan) { ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // No extra sort before last sort merge join +- assert(collect(plan) { case _: SortExec => true }.size === 3) ++ assert(collect(plan) { ++ case _: SortExec | _: CometSortExec => true }.size === 3) + }) + } + +@@ -1240,12 +1248,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + inputDFs.foreach { case (df1, df2, joinExprs) => + val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full") + assert(collect(smjDF.queryExecution.executedPlan) { +- case _: SortMergeJoinExec => true }.size === 1) ++ case _: SortMergeJoinExec | _: CometSortMergeJoinExec => true }.size === 1) + val smjResult = smjDF.collect() + + val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), joinExprs, "full") + assert(collect(shjDF.queryExecution.executedPlan) { +- case _: ShuffledHashJoinExec => true }.size === 1) ++ case _: ShuffledHashJoinExec | _: CometHashJoinExec => true }.size === 1) + // Same result between shuffled hash join and sort merge join + checkAnswer(shjDF, smjResult) + } +@@ -1340,7 +1348,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan + assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) + // Have shuffle before aggregation +- assert(collect(plan) { case _: ShuffleExchangeExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true }.size === 1) + } + + def getJoinQuery(selectExpr: String, joinType: String): String = { +@@ -1369,9 +1378,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -415,7 +664,7 @@ index 5125708be32..a1f1ae90796 100644 } // Test output ordering is not preserved -@@ -1380,9 +1384,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1380,9 +1392,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0" val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -430,6 +679,16 @@ index 5125708be32..a1f1ae90796 100644 } // Test singe partition +@@ -1392,7 +1407,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2 + |""".stripMargin) + val plan = fullJoinDF.queryExecution.executedPlan +- assert(collect(plan) { case _: ShuffleExchangeExec => true}.size == 1) ++ assert(collect(plan) { ++ case _: ShuffleExchangeLike => true}.size == 1) + checkAnswer(fullJoinDF, Row(100)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index b5b34922694..a72403780c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -443,20 +702,38 @@ index b5b34922694..a72403780c4 100644 protected val baseResourcePath = { // use the same way as `SQLQueryTestSuite` to get the resource path +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +index 525d97e4998..8a3e7457618 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +@@ -1508,7 +1508,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark + checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001"))) + } + +- test("external sorting updates peak execution memory") { ++ test("external sorting updates peak execution memory", ++ IgnoreComet("TODO: native CometSort does not update peak execution memory")) { + AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") { + sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect() + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala -index 3cfda19134a..278bb1060c4 100644 +index 3cfda19134a..7590b808def 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala -@@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer +@@ -21,10 +21,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union} +import org.apache.spark.sql.comet.CometScanExec -+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.datasources.FileScanRDD -@@ -1543,6 +1545,12 @@ class SubquerySuite extends QueryTest +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} + import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.test.SharedSparkSession +@@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest fs.inputRDDs().forall( _.asInstanceOf[FileScanRDD].filePartitions.forall( _.files.forall(_.urlEncodedPath.contains("p=0")))) @@ -469,14 +746,78 @@ index 3cfda19134a..278bb1060c4 100644 case _ => false }) } -@@ -2109,6 +2117,7 @@ class SubquerySuite extends QueryTest +@@ -2108,7 +2115,7 @@ class SubquerySuite extends QueryTest + df.collect() val exchanges = collect(df.queryExecution.executedPlan) { - case s: ShuffleExchangeExec => s -+ case s: CometShuffleExchangeExec => s +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s } assert(exchanges.size === 1) } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +index 02990a7a40d..bddf5e1ccc2 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.connector._ + + import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} + import org.apache.spark.sql.catalyst.InternalRow ++import org.apache.spark.sql.comet.CometSortExec + import org.apache.spark.sql.connector.catalog.{PartitionInternalRow, SupportsRead, Table, TableCapability, TableProvider} + import org.apache.spark.sql.connector.catalog.TableCapability._ + import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform} +@@ -33,7 +34,7 @@ import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, + import org.apache.spark.sql.execution.SortExec + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, DataSourceV2ScanRelation} +-import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} ++import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ +@@ -268,13 +269,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + val groupByColJ = df.groupBy($"j").agg(sum($"i")) + checkAnswer(groupByColJ, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(collectFirst(groupByColJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined) + + val groupByIPlusJ = df.groupBy($"i" + $"j").agg(count("*")) + checkAnswer(groupByIPlusJ, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(collectFirst(groupByIPlusJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined) + } + } +@@ -334,10 +335,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + + val (shuffleExpected, sortExpected) = groupByExpects + assert(collectFirst(groupBy.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined === shuffleExpected) + assert(collectFirst(groupBy.queryExecution.executedPlan) { + case e: SortExec => e ++ case c: CometSortExec => c + }.isDefined === sortExpected) + } + +@@ -352,10 +354,11 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS + + val (shuffleExpected, sortExpected) = windowFuncExpects + assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { +- case e: ShuffleExchangeExec => e ++ case e: ShuffleExchangeLike => e + }.isDefined === shuffleExpected) + assert(collectFirst(windowPartByColIOrderByColJ.queryExecution.executedPlan) { + case e: SortExec => e ++ case c: CometSortExec => c + }.isDefined === sortExpected) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala index cfc8b2cc845..c6fcfd7bd08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala @@ -502,6 +843,44 @@ index cfc8b2cc845..c6fcfd7bd08 100644 } } finally { spark.listenerManager.unregister(listener) +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +index cf76f6ca32c..f454128af06 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +@@ -22,6 +22,7 @@ import org.apache.spark.sql.{DataFrame, Row} + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression} + import org.apache.spark.sql.catalyst.plans.physical ++import org.apache.spark.sql.comet.CometSortMergeJoinExec + import org.apache.spark.sql.connector.catalog.Identifier + import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog + import org.apache.spark.sql.connector.catalog.functions._ +@@ -31,7 +32,7 @@ import org.apache.spark.sql.connector.expressions.Expressions._ + import org.apache.spark.sql.execution.SparkPlan + import org.apache.spark.sql.execution.datasources.v2.BatchScanExec + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.joins.SortMergeJoinExec + import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.internal.SQLConf._ +@@ -279,13 +280,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { + Row("bbb", 20, 250.0), Row("bbb", 20, 350.0), Row("ccc", 30, 400.50))) + } + +- private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeExec] = { ++ private def collectShuffles(plan: SparkPlan): Seq[ShuffleExchangeLike] = { + // here we skip collecting shuffle operators that are not associated with SMJ + collect(plan) { + case s: SortMergeJoinExec => s ++ case c: CometSortMergeJoinExec => c.originalPlan + }.flatMap(smj => + collect(smj) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + }) + } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index c0ec8a58bd5..4e8bc6ed3c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -547,6 +926,45 @@ index 418ca3430bb..eb8267192f8 100644 Seq("json", "orc", "parquet").foreach { format => withTempPath { path => val dir = path.getCanonicalPath +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +index 743ec41dbe7..9f30d6c8e04 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +@@ -53,6 +53,10 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv + case ColumnarToRowExec(i: InputAdapter) => isScanPlanTree(i.child) + case p: ProjectExec => isScanPlanTree(p.child) + case f: FilterExec => isScanPlanTree(f.child) ++ // Comet produces scan plan tree like: ++ // ColumnarToRow ++ // +- ReusedExchange ++ case _: ReusedExchangeExec => false + case _: LeafExecNode => true + case _ => false + } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +index 4b3d3a4b805..56e1e0e6f16 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +@@ -18,7 +18,7 @@ + package org.apache.spark.sql.execution + + import org.apache.spark.rdd.RDD +-import org.apache.spark.sql.{execution, DataFrame, Row} ++import org.apache.spark.sql.{execution, DataFrame, IgnoreCometSuite, Row} + import org.apache.spark.sql.catalyst.InternalRow + import org.apache.spark.sql.catalyst.expressions._ + import org.apache.spark.sql.catalyst.plans._ +@@ -35,7 +35,9 @@ import org.apache.spark.sql.internal.SQLConf + import org.apache.spark.sql.test.SharedSparkSession + import org.apache.spark.sql.types._ + +-class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { ++// Ignore this suite when Comet is enabled. This suite tests the Spark planner and Comet planner ++// comes out with too many difference. Simply ignoring this suite for now. ++class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper with IgnoreCometSuite { + import testImplicits._ + + setupTestData() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala index 9e9d717db3b..91a4f9a38d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantProjectsSuite.scala @@ -571,11 +989,108 @@ index 9e9d717db3b..91a4f9a38d5 100644 assert(actual == expected) } } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +index 30ce940b032..0d3f6c6c934 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution + + import org.apache.spark.sql.{DataFrame, QueryTest} + import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} ++import org.apache.spark.sql.comet.CometSortExec + import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} + import org.apache.spark.sql.execution.joins.ShuffledJoin + import org.apache.spark.sql.internal.SQLConf +@@ -33,7 +34,7 @@ abstract class RemoveRedundantSortsSuiteBase + + private def checkNumSorts(df: DataFrame, count: Int): Unit = { + val plan = df.queryExecution.executedPlan +- assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count) ++ assert(collectWithSubqueries(plan) { case _: SortExec | _: CometSortExec => 1 }.length == count) + } + + private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala +index 47679ed7865..9ffbaecb98e 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReplaceHashWithSortAggSuite.scala +@@ -18,6 +18,7 @@ + package org.apache.spark.sql.execution + + import org.apache.spark.sql.{DataFrame, QueryTest} ++import org.apache.spark.sql.comet.CometHashAggregateExec + import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} + import org.apache.spark.sql.internal.SQLConf +@@ -31,7 +32,7 @@ abstract class ReplaceHashWithSortAggSuiteBase + private def checkNumAggs(df: DataFrame, hashAggCount: Int, sortAggCount: Int): Unit = { + val plan = df.queryExecution.executedPlan + assert(collectWithSubqueries(plan) { +- case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec) => s ++ case s @ (_: HashAggregateExec | _: ObjectHashAggregateExec | _: CometHashAggregateExec ) => s + }.length == hashAggCount) + assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala -index ac710c32296..37746bd470d 100644 +index ac710c32296..e163c1a6a76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala -@@ -616,7 +616,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession +@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution + + import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode} + import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator} ++import org.apache.spark.sql.comet.CometSortMergeJoinExec + import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite + import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} + import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +@@ -224,6 +225,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true + case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true ++ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null), +@@ -258,6 +260,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer") + assert(twoJoinsDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: CometSortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, + Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5), +@@ -280,8 +283,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi") + .join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi") + assert(twoJoinsDF.queryExecution.executedPlan.collect { +- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) | +- WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: SortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3))) + } +@@ -302,8 +304,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti") + .join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti") + assert(twoJoinsDF.queryExecution.executedPlan.collect { +- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) | +- WholeStageCodegenExec(_ : SortMergeJoinExec) => true ++ case _: SortMergeJoinExec => true + }.size === 2) + checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9))) + } +@@ -436,7 +437,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession + val plan = df.queryExecution.executedPlan + assert(plan.exists(p => + p.isInstanceOf[WholeStageCodegenExec] && +- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec])) ++ p.asInstanceOf[WholeStageCodegenExec].collect { ++ case _: SortExec => true ++ }.nonEmpty)) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } + +@@ -616,7 +619,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession .write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255", @@ -587,18 +1102,31 @@ index ac710c32296..37746bd470d 100644 val df = spark.read.parquet(path).selectExpr(projection: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala -index 593bd7bb4ba..be1b82d0030 100644 +index 593bd7bb4ba..7ad55e3ab20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala -@@ -29,6 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe - import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} +@@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._ + + import org.apache.spark.SparkException + import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +-import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} ++import org.apache.spark.sql.{Dataset, IgnoreComet, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.comet._ ++import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec} import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec -@@ -116,6 +117,9 @@ class AdaptiveQueryExecSuite +@@ -104,6 +106,7 @@ class AdaptiveQueryExecSuite + private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { + collect(plan) { + case j: BroadcastHashJoinExec => j ++ case j: CometBroadcastHashJoinExec => j.originalPlan.asInstanceOf[BroadcastHashJoinExec] + } + } + +@@ -116,30 +119,38 @@ class AdaptiveQueryExecSuite private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { collect(plan) { case j: SortMergeJoinExec => j @@ -608,6 +1136,331 @@ index 593bd7bb4ba..be1b82d0030 100644 } } + private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { + collect(plan) { + case j: ShuffledHashJoinExec => j ++ case j: CometHashJoinExec => j.originalPlan.asInstanceOf[ShuffledHashJoinExec] + } + } + + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { + collect(plan) { + case j: BaseJoinExec => j ++ case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] ++ case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] + } + } + + private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { + collect(plan) { + case s: SortExec => s ++ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] + } + } + + private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { + collect(plan) { + case agg: BaseAggregateExec => agg ++ case agg: CometHashAggregateExec => agg.originalPlan.asInstanceOf[BaseAggregateExec] + } + } + +@@ -176,6 +187,7 @@ class AdaptiveQueryExecSuite + val parts = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + } ++ + assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) + } + +@@ -184,7 +196,7 @@ class AdaptiveQueryExecSuite + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == numPartition) +@@ -200,7 +212,8 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics ++ // checkNumLocalShuffleReads(adaptivePlan) + } + } + +@@ -227,7 +240,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Reuse the parallelism of coalesced shuffle in local shuffle read") { ++ test("Reuse the parallelism of coalesced shuffle in local shuffle read", ++ IgnoreComet("Comet shuffle changes shuffle partition size")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", +@@ -259,7 +273,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Reuse the default parallelism in local shuffle read") { ++ test("Reuse the default parallelism in local shuffle read", ++ IgnoreComet("Comet shuffle changes shuffle partition size")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", +@@ -273,7 +288,8 @@ class AdaptiveQueryExecSuite + val localReads = collect(adaptivePlan) { + case read: AQEShuffleReadExec if read.isLocalRead => read + } +- assert(localReads.length == 2) ++ // Comet shuffle changes shuffle metrics ++ assert(localReads.length == 1) + val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] + val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 +@@ -322,7 +338,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Scalar subquery") { ++ test("Scalar subquery", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -337,7 +353,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Scalar subquery in later stages") { ++ test("Scalar subquery in later stages", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -353,7 +369,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins") { ++ test("multiple joins", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -398,7 +414,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins with aggregate") { ++ test("multiple joins with aggregate", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -443,7 +459,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("multiple joins with aggregate 2") { ++ test("multiple joins with aggregate 2", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { +@@ -508,7 +524,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Exchange reuse with subqueries") { ++ test("Exchange reuse with subqueries", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -539,7 +555,9 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics, ++ // so we can't check the number of local shuffle reads. ++ // checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) +@@ -560,7 +578,9 @@ class AdaptiveQueryExecSuite + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) +- checkNumLocalShuffleReads(adaptivePlan) ++ // Comet shuffle changes shuffle metrics, ++ // so we can't check the number of local shuffle reads. ++ // checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.isEmpty) +@@ -569,7 +589,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("Broadcast exchange reuse across subqueries") { ++ test("Broadcast exchange reuse across subqueries", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", +@@ -664,7 +685,8 @@ class AdaptiveQueryExecSuite + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + // There is still a SMJ, and its two shuffles can't apply local read. +- checkNumLocalShuffleReads(adaptivePlan, 2) ++ // Comet shuffle changes shuffle metrics ++ // checkNumLocalShuffleReads(adaptivePlan, 2) + } + } + +@@ -786,7 +808,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-29544: adaptive skew join with different join types") { ++ test("SPARK-29544: adaptive skew join with different join types", ++ IgnoreComet("Comet shuffle has different partition metrics")) { + Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => + def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { + findTopLevelSortMergeJoin(plan) +@@ -1004,7 +1027,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("metrics of the shuffle read") { ++ test("metrics of the shuffle read", ++ IgnoreComet("Comet shuffle changes the metrics")) { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT key FROM testData GROUP BY key") +@@ -1599,7 +1623,7 @@ class AdaptiveQueryExecSuite + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") + assert(collect(adaptivePlan) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + }.length == 1) + } + } +@@ -1679,7 +1703,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-33551: Do not use AQE shuffle read for repartition") { ++ test("SPARK-33551: Do not use AQE shuffle read for repartition", ++ IgnoreComet("Comet shuffle changes partition size")) { + def hasRepartitionShuffle(plan: SparkPlan): Boolean = { + find(plan) { + case s: ShuffleExchangeLike => +@@ -1864,6 +1889,9 @@ class AdaptiveQueryExecSuite + def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { + assert(collect(ds.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s ++ case c: CometShuffleExchangeExec ++ if c.originalPlan.shuffleOrigin == origin && ++ c.originalPlan.numPartitions == 2 => c + }.size == 1) + ds.collect() + val plan = ds.queryExecution.executedPlan +@@ -1872,6 +1900,9 @@ class AdaptiveQueryExecSuite + }.isEmpty) + assert(collect(plan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s ++ case c: CometShuffleExchangeExec ++ if c.originalPlan.shuffleOrigin == origin && ++ c.originalPlan.numPartitions == 2 => c + }.size == 1) + checkAnswer(ds, testData) + } +@@ -2028,7 +2059,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35264: Support AQE side shuffled hash join formula") { ++ test("SPARK-35264: Support AQE side shuffled hash join formula", ++ IgnoreComet("Comet shuffle changes the partition size")) { + withTempView("t1", "t2") { + def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { + Seq("100", "100000").foreach { size => +@@ -2114,7 +2146,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { ++ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", +@@ -2213,7 +2246,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2 GROUP BY key1") + val shuffles1 = collect(adaptive1) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles1.size == 3) + // shuffles1.head is the top-level shuffle under the Aggregate operator +@@ -2226,7 +2259,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2") + val shuffles2 = collect(adaptive2) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + if (hasRequiredDistribution) { + assert(shuffles2.size == 3) +@@ -2260,7 +2293,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-35794: Allow custom plugin for cost evaluator") { ++ test("SPARK-35794: Allow custom plugin for cost evaluator", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + CostEvaluator.instantiate( + classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + intercept[IllegalArgumentException] { +@@ -2404,6 +2438,7 @@ class AdaptiveQueryExecSuite + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + assert(adaptive.collect { + case sort: SortExec => sort ++ case sort: CometSortExec => sort + }.size == 1) + val read = collect(adaptive) { + case read: AQEShuffleReadExec => read +@@ -2421,7 +2456,8 @@ class AdaptiveQueryExecSuite + } + } + +- test("SPARK-37357: Add small partition factor for rebalance partitions") { ++ test("SPARK-37357: Add small partition factor for rebalance partitions", ++ IgnoreComet("Comet shuffle changes shuffle metrics")) { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", +@@ -2533,7 +2569,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value2 = value3") + val shuffles1 = collect(adaptive1) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles1.size == 4) + val smj1 = findTopLevelSortMergeJoin(adaptive1) +@@ -2544,7 +2580,7 @@ class AdaptiveQueryExecSuite + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value1 = value3") + val shuffles2 = collect(adaptive2) { +- case s: ShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } + assert(shuffles2.size == 4) + val smj2 = findTopLevelSortMergeJoin(adaptive2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index bd9c79e5b96..ab7584e768e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -628,6 +1481,29 @@ index bd9c79e5b96..ab7584e768e 100644 } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +index ce43edb79c1..c414b19eda7 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +@@ -17,7 +17,7 @@ + + package org.apache.spark.sql.execution.datasources + +-import org.apache.spark.sql.{QueryTest, Row} ++import org.apache.spark.sql.{IgnoreComet, QueryTest, Row} + import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder} + import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} + import org.apache.spark.sql.execution.{QueryExecution, SortExec} +@@ -305,7 +305,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write + } + } + +- test("v1 write with AQE changing SMJ to BHJ") { ++ test("v1 write with AQE changing SMJ to BHJ", ++ IgnoreComet("TODO: Comet SMJ to BHJ by AQE")) { + withPlannedWrite { enabled => + withTable("t") { + sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 1d2e467c94c..3ea82cd1a3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -917,10 +1793,22 @@ index 3a0bd35cb70..b28f06a757f 100644 val workDirPath = workDir.getAbsolutePath val input = spark.range(5).toDF("id") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala -index 26e61c6b58d..cde10983c68 100644 +index 26e61c6b58d..cb09d7e116a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala -@@ -737,7 +737,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils +@@ -45,8 +45,10 @@ import org.apache.spark.sql.util.QueryExecutionListener + import org.apache.spark.util.{AccumulatorContext, JsonProtocol} + + // Disable AQE because metric info is different with AQE on/off ++// This test suite runs tests against the metrics of physical operators. ++// Disabling it for Comet because the metrics are different with Comet enabled. + class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils +- with DisableAdaptiveExecutionSuite { ++ with DisableAdaptiveExecutionSuite with IgnoreCometSuite { + import testImplicits._ + + /** +@@ -737,7 +739,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } } @@ -1002,21 +1890,24 @@ index d083cac48ff..3c11bcde807 100644 import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala -index 266bb343526..b33bb677f0d 100644 +index 266bb343526..a426d8396be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala -@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec +@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} +import org.apache.spark.sql.comet._ -+import org.apache.spark.sql.comet.execution.shuffle._ +import org.apache.spark.sql.execution.{ColumnarToRowExec, FileSourceScanExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.datasources.BucketingUtils - import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -@@ -101,12 +103,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.joins.SortMergeJoinExec + import org.apache.spark.sql.functions._ + import org.apache.spark.sql.internal.SQLConf +@@ -101,12 +102,20 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti } } @@ -1039,7 +1930,7 @@ index 266bb343526..b33bb677f0d 100644 // To verify if the bucket pruning works, this function checks two conditions: // 1) Check if the pruned buckets (before filtering) are empty. // 2) Verify the final result is the same as the expected one -@@ -155,7 +165,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -155,7 +164,8 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition) .queryExecution.executedPlan val fileScan = getFileScan(planWithoutBucketedScan) @@ -1049,7 +1940,7 @@ index 266bb343526..b33bb677f0d 100644 val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType val rowsWithInvalidBuckets = fileScan.execute().filter(row => { -@@ -451,28 +462,46 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -451,28 +461,44 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) { val executedPlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan @@ -1082,13 +1973,11 @@ index 266bb343526..b33bb677f0d 100644 // check existence of shuffle assert( - joinOperator.left.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleLeft, -+ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeExec] || -+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleLeft, ++ joinOperator.left.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.exists(_.isInstanceOf[ShuffleExchangeExec]) == shuffleRight, -+ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeExec] || -+ op.isInstanceOf[CometShuffleExchangeExec]) == shuffleRight, ++ joinOperator.right.exists(op => op.isInstanceOf[ShuffleExchangeLike]) == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -1104,7 +1993,7 @@ index 266bb343526..b33bb677f0d 100644 s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}") // check the output partitioning -@@ -835,11 +864,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -835,11 +861,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") val scanDF = spark.table("bucketed_table").select("j") @@ -1118,14 +2007,13 @@ index 266bb343526..b33bb677f0d 100644 checkAnswer(aggDF, df1.groupBy("j").agg(max("k"))) } } -@@ -1026,15 +1055,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti +@@ -1026,15 +1052,23 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti expectedNumShuffles: Int, expectedCoalescedNumBuckets: Option[Int]): Unit = { val plan = sql(query).queryExecution.executedPlan - val shuffles = plan.collect { case s: ShuffleExchangeExec => s } + val shuffles = plan.collect { -+ case s: ShuffleExchangeExec => s -+ case s: CometShuffleExchangeExec => s ++ case s: ShuffleExchangeLike => s + } assert(shuffles.length == expectedNumShuffles) @@ -1303,6 +2191,120 @@ index 2a2a83d35e1..e3b7b290b3e 100644 val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS() val initialState: KeyValueGroupedDataset[String, RunningCount] = initialStateDS.groupByKey(_._1).mapValues(_._2) +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +index ef5b8a769fe..84fe1bfabc9 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +@@ -37,6 +37,7 @@ import org.apache.spark.sql._ + import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} + import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} + import org.apache.spark.sql.catalyst.util.DateTimeUtils ++import org.apache.spark.sql.comet.CometLocalLimitExec + import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} + import org.apache.spark.sql.execution.command.ExplainCommand + import org.apache.spark.sql.execution.streaming._ +@@ -1103,11 +1104,12 @@ class StreamSuite extends StreamTest { + val localLimits = execPlan.collect { + case l: LocalLimitExec => l + case l: StreamingLocalLimitExec => l ++ case l: CometLocalLimitExec => l + } + + require( + localLimits.size == 1, +- s"Cant verify local limit optimization with this plan:\n$execPlan") ++ s"Cant verify local limit optimization ${localLimits.size} with this plan:\n$execPlan") + + if (expectStreamingLimit) { + assert( +@@ -1115,7 +1117,8 @@ class StreamSuite extends StreamTest { + s"Local limit was not StreamingLocalLimitExec:\n$execPlan") + } else { + assert( +- localLimits.head.isInstanceOf[LocalLimitExec], ++ localLimits.head.isInstanceOf[LocalLimitExec] || ++ localLimits.head.isInstanceOf[CometLocalLimitExec], + s"Local limit was not LocalLimitExec:\n$execPlan") + } + } +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala +index b4c4ec7acbf..20579284856 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationDistributionSuite.scala +@@ -23,6 +23,7 @@ import org.apache.commons.io.FileUtils + import org.scalatest.Assertions + + import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution ++import org.apache.spark.sql.comet.CometHashAggregateExec + import org.apache.spark.sql.execution.aggregate.BaseAggregateExec + import org.apache.spark.sql.execution.streaming.{MemoryStream, StateStoreRestoreExec, StateStoreSaveExec} + import org.apache.spark.sql.functions.count +@@ -67,6 +68,7 @@ class StreamingAggregationDistributionSuite extends StreamTest + // verify aggregations in between, except partial aggregation + val allAggregateExecs = query.lastExecution.executedPlan.collect { + case a: BaseAggregateExec => a ++ case c: CometHashAggregateExec => c.originalPlan + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { +@@ -201,6 +203,7 @@ class StreamingAggregationDistributionSuite extends StreamTest + // verify aggregations in between, except partial aggregation + val allAggregateExecs = executedPlan.collect { + case a: BaseAggregateExec => a ++ case c: CometHashAggregateExec => c.originalPlan + } + + val aggregateExecsWithoutPartialAgg = allAggregateExecs.filter { +diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +index 4d92e270539..33f1c2eb75e 100644 +--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation + import org.apache.spark.sql.{DataFrame, Row, SparkSession} + import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} + import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec ++import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike + import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinExec, StreamingSymmetricHashJoinHelper} + import org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider, StateStore, StateStoreProviderId} + import org.apache.spark.sql.functions._ +@@ -619,14 +619,28 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { + + val numPartitions = spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS) + +- assert(query.lastExecution.executedPlan.collect { +- case j @ StreamingSymmetricHashJoinExec(_, _, _, _, _, _, _, _, _, +- ShuffleExchangeExec(opA: HashPartitioning, _, _), +- ShuffleExchangeExec(opB: HashPartitioning, _, _)) +- if partitionExpressionsColumns(opA.expressions) === Seq("a", "b") +- && partitionExpressionsColumns(opB.expressions) === Seq("a", "b") +- && opA.numPartitions == numPartitions && opB.numPartitions == numPartitions => j +- }.size == 1) ++ val join = query.lastExecution.executedPlan.collect { ++ case j: StreamingSymmetricHashJoinExec => j ++ }.head ++ val opA = join.left.collect { ++ case s: ShuffleExchangeLike ++ if s.outputPartitioning.isInstanceOf[HashPartitioning] && ++ partitionExpressionsColumns( ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning] ++ }.head ++ val opB = join.right.collect { ++ case s: ShuffleExchangeLike ++ if s.outputPartitioning.isInstanceOf[HashPartitioning] && ++ partitionExpressionsColumns( ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning].expressions) === Seq("a", "b") => ++ s.outputPartitioning ++ .asInstanceOf[HashPartitioning] ++ }.head ++ assert(opA.numPartitions == numPartitions && opB.numPartitions == numPartitions) + }) + } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala index abe606ad9c1..2d930b64cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala @@ -1327,7 +2329,7 @@ index abe606ad9c1..2d930b64cca 100644 val tblTargetName = "tbl_target" val tblSourceQualified = s"default.$tblSourceName" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala -index dd55fcfe42c..b4776c50e49 100644 +index dd55fcfe42c..293e9dc2986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest @@ -1364,20 +2366,20 @@ index dd55fcfe42c..b4776c50e49 100644 + } + + /** -+ * Whether Spark should only apply Comet scan optimization. This is only effective when ++ * Whether to enable ansi mode This is only effective when + * [[isCometEnabled]] returns true. + */ -+ protected def isCometScanOnly: Boolean = { -+ val v = System.getenv("ENABLE_COMET_SCAN_ONLY") ++ protected def enableCometAnsiMode: Boolean = { ++ val v = System.getenv("ENABLE_COMET_ANSI_MODE") + v != null && v.toBoolean + } + + /** -+ * Whether to enable ansi mode This is only effective when ++ * Whether Spark should only apply Comet scan optimization. This is only effective when + * [[isCometEnabled]] returns true. + */ -+ protected def enableCometAnsiMode: Boolean = { -+ val v = System.getenv("ENABLE_COMET_ANSI_MODE") ++ protected def isCometScanOnly: Boolean = { ++ val v = System.getenv("ENABLE_COMET_SCAN_ONLY") + v != null && v.toBoolean + } + @@ -1394,7 +2396,7 @@ index dd55fcfe42c..b4776c50e49 100644 spark.internalCreateDataFrame(withoutFilters.execute(), schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala -index ed2e309fa07..f64cc283903 100644 +index ed2e309fa07..e071fc44960 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -74,6 +74,28 @@ trait SharedSparkSessionBase @@ -1414,6 +2416,7 @@ index ed2e309fa07..f64cc283903 100644 + .set("spark.shuffle.manager", + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") + .set("spark.comet.exec.shuffle.enabled", "true") ++ .set("spark.comet.memoryOverhead", "10g") + } + + if (enableCometAnsiMode) { @@ -1421,11 +2424,23 @@ index ed2e309fa07..f64cc283903 100644 + .set("spark.sql.ansi.enabled", "true") + .set("spark.comet.ansi.enabled", "true") + } -+ + } conf.set( StaticSQLConf.WAREHOUSE_PATH, conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) +diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +index 1510e8957f9..7618419d8ff 100644 +--- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala ++++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceWithActualMetricsSuite.scala +@@ -43,7 +43,7 @@ class SqlResourceWithActualMetricsSuite + import testImplicits._ + + // Exclude nodes which may not have the metrics +- val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject") ++ val excludedNodes = List("WholeStageCodegen", "Project", "SerializeFromObject", "RowToColumnar") + + implicit val formats = new DefaultFormats { + override def dateFormatter = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala index 52abd248f3a..7a199931a08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/DynamicPartitionPruningHiveScanSuite.scala @@ -1463,10 +2478,10 @@ index 1966e1e64fd..cde97a0aafe 100644 spark.sql( """ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala -index 07361cfdce9..1763168a808 100644 +index 07361cfdce9..25b0dc3ef7e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala -@@ -55,25 +55,54 @@ object TestHive +@@ -55,25 +55,53 @@ object TestHive new SparkContext( System.getProperty("spark.sql.test.master", "local[1]"), "TestSQLContext", @@ -1530,9 +2545,8 @@ index 07361cfdce9..1763168a808 100644 + .set("spark.sql.ansi.enabled", "true") + .set("spark.comet.ansi.enabled", "true") + } - + } -+ + + conf + } + )) diff --git a/docs/source/contributor-guide/development.md b/docs/source/contributor-guide/development.md index 356c81b33..b0d247c0b 100644 --- a/docs/source/contributor-guide/development.md +++ b/docs/source/contributor-guide/development.md @@ -73,6 +73,18 @@ After that you can open the project in CLion. The IDE should automatically detec Like other Maven projects, you can run tests in IntelliJ IDEA by right-clicking on the test class or test method and selecting "Run" or "Debug". However if the tests is related to the native side. Please make sure to run `make core` or `cd core && cargo build` before running the tests in IDEA. +### Running Tests from command line + +It is possible to specify which ScalaTest suites you want to run from the CLI using the `suites` +argument, for example if you only want to execute the test cases that contains *valid* +in their name in `org.apache.comet.CometCastSuite` you can use + +```sh +mvn test -Dsuites="org.apache.comet.CometCastSuite valid" -Dskip.surefire.tests=true +``` + +Other options for selecting specific suites are described in the [ScalaTest Maven Plugin documentation](https://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin) + ## Benchmark There's a `make` command to run micro benchmarks in the repo. For diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index 57a4271f4..a4ed9289f 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -88,11 +88,25 @@ The following cast operations are generally compatible with Spark except for the | long | double | | | long | string | | | float | boolean | | +| float | byte | | +| float | short | | +| float | integer | | +| float | long | | | float | double | | +| float | decimal | | | float | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | | double | boolean | | +| double | byte | | +| double | short | | +| double | integer | | +| double | long | | | double | float | | +| double | decimal | | | double | string | There can be differences in precision. For example, the input "1.4E-45" will produce 1.0E-45 instead of 1.4E-45 | +| decimal | byte | | +| decimal | short | | +| decimal | integer | | +| decimal | long | | | decimal | float | | | decimal | double | | | string | boolean | | @@ -115,8 +129,6 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | integer | decimal | No overflow check | | long | decimal | No overflow check | -| float | decimal | No overflow check | -| double | decimal | No overflow check | | string | timestamp | Not all valid formats are supported | | binary | string | Only works for binary data representing valid UTF-8 strings | diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 5a4c56dee..6932db0c1 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -29,7 +29,7 @@ Comet provides the following configuration settings. | spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config is false. | false | | spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. By default, this config is 100. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 | | spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. By default, this config is 3. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 | -| spark.comet.columnar.shuffle.enabled | Force Comet to only use columnar shuffle for CometScan and Spark regular operators. If this is enabled, Comet native shuffle will not be enabled but only Arrow shuffle. By default, this config is false. | false | +| spark.comet.columnar.shuffle.enabled | Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. If this is enabled, Comet prefers columnar shuffle than native shuffle. By default, this config is true. | true | | spark.comet.columnar.shuffle.memory.factor | Fraction of Comet memory to be allocated per executor process for Comet shuffle. Comet memory size is specified by `spark.comet.memoryOverhead` or calculated by `spark.comet.memory.overhead.factor` * `spark.executor.memory`. By default, this config is 1.0. | 1.0 | | spark.comet.debug.enabled | Whether to enable debug mode for Comet. By default, this config is false. When enabled, Comet will do additional checks for debugging purpose. For example, validating array when importing arrays from JVM at native side. Note that these checks may be expensive in performance and should only be enabled for debugging purpose. | false | | spark.comet.enabled | Whether to enable Comet extension for Spark. When this is turned on, Spark will use Comet to read Parquet data source. Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. By default, this config is the value of the env var `ENABLE_COMET` if set, or true otherwise. | true | diff --git a/pom.xml b/pom.xml index d47953fac..5b054f001 100644 --- a/pom.xml +++ b/pom.xml @@ -88,7 +88,8 @@ under the License. -ea -Xmx4g -Xss4m ${extraJavaTestArgs} spark-3.3-plus spark-3.4 - spark-3.x + spark-3.x + spark-3.4 @@ -500,6 +501,7 @@ under the License. not-needed-yet not-needed-yet + spark-3.2 @@ -512,6 +514,7 @@ under the License. 1.12.0 spark-3.3-plus not-needed-yet + spark-3.3 @@ -523,6 +526,7 @@ under the License. 1.13.1 spark-3.3-plus spark-3.4 + spark-3.4 @@ -725,6 +729,7 @@ under the License. file:src/test/resources/log4j2.properties + ${skip.surefire.tests} diff --git a/spark/pom.xml b/spark/pom.xml index 7d3d3d758..21fa09fc2 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -263,7 +263,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 57e07b8cd..795bdb428 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -226,19 +226,25 @@ object CometCast { } private def canCastFromFloat(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType | DataTypes.DoubleType => Compatible() - case _: DecimalType => Incompatible(Some("No overflow check")) + case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case _: DecimalType => Compatible() case _ => Unsupported } private def canCastFromDouble(toType: DataType): SupportLevel = toType match { - case DataTypes.BooleanType | DataTypes.FloatType => Compatible() - case _: DecimalType => Incompatible(Some("No overflow check")) + case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() + case _: DecimalType => Compatible() case _ => Unsupported } private def canCastFromDecimal(toType: DataType): SupportLevel = toType match { - case DataTypes.FloatType | DataTypes.DoubleType => Compatible() + case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType => + Compatible() case _ => Unsupported } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 86e9f10b9..7238990ad 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -47,12 +47,13 @@ import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupp import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} +import org.apache.comet.shims.CometExprShim import org.apache.comet.shims.ShimQueryPlanSerde /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimQueryPlanSerde { +object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -1467,6 +1468,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) + case e: Unhex if !isSpark32 => + val unHex = unhexSerde(e) + + val childExpr = exprToProtoInternal(unHex._1, inputs) + val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs) + + val optExpr = + scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) + optExprWithInfo(optExpr, expr, unHex._1) + case e @ Ceil(child) => val childExpr = exprToProtoInternal(child, inputs) child.dataType match { @@ -2519,7 +2530,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { withInfo(join, "SortMergeJoin is not enabled") None - case op if isCometSink(op) => + case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType)) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index fb2f2a209..49c263f3f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -334,7 +334,7 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec { // seed by hashing will help. Refer to SPARK-21782 for more details. val partitionId = TaskContext.get().partitionId() var position = new XORShiftRandom(partitionId).nextInt(numPartitions) - (row: InternalRow) => { + (_: InternalRow) => { // The HashPartitioner will handle the `mod` by the number of partitions position += 1 position diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala index 6e5b44c8f..996526e55 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/plans/AliasAwareOutputExpression.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.trees.CurrentOrigin trait AliasAwareOutputExpression extends SQLConfHelper { // `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only. // Use a default value for now. - protected val aliasCandidateLimit = 100 + protected val aliasCandidateLimit: Int = + conf.getConfString("spark.sql.optimizer.expressionProjectionCandidateLimit", "100").toInt protected def outputExpressions: Seq[NamedExpression] /** diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..0c45a9c2c --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..0c45a9c2c --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 000000000..409e1c94b --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1d698a49a..1881c561c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} import org.apache.comet.expressions.{CometCast, Compatible} @@ -320,23 +320,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateFloats(), DataTypes.BooleanType) } - ignore("cast FloatType to ByteType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast FloatType to ByteType") { castTest(generateFloats(), DataTypes.ByteType) } - ignore("cast FloatType to ShortType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast FloatType to ShortType") { castTest(generateFloats(), DataTypes.ShortType) } - ignore("cast FloatType to IntegerType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast FloatType to IntegerType") { castTest(generateFloats(), DataTypes.IntegerType) } - ignore("cast FloatType to LongType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast FloatType to LongType") { castTest(generateFloats(), DataTypes.LongType) } @@ -344,8 +340,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateFloats(), DataTypes.DoubleType) } - ignore("cast FloatType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] + test("cast FloatType to DecimalType(10,2)") { castTest(generateFloats(), DataTypes.createDecimalType(10, 2)) } @@ -378,23 +373,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDoubles(), DataTypes.BooleanType) } - ignore("cast DoubleType to ByteType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast DoubleType to ByteType") { castTest(generateDoubles(), DataTypes.ByteType) } - ignore("cast DoubleType to ShortType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast DoubleType to ShortType") { castTest(generateDoubles(), DataTypes.ShortType) } - ignore("cast DoubleType to IntegerType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast DoubleType to IntegerType") { castTest(generateDoubles(), DataTypes.IntegerType) } - ignore("cast DoubleType to LongType") { - // https://github.com/apache/datafusion-comet/issues/350 + test("cast DoubleType to LongType") { castTest(generateDoubles(), DataTypes.LongType) } @@ -402,8 +393,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDoubles(), DataTypes.FloatType) } - ignore("cast DoubleType to DecimalType(10,2)") { - // Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE] + test("cast DoubleType to DecimalType(10,2)") { castTest(generateDoubles(), DataTypes.createDecimalType(10, 2)) } @@ -430,45 +420,57 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { ignore("cast DecimalType(10,2) to BooleanType") { // Arrow error: Cast error: Casting from Decimal128(38, 18) to Boolean not supported - castTest(generateDecimals(), DataTypes.BooleanType) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.BooleanType) } - ignore("cast DecimalType(10,2) to ByteType") { - // https://github.com/apache/datafusion-comet/issues/350 - castTest(generateDecimals(), DataTypes.ByteType) + test("cast DecimalType(10,2) to ByteType") { + castTest(generateDecimalsPrecision10Scale2(), DataTypes.ByteType) } - ignore("cast DecimalType(10,2) to ShortType") { - // https://github.com/apache/datafusion-comet/issues/350 - castTest(generateDecimals(), DataTypes.ShortType) + test("cast DecimalType(10,2) to ShortType") { + castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType) } - ignore("cast DecimalType(10,2) to IntegerType") { - // https://github.com/apache/datafusion-comet/issues/350 - castTest(generateDecimals(), DataTypes.IntegerType) + test("cast DecimalType(10,2) to IntegerType") { + castTest(generateDecimalsPrecision10Scale2(), DataTypes.IntegerType) } - ignore("cast DecimalType(10,2) to LongType") { - // https://github.com/apache/datafusion-comet/issues/350 - castTest(generateDecimals(), DataTypes.LongType) + test("cast DecimalType(10,2) to LongType") { + castTest(generateDecimalsPrecision10Scale2(), DataTypes.LongType) } test("cast DecimalType(10,2) to FloatType") { - castTest(generateDecimals(), DataTypes.FloatType) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.FloatType) } test("cast DecimalType(10,2) to DoubleType") { - castTest(generateDecimals(), DataTypes.DoubleType) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.DoubleType) + } + + test("cast DecimalType(38,18) to ByteType") { + castTest(generateDecimalsPrecision38Scale18(), DataTypes.ByteType) + } + + test("cast DecimalType(38,18) to ShortType") { + castTest(generateDecimalsPrecision38Scale18(), DataTypes.ShortType) + } + + test("cast DecimalType(38,18) to IntegerType") { + castTest(generateDecimalsPrecision38Scale18(), DataTypes.IntegerType) + } + + test("cast DecimalType(38,18) to LongType") { + castTest(generateDecimalsPrecision38Scale18(), DataTypes.LongType) } ignore("cast DecimalType(10,2) to StringType") { // input: 0E-18, expected: 0E-18, actual: 0.000000000000000000 - castTest(generateDecimals(), DataTypes.StringType) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType) } ignore("cast DecimalType(10,2) to TimestampType") { // input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, actual: 1969-12-31 15:59:59.876544 - castTest(generateDecimals(), DataTypes.TimestampType) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType) } // CAST from StringType @@ -800,9 +802,47 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(values).toDF("a") } - private def generateDecimals(): DataFrame = { - // TODO improve this - val values = Seq(BigDecimal("123456.789"), BigDecimal("-123456.789"), BigDecimal("0.0")) + private def generateDecimalsPrecision10Scale2(): DataFrame = { + val values = Seq( + BigDecimal("-99999999.999"), + BigDecimal("-123456.789"), + BigDecimal("-32768.678"), + // Short Min + BigDecimal("-32767.123"), + BigDecimal("-128.12312"), + // Byte Min + BigDecimal("-127.123"), + BigDecimal("0.0"), + // Byte Max + BigDecimal("127.123"), + BigDecimal("128.12312"), + BigDecimal("32767.122"), + // Short Max + BigDecimal("32768.678"), + BigDecimal("123456.789"), + BigDecimal("99999999.999")) + withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b") + } + + private def generateDecimalsPrecision38Scale18(): DataFrame = { + val values = Seq( + BigDecimal("-99999999999999999999.999999999999"), + BigDecimal("-9223372036854775808.234567"), + // Long Min + BigDecimal("-9223372036854775807.123123"), + BigDecimal("-2147483648.123123123"), + // Int Min + BigDecimal("-2147483647.123123123"), + BigDecimal("-123456.789"), + BigDecimal("0.00000000000"), + BigDecimal("123456.789"), + // Int Max + BigDecimal("2147483647.123123123"), + BigDecimal("2147483648.123123123"), + BigDecimal("9223372036854775807.123123"), + // Long Max + BigDecimal("9223372036854775808.234567"), + BigDecimal("99999999999999999999.999999999999")) withNulls(values).toDF("a") } @@ -867,26 +907,27 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - private def castFallbackTestTimezone( - input: DataFrame, - toType: DataType, - expectedMessage: String): Unit = { - withTempPath { dir => - val data = roundtripParquet(input, dir).coalesce(1) - data.createOrReplaceTempView("t") - - withSQLConf( - (SQLConf.ANSI_ENABLED.key, "false"), - (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true"), - (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) { - val df = data.withColumn("converted", col("a").cast(toType)) - df.collect() - val str = - new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan) - assert(str.contains(expectedMessage)) - } - } - } + // TODO Commented out to work around scalafix since this is currently unused. + // private def castFallbackTestTimezone( + // input: DataFrame, + // toType: DataType, + // expectedMessage: String): Unit = { + // withTempPath { dir => + // val data = roundtripParquet(input, dir).coalesce(1) + // data.createOrReplaceTempView("t") + // + // withSQLConf( + // (SQLConf.ANSI_ENABLED.key, "false"), + // (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key, "true"), + // (SQLConf.SESSION_LOCAL_TIMEZONE.key, "America/Los_Angeles")) { + // val df = data.withColumn("converted", col("a").cast(toType)) + // df.collect() + // val str = + // new ExtendedExplainInfo().generateExtendedInfo(df.queryExecution.executedPlan) + // assert(str.contains(expectedMessage)) + // } + // } + // } private def castTimestampTest(input: DataFrame, toType: DataType) = { withTempPath { dir => @@ -960,11 +1001,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { val cometMessageModified = cometMessage .replace("[CAST_INVALID_INPUT] ", "") .replace("[CAST_OVERFLOW] ", "") - assert(cometMessageModified == sparkMessage) + .replace("[NUMERIC_VALUE_OUT_OF_RANGE] ", "") + + if (sparkMessage.contains("cannot be represented as")) { + assert(cometMessage.contains("cannot be represented as")) + } else { + assert(cometMessageModified == sparkMessage) + } } else { // for Spark 3.2 we just make sure we are seeing a similar type of error if (sparkMessage.contains("causes overflow")) { assert(cometMessage.contains("due to an overflow")) + } else if (sparkMessage.contains("cannot be represented as")) { + assert(cometMessage.contains("cannot be represented as")) } else { // assume that this is an invalid input message in the form: // `invalid input syntax for type numeric: -9223372036854775809` diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index eb4429dc6..28027c5cb 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1036,7 +1036,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("unhex") { + // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that + // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not + // the same (and this only applies to edge cases with hex inputs with lengths that are not divisible by 2) + assume(!isSpark32, "unhex function has incorrect behavior in 3.2") + val table = "unhex_table" + withTable(table) { + sql(s"create table $table(col string) using parquet") + + sql(s"""INSERT INTO $table VALUES + |('537061726B2053514C'), + |('737472696E67'), + |('\\0'), + |(''), + |('###'), + |('G123'), + |('hello'), + |('A1B'), + |('0A1B')""".stripMargin) + + checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") + } + } test("length, reverse, instr, replace, translate") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index bfbd16749..114351fd1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -587,7 +587,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on nested array") { - Seq("false", "true").foreach { execEnabled => + Seq("false", "true").foreach { _ => Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index 3342d750c..eb27dd36c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -157,8 +157,9 @@ class CometTPCDSQuerySuite conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "20g") conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") - conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") + conf.set(MEMORY_OFFHEAP_SIZE.key, "20g") conf }