From 27f167bad4ac2b90457ecce49682ef2932726c3b Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 27 Feb 2024 16:22:54 -0800 Subject: [PATCH 1/3] doc: Add Quickstart Comet doc section (#125) Co-authored-by: o_voievodin --- README.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/README.md b/README.md index f0786099d..f48dfd932 100644 --- a/README.md +++ b/README.md @@ -58,3 +58,49 @@ Linux, Apple OSX (Intel and M1) - Apache Spark 3.2, 3.3, or 3.4 - JDK 8 and up - GLIBC 2.17 (Centos 7) and up + +## Getting started + +Make sure the requirements above are met and software installed on your machine + +### Clone repo +```commandline +git clone https://github.com/apache/arrow-datafusion-comet.git +``` + +### Specify the Spark version and build the Comet +Spark 3.4 used for the example. +``` +cd arrow-datafusion-comet +make release PROFILES="-Pspark-3.4" +``` + +### Run Spark with Comet enabled +Make sure `SPARK_HOME` points to the same Spark version as Comet has built for. + +``` +$SPARK_HOME/bin/spark-shell --jars spark/target/comet-spark-spark3.4_2.12-0.1.0-SNAPSHOT.jar \ +--conf spark.sql.extensions=org.apache.comet.CometSparkSessionExtensions \ +--conf spark.comet.enabled=true \ +--conf spark.comet.exec.enabled=true \ +--conf spark.comet.exec.all.enabled=true +``` + +### Verify Comet enabled for Spark SQL query + +Create a test Parquet source +```scala +scala> (0 until 10).toDF("a").write.mode("overwrite").parquet("/tmp/test") +``` + +Query the data from the test source and check: +- INFO message shows the native Comet library has been initialized. +- The query plan reflects Comet operators being used for this query instead of Spark ones +```scala +scala> spark.read.parquet("/tmp/test").createOrReplaceTempView("t1"); spark.sql("select * from t1 where a > 5").explain +INFO src/lib.rs: Comet native library initialized +== Physical Plan == + *(1) ColumnarToRow + +- CometFilter [a#14], (isnotnull(a#14) AND (a#14 > 5)) ++- CometScan parquet [a#14] Batched: true, DataFilters: [isnotnull(a#14), (a#14 > 5)], Format: CometParquet, Location: InMemoryFileIndex(1 paths)[file:/tmp/test], PartitionFilters: [], PushedFilters: [IsNotNull(a), GreaterThan(a,5)], ReadSchema: struct +``` \ No newline at end of file From ee977c3d4277d57a657b97e05c077d66798e6457 Mon Sep 17 00:00:00 2001 From: comphead Date: Tue, 27 Feb 2024 17:17:18 -0800 Subject: [PATCH 2/3] build: Upgrade DF to 36.0.0 and arrow-rs 50.0.0 (#66) * Upgrade DF and arrow-rs * fix benches * fix merge * fix merge * Update core/src/execution/datafusion/expressions/scalar_funcs.rs Co-authored-by: Liang-Chi Hsieh * Update core/src/execution/datafusion/expressions/scalar_funcs.rs Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: o_voievodin Co-authored-by: Liang-Chi Hsieh --- core/Cargo.lock | 93 +++++++++++++------ core/Cargo.toml | 6 +- core/benches/common.rs | 2 + .../execution/datafusion/expressions/avg.rs | 12 +-- .../datafusion/expressions/avg_decimal.rs | 10 +- .../datafusion/expressions/scalar_funcs.rs | 20 ++-- .../datafusion/expressions/sum_decimal.rs | 10 +- core/src/execution/datafusion/planner.rs | 2 + core/src/execution/operators/copy.rs | 5 +- core/src/execution/operators/scan.rs | 4 +- 10 files changed, 103 insertions(+), 61 deletions(-) diff --git a/core/Cargo.lock b/core/Cargo.lock index 0f262c03c..456d96966 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -492,16 +492,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.31" +version = "0.4.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" +checksum = "5bc015644b92d5890fab7489e49d21f879d5c990186827d42ec511919404f38b" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "wasm-bindgen", - "windows-targets 0.48.5", + "windows-targets 0.52.0", ] [[package]] @@ -650,8 +650,8 @@ version = "7.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c64043d6c7b7a4c58e39e7efccfdea7b93d885a795d0c054a69dbbf4dd52686" dependencies = [ - "strum", - "strum_macros", + "strum 0.25.0", + "strum_macros 0.25.3", "unicode-width", ] @@ -833,9 +833,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4328f5467f76d890fe3f924362dbc3a838c6a733f762b32d87f9e0b7bef5fb49" +checksum = "b2b360b692bf6c6d6e6b6dbaf41a3be0020daeceac0f406aed54c75331e50dbb" dependencies = [ "ahash", "arrow", @@ -849,6 +849,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-plan", @@ -874,9 +875,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d29a7752143b446db4a2cccd9a6517293c6b97e8c39e520ca43ccd07135a4f7e" +checksum = "37f343ccc298f440e25aa38ff82678291a7acc24061c7370ba6c0ff5cc811412" dependencies = [ "ahash", "arrow", @@ -893,9 +894,9 @@ dependencies = [ [[package]] name = "datafusion-execution" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d447650af16e138c31237f53ddaef6dd4f92f0e2d3f2f35d190e16c214ca496" +checksum = "3f9c93043081487e335399a21ebf8295626367a647ac5cb87d41d18afad7d0f7" dependencies = [ "arrow", "chrono", @@ -914,9 +915,9 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8d19598e48a498850fb79f97a9719b1f95e7deb64a7a06f93f313e8fa1d524b" +checksum = "e204d89909e678846b6a95f156aafc1ee5b36cb6c9e37ec2e1449b078a38c818" dependencies = [ "ahash", "arrow", @@ -924,15 +925,30 @@ dependencies = [ "datafusion-common", "paste", "sqlparser", - "strum", - "strum_macros", + "strum 0.26.1", + "strum_macros 0.26.1", +] + +[[package]] +name = "datafusion-functions" +version = "36.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98f1c73f7801b2b8ba2297b3ad78ffcf6c1fc6b8171f502987eb9ad5cb244ee7" +dependencies = [ + "arrow", + "base64", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "hex", + "log", ] [[package]] name = "datafusion-optimizer" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7feb0391f1fc75575acb95b74bfd276903dc37a5409fcebe160bc7ddff2010" +checksum = "5ae27e07bf1f04d327be5c2a293470879801ab5535204dc3b16b062fda195496" dependencies = [ "arrow", "async-trait", @@ -948,9 +964,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e911bca609c89a54e8f014777449d8290327414d3e10c57a3e3c2122e38878d0" +checksum = "dde620cd9ef76a3bca9c754fb68854bd2349c49f55baf97e08001f9e967f6d6b" dependencies = [ "ahash", "arrow", @@ -958,11 +974,13 @@ dependencies = [ "arrow-buffer", "arrow-ord", "arrow-schema", + "arrow-string", "base64", "blake2", "blake3", "chrono", "datafusion-common", + "datafusion-execution", "datafusion-expr", "half 2.1.0", "hashbrown 0.14.3", @@ -982,9 +1000,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b546b8a02e9c2ab35ac6420d511f12a4701950c1eb2e568c122b4fefb0be3" +checksum = "9a4c75fba9ea99d64b2246cbd2fcae2e6fc973e6616b1015237a616036506dd4" dependencies = [ "ahash", "arrow", @@ -1013,9 +1031,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "35.0.0" +version = "36.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d18d36f260bbbd63aafdb55339213a23d540d3419810575850ef0a798a6b768" +checksum = "21474a95c3a62d113599d21b439fa15091b538bac06bd20be0bb2e7d22903c09" dependencies = [ "arrow", "arrow-schema", @@ -2516,9 +2534,9 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "sqlparser" -version = "0.41.0" +version = "0.43.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" +checksum = "f95c4bae5aba7cd30bd506f7140026ade63cff5afd778af8854026f9606bf5d4" dependencies = [ "log", "sqlparser_derive", @@ -2558,8 +2576,14 @@ name = "strum" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" + +[[package]] +name = "strum" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "723b93e8addf9aa965ebe2d11da6d7540fa2283fcea14b3371ff055f7ba13f5f" dependencies = [ - "strum_macros", + "strum_macros 0.26.1", ] [[package]] @@ -2575,6 +2599,19 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "strum_macros" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a3417fc93d76740d974a01654a09777cb500428cc874ca9f45edfe0c4d4cd18" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.48", +] + [[package]] name = "subtle" version = "2.5.0" @@ -2740,9 +2777,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", diff --git a/core/Cargo.toml b/core/Cargo.toml index 14e271788..4dc5afe6f 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -66,9 +66,9 @@ itertools = "0.11.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.8" } paste = "1.0.14" -datafusion-common = { version = "35.0.0" } -datafusion = { default-features = false, version = "35.0.0", features = ["unicode_expressions"] } -datafusion-physical-expr = { version = "35.0.0", default-features = false , features = ["unicode_expressions"] } +datafusion-common = { version = "36.0.0" } +datafusion = { default-features = false, version = "36.0.0", features = ["unicode_expressions"] } +datafusion-physical-expr = { version = "36.0.0", default-features = false , features = ["unicode_expressions"] } unicode-segmentation = "^1.10.1" once_cell = "1.18.0" regex = "1.9.6" diff --git a/core/benches/common.rs b/core/benches/common.rs index 059721698..15952b83c 100644 --- a/core/benches/common.rs +++ b/core/benches/common.rs @@ -45,6 +45,7 @@ pub fn create_int64_array(size: usize, null_density: f32, min: i64, max: i64) -> .collect() } +#[allow(dead_code)] pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray where T: ArrowPrimitiveType, @@ -64,6 +65,7 @@ where /// Creates a dictionary with random keys and values, with value type `T`. /// Note here the keys are the dictionary indices. +#[allow(dead_code)] pub fn create_dictionary_array( size: usize, value_size: usize, diff --git a/core/src/execution/datafusion/expressions/avg.rs b/core/src/execution/datafusion/expressions/avg.rs index dc2b34747..1e04ab0e9 100644 --- a/core/src/execution/datafusion/expressions/avg.rs +++ b/core/src/execution/datafusion/expressions/avg.rs @@ -24,11 +24,11 @@ use arrow_array::{ Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, }; use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{type_coercion::aggregates::avg_return_type, Accumulator}; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::{ - expressions::format_state_name, AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr, +use datafusion::logical_expr::{ + type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, }; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; @@ -146,7 +146,7 @@ pub struct AvgAccumulator { } impl Accumulator for AvgAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::Float64(self.sum), ScalarValue::from(self.count), @@ -175,7 +175,7 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64( self.sum.map(|f| f / self.count as f64), )) diff --git a/core/src/execution/datafusion/expressions/avg_decimal.rs b/core/src/execution/datafusion/expressions/avg_decimal.rs index dc7bf1599..6fb558109 100644 --- a/core/src/execution/datafusion/expressions/avg_decimal.rs +++ b/core/src/execution/datafusion/expressions/avg_decimal.rs @@ -24,11 +24,9 @@ use arrow_array::{ Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray, }; use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::Accumulator; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::{ - expressions::format_state_name, AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr, -}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; use std::{any::Any, sync::Arc}; use arrow_array::ArrowNativeTypeOp; @@ -214,7 +212,7 @@ impl AvgDecimalAccumulator { } impl Accumulator for AvgDecimalAccumulator { - fn state(&self) -> Result> { + fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), ScalarValue::from(self.count), @@ -266,7 +264,7 @@ impl Accumulator for AvgDecimalAccumulator { Ok(()) } - fn evaluate(&self) -> Result { + fn evaluate(&mut self) -> Result { fn make_decimal128(value: Option, precision: u8, scale: i8) -> ScalarValue { ScalarValue::Decimal128(value, precision, scale) } diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 875956621..8ff13e125 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -31,13 +31,11 @@ use datafusion::{ physical_plan::ColumnarValue, }; use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, - ScalarValue, + cast::as_generic_string_array, exec_err, internal_err, DataFusionError, + Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::{ - execution_props::ExecutionProps, - functions::{create_physical_fun, make_scalar_function}, - math_expressions, + execution_props::ExecutionProps, functions::create_physical_fun, math_expressions, }; use num::{BigInt, Signed, ToPrimitive}; use unicode_segmentation::UnicodeSegmentation; @@ -366,7 +364,12 @@ fn spark_round( let (precision, scale) = get_precision_scale(data_type); make_decimal_array(array, precision, scale, &f) } - _ => make_scalar_function(math_expressions::round)(args), + DataType::Float32 | DataType::Float64 => { + Ok(ColumnarValue::Array(math_expressions::round(&[ + array.clone() + ])?)) + } + dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, ColumnarValue::Scalar(a) => match a { ScalarValue::Int64(a) if *point < 0 => { @@ -386,7 +389,10 @@ fn spark_round( let (precision, scale) = get_precision_scale(data_type); make_decimal_scalar(a, precision, scale, &f) } - _ => make_scalar_function(math_expressions::round)(args), + ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( + ScalarValue::try_from_array(&math_expressions::round(&[a.to_array()?])?, 0)?, + )), + dt => exec_err!("Not supported datatype for ROUND: {dt}"), }, } } diff --git a/core/src/execution/datafusion/expressions/sum_decimal.rs b/core/src/execution/datafusion/expressions/sum_decimal.rs index a6da5f579..2afbbf011 100644 --- a/core/src/execution/datafusion/expressions/sum_decimal.rs +++ b/core/src/execution/datafusion/expressions/sum_decimal.rs @@ -24,11 +24,9 @@ use arrow_array::{ }; use arrow_data::decimal::validate_decimal_precision; use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::Accumulator; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; use datafusion_common::{Result as DFResult, ScalarValue}; -use datafusion_physical_expr::{ - aggregate::utils::down_cast_any_ref, AggregateExpr, EmitTo, GroupsAccumulator, PhysicalExpr, -}; +use datafusion_physical_expr::{aggregate::utils::down_cast_any_ref, AggregateExpr, PhysicalExpr}; use std::{any::Any, ops::BitAnd, sync::Arc}; use crate::unlikely; @@ -204,7 +202,7 @@ impl Accumulator for SumDecimalAccumulator { Ok(()) } - fn evaluate(&self) -> DFResult { + fn evaluate(&mut self) -> DFResult { // For each group: // 1. if `is_empty` is true, it means either there is no value or all values for the group // are null, in this case we'll return null @@ -224,7 +222,7 @@ impl Accumulator for SumDecimalAccumulator { std::mem::size_of_val(self) } - fn state(&self) -> DFResult> { + fn state(&mut self) -> DFResult> { let sum = if self.is_not_null { ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)? } else { diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 66a29cbb1..f4a0cec79 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -608,6 +608,7 @@ impl PhysicalPlanner { vec![left, right], data_type, None, + false, ))) } _ => Ok(Arc::new(BinaryExpr::new(left, op, right))), @@ -984,6 +985,7 @@ impl PhysicalPlanner { args.to_vec(), data_type, None, + args.is_empty(), )); Ok(scalar_expr) diff --git a/core/src/execution/operators/copy.rs b/core/src/execution/operators/copy.rs index c818d622d..996db2b47 100644 --- a/core/src/execution/operators/copy.rs +++ b/core/src/execution/operators/copy.rs @@ -28,7 +28,7 @@ use arrow_array::{ArrayRef, RecordBatch}; use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::{execution::TaskContext, physical_expr::*, physical_plan::*}; -use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use super::copy_or_cast_array; @@ -141,8 +141,7 @@ impl CopyStream { .iter() .map(|v| copy_or_cast_array(v)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), vectors) - .map_err(|err| DataFusionError::ArrowError(err, None)) + RecordBatch::try_new(self.schema.clone(), vectors).map_err(|e| arrow_datafusion_err!(e)) } } diff --git a/core/src/execution/operators/scan.rs b/core/src/execution/operators/scan.rs index 9f85de80f..e31230c58 100644 --- a/core/src/execution/operators/scan.rs +++ b/core/src/execution/operators/scan.rs @@ -43,7 +43,7 @@ use datafusion::{ physical_expr::*, physical_plan::{ExecutionPlan, *}, }; -use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use jni::{ objects::{GlobalRef, JLongArray, JObject, ReleaseMode}, sys::jlongArray, @@ -325,7 +325,7 @@ impl ScanStream { let options = RecordBatchOptions::new().with_row_count(Some(num_rows)); RecordBatch::try_new_with_options(self.schema.clone(), new_columns, &options) - .map_err(|err| DataFusionError::ArrowError(err, None)) + .map_err(|e| arrow_datafusion_err!(e)) } } From 4d103b88bf9d0165954e04a98f3eb928fdda2291 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 28 Feb 2024 00:53:15 -0800 Subject: [PATCH 3/3] fix: Fix corrupted AggregateMode when transforming plan parameters (#118) --- .../comet/CometSparkSessionExtensions.scala | 20 ++++++---- .../apache/spark/sql/comet/operators.scala | 40 ++++++++++++------- .../apache/comet/exec/CometExecSuite.scala | 16 +++++++- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index f2aba74a0..10c332801 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -237,7 +237,13 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometProjectExec(nativeOp, op, op.projectList, op.output, op.child, None) + CometProjectExec( + nativeOp, + op, + op.projectList, + op.output, + op.child, + SerializedPlan(None)) case None => op } @@ -246,7 +252,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometFilterExec(nativeOp, op, op.condition, op.child, None) + CometFilterExec(nativeOp, op, op.condition, op.child, SerializedPlan(None)) case None => op } @@ -255,7 +261,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometSortExec(nativeOp, op, op.sortOrder, op.child, None) + CometSortExec(nativeOp, op, op.sortOrder, op.child, SerializedPlan(None)) case None => op } @@ -264,7 +270,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometLocalLimitExec(nativeOp, op, op.limit, op.child, None) + CometLocalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) case None => op } @@ -273,7 +279,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometGlobalLimitExec(nativeOp, op, op.limit, op.child, None) + CometGlobalLimitExec(nativeOp, op, op.limit, op.child, SerializedPlan(None)) case None => op } @@ -282,7 +288,7 @@ class CometSparkSessionExtensions val newOp = transform1(op) newOp match { case Some(nativeOp) => - CometExpandExec(nativeOp, op, op.projections, op.child, None) + CometExpandExec(nativeOp, op, op.projections, op.child, SerializedPlan(None)) case None => op } @@ -305,7 +311,7 @@ class CometSparkSessionExtensions child.output, if (modes.nonEmpty) Some(modes.head) else None, child, - None) + SerializedPlan(None)) case None => op } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0298bc643..e75f9a4a5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -150,7 +150,7 @@ abstract class CometNativeExec extends CometExec { * The serialized native query plan, optional. This is only defined when the current node is the * "boundary" node between native and Spark. */ - def serializedPlanOpt: Option[Array[Byte]] + def serializedPlanOpt: SerializedPlan /** The Comet native operator */ def nativeOp: Operator @@ -200,7 +200,7 @@ abstract class CometNativeExec extends CometExec { } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - serializedPlanOpt match { + serializedPlanOpt.plan match { case None => // This is in the middle of a native execution, it should not be executed directly. throw new CometRuntimeException( @@ -282,11 +282,11 @@ abstract class CometNativeExec extends CometExec { */ def convertBlock(): CometNativeExec = { def transform(arg: Any): AnyRef = arg match { - case serializedPlan: Option[Array[Byte]] if serializedPlan.isEmpty => + case serializedPlan: SerializedPlan if serializedPlan.isEmpty => val out = new ByteArrayOutputStream() nativeOp.writeTo(out) out.close() - Some(out.toByteArray) + SerializedPlan(Some(out.toByteArray)) case other: AnyRef => other case null => null } @@ -300,8 +300,8 @@ abstract class CometNativeExec extends CometExec { */ def cleanBlock(): CometNativeExec = { def transform(arg: Any): AnyRef = arg match { - case serializedPlan: Option[Array[Byte]] if serializedPlan.isDefined => - None + case serializedPlan: SerializedPlan if serializedPlan.isDefined => + SerializedPlan(None) case other: AnyRef => other case null => null } @@ -323,13 +323,23 @@ abstract class CometNativeExec extends CometExec { abstract class CometUnaryExec extends CometNativeExec with UnaryExecNode +/** + * Represents the serialized plan of Comet native operators. Only the first operator in a block of + * continuous Comet native operators has defined plan bytes which contains the serialization of + * the plan tree of the block. + */ +case class SerializedPlan(plan: Option[Array[Byte]]) { + def isDefined: Boolean = plan.isDefined + def isEmpty: Boolean = plan.isEmpty +} + case class CometProjectExec( override val nativeOp: Operator, override val originalPlan: SparkPlan, projectList: Seq[NamedExpression], override val output: Seq[Attribute], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override def producedAttributes: AttributeSet = outputSet override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -356,7 +366,7 @@ case class CometFilterExec( override val originalPlan: SparkPlan, condition: Expression, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -390,7 +400,7 @@ case class CometSortExec( override val originalPlan: SparkPlan, sortOrder: Seq[SortOrder], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -422,7 +432,7 @@ case class CometLocalLimitExec( override val originalPlan: SparkPlan, limit: Int, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -449,7 +459,7 @@ case class CometGlobalLimitExec( override val originalPlan: SparkPlan, limit: Int, child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -474,7 +484,7 @@ case class CometExpandExec( override val originalPlan: SparkPlan, projections: Seq[Seq[Expression]], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override def producedAttributes: AttributeSet = outputSet @@ -538,7 +548,7 @@ case class CometHashAggregateExec( input: Seq[Attribute], mode: Option[AggregateMode], child: SparkPlan, - override val serializedPlanOpt: Option[Array[Byte]]) + override val serializedPlanOpt: SerializedPlan) extends CometUnaryExec { override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -576,7 +586,7 @@ case class CometHashAggregateExec( case class CometScanWrapper(override val nativeOp: Operator, override val originalPlan: SparkPlan) extends CometNativeExec with LeafExecNode { - override val serializedPlanOpt: Option[Array[Byte]] = None + override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, originalPlan) } @@ -592,7 +602,7 @@ case class CometSinkPlaceHolder( override val originalPlan: SparkPlan, child: SparkPlan) extends CometUnaryExec { - override val serializedPlanOpt: Option[Array[Byte]] = None + override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = { this.copy(child = newChild) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 29b6e120a..05be34c10 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,13 +31,14 @@ import org.apache.spark.sql.{AnalysisException, Column, CometTestBase, DataFrame import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Hex +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometFilterExec, CometHashAggregateExec, CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, UnionExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.functions.{date_add, expr} +import org.apache.spark.sql.functions.{date_add, expr, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String @@ -57,6 +58,19 @@ class CometExecSuite extends CometTestBase { } } + test("Fix corrupted AggregateMode when transforming plan parameters") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "table") { + val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2")) + val agg = stripAQEPlan(df.queryExecution.executedPlan).collectFirst { + case s: CometHashAggregateExec => s + }.get + + assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode]) + val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] + assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode]) + } + } + test("CometBroadcastExchangeExec") { withSQLConf(CometConf.COMET_EXEC_BROADCAST_ENABLED.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") {