diff --git a/core/src/execution/datafusion/expressions/cast.rs b/core/src/execution/datafusion/expressions/cast.rs index f68732fb15..fd1f9166d8 100644 --- a/core/src/execution/datafusion/expressions/cast.rs +++ b/core/src/execution/datafusion/expressions/cast.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::errors::{CometError, CometResult}; use arrow::{ compute::{cast_with_options, CastOptions}, datatypes::{ @@ -33,23 +32,27 @@ use arrow::{ util::display::FormatOptions, }; use arrow_array::{ - types::{Int16Type, Int32Type, Int64Type, Int8Type}, + types::{Date32Type, Int16Type, Int32Type, Int64Type, Int8Type}, Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }; use arrow_schema::{DataType, Schema}; -use chrono::{TimeZone, Timelike}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive}; use regex::Regex; -use crate::execution::datafusion::expressions::utils::{ - array_with_timezone, down_cast_any_ref, spark_cast, +use crate::{ + errors::{CometError, CometResult}, + execution::datafusion::expressions::utils::{ + array_with_timezone, down_cast_any_ref, spark_cast, + }, }; static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); + static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -511,6 +514,31 @@ impl Cast { (DataType::Utf8, DataType::Timestamp(_, _)) => { Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)? } + (DataType::Utf8, DataType::Date32) => { + Self::cast_string_to_date(&array, to_type, self.eval_mode)? + } + (DataType::Dictionary(key_type, value_type), DataType::Date32) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + match value_type.as_ref() { + DataType::Utf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?; + Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? + } + DataType::LargeUtf8 => { + let unpacked_array = + cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?; + Self::cast_string_to_date(&unpacked_array, to_type, self.eval_mode)? + } + dt => unreachable!( + "{}", + format!("invalid value type {dt} for dictionary-encoded string array") + ), + } + } (DataType::Int64, DataType::Int32) | (DataType::Int64, DataType::Int16) | (DataType::Int64, DataType::Int8) @@ -635,6 +663,38 @@ impl Cast { Ok(cast_array) } + fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> CometResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Date32 => { + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + for i in 0..len { + if !string_array.is_null(i) { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => cast_array.append_value(cast_value), + Ok(None) => cast_array.append_null(), + Err(e) => return Err(e), + } + } else { + cast_array.append_null() + } + } + Arc::new(cast_array.finish()) as ArrayRef + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + fn cast_string_to_timestamp( array: &ArrayRef, to_type: &DataType, @@ -858,7 +918,7 @@ impl Cast { i32, "FLOAT", "INT", - std::i32::MAX, + i32::MAX, "{:e}" ), (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( @@ -870,7 +930,7 @@ impl Cast { i64, "FLOAT", "BIGINT", - std::i64::MAX, + i64::MAX, "{:e}" ), (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( @@ -904,7 +964,7 @@ impl Cast { i32, "DOUBLE", "INT", - std::i32::MAX, + i32::MAX, "{:e}D" ), (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( @@ -916,7 +976,7 @@ impl Cast { i64, "DOUBLE", "BIGINT", - std::i64::MAX, + i64::MAX, "{:e}D" ), (DataType::Decimal128(precision, scale), DataType::Int8) => { @@ -936,7 +996,7 @@ impl Cast { Int32Array, i32, "INT", - std::i32::MAX, + i32::MAX, *precision, *scale ) @@ -948,7 +1008,7 @@ impl Cast { Int64Array, i64, "BIGINT", - std::i64::MAX, + i64::MAX, *precision, *scale ) @@ -1264,15 +1324,15 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult } if timestamp.is_none() { - if eval_mode == EvalMode::Ansi { - return Err(CometError::CastInvalidValue { + return if eval_mode == EvalMode::Ansi { + Err(CometError::CastInvalidValue { value: value.to_string(), from_type: "STRING".to_string(), to_type: "TIMESTAMP".to_string(), - }); + }) } else { - return Ok(None); - } + Ok(None) + }; } match timestamp { @@ -1409,13 +1469,136 @@ fn parse_str_to_time_only_timestamp(value: &str) -> CometResult> { Ok(Some(timestamp)) } +//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate. +fn date_parser(date_str: &str, eval_mode: EvalMode) -> CometResult> { + // local functions + fn get_trimmed_start(bytes: &[u8]) -> usize { + let mut start = 0; + while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) { + start += 1; + } + start + } + + fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize { + let mut end = bytes.len() - 1; + while end > start && is_whitespace_or_iso_control(bytes[end]) { + end -= 1; + } + end + 1 + } + + fn is_whitespace_or_iso_control(byte: u8) -> bool { + byte.is_ascii_whitespace() || byte.is_ascii_control() + } + + fn is_valid_digits(segment: i32, digits: usize) -> bool { + // An integer is able to represent a date within [+-]5 million years. + let max_digits_year = 7; + //year (segment 0) can be between 4 to 7 digits, + //month and day (segment 1 and 2) can be between 1 to 2 digits + (segment == 0 && digits >= 4 && digits <= max_digits_year) + || (segment != 0 && digits > 0 && digits <= 2) + } + + fn return_result(date_str: &str, eval_mode: EvalMode) -> CometResult> { + if eval_mode == EvalMode::Ansi { + Err(CometError::CastInvalidValue { + value: date_str.to_string(), + from_type: "STRING".to_string(), + to_type: "DATE".to_string(), + }) + } else { + Ok(None) + } + } + // end local functions + + if date_str.is_empty() { + return return_result(date_str, eval_mode); + } + + //values of date segments year, month and day defaulting to 1 + let mut date_segments = [1, 1, 1]; + let mut sign = 1; + let mut current_segment = 0; + let mut current_segment_value = 0; + let mut current_segment_digits = 0; + let bytes = date_str.as_bytes(); + + let mut j = get_trimmed_start(bytes); + let str_end_trimmed = get_trimmed_end(j, bytes); + + if j == str_end_trimmed { + return return_result(date_str, eval_mode); + } + + //assign a sign to the date + if bytes[j] == b'-' || bytes[j] == b'+' { + sign = if bytes[j] == b'-' { -1 } else { 1 }; + j += 1; + } + + //loop to the end of string until we have processed 3 segments, + //exit loop on encountering any space ' ' or 'T' after the 3rd segment + while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) { + let b = bytes[j]; + if current_segment < 2 && b == b'-' { + //check for validity of year and month segments if current byte is separator + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + //if valid update corresponding segment with the current segment value. + date_segments[current_segment as usize] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + current_segment += 1; + } else if !b.is_ascii_digit() { + return return_result(date_str, eval_mode); + } else { + //increment value of current segment by the next digit + let parsed_value = (b - b'0') as i32; + current_segment_value = current_segment_value * 10 + parsed_value; + current_segment_digits += 1; + } + j += 1; + } + + //check for validity of last segment + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + + if current_segment < 2 && j < str_end_trimmed { + // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed. + return return_result(date_str, eval_mode); + } + + date_segments[current_segment as usize] = current_segment_value; + + match NaiveDate::from_ymd_opt( + sign * date_segments[0], + date_segments[1] as u32, + date_segments[2] as u32, + ) { + Some(date) => { + let duration_since_epoch = date + .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date()) + .num_days(); + Ok(Some(duration_since_epoch.to_i32().unwrap())) + } + None => Ok(None), + } +} + #[cfg(test)] mod tests { - use super::*; use arrow::datatypes::TimestampMicrosecondType; use arrow_array::StringArray; use arrow_schema::TimeUnit; + use super::*; + #[test] fn timestamp_parser_test() { // write for all formats @@ -1480,6 +1663,168 @@ mod tests { assert_eq!(result.len(), 2); } + #[test] + fn date_parser_test() { + for date in &[ + "2020", + "2020-01", + "2020-01-01", + "02020-01-01", + "002020-01-01", + "0002020-01-01", + "2020-1-1", + "2020-01-01 ", + "2020-01-01T", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(18262)); + } + } + + //dates in invalid formats + for date in &[ + "abc", + "", + "not_a_date", + "3/", + "3/12", + "3/12/2020", + "3/12/2002 T", + "202", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31 ", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + assert!(date_parser(*date, EvalMode::Ansi).is_err()); + } + + for date in &["-3638-5"] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(-2048160)); + } + } + + //Naive Date only supports years 262142 AD to 262143 BC + //returns None for dates out of range supported by Naive Date. + for date in &[ + "-262144-1-1", + "262143-01-1", + "262143-1-1", + "262143-01-1 ", + "262143-01-01T ", + "262143-1-01T 1234", + "-0973250", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + } + } + + #[test] + fn test_cast_string_to_date() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let result = + Cast::cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(date32_array.len(), 4); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), 18262)); + } + + #[test] + fn test_cast_string_array_with_valid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("-262143-12-31"), + Some("\n -262143-12-31 "), + Some("-262143-12-31T \t\n"), + Some("\n\t-262143-12-31T\r"), + Some("-262143-12-31T 123123123"), + Some("\r\n-262143-12-31T \r123123123"), + Some("\n -262143-12-31T \n\t"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result.len(), 7); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), -96464928)); + } + } + + #[test] + fn test_cast_string_array_with_invalid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + //4 invalid dates + Some("2020-010-01T"), + Some("202"), + Some(" 202 "), + Some("\n 2020-\r8 "), + Some("2020-01-01T"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + date32_array.iter().collect::>(), + vec![ + Some(18262), + Some(18262), + Some(18262), + None, + None, + None, + None, + Some(18262) + ] + ); + } + + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); + match result { + Err(e) => assert!( + e.to_string().contains( + "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed") + ), + _ => panic!("Expected error"), + } + } + #[test] fn test_cast_string_as_i8() { // basic diff --git a/docs/source/user-guide/compatibility.md b/docs/source/user-guide/compatibility.md index a4ed9289f7..a16fd1b21a 100644 --- a/docs/source/user-guide/compatibility.md +++ b/docs/source/user-guide/compatibility.md @@ -115,6 +115,7 @@ The following cast operations are generally compatible with Spark except for the | string | integer | | | string | long | | | string | binary | | +| string | date | Only supports years between 262143 BC and 262142 AD | | date | string | | | timestamp | long | | | timestamp | decimal | | diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 795bdb428a..11c5a53cc0 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -119,7 +119,7 @@ object CometCast { Unsupported case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 - Unsupported + Compatible(Some("Only supports years between 262143 BC and 262142 AD")) case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") => Incompatible(Some(s"Cast will use UTC instead of $timeZoneId")) case DataTypes.TimestampType if evalMode == "ANSI" => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 8caba14c61..1710090e2a 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -22,6 +22,7 @@ package org.apache.comet import java.io.File import scala.util.Random +import scala.util.matching.Regex import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode} import org.apache.spark.sql.catalyst.expressions.Cast @@ -33,6 +34,7 @@ import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType} import org.apache.comet.expressions.{CometCast, Compatible} class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { + import testImplicits._ /** Create a data generator using a fixed seed so that tests are reproducible */ @@ -53,6 +55,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { private val numericPattern = "0123456789deEf+-." + whitespaceChars private val datePattern = "0123456789/" + whitespaceChars + private val timestampPattern = "0123456789/:T" + whitespaceChars test("all valid cast combinations covered") { @@ -567,9 +570,68 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.BinaryType) } - ignore("cast StringType to DateType") { - // https://github.com/apache/datafusion-comet/issues/327 - castTest(gen.generateStrings(dataSize, datePattern, 8).toDF("a"), DataTypes.DateType) + test("cast StringType to DateType") { + // error message for invalid dates in Spark 3.2 not supported by Comet see below issue. + // https://github.com/apache/datafusion-comet/issues/440 + assume(CometSparkSessionExtensions.isSpark33Plus) + val validDates = Seq( + "262142-01-01", + "262142-01-01 ", + "262142-01-01T ", + "262142-01-01T 123123123", + "-262143-12-31", + "-262143-12-31 ", + "-262143-12-31T", + "-262143-12-31T ", + "-262143-12-31T 123123123", + "2020", + "2020-1", + "2020-1-1", + "2020-01", + "2020-01-01", + "2020-1-01 ", + "2020-01-1", + "02020-01-01", + "2020-01-01T", + "2020-10-01T 1221213", + "002020-01-01 ", + "0002020-01-01 123344", + "-3638-5") + val invalidDates = Seq( + "0", + "202", + "3/", + "3/3/", + "3/3/2020", + "3#3#2020", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31T 1234 ", + "abc-def-ghi", + "abc-def-ghi jkl", + "2020-mar-20", + "not_a_date", + "T2", + "\t\n3938\n8", + "8701\t", + "\n8757", + "7593\t\t\t", + "\t9374 \n ", + "\n 9850 \t", + "\r\n\t9840", + "\t9629\n", + "\r\n 9629 \r\n", + "\r\n 962 \r\n", + "\r\n 62 \r\n") + + // due to limitations of NaiveDate we only support years between 262143 BC and 262142 AD" + val unsupportedYearPattern: Regex = "^\\s*[0-9]{5,}".r + val fuzzDates = gen + .generateStrings(dataSize, datePattern, 8) + .filterNot(str => unsupportedYearPattern.findFirstMatchIn(str).isDefined) + castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"), DataTypes.DateType) } test("cast StringType to TimestampType disabled by default") {