diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index d742e6d35..96a27ac45 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] +use futures_util::stream::Fuse; + #[cfg(feature = "gzip")] use async_compression::tokio::bufread::GzipDecoder; @@ -108,19 +116,19 @@ enum Inner { /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] - Gzip(Pin, BytesCodec>>>), + Gzip(Pin, BytesCodec>>>>), /// A `Brotli` decoder will uncompress the brotlied response content before returning it. #[cfg(feature = "brotli")] - Brotli(Pin, BytesCodec>>>), + Brotli(Pin, BytesCodec>>>>), /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. #[cfg(feature = "zstd")] - Zstd(Pin, BytesCodec>>>), + Zstd(Pin, BytesCodec>>>>), /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] - Deflate(Pin, BytesCodec>>>), + Deflate(Pin, BytesCodec>>>>), /// A decoder that doesn't have a value yet. #[cfg(any( @@ -365,34 +373,74 @@ impl HttpBody for Decoder { } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after gzip stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after brotli stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "zstd")] Inner::Zstd(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after zstd stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after deflate stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } } @@ -456,25 +504,37 @@ impl Future for Pending { match self.1 { #[cfg(feature = "brotli")] - DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new( - BrotliDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin( + FramedRead::new( + BrotliDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "zstd")] - DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( - ZstdDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin( + FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "gzip")] - DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( - GzipDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin( + FramedRead::new( + GzipDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "deflate")] - DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new( - ZlibDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin( + FramedRead::new( + ZlibDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), } } } diff --git a/tests/brotli.rs b/tests/brotli.rs index 5c2b01849..ba116ed92 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -1,6 +1,7 @@ mod support; use std::io::Read; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn brotli_response() { @@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: br\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn brotli_compress(input: &[u8]) -> Vec { + let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20); + let mut brotlied_content = Vec::new(); + encoder.read_to_end(&mut brotlied_content).unwrap(); + brotlied_content +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &brotlied_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/deflate.rs b/tests/deflate.rs index ec27ba180..55331afc5 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -1,6 +1,7 @@ mod support; use std::io::Write; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn deflate_response() { @@ -148,3 +149,214 @@ async fn deflate_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: deflate\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn deflate_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::zlib::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to deflate encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", deflated_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &deflated_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/gzip.rs b/tests/gzip.rs index 57189e0ac..74ead8783 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -2,6 +2,8 @@ mod support; use support::server; use std::io::Write; +use tokio::io::AsyncWriteExt; +use tokio::time::Duration; #[tokio::test] async fn gzip_response() { @@ -149,3 +151,214 @@ async fn gzip_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: gzip\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn gzip_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::gzip::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to gzip encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", gzipped_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &gzipped_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/support/server.rs b/tests/support/server.rs index 29835ead1..79ebd2d8f 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -6,6 +6,8 @@ use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; use tokio::runtime; use tokio::sync::oneshot; @@ -240,3 +242,104 @@ where .join() .unwrap() } + +pub fn low_level_with_response(do_response: F) -> Server +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c> + + Clone + + Send + + 'static, +{ + // Spawn new runtime in thread to prevent reactor execution context conflict + let test_name = thread::current().name().unwrap_or("").to_string(); + thread::spawn(move || { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("new rt"); + let listener = rt.block_on(async move { + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap() + }); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (panic_tx, panic_rx) = std_mpsc::channel(); + let (events_tx, events_rx) = std_mpsc::channel(); + let tname = format!("test({})-support-server", test_name,); + thread::Builder::new() + .name(tname) + .spawn(move || { + rt.block_on(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + accepted = listener.accept() => { + let (io, _) = accepted.expect("accepted"); + let do_response = do_response.clone(); + let events_tx = events_tx.clone(); + tokio::spawn(async move { + low_level_server_client(io, do_response).await; + let _ = events_tx.send(Event::ConnectionClosed); + }); + } + } + } + let _ = panic_tx.send(()); + }); + }) + .expect("thread spawn"); + Server { + addr, + panic_rx, + events_rx, + shutdown_tx: Some(shutdown_tx), + } + }) + .join() + .unwrap() +} + +async fn low_level_server_client(mut client_socket: TcpStream, do_response: F) +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c>, +{ + loop { + let request = low_level_read_http_request(&mut client_socket) + .await + .expect("read_http_request failed"); + if request.is_empty() { + // connection closed by client + break; + } + + Box::into_pin(do_response(&request, &mut client_socket)).await; + } +} + +async fn low_level_read_http_request( + client_socket: &mut TcpStream, +) -> core::result::Result, std::io::Error> { + let mut buf = Vec::new(); + + // Read until the delimiter "\r\n\r\n" is found + loop { + let mut temp_buffer = [0; 1024]; + let n = client_socket.read(&mut temp_buffer).await?; + + if n == 0 { + break; + } + + buf.extend_from_slice(&temp_buffer[..n]); + + if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") { + return Ok(buf.drain(..pos + 4).collect()); + } + } + + Ok(buf) +} diff --git a/tests/zstd.rs b/tests/zstd.rs index d1886ee49..ed3914e79 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -1,5 +1,6 @@ mod support; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn zstd_response() { @@ -142,3 +143,209 @@ async fn zstd_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: zstd\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn zstd_compress(input: &[u8]) -> Vec { + zstd_crate::encode_all(input, 3).unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", zstded_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &zstded_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +}