diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 463de90c2..5aee02f11 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -19,6 +19,7 @@ package org.apache.comet +import java.util.Locale import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer @@ -131,14 +132,17 @@ object CometConf { .booleanConf .createWithDefault(false) - val COMET_COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] = - conf("spark.comet.columnar.shuffle.enabled") - .doc( - "Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. " + - "If this is enabled, Comet prefers columnar shuffle than native shuffle. " + - "By default, this config is true.") - .booleanConf - .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") + .doc("The mode of Comet shuffle. This config is only effective if Comet shuffle " + + "is enabled. Available modes are 'native', 'jvm', and 'auto'. " + + "'native' is for native shuffle which has best performance in general. " + + "'jvm' is for jvm-based columnar shuffle which has higher coverage than native shuffle. " + + "'auto' is for Comet to choose the best shuffle mode based on the query plan. " + + "By default, this config is 'jvm'.") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set("native", "jvm", "auto")) + .createWithDefault("jvm") val COMET_SHUFFLE_ENFORCE_MODE_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.shuffle.enforceMode.enabled") diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f68732fb1..fd1f9166d 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, datatypes::{ @@ -33,23 +32,27 @@ use arrow::{ util::display::FormatOptions, }; use arrow_array::{ - types::{Int16Type, Int32Type, Int64Type, Int8Type}, + types::{Date32Type, Int16Type, Int32Type, Int64Type, Int8Type}, Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }; use arrow_schema::{DataType, Schema}; -use chrono::{TimeZone, Timelike}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; -use crate::execution::datafusion::expressions::utils::{ - array_with_timezone, down_cast_any_ref, spark_cast, +use crate::{ + errors::{CometError, CometResult}, + execution::datafusion::expressions::utils::{ + array_with_timezone, down_cast_any_ref, spark_cast, + }, }; static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); + static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -511,6 +514,31 @@ impl Cast { (DataType::Utf8, DataType::Timestamp(_, _)) => { Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? } + (DataType::Utf8, DataType::Date32) => { + Self::cast_string_to_date(&array, to_type, self.eval_mode)? + } + (DataType::Dictionary(key_type, value_type), DataType::Date32) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + match value_type.as_ref() { + DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; + Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? + } + DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; + Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? + } + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), + } + } (DataType::Int64, DataType::Int32) | (DataType::Int64, DataType::Int16) | (DataType::Int64, DataType::Int8) @@ -635,6 +663,38 @@ impl Cast { Ok(cast_array) } + fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Date32 => { + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + for i in 0..len { + if !string_array.is_null(i) { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => cast_array.append_value(cast_value), + Ok(None) => cast_array.append_null(), + Err(e) => return Err(e), + } + } else { + cast_array.append_null() + } + } + Arc::new(cast_array.finish()) as ArrayRef + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + fn cast_string_to_timestamp( array: &ArrayRef, to_type: &DataType, @@ -858,7 +918,7 @@ impl Cast { i32, "FLOAT", "INT", - std::i32::MAX, + i32::MAX, "{:e}" ), (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( @@ -870,7 +930,7 @@ impl Cast { i64, "FLOAT", "BIGINT", - std::i64::MAX, + i64::MAX, "{:e}" ), (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( @@ -904,7 +964,7 @@ impl Cast { i32, "DOUBLE", "INT", - std::i32::MAX, + i32::MAX, "{:e}D" ), (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( @@ -916,7 +976,7 @@ impl Cast { i64, "DOUBLE", "BIGINT", - std::i64::MAX, + i64::MAX, "{:e}D" ), (DataType::Decimal128(precision, scale), DataType::Int8) => { @@ -936,7 +996,7 @@ impl Cast { Int32Array, i32, "INT", - std::i32::MAX, + i32::MAX, *precision, *scale ) @@ -948,7 +1008,7 @@ impl Cast { Int64Array, i64, "BIGINT", - std::i64::MAX, + i64::MAX, *precision, *scale ) @@ -1264,15 +1324,15 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult } if timestamp.is_none() { - if eval_mode == EvalMode::Ansi { - return Err(CometError::CastInvalidValue { + return if eval_mode == EvalMode::Ansi { + Err(CometError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), - }); + }) } else { - return Ok(None); - } + Ok(None) + }; } match timestamp { @@ -1409,13 +1469,136 @@ fn parse_str_to_time_only_timestamp(value: &str) -> CometResult> { Ok(Some(timestamp)) } +//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate. +fn date_parser(date_str: &str, eval_mode: EvalMode) -> CometResult> { + // local functions + fn get_trimmed_start(bytes: &[u8]) -> usize { + let mut start = 0; + while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) { + start += 1; + } + start + } + + fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize { + let mut end = bytes.len() - 1; + while end > start && is_whitespace_or_iso_control(bytes[end]) { + end -= 1; + } + end + 1 + } + + fn is_whitespace_or_iso_control(byte: u8) -> bool { + byte.is_ascii_whitespace() || byte.is_ascii_control() + } + + fn is_valid_digits(segment: i32, digits: usize) -> bool { + // An integer is able to represent a date within [+-]5 million years. + let max_digits_year = 7; + //year (segment 0) can be between 4 to 7 digits, + //month and day (segment 1 and 2) can be between 1 to 2 digits + (segment == 0 && digits >= 4 && digits <= max_digits_year) + || (segment != 0 && digits > 0 && digits <= 2) + } + + fn return_result(date_str: &str, eval_mode: EvalMode) -> CometResult> { + if eval_mode == EvalMode::Ansi { + Err(CometError::CastInvalidValue { + value: date_str.to_string(), + from_type: "STRING".to_string(), + to_type: "DATE".to_string(), + }) + } else { + Ok(None) + } + } + // end local functions + + if date_str.is_empty() { + return return_result(date_str, eval_mode); + } + + //values of date segments year, month and day defaulting to 1 + let mut date_segments = [1, 1, 1]; + let mut sign = 1; + let mut current_segment = 0; + let mut current_segment_value = 0; + let mut current_segment_digits = 0; + let bytes = date_str.as_bytes(); + + let mut j = get_trimmed_start(bytes); + let str_end_trimmed = get_trimmed_end(j, bytes); + + if j == str_end_trimmed { + return return_result(date_str, eval_mode); + } + + //assign a sign to the date + if bytes[j] == b'-' || bytes[j] == b'+' { + sign = if bytes[j] == b'-' { -1 } else { 1 }; + j += 1; + } + + //loop to the end of string until we have processed 3 segments, + //exit loop on encountering any space ' ' or 'T' after the 3rd segment + while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) { + let b = bytes[j]; + if current_segment < 2 && b == b'-' { + //check for validity of year and month segments if current byte is separator + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + //if valid update corresponding segment with the current segment value. + date_segments[current_segment as usize] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + current_segment += 1; + } else if !b.is_ascii_digit() { + return return_result(date_str, eval_mode); + } else { + //increment value of current segment by the next digit + let parsed_value = (b - b'0') as i32; + current_segment_value = current_segment_value * 10 + parsed_value; + current_segment_digits += 1; + } + j += 1; + } + + //check for validity of last segment + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + + if current_segment < 2 && j < str_end_trimmed { + // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed. + return return_result(date_str, eval_mode); + } + + date_segments[current_segment as usize] = current_segment_value; + + match NaiveDate::from_ymd_opt( + sign * date_segments[0], + date_segments[1] as u32, + date_segments[2] as u32, + ) { + Some(date) => { + let duration_since_epoch = date + .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date()) + .num_days(); + Ok(Some(duration_since_epoch.to_i32().unwrap())) + } + None => Ok(None), + } +} + #[cfg(test)] mod tests { - use super::*; use arrow::datatypes::TimestampMicrosecondType; use arrow_array::StringArray; use arrow_schema::TimeUnit; + use super::*; + #[test] fn timestamp_parser_test() { // write for all formats @@ -1480,6 +1663,168 @@ mod tests { assert_eq!(result.len(), 2); } + #[test] + fn date_parser_test() { + for date in &[ + "2020", + "2020-01", + "2020-01-01", + "02020-01-01", + "002020-01-01", + "0002020-01-01", + "2020-1-1", + "2020-01-01 ", + "2020-01-01T", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(18262)); + } + } + + //dates in invalid formats + for date in &[ + "abc", + "", + "not_a_date", + "3/", + "3/12", + "3/12/2020", + "3/12/2002 T", + "202", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31 ", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + assert!(date_parser(*date, EvalMode::Ansi).is_err()); + } + + for date in &["-3638-5"] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(-2048160)); + } + } + + //Naive Date only supports years 262142 AD to 262143 BC + //returns None for dates out of range supported by Naive Date. + for date in &[ + "-262144-1-1", + "262143-01-1", + "262143-1-1", + "262143-01-1 ", + "262143-01-01T ", + "262143-1-01T 1234", + "-0973250", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + } + } + + #[test] + fn test_cast_string_to_date() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let result = + Cast::cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(date32_array.len(), 4); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), 18262)); + } + + #[test] + fn test_cast_string_array_with_valid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("-262143-12-31"), + Some("\n -262143-12-31 "), + Some("-262143-12-31T \t\n"), + Some("\n\t-262143-12-31T\r"), + Some("-262143-12-31T 123123123"), + Some("\r\n-262143-12-31T \r123123123"), + Some("\n -262143-12-31T \n\t"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result.len(), 7); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), -96464928)); + } + } + + #[test] + fn test_cast_string_array_with_invalid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + //4 invalid dates + Some("2020-010-01T"), + Some("202"), + Some(" 202 "), + Some("\n 2020-\r8 "), + Some("2020-01-01T"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + date32_array.iter().collect::>(), + vec![ + Some(18262), + Some(18262), + Some(18262), + None, + None, + None, + None, + Some(18262) + ] + ); + } + + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); + match result { + Err(e) => assert!( + e.to_string().contains( + "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed") + ), + _ => panic!("Expected error"), + } + } + #[test] fn test_cast_string_as_i8() { // basic diff --git a/core/src/execution/datafusion/expressions/correlation.rs b/core/src/execution/datafusion/expressions/correlation.rs new file mode 100644 index 000000000..c83341e54 --- /dev/null +++ b/core/src/execution/datafusion/expressions/correlation.rs @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::{and, filter, is_not_null}; + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::{ + covariance::CovarianceAccumulator, stats::StatsType, stddev::StddevAccumulator, + utils::down_cast_any_ref, +}; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; + +/// CORR aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Correlation { + name: String, + expr1: Arc, + expr2: Arc, + null_on_divide_by_zero: bool, +} + +impl Correlation { + pub fn new( + expr1: Arc, + expr2: Arc, + name: impl Into, + data_type: DataType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of correlation just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr1, + expr2, + null_on_divide_by_zero, + } + } +} + +impl AggregateExpr for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new( + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_2"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr1.clone(), self.expr2.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Correlation { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr1.eq(&x.expr1) + && self.expr2.eq(&x.expr2) + && self.null_on_divide_by_zero == x.null_on_divide_by_zero + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, + null_on_divide_by_zero: bool, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new(null_on_divide_by_zero: bool) -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population)?, + stddev1: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?, + stddev2: StddevAccumulator::try_new(StatsType::Population, null_on_divide_by_zero)?, + null_on_divide_by_zero, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.covar.get_algo_const()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.stddev2.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + if !values[0].is_empty() && !values[1].is_empty() { + self.covar.update_batch(&values)?; + self.stddev1.update_batch(&values[0..1])?; + self.stddev2.update_batch(&values[1..2])?; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.retract_batch(&values)?; + self.stddev1.retract_batch(&values[0..1])?; + self.stddev2.retract_batch(&values[1..2])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + states[0].clone(), + states[1].clone(), + states[2].clone(), + states[3].clone(), + ]; + let states_s1 = [states[0].clone(), states[1].clone(), states[4].clone()]; + let states_s2 = [states[0].clone(), states[2].clone(), states[5].clone()]; + + if states[0].len() > 0 && states[1].len() > 0 && states[2].len() > 0 { + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + match (covar, stddev1, stddev2) { + ( + ScalarValue::Float64(Some(c)), + ScalarValue::Float64(Some(s1)), + ScalarValue::Float64(Some(s2)), + ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))), + _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)), + _ => { + if self.covar.get_count() == 1.0 { + return Ok(ScalarValue::Float64(Some(f64::NAN))); + } + Ok(ScalarValue::Float64(None)) + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + self.covar.size() + - std::mem::size_of_val(&self.stddev1) + + self.stddev1.size() + - std::mem::size_of_val(&self.stddev2) + + self.stddev2.size() + } +} diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 10cac1696..9db4b65b3 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -27,6 +27,7 @@ pub use normalize_nan::NormalizeNaNAndZero; pub mod avg; pub mod avg_decimal; pub mod bloom_filter_might_contain; +pub mod correlation; pub mod covariance; pub mod stats; pub mod stddev; diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 59818857e..01d892381 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -67,6 +67,7 @@ use crate::{ bloom_filter_might_contain::BloomFilterMightContain, cast::{Cast, EvalMode}, checkoverflow::CheckOverflow, + correlation::Correlation, covariance::Covariance, if_expr::IfExpr, scalar_funcs::create_comet_physical_fun, @@ -1310,6 +1311,18 @@ impl PhysicalPlanner { ))), } } + AggExprStruct::Correlation(expr) => { + let child1 = self.create_expr(expr.child1.as_ref().unwrap(), schema.clone())?; + let child2 = self.create_expr(expr.child2.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Correlation::new( + child1, + child2, + "correlation", + datatype, + expr.null_on_divide_by_zero, + ))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index ee3de865a..be85e8a92 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -96,6 +96,7 @@ message AggExpr { CovPopulation covPopulation = 13; Variance variance = 14; Stddev stddev = 15; + Correlation correlation = 16; } } @@ -186,6 +187,13 @@ message Stddev { StatisticsType stats_type = 4; } +message Correlation { + Expr child1 = 1; + Expr child2 = 2; + bool null_on_divide_by_zero = 3; + DataType datatype = 4; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index a4ed9289f..a16fd1b21 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -115,6 +115,7 @@ The following cast operations are generally compatible with Spark except for the | string | integer | | | string | long | | | string | binary | | +| string | date | Only supports years between 262143 BC and 262142 AD | | date | string | | | timestamp | long | | | timestamp | decimal | | diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 0204b0c54..eb349b349 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -29,7 +29,6 @@ Comet provides the following configuration settings. | spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. By default, this config is false. | false | | spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. By default, this config is 100. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 | | spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. By default, this config is 3. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 | -| spark.comet.columnar.shuffle.enabled | Whether to enable Arrow-based columnar shuffle for Comet and Spark regular operators. If this is enabled, Comet prefers columnar shuffle than native shuffle. By default, this config is true. | true | | spark.comet.columnar.shuffle.memory.factor | Fraction of Comet memory to be allocated per executor process for Comet shuffle. Comet memory size is specified by `spark.comet.memoryOverhead` or calculated by `spark.comet.memory.overhead.factor` * `spark.executor.memory`. By default, this config is 1.0. | 1.0 | | spark.comet.debug.enabled | Whether to enable debug mode for Comet. By default, this config is false. When enabled, Comet will do additional checks for debugging purpose. For example, validating array when importing arrays from JVM at native side. Note that these checks may be expensive in performance and should only be enabled for debugging purpose. | false | | spark.comet.enabled | Whether to enable Comet extension for Spark. When this is turned on, Spark will use Comet to read Parquet data source. Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. By default, this config is the value of the env var `ENABLE_COMET` if set, or true otherwise. | true | @@ -39,6 +38,7 @@ Comet provides the following configuration settings. | spark.comet.exec.memoryFraction | The fraction of memory from Comet memory overhead that the native memory manager can use for execution. The purpose of this config is to set aside memory for untracked data structures, as well as imprecise size estimation during memory acquisition. Default value is 0.7. | 0.7 | | spark.comet.exec.shuffle.codec | The codec of Comet native shuffle used to compress shuffle data. Only zstd is supported. | zstd | | spark.comet.exec.shuffle.enabled | Whether to enable Comet native shuffle. By default, this config is false. Note that this requires setting 'spark.shuffle.manager' to 'org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager'. 'spark.shuffle.manager' must be set before starting the Spark application and cannot be changed during the application. | false | +| spark.comet.exec.shuffle.mode | The mode of Comet shuffle. This config is only effective if Comet shuffle is enabled. Available modes are 'native', 'jvm', and 'auto'. 'native' is for native shuffle which has best performance in general. 'jvm' is for jvm-based columnar shuffle which has higher coverage than native shuffle. 'auto' is for Comet to choose the best shuffle mode based on the query plan. By default, this config is 'jvm'. | jvm | | spark.comet.explainFallback.enabled | When this setting is enabled, Comet will provide logging explaining the reason(s) why a query stage cannot be executed natively. | false | | spark.comet.memory.overhead.factor | Fraction of executor memory to be allocated as additional non-heap memory per executor process for Comet. Default value is 0.2. | 0.2 | | spark.comet.memory.overhead.min | Minimum amount of additional memory to be allocated per executor process for Comet, in MiB. | 402653184b | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 38c86c727..521699d34 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -109,3 +109,4 @@ The following Spark expressions are currently available: - VarianceSamp - StddevPop - StddevSamp + - Corr diff --git a/docs/source/user-guide/tuning.md b/docs/source/user-guide/tuning.md index 01fa7bdbe..5a3100bd0 100644 --- a/docs/source/user-guide/tuning.md +++ b/docs/source/user-guide/tuning.md @@ -39,22 +39,26 @@ It must be set before the Spark context is created. You can enable or disable Co at runtime by setting `spark.comet.exec.shuffle.enabled` to `true` or `false`. Once it is disabled, Comet will fallback to the default Spark shuffle manager. -### Columnar Shuffle +### Shuffle Mode -By default, once `spark.comet.exec.shuffle.enabled` is enabled, Comet uses columnar shuffle +Comet provides three shuffle modes: Columnar Shuffle, Native Shuffle and Auto Mode. + +#### Columnar Shuffle + +By default, once `spark.comet.exec.shuffle.enabled` is enabled, Comet uses JVM-based columnar shuffle to improve the performance of shuffle operations. Columnar shuffle supports HashPartitioning, -RoundRobinPartitioning, RangePartitioning and SinglePartitioning. +RoundRobinPartitioning, RangePartitioning and SinglePartitioning. This mode has the highest +query coverage. -Columnar shuffle can be disabled by setting `spark.comet.columnar.shuffle.enabled` to `false`. +Columnar shuffle can be enabled by setting `spark.comet.exec.shuffle.mode` to `jvm`. -### Native Shuffle +#### Native Shuffle Comet also provides a fully native shuffle implementation that can be used to improve the performance. -To enable native shuffle, just disable `spark.comet.columnar.shuffle.enabled`. +To enable native shuffle, just set `spark.comet.exec.shuffle.mode` to `native` Native shuffle only supports HashPartitioning and SinglePartitioning. +### Auto Mode - - - +`spark.comet.exec.shuffle.mode` to `auto` will let Comet choose the best shuffle mode based on the query plan. diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 1a2699781..d6ec85f5b 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} +import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -197,7 +197,7 @@ class CometSparkSessionExtensions private def applyCometShuffle(plan: SparkPlan): SparkPlan = { plan.transformUp { case s: ShuffleExchangeExec - if isCometPlan(s.child) && !isCometColumnarShuffleEnabled(conf) && + if isCometPlan(s.child) && isCometNativeShuffleMode(conf) && QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for Native Shuffle") @@ -209,8 +209,8 @@ class CometSparkSessionExtensions // Columnar shuffle for regular Spark operators (not Comet) and Comet operators // (if configured) case s: ShuffleExchangeExec - if (!s.child.supportsColumnar || isCometPlan( - s.child)) && isCometColumnarShuffleEnabled(conf) && + if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode( + conf) && QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") @@ -641,7 +641,7 @@ class CometSparkSessionExtensions // Native shuffle for Comet operators case s: ShuffleExchangeExec if isCometShuffleEnabled(conf) && - !isCometColumnarShuffleEnabled(conf) && + isCometNativeShuffleMode(conf) && QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._1 => logInfo("Comet extension enabled for Native Shuffle") @@ -662,7 +662,7 @@ class CometSparkSessionExtensions // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not // convert it to CometColumnarShuffle, case s: ShuffleExchangeExec - if isCometShuffleEnabled(conf) && isCometColumnarShuffleEnabled(conf) && + if isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) && QueryPlanSerde.supportPartitioningTypes(s.child.output)._1 && !isShuffleOperator(s.child) => logInfo("Comet extension enabled for JVM Columnar Shuffle") @@ -684,19 +684,19 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec => val isShuffleEnabled = isCometShuffleEnabled(conf) val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available") - val msg1 = createMessage(!isShuffleEnabled, s"Native shuffle is not enabled: $reason") - val columnarShuffleEnabled = isCometColumnarShuffleEnabled(conf) + val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason") + val columnarShuffleEnabled = isCometJVMShuffleMode(conf) val msg2 = createMessage( isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde .supportPartitioning(s.child.output, s.outputPartitioning) ._1, - "Shuffle: " + + "Native shuffle: " + s"${QueryPlanSerde.supportPartitioning(s.child.output, s.outputPartitioning)._2}") val msg3 = createMessage( isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde .supportPartitioningTypes(s.child.output) ._1, - s"Columnar shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}") + s"JVM shuffle: ${QueryPlanSerde.supportPartitioningTypes(s.child.output)._2}") withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(",")) s @@ -966,8 +966,20 @@ object CometSparkSessionExtensions extends Logging { COMET_EXEC_ENABLED.get(conf) } - private[comet] def isCometColumnarShuffleEnabled(conf: SQLConf): Boolean = { - COMET_COLUMNAR_SHUFFLE_ENABLED.get(conf) + private[comet] def isCometNativeShuffleMode(conf: SQLConf): Boolean = { + COMET_SHUFFLE_MODE.get(conf) match { + case "native" => true + case "auto" => true + case _ => false + } + } + + private[comet] def isCometJVMShuffleMode(conf: SQLConf): Boolean = { + COMET_SHUFFLE_MODE.get(conf) match { + case "jvm" => true + case "auto" => true + case _ => false + } } private[comet] def isCometAllOperatorEnabled(conf: SQLConf): Boolean = { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 795bdb428..11c5a53cc 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -119,7 +119,7 @@ object CometCast { Unsupported case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 - Unsupported + Compatible(Some("Only supports years between 262143 BC and 262142 AD")) case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") => Incompatible(Some(s"Cast will use UTC instead of $timeZoneId")) case DataTypes.TimestampType if evalMode == "ANSI" => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c7d0f59a5..878983947 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -547,6 +547,27 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(aggExpr, child) None } + case corr @ Corr(child1, child2, nullOnDivideByZero) => + val child1Expr = exprToProto(child1, inputs, binding) + val child2Expr = exprToProto(child2, inputs, binding) + val dataType = serializeDataType(corr.dataType) + + if (child1Expr.isDefined && child2Expr.isDefined && dataType.isDefined) { + val corrBuilder = ExprOuterClass.Correlation.newBuilder() + corrBuilder.setChild1(child1Expr.get) + corrBuilder.setChild2(child2Expr.get) + corrBuilder.setNullOnDivideByZero(nullOnDivideByZero) + corrBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setCorrelation(corrBuilder) + .build()) + } else { + withInfo(aggExpr, child1, child2) + None + } case fn => val msg = s"unsupported Spark aggregate function: ${fn.prettyName}" emitWarning(msg) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 21f974fbf..4d73596a7 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -22,6 +22,7 @@ package org.apache.comet import java.io.File import scala.util.Random +import scala.util.matching.Regex import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode} import org.apache.spark.sql.catalyst.expressions.Cast @@ -33,6 +34,7 @@ import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} import org.apache.comet.expressions.{CometCast, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ /** Create a data generator using a fixed seed so that tests are reproducible */ @@ -53,6 +55,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val numericPattern = "0123456789deEf+-." + whitespaceChars private val datePattern = "0123456789/" + whitespaceChars + private val timestampPattern = "0123456789/:T" + whitespaceChars test("all valid cast combinations covered") { @@ -567,9 +570,68 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.BinaryType) } - ignore("cast StringType to DateType") { - // https://github.com/apache/datafusion-comet/issues/327 - castTest(gen.generateStrings(dataSize, datePattern, 8).toDF("a"), DataTypes.DateType) + test("cast StringType to DateType") { + // error message for invalid dates in Spark 3.2 not supported by Comet see below issue. + // https://github.com/apache/datafusion-comet/issues/440 + assume(CometSparkSessionExtensions.isSpark33Plus) + val validDates = Seq( + "262142-01-01", + "262142-01-01 ", + "262142-01-01T ", + "262142-01-01T 123123123", + "-262143-12-31", + "-262143-12-31 ", + "-262143-12-31T", + "-262143-12-31T ", + "-262143-12-31T 123123123", + "2020", + "2020-1", + "2020-1-1", + "2020-01", + "2020-01-01", + "2020-1-01 ", + "2020-01-1", + "02020-01-01", + "2020-01-01T", + "2020-10-01T 1221213", + "002020-01-01 ", + "0002020-01-01 123344", + "-3638-5") + val invalidDates = Seq( + "0", + "202", + "3/", + "3/3/", + "3/3/2020", + "3#3#2020", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31T 1234 ", + "abc-def-ghi", + "abc-def-ghi jkl", + "2020-mar-20", + "not_a_date", + "T2", + "\t\n3938\n8", + "8701\t", + "\n8757", + "7593\t\t\t", + "\t9374 \n ", + "\n 9850 \t", + "\r\n\t9840", + "\t9629\n", + "\r\n 9629 \r\n", + "\r\n 962 \r\n", + "\r\n 62 \r\n") + + // due to limitations of NaiveDate we only support years between 262143 BC and 262142 AD" + val unsupportedYearPattern: Regex = "^\\s*[0-9]{5,}".r + val fuzzDates = gen + .generateStrings(dataSize, datePattern, 8) + .filterNot(str => unsupportedYearPattern.findFirstMatchIn(str).isDefined) + castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"), DataTypes.DateType) } test("cast StringType to TimestampType disabled by default") { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 98a2bad02..6ca4baf60 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1421,7 +1421,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { ( s"SELECT sum(c0), sum(c2) from $table group by c1", Set( - "Native shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", + "Comet shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", "AQEShuffleRead is not supported")), ( "SELECT A.c1, A.sum_c0, A.sum_c2, B.casted from " @@ -1429,7 +1429,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + s"(SELECT c1, cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as casted from $table) as B " + "where A.c1 = B.c1 ", Set( - "Native shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", + "Comet shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", "AQEShuffleRead is not supported", "make_interval is not supported", "BroadcastExchange is not supported", diff --git a/spark/src/test/scala/org/apache/comet/DataGenerator.scala b/spark/src/test/scala/org/apache/comet/DataGenerator.scala index 691a371b5..80e7c2288 100644 --- a/spark/src/test/scala/org/apache/comet/DataGenerator.scala +++ b/spark/src/test/scala/org/apache/comet/DataGenerator.scala @@ -21,14 +21,20 @@ package org.apache.comet import scala.util.Random +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.types.{StringType, StructType} + object DataGenerator { // note that we use `def` rather than `val` intentionally here so that // each test suite starts with a fresh data generator to help ensure // that tests are deterministic def DEFAULT = new DataGenerator(new Random(42)) + // matches the probability of nulls in Spark's RandomDataGenerator + private val PROBABILITY_OF_NULL: Float = 0.1f } class DataGenerator(r: Random) { + import DataGenerator._ /** Generate a random string using the specified characters */ def generateString(chars: String, maxLen: Int): String = { @@ -95,4 +101,39 @@ class DataGenerator(r: Random) { Range(0, n).map(_ => r.nextLong()) } + // Generate a random row according to the schema, the string filed in the struct could be + // configured to generate strings by passing a stringGen function. Other types are delegated + // to Spark's RandomDataGenerator. + def generateRow(schema: StructType, stringGen: Option[() => String] = None): Row = { + val fields = schema.fields.map { f => + f.dataType match { + case StructType(children) => + generateRow(StructType(children), stringGen) + case StringType if stringGen.isDefined => + val gen = stringGen.get + val data = if (f.nullable && r.nextFloat() <= PROBABILITY_OF_NULL) { + null + } else { + gen() + } + data + case _ => + val gen = RandomDataGenerator.forType(f.dataType, f.nullable, r) match { + case Some(g) => g + case None => + throw new IllegalStateException(s"No RandomDataGenerator for type ${f.dataType}") + } + gen() + } + }.toSeq + Row.fromSeq(fields) + } + + def generateRows( + num: Int, + schema: StructType, + stringGen: Option[() => String] = None): Seq[Row] = { + Range(0, num).map(_ => generateRow(schema, stringGen)) + } + } diff --git a/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala new file mode 100644 index 000000000..02dfb9d7b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/DataGeneratorSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.types.StructType + +class DataGeneratorSuite extends CometTestBase { + + test("test configurable stringGen in row generator") { + val gen = DataGenerator.DEFAULT + val chars = "abcde" + val maxLen = 10 + val stringGen = () => gen.generateString(chars, maxLen) + val numRows = 100 + val schema = new StructType().add("a", "string") + var numNulls = 0 + gen + .generateRows(numRows, schema, Some(stringGen)) + .foreach(row => { + if (row.getString(0) != null) { + assert(row.getString(0).forall(chars.toSeq.contains)) + assert(row.getString(0).length <= maxLen) + } else { + numNulls += 1 + } + }) + // 0.1 null probability + assert(numNulls >= 0.05 * numRows && numNulls <= 0.15 * numRows) + } + +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 310a24ee3..ca7bc7df0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -44,7 +44,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val df1 = sql("SELECT count(DISTINCT 2), count(DISTINCT 2,3)") checkSparkAnswer(df1) @@ -57,7 +57,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { checkSparkAnswer(sql(""" |SELECT | lag(123, 100, 321) OVER (ORDER BY id) as lag, @@ -78,7 +78,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val df1 = Seq( ("a", "b", "c"), ("a", "b", "c"), @@ -99,7 +99,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val df = sql("SELECT LAST(n) FROM lowerCaseData") checkSparkAnswer(df) } @@ -114,7 +114,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val df = sql("select sum(a), avg(a) from allNulls") checkSparkAnswer(df) } @@ -125,7 +125,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test") makeParquetFile(path, 10000, 10, false) @@ -141,7 +141,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test") @@ -160,7 +160,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { sql( "CREATE TABLE lineitem(l_extendedprice DOUBLE, l_quantity DOUBLE, l_partkey STRING) USING PARQUET") @@ -197,7 +197,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> EliminateSorts.ruleName, CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test") @@ -216,7 +216,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "false", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { withTable(table) { sql(s"CREATE TABLE $table(col DECIMAL(5, 2)) USING PARQUET") sql(s"INSERT INTO TABLE $table VALUES (CAST(12345.01 AS DECIMAL(5, 2)))") @@ -316,7 +316,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { dictionaryEnabled => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { withParquetTable( (0 until 100).map(i => (i, (i % 10).toString)), "tbl", @@ -497,7 +497,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { nativeShuffleEnabled => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test") makeParquetFile(path, 1000, 20, dictionaryEnabled) @@ -686,7 +686,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test final count") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(false, true).foreach { dictionaryEnabled => withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { checkSparkAnswerAndNumOfAggregates("SELECT _2, COUNT(_1) FROM tbl GROUP BY _2", 2) @@ -703,7 +703,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test final min/max") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { checkSparkAnswerAndNumOfAggregates( @@ -724,7 +724,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test final min/max/count with result expressions") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable((0 until 5).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { checkSparkAnswerAndNumOfAggregates( @@ -759,7 +759,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test final sum") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(false, true).foreach { dictionaryEnabled => withParquetTable((0L until 5L).map(i => (i, i % 2)), "tbl", dictionaryEnabled) { checkSparkAnswerAndNumOfAggregates( @@ -780,7 +780,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test final avg") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(true, false).foreach { dictionaryEnabled => withParquetTable( (0 until 5).map(i => (i.toDouble, i.toDouble % 2)), @@ -805,7 +805,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { Seq(true, false).foreach { dictionaryEnabled => withSQLConf("parquet.enable.dictionary" -> dictionaryEnabled.toString) { val table = "t1" @@ -850,7 +850,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("avg null handling") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + CometConf.COMET_SHUFFLE_MODE.key -> "native") { val table = "t1" withTable(table) { sql(s"create table $table(a double, b double) using parquet") @@ -872,7 +872,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { nativeShuffleEnabled => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> nativeShuffleEnabled.toString, - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "native", CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test") @@ -912,11 +912,11 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("distinct") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { - Seq(true, false).foreach { cometColumnShuffleEnabled => - withSQLConf( - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq("native", "jvm").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val cometColumnShuffleEnabled = cometShuffleMode == "jvm" val table = "test" withTable(table) { sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") @@ -970,7 +970,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "test" @@ -1016,9 +1016,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test bool_and/bool_or") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { - Seq(true, false).foreach { cometColumnShuffleEnabled => - withSQLConf( - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq("native", "jvm").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "test" @@ -1043,7 +1042,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("bitwise aggregate") { withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "test" @@ -1092,9 +1091,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("covar_pop and covar_samp") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { - Seq(true, false).foreach { cometColumnShuffleEnabled => - withSQLConf( - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq("native", "jvm").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { val table = "test" @@ -1131,9 +1129,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("var_pop and var_samp") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { - Seq(true, false).foreach { cometColumnShuffleEnabled => - withSQLConf( - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq("native", "jvm").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { Seq(true, false).foreach { nullOnDivideByZero => @@ -1171,9 +1168,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("stddev_pop and stddev_samp") { withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { - Seq(true, false).foreach { cometColumnShuffleEnabled => - withSQLConf( - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq("native", "jvm").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { Seq(true, false).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { Seq(true, false).foreach { nullOnDivideByZero => @@ -1212,6 +1208,156 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("correlation") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq("jvm", "native").foreach { cometShuffleMode => + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + Seq(true, false).foreach { nullOnDivideByZero => + withSQLConf( + "spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) { + val table = "test" + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 1), (3, 6, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 3), (2, -5, 3), (3, 6, 1)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1.1, 4.1, 2.3), (2, 5, 1.5), (3, 6, 2.3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql( + s"create table $table(col1 double, col2 double, col3 double) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 6, 3), (1.1, 4.4, 1), (2.2, 5.5, 2), (3.3, 6.6, 3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(1, 4, 1), (2, 5, 2), (3, 6, 3)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(1, 4, 2), (null, null, 2), (3, 6, 1), (3, 3, 1)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(1, 4, 1), (null, 5, 1), (2, 5, 2), (9, null, 2), (3, 6, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql(s"insert into $table values(null, null, 1), (1, 2, 1), (null, null, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int) using parquet") + sql( + s"insert into $table values(null, null, 1), (null, null, 1), (null, null, 2)") + val expectedNumOfCometAggregates = 2 + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table", + expectedNumOfCometAggregates) + + checkSparkAnswerWithTolAndNumOfAggregates( + s"SELECT corr(col1, col2) FROM $table GROUP BY col3", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 600f9c44f..c38be7c4a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -54,7 +54,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> asyncShuffleEnable.toString, CometConf.COMET_COLUMNAR_SHUFFLE_SPILL_THRESHOLD.key -> numElementsForceSpillThreshold.toString, CometConf.COMET_EXEC_ENABLED.key -> "false", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_COLUMNAR_SHUFFLE_MEMORY_SIZE.key -> "1536m") { testFun @@ -963,7 +963,7 @@ class CometShuffleSuite extends CometColumnarShuffleSuite { SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { val df = sql("SELECT * FROM tbl_a") val shuffled = df @@ -983,7 +983,7 @@ class CometShuffleSuite extends CometColumnarShuffleSuite { SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { val df = sql("SELECT * FROM tbl_a") @@ -1016,7 +1016,7 @@ class DisableAQECometShuffleSuite extends CometColumnarShuffleSuite { SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") { withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") { val df = sql("SELECT * FROM tbl_a") @@ -1061,7 +1061,7 @@ class CometShuffleEncryptionSuite extends CometTestBase { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "false", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> asyncEnabled.toString) { readParquetFile(path.toString) { df => val shuffled = df diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index d91383adb..a0afe6b0c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -134,14 +134,15 @@ class CometExecSuite extends CometTestBase { .toDF("c1", "c2") .createOrReplaceTempView("v") - Seq(true, false).foreach { columnarShuffle => + Seq("native", "jvm").foreach { columnarShuffleMode => withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle.toString) { + CometConf.COMET_SHUFFLE_MODE.key -> columnarShuffleMode) { val df = sql("SELECT * FROM v where c1 = 1 order by c1, c2") val shuffle = find(df.queryExecution.executedPlan) { - case _: CometShuffleExchangeExec if columnarShuffle => true - case _: ShuffleExchangeExec if !columnarShuffle => true + case _: CometShuffleExchangeExec if columnarShuffleMode.equalsIgnoreCase("jvm") => + true + case _: ShuffleExchangeExec if !columnarShuffleMode.equalsIgnoreCase("jvm") => true case _ => false }.get assert(shuffle.logicalLink.isEmpty) @@ -179,7 +180,7 @@ class CometExecSuite extends CometTestBase { SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled, // `REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION` is a new config in Spark 3.3+. "spark.sql.requireAllClusterKeysForDistribution" -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value") val windowSpec = Window.partitionBy("key1", "key2").orderBy("value") @@ -318,7 +319,7 @@ class CometExecSuite extends CometTestBase { dataTypes.map { subqueryType => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { var column1 = s"CAST(max(_1) AS $subqueryType)" @@ -499,7 +500,7 @@ class CometExecSuite extends CometTestBase { SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withTable(tableName, dim) { sql( @@ -716,7 +717,7 @@ class CometExecSuite extends CometTestBase { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { val df = sql("SELECT * FROM tbl").sort($"_1".desc) checkSparkAnswerAndOperator(df) @@ -764,10 +765,10 @@ class CometExecSuite extends CometTestBase { } test("limit") { - Seq("true", "false").foreach { columnarShuffle => + Seq("native", "jvm").foreach { columnarShuffleMode => withSQLConf( CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> columnarShuffle) { + CometConf.COMET_SHUFFLE_MODE.key -> columnarShuffleMode) { withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl_a") { val df = sql("SELECT * FROM tbl_a") .repartition(10, $"_1") @@ -1415,7 +1416,7 @@ class CometExecSuite extends CometTestBase { Seq("true", "false").foreach(aqe => { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe, - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "false") { spark .range(1000) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index d48ba1839..d17e4abf4 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -37,7 +37,7 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper super.test(testName, testTags: _*) { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false", + CometConf.COMET_SHUFFLE_MODE.key -> "native", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { testFun } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala index 32b60b087..ec87f19e9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCHQuerySuite.scala @@ -91,7 +91,7 @@ class CometTPCHQuerySuite extends QueryTest with CometTPCBase with ShimCometTPCH conf.set(CometConf.COMET_EXEC_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true") conf.set(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key, "true") - conf.set(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key, "true") + conf.set(CometConf.COMET_SHUFFLE_MODE.key, "jvm") conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") conf.set(MEMORY_OFFHEAP_SIZE.key, "2g") } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala index d6020ac69..bf4bfdbee 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometExecBenchmark.scala @@ -131,7 +131,7 @@ object CometExecBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { spark.sql( "SELECT (SELECT max(col1) AS parquetV1Table FROM parquetV1Table) AS a, " + "col2, col3 FROM parquetV1Table") diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala index 865572811..30a2823cf 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometShuffleBenchmark.scala @@ -106,7 +106,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql( @@ -165,7 +165,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql( @@ -222,7 +222,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1.0", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql("select c1 from parquetV1Table") @@ -238,7 +238,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "2.0", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql("select c1 from parquetV1Table") @@ -254,7 +254,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> "1000000000.0", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql("select c1 from parquetV1Table") @@ -321,7 +321,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "false") { spark .sql("select c1 from parquetV1Table") @@ -336,7 +336,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", CometConf.COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED.key -> "true") { spark .sql("select c1 from parquetV1Table") @@ -409,7 +409,7 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { spark .sql(s"select $columns from parquetV1Table") .repartition(partitionNum, Column("c1")) @@ -422,7 +422,8 @@ object CometShuffleBenchmark extends CometBenchmarkBase { CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key -> "true", - CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native") { spark .sql(s"select $columns from parquetV1Table") .repartition(partitionNum, Column("c1")) diff --git a/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala index 75765bd9b..764f7b18d 100644 --- a/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala +++ b/spark/src/test/spark-3.4-plus/org/apache/comet/exec/CometExec3_4PlusSuite.scala @@ -42,7 +42,7 @@ class CometExec3_4PlusSuite extends CometTestBase { // The syntax is only supported by Spark 3.4+. test("subquery limit: limit with offset should return correct results") { - withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { withTable("t1", "t2") { val table1 = """create temporary view t1 as select * from values @@ -94,7 +94,7 @@ class CometExec3_4PlusSuite extends CometTestBase { // Dataset.offset API is not available before Spark 3.4 test("offset") { - withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { checkSparkAnswer(testData.offset(90)) checkSparkAnswer(arrayData.toDF().offset(99)) checkSparkAnswer(mapData.toDF().offset(99))