Skip to content

Commit

Permalink
precompute batch header
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 29, 2024
1 parent 1808231 commit f0793f0
Showing 1 changed file with 112 additions and 76 deletions.
188 changes: 112 additions & 76 deletions native/core/src/execution/shuffle/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,10 @@ pub enum CompressionCodec {
}

pub struct ShuffleBlockWriter {
enable_fast_encoding: bool,
fast_encoding: bool,
codec: CompressionCodec,
encoded_schema: Vec<u8>,
header_bytes: Vec<u8>,
}

impl ShuffleBlockWriter {
Expand All @@ -1582,15 +1583,52 @@ impl ShuffleBlockWriter {
codec: CompressionCodec,
) -> Result<Self> {
let mut encoded_schema = vec![];

let enable_fast_encoding = enable_fast_encoding
&& schema
.fields()
.iter()
.all(|f| fast_codec_supports_type(f.data_type()));

// encode the schema once and then reuse the encoded bytes for each batch
if enable_fast_encoding {
// encode the schema once and then reuse the encoded bytes for each batch
let mut w = BatchWriter::new(&mut encoded_schema);
w.write_partial_schema(schema)?;
}

let header_bytes = Vec::with_capacity(24);
let mut cursor = Cursor::new(header_bytes);

// write placeholder for compressed message length
cursor.write_all(&[0u8; 8])?;

// write number of columns because JVM side needs to know how many addresses to allocate
let field_count = schema.fields().len();
cursor.write_all(&field_count.to_le_bytes())?;

// write compression codec to header
let codec_header = match &codec {
CompressionCodec::Snappy => b"SNAP",
CompressionCodec::Lz4Frame => b"LZ4_",
CompressionCodec::Zstd(_) => b"ZSTD",
CompressionCodec::None => b"NONE",
};
cursor.write_all(codec_header)?;

// write encoding scheme
if enable_fast_encoding {
cursor.write_all(b"FAST")?;
} else {
cursor.write_all(b"AIPC")?;
}

let header_bytes = cursor.into_inner();

Ok(Self {
enable_fast_encoding,
fast_encoding: enable_fast_encoding,
codec,
encoded_schema,
header_bytes,
})
}

Expand All @@ -1609,82 +1647,80 @@ impl ShuffleBlockWriter {
let mut timer = ipc_time.timer();
let start_pos = output.stream_position()?;

// write ipc_length placeholder
output.write_all(&[0u8; 8])?;
// write header
output.write_all(&self.header_bytes)?;

// write number of columns because JVM side needs to know how many addresses to allocate
let field_count = batch.schema().fields().len();
output.write_all(&field_count.to_le_bytes())?;

// write compression codec used
match &self.codec {
CompressionCodec::Snappy => output.write_all(b"SNAP")?,
CompressionCodec::Lz4Frame => output.write_all(b"LZ4_")?,
CompressionCodec::Zstd(_) => output.write_all(b"ZSTD")?,
CompressionCodec::None => output.write_all(b"NONE")?,
}

// write encoding method used
let fast_encoding = self.enable_fast_encoding
&& batch
.schema()
.fields()
.iter()
.all(|f| fast_codec_supports_type(f.data_type()));

if fast_encoding {
output.write_all(b"FAST")?;
let output = if self.fast_encoding {
match &self.codec {
CompressionCodec::None => {
let mut fast_writer = BatchWriter::new(&mut *output);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
output
}
CompressionCodec::Lz4Frame => {
let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
let mut fast_writer = BatchWriter::new(&mut wtr);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
wtr.finish().map_err(|e| {
DataFusionError::Execution(format!("lz4 compression error: {}", e))
})?
}
CompressionCodec::Zstd(level) => {
let mut encoder = zstd::Encoder::new(output, *level)?;
let mut fast_writer = BatchWriter::new(&mut encoder);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
encoder.finish()?
}
CompressionCodec::Snappy => {
let mut encoder = snap::write::FrameEncoder::new(output);
let mut fast_writer = BatchWriter::new(&mut encoder);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
encoder.into_inner().map_err(|e| {
DataFusionError::Execution(format!("snappy compression error: {}", e))
})?
}
}
} else {
output.write_all(b"AIPC")?;
}
match &self.codec {
CompressionCodec::None => {
let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
arrow_writer.into_inner()?
}
CompressionCodec::Lz4Frame => {
let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
wtr.finish().map_err(|e| {
DataFusionError::Execution(format!("lz4 compression error: {}", e))
})?
}

let output = match (fast_encoding, &self.codec) {
(false, CompressionCodec::None) => {
let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
arrow_writer.into_inner()?
}
(true, CompressionCodec::None) => {
let mut fast_writer = BatchWriter::new(&mut *output);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
output
}
(false, CompressionCodec::Lz4Frame) => {
let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
wtr.finish().map_err(|e| {
DataFusionError::Execution(format!("lz4 compression error: {}", e))
})?
}
(true, CompressionCodec::Lz4Frame) => {
let mut wtr = lz4_flex::frame::FrameEncoder::new(output);
let mut fast_writer = BatchWriter::new(&mut wtr);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
wtr.finish().map_err(|e| {
DataFusionError::Execution(format!("lz4 compression error: {}", e))
})?
}
(false, CompressionCodec::Zstd(level)) => {
let encoder = zstd::Encoder::new(output, *level)?;
let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
let zstd_encoder = arrow_writer.into_inner()?;
zstd_encoder.finish()?
}
(true, CompressionCodec::Zstd(level)) => {
let mut encoder = zstd::Encoder::new(output, *level)?;
let mut fast_writer = BatchWriter::new(&mut encoder);
fast_writer.write_all(&self.encoded_schema)?;
fast_writer.write_batch(batch)?;
encoder.finish()?
CompressionCodec::Zstd(level) => {
let encoder = zstd::Encoder::new(output, *level)?;
let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
let zstd_encoder = arrow_writer.into_inner()?;
zstd_encoder.finish()?
}

CompressionCodec::Snappy => {
let mut wtr = snap::write::FrameEncoder::new(output);
let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?;
arrow_writer.write(batch)?;
arrow_writer.finish()?;
wtr.into_inner().map_err(|e| {
DataFusionError::Execution(format!("snappy compression error: {}", e))
})?
}
}
_ => unreachable!(),
};

// fill ipc length
Expand All @@ -1693,7 +1729,7 @@ impl ShuffleBlockWriter {

// fill ipc length
output.seek(SeekFrom::Start(start_pos))?;
output.write_all(&ipc_length.to_le_bytes()[..])?;
output.write_all(&ipc_length.to_le_bytes())?;
output.seek(SeekFrom::Start(end_pos))?;

timer.stop();
Expand Down

0 comments on commit f0793f0

Please sign in to comment.