diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs index e2a212999..68c12dd24 100644 --- a/native/core/src/execution/shuffle/codec.rs +++ b/native/core/src/execution/shuffle/codec.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use arrow_array::cast::AsArray; +use arrow_array::types::Int32Type; use arrow_array::{ - Array, ArrayRef, Date32Array, Decimal128Array, Int32Array, Int64Array, RecordBatch, StringArray, + Array, ArrayRef, Date32Array, Decimal128Array, DictionaryArray, Int32Array, Int64Array, + RecordBatch, StringArray, }; use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field, Schema}; @@ -28,6 +31,9 @@ pub fn fast_codec_supports_type(data_type: &DataType) -> bool { match data_type { DataType::Int32 | DataType::Int64 | DataType::Date32 | DataType::Utf8 => true, DataType::Decimal128(_, s) if *s >= 0 => true, + DataType::Dictionary(k, v) if **k == DataType::Int32 => { + fast_codec_supports_type(k) && fast_codec_supports_type(v) + } _ => false, } } @@ -40,6 +46,7 @@ enum DataTypeId { Date32 = 4, Decimal128 = 5, Utf8 = 6, + Dictionary = 7, } pub struct BatchWriter { @@ -61,135 +68,156 @@ impl BatchWriter { self.inner.write_all(&field_name.len().to_le_bytes())?; self.inner.write_all(field_name.as_str().as_bytes())?; // data type - match field.data_type() { - DataType::Int32 => { - self.inner.write_all(&[DataTypeId::Int32 as u8])?; - } - DataType::Int64 => { - self.inner.write_all(&[DataTypeId::Int64 as u8])?; - } - DataType::Date32 => { - self.inner.write_all(&[DataTypeId::Date32 as u8])?; - } - DataType::Utf8 => { - self.inner.write_all(&[DataTypeId::Utf8 as u8])?; - } - DataType::Decimal128(p, s) if *s >= 0 => { - self.inner - .write_all(&[DataTypeId::Decimal128 as u8, *p, *s as u8])?; - } - other => { - return Err(DataFusionError::Internal(format!( - "unsupported type {other}" - ))) - } - } + self.write_data_type(field.data_type())?; // TODO nullable - assume all nullable for now } Ok(()) } + fn write_data_type(&mut self, data_type: &DataType) -> Result<(), DataFusionError> { + match data_type { + DataType::Int32 => { + self.inner.write_all(&[DataTypeId::Int32 as u8])?; + } + DataType::Int64 => { + self.inner.write_all(&[DataTypeId::Int64 as u8])?; + } + DataType::Date32 => { + self.inner.write_all(&[DataTypeId::Date32 as u8])?; + } + DataType::Utf8 => { + self.inner.write_all(&[DataTypeId::Utf8 as u8])?; + } + DataType::Decimal128(p, s) if *s >= 0 => { + self.inner + .write_all(&[DataTypeId::Decimal128 as u8, *p, *s as u8])?; + } + DataType::Dictionary(k, v) => { + self.inner.write_all(&[DataTypeId::Dictionary as u8])?; + self.write_data_type(&k)?; + self.write_data_type(&v)?; + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } + } + Ok(()) + } + pub fn write_all(&mut self, bytes: &[u8]) -> std::io::Result<()> { self.inner.write_all(bytes) } pub fn write_batch(&mut self, batch: &RecordBatch) -> Result<(), DataFusionError> { for i in 0..batch.num_columns() { - let col = batch.column(i); - match col.data_type() { - DataType::Int32 => { - let arr = col.as_any().downcast_ref::().unwrap(); + self.write_array(batch.column(i))?; + } + Ok(()) + } + + fn write_array(&mut self, col: &dyn Array) -> Result<(), DataFusionError> { + match col.data_type() { + DataType::Int32 => { + let arr = col.as_any().downcast_ref::().unwrap(); - // write data buffer - let buffer = arr.values(); + // write data buffer + let buffer = arr.values(); + let buffer = buffer.inner(); + self.write_buffer(buffer)?; + + if let Some(nulls) = arr.nulls() { + let buffer = nulls.inner(); let buffer = buffer.inner(); self.write_buffer(buffer)?; - - if let Some(nulls) = arr.nulls() { - let buffer = nulls.inner(); - let buffer = buffer.inner(); - self.write_buffer(buffer)?; - } else { - self.inner.write_all(&0_usize.to_le_bytes())?; - } + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; } - DataType::Int64 => { - let arr = col.as_any().downcast_ref::().unwrap(); + } + DataType::Int64 => { + let arr = col.as_any().downcast_ref::().unwrap(); + + // write data buffer + let buffer = arr.values(); + let buffer = buffer.inner(); + self.write_buffer(buffer)?; - // write data buffer - let buffer = arr.values(); + if let Some(nulls) = arr.nulls() { + let buffer = nulls.inner(); let buffer = buffer.inner(); self.write_buffer(buffer)?; - - if let Some(nulls) = arr.nulls() { - let buffer = nulls.inner(); - let buffer = buffer.inner(); - self.write_buffer(buffer)?; - } else { - self.inner.write_all(&0_usize.to_le_bytes())?; - } + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; } - DataType::Date32 => { - let arr = col.as_any().downcast_ref::().unwrap(); + } + DataType::Date32 => { + let arr = col.as_any().downcast_ref::().unwrap(); - // write data buffer - let buffer = arr.values(); + // write data buffer + let buffer = arr.values(); + let buffer = buffer.inner(); + self.write_buffer(buffer)?; + + if let Some(nulls) = arr.nulls() { + let buffer = nulls.inner(); let buffer = buffer.inner(); self.write_buffer(buffer)?; - - if let Some(nulls) = arr.nulls() { - let buffer = nulls.inner(); - let buffer = buffer.inner(); - self.write_buffer(buffer)?; - } else { - self.inner.write_all(&0_usize.to_le_bytes())?; - } + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; } - DataType::Decimal128(_, _) => { - let arr = col.as_any().downcast_ref::().unwrap(); + } + DataType::Decimal128(_, _) => { + let arr = col.as_any().downcast_ref::().unwrap(); + + // write data buffer + let buffer = arr.values(); + let buffer = buffer.inner(); + self.write_buffer(buffer)?; - // write data buffer - let buffer = arr.values(); + if let Some(nulls) = arr.nulls() { + let buffer = nulls.inner(); let buffer = buffer.inner(); self.write_buffer(buffer)?; - - if let Some(nulls) = arr.nulls() { - let buffer = nulls.inner(); - let buffer = buffer.inner(); - self.write_buffer(buffer)?; - } else { - self.inner.write_all(&0_usize.to_le_bytes())?; - } + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; } - DataType::Utf8 => { - let arr = col.as_any().downcast_ref::().unwrap(); + } + DataType::Utf8 => { + let arr = col.as_any().downcast_ref::().unwrap(); - // write data buffer - let buffer = arr.values(); - self.write_buffer(buffer)?; + // write data buffer + let buffer = arr.values(); + self.write_buffer(buffer)?; - // write offset buffer - let offsets = arr.offsets(); - let scalar_buffer = offsets.inner(); - let buffer = scalar_buffer.inner(); - self.write_buffer(buffer)?; + // write offset buffer + let offsets = arr.offsets(); + let scalar_buffer = offsets.inner(); + let buffer = scalar_buffer.inner(); + self.write_buffer(buffer)?; - if let Some(nulls) = arr.nulls() { - let buffer = nulls.inner(); - let buffer = buffer.inner(); - self.write_buffer(buffer)?; - } else { - self.inner.write_all(&0_usize.to_le_bytes())?; - } - } - other => { - return Err(DataFusionError::Internal(format!( - "unsupported type {other}" - ))) + if let Some(nulls) = arr.nulls() { + let buffer = nulls.inner(); + let buffer = buffer.inner(); + self.write_buffer(buffer)?; + } else { + self.inner.write_all(&0_usize.to_le_bytes())?; } } + DataType::Dictionary(k, _) if **k == DataType::Int32 => { + let arr = col + .as_any() + .downcast_ref::>() + .unwrap(); + self.write_array(arr.keys())?; + self.write_array(arr.values())?; + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } } - Ok(()) } @@ -235,26 +263,7 @@ impl<'a> BatchReader<'a> { let field_name = unsafe { String::from_utf8_unchecked(field_name_bytes.into()) }; // read data type - let type_id = self.input[self.offset] as i32; - let data_type = match type_id { - x if x == DataTypeId::Int32 as i32 => DataType::Int32, - x if x == DataTypeId::Int64 as i32 => DataType::Int64, - x if x == DataTypeId::Date32 as i32 => DataType::Date32, - x if x == DataTypeId::Utf8 as i32 => DataType::Utf8, - x if x == DataTypeId::Decimal128 as i32 => DataType::Decimal128( - self.input[self.offset + 1], - self.input[self.offset + 2] as i8, - ), - other => { - return Err(DataFusionError::Internal(format!( - "unsupported type {other}" - ))) - } - }; - self.offset += 1; - if matches!(data_type, DataType::Decimal128(_, _)) { - self.offset += 2; - } + let data_type = self.read_data_type()?; // create field let field = Arc::new(Field::new(field_name, data_type, true)); @@ -264,48 +273,92 @@ impl<'a> BatchReader<'a> { let mut arrays = Vec::with_capacity(schema.fields().len()); for i in 0..schema.fields().len() { - let buffer = self.read_buffer(); - let array: ArrayRef = match schema.field(i).data_type() { - DataType::Int32 => { - let data_buffer = ScalarBuffer::::from(buffer); - let null_buffer = self.read_null_buffer(); - Arc::new(Int32Array::try_new(data_buffer, null_buffer)?) - } - DataType::Int64 => { - let data_buffer = ScalarBuffer::::from(buffer); - let null_buffer = self.read_null_buffer(); - Arc::new(Int64Array::try_new(data_buffer, null_buffer)?) - } - DataType::Date32 => { - let data_buffer = ScalarBuffer::::from(buffer); - let null_buffer = self.read_null_buffer(); - Arc::new(Date32Array::try_new(data_buffer, null_buffer)?) - } - DataType::Decimal128(p, s) => { - let data_buffer = ScalarBuffer::::from(buffer); - let null_buffer = self.read_null_buffer(); - Arc::new( - Decimal128Array::try_new(data_buffer, null_buffer)? - .with_precision_and_scale(*p, *s)?, - ) - } - DataType::Utf8 => { - let offset_buffer = self.read_offset_buffer(); - let null_buffer = self.read_null_buffer(); - Arc::new(StringArray::try_new(offset_buffer, buffer, null_buffer)?) - } - other => { - return Err(DataFusionError::Internal(format!( - "unsupported type {other}" - ))) - } - }; + let array = self.read_array(schema.field(i).data_type())?; arrays.push(array); } Ok(RecordBatch::try_new(schema, arrays).unwrap()) } + fn read_array(&mut self, data_type: &DataType) -> Result { + Ok(match data_type { + DataType::Int32 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int32Array::try_new(data_buffer, null_buffer)?) + } + DataType::Int64 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Int64Array::try_new(data_buffer, null_buffer)?) + } + DataType::Date32 => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new(Date32Array::try_new(data_buffer, null_buffer)?) + } + DataType::Decimal128(p, s) => { + let buffer = self.read_buffer(); + let data_buffer = ScalarBuffer::::from(buffer); + let null_buffer = self.read_null_buffer(); + Arc::new( + Decimal128Array::try_new(data_buffer, null_buffer)? + .with_precision_and_scale(*p, *s)?, + ) + } + DataType::Utf8 => { + let buffer = self.read_buffer(); + let offset_buffer = self.read_offset_buffer(); + let null_buffer = self.read_null_buffer(); + Arc::new(StringArray::try_new(offset_buffer, buffer, null_buffer)?) + } + DataType::Dictionary(k, v) => { + let k = self.read_array(&k)?; + let v = self.read_array(&v)?; + Arc::new(DictionaryArray::try_new( + k.as_primitive::().to_owned(), + v, + )?) + } + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } + }) + } + + fn read_data_type(&mut self) -> Result { + let type_id = self.input[self.offset] as i32; + let data_type = match type_id { + x if x == DataTypeId::Int32 as i32 => DataType::Int32, + x if x == DataTypeId::Int64 as i32 => DataType::Int64, + x if x == DataTypeId::Date32 as i32 => DataType::Date32, + x if x == DataTypeId::Utf8 as i32 => DataType::Utf8, + x if x == DataTypeId::Dictionary as i32 => DataType::Dictionary( + Box::new(self.read_data_type()?), + Box::new(self.read_data_type()?), + ), + x if x == DataTypeId::Decimal128 as i32 => DataType::Decimal128( + self.input[self.offset + 1], + self.input[self.offset + 2] as i8, + ), + other => { + return Err(DataFusionError::Internal(format!( + "unsupported type {other}" + ))) + } + }; + self.offset += 1; + if matches!(data_type, DataType::Decimal128(_, _)) { + self.offset += 2; + } + Ok(data_type) + } + fn read_offset_buffer(&mut self) -> OffsetBuffer { let offset_buffer = self.read_buffer(); let offset_buffer: ScalarBuffer = ScalarBuffer::from(offset_buffer);