From 2bc6c97c2a57ab599c7cc5c9d3af85e026660de2 Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Fri, 9 Feb 2024 09:56:13 -0800 Subject: [PATCH] feat: date and timestamp trunc with format array (#1438) feat: date and timestamp trunc with format array --- .../datafusion/expressions/temporal.rs | 28 +- core/src/execution/kernels/temporal.rs | 715 +++++++++++++++++- .../apache/comet/CometExpressionSuite.scala | 39 + .../org/apache/spark/sql/CometTestBase.scala | 75 ++ 4 files changed, 849 insertions(+), 8 deletions(-) diff --git a/core/src/execution/datafusion/expressions/temporal.rs b/core/src/execution/datafusion/expressions/temporal.rs index 3654a4ed9..5bdb533d0 100644 --- a/core/src/execution/datafusion/expressions/temporal.rs +++ b/core/src/execution/datafusion/expressions/temporal.rs @@ -33,7 +33,10 @@ use datafusion_physical_expr::PhysicalExpr; use crate::execution::{ datafusion::expressions::utils::{array_with_timezone, down_cast_any_ref}, - kernels::temporal::{date_trunc_dyn, timestamp_trunc_dyn}, + kernels::temporal::{ + date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, + timestamp_trunc_dyn, + }, }; #[derive(Debug, Hash)] @@ -372,9 +375,13 @@ impl PhysicalExpr for DateTruncExec { let result = date_trunc_dyn(&date, format)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; + Ok(ColumnarValue::Array(result)) + } _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar)" - .to_string(), + "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ + (PrimitiveArray, StringArray)".to_string(), )), } } @@ -486,9 +493,20 @@ impl PhysicalExpr for TimestampTruncExec { let result = timestamp_trunc_dyn(&ts, format)?; Ok(ColumnarValue::Array(result)) } + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + ); + let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; + Ok(ColumnarValue::Array(result)) + } _ => Err(DataFusionError::Execution( - "Invalid input to function TimestampTrunc. ".to_owned() - + "Expected (PrimitiveArray, Scalar, String)", + "Invalid input to function TimestampTrunc. \ + Expected (PrimitiveArray, Scalar, String) or \ + (PrimitiveArray, StringArray, String)" + .to_string(), )), } } diff --git a/core/src/execution/kernels/temporal.rs b/core/src/execution/kernels/temporal.rs index ec7f2be7e..8e454fc9e 100644 --- a/core/src/execution/kernels/temporal.rs +++ b/core/src/execution/kernels/temporal.rs @@ -26,7 +26,7 @@ use arrow_array::{ downcast_dictionary_array, downcast_temporal_array, temporal_conversions::*, timezone::Tz, - types::{ArrowTemporalType, Date32Type, TimestampMicrosecondType}, + types::{ArrowDictionaryKeyType, ArrowTemporalType, Date32Type, TimestampMicrosecondType}, ArrowNumericType, }; @@ -76,6 +76,24 @@ where builder.finish() } +#[inline] +fn as_datetime_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + op: F, +) where + F: Fn(NaiveDateTime) -> i32, +{ + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } +} + // Based on arrow_arith/temporal.rs:extract_component_from_datetime_array // Transforms an array of DateTime to an arrayOf TimeStampMicrosecond after applying an // operation @@ -106,6 +124,30 @@ where Ok(builder.finish()) } +fn as_timestamp_tz_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + tz: &Tz, + op: F, +) -> Result<(), ExpressionError> +where + F: Fn(DateTime) -> i64, + i64: From, +{ + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), *tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(ExpressionError::ArrowError( + "Unable to read value as datetime".to_string(), + )); + } + }, + None => builder.append_null(), + } + Ok(()) +} + #[inline] fn as_days_from_unix_epoch(dt: Option) -> i32 { dt.unwrap().num_days_from_ce() - DAYS_TO_UNIX_EPOCH @@ -207,6 +249,13 @@ fn trunc_date_to_microsec(dt: T) -> Option { Some(dt).and_then(|d| d.with_nanosecond(1_000 * (d.nanosecond() / 1_000))) } +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. pub fn date_trunc_dyn(array: &dyn Array, format: String) -> Result { match array.data_type().clone() { DataType::Dictionary(_, _) => { @@ -274,6 +323,186 @@ where } } +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format may be an array +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding date value. +/// The array may be a dictionary array. +pub fn date_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, v), DataType::Utf8) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_plain( + &array.downcast_dict::().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + (DataType::Date32, DataType::Dictionary(_, f)) => { + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + date_trunc_array_fmt_plain_dict( + array.as_any().downcast_ref::() + .expect("Unexpected error in casting date array"), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Date32, DataType::Utf8) => date_trunc_array_fmt_plain_plain( + array + .as_any() + .downcast_ref::() + .expect("Unexpected error in casting date array"), + formats + .as_any() + .downcast_ref::() + .expect("Unexpected value type in formats"), + ) + .map(|a| Arc::new(a) as ArrayRef), + (dt, fmt) => Err(ExpressionError::ArrowError(format!( + "Unsupported datatype: {:}, format: {:?} for function 'date_trunc'", + dt, fmt + ))), + } +} + +macro_rules! date_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = Date32Builder::with_capacity($array.len()); + let iter = $array.into_iter(); + match $datatype { + DataType::Date32 => { + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_year(dt)) + })) + } + "QUARTER" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_quarter(dt)) + })), + "MONTH" | "MON" | "MM" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_month(dt)) + })) + } + "WEEK" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_week(dt)) + })), + _ => Err(ExpressionError::ArrowError(format!( + "Unsupported format: {:?} for function 'date_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'date_trunc'", + dt + ), + } + }}; +} + +fn date_trunc_array_fmt_plain_plain( + array: &Date32Array, + formats: &StringArray, +) -> Result +where +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_plain_dict( + array: &Date32Array, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray, + formats: &StringArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. pub fn timestamp_trunc_dyn(array: &dyn Array, format: String) -> Result { match array.data_type().clone() { DataType::Dictionary(_, _) => { @@ -373,10 +602,226 @@ where } } +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format may be an array +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding timestamp +/// value. The array may be a dictionary array. +pub fn timestamp_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, _), DataType::Utf8) => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_plain( + &array.downcast_dict::>().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + timestamp_trunc_array_fmt_plain_dict( + array, + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Utf8) => { + downcast_temporal_array!( + array => { + timestamp_trunc_array_fmt_plain_plain(array, + formats.as_any().downcast_ref::().expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + }, + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (dt, fmt) => Err(ExpressionError::ArrowError(format!( + "Unsupported datatype: {:}, format: {:?} for function 'timestamp_trunc'", + dt, fmt + ))), + } +} + +macro_rules! timestamp_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = TimestampMicrosecondBuilder::with_capacity($array.len()); + let iter = $array.into_iter(); + assert_eq!( + $array.len(), + $formats.len(), + "lengths of values array and format array must be the same" + ); + match $datatype { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let tz: Tz = tz.parse()?; + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_year(dt)) + }) + } + "QUARTER" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_quarter(dt)) + }) + } + "MONTH" | "MON" | "MM" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_month(dt)) + }) + } + "WEEK" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_week(dt)) + }) + } + "DAY" | "DD" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_day(dt)) + }) + } + "HOUR" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_hour(dt)) + }) + } + "MINUTE" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_minute(dt)) + }) + } + "SECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_second(dt)) + }) + } + "MILLISECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_ms(dt)) + }) + } + "MICROSECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) + }) + } + _ => Err(ExpressionError::ArrowError(format!( + "Unsupported format: {:?} for function 'timestamp_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => { + return_compute_error_with!( + "Unsupported input type '{:?}' for function 'timestamp_trunc'", + dt + ) + } + } + }}; +} + +fn timestamp_trunc_array_fmt_plain_plain( + array: &PrimitiveArray, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} +fn timestamp_trunc_array_fmt_plain_dict( + array: &PrimitiveArray, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray>, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray>, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + #[cfg(test)] mod tests { - use crate::execution::kernels::temporal::{date_trunc, timestamp_trunc}; - use arrow_array::{Date32Array, TimestampMicrosecondArray}; + use crate::execution::kernels::temporal::{ + date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, timestamp_trunc_array_fmt_dyn, + }; + use arrow_array::{ + builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}, + iterator::ArrayIter, + types::{Date32Type, Int32Type, TimestampMicrosecondType}, + Array, Date32Array, PrimitiveArray, StringArray, TimestampMicrosecondArray, + }; + use std::sync::Arc; #[test] fn test_date_trunc() { @@ -400,6 +845,122 @@ mod tests { } } + #[test] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_date_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for j in 0..formats.len() { + vec.push(i as i32 * 1_000_001); + fmt_vec.push(formats[j]); + } + } + + // timestamp array + let array = Date32Array::from(vec); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut date_dict_builder = PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + date_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = date_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + } + #[test] fn test_timestamp_trunc() { let size = 1000; @@ -435,4 +996,152 @@ mod tests { } } } + + #[test] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_timestamp_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", + "YYYY", + "YY", + "QUARTER", + "MONTH", + "MON", + "MM", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for j in 0..formats.len() { + vec.push(i as i64 * 1_000_000_001); + fmt_vec.push(formats[j]); + } + } + + // timestamp array + let array = TimestampMicrosecondArray::from(vec).with_timezone_utc(); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut timestamp_dict_builder = + PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + timestamp_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = timestamp_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .with_timezone_utc(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index df8bc7c7d..66ee2752e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -286,6 +286,23 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("trunc with format array") { + val numRows = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet") + makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + withParquetTable(path.toString, "dateformattbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "dateformat, _7, " + + "trunc(_7, dateformat) " + + " from dateformattbl ") + } + } + } + } + test("date_trunc") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => @@ -355,6 +372,28 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("date_trunc with format array") { + val numRows = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") + makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + withParquetTable(path.toString, "timeformattbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "format, _0, _1, _2, _3, _4, _5, " + + "date_trunc(format, _0), " + + "date_trunc(format, _1), " + + "date_trunc(format, _2), " + + "date_trunc(format, _3), " + + "date_trunc(format, _4), " + + "date_trunc(format, _5) " + + " from timeformattbl ") + } + } + } + } + test("date_trunc on int96 timestamp column") { import testImplicits._ diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 4f2838cfb..8e3e1d947 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -586,6 +586,81 @@ abstract class CometTestBase expected } + protected def makeDateTimeWithFormatTable( + path: Path, + dictionaryEnabled: Boolean, + n: Int, + rowGroupSize: Long = 1024 * 1024L): Seq[Option[Long]] = { + val schemaStr = + """ + |message root { + | optional int64 _0(TIMESTAMP_MILLIS); + | optional int64 _1(TIMESTAMP_MICROS); + | optional int64 _2(TIMESTAMP(MILLIS,true)); + | optional int64 _3(TIMESTAMP(MILLIS,false)); + | optional int64 _4(TIMESTAMP(MICROS,true)); + | optional int64 _5(TIMESTAMP(MICROS,false)); + | optional int64 _6(INT_64); + | optional int32 _7(DATE); + | optional binary format(UTF8); + | optional binary dateFormat(UTF8); + | } + """.stripMargin + + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter( + schema, + path, + dictionaryEnabled = dictionaryEnabled, + rowGroupSize = rowGroupSize) + val div = if (dictionaryEnabled) 10 else n // maps value to a small range for dict to kick in + + val expected = (0 until n).map { i => + Some(getValue(i, div)) + } + expected.foreach { opt => + val timestampFormats = List( + "YEAR", + "YYYY", + "YY", + "MON", + "MONTH", + "MM", + "QUARTER", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND") + val dateFormats = List("YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "QUARTER", "WEEK") + val formats = timestampFormats.zipAll(dateFormats, "NONE", "YEAR") + + formats.foreach { format => + val record = new SimpleGroup(schema) + opt match { + case Some(i) => + record.add(0, i) + record.add(1, i * 1000) // convert millis to micros, same below + record.add(2, i) + record.add(3, i) + record.add(4, i * 1000) + record.add(5, i * 1000) + record.add(6, i * 1000) + record.add(7, i.toInt) + record.add(8, format._1) + record.add(9, format._2) + case _ => + } + writer.write(record) + } + } + writer.close() + expected + } + def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = { val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded spark