Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST from string to timestamp types #335

Merged
merged 30 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7619757
casting str to timestamp
vaibhawvipul Apr 27, 2024
0a85cde
fix format
vaibhawvipul Apr 27, 2024
fe1896c
fixing failed tests, using char as pattern
vaibhawvipul Apr 28, 2024
70f1e69
bug fixes
vaibhawvipul Apr 28, 2024
0217f13
hangling microsecond
vaibhawvipul Apr 28, 2024
36d5cc5
make format
vaibhawvipul Apr 28, 2024
8d0a0d9
bug fixes and core refactor
vaibhawvipul Apr 29, 2024
87b5e66
format code
vaibhawvipul Apr 29, 2024
b9966b7
resolving merge conflicts
vaibhawvipul Apr 29, 2024
30c442d
removing print statements
vaibhawvipul Apr 29, 2024
800e085
clippy error
vaibhawvipul Apr 29, 2024
de0cfc2
enabling cast timestamp test case
vaibhawvipul Apr 29, 2024
fe18d81
code refactor
vaibhawvipul Apr 30, 2024
938b0b3
comet spark test case
vaibhawvipul Apr 30, 2024
cc60cfa
adding all the supported format in test
vaibhawvipul Apr 30, 2024
d250817
merge conflict resolved
vaibhawvipul May 1, 2024
889f91b
resolve conflicts
vaibhawvipul May 1, 2024
9dad369
fallback spark when timestamp not utc
vaibhawvipul May 1, 2024
825fe5e
bug fix
vaibhawvipul May 1, 2024
5041192
bug fix
vaibhawvipul May 1, 2024
9a8dc1d
adding an explainer commit
vaibhawvipul May 1, 2024
2980176
fix test case
vaibhawvipul May 1, 2024
2ffea83
bug fix
vaibhawvipul May 2, 2024
6db8115
bug fix
vaibhawvipul May 2, 2024
a67ccf7
better error handling for unwrap in fn parse_str_to_time_only_timestamp
vaibhawvipul May 2, 2024
b7a3961
remove unwrap from macro
vaibhawvipul May 2, 2024
2f3ab08
improving error handling
vaibhawvipul May 2, 2024
7a39136
adding tests for invalid inputs
vaibhawvipul May 2, 2024
4743742
removed all unwraps from timestamp cast functions
vaibhawvipul May 2, 2024
8c4ad72
code format
vaibhawvipul May 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 287 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
record_batch::RecordBatch,
util::display::FormatOptions,
};
Expand All @@ -33,10 +34,12 @@ use arrow_array::{
Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use regex::Regex;

use crate::execution::datafusion::expressions::utils::{
array_with_timezone, down_cast_any_ref, spark_cast,
Expand Down Expand Up @@ -86,6 +89,24 @@ macro_rules! cast_utf8_to_int {
}};
}

macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
result.unwrap()
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -125,6 +146,9 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
}
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand Down Expand Up @@ -200,6 +224,30 @@ impl Cast {
Ok(cast_array)
}

fn cast_string_to_timestamp(
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let cast_array: ArrayRef = match to_type {
DataType::Timestamp(_, _) => {
cast_utf8_to_timestamp!(
string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser
)
}
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
};
Ok(cast_array)
}

fn spark_cast_utf8_to_boolean<OffsetSize>(
from: &dyn Array,
eval_mode: EvalMode,
Expand Down Expand Up @@ -510,9 +558,246 @@ impl PhysicalExpr for Cast {
}
}

fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> {
let value = value.trim();
if value.is_empty() {
return Ok(None);
}

// Define regex patterns and corresponding parsing functions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regex approach is a good way to quickly get support for all of the format variations and fix the correctness issue but could also be quite expensive. It would be good to add some criterion benchmarks so that we can understand what performance looks like (does not have to be part of this PR).

let patterns = &[
(
Regex::new(r"^\d{4}$").unwrap(),
parse_str_to_year_timestamp as fn(&str) -> CometResult<Option<i64>>,
),
(
Regex::new(r"^\d{4}-\d{2}$").unwrap(),
parse_str_to_month_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(),
parse_str_to_day_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
parse_str_to_hour_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
parse_str_to_minute_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
parse_str_to_second_timestamp,
),
(
Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
parse_str_to_microsecond_timestamp,
),
(
Regex::new(r"^T\d{1,2}$").unwrap(),
parse_str_to_time_only_timestamp,
),
];

let mut timestamp = None;

// Iterate through patterns and try matching
for (pattern, parse_func) in patterns {
if pattern.is_match(value) {
timestamp = parse_func(value)?;
break;
}
}

if timestamp.is_none() {
if eval_mode == EvalMode::Ansi {
return Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
});
} else {
return Ok(None);
}
}
Ok(Some(timestamp.unwrap()))
}

fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> CometResult<Option<i64>> {
let datetime = chrono::Utc
.with_ymd_and_hms(year, month, day, 0, 0, 0)
.unwrap()
.with_timezone(&chrono::Utc);
Ok(Some(datetime.timestamp_micros()))
}

fn parse_hms_timestamp(
year: i32,
month: u32,
day: u32,
hour: u32,
minute: u32,
second: u32,
microsecond: u32,
) -> CometResult<Option<i64>> {
let datetime = chrono::Utc
.with_ymd_and_hms(year, month, day, hour, minute, second)
.unwrap()
.with_timezone(&chrono::Utc)
.with_nanosecond(microsecond * 1000);
Ok(Some(datetime.unwrap().timestamp_micros()))
}

fn get_timestamp_values(value: &str, timestamp_type: &str) -> CometResult<Option<i64>> {
let values: Vec<_> = value
.split(|c| c == 'T' || c == '-' || c == ':' || c == '.')
.collect();
let year = values[0].parse::<i32>().unwrap_or_default();
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
let minute = values.get(4).map_or(0, |m| m.parse::<u32>().unwrap_or(0));
let second = values.get(5).map_or(0, |s| s.parse::<u32>().unwrap_or(0));
let microsecond = values.get(6).map_or(0, |ms| ms.parse::<u32>().unwrap_or(0));

match timestamp_type {
"year" => parse_ymd_timestamp(year, 1, 1),
"month" => parse_ymd_timestamp(year, month, 1),
"day" => parse_ymd_timestamp(year, month, day),
"hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0),
"minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0),
"second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0),
"microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond),
_ => Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
}),
}
}

fn parse_str_to_year_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "year")
}

fn parse_str_to_month_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "month")
}

fn parse_str_to_day_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "day")
}

fn parse_str_to_hour_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "hour")
}

fn parse_str_to_minute_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "minute")
}

fn parse_str_to_second_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "second")
}

fn parse_str_to_microsecond_timestamp(value: &str) -> CometResult<Option<i64>> {
get_timestamp_values(value, "microsecond")
}

fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
let values: Vec<&str> = value.split('T').collect();
let time_values: Vec<u32> = values[1]
.split(':')
.map(|v| v.parse::<u32>().unwrap_or(0))
.collect();

let datetime = chrono::Utc::now();
let timestamp = datetime
.with_hour(time_values.first().copied().unwrap_or_default())
.unwrap()
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
.with_minute(*time_values.get(1).unwrap_or(&0))
.unwrap()
.with_second(*time_values.get(2).unwrap_or(&0))
.unwrap()
.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)
.unwrap()
.to_utc()
.timestamp_micros();

Ok(Some(timestamp))
}

#[cfg(test)]
mod test {
use super::{cast_string_to_i8, EvalMode};
mod tests {
use super::*;
use arrow::datatypes::TimestampMicrosecondType;
use arrow_array::StringArray;
use arrow_schema::TimeUnit;

#[test]
fn timestamp_parser_test() {
// write for all formats
assert_eq!(
timestamp_parser("2020", EvalMode::Legacy).unwrap(),
Some(1577836800000000) // this is in milliseconds
);
assert_eq!(
timestamp_parser("2020-01", EvalMode::Legacy).unwrap(),
Some(1577836800000000)
);
assert_eq!(
timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(),
Some(1577836800000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(),
Some(1577880000000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(),
Some(1577882040000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(),
Some(1577882096000000)
);
assert_eq!(
timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(),
Some(1577882096123456)
);
// assert_eq!(
// timestamp_parser("T2", EvalMode::Legacy).unwrap(),
// Some(1714356000000000) // this value needs to change everyday.
// );
}

#[test]
fn test_cast_string_to_timestamp() {
let array: ArrayRef = Arc::new(StringArray::from(vec![
Some("2020-01-01T12:34:56.123456"),
Some("T2"),
]));

let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let eval_mode = EvalMode::Legacy;
let result = cast_utf8_to_timestamp!(
&string_array,
eval_mode,
TimestampMicrosecondType,
timestamp_parser
);

assert_eq!(
result.data_type(),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()))
);
assert_eq!(result.len(), 2);
}

#[test]
fn test_cast_string_as_i8() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
false
case _ => true
}

if (supportedCast) {
// if Timezone is not UTC, we fallback to Spark
if (!timeZoneId.contains("UTC")) {
withInfo(expr, s"Unsupported timezone ${timeZoneId} for timestamp cast")
None
}
castToProto(timeZoneId, dt, childExpr, evalModeStr)
} else {
// no need to call withInfo here since it was called when determining
Expand Down
Loading
Loading