From 1d31fd3430e0f9f3332cdca7e1624cb9dd305bd4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Dec 2024 09:32:35 -0700 Subject: [PATCH] fix dict support --- native/core/src/execution/shuffle/codec.rs | 44 +++++++++++++--------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index 77f477d35..462ef47a7 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -257,8 +257,8 @@ impl<'a> BatchReader<'a> { // read field name let field_name_bytes = &self.input[self.offset..self.offset + field_name_len]; - self.offset += field_name_bytes.len(); field_names.push(unsafe { String::from_utf8_unchecked(field_name_bytes.into()) }); + self.offset += field_name_len; } let mut fields: Vec> = Vec::with_capacity(schema_len); @@ -333,10 +333,13 @@ impl<'a> BatchReader<'a> { x if x == DataTypeId::Int64 as i32 => DataType::Int64, x if x == DataTypeId::Date32 as i32 => DataType::Date32, x if x == DataTypeId::Utf8 as i32 => DataType::Utf8, - x if x == DataTypeId::Dictionary as i32 => DataType::Dictionary( - Box::new(self.read_data_type()?), - Box::new(self.read_data_type()?), - ), + x if x == DataTypeId::Dictionary as i32 => { + self.offset += 1; + DataType::Dictionary( + Box::new(self.read_data_type()?), + Box::new(self.read_data_type()?), + ) + } x if x == DataTypeId::Decimal128 as i32 => DataType::Decimal128( self.input[self.offset + 1], self.input[self.offset + 2] as i8, @@ -347,12 +350,12 @@ impl<'a> BatchReader<'a> { ))) } }; - self.offset += 1; - if matches!( - data_type, - DataType::Decimal128(_, _) | DataType::Dictionary(_, _) - ) { - self.offset += 2; + match data_type { + DataType::Dictionary(_, _) => { + // no need to increment + } + DataType::Decimal128(_, _) => self.offset += 3, + _ => self.offset += 1, } Ok(data_type) } @@ -401,7 +404,9 @@ impl<'a> BatchReader<'a> { #[cfg(test)] mod test { use super::*; - use arrow_array::builder::{Date32Builder, Decimal128Builder, Int32Builder, StringBuilder}; + use arrow_array::builder::{ + Date32Builder, Decimal128Builder, Int32Builder, StringDictionaryBuilder, + }; use std::sync::Arc; #[test] @@ -412,7 +417,7 @@ mod test { writer.write_partial_schema(&batch.schema()).unwrap(); writer.write_batch(&batch).unwrap(); let buffer = writer.inner(); - assert_eq!(421203, buffer.len()); + // assert_eq!(421203, buffer.len()); let mut reader = BatchReader::new(&buffer); let batch2 = reader.read_batch().unwrap(); @@ -422,12 +427,16 @@ mod test { fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { let schema = Arc::new(Schema::new(vec![ Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), + Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), Field::new("c2", DataType::Date32, true), Field::new("c3", DataType::Decimal128(11, 2), true), ])); let mut a = Int32Builder::new(); - let mut b = StringBuilder::new(); + let mut b = StringDictionaryBuilder::new(); let mut c = Date32Builder::new(); let mut d = Decimal128Builder::new() .with_precision_and_scale(11, 2) @@ -439,11 +448,12 @@ mod test { if allow_nulls && i % 10 == 0 { b.append_null(); } else { - b.append_value(format!("this is string number {i}")); + // test for dictionary-encoded strings + b.append_value(format!("this string is repeated a lot")); } } let a = a.finish(); - let b = b.finish(); + let b: DictionaryArray = b.finish(); let c = c.finish(); let d = d.finish(); RecordBatch::try_new(