diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index a6e3adaca9..2a36a87463 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, datatypes::TimestampMicrosecondType, @@ -30,18 +29,20 @@ use arrow::{ util::display::FormatOptions, }; use arrow_array::{ + Array, + ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, PrimitiveArray, types::{Int16Type, Int32Type, Int64Type, Int8Type}, - Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, }; +use arrow_array::types::Date32Type; use arrow_schema::{DataType, Schema}; -use chrono::{TimeZone, Timelike}; +use chrono::{Datelike, NaiveDate, Timelike, TimeZone}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use num::{traits::CheckedNeg, CheckedSub, Integer, Num}; +use num::{CheckedSub, Integer, Num, traits::CheckedNeg}; use regex::Regex; +use crate::errors::{CometError, CometResult}; use crate::execution::datafusion::expressions::utils::{ array_with_timezone, down_cast_any_ref, spark_cast, }; @@ -107,7 +108,23 @@ macro_rules! cast_utf8_to_timestamp { result }}; } - +macro_rules! cast_utf8_to_date { + ($array:expr, $eval_mode:expr, $array_type:ty, $date_parser:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = $date_parser($array.value(i).trim(), $eval_mode) { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} macro_rules! cast_float_to_string { ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ @@ -274,16 +291,20 @@ impl Cast { (DataType::Utf8, DataType::Timestamp(_, _)) => { Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? } + (DataType::Utf8, DataType::Date32) + | (DataType::Utf8, DataType::Date64) => { + Self::cast_string_to_date(&array, to_type, self.eval_mode)? + } (DataType::Int64, DataType::Int32) | (DataType::Int64, DataType::Int16) | (DataType::Int64, DataType::Int8) | (DataType::Int32, DataType::Int16) | (DataType::Int32, DataType::Int8) | (DataType::Int16, DataType::Int8) - if self.eval_mode != EvalMode::Try => - { - Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? - } + if self.eval_mode != EvalMode::Try => + { + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)? + } ( DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, @@ -297,29 +318,29 @@ impl Cast { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, ) if key_type.as_ref() == &DataType::Int32 && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - // TODO: we are unpacking a dictionary-encoded array and then performing - // the cast. We could potentially improve performance here by casting the - // dictionary values directly without unpacking the array first, although this - // would add more complexity to the code - match value_type.as_ref() { - DataType::Utf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? - } - DataType::LargeUtf8 => { - let unpacked_array = - cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; - Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + || value_type.as_ref() == &DataType::LargeUtf8) => + { + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + match value_type.as_ref() { + DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; + Self::cast_string_to_int::(to_type, &unpacked_array, self.eval_mode)? + } + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), } - dt => unreachable!( - "{}", - format!("invalid value type {dt} for dictionary-encoded string array") - ), } - } (DataType::Float64, DataType::Utf8) => { Self::spark_cast_float64_to_utf8::(&array, self.eval_mode)? } @@ -371,6 +392,30 @@ impl Cast { Ok(cast_array) } + fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Date32 | DataType::Date64 => { + cast_utf8_to_date!( + string_array, + eval_mode, + Date32Type, + date_parser + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + fn cast_string_to_timestamp( array: &ArrayRef, to_type: &DataType, @@ -399,8 +444,8 @@ impl Cast { from: &dyn Array, _eval_mode: EvalMode, ) -> CometResult - where - OffsetSize: OffsetSizeTrait, + where + OffsetSize: OffsetSizeTrait, { cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) } @@ -409,8 +454,8 @@ impl Cast { from: &dyn Array, _eval_mode: EvalMode, ) -> CometResult - where - OffsetSize: OffsetSizeTrait, + where + OffsetSize: OffsetSizeTrait, { cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) } @@ -451,8 +496,8 @@ impl Cast { from: &dyn Array, eval_mode: EvalMode, ) -> CometResult - where - OffsetSize: OffsetSizeTrait, + where + OffsetSize: OffsetSizeTrait, { let array = from .as_any() @@ -489,7 +534,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult> i8::MIN as i32, i8::MAX as i32, )? - .map(|v| v as i8)) + .map(|v| v as i8)) } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort @@ -501,7 +546,7 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult i16::MIN as i32, i16::MAX as i32, )? - .map(|v| v as i16)) + .map(|v| v as i16)) } /// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) @@ -809,15 +854,15 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult } if timestamp.is_none() { - if eval_mode == EvalMode::Ansi { - return Err(CometError::CastInvalidValue { + return if eval_mode == EvalMode::Ansi { + Err(CometError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), - }); + }) } else { - return Ok(None); - } + Ok(None) + }; } match timestamp { @@ -954,13 +999,82 @@ fn parse_str_to_time_only_timestamp(value: &str) -> CometResult> { Ok(Some(timestamp)) } + +fn date_parser(value: &str, eval_mode: EvalMode) -> CometResult> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + + // Define regex patterns and corresponding parsing functions + let patterns = &[ + (Regex::new(r"^\d{4}$").unwrap(), parse_year as fn(&str) -> CometResult>), + (Regex::new(r"^\d{4}-\d{2}$").unwrap(), parse_year_month), + (Regex::new(r"^\d{4}-\d{2}-\d{2}T?$").unwrap(), parse_year_month_day), + ]; + + let mut date = None; + + // Iterate through patterns and try matching + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + date = parse_func(value)?; + break; + } + } + + if date.is_none() && eval_mode == EvalMode::Ansi { + return Err(CometError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "DATE".to_string(), + }); + } + + Ok(date) +} + +fn parse_year(value: &str) -> CometResult> { + let year: i32 = value.parse()?; + let date = NaiveDate::from_ymd_opt(year, 1, 1); + match date { + Some(date) => Ok(Some(date.num_days_from_ce())), + None => Err(CometError::Internal( + "Failed to parse date".to_string(), + )), + } +} + +fn parse_year_month(value: &str) -> CometResult> { + let date = NaiveDate::parse_from_str(value, "%Y-%m"); + match date { + Ok(date) => Ok(Some(date.num_days_from_ce())), + Err(_) => Err(CometError::Internal( + "Failed to parse date".to_string(), + )), + } +} + +fn parse_year_month_day(value: &str) -> CometResult> { + let value = value.trim_end_matches('T'); + let date = NaiveDate::parse_from_str(value, "%Y-%m-%d"); + match date { + Ok(date) => Ok(Some(date.num_days_from_ce())), + Err(_) => Err(CometError::Internal( + "Failed to parse date".to_string(), + )), + } +} + #[cfg(test)] mod tests { - use super::*; + use arrow::datatypes::Date32Type; use arrow::datatypes::TimestampMicrosecondType; use arrow_array::StringArray; use arrow_schema::TimeUnit; + use super::*; + #[test] fn timestamp_parser_test() { // write for all formats @@ -1025,6 +1139,35 @@ mod tests { assert_eq!(result.len(), 2); } + #[test] + fn test_cast_string_as_date() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let eval_mode = EvalMode::Legacy; + let result = cast_utf8_to_date!( + &string_array, + eval_mode, + Date32Type, + date_parser + ); + + assert_eq!( + result.data_type(), + &DataType::Date32 + ); + assert_eq!(result.len(), 4); + } + #[test] fn test_cast_string_as_i8() { // basic diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 483301e02e..2c436e9b1b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -603,6 +603,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast StringType to DateType") { + // test with hand-picked values + castTest( + Seq("2020-01-01", "2020-01", "2020-01-01", "2020-01-01T") + .toDF("a"), + DataTypes.DateType) + // fuzz test + castTest(generateStrings(datePattern, 10).toDF("a"), DataTypes.DateType) + } // CAST from BinaryType ignore("cast BinaryType to StringType") { @@ -612,42 +621,42 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // CAST from DateType - ignore("cast DateType to BooleanType") { + test("cast DateType to BooleanType") { // Arrow error: Cast error: Casting from Date32 to Boolean not supported castTest(generateDates(), DataTypes.BooleanType) } - ignore("cast DateType to ByteType") { + test("cast DateType to ByteType") { // Arrow error: Cast error: Casting from Date32 to Int8 not supported castTest(generateDates(), DataTypes.ByteType) } - ignore("cast DateType to ShortType") { + test("cast DateType to ShortType") { // Arrow error: Cast error: Casting from Date32 to Int16 not supported castTest(generateDates(), DataTypes.ShortType) } - ignore("cast DateType to IntegerType") { + test("cast DateType to IntegerType") { // input: 2345-01-01, expected: null, actual: 3789391 castTest(generateDates(), DataTypes.IntegerType) } - ignore("cast DateType to LongType") { + test("cast DateType to LongType") { // input: 2024-01-01, expected: null, actual: 19723 castTest(generateDates(), DataTypes.LongType) } - ignore("cast DateType to FloatType") { + test("cast DateType to FloatType") { // Arrow error: Cast error: Casting from Date32 to Float32 not supported castTest(generateDates(), DataTypes.FloatType) } - ignore("cast DateType to DoubleType") { + test("cast DateType to DoubleType") { // Arrow error: Cast error: Casting from Date32 to Float64 not supported castTest(generateDates(), DataTypes.DoubleType) } - ignore("cast DateType to DecimalType(10,2)") { + test("cast DateType to DecimalType(10,2)") { // Arrow error: Cast error: Casting from Date32 to Decimal128(10, 2) not supported castTest(generateDates(), DataTypes.createDecimalType(10, 2)) } @@ -656,7 +665,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateDates(), DataTypes.StringType) } - ignore("cast DateType to TimestampType") { + test("cast DateType to TimestampType") { // Arrow error: Cast error: Casting from Date32 to Timestamp(Microsecond, Some("UTC")) not supported castTest(generateDates(), DataTypes.TimestampType) }