Skip to content

Commit

Permalink
add support for casting unsigned int to signed int
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 13, 2024
1 parent f4d76f1 commit 37bb93a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 55 deletions.
109 changes: 56 additions & 53 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use crate::timezone;
use crate::utils::array_with_timezone;
use crate::{EvalMode, SparkError, SparkResult};
use arrow::{
array::{
cast::AsArray,
Expand Down Expand Up @@ -56,11 +59,6 @@ use std::{
sync::Arc,
};

use crate::timezone;
use crate::utils::array_with_timezone;

use crate::{EvalMode, SparkError, SparkResult};

static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");

const MICROS_PER_SECOND: i64 = 1000000;
Expand Down Expand Up @@ -166,6 +164,11 @@ pub fn cast_supported(

match (from_type, to_type) {
(Boolean, _) => can_cast_from_boolean(from_type, options),
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
if options.allow_cast_unsigned_ints =>
{
true
}
(Int8, _) => can_cast_from_byte(from_type, options),
(Int16, _) => can_cast_from_short(from_type, options),
(Int32, _) => can_cast_from_int(from_type, options),
Expand Down Expand Up @@ -783,6 +786,8 @@ pub struct SparkCastOptions {
pub timezone: String,
/// Allow casts that are supported but not guaranteed to be 100% compatible
pub allow_incompat: bool,
/// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter)
pub allow_cast_unsigned_ints: bool,
}

impl SparkCastOptions {
Expand All @@ -791,6 +796,7 @@ impl SparkCastOptions {
eval_mode,
timezone: timezone.to_string(),
allow_incompat,
allow_cast_unsigned_ints: false,
}
}

Expand All @@ -799,6 +805,7 @@ impl SparkCastOptions {
eval_mode,
timezone: "".to_string(),
allow_incompat,
allow_cast_unsigned_ints: false,
}
}
}
Expand Down Expand Up @@ -834,14 +841,14 @@ fn cast_array(
to_type: &DataType,
cast_options: &SparkCastOptions,
) -> DataFusionResult<ArrayRef> {
use DataType::*;
let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?;
let from_type = array.data_type().clone();

let array = match &from_type {
DataType::Dictionary(key_type, value_type)
if key_type.as_ref() == &DataType::Int32
&& (value_type.as_ref() == &DataType::Utf8
|| value_type.as_ref() == &DataType::LargeUtf8) =>
Dictionary(key_type, value_type)
if key_type.as_ref() == &Int32
&& (value_type.as_ref() == &Utf8 || value_type.as_ref() == &LargeUtf8) =>
{
let dict_array = array
.as_any()
Expand All @@ -854,7 +861,7 @@ fn cast_array(
);

let casted_result = match to_type {
DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
_ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?,
};
return Ok(spark_cast_postprocess(casted_result, &from_type, to_type));
Expand All @@ -865,70 +872,66 @@ fn cast_array(
let eval_mode = cast_options.eval_mode;

let cast_result = match (from_type, to_type) {
(DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
(DataType::LargeUtf8, DataType::Boolean) => {
spark_cast_utf8_to_boolean::<i64>(&array, eval_mode)
}
(DataType::Utf8, DataType::Timestamp(_, _)) => {
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
(Utf8, Timestamp(_, _)) => {
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
}
(DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode),
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
| (DataType::Int32, DataType::Int16)
| (DataType::Int32, DataType::Int8)
| (DataType::Int16, DataType::Int8)
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
(Int64, Int32)
| (Int64, Int16)
| (Int64, Int8)
| (Int32, Int16)
| (Int32, Int8)
| (Int16, Int8)
if eval_mode != EvalMode::Try =>
{
spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
}
(DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64) => {
(Utf8, Int8 | Int16 | Int32 | Int64) => {
cast_string_to_int::<i32>(to_type, &array, eval_mode)
}
(
DataType::LargeUtf8,
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) => cast_string_to_int::<i64>(to_type, &array, eval_mode),
(DataType::Float64, DataType::Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
(DataType::Float64, DataType::LargeUtf8) => {
spark_cast_float64_to_utf8::<i64>(&array, eval_mode)
}
(DataType::Float32, DataType::Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
(DataType::Float32, DataType::LargeUtf8) => {
spark_cast_float32_to_utf8::<i64>(&array, eval_mode)
}
(DataType::Float32, DataType::Decimal128(precision, scale)) => {
(LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
cast_string_to_int::<i64>(to_type, &array, eval_mode)
}
(Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
(Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, eval_mode),
(Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
(Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, eval_mode),
(Float32, Decimal128(precision, scale)) => {
cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
}
(DataType::Float64, DataType::Decimal128(precision, scale)) => {
(Float64, Decimal128(precision, scale)) => {
cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
}
(DataType::Float32, DataType::Int8)
| (DataType::Float32, DataType::Int16)
| (DataType::Float32, DataType::Int32)
| (DataType::Float32, DataType::Int64)
| (DataType::Float64, DataType::Int8)
| (DataType::Float64, DataType::Int16)
| (DataType::Float64, DataType::Int32)
| (DataType::Float64, DataType::Int64)
| (DataType::Decimal128(_, _), DataType::Int8)
| (DataType::Decimal128(_, _), DataType::Int16)
| (DataType::Decimal128(_, _), DataType::Int32)
| (DataType::Decimal128(_, _), DataType::Int64)
(Float32, Int8)
| (Float32, Int16)
| (Float32, Int32)
| (Float32, Int64)
| (Float64, Int8)
| (Float64, Int16)
| (Float64, Int32)
| (Float64, Int64)
| (Decimal128(_, _), Int8)
| (Decimal128(_, _), Int16)
| (Decimal128(_, _), Int32)
| (Decimal128(_, _), Int64)
if eval_mode != EvalMode::Try =>
{
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type)
}
(DataType::Struct(_), DataType::Utf8) => {
Ok(casts_struct_to_string(array.as_struct(), cast_options)?)
}
(DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct(
(Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?),
(Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
array.as_struct(),
from_type,
to_type,
cast_options,
)?),
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
if cast_options.allow_cast_unsigned_ints =>
{
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
_ if is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => {
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
Expand Down
21 changes: 19 additions & 2 deletions native/spark-expr/src/schema_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ mod test {
use arrow::array::{Int32Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::UInt32Array;
use arrow_schema::SchemaRef;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
Expand All @@ -304,7 +305,7 @@ mod test {
use std::sync::Arc;

#[tokio::test]
async fn parquet_roundtrip() -> Result<(), DataFusionError> {
async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> {
let file_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Expand All @@ -325,6 +326,20 @@ mod test {
Ok(())
}

#[tokio::test]
async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> {
let file_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt32, false)]));

let ids = Arc::new(UInt32Array::from(vec![1, 2, 3])) as Arc<dyn arrow::array::Array>;
let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids])?;

let required_schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));

let _ = roundtrip(&batch, required_schema).await?;

Ok(())
}

/// Create a Parquet file containing a single batch and then read the batch back using
/// the specified required_schema. This will cause the SchemaAdapter code to be used.
async fn roundtrip(
Expand All @@ -344,7 +359,9 @@ mod test {
filename.to_string(),
)?]]);

let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
let mut spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
spark_cast_options.allow_cast_unsigned_ints = true;

let parquet_exec = ParquetExec::builder(file_scan_config)
.with_schema_adapter_factory(Arc::new(SparkSchemaAdapterFactory::new(
spark_cast_options,
Expand Down

0 comments on commit 37bb93a

Please sign in to comment.