Skip to content

Commit

Permalink
Make an effort to ensure that write() does not return Ok(0)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmja committed Jul 1, 2024
1 parent 1c27698 commit 9cdc76d
Showing 1 changed file with 85 additions and 7 deletions.
92 changes: 85 additions & 7 deletions src/body_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,20 @@ where

if self.header_pos + EMPTY_CHUNK.len() > self.buf.len() {
// There is not enough space to fit the empty chunk in the buffer
self.emit_finished_chunks().await?;
self.emit_buffered().await?;
}
}

self.buf[self.header_pos..self.header_pos + EMPTY_CHUNK.len()].copy_from_slice(EMPTY_CHUNK);
self.header_pos += EMPTY_CHUNK.len();
self.allocated_header = 0;
self.pos = self.header_pos + self.allocated_header;
self.emit_finished_chunks().await
self.emit_buffered().await
}

/// Append to the buffer
fn append_current_chunk(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.buf.len() - NEWLINE.len() - self.pos);
let buffered = usize::min(buf.len(), self.buf.len().saturating_sub(NEWLINE.len() + self.pos));
if buffered > 0 {
self.buf[self.pos..self.pos + buffered].copy_from_slice(&buf[..buffered]);
self.pos += buffered;
Expand Down Expand Up @@ -196,7 +196,7 @@ where
self.pos = self.header_pos + self.allocated_header;
}

async fn emit_finished_chunks(&mut self) -> Result<(), C::Error> {
async fn emit_buffered(&mut self) -> Result<(), C::Error> {
self.conn.write_all(&self.buf[..self.header_pos]).await?;
self.header_pos = 0;
self.allocated_header = get_max_chunk_header_size(self.buf.len());
Expand All @@ -217,10 +217,20 @@ where
C: Write,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
let written = self.append_current_chunk(buf);
if buf.is_empty() {
return Ok(0);
}

let mut written = self.append_current_chunk(buf);
if written == 0 {
// Unable to append any data to the buffer
// This can happen if the writer was pre-loaded with data
self.emit_buffered().await.map_err(|e| e.kind())?;
written = self.append_current_chunk(buf);
}
if written < buf.len() {
self.finish_current_chunk();
self.emit_finished_chunks().await.map_err(|e| e.kind())?;
self.emit_buffered().await.map_err(|e| e.kind())?;
}
Ok(written)
}
Expand All @@ -229,7 +239,10 @@ where
if self.pos > self.header_pos + self.allocated_header {
// There are bytes written in the current chunk
self.finish_current_chunk();
self.emit_finished_chunks().await.map_err(|e| e.kind())?;
self.emit_buffered().await.map_err(|e| e.kind())?;
} else if self.header_pos > 0 {
// There are pre-written bytes in the buffer
self.emit_buffered().await.map_err(|e| e.kind())?;
}
self.conn.flush().await.map_err(|e| e.kind())
}
Expand Down Expand Up @@ -337,6 +350,71 @@ mod tests {
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn write_when_entire_buffer_is_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.write_all(b"BODY").await.unwrap(); // Cannot fit
writer.terminate().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn flush_when_entire_buffer_is_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf.copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}

#[tokio::test]
async fn flush_when_entire_buffer_is_nearly_prewritten() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 11];
buf[..10].copy_from_slice(b"HELLOHELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 10);
writer.flush().await.unwrap();

// Then
print!("{:?}", conn.as_slice());
assert_eq!(b"HELLOHELLO", conn.as_slice());
}

#[tokio::test]
async fn flushes_already_written_bytes_if_first_cannot_fit() {
// Given
let mut conn = Vec::new();
let mut buf = [0; 10];
buf[..5].copy_from_slice(b"HELLO");

// When
let mut writer = BufferingChunkedBodyWriter::new_with_data(&mut conn, &mut buf, 5);
writer.write_all(b"BODY").await.unwrap(); // Cannot fit
writer.terminate().await.unwrap(); // Can fit

// Then
assert_eq!(b"HELLO4\r\nBODY\r\n0\r\n\r\n", conn.as_slice());
}

#[tokio::test]
async fn current_chunk_is_emitted_before_empty_chunk_is_emitted() {
// Given
Expand Down

0 comments on commit 9cdc76d

Please sign in to comment.