Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poll inner connection until EOF after gzip decoder is complete #2484

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/async_impl/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,19 @@ impl HttpBody for Decoder {
}
#[cfg(feature = "gzip")]
Inner::Gzip(ref mut decoder) => {
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
None => { // poll inner connection until EOF after gzip stream is finished
let gzip_decoder = decoder.get_mut();
let stream_reader = gzip_decoder.get_mut();
let peekable_io_stream = stream_reader.get_mut();
match futures_core::ready!(Pin::new(peekable_io_stream).poll_next(cx)) {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
Andrey36652 marked this conversation as resolved.
Show resolved Hide resolved
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
None => Poll::Ready(None),
}
}
}
}
#[cfg(feature = "brotli")]
Expand Down
171 changes: 171 additions & 0 deletions tests/gzip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ mod support;
use support::server;

use std::io::Write;
use tokio::io::AsyncWriteExt;
use tokio::time::Duration;

#[tokio::test]
async fn gzip_response() {
Expand Down Expand Up @@ -149,3 +151,172 @@ async fn gzip_case(response_size: usize, chunk_size: usize) {
let body = res.text().await.expect("text");
assert_eq!(body, content);
}

#[tokio::test]
async fn test_non_chunked_non_fragmented_response() {
let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let response = b"HTTP/1.1 200 OK\x0d\x0a\
Content-Type: text/plain\x0d\x0a\
Connection: keep-alive\x0d\x0a\
Content-Encoding: gzip\x0d\x0a\
Content-Length: 85\x0d\x0a\
\x0d\x0a\
\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\xabV*\xae\xccM\xca\xcfQ\xb2Rr\x0aq\x0e\x0dv\x09Q\xd2Q\xca/H\xcd\xf3\xcc+I-J-.\x01J\x98\x1b\x18\x98\x9a\xe9\x99\x9a\x18\x03\xa5J2sS\x95\xac\x0c\xcd\x8d\x8cM\x8cLML\x0c---j\x01\xd7Gb;D\x00\x00\x00";

client_socket
.write_all(response)
.await
.expect("response write_all failed");
client_socket.flush().await.expect("response flush failed");
})
});

let client = reqwest::Client::builder()
.connection_verbose(true)
.timeout(Duration::from_secs(15))
.pool_idle_timeout(Some(std::time::Duration::from_secs(300)))
.pool_max_idle_per_host(5)
.build()
.expect("reqwest client init error");

let res = client
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

let body = res.text().await.expect("text");
assert_eq!(
body,
r#"{"symbol":"BTCUSDT","openInterest":"70056.543","time":1723425441998}"#
);
}

#[tokio::test]
async fn test_chunked_fragmented_response_1() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000);
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);

let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let response_first_part = b"HTTP/1.1 200 OK\x0d\x0a\
Content-Type: text/plain\x0d\x0a\
Transfer-Encoding: chunked\x0d\x0a\
Connection: keep-alive\x0d\x0a\
Content-Encoding: gzip\x0d\x0a\
\x0d\x0a\
55\x0d\x0a\
\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\xabV*\xae\xccM\xca\xcfQ\xb2Rr\x0aq\x0e\x0dv\x09Q\xd2Q\xca/H\xcd\xf3\xcc+I-J-.\x01J\x98\x1b\x18\x98\x9a\xe9\x99\x9a\x18\x03\xa5J2sS\x95\xac\x0c\xcd\x8d\x8cM\x8cLML\x0c---j\x01\xd7Gb;D\x00\x00\x00";
let response_second_part = b"\x0d\x0a0\x0d\x0a\x0d\x0a";

client_socket
.write_all(response_first_part)
.await
.expect("response_first_part write_all failed");
client_socket
.flush()
.await
.expect("response_first_part flush failed");

tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

client_socket
.write_all(response_second_part)
.await
.expect("response_second_part write_all failed");
client_socket
.flush()
.await
.expect("response_second_part flush failed");
})
});

let start = tokio::time::Instant::now();

let client = reqwest::Client::builder()
.connection_verbose(true)
.timeout(Duration::from_secs(15))
.pool_idle_timeout(Some(std::time::Duration::from_secs(300)))
.pool_max_idle_per_host(5)
.build()
.expect("reqwest client init error");

let res = client
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

let body = res.text().await.expect("text");
assert_eq!(
body,
r#"{"symbol":"BTCUSDT","openInterest":"70056.543","time":1723425441998}"#
);
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}

#[tokio::test]
async fn test_chunked_fragmented_response_2() {
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
tokio::time::Duration::from_millis(1000);
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);

let server = server::low_level_with_response(|_raw_request, client_socket| {
Box::new(async move {
let response_first_part = b"HTTP/1.1 200 OK\x0d\x0a\
Content-Type: text/plain\x0d\x0a\
Transfer-Encoding: chunked\x0d\x0a\
Connection: keep-alive\x0d\x0a\
Content-Encoding: gzip\x0d\x0a\
\x0d\x0a\
55\x0d\x0a\
\x1f\x8b\x08\x00\x00\x00\x00\x00\x00\x03\xabV*\xae\xccM\xca\xcfQ\xb2Rr\x0aq\x0e\x0dv\x09Q\xd2Q\xca/H\xcd\xf3\xcc+I-J-.\x01J\x98\x1b\x18\x98\x9a\xe9\x99\x9a\x18\x03\xa5J2sS\x95\xac\x0c\xcd\x8d\x8cM\x8cLML\x0c---j\x01\xd7Gb;D\x00\x00\x00\x0d\x0a";
let response_second_part = b"0\x0d\x0a\x0d\x0a";

client_socket
.write_all(response_first_part)
.await
.expect("response_first_part write_all failed");
client_socket
.flush()
.await
.expect("response_first_part flush failed");

tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;

client_socket
.write_all(response_second_part)
.await
.expect("response_second_part write_all failed");
client_socket
.flush()
.await
.expect("response_second_part flush failed");
})
});

let start = tokio::time::Instant::now();

let client = reqwest::Client::builder()
.connection_verbose(true)
.timeout(Duration::from_secs(15))
.pool_idle_timeout(Some(std::time::Duration::from_secs(300)))
.pool_max_idle_per_host(5)
.build()
.expect("reqwest client init error");

let res = client
.get(&format!("http://{}/", server.addr()))
.send()
.await
.expect("response");

let body = res.text().await.expect("text");
assert_eq!(
body,
r#"{"symbol":"BTCUSDT","openInterest":"70056.543","time":1723425441998}"#
);
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
}
102 changes: 102 additions & 0 deletions tests/support/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::sync::mpsc as std_mpsc;
use std::thread;
use std::time::Duration;

use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::runtime;
use tokio::sync::oneshot;

Expand Down Expand Up @@ -240,3 +242,103 @@ where
.join()
.unwrap()
}

pub fn low_level_with_response<F>(do_response: F) -> Server
where
for <'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box<dyn Future<Output = ()> + Send + 'c> + Clone + Send + 'static,
{
// Spawn new runtime in thread to prevent reactor execution context conflict
let test_name = thread::current().name().unwrap_or("<unknown>").to_string();
thread::spawn(move || {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("new rt");
let listener = rt.block_on(async move {
tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap()
});
let addr = listener.local_addr().unwrap();

let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let (panic_tx, panic_rx) = std_mpsc::channel();
let (events_tx, events_rx) = std_mpsc::channel();
let tname = format!(
"test({})-support-server",
test_name,
);
thread::Builder::new()
.name(tname)
.spawn(move || {
rt.block_on(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => {
break;
}
accepted = listener.accept() => {
let (io, _) = accepted.expect("accepted");
let do_response = do_response.clone();
let events_tx = events_tx.clone();
tokio::spawn(async move {
low_level_server_client(io, do_response).await;
let _ = events_tx.send(Event::ConnectionClosed);
});
}
}
}
let _ = panic_tx.send(());
});
})
.expect("thread spawn");
Server {
addr,
panic_rx,
events_rx,
shutdown_tx: Some(shutdown_tx),
}
})
.join()
.unwrap()
}

async fn low_level_server_client<F>(mut client_socket: TcpStream, do_response: F)
where
for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box<dyn Future<Output = ()> + Send + 'c>
{
loop {
let request = low_level_read_http_request(&mut client_socket)
.await
.expect("read_http_request failed");
if request.is_empty() { // connection closed by client
break;
}

Box::into_pin(do_response(&request, &mut client_socket)).await;
}
}

async fn low_level_read_http_request(
client_socket: &mut TcpStream,
) -> core::result::Result<Vec<u8>, std::io::Error> {
let mut buf = Vec::new();

// Read until the delimiter "\r\n\r\n" is found
loop {
let mut temp_buffer = [0; 1024];
let n = client_socket.read(&mut temp_buffer).await?;

if n == 0 {
break;
}

buf.extend_from_slice(&temp_buffer[..n]);

if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") {
return Ok(buf.drain(..pos + 4).collect());
}
}

Ok(buf)
}
Loading