Skip to content

Commit

Permalink
chore: Refactor cast to use SparkCastOptions param (#1146)
Browse files Browse the repository at this point in the history
* Refactor cast to use SparkCastOptions param

* update tests

* update benches

* update benches

* update benches
  • Loading branch information
andygrove authored Dec 6, 2024
1 parent 1c6c7a9 commit 8d83cc1
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 158 deletions.
23 changes: 9 additions & 14 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ use datafusion_comet_proto::{
};
use datafusion_comet_spark_expr::{
ArrayInsert, Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField,
HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, SparkCastOptions,
TimestampTruncExpr, ToJson,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
Expand Down Expand Up @@ -388,14 +389,11 @@ impl PhysicalPlanner {
ExprStruct::Cast(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let timezone = expr.timezone.clone();
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
Ok(Arc::new(Cast::new(
child,
datatype,
eval_mode,
timezone,
expr.allow_incompat,
SparkCastOptions::new(eval_mode, &expr.timezone, expr.allow_incompat),
)))
}
ExprStruct::Hour(expr) => {
Expand Down Expand Up @@ -806,24 +804,21 @@ impl PhysicalPlanner {
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new_without_timezone(
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new_without_timezone(
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new_without_timezone(
Ok(Arc::new(Cast::new(
child,
data_type,
EvalMode::Legacy,
false,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
)))
}
(
Expand Down
30 changes: 6 additions & 24 deletions native/spark-expr/benches/cast_from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,18 @@
use arrow_array::{builder::StringBuilder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::{Cast, EvalMode};
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
let batch = create_utf8_batch();
let expr = Arc::new(Column::new("a", 0));
let timezone = "".to_string();
let cast_string_to_i8 = Cast::new(
expr.clone(),
DataType::Int8,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i16 = Cast::new(
expr.clone(),
DataType::Int16,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i32 = Cast::new(
expr.clone(),
DataType::Int32,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false);
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false);
let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);

let mut group = c.benchmark_group("cast_string_to_int");
group.bench_function("cast_string_to_i8", |b| {
Expand Down
22 changes: 5 additions & 17 deletions native/spark-expr/benches/cast_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,17 @@
use arrow_array::{builder::Int32Builder, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::{Cast, EvalMode};
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
use std::sync::Arc;

fn criterion_benchmark(c: &mut Criterion) {
let batch = create_int32_batch();
let expr = Arc::new(Column::new("a", 0));
let timezone = "".to_string();
let cast_i32_to_i8 = Cast::new(
expr.clone(),
DataType::Int8,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_i32_to_i16 = Cast::new(
expr.clone(),
DataType::Int16,
EvalMode::Legacy,
timezone.clone(),
false,
);
let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone, false);
let spark_cast_options = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
let cast_i32_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
let cast_i32_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);

let mut group = c.benchmark_group("cast_int_to_int");
group.bench_function("cast_i32_to_i8", |b| {
Expand Down
Loading

0 comments on commit 8d83cc1

Please sign in to comment.