From 4367d30ee3c6903f34c7e8c1f39015cf5554d8b1 Mon Sep 17 00:00:00 2001 From: caojen <43436876+caojen@users.noreply.github.com> Date: Sat, 30 Nov 2024 02:30:24 +0800 Subject: [PATCH 1/6] ci: pin hashbrown for msrv job (#2488) * ci: pin hashbrown for msrv job * ci: specify hashbrown@0.15.2 --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 408bccce0..9dd04008a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -298,6 +298,7 @@ jobs: cargo update -p tokio --precise 1.29.1 cargo update -p tokio-util --precise 0.7.11 cargo update -p idna_adapter --precise 1.1.0 + cargo update -p hashbrown@0.15.2 --precise 0.15.0 - uses: Swatinem/rust-cache@v2 From d36c0f5fd93f8190c9f39990ce4ec859c2b6d567 Mon Sep 17 00:00:00 2001 From: Andrey36652 <35865938+Andrey36652@users.noreply.github.com> Date: Tue, 3 Dec 2024 19:01:58 +0300 Subject: [PATCH 2/6] perf: fix decoder streams to make pooled connections reusable (#2484) When a response body is being decompressed, and the length wasn't known, but was using chunked transfer-encoding, the remaining `0\r\n\r\n` was not consumed. That would leave the connection in a state that could be not be reused, and so the pool had to discard it. This fix makes sure the remaining end chunk is consumed, improving the amount of pooled connections that can be reused. Closes #2381 --- src/async_impl/decoder.rs | 116 ++++++++++++++++----- tests/brotli.rs | 210 +++++++++++++++++++++++++++++++++++++ tests/deflate.rs | 212 +++++++++++++++++++++++++++++++++++++ tests/gzip.rs | 213 ++++++++++++++++++++++++++++++++++++++ tests/support/server.rs | 103 ++++++++++++++++++ tests/zstd.rs | 207 ++++++++++++++++++++++++++++++++++++ 6 files changed, 1033 insertions(+), 28 deletions(-) diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index d742e6d35..96a27ac45 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,14 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] +use futures_util::stream::Fuse; + #[cfg(feature = "gzip")] use async_compression::tokio::bufread::GzipDecoder; @@ -108,19 +116,19 @@ enum Inner { /// A `Gzip` decoder will uncompress the gzipped response content before returning it. #[cfg(feature = "gzip")] - Gzip(Pin, BytesCodec>>>), + Gzip(Pin, BytesCodec>>>>), /// A `Brotli` decoder will uncompress the brotlied response content before returning it. #[cfg(feature = "brotli")] - Brotli(Pin, BytesCodec>>>), + Brotli(Pin, BytesCodec>>>>), /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. #[cfg(feature = "zstd")] - Zstd(Pin, BytesCodec>>>), + Zstd(Pin, BytesCodec>>>>), /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] - Deflate(Pin, BytesCodec>>>), + Deflate(Pin, BytesCodec>>>>), /// A decoder that doesn't have a value yet. #[cfg(any( @@ -365,34 +373,74 @@ impl HttpBody for Decoder { } #[cfg(feature = "gzip")] Inner::Gzip(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after gzip stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "brotli")] Inner::Brotli(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after brotli stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "zstd")] Inner::Zstd(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after zstd stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { - match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) { Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), - None => Poll::Ready(None), + None => { + // poll inner connection until EOF after deflate stream is finished + let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut(); + match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) { + Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode( + "there are extra bytes after body has been decompressed", + )))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } } } } @@ -456,25 +504,37 @@ impl Future for Pending { match self.1 { #[cfg(feature = "brotli")] - DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new( - BrotliDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin( + FramedRead::new( + BrotliDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "zstd")] - DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( - ZstdDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin( + FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "gzip")] - DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( - GzipDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin( + FramedRead::new( + GzipDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), #[cfg(feature = "deflate")] - DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new( - ZlibDecoder::new(StreamReader::new(_body)), - BytesCodec::new(), - ))))), + DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin( + FramedRead::new( + ZlibDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ) + .fuse(), + )))), } } } diff --git a/tests/brotli.rs b/tests/brotli.rs index 5c2b01849..ba116ed92 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -1,6 +1,7 @@ mod support; use std::io::Read; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn brotli_response() { @@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: br\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn brotli_compress(input: &[u8]) -> Vec { + let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20); + let mut brotlied_content = Vec::new(); + encoder.read_to_end(&mut brotlied_content).unwrap(); + brotlied_content +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &brotlied_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + brotlied_content.len() + ) + .as_bytes(), + &brotlied_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/deflate.rs b/tests/deflate.rs index ec27ba180..55331afc5 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -1,6 +1,7 @@ mod support; use std::io::Write; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn deflate_response() { @@ -148,3 +149,214 @@ async fn deflate_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: deflate\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn deflate_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::zlib::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to deflate encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", deflated_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &deflated_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let deflated_content = deflate_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + deflated_content.len() + ) + .as_bytes(), + &deflated_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/gzip.rs b/tests/gzip.rs index 57189e0ac..74ead8783 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -2,6 +2,8 @@ mod support; use support::server; use std::io::Write; +use tokio::io::AsyncWriteExt; +use tokio::time::Duration; #[tokio::test] async fn gzip_response() { @@ -149,3 +151,214 @@ async fn gzip_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: gzip\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn gzip_compress(input: &[u8]) -> Vec { + let mut encoder = libflate::gzip::Encoder::new(Vec::new()).unwrap(); + match encoder.write(input) { + Ok(n) => assert!(n > 0, "Failed to write to encoder."), + _ => panic!("Failed to gzip encode string."), + }; + encoder.finish().into_result().unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", gzipped_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &gzipped_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let gzipped_content = gzip_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + gzipped_content.len() + ) + .as_bytes(), + &gzipped_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} diff --git a/tests/support/server.rs b/tests/support/server.rs index 29835ead1..79ebd2d8f 100644 --- a/tests/support/server.rs +++ b/tests/support/server.rs @@ -6,6 +6,8 @@ use std::sync::mpsc as std_mpsc; use std::thread; use std::time::Duration; +use tokio::io::AsyncReadExt; +use tokio::net::TcpStream; use tokio::runtime; use tokio::sync::oneshot; @@ -240,3 +242,104 @@ where .join() .unwrap() } + +pub fn low_level_with_response(do_response: F) -> Server +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c> + + Clone + + Send + + 'static, +{ + // Spawn new runtime in thread to prevent reactor execution context conflict + let test_name = thread::current().name().unwrap_or("").to_string(); + thread::spawn(move || { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("new rt"); + let listener = rt.block_on(async move { + tokio::net::TcpListener::bind(&std::net::SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap() + }); + let addr = listener.local_addr().unwrap(); + + let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); + let (panic_tx, panic_rx) = std_mpsc::channel(); + let (events_tx, events_rx) = std_mpsc::channel(); + let tname = format!("test({})-support-server", test_name,); + thread::Builder::new() + .name(tname) + .spawn(move || { + rt.block_on(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + } + accepted = listener.accept() => { + let (io, _) = accepted.expect("accepted"); + let do_response = do_response.clone(); + let events_tx = events_tx.clone(); + tokio::spawn(async move { + low_level_server_client(io, do_response).await; + let _ = events_tx.send(Event::ConnectionClosed); + }); + } + } + } + let _ = panic_tx.send(()); + }); + }) + .expect("thread spawn"); + Server { + addr, + panic_rx, + events_rx, + shutdown_tx: Some(shutdown_tx), + } + }) + .join() + .unwrap() +} + +async fn low_level_server_client(mut client_socket: TcpStream, do_response: F) +where + for<'c> F: Fn(&'c [u8], &'c mut TcpStream) -> Box + Send + 'c>, +{ + loop { + let request = low_level_read_http_request(&mut client_socket) + .await + .expect("read_http_request failed"); + if request.is_empty() { + // connection closed by client + break; + } + + Box::into_pin(do_response(&request, &mut client_socket)).await; + } +} + +async fn low_level_read_http_request( + client_socket: &mut TcpStream, +) -> core::result::Result, std::io::Error> { + let mut buf = Vec::new(); + + // Read until the delimiter "\r\n\r\n" is found + loop { + let mut temp_buffer = [0; 1024]; + let n = client_socket.read(&mut temp_buffer).await?; + + if n == 0 { + break; + } + + buf.extend_from_slice(&temp_buffer[..n]); + + if let Some(pos) = buf.windows(4).position(|window| window == b"\r\n\r\n") { + return Ok(buf.drain(..pos + 4).collect()); + } + } + + Ok(buf) +} diff --git a/tests/zstd.rs b/tests/zstd.rs index d1886ee49..ed3914e79 100644 --- a/tests/zstd.rs +++ b/tests/zstd.rs @@ -1,5 +1,6 @@ mod support; use support::server; +use tokio::io::AsyncWriteExt; #[tokio::test] async fn zstd_response() { @@ -142,3 +143,209 @@ async fn zstd_case(response_size: usize, chunk_size: usize) { let body = res.text().await.expect("text"); assert_eq!(body, content); } + +const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\ + Content-Type: text/plain\x0d\x0a\ + Connection: keep-alive\x0d\x0a\ + Content-Encoding: zstd\x0d\x0a"; + +const RESPONSE_CONTENT: &str = "some message here"; + +fn zstd_compress(input: &[u8]) -> Vec { + zstd_crate::encode_all(input, 3).unwrap() +} + +#[tokio::test] +async fn test_non_chunked_non_fragmented_response() { + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let content_length_header = + format!("Content-Length: {}\r\n\r\n", zstded_content.len()).into_bytes(); + let response = [ + COMPRESSED_RESPONSE_HEADERS, + &content_length_header, + &zstded_content, + ] + .concat(); + + client_socket + .write_all(response.as_slice()) + .await + .expect("response write_all failed"); + client_socket.flush().await.expect("response flush failed"); + }) + }); + + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_1() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_2() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + b"\r\n", + ] + .concat(); + let response_second_part = b"0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} + +#[tokio::test] +async fn test_chunked_fragmented_response_with_extra_bytes() { + const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration = + tokio::time::Duration::from_millis(1000); + const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50); + + let server = server::low_level_with_response(|_raw_request, client_socket| { + Box::new(async move { + let zstded_content = zstd_compress(RESPONSE_CONTENT.as_bytes()); + let response_first_part = [ + COMPRESSED_RESPONSE_HEADERS, + format!( + "Transfer-Encoding: chunked\r\n\r\n{:x}\r\n", + zstded_content.len() + ) + .as_bytes(), + &zstded_content, + ] + .concat(); + let response_second_part = b"\r\n2ab\r\n0\r\n\r\n"; + + client_socket + .write_all(response_first_part.as_slice()) + .await + .expect("response_first_part write_all failed"); + client_socket + .flush() + .await + .expect("response_first_part flush failed"); + + tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await; + + client_socket + .write_all(response_second_part) + .await + .expect("response_second_part write_all failed"); + client_socket + .flush() + .await + .expect("response_second_part flush failed"); + }) + }); + + let start = tokio::time::Instant::now(); + let res = reqwest::Client::new() + .get(&format!("http://{}/", server.addr())) + .send() + .await + .expect("response"); + + let err = res.text().await.expect_err("there must be an error"); + assert!(err.is_decode()); + assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN); +} From 8a2174f8a4259691b5ca0a16778427127f61373a Mon Sep 17 00:00:00 2001 From: Jess Izen <44884346+jlizen@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:18:41 -0800 Subject: [PATCH 3/6] chore: in README, update requirements to mention rustls along with vendored openssl (#2499) Co-authored-by: JESS IZEN --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1f8dbcb1f..b0b0eb813 100644 --- a/README.md +++ b/README.md @@ -53,13 +53,14 @@ On Linux: - OpenSSL with headers. See https://docs.rs/openssl for supported versions and more details. Alternatively you can enable the `native-tls-vendored` - feature to compile a copy of OpenSSL. + feature to compile a copy of OpenSSL. Or, you can use [rustls](https://github.com/rustls/rustls) + via `rustls-tls` or other `rustls-tls-*` features. On Windows and macOS: - Nothing. -Reqwest uses [rust-native-tls](https://github.com/sfackler/rust-native-tls), +By default, Reqwest uses [rust-native-tls](https://github.com/sfackler/rust-native-tls), which will use the operating system TLS framework if available, meaning Windows and macOS. On Linux, it will use the available OpenSSL or fail to build if not found. From 2a7c1b61e0693f8b9924e2e89257aa131a30ee83 Mon Sep 17 00:00:00 2001 From: Jess Izen <44884346+jlizen@users.noreply.github.com> Date: Mon, 23 Dec 2024 06:45:01 -0800 Subject: [PATCH 4/6] feat: allow pluggable tower layers in connector service stack (#2496) Co-authored-by: Jess Izen --- Cargo.toml | 13 +- ...onnect_via_lower_priority_tokio_runtime.rs | 264 +++++++++++++ src/async_impl/client.rs | 103 +++-- src/blocking/client.rs | 38 +- src/connect.rs | 325 ++++++++++----- src/error.rs | 12 + tests/connector_layers.rs | 374 ++++++++++++++++++ tests/support/delay_layer.rs | 119 ++++++ tests/support/mod.rs | 1 + tests/timeouts.rs | 18 + 10 files changed, 1136 insertions(+), 131 deletions(-) create mode 100644 examples/connect_via_lower_priority_tokio_runtime.rs create mode 100644 tests/connector_layers.rs create mode 100644 tests/support/delay_layer.rs diff --git a/Cargo.toml b/Cargo.toml index 39ff48424..1a0c4abf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Sean McArthur "] readme = "README.md" license = "MIT OR Apache-2.0" edition = "2021" -rust-version = "1.63.0" +rust-version = "1.64.0" autotests = true [package.metadata.docs.rs] @@ -105,6 +105,7 @@ url = "2.4" bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" +tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] } tower-service = "0.3" futures-core = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false } @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } @@ -222,6 +222,11 @@ features = [ wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] } wasm-bindgen-test = "0.3" +[dev-dependencies] +tower = { version = "0.5.2", default-features = false, features = ["limit"] } +num_cpus = "1.0" +libc = "0" + [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] } @@ -253,6 +258,10 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "connect_via_lower_priority_tokio_runtime" +path = "examples/connect_via_lower_priority_tokio_runtime.rs" + [[test]] name = "blocking" path = "tests/blocking.rs" diff --git a/examples/connect_via_lower_priority_tokio_runtime.rs b/examples/connect_via_lower_priority_tokio_runtime.rs new file mode 100644 index 000000000..33151d4a1 --- /dev/null +++ b/examples/connect_via_lower_priority_tokio_runtime.rs @@ -0,0 +1,264 @@ +#![deny(warnings)] +// This example demonstrates how to delegate the connect calls, which contain TLS handshakes, +// to a secondary tokio runtime of lower OS thread priority using a custom tower layer. +// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures. +// +// This does introduce overhead of additional threads, channels, extra vtables, etc, +// so it is best suited to services with large numbers of incoming connections or that +// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads +// and/or to use the current_thread runtime. +// +// This is using the `tokio` runtime and certain other dependencies: +// +// `tokio = { version = "1", features = ["full"] }` +// `num_cpus = "1.0"` +// `libc = "0"` +// `pin-project-lite = "0.2"` +// `tower = { version = "0.5", default-features = false}` + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + background_threadpool::init_background_runtime(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let client = reqwest::Client::builder() + .connector_layer(background_threadpool::BackgroundProcessorLayer::new()) + .build() + .expect("should be able to build reqwest client"); + + let url = if let Some(url) = std::env::args().nth(1) { + url + } else { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + }; + + eprintln!("Fetching {url:?}..."); + + let res = client.get(url).send().await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{body}"); + + Ok(()) +} + +// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod background_threadpool { + use std::{ + future::Future, + pin::Pin, + sync::OnceLock, + task::{Context, Poll}, + }; + + use futures_util::TryFutureExt; + use pin_project_lite::pin_project; + use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError}; + use tower::{BoxError, Layer, Service}; + + static CPU_HEAVY_THREAD_POOL: OnceLock< + tokio::sync::mpsc::Sender + Send + 'static>>>, + > = OnceLock::new(); + + pub(crate) fn init_background_runtime() { + std::thread::Builder::new() + .name("cpu-heavy-background-threadpool".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("cpu-heavy-background-pool-thread") + .worker_threads(num_cpus::get() as usize) + // ref: https://github.com/tokio-rs/tokio/issues/4941 + // consider uncommenting if seeing heavy task contention + // .disable_lifo_slot() + .on_thread_start(move || { + #[cfg(target_os = "linux")] + unsafe { + // Increase thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + { + *libc::__errno_location() = 0; + if libc::nice(10) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness: {}", error); + } + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + rt.block_on(async { + log::debug!("starting background cpu-heavy work"); + process_cpu_work().await; + }); + }) + .unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e)); + } + + #[cfg(not(target_arch = "wasm32"))] + async fn process_cpu_work() { + // we only use this channel for routing work, it should move pretty quick, it can be small + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + // share the handle to the background channel globally + CPU_HEAVY_THREAD_POOL.set(tx).unwrap(); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(work); + } + } + + // retrieve the sender to the background channel, and send the future over to it for execution + fn send_to_background_runtime(future: impl Future + Send + 'static) { + let tx = CPU_HEAVY_THREAD_POOL.get().expect( + "start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`", + ); + + match tx.try_send(Box::pin(future)) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => { + panic!("background cpu heavy runtime channel is closed") + } + Err(TrySendError::Full(msg)) => { + log::warn!( + "background cpu heavy runtime channel is full, task spawning loop delayed" + ); + let tx = tx.clone(); + Handle::current().spawn(async move { + tx.send(msg) + .await + .expect("background cpu heavy runtime channel is closed") + }); + } + } + } + + // This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing. + // We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to + // spawn a new task per call. + #[derive(Copy, Clone)] + pub struct BackgroundProcessorLayer {} + impl BackgroundProcessorLayer { + pub fn new() -> Self { + Self {} + } + } + impl Layer for BackgroundProcessorLayer { + type Service = BackgroundProcessor; + fn layer(&self, service: S) -> Self::Service { + BackgroundProcessor::new(service) + } + } + + impl std::fmt::Debug for BackgroundProcessorLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("BackgroundProcessorLayer").finish() + } + } + + // This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing. + #[derive(Debug, Clone)] + pub struct BackgroundProcessor { + inner: S, + } + + impl BackgroundProcessor { + pub fn new(inner: S) -> Self { + BackgroundProcessor { inner } + } + } + + impl Service for BackgroundProcessor + where + S: Service, + S::Response: Send + 'static, + S::Error: Into + Send, + S::Future: Send + 'static, + { + type Response = S::Response; + + type Error = BoxError; + + type Future = BackgroundResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + + // wrap our inner service's future with a future that writes to this oneshot channel + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let future = async move { + select!( + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + } + result = response.map_err(|err| Into::::into(err)) => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + ) + }; + // send the wrapped future to the background + send_to_background_runtime(future); + + BackgroundResponseFuture::new(rx) + } + } + + // `BackgroundProcessor` response future + pin_project! { + #[derive(Debug)] + pub struct BackgroundResponseFuture { + #[pin] + rx: tokio::sync::oneshot::Receiver>, + } + } + + impl BackgroundResponseFuture { + pub(crate) fn new(rx: tokio::sync::oneshot::Receiver>) -> Self { + BackgroundResponseFuture { rx } + } + } + + impl Future for BackgroundResponseFuture + where + S: Send + 'static, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // now poll on the receiver end of the oneshot to get the result + match this.rx.poll(cx) { + Poll::Ready(v) => match v { + Ok(v) => Poll::Ready(v.map_err(Into::into)), + Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32"))] +fn main() {} diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 579050041..354a23205 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,27 +1,14 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; +use std::future::Future; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; -use bytes::Bytes; -use http::header::{ - Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, -}; -use http::uri::Scheme; -use http::Uri; -use hyper_util::client::legacy::connect::HttpConnector; -#[cfg(feature = "default-tls")] -use native_tls_crate::TlsConnector; -use pin_project_lite::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::time::Sleep; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; @@ -30,13 +17,16 @@ use super::Body; use crate::async_impl::h3_client::connect::H3Connector; #[cfg(feature = "http3")] use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; -use crate::connect::Connector; +use crate::connect::{ + sealed::{Conn, Unnameable}, + BoxedConnectorLayer, BoxedConnectorService, Connector, ConnectorBuilder, +}; #[cfg(feature = "cookies")] use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; -use crate::error; +use crate::error::{self, BoxError}; use crate::into_url::try_uri; use crate::redirect::{self, remove_sensitive_headers}; #[cfg(feature = "__rustls")] @@ -48,11 +38,25 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use bytes::Bytes; +use http::header::{ + Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, +}; +use http::uri::Scheme; +use http::Uri; +use hyper_util::client::legacy::connect::HttpConnector; use log::debug; +#[cfg(feature = "default-tls")] +use native_tls_crate::TlsConnector; +use pin_project_lite::pin_project; #[cfg(feature = "http3")] use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +use tokio::time::Sleep; +use tower::util::BoxCloneSyncServiceLayer; +use tower::{Layer, Service}; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; @@ -130,6 +134,7 @@ struct Config { tls_info: bool, #[cfg(feature = "__tls")] tls: TlsBackend, + connector_layers: Vec, http_version_pref: HttpVersionPref, http09_responses: bool, http1_title_case_headers: bool, @@ -185,7 +190,7 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { let mut headers: HeaderMap = HeaderMap::with_capacity(2); headers.insert(ACCEPT, HeaderValue::from_static("*/*")); @@ -233,6 +238,7 @@ impl ClientBuilder { tls_info: false, #[cfg(feature = "__tls")] tls: TlsBackend::default(), + connector_layers: Vec::new(), http_version_pref: HttpVersionPref::All, http09_responses: false, http1_title_case_headers: false, @@ -278,7 +284,9 @@ impl ClientBuilder { }, } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -302,7 +310,7 @@ impl ClientBuilder { #[cfg(feature = "http3")] let mut h3_connector = None; - let mut connector = { + let mut connector_builder = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { headers.get(USER_AGENT).cloned() @@ -445,7 +453,7 @@ impl ClientBuilder { tls.max_protocol_version(Some(protocol)); } - Connector::new_default_tls( + ConnectorBuilder::new_default_tls( http, tls, proxies.clone(), @@ -462,7 +470,7 @@ impl ClientBuilder { )? } #[cfg(feature = "native-tls")] - TlsBackend::BuiltNativeTls(conn) => Connector::from_built_default_tls( + TlsBackend::BuiltNativeTls(conn) => ConnectorBuilder::from_built_default_tls( http, conn, proxies.clone(), @@ -489,7 +497,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, conn, proxies.clone(), @@ -684,7 +692,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, tls, proxies.clone(), @@ -709,7 +717,7 @@ impl ClientBuilder { } #[cfg(not(feature = "__tls"))] - Connector::new( + ConnectorBuilder::new( http, proxies.clone(), config.local_address, @@ -719,8 +727,9 @@ impl ClientBuilder { ) }; - connector.set_timeout(config.connect_timeout); - connector.set_verbose(config.connection_verbose); + connector_builder.set_timeout(config.connect_timeout); + connector_builder.set_verbose(config.connection_verbose); + connector_builder.set_keepalive(config.tcp_keepalive); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -763,7 +772,6 @@ impl ClientBuilder { builder.pool_timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); - connector.set_keepalive(config.tcp_keepalive); if config.http09_responses { builder.http09_responses(true); @@ -801,7 +809,7 @@ impl ClientBuilder { } None => None, }, - hyper: builder.build(connector), + hyper: builder.build(connector_builder.build(config.connector_layers)), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -1953,6 +1961,43 @@ impl ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment.a + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// If configured, the `connect_timeout` will be the outermost layer. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// # #[cfg(not(feature = "rustls-tls-no-provider"))] + /// let client = reqwest::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + /// + pub fn connector_layer(mut self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + let layer = BoxCloneSyncServiceLayer::new(layer); + + self.config.connector_layers.push(layer); + + self + } } type HyperClient = hyper_util::client::legacy::Client; diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 73f25208f..700ce57a9 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -12,11 +12,16 @@ use std::time::Duration; use http::header::HeaderValue; use log::{error, trace}; use tokio::sync::{mpsc, oneshot}; +use tower::Layer; +use tower::Service; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::wait; +use crate::connect::sealed::{Conn, Unnameable}; +use crate::connect::BoxedConnectorService; use crate::dns::Resolve; +use crate::error::BoxError; #[cfg(feature = "__tls")] use crate::tls; #[cfg(feature = "__rustls")] @@ -84,13 +89,15 @@ impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { ClientBuilder { inner: async_impl::ClientBuilder::new(), timeout: Timeout::default(), } } +} +impl ClientBuilder { /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -968,6 +975,35 @@ impl ClientBuilder { self.with_inner(|inner| inner.dns_resolver(resolver)) } + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// Example usage: + /// ``` + /// use std::time::Duration; + /// + /// let client = reqwest::blocking::Client::builder() + /// // resolved to outermost layer, meaning while we are waiting on concurrency limit + /// .connect_timeout(Duration::from_millis(200)) + /// // underneath the concurrency check, so only after concurrency limit lets us through + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + pub fn connector_layer(self, layer: L) -> ClientBuilder + where + L: Layer + Clone + Send + Sync + 'static, + L::Service: + Service + Clone + Send + Sync + 'static, + >::Future: Send + 'static, + { + self.with_inner(|inner| inner.connector_layer(layer)) + } + // private fn with_inner(mut self, func: F) -> ClientBuilder diff --git a/src/connect.rs b/src/connect.rs index ff86ba3c9..c366473cc 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -8,9 +8,11 @@ use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper_util::rt::TokioIo; #[cfg(feature = "default-tls")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; +use pin_project_lite::pin_project; +use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer}; +use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder}; use tower_service::Service; -use pin_project_lite::pin_project; use std::future::Future; use std::io::{self, IoSlice}; use std::net::IpAddr; @@ -24,13 +26,47 @@ use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; -use crate::error::BoxError; +use crate::error::{cast_to_internal_error, BoxError}; use crate::proxy::{Proxy, ProxyScheme}; +use sealed::{Conn, Unnameable}; pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] -pub(crate) struct Connector { +pub(crate) enum Connector { + // base service, with or without an embedded timeout + Simple(ConnectorService), + // at least one custom layer along with maybe an outer timeout layer + // from `builder.connect_timeout()` + WithLayers(BoxCloneSyncService), +} + +impl Service for Connector { + type Response = Conn; + type Error = BoxError; + type Future = Connecting; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Connector::Simple(service) => service.poll_ready(cx), + Connector::WithLayers(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + match self { + Connector::Simple(service) => service.call(dst), + Connector::WithLayers(service) => service.call(Unnameable(dst)), + } + } +} + +pub(crate) type BoxedConnectorService = BoxCloneSyncService; + +pub(crate) type BoxedConnectorLayer = + BoxCloneSyncServiceLayer; + +pub(crate) struct ConnectorBuilder { inner: Inner, proxies: Arc>, verbose: verbose::Wrapper, @@ -43,21 +79,70 @@ pub(crate) struct Connector { user_agent: Option, } -#[derive(Clone)] -enum Inner { - #[cfg(not(feature = "__tls"))] - Http(HttpConnector), - #[cfg(feature = "default-tls")] - DefaultTls(HttpConnector, TlsConnector), - #[cfg(feature = "__rustls")] - RustlsTls { - http: HttpConnector, - tls: Arc, - tls_proxy: Arc, - }, -} +impl ConnectorBuilder { + pub(crate) fn build(self, layers: Vec) -> Connector +where { + // construct the inner tower service + let mut base_service = ConnectorService { + inner: self.inner, + proxies: self.proxies, + verbose: self.verbose, + #[cfg(feature = "__tls")] + nodelay: self.nodelay, + #[cfg(feature = "__tls")] + tls_info: self.tls_info, + #[cfg(feature = "__tls")] + user_agent: self.user_agent, + simple_timeout: None, + }; + + if layers.is_empty() { + // we have no user-provided layers, only use concrete types + base_service.simple_timeout = self.timeout; + return Connector::Simple(base_service); + } + + // otherwise we have user provided layers + // so we need type erasure all the way through + // as well as mapping the unnameable type of the layers back to Uri for the inner service + let unnameable_service = ServiceBuilder::new() + .layer(MapRequestLayer::new(|request: Unnameable| request.0)) + .service(base_service); + let mut service = BoxCloneSyncService::new(unnameable_service); + + for layer in layers { + service = ServiceBuilder::new().layer(layer).service(service); + } + + // now we handle the concrete stuff - any `connect_timeout`, + // plus a final map_err layer we can use to cast default tower layer + // errors to internal errors + match self.timeout { + Some(timeout) => { + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(timeout)) + .service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + + Connector::WithLayers(service) + } + None => { + // no timeout, but still map err + // no named timeout layer but we still map errors since + // we might have user-provided timeout layer + let service = ServiceBuilder::new().service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + Connector::WithLayers(service) + } + } + } -impl Connector { #[cfg(not(feature = "__tls"))] pub(crate) fn new( mut http: HttpConnector, @@ -66,7 +151,7 @@ impl Connector { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] interface: Option<&str>, nodelay: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -77,10 +162,10 @@ impl Connector { } http.set_nodelay(nodelay); - Connector { + ConnectorBuilder { inner: Inner::Http(http), - verbose: verbose::OFF, proxies, + verbose: verbose::OFF, timeout: None, } } @@ -96,7 +181,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> crate::Result + ) -> crate::Result where T: Into>, { @@ -125,7 +210,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -137,14 +222,14 @@ impl Connector { http.set_nodelay(nodelay); http.enforce_http(false); - Connector { + ConnectorBuilder { inner: Inner::DefaultTls(http, tls), proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -159,7 +244,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -180,7 +265,7 @@ impl Connector { (Arc::new(tls), Arc::new(tls_proxy)) }; - Connector { + ConnectorBuilder { inner: Inner::RustlsTls { http, tls, @@ -188,10 +273,10 @@ impl Connector { }, proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -203,6 +288,52 @@ impl Connector { self.verbose.0 = enabled; } + pub(crate) fn set_keepalive(&mut self, dur: Option) { + match &mut self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), + #[cfg(feature = "__rustls")] + Inner::RustlsTls { http, .. } => http.set_keepalive(dur), + #[cfg(not(feature = "__tls"))] + Inner::Http(http) => http.set_keepalive(dur), + } + } +} + +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub(crate) struct ConnectorService { + inner: Inner, + proxies: Arc>, + verbose: verbose::Wrapper, + /// When there is a single timeout layer and no other layers, + /// we embed it directly inside our base Service::call(). + /// This lets us avoid an extra `Box::pin` indirection layer + /// since `tokio::time::Timeout` is `Unpin` + simple_timeout: Option, + #[cfg(feature = "__tls")] + nodelay: bool, + #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] + user_agent: Option, +} + +#[derive(Clone)] +enum Inner { + #[cfg(not(feature = "__tls"))] + Http(HttpConnector), + #[cfg(feature = "default-tls")] + DefaultTls(HttpConnector, TlsConnector), + #[cfg(feature = "__rustls")] + RustlsTls { + http: HttpConnector, + tls: Arc, + tls_proxy: Arc, + }, +} + +impl ConnectorService { #[cfg(feature = "socks")] async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result { let dns = match proxy { @@ -449,17 +580,6 @@ impl Connector { self.connect_with_maybe_proxy(proxy_dst, true).await } - - pub fn set_keepalive(&mut self, dur: Option) { - match &mut self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), - #[cfg(feature = "__rustls")] - Inner::RustlsTls { http, .. } => http.set_keepalive(dur), - #[cfg(not(feature = "__tls"))] - Inner::Http(http) => http.set_keepalive(dur), - } - } } fn into_uri(scheme: Scheme, host: Authority) -> Uri { @@ -487,7 +607,7 @@ where } } -impl Service for Connector { +impl Service for ConnectorService { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -498,7 +618,7 @@ impl Service for Connector { fn call(&mut self, dst: Uri) -> Self::Future { log::debug!("starting new connection: {dst:?}"); - let timeout = self.timeout; + let timeout = self.simple_timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( @@ -633,80 +753,87 @@ impl AsyncConnWithInfo for T {} type BoxConn = Box; -pin_project! { - /// Note: the `is_proxy` member means *is plain text HTTP proxy*. - /// This tells hyper whether the URI should be written in - /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or - /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. - pub(crate) struct Conn { - #[pin] - inner: BoxConn, - is_proxy: bool, - // Only needed for __tls, but #[cfg()] on fields breaks pin_project! - tls_info: bool, +pub(crate) mod sealed { + use super::*; + #[derive(Debug, Clone)] + pub struct Unnameable(pub(super) Uri); + + pin_project! { + /// Note: the `is_proxy` member means *is plain text HTTP proxy*. + /// This tells hyper whether the URI should be written in + /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or + /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. + #[allow(missing_debug_implementations)] + pub struct Conn { + #[pin] + pub(super)inner: BoxConn, + pub(super) is_proxy: bool, + // Only needed for __tls, but #[cfg()] on fields breaks pin_project! + pub(super) tls_info: bool, + } } -} -impl Connection for Conn { - fn connected(&self) -> Connected { - let connected = self.inner.connected().proxy(self.is_proxy); - #[cfg(feature = "__tls")] - if self.tls_info { - if let Some(tls_info) = self.inner.tls_info() { - connected.extra(tls_info) + impl Connection for Conn { + fn connected(&self) -> Connected { + let connected = self.inner.connected().proxy(self.is_proxy); + #[cfg(feature = "__tls")] + if self.tls_info { + if let Some(tls_info) = self.inner.tls_info() { + connected.extra(tls_info) + } else { + connected + } } else { connected } - } else { + #[cfg(not(feature = "__tls"))] connected } - #[cfg(not(feature = "__tls"))] - connected } -} -impl Read for Conn { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: ReadBufCursor<'_>, - ) -> Poll> { - let this = self.project(); - Read::poll_read(this.inner, cx, buf) + impl Read for Conn { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: ReadBufCursor<'_>, + ) -> Poll> { + let this = self.project(); + Read::poll_read(this.inner, cx, buf) + } } -} -impl Write for Conn { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - let this = self.project(); - Write::poll_write(this.inner, cx, buf) - } + impl Write for Conn { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let this = self.project(); + Write::poll_write(this.inner, cx, buf) + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - let this = self.project(); - Write::poll_write_vectored(this.inner, cx, bufs) - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + let this = self.project(); + Write::poll_write_vectored(this.inner, cx, bufs) + } - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_flush(this.inner, cx) - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_flush(this.inner, cx) + } - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let this = self.project(); - Write::poll_shutdown(this.inner, cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + Write::poll_shutdown(this.inner, cx) + } } } diff --git a/src/error.rs b/src/error.rs index ca7413fd6..6a9f07e51 100644 --- a/src/error.rs +++ b/src/error.rs @@ -165,6 +165,18 @@ impl Error { } } +/// Converts from external types to reqwest's +/// internal equivalents. +/// +/// Currently only is used for `tower::timeout::error::Elapsed`. +pub(crate) fn cast_to_internal_error(error: BoxError) -> BoxError { + if error.is::() { + Box::new(crate::error::TimedOut) as BoxError + } else { + error + } +} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("reqwest::Error"); diff --git a/tests/connector_layers.rs b/tests/connector_layers.rs new file mode 100644 index 000000000..1be18aeb8 --- /dev/null +++ b/tests/connector_layers.rs @@ -0,0 +1,374 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::time::Duration; + +use futures_util::future::join_all; +use tower::layer::util::Identity; +use tower::limit::ConcurrencyLimitLayer; +use tower::timeout::TimeoutLayer; + +use support::{delay_layer::DelayLayer, server}; + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer_with_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_never_returning() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_slow() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_under_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(300))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(500))) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_over_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connect_timeout(Duration::from_millis(50)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + let timed_out = all_res + .into_iter() + .any(|res| res.is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + for res in all_res.into_iter() { + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(feature = "blocking")] +#[test] +fn non_op_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(Identity::new()) + .build() + .unwrap(); + + let res = client.get(url).send(); + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn timeout_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send(); + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + let timed_out = join_handles + .into_iter() + .any(|handle| handle.join().unwrap().is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + for handle in join_handles { + let res = handle.join().unwrap(); + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn no_generic_bounds_required_for_client_new() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::new(); + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn no_generic_bounds_required_for_client_new_blocking() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::new(); + let res = client.get(url).send(); + + assert!(res.is_ok()); +} diff --git a/tests/support/delay_layer.rs b/tests/support/delay_layer.rs new file mode 100644 index 000000000..b8eec42a1 --- /dev/null +++ b/tests/support/delay_layer.rs @@ -0,0 +1,119 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; +use tower::{BoxError, Layer, Service}; + +/// This tower layer injects an arbitrary delay before calling downstream layers. +#[derive(Clone)] +pub struct DelayLayer { + delay: Duration, +} + +impl DelayLayer { + pub const fn new(delay: Duration) -> Self { + DelayLayer { delay } + } +} + +impl Layer for DelayLayer { + type Service = Delay; + fn layer(&self, service: S) -> Self::Service { + Delay::new(service, self.delay) + } +} + +impl std::fmt::Debug for DelayLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DelayLayer") + .field("delay", &self.delay) + .finish() + } +} + +/// This tower service injects an arbitrary delay before calling downstream layers. +#[derive(Debug, Clone)] +pub struct Delay { + inner: S, + delay: Duration, +} +impl Delay { + pub fn new(inner: S, delay: Duration) -> Self { + Delay { inner, delay } + } +} + +impl Service for Delay +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + let sleep = tokio::time::sleep(self.delay); + + ResponseFuture::new(response, sleep) + } +} + +// `Delay` response future +pin_project! { + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + response: S, + #[pin] + sleep: Sleep, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: S, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // First poll the sleep until complete + match this.sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + + // Then poll the inner future + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796956d8..9d4ce7b9b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_layer; pub mod delay_server; pub mod server; diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 79a6fbb4d..71dc0ce66 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -337,6 +337,24 @@ fn timeout_blocking_request() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[cfg(feature = "blocking")] +#[test] +fn connect_timeout_blocking_request() { + let _ = env_logger::try_init(); + + let client = reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let err = client.get(url).send().unwrap_err(); + + assert!(err.is_timeout()); +} + #[cfg(feature = "blocking")] #[cfg(feature = "stream")] #[test] From 44ca5ee864ebff81e987263de74be7002fc6b353 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 23 Dec 2024 10:14:10 -0500 Subject: [PATCH 5/6] remove Clone from connect::Unnameable for now (#2502) --- src/connect.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/connect.rs b/src/connect.rs index c366473cc..dfaf028a9 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -755,7 +755,7 @@ type BoxConn = Box; pub(crate) mod sealed { use super::*; - #[derive(Debug, Clone)] + #[derive(Debug)] pub struct Unnameable(pub(super) Uri); pin_project! { From 3ce98b5f2288637e22dad98e881c210246567021 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Mon, 23 Dec 2024 15:19:17 -0500 Subject: [PATCH 6/6] fix: propagate Body::size_hint when wrapping bodies (#2503) --- src/async_impl/body.rs | 53 ++++++++++++++++++++++++++++++++++++------ tests/client.rs | 2 +- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/async_impl/body.rs b/src/async_impl/body.rs index c2f1257c1..454046dd0 100644 --- a/src/async_impl/body.rs +++ b/src/async_impl/body.rs @@ -148,10 +148,7 @@ impl Body { { use http_body_util::BodyExt; - let boxed = inner - .map_frame(|f| f.map_data(Into::into)) - .map_err(Into::into) - .boxed(); + let boxed = IntoBytesBody { inner }.map_err(Into::into).boxed(); Body { inner: Inner::Streaming(boxed), @@ -461,6 +458,47 @@ where } } +// ===== impl IntoBytesBody ===== + +pin_project! { + struct IntoBytesBody { + #[pin] + inner: B, + } +} + +// We can't use `map_frame()` because that loses the hint data (for good reason). +// But we aren't transforming the data. +impl hyper::body::Body for IntoBytesBody +where + B: hyper::body::Body, + B::Data: Into, +{ + type Data = Bytes; + type Error = B::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll, Self::Error>>> { + match futures_core::ready!(self.project().inner.poll_frame(cx)) { + Some(Ok(f)) => Poll::Ready(Some(Ok(f.map_data(Into::into)))), + Some(Err(e)) => Poll::Ready(Some(Err(e))), + None => Poll::Ready(None), + } + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } +} + #[cfg(test)] mod tests { use http_body::Body as _; @@ -484,8 +522,9 @@ mod tests { assert!(!bytes_body.is_end_stream()); assert_eq!(bytes_body.size_hint().exact(), Some(3)); - let stream_body = Body::wrap(bytes_body); - assert!(!stream_body.is_end_stream()); - assert_eq!(stream_body.size_hint().exact(), None); + // can delegate even when wrapped + let stream_body = Body::wrap(empty_body); + assert!(stream_body.is_end_stream()); + assert_eq!(stream_body.size_hint().exact(), Some(0)); } } diff --git a/tests/client.rs b/tests/client.rs index 51fb9dfa0..f99418322 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -197,7 +197,7 @@ async fn body_pipe_response() { http::Response::new("pipe me".into()) } else { assert_eq!(req.uri(), "/pipe"); - assert_eq!(req.headers()["transfer-encoding"], "chunked"); + assert_eq!(req.headers()["content-length"], "7"); let full: Vec = req .into_body()