Skip to content

Commit

Permalink
support more data types (Int8, Int16, Float32)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 30, 2024
1 parent db055e1 commit 74b6f0e
Showing 1 changed file with 77 additions and 9 deletions.
86 changes: 77 additions & 9 deletions native/core/src/execution/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use arrow_array::cast::AsArray;
use arrow_array::types::Int32Type;
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray,
Float64Array, Int32Array, Int64Array, RecordBatch, RecordBatchOptions, StringArray,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch,
RecordBatchOptions, StringArray,
};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Schema};
Expand All @@ -30,8 +31,11 @@ use std::sync::Arc;
pub fn fast_codec_supports_type(data_type: &DataType) -> bool {
match data_type {
DataType::Boolean
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Utf8
Expand All @@ -47,14 +51,18 @@ pub fn fast_codec_supports_type(data_type: &DataType) -> bool {

enum DataTypeId {
Boolean = 0,
Float64 = 1,
Int32 = 2,
Int64 = 3,
Date32 = 4,
Decimal128 = 5,
Utf8 = 6,
Dictionary = 7,
Binary = 8,
Int8,
Int16,
Int32,
Int64,
Date32,
// TODO Timestamp(Microsecond) with and without timezone
Decimal128,
Float32,
Float64,
Utf8,
Binary,
Dictionary,
}

pub struct BatchWriter<W: Write> {
Expand Down Expand Up @@ -85,12 +93,21 @@ impl<W: Write> BatchWriter<W> {
DataType::Boolean => {
self.inner.write_all(&[DataTypeId::Boolean as u8])?;
}
DataType::Int8 => {
self.inner.write_all(&[DataTypeId::Int8 as u8])?;
}
DataType::Int16 => {
self.inner.write_all(&[DataTypeId::Int16 as u8])?;
}
DataType::Int32 => {
self.inner.write_all(&[DataTypeId::Int32 as u8])?;
}
DataType::Int64 => {
self.inner.write_all(&[DataTypeId::Int64 as u8])?;
}
DataType::Float32 => {
self.inner.write_all(&[DataTypeId::Float32 as u8])?;
}
DataType::Float64 => {
self.inner.write_all(&[DataTypeId::Float64 as u8])?;
}
Expand Down Expand Up @@ -152,6 +169,26 @@ impl<W: Write> BatchWriter<W> {

self.write_null_buffer(arr.nulls())?;
}
DataType::Int8 => {
let arr = col.as_any().downcast_ref::<Int8Array>().unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Int16 => {
let arr = col.as_any().downcast_ref::<Int16Array>().unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Int32 => {
let arr = col.as_any().downcast_ref::<Int32Array>().unwrap();

Expand All @@ -172,6 +209,16 @@ impl<W: Write> BatchWriter<W> {

self.write_null_buffer(arr.nulls())?;
}
DataType::Float32 => {
let arr = col.as_any().downcast_ref::<Float32Array>().unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Float64 => {
let arr = col.as_any().downcast_ref::<Float64Array>().unwrap();

Expand Down Expand Up @@ -345,6 +392,18 @@ impl<'a> BatchReader<'a> {
let null_buffer = self.read_null_buffer();
Arc::new(BooleanArray::new(data_buffer, null_buffer))
}
DataType::Int8 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i8>::from(buffer);
let null_buffer = self.read_null_buffer();
Arc::new(Int8Array::try_new(data_buffer, null_buffer)?)
}
DataType::Int16 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i16>::from(buffer);
let null_buffer = self.read_null_buffer();
Arc::new(Int16Array::try_new(data_buffer, null_buffer)?)
}
DataType::Int32 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i32>::from(buffer);
Expand All @@ -357,6 +416,12 @@ impl<'a> BatchReader<'a> {
let null_buffer = self.read_null_buffer();
Arc::new(Int64Array::try_new(data_buffer, null_buffer)?)
}
DataType::Float32 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<f32>::from(buffer);
let null_buffer = self.read_null_buffer();
Arc::new(Float32Array::try_new(data_buffer, null_buffer)?)
}
DataType::Float64 => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<f64>::from(buffer);
Expand Down Expand Up @@ -410,8 +475,11 @@ impl<'a> BatchReader<'a> {
let type_id = self.input[self.offset] as i32;
let data_type = match type_id {
x if x == DataTypeId::Boolean as i32 => DataType::Boolean,
x if x == DataTypeId::Int8 as i32 => DataType::Int8,
x if x == DataTypeId::Int16 as i32 => DataType::Int16,
x if x == DataTypeId::Int32 as i32 => DataType::Int32,
x if x == DataTypeId::Int64 as i32 => DataType::Int64,
x if x == DataTypeId::Float32 as i32 => DataType::Float32,
x if x == DataTypeId::Float64 as i32 => DataType::Float64,
x if x == DataTypeId::Date32 as i32 => DataType::Date32,
x if x == DataTypeId::Utf8 as i32 => DataType::Utf8,
Expand Down

0 comments on commit 74b6f0e

Please sign in to comment.