Skip to content

Commit

Permalink
perf: fix decoder streams to make pooled connections reusable (#2484)
Browse files Browse the repository at this point in the history
When a response body is being decompressed, and the length wasn't known, but was using chunked transfer-encoding, the remaining `0\r\n\r\n` was not consumed. That would leave the connection in a state that could be not be reused, and so the pool had to discard it.

This fix makes sure the remaining end chunk is consumed, improving the amount of pooled connections that can be reused.

Closes #2381
  • Loading branch information
Andrey36652 authored Dec 3, 2024
1 parent 4367d30 commit d36c0f5
Show file tree
Hide file tree
Showing 6 changed files with 1,033 additions and 28 deletions.
116 changes: 88 additions & 28 deletions src/async_impl/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

#[cfg(any(
feature = "gzip",
feature = "zstd",
feature = "brotli",
feature = "deflate"
))]
use futures_util::stream::Fuse;

#[cfg(feature = "gzip")]
use async_compression::tokio::bufread::GzipDecoder;

Expand Down Expand Up @@ -108,19 +116,19 @@ enum Inner {

/// A `Gzip` decoder will uncompress the gzipped response content before returning it.
#[cfg(feature = "gzip")]
Gzip(Pin<Box<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>),
Gzip(Pin<Box<Fuse<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>>),

/// A `Brotli` decoder will uncompress the brotlied response content before returning it.
#[cfg(feature = "brotli")]
Brotli(Pin<Box<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>),
Brotli(Pin<Box<Fuse<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>>),

/// A `Zstd` decoder will uncompress the zstd compressed response content before returning it.
#[cfg(feature = "zstd")]
Zstd(Pin<Box<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>),
Zstd(Pin<Box<Fuse<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>>),

/// A `Deflate` decoder will uncompress the deflated response content before returning it.
#[cfg(feature = "deflate")]
Deflate(Pin<Box<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>),
Deflate(Pin<Box<Fuse<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>>),

/// A decoder that doesn't have a value yet.
#[cfg(any(
Expand Down Expand Up @@ -365,34 +373,74 @@ impl HttpBody for Decoder {
}
#[cfg(feature = "gzip")]
Inner::Gzip(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
None => {
// poll inner connection until EOF after gzip stream is finished
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
"there are extra bytes after body has been decompressed",
)))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
}
}
}
}
#[cfg(feature = "brotli")]
Inner::Brotli(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
None => {
// poll inner connection until EOF after brotli stream is finished
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
"there are extra bytes after body has been decompressed",
)))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
}
}
}
}
#[cfg(feature = "zstd")]
Inner::Zstd(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
None => {
// poll inner connection until EOF after zstd stream is finished
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
"there are extra bytes after body has been decompressed",
)))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
}
}
}
}
#[cfg(feature = "deflate")]
Inner::Deflate(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
None => {
// poll inner connection until EOF after deflate stream is finished
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
"there are extra bytes after body has been decompressed",
)))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
}
}
}
}
}
Expand Down Expand Up @@ -456,25 +504,37 @@ impl Future for Pending {

match self.1 {
#[cfg(feature = "brotli")]
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new(
BrotliDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
))))),
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(
FramedRead::new(
BrotliDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
)
.fuse(),
)))),
#[cfg(feature = "zstd")]
DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new(
ZstdDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
))))),
DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(
FramedRead::new(
ZstdDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
)
.fuse(),
)))),
#[cfg(feature = "gzip")]
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new(
GzipDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
))))),
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(
FramedRead::new(
GzipDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
)
.fuse(),
)))),
#[cfg(feature = "deflate")]
DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new(
ZlibDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
))))),
DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(
FramedRead::new(
ZlibDecoder::new(StreamReader::new(_body)),
BytesCodec::new(),
)
.fuse(),
)))),
}
}
}
Expand Down
210 changes: 210 additions & 0 deletions tests/brotli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod support;
use std::io::Read;
use support::server;
use tokio::io::AsyncWriteExt;

#[tokio::test]
async fn brotli_response() {
Expand Down Expand Up @@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) {
let body = res.text().await.expect("text");
assert_eq!(body, content);
}

const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\
Content-Type: text/plain\x0d\x0a\
Connection: keep-alive\x0d\x0a\
Content-Encoding: br\x0d\x0a";

const RESPONSE_CONTENT: &str = "some message here";

fn brotli_compress(input: &[u8]) -> Vec<u8> {
let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20);
let mut brotlied_content = Vec::new();
encoder.read_to_end(&mut brotlied_content).unwrap();
brotlied_content
}

#[tokio::test]
async fn test_non_chunked_non_fragmented_response() {
let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
let content_length_header =
format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes();
let response = [
COMPRESSED_RESPONSE_HEADERS,
&content_length_header,
&brotlied_content,
]
.concat();

client_socket
.write_all(response.as_slice())
.await
.expect("response write_all failed");
client_socket.flush().await.expect("response flush failed");
})
});

let res = reqwest::Client::new()
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
}

#[tokio::test]
async fn test_chunked_fragmented_response_1() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000);
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);

let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
let response_first_part = [
COMPRESSED_RESPONSE_HEADERS,
format!(
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
brotlied_content.len()
)
.as_bytes(),
&brotlied_content,
]
.concat();
let response_second_part = b"\r\n0\r\n\r\n";

client_socket
.write_all(response_first_part.as_slice())
.await
.expect("response_first_part write_all failed");
client_socket
.flush()
.await
.expect("response_first_part flush failed");

tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

client_socket
.write_all(response_second_part)
.await
.expect("response_second_part write_all failed");
client_socket
.flush()
.await
.expect("response_second_part flush failed");
})
});

let start = tokio::time::Instant::now();
let res = reqwest::Client::new()
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}

#[tokio::test]
async fn test_chunked_fragmented_response_2() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000);
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);

let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
let response_first_part = [
COMPRESSED_RESPONSE_HEADERS,
format!(
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
brotlied_content.len()
)
.as_bytes(),
&brotlied_content,
b"\r\n",
]
.concat();
let response_second_part = b"0\r\n\r\n";

client_socket
.write_all(response_first_part.as_slice())
.await
.expect("response_first_part write_all failed");
client_socket
.flush()
.await
.expect("response_first_part flush failed");

tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

client_socket
.write_all(response_second_part)
.await
.expect("response_second_part write_all failed");
client_socket
.flush()
.await
.expect("response_second_part flush failed");
})
});

let start = tokio::time::Instant::now();
let res = reqwest::Client::new()
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}

#[tokio::test]
async fn test_chunked_fragmented_response_with_extra_bytes() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000);
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);

let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
let response_first_part = [
COMPRESSED_RESPONSE_HEADERS,
format!(
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
brotlied_content.len()
)
.as_bytes(),
&brotlied_content,
]
.concat();
let response_second_part = b"\r\n2ab\r\n0\r\n\r\n";

client_socket
.write_all(response_first_part.as_slice())
.await
.expect("response_first_part write_all failed");
client_socket
.flush()
.await
.expect("response_first_part flush failed");

tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

client_socket
.write_all(response_second_part)
.await
.expect("response_second_part write_all failed");
client_socket
.flush()
.await
.expect("response_second_part flush failed");
})
});

let start = tokio::time::Instant::now();
let res = reqwest::Client::new()
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

let err = res.text().await.expect_err("there must be an error");
assert!(err.is_decode());
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}
Loading

0 comments on commit d36c0f5

Please sign in to comment.