Skip to content

Commit

Permalink
support casting DateType in comet
Browse files Browse the repository at this point in the history
  • Loading branch information
vidyasankarv committed May 5, 2024
1 parent b39ed88 commit 52e6bcf
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 54 deletions.
233 changes: 188 additions & 45 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,27 @@ use std::{
sync::Arc,
};

use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
record_batch::RecordBatch,
util::display::FormatOptions,
};
use arrow_array::{
Array,
ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
types::{Int16Type, Int32Type, Int64Type, Int8Type},
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_array::types::Date32Type;
use arrow_schema::{DataType, Schema};
use chrono::{TimeZone, Timelike};
use chrono::{Datelike, NaiveDate, Timelike, TimeZone};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use num::{CheckedSub, Integer, Num, traits::CheckedNeg};
use regex::Regex;

use crate::errors::{CometError, CometResult};
use crate::execution::datafusion::expressions::utils::{
array_with_timezone, down_cast_any_ref, spark_cast,
};
Expand Down Expand Up @@ -107,7 +108,23 @@ macro_rules! cast_utf8_to_timestamp {
result
}};
}

macro_rules! cast_utf8_to_date {
($array:expr, $eval_mode:expr, $array_type:ty, $date_parser:ident) => {{
let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
} else if let Ok(Some(cast_value)) = $date_parser($array.value(i).trim(), $eval_mode) {
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
result
}};
}
macro_rules! cast_float_to_string {
($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{

Expand Down Expand Up @@ -274,16 +291,20 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type, self.eval_mode)?
}
(DataType::Utf8, DataType::Date32)
| (DataType::Utf8, DataType::Date64) => {
Self::cast_string_to_date(&array, to_type, self.eval_mode)?
}
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
if self.eval_mode != EvalMode::Try =>
{
Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type)?
}
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
Expand All @@ -297,29 +318,29 @@ impl Cast {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
match value_type.as_ref() {
DataType::Utf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i32>(to_type, &unpacked_array, self.eval_mode)?
}
DataType::LargeUtf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i64>(to_type, &unpacked_array, self.eval_mode)?
|| value_type.as_ref() == &DataType::LargeUtf8) =>
{
// TODO: we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
match value_type.as_ref() {
DataType::Utf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::Utf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i32>(to_type, &unpacked_array, self.eval_mode)?
}
DataType::LargeUtf8 => {
let unpacked_array =
cast_with_options(&array, &DataType::LargeUtf8, &CAST_OPTIONS)?;
Self::cast_string_to_int::<i64>(to_type, &unpacked_array, self.eval_mode)?
}
dt => unreachable!(
"{}",
format!("invalid value type {dt} for dictionary-encoded string array")
),
}
dt => unreachable!(
"{}",
format!("invalid value type {dt} for dictionary-encoded string array")
),
}
}
(DataType::Float64, DataType::Utf8) => {
Self::spark_cast_float64_to_utf8::<i32>(&array, self.eval_mode)?
}
Expand Down Expand Up @@ -371,6 +392,30 @@ impl Cast {
Ok(cast_array)
}

fn cast_string_to_date(
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let cast_array: ArrayRef = match to_type {
DataType::Date32 | DataType::Date64 => {
cast_utf8_to_date!(
string_array,
eval_mode,
Date32Type,
date_parser
)
}
_ => unreachable!("Invalid data type {:?} in cast from string", to_type),
};
Ok(cast_array)
}

fn cast_string_to_timestamp(
array: &ArrayRef,
to_type: &DataType,
Expand Down Expand Up @@ -399,8 +444,8 @@ impl Cast {
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize)
}
Expand All @@ -409,8 +454,8 @@ impl Cast {
from: &dyn Array,
_eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
where
OffsetSize: OffsetSizeTrait,
{
cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize)
}
Expand Down Expand Up @@ -451,8 +496,8 @@ impl Cast {
from: &dyn Array,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
where
OffsetSize: OffsetSizeTrait,
{
let array = from
.as_any()
Expand Down Expand Up @@ -489,7 +534,7 @@ fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult<Option<i8>>
i8::MIN as i32,
i8::MAX as i32,
)?
.map(|v| v as i8))
.map(|v| v as i8))
}

/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
Expand All @@ -501,7 +546,7 @@ fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult<Option<i16>
i16::MIN as i32,
i16::MAX as i32,
)?
.map(|v| v as i16))
.map(|v| v as i16))
}

/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
Expand Down Expand Up @@ -809,15 +854,15 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i64>
}

if timestamp.is_none() {
if eval_mode == EvalMode::Ansi {
return Err(CometError::CastInvalidValue {
return if eval_mode == EvalMode::Ansi {
Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
});
})
} else {
return Ok(None);
}
Ok(None)
};
}

match timestamp {
Expand Down Expand Up @@ -954,13 +999,82 @@ fn parse_str_to_time_only_timestamp(value: &str) -> CometResult<Option<i64>> {
Ok(Some(timestamp))
}


fn date_parser(value: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> {
let value = value.trim();
if value.is_empty() {
return Ok(None);
}

// Define regex patterns and corresponding parsing functions
let patterns = &[
(Regex::new(r"^\d{4}$").unwrap(), parse_year as fn(&str) -> CometResult<Option<i32>>),
(Regex::new(r"^\d{4}-\d{2}$").unwrap(), parse_year_month),
(Regex::new(r"^\d{4}-\d{2}-\d{2}T?$").unwrap(), parse_year_month_day),
];

let mut date = None;

// Iterate through patterns and try matching
for (pattern, parse_func) in patterns {
if pattern.is_match(value) {
date = parse_func(value)?;
break;
}
}

if date.is_none() && eval_mode == EvalMode::Ansi {
return Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "DATE".to_string(),
});
}

Ok(date)
}

fn parse_year(value: &str) -> CometResult<Option<i32>> {
let year: i32 = value.parse()?;
let date = NaiveDate::from_ymd_opt(year, 1, 1);
match date {
Some(date) => Ok(Some(date.num_days_from_ce())),
None => Err(CometError::Internal(
"Failed to parse date".to_string(),
)),
}
}

fn parse_year_month(value: &str) -> CometResult<Option<i32>> {
let date = NaiveDate::parse_from_str(value, "%Y-%m");
match date {
Ok(date) => Ok(Some(date.num_days_from_ce())),
Err(_) => Err(CometError::Internal(
"Failed to parse date".to_string(),
)),
}
}

fn parse_year_month_day(value: &str) -> CometResult<Option<i32>> {
let value = value.trim_end_matches('T');
let date = NaiveDate::parse_from_str(value, "%Y-%m-%d");
match date {
Ok(date) => Ok(Some(date.num_days_from_ce())),
Err(_) => Err(CometError::Internal(
"Failed to parse date".to_string(),
)),
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::Date32Type;
use arrow::datatypes::TimestampMicrosecondType;
use arrow_array::StringArray;
use arrow_schema::TimeUnit;

use super::*;

#[test]
fn timestamp_parser_test() {
// write for all formats
Expand Down Expand Up @@ -1025,6 +1139,35 @@ mod tests {
assert_eq!(result.len(), 2);
}

#[test]
fn test_cast_string_as_date() {
let array: ArrayRef = Arc::new(StringArray::from(vec![
Some("2020"),
Some("2020-01"),
Some("2020-01-01"),
Some("2020-01-01T"),
]));

let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("Expected a string array");

let eval_mode = EvalMode::Legacy;
let result = cast_utf8_to_date!(
&string_array,
eval_mode,
Date32Type,
date_parser
);

assert_eq!(
result.data_type(),
&DataType::Date32
);
assert_eq!(result.len(), 4);
}

#[test]
fn test_cast_string_as_i8() {
// basic
Expand Down
Loading

0 comments on commit 52e6bcf

Please sign in to comment.