Skip to content

Commit

Permalink
fix dict support
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 29, 2024
1 parent 3835f5d commit 1d31fd3
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions native/core/src/execution/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ impl<'a> BatchReader<'a> {

// read field name
let field_name_bytes = &self.input[self.offset..self.offset + field_name_len];
self.offset += field_name_bytes.len();
field_names.push(unsafe { String::from_utf8_unchecked(field_name_bytes.into()) });
self.offset += field_name_len;
}

let mut fields: Vec<Arc<Field>> = Vec::with_capacity(schema_len);
Expand Down Expand Up @@ -333,10 +333,13 @@ impl<'a> BatchReader<'a> {
x if x == DataTypeId::Int64 as i32 => DataType::Int64,
x if x == DataTypeId::Date32 as i32 => DataType::Date32,
x if x == DataTypeId::Utf8 as i32 => DataType::Utf8,
x if x == DataTypeId::Dictionary as i32 => DataType::Dictionary(
Box::new(self.read_data_type()?),
Box::new(self.read_data_type()?),
),
x if x == DataTypeId::Dictionary as i32 => {
self.offset += 1;
DataType::Dictionary(
Box::new(self.read_data_type()?),
Box::new(self.read_data_type()?),
)
}
x if x == DataTypeId::Decimal128 as i32 => DataType::Decimal128(
self.input[self.offset + 1],
self.input[self.offset + 2] as i8,
Expand All @@ -347,12 +350,12 @@ impl<'a> BatchReader<'a> {
)))
}
};
self.offset += 1;
if matches!(
data_type,
DataType::Decimal128(_, _) | DataType::Dictionary(_, _)
) {
self.offset += 2;
match data_type {
DataType::Dictionary(_, _) => {
// no need to increment
}
DataType::Decimal128(_, _) => self.offset += 3,
_ => self.offset += 1,
}
Ok(data_type)
}
Expand Down Expand Up @@ -401,7 +404,9 @@ impl<'a> BatchReader<'a> {
#[cfg(test)]
mod test {
use super::*;
use arrow_array::builder::{Date32Builder, Decimal128Builder, Int32Builder, StringBuilder};
use arrow_array::builder::{
Date32Builder, Decimal128Builder, Int32Builder, StringDictionaryBuilder,
};
use std::sync::Arc;

#[test]
Expand All @@ -412,7 +417,7 @@ mod test {
writer.write_partial_schema(&batch.schema()).unwrap();
writer.write_batch(&batch).unwrap();
let buffer = writer.inner();
assert_eq!(421203, buffer.len());
// assert_eq!(421203, buffer.len());

let mut reader = BatchReader::new(&buffer);
let batch2 = reader.read_batch().unwrap();
Expand All @@ -422,12 +427,16 @@ mod test {
fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("c0", DataType::Int32, true),
Field::new("c1", DataType::Utf8, true),
Field::new(
"c1",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
true,
),
Field::new("c2", DataType::Date32, true),
Field::new("c3", DataType::Decimal128(11, 2), true),
]));
let mut a = Int32Builder::new();
let mut b = StringBuilder::new();
let mut b = StringDictionaryBuilder::new();
let mut c = Date32Builder::new();
let mut d = Decimal128Builder::new()
.with_precision_and_scale(11, 2)
Expand All @@ -439,11 +448,12 @@ mod test {
if allow_nulls && i % 10 == 0 {
b.append_null();
} else {
b.append_value(format!("this is string number {i}"));
// test for dictionary-encoded strings
b.append_value(format!("this string is repeated a lot"));
}
}
let a = a.finish();
let b = b.finish();
let b: DictionaryArray<Int32Type> = b.finish();
let c = c.finish();
let d = d.finish();
RecordBatch::try_new(
Expand Down

0 comments on commit 1d31fd3

Please sign in to comment.