Skip to content

Commit

Permalink
introduce a macro
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Apr 23, 2024
1 parent 2ea4442 commit d4fd8ff
Showing 1 changed file with 62 additions and 138 deletions.
200 changes: 62 additions & 138 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ pub struct Cast {
pub timezone: String,
}

macro_rules! spark_cast_utf8_to_integral {
($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len());
for i in 0..$string_array.len() {
if $string_array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) =
$cast_method($string_array.value(i).trim(), $eval_mode)?
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) as ArrayRef);
result
}};
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -110,57 +129,19 @@ impl Cast {
(
DataType::Utf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) => {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("spark_cast_utf8_to_i8 expected a string array");

match to_type {
DataType::Int8 => {
Self::spark_cast_utf8_to_i8::<i32>(string_array, self.eval_mode)?
}
DataType::Int16 => {
Self::spark_cast_utf8_to_i16::<i32>(string_array, self.eval_mode)?
}
DataType::Int32 => {
Self::spark_cast_utf8_to_i32::<i32>(string_array, self.eval_mode)?
}
DataType::Int64 => {
Self::spark_cast_utf8_to_i64::<i32>(string_array, self.eval_mode)?
}
_ => unreachable!("invalid integral type in cast from string"),
}
}
) => Self::spark_cast_string_to_integral(to_type, &array, self.eval_mode)?,
(
DataType::Dictionary(key_type, value_type),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) if key_type.as_ref() == &DataType::Int32
&& value_type.as_ref() == &DataType::Utf8 =>
{
// TODO file follow on issue for optimizing this to avoid unpacking first
// Note that we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
let unpacked_array = Self::unpack_dict_string_array::<Int32Type>(&array)?;
let string_array = unpacked_array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("spark_cast_utf8_to_i8 expected a string array");
match to_type {
DataType::Int8 => {
Self::spark_cast_utf8_to_i8::<i32>(&string_array, self.eval_mode)?
}
DataType::Int16 => {
Self::spark_cast_utf8_to_i16::<i32>(&string_array, self.eval_mode)?
}
DataType::Int32 => {
Self::spark_cast_utf8_to_i32::<i32>(&string_array, self.eval_mode)?
}
DataType::Int64 => {
Self::spark_cast_utf8_to_i64::<i32>(&string_array, self.eval_mode)?
}
_ => {
unreachable!("invalid integral type in cast from dictionary-encoded string")
}
}
Self::spark_cast_string_to_integral(to_type, &unpacked_array, self.eval_mode)?
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
Expand All @@ -170,6 +151,43 @@ impl Cast {
Ok(spark_cast(cast_result, from_type, to_type))
}

fn spark_cast_string_to_integral(
to_type: &DataType,
array: &ArrayRef,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.expect("spark_cast_string_to_integral expected a string array");

let cast_array: ArrayRef = match to_type {
DataType::Int8 => {
spark_cast_utf8_to_integral!(string_array, eval_mode, Int8Type, cast_string_to_i8)?
}
DataType::Int16 => spark_cast_utf8_to_integral!(
string_array,
eval_mode,
Int16Type,
cast_string_to_i16
)?,
DataType::Int32 => spark_cast_utf8_to_integral!(
string_array,
eval_mode,
Int32Type,
cast_string_to_i32
)?,
DataType::Int64 => spark_cast_utf8_to_integral!(
string_array,
eval_mode,
Int64Type,
cast_string_to_i64
)?,
_ => unreachable!("invalid integral type in cast from string"),
};
Ok(cast_array)
}

fn unpack_dict_string_array<T: ArrowDictionaryKeyType>(
array: &ArrayRef,
) -> DataFusionResult<ArrayRef> {
Expand Down Expand Up @@ -213,100 +231,6 @@ impl Cast {

Ok(Arc::new(output_array))
}

// TODO reduce code duplication

fn spark_cast_utf8_to_i8<OffsetSize>(
string_array: &GenericStringArray<OffsetSize>,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
// cast the dictionary values from string to int8
let mut cast_array = PrimitiveArray::<Int8Type>::builder(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) =
cast_string_to_i8(string_array.value(i).trim(), eval_mode)?
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
Ok(Arc::new(cast_array.finish()))
}

fn spark_cast_utf8_to_i16<OffsetSize>(
string_array: &GenericStringArray<OffsetSize>,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
// cast the dictionary values from string to int8
let mut cast_array = PrimitiveArray::<Int16Type>::builder(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) =
cast_string_to_i16(string_array.value(i).trim(), eval_mode)?
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
Ok(Arc::new(cast_array.finish()))
}

fn spark_cast_utf8_to_i32<OffsetSize>(
string_array: &GenericStringArray<OffsetSize>,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
// cast the dictionary values from string to int8
let mut cast_array = PrimitiveArray::<Int32Type>::builder(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) =
cast_string_to_i32(string_array.value(i).trim(), eval_mode)?
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
Ok(Arc::new(cast_array.finish()))
}

fn spark_cast_utf8_to_i64<OffsetSize>(
string_array: &GenericStringArray<OffsetSize>,
eval_mode: EvalMode,
) -> CometResult<ArrayRef>
where
OffsetSize: OffsetSizeTrait,
{
// cast the dictionary values from string to int8
let mut cast_array = PrimitiveArray::<Int64Type>::builder(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
cast_array.append_null()
} else if let Some(cast_value) =
cast_string_to_i64(string_array.value(i).trim(), eval_mode)?
{
cast_array.append_value(cast_value);
} else {
cast_array.append_null()
}
}
Ok(Arc::new(cast_array.finish()))
}
}

fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult<Option<i8>> {
Expand Down

0 comments on commit d4fd8ff

Please sign in to comment.