From 15e7baa5d1f76a1a0720e4f848cd01a14c09e147 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 9 Jul 2024 15:07:01 -0600 Subject: [PATCH] feat: Upgrade to DataFusion 40.0.0-rc1 (#644) * Partial upgrade to DataFusion 40.0.0 * fix * implement more udaf * update bitwise agg * add func names * remove unused imports * remove arrow-string dep * fix copy and paste error * use 40.0.0-rc1 and temporarily ignore failing test * clippy * fall back to Spark for count windows aggregate * address feedback --- native/Cargo.lock | 128 ++++++++++-------- native/core/Cargo.toml | 24 ++-- .../execution/datafusion/expressions/abs.rs | 2 +- .../execution/datafusion/expressions/avg.rs | 2 +- .../execution/datafusion/operators/expand.rs | 4 + .../core/src/execution/datafusion/planner.rs | 119 ++++++++++++---- .../execution/datafusion/shuffle_writer.rs | 4 + native/core/src/execution/operators/copy.rs | 4 + native/core/src/execution/operators/scan.rs | 4 + .../apache/comet/serde/QueryPlanSerde.scala | 5 +- .../apache/comet/exec/CometExecSuite.scala | 2 +- 11 files changed, 202 insertions(+), 96 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 6136e0339..df1828ee0 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -114,8 +114,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" dependencies = [ "arrow-arith", "arrow-array", @@ -134,8 +135,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" dependencies = [ "arrow-array", "arrow-buffer", @@ -148,8 +150,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" dependencies = [ "ahash", "arrow-buffer", @@ -164,8 +167,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" dependencies = [ "bytes", "half", @@ -174,8 +178,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" dependencies = [ "arrow-array", "arrow-buffer", @@ -194,8 +199,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f843490bd258c5182b66e888161bb6f198f49f3792f7c7f98198b924ae0f564" dependencies = [ "arrow-array", "arrow-buffer", @@ -212,8 +218,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" dependencies = [ "arrow-buffer", "arrow-schema", @@ -223,8 +230,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf9c3fb57390a1af0b7bb3b5558c1ee1f63905f3eccf49ae7676a8d1e6e5a72" dependencies = [ "arrow-array", "arrow-buffer", @@ -237,8 +245,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "654e7f3724176b66ddfacba31af397c48e106fbe4d281c8144e7d237df5acfd7" dependencies = [ "arrow-array", "arrow-buffer", @@ -256,8 +265,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" dependencies = [ "arrow-array", "arrow-buffer", @@ -270,8 +280,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" dependencies = [ "ahash", "arrow-array", @@ -284,16 +295,18 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" dependencies = [ "bitflags 2.6.0", ] [[package]] name = "arrow-select" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" dependencies = [ "ahash", "arrow-array", @@ -305,8 +318,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" dependencies = [ "arrow-array", "arrow-buffer", @@ -790,8 +804,8 @@ dependencies = [ [[package]] name = "datafusion" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -890,8 +904,8 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -910,16 +924,16 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "arrow", "chrono", @@ -938,8 +952,8 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -956,8 +970,8 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "arrow", "base64", @@ -967,7 +981,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", "hashbrown", "hex", "itertools 0.12.1", @@ -982,8 +995,8 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -999,8 +1012,8 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "arrow", "async-trait", @@ -1012,13 +1025,14 @@ dependencies = [ "indexmap", "itertools 0.12.1", "log", + "paste", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -1032,7 +1046,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", "datafusion-physical-expr-common", "half", "hashbrown", @@ -1047,19 +1060,21 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown", "rand", ] [[package]] name = "datafusion-physical-plan" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "ahash", "arrow", @@ -1091,8 +1106,8 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "39.0.0" -source = "git+https://github.com/viirya/arrow-datafusion.git?rev=17446b1#17446b1886d2872be482efa4225d2b35e5d96569" +version = "40.0.0" +source = "git+https://github.com/apache/datafusion.git?rev=40.0.0-rc1#4cae81363e29f011c6602a7a7a54e1aaee841046" dependencies = [ "arrow", "arrow-array", @@ -2015,8 +2030,9 @@ dependencies = [ [[package]] name = "parquet" -version = "52.0.0" -source = "git+https://github.com/apache/arrow-rs.git?rev=0a4d8a1#0a4d8a14b58e45ef92e31541f0b51a5b25de5f10" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f22ba0d95db56dde8685e3fadcb915cdaadda31ab8abbe3ff7f0ad1ef333267" dependencies = [ "ahash", "bytes", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index bd0a3d5e4..160db2949 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -32,12 +32,12 @@ include = [ ] [dependencies] -arrow = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1", features = ["prettyprint", "ffi", "chrono-tz"] } -arrow-array = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1" } -arrow-buffer = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1" } -arrow-data = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1" } -arrow-schema = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1" } -parquet = { git = "https://github.com/apache/arrow-rs.git", rev = "0a4d8a1", default-features = false, features = ["experimental"] } +arrow = { version = "52.1.0", features = ["prettyprint", "ffi", "chrono-tz"] } +arrow-array = { version = "52.1.0" } +arrow-buffer = { version = "52.1.0" } +arrow-data = { version = "52.1.0" } +arrow-schema = { version = "52.1.0" } +parquet = { version = "52.1.0", default-features = false, features = ["experimental"] } half = { version = "2.4.1", default-features = false } futures = "0.3.28" mimalloc = { version = "*", default-features = false, optional = true } @@ -64,12 +64,12 @@ itertools = "0.11.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.8" } paste = "1.0.14" -datafusion-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1" } -datafusion = { default-features = false, git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1", features = ["unicode_expressions", "crypto_expressions"] } -datafusion-functions = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1", features = ["crypto_expressions"] } -datafusion-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1", default-features = false } -datafusion-physical-expr-common = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1", default-features = false } -datafusion-physical-expr = { git = "https://github.com/viirya/arrow-datafusion.git", rev = "17446b1", default-features = false } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1" } +datafusion = { default-features = false, git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", features = ["unicode_expressions", "crypto_expressions"] } +datafusion-functions = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", features = ["crypto_expressions"] } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } +datafusion-physical-expr-common = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } +datafusion-physical-expr = { git = "https://github.com/apache/datafusion.git", rev = "40.0.0-rc1", default-features = false } unicode-segmentation = "^1.10.1" once_cell = "1.18.0" regex = "1.9.6" diff --git a/native/core/src/execution/datafusion/expressions/abs.rs b/native/core/src/execution/datafusion/expressions/abs.rs index 4eb8c7c1e..a037e5cbc 100644 --- a/native/core/src/execution/datafusion/expressions/abs.rs +++ b/native/core/src/execution/datafusion/expressions/abs.rs @@ -37,7 +37,7 @@ impl CometAbsFunc { pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { if let EvalMode::Legacy | EvalMode::Ansi = eval_mode { Ok(Self { - inner_abs_func: math::abs().inner(), + inner_abs_func: math::abs().inner().clone(), eval_mode, data_type_name, }) diff --git a/native/core/src/execution/datafusion/expressions/avg.rs b/native/core/src/execution/datafusion/expressions/avg.rs index 1ff276e5d..3c8865bd1 100644 --- a/native/core/src/execution/datafusion/expressions/avg.rs +++ b/native/core/src/execution/datafusion/expressions/avg.rs @@ -47,7 +47,7 @@ pub struct Avg { impl Avg { /// Create a new AVG aggregate function pub fn new(expr: Arc, name: impl Into, data_type: DataType) -> Self { - let result_data_type = avg_return_type(&data_type).unwrap(); + let result_data_type = avg_return_type("avg", &data_type).unwrap(); Self { name: name.into(), diff --git a/native/core/src/execution/datafusion/operators/expand.rs b/native/core/src/execution/datafusion/operators/expand.rs index 5285dfb46..67171212f 100644 --- a/native/core/src/execution/datafusion/operators/expand.rs +++ b/native/core/src/execution/datafusion/operators/expand.rs @@ -126,6 +126,10 @@ impl ExecutionPlan for CometExpandExec { fn properties(&self) -> &PlanProperties { &self.cache } + + fn name(&self) -> &str { + "CometExpandExec" + } } pub struct ExpandStream { diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 40515c0c4..360380400 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -20,19 +20,22 @@ use std::{collections::HashMap, sync::Arc}; use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::physical_plan::windows::BoundedWindowAggExec; use datafusion::physical_plan::InputOrderMode; use datafusion::{ arrow::{compute::SortOptions, datatypes::SchemaRef}, common::DataFusionError, execution::FunctionRegistry, + functions_aggregate::first_last::{FirstValue, LastValue}, logical_expr::Operator as DataFusionOperator, physical_expr::{ execution_props::ExecutionProps, expressions::{ - in_list, BinaryExpr, BitAnd, BitOr, BitXor, CaseExpr, CastExpr, Column, Count, - FirstValue, IsNotNullExpr, IsNullExpr, LastValue, Literal as DataFusionLiteral, Max, - Min, NotExpr, Sum, + in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, + Literal as DataFusionLiteral, Max, Min, NotExpr, }, AggregateExpr, PhysicalExpr, PhysicalSortExpr, ScalarFunctionExpr, }, @@ -647,7 +650,7 @@ impl PhysicalPlanner { let left = self.create_expr(left, input_schema.clone())?; let right = self.create_expr(right, input_schema.clone())?; match ( - op, + &op, left.data_type(&input_schema), right.data_type(&input_schema), ) { @@ -1208,11 +1211,19 @@ impl PhysicalPlanner { .iter() .map(|child| self.create_expr(child, schema.clone())) .collect::, _>>()?; - Ok(Arc::new(Count::new_with_multiple_exprs( - children, + + create_aggregate_expr( + &count_udaf(), + &children, + &[], + &[], + &[], + schema.as_ref(), "count", - DataType::Int64, - ))) + false, + false, + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } AggExprStruct::Min(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; @@ -1236,7 +1247,18 @@ impl PhysicalPlanner { // cast to the result data type of SUM if necessary, we should not expect // a cast failure since it should have already been checked at Spark side let child = Arc::new(CastExpr::new(child, datatype.clone(), None)); - Ok(Arc::new(Sum::new(child, "sum", datatype))) + create_aggregate_expr( + &sum_udaf(), + &[child], + &[], + &[], + &[], + schema.as_ref(), + "sum", + false, + false, + ) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) } } } @@ -1263,31 +1285,79 @@ impl PhysicalPlanner { AggExprStruct::First(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); - - create_aggregate_expr(&func, &[child], &[], &[], &schema, "first", false, false) - .map_err(|e| e.into()) + create_aggregate_expr( + &func, + &[child], + &[], + &[], + &[], + &schema, + "first", + false, + false, + ) + .map_err(|e| e.into()) } AggExprStruct::Last(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - - create_aggregate_expr(&func, &[child], &[], &[], &schema, "last", false, false) - .map_err(|e| e.into()) + create_aggregate_expr( + &func, + &[child], + &[], + &[], + &[], + &schema, + "last", + false, + false, + ) + .map_err(|e| e.into()) } AggExprStruct::BitAndAgg(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(BitAnd::new(child, "bit_and", datatype))) + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + create_aggregate_expr( + &bit_and_udaf(), + &[child], + &[], + &[], + &[], + &schema, + "bit_and", + false, + false, + ) + .map_err(|e| e.into()) } AggExprStruct::BitOrAgg(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(BitOr::new(child, "bit_or", datatype))) + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + create_aggregate_expr( + &bit_or_udaf(), + &[child], + &[], + &[], + &[], + &schema, + "bit_or", + false, + false, + ) + .map_err(|e| e.into()) } AggExprStruct::BitXorAgg(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), schema)?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(BitXor::new(child, "bit_xor", datatype))) + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + create_aggregate_expr( + &bit_xor_udaf(), + &[child], + &[], + &[], + &[], + &schema, + "bit_xor", + false, + false, + ) + .map_err(|e| e.into()) } AggExprStruct::Covariance(expr) => { let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?; @@ -1483,6 +1553,7 @@ impl PhysicalPlanner { &window_func, window_func_name, &window_args, + &[], partition_by, sort_exprs, window_frame.into(), diff --git a/native/core/src/execution/datafusion/shuffle_writer.rs b/native/core/src/execution/datafusion/shuffle_writer.rs index 5afc9a53e..6e59ce53a 100644 --- a/native/core/src/execution/datafusion/shuffle_writer.rs +++ b/native/core/src/execution/datafusion/shuffle_writer.rs @@ -160,6 +160,10 @@ impl ExecutionPlan for ShuffleWriterExec { fn properties(&self) -> &PlanProperties { &self.cache } + + fn name(&self) -> &str { + "ShuffleWriterExec" + } } impl ShuffleWriterExec { diff --git a/native/core/src/execution/operators/copy.rs b/native/core/src/execution/operators/copy.rs index d011b3cb2..68c91aafc 100644 --- a/native/core/src/execution/operators/copy.rs +++ b/native/core/src/execution/operators/copy.rs @@ -126,6 +126,10 @@ impl ExecutionPlan for CopyExec { fn properties(&self) -> &PlanProperties { &self.cache } + + fn name(&self) -> &str { + "CopyExec" + } } struct CopyStream { diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index de5328210..68dd773cf 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -270,6 +270,10 @@ impl ExecutionPlan for ScanExec { fn properties(&self) -> &PlanProperties { &self.cache } + + fn name(&self) -> &str { + "ScanExec" + } } impl DisplayAs for ScanExec { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 65de37c83..da534b02c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -208,7 +208,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr match { case agg: AggregateExpression => agg.aggregateFunction match { - case _: Min | _: Max | _: Count => + // TODO add support for Count (this was removed when upgrading + // to DataFusion 40 because it is no longer a built-in window function) + // https://github.com/apache/datafusion-comet/issues/645 + case _: Min | _: Max => Some(agg) case _ => withInfo(windowExpr, "Unsupported aggregate", expr) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 9cc4e7f78..e657af9b9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -1438,7 +1438,7 @@ class CometExecSuite extends CometTestBase { SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled) { withParquetTable((0 until 10).map(i => (i, 10 - i)), "t1") { // TODO: test nulls val aggregateFunctions = - List("COUNT(_1)", "MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates + List("MAX(_1)", "MIN(_1)") // TODO: Test all the aggregates aggregateFunctions.foreach { function => val queries = Seq(