Skip to content

Commit

Permalink
chore: remove some unwraps from shuffle module (#601)
Browse files Browse the repository at this point in the history
* remove some unwraps from shuffle module

* simplifiy
  • Loading branch information
andygrove authored Jun 27, 2024
1 parent a1641ab commit 0d994d0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 33 deletions.
26 changes: 13 additions & 13 deletions core/src/execution/shuffle/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,87 +192,87 @@ pub fn append_list_element<T: ArrayBuilder>(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<BooleanBuilder>>()
.unwrap(),
.expect("ListBuilder<BooleanBuilder>"),
list,
idx,
),
DataType::Int8 => append_int8_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Int8Builder>>()
.unwrap(),
.expect("ListBuilder<Int8Builder>"),
list,
idx,
),
DataType::Int16 => append_int16_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Int16Builder>>()
.unwrap(),
.expect("ListBuilder<Int16Builder>"),
list,
idx,
),
DataType::Int32 => append_int32_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Int32Builder>>()
.unwrap(),
.expect("ListBuilder<Int32Builder>"),
list,
idx,
),
DataType::Int64 => append_int64_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Int64Builder>>()
.unwrap(),
.expect("ListBuilder<Int64Builder>"),
list,
idx,
),
DataType::Float32 => append_float32_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Float32Builder>>()
.unwrap(),
.expect("ListBuilder<Float32Builder>"),
list,
idx,
),
DataType::Float64 => append_float64_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Float64Builder>>()
.unwrap(),
.expect("ListBuilder<Float64Builder>"),
list,
idx,
),
DataType::Date32 => append_date32_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<Date32Builder>>()
.unwrap(),
.expect("ListBuilder<Date32Builder>"),
list,
idx,
),
DataType::Timestamp(TimeUnit::Microsecond, _) => append_timestamp_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<TimestampMicrosecondBuilder>>()
.unwrap(),
.expect("ListBuilder<TimestampMicrosecondBuilder>"),
list,
idx,
),
DataType::Binary => append_binary_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<BinaryBuilder>>()
.unwrap(),
.expect("ListBuilder<BinaryBuilder>"),
list,
idx,
),
DataType::Utf8 => append_string_element(
list_builder
.as_any_mut()
.downcast_mut::<ListBuilder<StringBuilder>>()
.unwrap(),
.expect("ListBuilder<StringBuilder>"),
list,
idx,
),
Expand All @@ -281,7 +281,7 @@ pub fn append_list_element<T: ArrayBuilder>(
.values()
.as_any_mut()
.downcast_mut::<Decimal128Builder>()
.unwrap();
.expect("ListBuilder<Decimal128Builder>");
let is_null = list.is_null_at(idx);

if is_null {
Expand Down Expand Up @@ -319,7 +319,7 @@ pub fn append_list_element<T: ArrayBuilder>(
.values()
.as_any_mut()
.downcast_mut::<StructBuilder>()
.unwrap();
.expect("StructBuilder");
let is_null = list.is_null_at(idx);

let nested_row = if is_null {
Expand Down
43 changes: 23 additions & 20 deletions core/src/execution/shuffle/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use arrow_array::{
types::Int32Type,
Array, ArrayRef, RecordBatch, RecordBatchOptions,
};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit};
use jni::sys::{jint, jlong};
use std::{
fs::OpenOptions,
Expand Down Expand Up @@ -275,7 +275,10 @@ impl SparkUnsafeRow {

macro_rules! downcast_builder {
($builder_type:ty, $builder:expr) => {
$builder.into_box_any().downcast::<$builder_type>().unwrap()
$builder
.into_box_any()
.downcast::<$builder_type>()
.expect(stringify!($builder_type))
};
}

Expand All @@ -284,7 +287,7 @@ macro_rules! downcast_builder_ref {
$builder
.as_any_mut()
.downcast_mut::<$builder_type>()
.unwrap()
.expect(stringify!($builder_type))
};
}

Expand Down Expand Up @@ -348,8 +351,7 @@ pub(crate) fn append_field(
$field,
field_builder,
&row.get_map(idx),
)
.unwrap();
)?;
}
}
}};
Expand Down Expand Up @@ -378,8 +380,7 @@ pub(crate) fn append_field(
$element_dt,
field_builder,
&row.get_array(idx),
)
.unwrap()
)?
}
}
}};
Expand Down Expand Up @@ -1057,7 +1058,7 @@ pub(crate) fn append_columns(
let element_builder = builder
.as_any_mut()
.downcast_mut::<$builder_type>()
.unwrap();
.expect(stringify!($builder_type));
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
Expand All @@ -1084,7 +1085,7 @@ pub(crate) fn append_columns(
let list_builder = builder
.as_any_mut()
.downcast_mut::<ListBuilder<$builder_type>>()
.unwrap();
.expect(stringify!($builder_type));
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
Expand All @@ -1103,8 +1104,7 @@ pub(crate) fn append_columns(
$element_dt,
list_builder,
&row.get_array(column_idx),
)
.unwrap()
)?
}
}
}};
Expand All @@ -1116,7 +1116,11 @@ pub(crate) fn append_columns(
let map_builder = builder
.as_any_mut()
.downcast_mut::<MapBuilder<$key_builder_type, $value_builder_type>>()
.unwrap();
.expect(&format!(
"MapBuilder<{},{}>",
stringify!($key_builder_type),
stringify!($value_builder_type)
));
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
Expand All @@ -1135,8 +1139,7 @@ pub(crate) fn append_columns(
$field,
map_builder,
&row.get_map(column_idx),
)
.unwrap()
)?
}
}
}};
Expand All @@ -1148,7 +1151,7 @@ pub(crate) fn append_columns(
let struct_builder = builder
.as_any_mut()
.downcast_mut::<StructBuilder>()
.unwrap();
.expect("StructBuilder");
let mut row = SparkUnsafeRow::new(schema);

for i in row_start..row_end {
Expand Down Expand Up @@ -3347,7 +3350,7 @@ pub fn process_sorted_row_partition(
.zip(schema.iter())
.map(|(builder, datatype)| builder_to_array(builder, datatype, prefer_dictionary_ratio))
.collect();
let batch = make_batch(array_refs?, n);
let batch = make_batch(array_refs?, n)?;

let mut frozen: Vec<u8> = vec![];
let mut cursor = Cursor::new(&mut frozen);
Expand Down Expand Up @@ -3382,7 +3385,7 @@ fn builder_to_array(
let builder = builder
.as_any_mut()
.downcast_mut::<StringDictionaryBuilder<Int32Type>>()
.unwrap();
.expect("StringDictionaryBuilder<Int32Type>");

let dict_array = builder.finish();
let num_keys = dict_array.keys().len();
Expand All @@ -3401,7 +3404,7 @@ fn builder_to_array(
let builder = builder
.as_any_mut()
.downcast_mut::<BinaryDictionaryBuilder<Int32Type>>()
.unwrap();
.expect("BinaryDictionaryBuilder<Int32Type>");

let dict_array = builder.finish();
let num_keys = dict_array.keys().len();
Expand All @@ -3420,7 +3423,7 @@ fn builder_to_array(
}
}

fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) -> RecordBatch {
fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) -> Result<RecordBatch, ArrowError> {
let mut dict_id = 0;
let fields = arrays
.iter()
Expand All @@ -3442,5 +3445,5 @@ fn make_batch(arrays: Vec<ArrayRef>, row_count: usize) -> RecordBatch {
.collect::<Vec<_>>();
let schema = Arc::new(Schema::new(fields));
let options = RecordBatchOptions::new().with_row_count(Option::from(row_count));
RecordBatch::try_new_with_options(schema, arrays, &options).unwrap()
RecordBatch::try_new_with_options(schema, arrays, &options)
}

0 comments on commit 0d994d0

Please sign in to comment.