diff --git a/Cargo.toml b/Cargo.toml index 39ff48424..1a0c4abf6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Sean McArthur "] readme = "README.md" license = "MIT OR Apache-2.0" edition = "2021" -rust-version = "1.63.0" +rust-version = "1.64.0" autotests = true [package.metadata.docs.rs] @@ -105,6 +105,7 @@ url = "2.4" bytes = "1.0" serde = "1.0" serde_urlencoded = "0.7.1" +tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] } tower-service = "0.3" futures-core = { version = "0.3.28", default-features = false } futures-util = { version = "0.3.28", default-features = false } @@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn futures-channel = { version = "0.3", optional = true } - [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] env_logger = "0.10" hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] } @@ -222,6 +222,11 @@ features = [ wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] } wasm-bindgen-test = "0.3" +[dev-dependencies] +tower = { version = "0.5.2", default-features = false, features = ["limit"] } +num_cpus = "1.0" +libc = "0" + [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] } @@ -253,6 +258,10 @@ path = "examples/form.rs" name = "simple" path = "examples/simple.rs" +[[example]] +name = "connect_via_lower_priority_tokio_runtime" +path = "examples/connect_via_lower_priority_tokio_runtime.rs" + [[test]] name = "blocking" path = "tests/blocking.rs" diff --git a/examples/connect_via_lower_priority_tokio_runtime.rs b/examples/connect_via_lower_priority_tokio_runtime.rs new file mode 100644 index 000000000..0567a6df7 --- /dev/null +++ b/examples/connect_via_lower_priority_tokio_runtime.rs @@ -0,0 +1,274 @@ +#![deny(warnings)] +// This example demonstrates how to delegate the connect calls, which contain TLS handshakes, +// to a secondary tokio runtime of lower OS thread priority using a custom tower layer. +// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures. +// +// This does introduce overhead of additional threads, channels, extra vtables, etc, +// so it is best suited to services with large numbers of incoming connections or that +// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads +// and/or to use the current_thread runtime. +// +// This is using the `tokio` runtime and certain other dependencies: +// +// `tokio = { version = "1", features = ["full"] }` +// `num_cpus = "1.0"` +// `libc = "0"` +// `pin-project-lite = "0.2"` +// `tower = { version = "0.5", default-features = false}` + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::main] +async fn main() -> Result<(), reqwest::Error> { + background_threadpool::init_background_runtime(); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let client = reqwest::Client::builder() + .connector_layer(background_threadpool::BackgroundProcessorLayer::::new()) + .build() + .expect("should be able to build reqwest client"); + + let url = if let Some(url) = std::env::args().nth(1) { + url + } else { + println!("No CLI URL provided, using default."); + "https://hyper.rs".into() + }; + + eprintln!("Fetching {url:?}..."); + + let res = client.get(url).send().await?; + + eprintln!("Response: {:?} {}", res.version(), res.status()); + eprintln!("Headers: {:#?}\n", res.headers()); + + let body = res.text().await?; + + println!("{body}"); + + Ok(()) +} + +// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))] +#[cfg(not(target_arch = "wasm32"))] +mod background_threadpool { + use std::{ + future::Future, + marker::PhantomData, + pin::Pin, + sync::OnceLock, + task::{Context, Poll}, + }; + + use futures_util::TryFutureExt; + use pin_project_lite::pin_project; + use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError}; + use tower::{BoxError, Layer, Service}; + + static CPU_HEAVY_THREAD_POOL: OnceLock< + tokio::sync::mpsc::Sender + Send + 'static>>>, + > = OnceLock::new(); + + pub(crate) fn init_background_runtime() { + std::thread::Builder::new() + .name("cpu-heavy-background-threadpool".to_string()) + .spawn(move || { + let rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("cpu-heavy-background-pool-thread") + .worker_threads(num_cpus::get() as usize) + // ref: https://github.com/tokio-rs/tokio/issues/4941 + // consider uncommenting if seeing heavy task contention + // .disable_lifo_slot() + .on_thread_start(move || { + #[cfg(target_os = "linux")] + unsafe { + // Increase thread pool thread niceness, so they are lower priority + // than the foreground executor and don't interfere with I/O tasks + { + *libc::__errno_location() = 0; + if libc::nice(10) == -1 && *libc::__errno_location() != 0 { + let error = std::io::Error::last_os_error(); + log::error!("failed to set threadpool niceness: {}", error); + } + } + } + }) + .enable_all() + .build() + .unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e)); + rt.block_on(async { + log::debug!("starting background cpu-heavy work"); + process_cpu_work().await; + }); + }) + .unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e)); + } + + #[cfg(not(target_arch = "wasm32"))] + async fn process_cpu_work() { + // we only use this channel for routing work, it should move pretty quick, it can be small + let (tx, mut rx) = tokio::sync::mpsc::channel(10); + // share the handle to the background channel globally + CPU_HEAVY_THREAD_POOL.set(tx).unwrap(); + + while let Some(work) = rx.recv().await { + tokio::task::spawn(work); + } + } + + // retrieve the sender to the background channel, and send the future over to it for execution + fn send_to_background_runtime(future: impl Future + Send + 'static) { + let tx = CPU_HEAVY_THREAD_POOL.get().expect( + "start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`", + ); + + match tx.try_send(Box::pin(future)) { + Ok(_) => (), + Err(TrySendError::Closed(_)) => { + panic!("background cpu heavy runtime channel is closed") + } + Err(TrySendError::Full(msg)) => { + log::warn!( + "background cpu heavy runtime channel is full, task spawning loop delayed" + ); + let tx = tx.clone(); + Handle::current().spawn(async move { + tx.send(msg) + .await + .expect("background cpu heavy runtime channel is closed") + }); + } + } + } + + // This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing. + // We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to + // spawn a new task per call. + pub struct BackgroundProcessorLayer { + _p: PhantomData, + } + impl BackgroundProcessorLayer { + pub fn new() -> Self { + Self { _p: PhantomData } + } + } + impl Layer for BackgroundProcessorLayer { + type Service = BackgroundProcessor; + fn layer(&self, service: S) -> Self::Service { + BackgroundProcessor::new(service) + } + } + + impl std::fmt::Debug for BackgroundProcessorLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("BackgroundProcessorLayer").finish() + } + } + + impl Clone for BackgroundProcessorLayer { + fn clone(&self) -> Self { + Self { _p: PhantomData } + } + } + + impl Copy for BackgroundProcessorLayer {} + + // This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing. + #[derive(Debug, Clone)] + pub struct BackgroundProcessor { + inner: S, + } + + impl BackgroundProcessor { + pub fn new(inner: S) -> Self { + BackgroundProcessor { inner } + } + } + + impl Service for BackgroundProcessor + where + S: Service, + S::Response: Send + 'static, + S::Error: Into + Send, + S::Future: Send + 'static, + { + type Response = S::Response; + + type Error = BoxError; + + type Future = BackgroundResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + + // wrap our inner service's future with a future that writes to this oneshot channel + let (mut tx, rx) = tokio::sync::oneshot::channel(); + let future = async move { + select!( + _ = tx.closed() => { + // receiver already dropped, don't need to do anything + } + result = response.map_err(|err| Into::::into(err)) => { + // if this fails, the receiver already dropped, so we don't need to do anything + let _ = tx.send(result); + } + ) + }; + // send the wrapped future to the background + send_to_background_runtime(future); + + BackgroundResponseFuture::new(rx) + } + } + + // `BackgroundProcessor` response future + pin_project! { + #[derive(Debug)] + pub struct BackgroundResponseFuture { + #[pin] + rx: tokio::sync::oneshot::Receiver>, + } + } + + impl BackgroundResponseFuture { + pub(crate) fn new(rx: tokio::sync::oneshot::Receiver>) -> Self { + BackgroundResponseFuture { rx } + } + } + + impl Future for BackgroundResponseFuture + where + S: Send + 'static, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // now poll on the receiver end of the oneshot to get the result + match this.rx.poll(cx) { + Poll::Ready(v) => match v { + Ok(v) => Poll::Ready(v.map_err(Into::into)), + Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)), + }, + Poll::Pending => Poll::Pending, + } + } + } +} + +// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function +// for wasm32 target, because tokio isn't compatible with wasm32. +// If you aren't building for wasm32, you don't need that line. +// The two lines below avoid the "'main' function not found" error when building for wasm32 target. +#[cfg(any(target_arch = "wasm32"))] +fn main() {} diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index 579050041..bc0e518dd 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -1,27 +1,14 @@ #[cfg(any(feature = "native-tls", feature = "__rustls",))] use std::any::Any; +use std::future::Future; use std::net::IpAddr; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use std::time::Duration; use std::{collections::HashMap, convert::TryInto, net::SocketAddr}; use std::{fmt, str}; -use bytes::Bytes; -use http::header::{ - Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, - CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, -}; -use http::uri::Scheme; -use http::Uri; -use hyper_util::client::legacy::connect::HttpConnector; -#[cfg(feature = "default-tls")] -use native_tls_crate::TlsConnector; -use pin_project_lite::pin_project; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; -use tokio::time::Sleep; - use super::decoder::Accepts; use super::request::{Request, RequestBuilder}; use super::response::Response; @@ -30,13 +17,13 @@ use super::Body; use crate::async_impl::h3_client::connect::H3Connector; #[cfg(feature = "http3")] use crate::async_impl::h3_client::{H3Client, H3ResponseFuture}; -use crate::connect::Connector; +use crate::connect::{Conn, Connector, ConnectorBuilder, ConnectorLayerBuilder, ConnectorService}; #[cfg(feature = "cookies")] use crate::cookie; #[cfg(feature = "hickory-dns")] use crate::dns::hickory::HickoryDnsResolver; use crate::dns::{gai::GaiResolver, DnsResolverWithOverrides, DynResolver, Resolve}; -use crate::error; +use crate::error::{self, BoxError}; use crate::into_url::try_uri; use crate::redirect::{self, remove_sensitive_headers}; #[cfg(feature = "__rustls")] @@ -48,11 +35,24 @@ use crate::Certificate; #[cfg(any(feature = "native-tls", feature = "__rustls"))] use crate::Identity; use crate::{IntoUrl, Method, Proxy, StatusCode, Url}; +use bytes::Bytes; +use http::header::{ + Entry, HeaderMap, HeaderValue, ACCEPT, ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, + CONTENT_TYPE, LOCATION, PROXY_AUTHORIZATION, RANGE, REFERER, TRANSFER_ENCODING, USER_AGENT, +}; +use http::uri::Scheme; +use http::Uri; +use hyper_util::client::legacy::connect::HttpConnector; use log::debug; +#[cfg(feature = "default-tls")] +use native_tls_crate::TlsConnector; +use pin_project_lite::pin_project; #[cfg(feature = "http3")] use quinn::TransportConfig; #[cfg(feature = "http3")] use quinn::VarInt; +use tokio::time::Sleep; +use tower::{layer::util::Stack, Layer, Service, ServiceBuilder}; type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture; @@ -76,8 +76,10 @@ pub struct Client { /// A `ClientBuilder` can be used to create a `Client` with custom configuration. #[must_use] -pub struct ClientBuilder { +pub struct ClientBuilder { config: Config, + // separated out to simplify casting between generic types while copying config + connector_layers: ConnectorLayerBuilder, } enum HttpVersionPref { @@ -175,17 +177,17 @@ struct Config { dns_resolver: Option>, } -impl Default for ClientBuilder { +impl Default for ClientBuilder { fn default() -> Self { Self::new() } } -impl ClientBuilder { +impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { let mut headers: HeaderMap = HeaderMap::with_capacity(2); headers.insert(ACCEPT, HeaderValue::from_static("*/*")); @@ -276,9 +278,21 @@ impl ClientBuilder { quic_send_window: None, dns_resolver: None, }, + connector_layers: ConnectorLayerBuilder { + builder: ServiceBuilder::new(), + has_custom_layers: false, + }, } } +} +impl ClientBuilder +where + ConnectorLayers: Layer, + ConnectorLayers::Service: + Service + Clone + Send + Sync + 'static, + <>::Service as Service>::Future: Send + 'static, +{ /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -302,7 +316,7 @@ impl ClientBuilder { #[cfg(feature = "http3")] let mut h3_connector = None; - let mut connector = { + let mut connector_builder = { #[cfg(feature = "__tls")] fn user_agent(headers: &HeaderMap) -> Option { headers.get(USER_AGENT).cloned() @@ -445,7 +459,7 @@ impl ClientBuilder { tls.max_protocol_version(Some(protocol)); } - Connector::new_default_tls( + ConnectorBuilder::new_default_tls( http, tls, proxies.clone(), @@ -462,7 +476,7 @@ impl ClientBuilder { )? } #[cfg(feature = "native-tls")] - TlsBackend::BuiltNativeTls(conn) => Connector::from_built_default_tls( + TlsBackend::BuiltNativeTls(conn) => ConnectorBuilder::from_built_default_tls( http, conn, proxies.clone(), @@ -489,7 +503,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, conn, proxies.clone(), @@ -684,7 +698,7 @@ impl ClientBuilder { )?; } - Connector::new_rustls_tls( + ConnectorBuilder::new_rustls_tls( http, tls, proxies.clone(), @@ -709,7 +723,7 @@ impl ClientBuilder { } #[cfg(not(feature = "__tls"))] - Connector::new( + ConnectorBuilder::new( http, proxies.clone(), config.local_address, @@ -719,8 +733,9 @@ impl ClientBuilder { ) }; - connector.set_timeout(config.connect_timeout); - connector.set_verbose(config.connection_verbose); + connector_builder.set_timeout(config.connect_timeout); + connector_builder.set_verbose(config.connection_verbose); + connector_builder.set_keepalive(config.tcp_keepalive); let mut builder = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()); @@ -763,7 +778,6 @@ impl ClientBuilder { builder.pool_timer(hyper_util::rt::TokioTimer::new()); builder.pool_idle_timeout(config.pool_idle_timeout); builder.pool_max_idle_per_host(config.pool_max_idle_per_host); - connector.set_keepalive(config.tcp_keepalive); if config.http09_responses { builder.http09_responses(true); @@ -801,7 +815,7 @@ impl ClientBuilder { } None => None, }, - hyper: builder.build(connector), + hyper: builder.build(connector_builder.build(self.connector_layers)), headers: config.headers, redirect_policy: config.redirect_policy, referer: config.referer, @@ -836,7 +850,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn user_agent(mut self, value: V) -> ClientBuilder + pub fn user_agent(mut self, value: V) -> ClientBuilder where V: TryInto, V::Error: Into, @@ -874,7 +888,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn default_headers(mut self, headers: HeaderMap) -> ClientBuilder { + pub fn default_headers(mut self, headers: HeaderMap) -> ClientBuilder { for (key, value) in headers.iter() { self.config.headers.insert(key, value.clone()); } @@ -897,7 +911,7 @@ impl ClientBuilder { /// This requires the optional `cookies` feature to be enabled. #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] - pub fn cookie_store(mut self, enable: bool) -> ClientBuilder { + pub fn cookie_store(mut self, enable: bool) -> ClientBuilder { if enable { self.cookie_provider(Arc::new(cookie::Jar::default())) } else { @@ -924,7 +938,7 @@ impl ClientBuilder { pub fn cookie_provider( mut self, cookie_store: Arc, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.cookie_store = Some(cookie_store as _); self } @@ -947,7 +961,7 @@ impl ClientBuilder { /// This requires the optional `gzip` feature to be enabled #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] - pub fn gzip(mut self, enable: bool) -> ClientBuilder { + pub fn gzip(mut self, enable: bool) -> ClientBuilder { self.config.accepts.gzip = enable; self } @@ -970,7 +984,7 @@ impl ClientBuilder { /// This requires the optional `brotli` feature to be enabled #[cfg(feature = "brotli")] #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))] - pub fn brotli(mut self, enable: bool) -> ClientBuilder { + pub fn brotli(mut self, enable: bool) -> ClientBuilder { self.config.accepts.brotli = enable; self } @@ -993,7 +1007,7 @@ impl ClientBuilder { /// This requires the optional `zstd` feature to be enabled #[cfg(feature = "zstd")] #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] - pub fn zstd(mut self, enable: bool) -> ClientBuilder { + pub fn zstd(mut self, enable: bool) -> ClientBuilder { self.config.accepts.zstd = enable; self } @@ -1016,7 +1030,7 @@ impl ClientBuilder { /// This requires the optional `deflate` feature to be enabled #[cfg(feature = "deflate")] #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))] - pub fn deflate(mut self, enable: bool) -> ClientBuilder { + pub fn deflate(mut self, enable: bool) -> ClientBuilder { self.config.accepts.deflate = enable; self } @@ -1026,7 +1040,7 @@ impl ClientBuilder { /// This method exists even if the optional `gzip` feature is not enabled. /// This can be used to ensure a `Client` doesn't use gzip decompression /// even if another dependency were to enable the optional `gzip` feature. - pub fn no_gzip(self) -> ClientBuilder { + pub fn no_gzip(self) -> ClientBuilder { #[cfg(feature = "gzip")] { self.gzip(false) @@ -1043,7 +1057,7 @@ impl ClientBuilder { /// This method exists even if the optional `brotli` feature is not enabled. /// This can be used to ensure a `Client` doesn't use brotli decompression /// even if another dependency were to enable the optional `brotli` feature. - pub fn no_brotli(self) -> ClientBuilder { + pub fn no_brotli(self) -> ClientBuilder { #[cfg(feature = "brotli")] { self.brotli(false) @@ -1060,7 +1074,7 @@ impl ClientBuilder { /// This method exists even if the optional `zstd` feature is not enabled. /// This can be used to ensure a `Client` doesn't use zstd decompression /// even if another dependency were to enable the optional `zstd` feature. - pub fn no_zstd(self) -> ClientBuilder { + pub fn no_zstd(self) -> ClientBuilder { #[cfg(feature = "zstd")] { self.zstd(false) @@ -1077,7 +1091,7 @@ impl ClientBuilder { /// This method exists even if the optional `deflate` feature is not enabled. /// This can be used to ensure a `Client` doesn't use deflate decompression /// even if another dependency were to enable the optional `deflate` feature. - pub fn no_deflate(self) -> ClientBuilder { + pub fn no_deflate(self) -> ClientBuilder { #[cfg(feature = "deflate")] { self.deflate(false) @@ -1094,7 +1108,7 @@ impl ClientBuilder { /// Set a `RedirectPolicy` for this client. /// /// Default will follow redirects up to a maximum of 10. - pub fn redirect(mut self, policy: redirect::Policy) -> ClientBuilder { + pub fn redirect(mut self, policy: redirect::Policy) -> ClientBuilder { self.config.redirect_policy = policy; self } @@ -1102,7 +1116,7 @@ impl ClientBuilder { /// Enable or disable automatic setting of the `Referer` header. /// /// Default is `true`. - pub fn referer(mut self, enable: bool) -> ClientBuilder { + pub fn referer(mut self, enable: bool) -> ClientBuilder { self.config.referer = enable; self } @@ -1114,7 +1128,7 @@ impl ClientBuilder { /// # Note /// /// Adding a proxy will disable the automatic usage of the "system" proxy. - pub fn proxy(mut self, proxy: Proxy) -> ClientBuilder { + pub fn proxy(mut self, proxy: Proxy) -> ClientBuilder { self.config.proxies.push(proxy); self.config.auto_sys_proxy = false; self @@ -1127,7 +1141,7 @@ impl ClientBuilder { /// on all desired proxies instead. /// /// This also disables the automatic usage of the "system" proxy. - pub fn no_proxy(mut self) -> ClientBuilder { + pub fn no_proxy(mut self) -> ClientBuilder { self.config.proxies.clear(); self.config.auto_sys_proxy = false; self @@ -1141,7 +1155,7 @@ impl ClientBuilder { /// response body has finished. Also considered a total deadline. /// /// Default is no timeout. - pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.timeout = Some(timeout); self } @@ -1153,7 +1167,7 @@ impl ClientBuilder { /// connections when the size isn't known beforehand. /// /// Default is no timeout. - pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn read_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.read_timeout = Some(timeout); self } @@ -1166,7 +1180,7 @@ impl ClientBuilder { /// /// This **requires** the futures be executed in a tokio runtime with /// a tokio timer enabled. - pub fn connect_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn connect_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.connect_timeout = Some(timeout); self } @@ -1177,7 +1191,7 @@ impl ClientBuilder { /// for read and write operations on connections. /// /// [log]: https://crates.io/crates/log - pub fn connection_verbose(mut self, verbose: bool) -> ClientBuilder { + pub fn connection_verbose(mut self, verbose: bool) -> ClientBuilder { self.config.connection_verbose = verbose; self } @@ -1189,7 +1203,7 @@ impl ClientBuilder { /// Pass `None` to disable timeout. /// /// Default is 90 seconds. - pub fn pool_idle_timeout(mut self, val: D) -> ClientBuilder + pub fn pool_idle_timeout(mut self, val: D) -> ClientBuilder where D: Into>, { @@ -1198,13 +1212,13 @@ impl ClientBuilder { } /// Sets the maximum idle connection per host allowed in the pool. - pub fn pool_max_idle_per_host(mut self, max: usize) -> ClientBuilder { + pub fn pool_max_idle_per_host(mut self, max: usize) -> ClientBuilder { self.config.pool_max_idle_per_host = max; self } /// Send headers as title case instead of lowercase. - pub fn http1_title_case_headers(mut self) -> ClientBuilder { + pub fn http1_title_case_headers(mut self) -> ClientBuilder { self.config.http1_title_case_headers = true; self } @@ -1217,14 +1231,17 @@ impl ClientBuilder { pub fn http1_allow_obsolete_multiline_headers_in_responses( mut self, value: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config .http1_allow_obsolete_multiline_headers_in_responses = value; self } /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. - pub fn http1_ignore_invalid_headers_in_responses(mut self, value: bool) -> ClientBuilder { + pub fn http1_ignore_invalid_headers_in_responses( + mut self, + value: bool, + ) -> ClientBuilder { self.config.http1_ignore_invalid_headers_in_responses = value; self } @@ -1237,20 +1254,20 @@ impl ClientBuilder { pub fn http1_allow_spaces_after_header_name_in_responses( mut self, value: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config .http1_allow_spaces_after_header_name_in_responses = value; self } /// Only use HTTP/1. - pub fn http1_only(mut self) -> ClientBuilder { + pub fn http1_only(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http1; self } /// Allow HTTP/0.9 responses - pub fn http09_responses(mut self) -> ClientBuilder { + pub fn http09_responses(mut self) -> ClientBuilder { self.config.http09_responses = true; self } @@ -1258,7 +1275,7 @@ impl ClientBuilder { /// Only use HTTP/2. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_prior_knowledge(mut self) -> ClientBuilder { + pub fn http2_prior_knowledge(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http2; self } @@ -1266,7 +1283,7 @@ impl ClientBuilder { /// Only use HTTP/3. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_prior_knowledge(mut self) -> ClientBuilder { + pub fn http3_prior_knowledge(mut self) -> ClientBuilder { self.config.http_version_pref = HttpVersionPref::Http3; self } @@ -1276,7 +1293,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_stream_window_size(mut self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_stream_window_size( + mut self, + sz: impl Into>, + ) -> ClientBuilder { self.config.http2_initial_stream_window_size = sz.into(); self } @@ -1289,7 +1309,7 @@ impl ClientBuilder { pub fn http2_initial_connection_window_size( mut self, sz: impl Into>, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.http2_initial_connection_window_size = sz.into(); self } @@ -1300,7 +1320,7 @@ impl ClientBuilder { /// `http2_initial_connection_window_size`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_adaptive_window(mut self, enabled: bool) -> ClientBuilder { + pub fn http2_adaptive_window(mut self, enabled: bool) -> ClientBuilder { self.config.http2_adaptive_window = enabled; self } @@ -1310,7 +1330,10 @@ impl ClientBuilder { /// Default is currently 16,384 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_frame_size(mut self, sz: impl Into>) -> ClientBuilder { + pub fn http2_max_frame_size( + mut self, + sz: impl Into>, + ) -> ClientBuilder { self.config.http2_max_frame_size = sz.into(); self } @@ -1320,7 +1343,10 @@ impl ClientBuilder { /// Default is currently 16KB, but can change. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_header_list_size(mut self, max_header_size_bytes: u32) -> ClientBuilder { + pub fn http2_max_header_list_size( + mut self, + max_header_size_bytes: u32, + ) -> ClientBuilder { self.config.http2_max_header_list_size = Some(max_header_size_bytes); self } @@ -1334,7 +1360,7 @@ impl ClientBuilder { pub fn http2_keep_alive_interval( mut self, interval: impl Into>, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.http2_keep_alive_interval = interval.into(); self } @@ -1346,7 +1372,7 @@ impl ClientBuilder { /// Default is currently disabled. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> ClientBuilder { + pub fn http2_keep_alive_timeout(mut self, timeout: Duration) -> ClientBuilder { self.config.http2_keep_alive_timeout = Some(timeout); self } @@ -1359,7 +1385,7 @@ impl ClientBuilder { /// Default is `false`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> ClientBuilder { + pub fn http2_keep_alive_while_idle(mut self, enabled: bool) -> ClientBuilder { self.config.http2_keep_alive_while_idle = enabled; self } @@ -1369,7 +1395,7 @@ impl ClientBuilder { /// Set whether sockets have `TCP_NODELAY` enabled. /// /// Default is `true`. - pub fn tcp_nodelay(mut self, enabled: bool) -> ClientBuilder { + pub fn tcp_nodelay(mut self, enabled: bool) -> ClientBuilder { self.config.nodelay = enabled; self } @@ -1387,7 +1413,7 @@ impl ClientBuilder { /// .local_address(local_addr) /// .build().unwrap(); /// ``` - pub fn local_address(mut self, addr: T) -> ClientBuilder + pub fn local_address(mut self, addr: T) -> ClientBuilder where T: Into>, { @@ -1408,7 +1434,7 @@ impl ClientBuilder { /// .build().unwrap(); /// ``` #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - pub fn interface(mut self, interface: &str) -> ClientBuilder { + pub fn interface(mut self, interface: &str) -> ClientBuilder { self.config.interface = Some(interface.to_string()); self } @@ -1416,7 +1442,7 @@ impl ClientBuilder { /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. /// /// If `None`, the option will not be set. - pub fn tcp_keepalive(mut self, val: D) -> ClientBuilder + pub fn tcp_keepalive(mut self, val: D) -> ClientBuilder where D: Into>, { @@ -1444,7 +1470,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn add_root_certificate(mut self, cert: Certificate) -> ClientBuilder { + pub fn add_root_certificate(mut self, cert: Certificate) -> ClientBuilder { self.config.root_certs.push(cert); self } @@ -1457,7 +1483,7 @@ impl ClientBuilder { /// This requires the `rustls-tls(-...)` Cargo feature enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn add_crl(mut self, crl: CertificateRevocationList) -> ClientBuilder { + pub fn add_crl(mut self, crl: CertificateRevocationList) -> ClientBuilder { self.config.crls.push(crl); self } @@ -1473,7 +1499,7 @@ impl ClientBuilder { pub fn add_crls( mut self, crls: impl IntoIterator, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.crls.extend(crls); self } @@ -1504,7 +1530,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_built_in_root_certs(mut self, tls_built_in_root_certs: bool) -> ClientBuilder { + pub fn tls_built_in_root_certs( + mut self, + tls_built_in_root_certs: bool, + ) -> ClientBuilder { self.config.tls_built_in_root_certs = tls_built_in_root_certs; #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] @@ -1525,7 +1554,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-webpki-roots-no-provider")))] - pub fn tls_built_in_webpki_certs(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_webpki_certs(mut self, enabled: bool) -> ClientBuilder { self.config.tls_built_in_certs_webpki = enabled; self } @@ -1535,7 +1564,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-native-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-native-roots-no-provider")))] - pub fn tls_built_in_native_certs(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_native_certs(mut self, enabled: bool) -> ClientBuilder { self.config.tls_built_in_certs_native = enabled; self } @@ -1548,7 +1577,7 @@ impl ClientBuilder { /// enabled. #[cfg(any(feature = "native-tls", feature = "__rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn identity(mut self, identity: Identity) -> ClientBuilder { + pub fn identity(mut self, identity: Identity) -> ClientBuilder { self.config.identity = Some(identity); self } @@ -1580,7 +1609,7 @@ impl ClientBuilder { pub fn danger_accept_invalid_hostnames( mut self, accept_invalid_hostname: bool, - ) -> ClientBuilder { + ) -> ClientBuilder { self.config.hostname_verification = !accept_invalid_hostname; self } @@ -1610,7 +1639,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> ClientBuilder { + pub fn danger_accept_invalid_certs( + mut self, + accept_invalid_certs: bool, + ) -> ClientBuilder { self.config.certs_verification = !accept_invalid_certs; self } @@ -1632,7 +1664,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_sni(mut self, tls_sni: bool) -> ClientBuilder { + pub fn tls_sni(mut self, tls_sni: bool) -> ClientBuilder { self.config.tls_sni = tls_sni; self } @@ -1661,7 +1693,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn min_tls_version(mut self, version: tls::Version) -> ClientBuilder { + pub fn min_tls_version(mut self, version: tls::Version) -> ClientBuilder { self.config.min_tls_version = Some(version); self } @@ -1693,7 +1725,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn max_tls_version(mut self, version: tls::Version) -> ClientBuilder { + pub fn max_tls_version(mut self, version: tls::Version) -> ClientBuilder { self.config.max_tls_version = Some(version); self } @@ -1708,7 +1740,7 @@ impl ClientBuilder { /// This requires the optional `native-tls` feature to be enabled. #[cfg(feature = "native-tls")] #[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] - pub fn use_native_tls(mut self) -> ClientBuilder { + pub fn use_native_tls(mut self) -> ClientBuilder { self.config.tls = TlsBackend::Default; self } @@ -1723,7 +1755,7 @@ impl ClientBuilder { /// This requires the optional `rustls-tls(-...)` feature to be enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn use_rustls_tls(mut self) -> ClientBuilder { + pub fn use_rustls_tls(mut self) -> ClientBuilder { self.config.tls = TlsBackend::Rustls; self } @@ -1748,7 +1780,7 @@ impl ClientBuilder { /// `rustls-tls(-...)` to be enabled. #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn use_preconfigured_tls(mut self, tls: impl Any) -> ClientBuilder { + pub fn use_preconfigured_tls(mut self, tls: impl Any) -> ClientBuilder { let mut tls = Some(tls); #[cfg(feature = "native-tls")] { @@ -1791,7 +1823,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_info(mut self, tls_info: bool) -> ClientBuilder { + pub fn tls_info(mut self, tls_info: bool) -> ClientBuilder { self.config.tls_info = tls_info; self } @@ -1799,7 +1831,7 @@ impl ClientBuilder { /// Restrict the Client to be used with HTTPS only requests. /// /// Defaults to false. - pub fn https_only(mut self, enabled: bool) -> ClientBuilder { + pub fn https_only(mut self, enabled: bool) -> ClientBuilder { self.config.https_only = enabled; self } @@ -1808,7 +1840,7 @@ impl ClientBuilder { #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] #[deprecated(note = "use `hickory_dns` instead")] - pub fn trust_dns(mut self, enable: bool) -> ClientBuilder { + pub fn trust_dns(mut self, enable: bool) -> ClientBuilder { self.config.hickory_dns = enable; self } @@ -1828,14 +1860,14 @@ impl ClientBuilder { /// that the default resolver does #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] - pub fn hickory_dns(mut self, enable: bool) -> ClientBuilder { + pub fn hickory_dns(mut self, enable: bool) -> ClientBuilder { self.config.hickory_dns = enable; self } #[doc(hidden)] #[deprecated(note = "use `no_hickory_dns` instead")] - pub fn no_trust_dns(self) -> ClientBuilder { + pub fn no_trust_dns(self) -> ClientBuilder { self.no_hickory_dns() } @@ -1844,7 +1876,7 @@ impl ClientBuilder { /// This method exists even if the optional `hickory-dns` feature is not enabled. /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. - pub fn no_hickory_dns(self) -> ClientBuilder { + pub fn no_hickory_dns(self) -> ClientBuilder { #[cfg(feature = "hickory-dns")] { self.hickory_dns(false) @@ -1860,7 +1892,7 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { + pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { self.resolve_to_addrs(domain, &[addr]) } @@ -1868,7 +1900,11 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve_to_addrs(mut self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + pub fn resolve_to_addrs( + mut self, + domain: &str, + addrs: &[SocketAddr], + ) -> ClientBuilder { self.config .dns_overrides .insert(domain.to_ascii_lowercase(), addrs.to_vec()); @@ -1880,7 +1916,10 @@ impl ClientBuilder { /// Pass an `Arc` wrapping a trait object implementing `Resolve`. /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will /// still be applied on top of this resolver. - pub fn dns_resolver(mut self, resolver: Arc) -> ClientBuilder { + pub fn dns_resolver( + mut self, + resolver: Arc, + ) -> ClientBuilder { self.config.dns_resolver = Some(resolver as _); self } @@ -1891,7 +1930,7 @@ impl ClientBuilder { /// The default is false. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn tls_early_data(mut self, enabled: bool) -> ClientBuilder { + pub fn tls_early_data(mut self, enabled: bool) -> ClientBuilder { self.config.tls_enable_early_data = enabled; self } @@ -1903,7 +1942,7 @@ impl ClientBuilder { /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_max_idle_timeout(mut self, value: Duration) -> ClientBuilder { + pub fn http3_max_idle_timeout(mut self, value: Duration) -> ClientBuilder { self.config.quic_max_idle_timeout = Some(value); self } @@ -1920,7 +1959,7 @@ impl ClientBuilder { /// Panics if the value is over 2^62. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_stream_receive_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_stream_receive_window(mut self, value: u64) -> ClientBuilder { self.config.quic_stream_receive_window = Some(value.try_into().unwrap()); self } @@ -1937,7 +1976,7 @@ impl ClientBuilder { /// Panics if the value is over 2^62. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_conn_receive_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_conn_receive_window(mut self, value: u64) -> ClientBuilder { self.config.quic_receive_window = Some(value.try_into().unwrap()); self } @@ -1949,10 +1988,50 @@ impl ClientBuilder { /// [`TransportConfig`]: https://docs.rs/quinn/latest/quinn/struct.TransportConfig.html #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(all(reqwest_unstable, feature = "http3",))))] - pub fn http3_send_window(mut self, value: u64) -> ClientBuilder { + pub fn http3_send_window(mut self, value: u64) -> ClientBuilder { self.config.quic_send_window = Some(value); self } + + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// If configured, the `connect_timeout` will be the outermost layer. + /// + /// Simple example: + /// ``` + /// use std::time::Duration; + /// + /// # #[cfg(not(feature = "rustls-tls-no-provider"))] + /// let client = reqwest::Client::builder() + /// // resolved to outermost layer, so before the semaphore permit is attempted + /// .connect_timeout(Duration::from_millis(100)) + /// // underneath the concurrency check, so only after a semaphore permit is acquired + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + /// + /// For a more complex example involving a custom layer, see `examples/connect_via_lower_priority_tokio_runtime.rs`. + pub fn connector_layer(self, layer: L) -> ClientBuilder> + where + S: Send + Sync + Clone + 'static, + L: Layer + Send + Sync + Clone + 'static, + { + let connector_layers = ConnectorLayerBuilder { + builder: self.connector_layers.builder.layer(layer), + has_custom_layers: true, + }; + + ClientBuilder::> { + config: self.config, + connector_layers, + } + } } type HyperClient = hyper_util::client::legacy::Client; @@ -1980,7 +2059,7 @@ impl Client { /// Creates a `ClientBuilder` to configure a `Client`. /// /// This is the same as `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { + pub fn builder() -> ClientBuilder { ClientBuilder::new() } @@ -2237,7 +2316,7 @@ impl tower_service::Service for &'_ Client { } } -impl fmt::Debug for ClientBuilder { +impl fmt::Debug for ClientBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("ClientBuilder"); self.config.fmt_fields(&mut builder); diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 73f25208f..5279d47f3 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -10,13 +10,20 @@ use std::thread; use std::time::Duration; use http::header::HeaderValue; +use http::Uri; use log::{error, trace}; use tokio::sync::{mpsc, oneshot}; +use tower::layer::util::Stack; +use tower::Layer; +use tower::Service; use super::request::{Request, RequestBuilder}; use super::response::Response; use super::wait; +use crate::connect::Conn; +use crate::connect::ConnectorService; use crate::dns::Resolve; +use crate::error::BoxError; #[cfg(feature = "__tls")] use crate::tls; #[cfg(feature = "__rustls")] @@ -69,28 +76,36 @@ pub struct Client { /// # } /// ``` #[must_use] -pub struct ClientBuilder { - inner: async_impl::ClientBuilder, +pub struct ClientBuilder { + inner: async_impl::ClientBuilder, timeout: Timeout, } -impl Default for ClientBuilder { +impl Default for ClientBuilder { fn default() -> Self { Self::new() } } -impl ClientBuilder { +impl ClientBuilder { /// Constructs a new `ClientBuilder`. /// /// This is the same as `Client::builder()`. - pub fn new() -> ClientBuilder { + pub fn new() -> Self { ClientBuilder { inner: async_impl::ClientBuilder::new(), timeout: Timeout::default(), } } +} +impl ClientBuilder +where + ConnectorLayers: Layer + Send + Sync + 'static, + ConnectorLayers::Service: + Service + Clone + Send + Sync + 'static, + <>::Service as Service>::Future: Send + 'static, +{ /// Returns a `Client` that uses this `ClientBuilder` configuration. /// /// # Errors @@ -128,7 +143,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn user_agent(self, value: V) -> ClientBuilder + pub fn user_agent(self, value: V) -> ClientBuilder where V: TryInto, V::Error: Into, @@ -160,7 +175,7 @@ impl ClientBuilder { /// # Ok(()) /// # } /// ``` - pub fn default_headers(self, headers: header::HeaderMap) -> ClientBuilder { + pub fn default_headers(self, headers: header::HeaderMap) -> ClientBuilder { self.with_inner(move |inner| inner.default_headers(headers)) } @@ -176,7 +191,7 @@ impl ClientBuilder { /// This requires the optional `cookies` feature to be enabled. #[cfg(feature = "cookies")] #[cfg_attr(docsrs, doc(cfg(feature = "cookies")))] - pub fn cookie_store(self, enable: bool) -> ClientBuilder { + pub fn cookie_store(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.cookie_store(enable)) } @@ -195,7 +210,7 @@ impl ClientBuilder { pub fn cookie_provider( self, cookie_store: Arc, - ) -> ClientBuilder { + ) -> ClientBuilder { self.with_inner(|inner| inner.cookie_provider(cookie_store)) } @@ -217,7 +232,7 @@ impl ClientBuilder { /// This requires the optional `gzip` feature to be enabled #[cfg(feature = "gzip")] #[cfg_attr(docsrs, doc(cfg(feature = "gzip")))] - pub fn gzip(self, enable: bool) -> ClientBuilder { + pub fn gzip(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.gzip(enable)) } @@ -239,7 +254,7 @@ impl ClientBuilder { /// This requires the optional `brotli` feature to be enabled #[cfg(feature = "brotli")] #[cfg_attr(docsrs, doc(cfg(feature = "brotli")))] - pub fn brotli(self, enable: bool) -> ClientBuilder { + pub fn brotli(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.brotli(enable)) } @@ -261,7 +276,7 @@ impl ClientBuilder { /// This requires the optional `zstd` feature to be enabled #[cfg(feature = "zstd")] #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] - pub fn zstd(self, enable: bool) -> ClientBuilder { + pub fn zstd(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.zstd(enable)) } @@ -283,7 +298,7 @@ impl ClientBuilder { /// This requires the optional `deflate` feature to be enabled #[cfg(feature = "deflate")] #[cfg_attr(docsrs, doc(cfg(feature = "deflate")))] - pub fn deflate(self, enable: bool) -> ClientBuilder { + pub fn deflate(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.deflate(enable)) } @@ -292,7 +307,7 @@ impl ClientBuilder { /// This method exists even if the optional `gzip` feature is not enabled. /// This can be used to ensure a `Client` doesn't use gzip decompression /// even if another dependency were to enable the optional `gzip` feature. - pub fn no_gzip(self) -> ClientBuilder { + pub fn no_gzip(self) -> ClientBuilder { self.with_inner(|inner| inner.no_gzip()) } @@ -301,7 +316,7 @@ impl ClientBuilder { /// This method exists even if the optional `brotli` feature is not enabled. /// This can be used to ensure a `Client` doesn't use brotli decompression /// even if another dependency were to enable the optional `brotli` feature. - pub fn no_brotli(self) -> ClientBuilder { + pub fn no_brotli(self) -> ClientBuilder { self.with_inner(|inner| inner.no_brotli()) } @@ -310,7 +325,7 @@ impl ClientBuilder { /// This method exists even if the optional `zstd` feature is not enabled. /// This can be used to ensure a `Client` doesn't use zstd decompression /// even if another dependency were to enable the optional `zstd` feature. - pub fn no_zstd(self) -> ClientBuilder { + pub fn no_zstd(self) -> ClientBuilder { self.with_inner(|inner| inner.no_zstd()) } @@ -319,7 +334,7 @@ impl ClientBuilder { /// This method exists even if the optional `deflate` feature is not enabled. /// This can be used to ensure a `Client` doesn't use deflate decompression /// even if another dependency were to enable the optional `deflate` feature. - pub fn no_deflate(self) -> ClientBuilder { + pub fn no_deflate(self) -> ClientBuilder { self.with_inner(|inner| inner.no_deflate()) } @@ -328,14 +343,14 @@ impl ClientBuilder { /// Set a `redirect::Policy` for this client. /// /// Default will follow redirects up to a maximum of 10. - pub fn redirect(self, policy: redirect::Policy) -> ClientBuilder { + pub fn redirect(self, policy: redirect::Policy) -> ClientBuilder { self.with_inner(move |inner| inner.redirect(policy)) } /// Enable or disable automatic setting of the `Referer` header. /// /// Default is `true`. - pub fn referer(self, enable: bool) -> ClientBuilder { + pub fn referer(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.referer(enable)) } @@ -346,7 +361,7 @@ impl ClientBuilder { /// # Note /// /// Adding a proxy will disable the automatic usage of the "system" proxy. - pub fn proxy(self, proxy: Proxy) -> ClientBuilder { + pub fn proxy(self, proxy: Proxy) -> ClientBuilder { self.with_inner(move |inner| inner.proxy(proxy)) } @@ -357,7 +372,7 @@ impl ClientBuilder { /// on all desired proxies instead. /// /// This also disables the automatic usage of the "system" proxy. - pub fn no_proxy(self) -> ClientBuilder { + pub fn no_proxy(self) -> ClientBuilder { self.with_inner(move |inner| inner.no_proxy()) } @@ -368,7 +383,7 @@ impl ClientBuilder { /// Default is 30 seconds. /// /// Pass `None` to disable timeout. - pub fn timeout(mut self, timeout: T) -> ClientBuilder + pub fn timeout(mut self, timeout: T) -> ClientBuilder where T: Into>, { @@ -379,7 +394,7 @@ impl ClientBuilder { /// Set a timeout for only the connect phase of a `Client`. /// /// Default is `None`. - pub fn connect_timeout(self, timeout: T) -> ClientBuilder + pub fn connect_timeout(self, timeout: T) -> ClientBuilder where T: Into>, { @@ -397,7 +412,7 @@ impl ClientBuilder { /// for read and write operations on connections. /// /// [log]: https://crates.io/crates/log - pub fn connection_verbose(self, verbose: bool) -> ClientBuilder { + pub fn connection_verbose(self, verbose: bool) -> ClientBuilder { self.with_inner(move |inner| inner.connection_verbose(verbose)) } @@ -408,7 +423,7 @@ impl ClientBuilder { /// Pass `None` to disable timeout. /// /// Default is 90 seconds. - pub fn pool_idle_timeout(self, val: D) -> ClientBuilder + pub fn pool_idle_timeout(self, val: D) -> ClientBuilder where D: Into>, { @@ -416,12 +431,12 @@ impl ClientBuilder { } /// Sets the maximum idle connection per host allowed in the pool. - pub fn pool_max_idle_per_host(self, max: usize) -> ClientBuilder { + pub fn pool_max_idle_per_host(self, max: usize) -> ClientBuilder { self.with_inner(move |inner| inner.pool_max_idle_per_host(max)) } /// Send headers as title case instead of lowercase. - pub fn http1_title_case_headers(self) -> ClientBuilder { + pub fn http1_title_case_headers(self) -> ClientBuilder { self.with_inner(|inner| inner.http1_title_case_headers()) } @@ -430,12 +445,18 @@ impl ClientBuilder { /// /// Newline codepoints (`\r` and `\n`) will be transformed to spaces when /// parsing. - pub fn http1_allow_obsolete_multiline_headers_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_allow_obsolete_multiline_headers_in_responses( + self, + value: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.http1_allow_obsolete_multiline_headers_in_responses(value)) } /// Sets whether invalid header lines should be silently ignored in HTTP/1 responses. - pub fn http1_ignore_invalid_headers_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_ignore_invalid_headers_in_responses( + self, + value: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.http1_ignore_invalid_headers_in_responses(value)) } @@ -444,24 +465,27 @@ impl ClientBuilder { /// /// Newline codepoints (\r and \n) will be transformed to spaces when /// parsing. - pub fn http1_allow_spaces_after_header_name_in_responses(self, value: bool) -> ClientBuilder { + pub fn http1_allow_spaces_after_header_name_in_responses( + self, + value: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.http1_allow_spaces_after_header_name_in_responses(value)) } /// Only use HTTP/1. - pub fn http1_only(self) -> ClientBuilder { + pub fn http1_only(self) -> ClientBuilder { self.with_inner(|inner| inner.http1_only()) } /// Allow HTTP/0.9 responses - pub fn http09_responses(self) -> ClientBuilder { + pub fn http09_responses(self) -> ClientBuilder { self.with_inner(|inner| inner.http09_responses()) } /// Only use HTTP/2. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_prior_knowledge(self) -> ClientBuilder { + pub fn http2_prior_knowledge(self) -> ClientBuilder { self.with_inner(|inner| inner.http2_prior_knowledge()) } @@ -470,7 +494,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_stream_window_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_stream_window_size( + self, + sz: impl Into>, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_initial_stream_window_size(sz)) } @@ -479,7 +506,10 @@ impl ClientBuilder { /// Default is currently 65,535 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_initial_connection_window_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_initial_connection_window_size( + self, + sz: impl Into>, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_initial_connection_window_size(sz)) } @@ -489,7 +519,7 @@ impl ClientBuilder { /// `http2_initial_connection_window_size`. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_adaptive_window(self, enabled: bool) -> ClientBuilder { + pub fn http2_adaptive_window(self, enabled: bool) -> ClientBuilder { self.with_inner(|inner| inner.http2_adaptive_window(enabled)) } @@ -498,7 +528,10 @@ impl ClientBuilder { /// Default is currently 16,384 but may change internally to optimize for common uses. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_frame_size(self, sz: impl Into>) -> ClientBuilder { + pub fn http2_max_frame_size( + self, + sz: impl Into>, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_max_frame_size(sz)) } @@ -507,7 +540,10 @@ impl ClientBuilder { /// Default is currently 16KB, but can change. #[cfg(feature = "http2")] #[cfg_attr(docsrs, doc(cfg(feature = "http2")))] - pub fn http2_max_header_list_size(self, max_header_size_bytes: u32) -> ClientBuilder { + pub fn http2_max_header_list_size( + self, + max_header_size_bytes: u32, + ) -> ClientBuilder { self.with_inner(|inner| inner.http2_max_header_list_size(max_header_size_bytes)) } @@ -515,7 +551,7 @@ impl ClientBuilder { /// enabled. #[cfg(feature = "http3")] #[cfg_attr(docsrs, doc(cfg(feature = "http3")))] - pub fn http3_prior_knowledge(self) -> ClientBuilder { + pub fn http3_prior_knowledge(self) -> ClientBuilder { self.with_inner(|inner| inner.http3_prior_knowledge()) } @@ -524,7 +560,7 @@ impl ClientBuilder { /// Set whether sockets have `TCP_NODELAY` enabled. /// /// Default is `true`. - pub fn tcp_nodelay(self, enabled: bool) -> ClientBuilder { + pub fn tcp_nodelay(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tcp_nodelay(enabled)) } @@ -539,7 +575,7 @@ impl ClientBuilder { /// .local_address(local_addr) /// .build().unwrap(); /// ``` - pub fn local_address(self, addr: T) -> ClientBuilder + pub fn local_address(self, addr: T) -> ClientBuilder where T: Into>, { @@ -557,14 +593,14 @@ impl ClientBuilder { /// .build().unwrap(); /// ``` #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - pub fn interface(self, interface: &str) -> ClientBuilder { + pub fn interface(self, interface: &str) -> ClientBuilder { self.with_inner(move |inner| inner.interface(interface)) } /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration. /// /// If `None`, the option will not be set. - pub fn tcp_keepalive(self, val: D) -> ClientBuilder + pub fn tcp_keepalive(self, val: D) -> ClientBuilder where D: Into>, { @@ -613,7 +649,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn add_root_certificate(self, cert: Certificate) -> ClientBuilder { + pub fn add_root_certificate(self, cert: Certificate) -> ClientBuilder { self.with_inner(move |inner| inner.add_root_certificate(cert)) } @@ -625,7 +661,7 @@ impl ClientBuilder { /// This requires the `rustls-tls(-...)` Cargo feature enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn add_crl(self, crl: CertificateRevocationList) -> ClientBuilder { + pub fn add_crl(self, crl: CertificateRevocationList) -> ClientBuilder { self.with_inner(move |inner| inner.add_crl(crl)) } @@ -640,7 +676,7 @@ impl ClientBuilder { pub fn add_crls( self, crls: impl IntoIterator, - ) -> ClientBuilder { + ) -> ClientBuilder { self.with_inner(move |inner| inner.add_crls(crls)) } @@ -661,7 +697,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_built_in_root_certs(self, tls_built_in_root_certs: bool) -> ClientBuilder { + pub fn tls_built_in_root_certs( + self, + tls_built_in_root_certs: bool, + ) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_root_certs(tls_built_in_root_certs)) } @@ -670,7 +709,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-webpki-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-webpki-roots-no-provider")))] - pub fn tls_built_in_webpki_certs(self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_webpki_certs(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_webpki_certs(enabled)) } @@ -679,7 +718,7 @@ impl ClientBuilder { /// If the feature is enabled, this value is `true` by default. #[cfg(feature = "rustls-tls-native-roots-no-provider")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls-native-roots-no-provider")))] - pub fn tls_built_in_native_certs(self, enabled: bool) -> ClientBuilder { + pub fn tls_built_in_native_certs(self, enabled: bool) -> ClientBuilder { self.with_inner(move |inner| inner.tls_built_in_native_certs(enabled)) } @@ -691,7 +730,7 @@ impl ClientBuilder { /// enabled. #[cfg(any(feature = "native-tls", feature = "__rustls"))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn identity(self, identity: Identity) -> ClientBuilder { + pub fn identity(self, identity: Identity) -> ClientBuilder { self.with_inner(move |inner| inner.identity(identity)) } @@ -719,7 +758,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_hostnames(self, accept_invalid_hostname: bool) -> ClientBuilder { + pub fn danger_accept_invalid_hostnames( + self, + accept_invalid_hostname: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_hostnames(accept_invalid_hostname)) } @@ -743,7 +785,10 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn danger_accept_invalid_certs(self, accept_invalid_certs: bool) -> ClientBuilder { + pub fn danger_accept_invalid_certs( + self, + accept_invalid_certs: bool, + ) -> ClientBuilder { self.with_inner(|inner| inner.danger_accept_invalid_certs(accept_invalid_certs)) } @@ -759,7 +804,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_sni(self, tls_sni: bool) -> ClientBuilder { + pub fn tls_sni(self, tls_sni: bool) -> ClientBuilder { self.with_inner(|inner| inner.tls_sni(tls_sni)) } @@ -787,7 +832,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn min_tls_version(self, version: tls::Version) -> ClientBuilder { + pub fn min_tls_version(self, version: tls::Version) -> ClientBuilder { self.with_inner(|inner| inner.min_tls_version(version)) } @@ -815,7 +860,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn max_tls_version(self, version: tls::Version) -> ClientBuilder { + pub fn max_tls_version(self, version: tls::Version) -> ClientBuilder { self.with_inner(|inner| inner.max_tls_version(version)) } @@ -829,7 +874,7 @@ impl ClientBuilder { /// This requires the optional `native-tls` feature to be enabled. #[cfg(feature = "native-tls")] #[cfg_attr(docsrs, doc(cfg(feature = "native-tls")))] - pub fn use_native_tls(self) -> ClientBuilder { + pub fn use_native_tls(self) -> ClientBuilder { self.with_inner(move |inner| inner.use_native_tls()) } @@ -843,7 +888,7 @@ impl ClientBuilder { /// This requires the optional `rustls-tls(-...)` feature to be enabled. #[cfg(feature = "__rustls")] #[cfg_attr(docsrs, doc(cfg(feature = "rustls-tls")))] - pub fn use_rustls_tls(self) -> ClientBuilder { + pub fn use_rustls_tls(self) -> ClientBuilder { self.with_inner(move |inner| inner.use_rustls_tls()) } @@ -862,7 +907,7 @@ impl ClientBuilder { feature = "rustls-tls" ))) )] - pub fn tls_info(self, tls_info: bool) -> ClientBuilder { + pub fn tls_info(self, tls_info: bool) -> ClientBuilder { self.with_inner(|inner| inner.tls_info(tls_info)) } @@ -886,7 +931,7 @@ impl ClientBuilder { /// `rustls-tls(-...)` to be enabled. #[cfg(any(feature = "native-tls", feature = "__rustls",))] #[cfg_attr(docsrs, doc(cfg(any(feature = "native-tls", feature = "rustls-tls"))))] - pub fn use_preconfigured_tls(self, tls: impl Any) -> ClientBuilder { + pub fn use_preconfigured_tls(self, tls: impl Any) -> ClientBuilder { self.with_inner(move |inner| inner.use_preconfigured_tls(tls)) } @@ -900,7 +945,7 @@ impl ClientBuilder { #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] #[deprecated(note = "use `hickory_dns` instead", since = "0.12.0")] - pub fn trust_dns(self, enable: bool) -> ClientBuilder { + pub fn trust_dns(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.hickory_dns(enable)) } @@ -913,7 +958,7 @@ impl ClientBuilder { /// This requires the optional `hickory-dns` feature to be enabled #[cfg(feature = "hickory-dns")] #[cfg_attr(docsrs, doc(cfg(feature = "hickory-dns")))] - pub fn hickory_dns(self, enable: bool) -> ClientBuilder { + pub fn hickory_dns(self, enable: bool) -> ClientBuilder { self.with_inner(|inner| inner.hickory_dns(enable)) } @@ -923,7 +968,7 @@ impl ClientBuilder { /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. #[deprecated(note = "use `no_hickory_dns` instead", since = "0.12.0")] - pub fn no_trust_dns(self) -> ClientBuilder { + pub fn no_trust_dns(self) -> ClientBuilder { self.with_inner(|inner| inner.no_hickory_dns()) } @@ -932,14 +977,14 @@ impl ClientBuilder { /// This method exists even if the optional `hickory-dns` feature is not enabled. /// This can be used to ensure a `Client` doesn't use the hickory-dns async resolver /// even if another dependency were to enable the optional `hickory-dns` feature. - pub fn no_hickory_dns(self) -> ClientBuilder { + pub fn no_hickory_dns(self) -> ClientBuilder { self.with_inner(|inner| inner.no_hickory_dns()) } /// Restrict the Client to be used with HTTPS only requests. /// /// Defaults to false. - pub fn https_only(self, enabled: bool) -> ClientBuilder { + pub fn https_only(self, enabled: bool) -> ClientBuilder { self.with_inner(|inner| inner.https_only(enabled)) } @@ -947,7 +992,7 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { + pub fn resolve(self, domain: &str, addr: SocketAddr) -> ClientBuilder { self.resolve_to_addrs(domain, &[addr]) } @@ -955,7 +1000,11 @@ impl ClientBuilder { /// /// Set the port to `0` to use the conventional port for the given scheme (e.g. 80 for http). /// Ports in the URL itself will always be used instead of the port in the overridden addr. - pub fn resolve_to_addrs(self, domain: &str, addrs: &[SocketAddr]) -> ClientBuilder { + pub fn resolve_to_addrs( + self, + domain: &str, + addrs: &[SocketAddr], + ) -> ClientBuilder { self.with_inner(|inner| inner.resolve_to_addrs(domain, addrs)) } @@ -964,23 +1013,63 @@ impl ClientBuilder { /// Pass an `Arc` wrapping a trait object implementing `Resolve`. /// Overrides for specific names passed to `resolve` and `resolve_to_addrs` will /// still be applied on top of this resolver. - pub fn dns_resolver(self, resolver: Arc) -> ClientBuilder { + pub fn dns_resolver( + self, + resolver: Arc, + ) -> ClientBuilder { self.with_inner(|inner| inner.dns_resolver(resolver)) } + /// Adds a new Tower [`Layer`](https://docs.rs/tower/latest/tower/trait.Layer.html) to the + /// base connector [`Service`](https://docs.rs/tower/latest/tower/trait.Service.html) which + /// is responsible for connection establishment. + /// + /// Each subsequent invocation of this function will wrap previous layers. + /// + /// Simple example: + /// ``` + /// use std::time::Duration; + /// + /// let client = reqwest::blocking::Client::builder() + /// // resolved to outermost layer, so before the semaphore permit is attempted + /// .connect_timeout(Duration::from_millis(100)) + /// // underneath the concurrency check, so only after a semaphore permit is acquired + /// .connector_layer(tower::timeout::TimeoutLayer::new(Duration::from_millis(50))) + /// .connector_layer(tower::limit::concurrency::ConcurrencyLimitLayer::new(2)) + /// .build() + /// .unwrap(); + /// ``` + pub fn connector_layer(self, layer: L) -> ClientBuilder> + where + S: Send + Sync + Clone + 'static, + L: Layer + Send + Sync + Clone + 'static, + { + // skipping using `with_inner` here because we need to cast the generic type + let inner = self.inner.connector_layer(layer); + + ClientBuilder::> { + inner, + timeout: self.timeout, + } + } + // private - fn with_inner(mut self, func: F) -> ClientBuilder + fn with_inner(mut self, func: F) -> ClientBuilder where - F: FnOnce(async_impl::ClientBuilder) -> async_impl::ClientBuilder, + F: FnOnce( + async_impl::ClientBuilder, + ) -> async_impl::ClientBuilder, { self.inner = func(self.inner); self } } -impl From for ClientBuilder { - fn from(builder: async_impl::ClientBuilder) -> Self { +impl From> + for ClientBuilder +{ + fn from(builder: async_impl::ClientBuilder) -> Self { Self { inner: builder, timeout: Timeout::default(), @@ -1014,7 +1103,7 @@ impl Client { /// Creates a `ClientBuilder` to configure a `Client`. /// /// This is the same as `ClientBuilder::new()`. - pub fn builder() -> ClientBuilder { + pub fn builder() -> ClientBuilder { ClientBuilder::new() } @@ -1112,7 +1201,7 @@ impl fmt::Debug for Client { } } -impl fmt::Debug for ClientBuilder { +impl fmt::Debug for ClientBuilder { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.inner.fmt(f) } @@ -1149,7 +1238,14 @@ impl Drop for InnerClientHandle { } impl ClientHandle { - fn new(builder: ClientBuilder) -> crate::Result { + fn new(builder: ClientBuilder) -> crate::Result + where + ConnectorLayers: Layer + Send + Sync + 'static, + ConnectorLayers::Service: + Service + Clone + Send + Sync + 'static, + <>::Service as Service>::Future: + Send + 'static, + { let timeout = builder.timeout; let builder = builder.inner; let (tx, rx) = mpsc::unbounded_channel::<(async_impl::Request, OneshotResponse)>(); diff --git a/src/connect.rs b/src/connect.rs index ff86ba3c9..e1fc599b1 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -8,6 +8,7 @@ use hyper_util::client::legacy::connect::{Connected, Connection}; use hyper_util::rt::TokioIo; #[cfg(feature = "default-tls")] use native_tls_crate::{TlsConnector, TlsConnectorBuilder}; +use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, Layer, ServiceBuilder}; use tower_service::Service; use pin_project_lite::pin_project; @@ -24,13 +25,50 @@ use self::native_tls_conn::NativeTlsConn; #[cfg(feature = "__rustls")] use self::rustls_tls_conn::RustlsTlsConn; use crate::dns::DynResolver; -use crate::error::BoxError; +use crate::error::{cast_to_internal_error, BoxError}; use crate::proxy::{Proxy, ProxyScheme}; pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector; #[derive(Clone)] -pub(crate) struct Connector { +pub(crate) enum Connector { + // base service, with or without an embedded timeout + Simple(ConnectorService), + // at least one custom layer along with maybe an outer timeout layer + // from `builder.connect_timeout()` + WithLayers(BoxCloneSyncService), +} + +impl Service for Connector { + type Response = Conn; + type Error = BoxError; + type Future = Connecting; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + Connector::Simple(service) => service.poll_ready(cx), + Connector::WithLayers(service) => service.poll_ready(cx), + } + } + + fn call(&mut self, dst: Uri) -> Self::Future { + match self { + Connector::Simple(service) => service.call(dst), + Connector::WithLayers(service) => service.call(dst), + } + } +} + +pub(crate) struct ConnectorLayerBuilder { + pub(crate) builder: ServiceBuilder, + // It's not trivial to identify whether the builder stack is `Stack` or not + // so we simply add a boolean flag to track. + // + // Knowing allows us reduce indirection in certain cases. + pub(crate) has_custom_layers: bool, +} + +pub(crate) struct ConnectorBuilder { inner: Inner, proxies: Arc>, verbose: verbose::Wrapper, @@ -43,21 +81,66 @@ pub(crate) struct Connector { user_agent: Option, } -#[derive(Clone)] -enum Inner { - #[cfg(not(feature = "__tls"))] - Http(HttpConnector), - #[cfg(feature = "default-tls")] - DefaultTls(HttpConnector, TlsConnector), - #[cfg(feature = "__rustls")] - RustlsTls { - http: HttpConnector, - tls: Arc, - tls_proxy: Arc, - }, -} +impl ConnectorBuilder { + pub(crate) fn build( + self, + layer: ConnectorLayerBuilder, + ) -> Connector + where + ConnectorLayers: Layer, + ConnectorLayers::Service: + Service + Clone + Send + Sync + 'static, + <>::Service as Service>::Future: + Send + 'static, + { + // construct the inner tower service + let mut base_service = ConnectorService { + inner: self.inner, + proxies: self.proxies, + verbose: self.verbose, + #[cfg(feature = "__tls")] + nodelay: self.nodelay, + #[cfg(feature = "__tls")] + tls_info: self.tls_info, + #[cfg(feature = "__tls")] + user_agent: self.user_agent, + simple_timeout: None, + }; + + // no user-provider layers so we can throw away our generic input layer stack + // and compose with named layers only + if !layer.has_custom_layers { + // if we know we have no other layers, we can embed the timeout directly inside + // our base service call which saves us a Box::pin + base_service.simple_timeout = self.timeout; + return Connector::Simple(base_service); + } + + // we have user-provided generic layer stack + let service = layer.builder.service(base_service); + + if let Some(timeout) = self.timeout { + // add in named timeout layer on the outside of the stack + let service = ServiceBuilder::new() + .layer(TimeoutLayer::new(timeout)) + .service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + return Connector::WithLayers(service); + } + + // no named timeout layer but we still map errors since + // we might have user-provided timeout layer + let service = ServiceBuilder::new().service(service); + let service = ServiceBuilder::new() + .map_err(|error: BoxError| cast_to_internal_error(error)) + .service(service); + let service = BoxCloneSyncService::new(service); + return Connector::WithLayers(service); + } -impl Connector { #[cfg(not(feature = "__tls"))] pub(crate) fn new( mut http: HttpConnector, @@ -66,7 +149,7 @@ impl Connector { #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] interface: Option<&str>, nodelay: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -77,10 +160,10 @@ impl Connector { } http.set_nodelay(nodelay); - Connector { + ConnectorBuilder { inner: Inner::Http(http), - verbose: verbose::OFF, proxies, + verbose: verbose::OFF, timeout: None, } } @@ -96,7 +179,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> crate::Result + ) -> crate::Result where T: Into>, { @@ -125,7 +208,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -137,14 +220,14 @@ impl Connector { http.set_nodelay(nodelay); http.enforce_http(false); - Connector { + ConnectorBuilder { inner: Inner::DefaultTls(http, tls), proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -159,7 +242,7 @@ impl Connector { interface: Option<&str>, nodelay: bool, tls_info: bool, - ) -> Connector + ) -> ConnectorBuilder where T: Into>, { @@ -180,7 +263,7 @@ impl Connector { (Arc::new(tls), Arc::new(tls_proxy)) }; - Connector { + ConnectorBuilder { inner: Inner::RustlsTls { http, tls, @@ -188,10 +271,10 @@ impl Connector { }, proxies, verbose: verbose::OFF, - timeout: None, nodelay, tls_info, user_agent, + timeout: None, } } @@ -203,6 +286,57 @@ impl Connector { self.verbose.0 = enabled; } + pub(crate) fn set_keepalive(&mut self, dur: Option) { + match &mut self.inner { + #[cfg(feature = "default-tls")] + Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), + #[cfg(feature = "__rustls")] + Inner::RustlsTls { http, .. } => http.set_keepalive(dur), + #[cfg(not(feature = "__tls"))] + Inner::Http(http) => http.set_keepalive(dur), + } + } +} + +// Struct is public because we can't have private trait in public bounds +// until we are on MSRV 1.74 when private-in-public was switched +// from error to lint - https://github.com/rust-lang/rfcs/pull/2145 +// but no internal details are exposed. We don't expose debug +// for similar reasons. +#[allow(missing_debug_implementations)] +#[derive(Clone)] +pub struct ConnectorService { + inner: Inner, + proxies: Arc>, + verbose: verbose::Wrapper, + /// When there is a single timeout layer and no other layers, + /// we embed it directly inside our base Service::call(). + /// This lets us avoid an extra `Box::pin` indirection layer + /// since `tokio::time::Timeout` is `Unpin` + simple_timeout: Option, + #[cfg(feature = "__tls")] + nodelay: bool, + #[cfg(feature = "__tls")] + tls_info: bool, + #[cfg(feature = "__tls")] + user_agent: Option, +} + +#[derive(Clone)] +enum Inner { + #[cfg(not(feature = "__tls"))] + Http(HttpConnector), + #[cfg(feature = "default-tls")] + DefaultTls(HttpConnector, TlsConnector), + #[cfg(feature = "__rustls")] + RustlsTls { + http: HttpConnector, + tls: Arc, + tls_proxy: Arc, + }, +} + +impl ConnectorService { #[cfg(feature = "socks")] async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result { let dns = match proxy { @@ -449,17 +583,6 @@ impl Connector { self.connect_with_maybe_proxy(proxy_dst, true).await } - - pub fn set_keepalive(&mut self, dur: Option) { - match &mut self.inner { - #[cfg(feature = "default-tls")] - Inner::DefaultTls(http, _tls) => http.set_keepalive(dur), - #[cfg(feature = "__rustls")] - Inner::RustlsTls { http, .. } => http.set_keepalive(dur), - #[cfg(not(feature = "__tls"))] - Inner::Http(http) => http.set_keepalive(dur), - } - } } fn into_uri(scheme: Scheme, host: Authority) -> Uri { @@ -487,7 +610,7 @@ where } } -impl Service for Connector { +impl Service for ConnectorService { type Response = Conn; type Error = BoxError; type Future = Connecting; @@ -498,7 +621,7 @@ impl Service for Connector { fn call(&mut self, dst: Uri) -> Self::Future { log::debug!("starting new connection: {dst:?}"); - let timeout = self.timeout; + let timeout = self.simple_timeout; for prox in self.proxies.iter() { if let Some(proxy_scheme) = prox.intercept(&dst) { return Box::pin(with_timeout( @@ -634,11 +757,18 @@ impl AsyncConnWithInfo for T {} type BoxConn = Box; pin_project! { + // Struct is public because we can't have private trait in public bounds + // until we are on MSRV 1.74 when private-in-public was switched + // from error to lint - https://github.com/rust-lang/rfcs/pull/2145 + // but no internal details are exposed. We don't expose debug + // for similar reasons. + /// Note: the `is_proxy` member means *is plain text HTTP proxy*. /// This tells hyper whether the URI should be written in /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. - pub(crate) struct Conn { + #[allow(missing_debug_implementations)] + pub struct Conn { #[pin] inner: BoxConn, is_proxy: bool, diff --git a/src/error.rs b/src/error.rs index ca7413fd6..6a9f07e51 100644 --- a/src/error.rs +++ b/src/error.rs @@ -165,6 +165,18 @@ impl Error { } } +/// Converts from external types to reqwest's +/// internal equivalents. +/// +/// Currently only is used for `tower::timeout::error::Elapsed`. +pub(crate) fn cast_to_internal_error(error: BoxError) -> BoxError { + if error.is::() { + Box::new(crate::error::TimedOut) as BoxError + } else { + error + } +} + impl fmt::Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut builder = f.debug_struct("reqwest::Error"); diff --git a/tests/connector_layers.rs b/tests/connector_layers.rs new file mode 100644 index 000000000..1be18aeb8 --- /dev/null +++ b/tests/connector_layers.rs @@ -0,0 +1,374 @@ +#![cfg(not(target_arch = "wasm32"))] +#![cfg(not(feature = "rustls-tls-manual-roots-no-provider"))] +mod support; + +use std::time::Duration; + +use futures_util::future::join_all; +use tower::layer::util::Identity; +use tower::limit::ConcurrencyLimitLayer; +use tower::timeout::TimeoutLayer; + +use support::{delay_layer::DelayLayer, server}; + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn non_op_layer_with_timeout() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(Identity::new()) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_never_returning() { + let _ = env_logger::try_init(); + + let client = reqwest::Client::builder() + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_connect_timeout_layer_slow() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(100))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_under_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(300))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(500))) + .connect_timeout(Duration::from_millis(200)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn multiple_timeout_layers_over_threshold() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .connect_timeout(Duration::from_millis(50)) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send().await; + + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + let timed_out = all_res + .into_iter() + .any(|res| res.is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn with_concurrency_limit_layer_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .no_proxy() + .build() + .unwrap(); + + // first call succeeds since no resource contention + let res = client.get(url.clone()).send().await; + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut futures = Vec::new(); + for _ in 0..3 { + futures.push(client.clone().get(url.clone()).send()); + } + + let all_res = join_all(futures).await; + + for res in all_res.into_iter() { + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(feature = "blocking")] +#[test] +fn non_op_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(Identity::new()) + .build() + .unwrap(); + + let res = client.get(url).send(); + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn timeout_layer_blocking_client() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(50))) + .no_proxy() + .build() + .unwrap(); + + let res = client.get(url).send(); + let err = res.unwrap_err(); + + assert!(err.is_connect() && err.is_timeout()); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_timeout() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(200)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls where the second two wait on the first and time out + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + let timed_out = join_handles + .into_iter() + .any(|handle| handle.join().unwrap().is_err_and(|err| err.is_timeout())); + + assert!(timed_out, "at least one request should have timed out"); +} + +#[cfg(feature = "blocking")] +#[test] +fn concurrency_layer_blocking_client_success() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::builder() + .connector_layer(DelayLayer::new(Duration::from_millis(100))) + .connector_layer(TimeoutLayer::new(Duration::from_millis(200))) + .connector_layer(ConcurrencyLimitLayer::new(1)) + .timeout(Duration::from_millis(1000)) + .pool_max_idle_per_host(0) // disable connection reuse to force resource contention on the concurrency limit semaphore + .build() + .unwrap(); + + let res = client.get(url.clone()).send(); + + assert!(res.is_ok()); + + // 3 calls of which all are individually below the inner timeout + // and the sum is below outer timeout which affects the final call which waited the whole time + let mut join_handles = Vec::new(); + for _ in 0..3 { + let client = client.clone(); + let url = url.clone(); + let join_handle = std::thread::spawn(move || client.get(url.clone()).send()); + join_handles.push(join_handle); + } + + for handle in join_handles { + let res = handle.join().unwrap(); + assert!( + res.is_ok(), + "neither outer long timeout or inner short timeout should be exceeded" + ); + } +} + +#[cfg(not(target_arch = "wasm32"))] +#[tokio::test] +async fn no_generic_bounds_required_for_client_new() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::Client::new(); + let res = client.get(url).send().await; + + assert!(res.is_ok()); +} + +#[cfg(feature = "blocking")] +#[test] +fn no_generic_bounds_required_for_client_new_blocking() { + let _ = env_logger::try_init(); + + let server = server::http(move |_req| async { http::Response::default() }); + + let url = format!("http://{}", server.addr()); + + let client = reqwest::blocking::Client::new(); + let res = client.get(url).send(); + + assert!(res.is_ok()); +} diff --git a/tests/support/delay_layer.rs b/tests/support/delay_layer.rs new file mode 100644 index 000000000..b8eec42a1 --- /dev/null +++ b/tests/support/delay_layer.rs @@ -0,0 +1,119 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; + +use pin_project_lite::pin_project; +use tokio::time::Sleep; +use tower::{BoxError, Layer, Service}; + +/// This tower layer injects an arbitrary delay before calling downstream layers. +#[derive(Clone)] +pub struct DelayLayer { + delay: Duration, +} + +impl DelayLayer { + pub const fn new(delay: Duration) -> Self { + DelayLayer { delay } + } +} + +impl Layer for DelayLayer { + type Service = Delay; + fn layer(&self, service: S) -> Self::Service { + Delay::new(service, self.delay) + } +} + +impl std::fmt::Debug for DelayLayer { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("DelayLayer") + .field("delay", &self.delay) + .finish() + } +} + +/// This tower service injects an arbitrary delay before calling downstream layers. +#[derive(Debug, Clone)] +pub struct Delay { + inner: S, + delay: Duration, +} +impl Delay { + pub fn new(inner: S, delay: Duration) -> Self { + Delay { inner, delay } + } +} + +impl Service for Delay +where + S: Service, + S::Error: Into, +{ + type Response = S::Response; + + type Error = BoxError; + + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.inner.poll_ready(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)), + } + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + let sleep = tokio::time::sleep(self.delay); + + ResponseFuture::new(response, sleep) + } +} + +// `Delay` response future +pin_project! { + #[derive(Debug)] + pub struct ResponseFuture { + #[pin] + response: S, + #[pin] + sleep: Sleep, + } +} + +impl ResponseFuture { + pub(crate) fn new(response: S, sleep: Sleep) -> Self { + ResponseFuture { response, sleep } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + // First poll the sleep until complete + match this.sleep.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(_) => {} + } + + // Then poll the inner future + match this.response.poll(cx) { + Poll::Ready(v) => Poll::Ready(v.map_err(Into::into)), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index c796956d8..9d4ce7b9b 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -1,3 +1,4 @@ +pub mod delay_layer; pub mod delay_server; pub mod server; diff --git a/tests/timeouts.rs b/tests/timeouts.rs index 79a6fbb4d..71dc0ce66 100644 --- a/tests/timeouts.rs +++ b/tests/timeouts.rs @@ -337,6 +337,24 @@ fn timeout_blocking_request() { assert_eq!(err.url().map(|u| u.as_str()), Some(url.as_str())); } +#[cfg(feature = "blocking")] +#[test] +fn connect_timeout_blocking_request() { + let _ = env_logger::try_init(); + + let client = reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_millis(100)) + .build() + .unwrap(); + + // never returns + let url = "http://192.0.2.1:81/slow"; + + let err = client.get(url).send().unwrap_err(); + + assert!(err.is_timeout()); +} + #[cfg(feature = "blocking")] #[cfg(feature = "stream")] #[test]