Skip to content

Commit

Permalink
add timestamp support
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 30, 2024
1 parent e16d24a commit 34dd489
Showing 1 changed file with 70 additions and 15 deletions.
85 changes: 70 additions & 15 deletions native/core/src/execution/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use arrow_array::types::Int32Type;
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Decimal128Array, DictionaryArray,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatch,
RecordBatchOptions, StringArray,
RecordBatchOptions, StringArray, TimestampMicrosecondArray,
};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Schema};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use datafusion_common::DataFusionError;
use std::io::Write;
use std::sync::Arc;
Expand All @@ -38,12 +38,14 @@ pub fn fast_codec_supports_type(data_type: &DataType) -> bool {
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Timestamp(TimeUnit::Microsecond, _)
| DataType::Utf8
| DataType::Binary => true,
DataType::Decimal128(_, s) if *s >= 0 => true,
DataType::Dictionary(k, v) if **k == DataType::Int32 => fast_codec_supports_type(v),
_ => {
println!("UNSUPPORTED: {data_type:?}");
// TODO remove this temp debug logging before merging
println!("Native shuffle fast codec does not support data type: {data_type:?}");
false
}
}
Expand All @@ -56,7 +58,8 @@ enum DataTypeId {
Int32,
Int64,
Date32,
// TODO Timestamp(Microsecond) with and without timezone
Timestamp,
TimestampNtz,
Decimal128,
Float32,
Float64,
Expand Down Expand Up @@ -114,6 +117,15 @@ impl<W: Write> BatchWriter<W> {
DataType::Date32 => {
self.inner.write_all(&[DataTypeId::Date32 as u8])?;
}
DataType::Timestamp(TimeUnit::Microsecond, None) => {
self.inner.write_all(&[DataTypeId::TimestampNtz as u8])?;
}
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
self.inner.write_all(&[DataTypeId::Timestamp as u8])?;
let tz_bytes = tz.as_bytes();
self.inner.write_all(&tz_bytes.len().to_le_bytes())?;
self.inner.write_all(tz_bytes)?;
}
DataType::Utf8 => {
self.inner.write_all(&[DataTypeId::Utf8 as u8])?;
}
Expand Down Expand Up @@ -239,6 +251,19 @@ impl<W: Write> BatchWriter<W> {

self.write_null_buffer(arr.nulls())?;
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let arr = col
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap();

// write data buffer
let buffer = arr.values();
let buffer = buffer.inner();
self.write_buffer(buffer)?;

self.write_null_buffer(arr.nulls())?;
}
DataType::Decimal128(_, _) => {
let arr = col.as_any().downcast_ref::<Decimal128Array>().unwrap();

Expand Down Expand Up @@ -330,11 +355,17 @@ impl<W: Write> BatchWriter<W> {
pub struct BatchReader<'a> {
input: &'a [u8],
offset: usize,
/// buffer for reading usize
length: [u8; 8],
}

impl<'a> BatchReader<'a> {
pub fn new(input: &'a [u8]) -> Self {
Self { input, offset: 0 }
Self {
input,
offset: 0,
length: [0; 8],
}
}

pub fn read_batch(&mut self) -> Result<RecordBatch, DataFusionError> {
Expand All @@ -345,15 +376,7 @@ impl<'a> BatchReader<'a> {

let mut field_names: Vec<String> = Vec::with_capacity(schema_len);
for _ in 0..schema_len {
// read field name length
length.copy_from_slice(&self.input[self.offset..self.offset + 8]);
let field_name_len = usize::from_le_bytes(length);
self.offset += 8;

// read field name
let field_name_bytes = &self.input[self.offset..self.offset + field_name_len];
field_names.push(unsafe { String::from_utf8_unchecked(field_name_bytes.into()) });
self.offset += field_name_len;
field_names.push(self.read_string());
}

length.copy_from_slice(&self.input[self.offset..self.offset + 8]);
Expand Down Expand Up @@ -434,6 +457,15 @@ impl<'a> BatchReader<'a> {
let null_buffer = self.read_null_buffer();
Arc::new(Date32Array::try_new(data_buffer, null_buffer)?)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i64>::from(buffer);
let null_buffer = self.read_null_buffer();
Arc::new(TimestampMicrosecondArray::try_new(
data_buffer,
null_buffer,
)?)
}
DataType::Decimal128(p, s) => {
let buffer = self.read_buffer();
let data_buffer = ScalarBuffer::<i128>::from(buffer);
Expand Down Expand Up @@ -482,6 +514,15 @@ impl<'a> BatchReader<'a> {
x if x == DataTypeId::Float32 as i32 => DataType::Float32,
x if x == DataTypeId::Float64 as i32 => DataType::Float64,
x if x == DataTypeId::Date32 as i32 => DataType::Date32,
x if x == DataTypeId::TimestampNtz as i32 => {
DataType::Timestamp(TimeUnit::Microsecond, None)
}
x if x == DataTypeId::Timestamp as i32 => {
self.offset += 1;
let tz = self.read_string();
let tz: Arc<str> = Arc::from(tz.into_boxed_str());
DataType::Timestamp(TimeUnit::Microsecond, Some(tz))
}
x if x == DataTypeId::Utf8 as i32 => DataType::Utf8,
x if x == DataTypeId::Binary as i32 => DataType::Binary,
x if x == DataTypeId::Dictionary as i32 => {
Expand All @@ -502,7 +543,7 @@ impl<'a> BatchReader<'a> {
}
};
match data_type {
DataType::Dictionary(_, _) => {
DataType::Dictionary(_, _) | DataType::Timestamp(_, _) => {
// no need to increment
}
DataType::Decimal128(_, _) => self.offset += 3,
Expand All @@ -511,6 +552,20 @@ impl<'a> BatchReader<'a> {
Ok(data_type)
}

fn read_string(&mut self) -> String {
// read field name length
self.length
.copy_from_slice(&self.input[self.offset..self.offset + 8]);
let field_name_len = usize::from_le_bytes(self.length);
self.offset += 8;

// read field name
let field_name_bytes = &self.input[self.offset..self.offset + field_name_len];
let str = unsafe { String::from_utf8_unchecked(field_name_bytes.into()) };
self.offset += field_name_len;
str
}

fn read_offset_buffer(&mut self) -> OffsetBuffer<i32> {
let offset_buffer = self.read_buffer();
let offset_buffer: ScalarBuffer<i32> = ScalarBuffer::from(offset_buffer);
Expand Down

0 comments on commit 34dd489

Please sign in to comment.