Skip to content

Commit

Permalink
support more types
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 29, 2024
1 parent a21ea6f commit 474a051
Showing 1 changed file with 115 additions and 43 deletions.
158 changes: 115 additions & 43 deletions native/core/src/execution/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{
Array, ArrayRef, Date32Array, Decimal128Array, DictionaryArray, Int32Array, Int64Array,
RecordBatch, StringArray,
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray,
Float64Array, Int32Array, Int64Array, RecordBatch, StringArray,
};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Schema};
Expand All @@ -29,22 +29,32 @@ use std::sync::Arc;

pub fn fast_codec_supports_type(data_type: &DataType) -> bool {
match data_type {
DataType::Int32 | DataType::Int64 | DataType::Date32 | DataType::Utf8 => true,
DataType::Boolean
| DataType::Int32
| DataType::Int64
| DataType::Float64
| DataType::Date32
| DataType::Utf8
| DataType::Binary => true,
DataType::Decimal128(_, s) if *s >= 0 => true,
DataType::Dictionary(k, v) if **k == DataType::Int32 => fast_codec_supports_type(v),
_ => false,
_ => {
println!("UNSUPPORTED: {data_type:?}");
false
}
}
}

enum DataTypeId {
// Int8 = 0,
// Int16 = 1,
Boolean = 0,
Float64 = 1,
Int32 = 2,
Int64 = 3,
Date32 = 4,
Decimal128 = 5,
Utf8 = 6,
Dictionary = 7,
Binary = 8,
}

pub struct BatchWriter<W: Write> {
Expand Down Expand Up @@ -72,18 +82,27 @@ impl<W: Write> BatchWriter<W> {

fn write_data_type(&mut self, data_type: &DataType) -> Result<(), DataFusionError> {
match data_type {
DataType::Boolean => {
self.inner.write_all(&[DataTypeId::Boolean as u8])?;
}
DataType::Int32 => {
self.inner.write_all(&[DataTypeId::Int32 as u8])?;
}
DataType::Int64 => {
self.inner.write_all(&[DataTypeId::Int64 as u8])?;
}
DataType::Float64 => {
self.inner.write_all(&[DataTypeId::Float64 as u8])?;
}
DataType::Date32 => {
self.inner.write_all(&[DataTypeId::Date32 as u8])?;
}
DataType::Utf8 => {
self.inner.write_all(&[DataTypeId::Utf8 as u8])?;
}
DataType::Binary => {
self.inner.write_all(&[DataTypeId::Binary as u8])?;
}
DataType::Decimal128(p, s) if *s >= 0 => {
self.inner
.write_all(&[DataTypeId::Decimal128 as u8, *p, *s as u8])?;
Expand Down Expand Up @@ -118,6 +137,16 @@ impl<W: Write> BatchWriter<W> {
self.write_data_type(col.data_type())?;
// array contents
match col.data_type() {
DataType::Boolean => {
let arr = col.as_any().downcast_ref::<BooleanArray>().unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Int32 => {
let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();

Expand All @@ -126,13 +155,7 @@ impl<W: Write> BatchWriter<W> {
let buffer = buffer.inner();
self.write_buffer(buffer)?;

if let Some(nulls) = arr.nulls() {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
self.write_null_buffer(arr.nulls())?;
}
DataType::Int64 => {
let arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
Expand All @@ -142,13 +165,17 @@ impl<W: Write> BatchWriter<W> {
let buffer = buffer.inner();
self.write_buffer(buffer)?;

if let Some(nulls) = arr.nulls() {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
self.write_null_buffer(arr.nulls())?;
}
DataType::Float64 => {
let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Date32 => {
let arr = col.as_any().downcast_ref::<Date32Array>().unwrap();
Expand All @@ -158,13 +185,7 @@ impl<W: Write> BatchWriter<W> {
let buffer = buffer.inner();
self.write_buffer(buffer)?;

if let Some(nulls) = arr.nulls() {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
self.write_null_buffer(arr.nulls())?;
}
DataType::Decimal128(_, _) => {
let arr = col.as_any().downcast_ref::<Decimal128Array>().unwrap();
Expand All @@ -174,13 +195,7 @@ impl<W: Write> BatchWriter<W> {
let buffer = buffer.inner();
self.write_buffer(buffer)?;

if let Some(nulls) = arr.nulls() {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
self.write_null_buffer(arr.nulls())?;
}
DataType::Utf8 => {
let arr = col.as_any().downcast_ref::<StringArray>().unwrap();
Expand All @@ -195,13 +210,22 @@ impl<W: Write> BatchWriter<W> {
let buffer = scalar_buffer.inner();
self.write_buffer(buffer)?;

if let Some(nulls) = arr.nulls() {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
self.write_null_buffer(arr.nulls())?;
}
DataType::Binary => {
let arr = col.as_any().downcast_ref::<BinaryArray>().unwrap();

// write data buffer
let buffer = arr.values();
self.write_buffer(buffer)?;

// write offset buffer
let offsets = arr.offsets();
let scalar_buffer = offsets.inner();
let buffer = scalar_buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Dictionary(k, _) if **k == DataType::Int32 => {
let arr = col
Expand All @@ -220,6 +244,20 @@ impl<W: Write> BatchWriter<W> {
Ok(())
}

fn write_null_buffer(
&mut self,
null_buffer: Option<&NullBuffer>,
) -> Result<(), DataFusionError> {
if let Some(nulls) = null_buffer {
let buffer = nulls.inner();
let buffer = buffer.inner();
self.write_buffer(buffer)?;
} else {
self.inner.write_all(&0_usize.to_le_bytes())?;
}
Ok(())
}

pub fn inner(self) -> W {
self.inner
}
Expand Down Expand Up @@ -277,6 +315,14 @@ impl<'a> BatchReader<'a> {
// read data type
let data_type = self.read_data_type()?;
Ok(match data_type {
DataType::Boolean => {
let buffer = self.read_buffer();
// TODO check length calculation
let length = buffer.len();
let data_buffer = BooleanBuffer::new(buffer, 0, length * 8);
let null_buffer = self.read_null_buffer();
Arc::new(BooleanArray::new(data_buffer, null_buffer))
}
DataType::Int32 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i32>::from(buffer);
Expand All @@ -289,6 +335,12 @@ impl<'a> BatchReader<'a> {
let null_buffer = self.read_null_buffer();
Arc::new(Int64Array::try_new(data_buffer, null_buffer)?)
}
DataType::Float64 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<f64>::from(buffer);
let null_buffer = self.read_null_buffer();
Arc::new(Float64Array::try_new(data_buffer, null_buffer)?)
}
DataType::Date32 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i32>::from(buffer);
Expand All @@ -310,6 +362,12 @@ impl<'a> BatchReader<'a> {
let null_buffer = self.read_null_buffer();
Arc::new(StringArray::try_new(offset_buffer, buffer, null_buffer)?)
}
DataType::Binary => {
let buffer = self.read_buffer();
let offset_buffer = self.read_offset_buffer();
let null_buffer = self.read_null_buffer();
Arc::new(StringArray::try_new(offset_buffer, buffer, null_buffer)?)
}
DataType::Dictionary(_, _) => {
let k = self.read_array()?;
let v = self.read_array()?;
Expand All @@ -329,10 +387,13 @@ impl<'a> BatchReader<'a> {
fn read_data_type(&mut self) -> Result<DataType, DataFusionError> {
let type_id = self.input[self.offset] as i32;
let data_type = match type_id {
x if x == DataTypeId::Boolean as i32 => DataType::Boolean,
x if x == DataTypeId::Int32 as i32 => DataType::Int32,
x if x == DataTypeId::Int64 as i32 => DataType::Int64,
x if x == DataTypeId::Float64 as i32 => DataType::Float64,
x if x == DataTypeId::Date32 as i32 => DataType::Date32,
x if x == DataTypeId::Utf8 as i32 => DataType::Utf8,
x if x == DataTypeId::Binary as i32 => DataType::Binary,
x if x == DataTypeId::Dictionary as i32 => {
self.offset += 1;
DataType::Dictionary(
Expand Down Expand Up @@ -405,7 +466,7 @@ impl<'a> BatchReader<'a> {
mod test {
use super::*;
use arrow_array::builder::{
Date32Builder, Decimal128Builder, Int32Builder, StringDictionaryBuilder,
BooleanBuilder, Date32Builder, Decimal128Builder, Int32Builder, StringDictionaryBuilder,
};
use std::sync::Arc;

Expand Down Expand Up @@ -434,31 +495,42 @@ mod test {
),
Field::new("c2", DataType::Date32, true),
Field::new("c3", DataType::Decimal128(11, 2), true),
Field::new("c4", DataType::Boolean, true),
]));
let mut a = Int32Builder::new();
let mut b = StringDictionaryBuilder::new();
let mut c = Date32Builder::new();
let mut d = Decimal128Builder::new()
.with_precision_and_scale(11, 2)
.unwrap();
let mut c4_bool = BooleanBuilder::with_capacity(num_rows);
for i in 0..num_rows {
a.append_value(i as i32);
c.append_value(i as i32);
d.append_value((i * 1000000) as i128);
if allow_nulls && i % 10 == 0 {
b.append_null();
c4_bool.append_null();
} else {
// test for dictionary-encoded strings
b.append_value(format!("this string is repeated a lot"));
c4_bool.append_value(i % 2 == 0)
}
}
let a = a.finish();
let b: DictionaryArray<Int32Type> = b.finish();
let c = c.finish();
let d = d.finish();
let c4_bool = c4_bool.finish();
RecordBatch::try_new(
schema.clone(),
vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)],
vec![
Arc::new(a),
Arc::new(b),
Arc::new(c),
Arc::new(d),
Arc::new(c4_bool),
],
)
.unwrap()
}
Expand Down

0 comments on commit 474a051

Please sign in to comment.