diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aabd79d16..74fbae5ee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -103,23 +103,23 @@ jobs: - name: windows / stable-x86_64-msvc os: windows-latest target: x86_64-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart,stream" - name: windows / stable-i686-msvc os: windows-latest target: i686-pc-windows-msvc - features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart,stream" - name: windows / stable-x86_64-gnu os: windows-latest rust: stable-x86_64-pc-windows-gnu target: x86_64-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart,stream" package_name: mingw-w64-x86_64-gcc mingw64_path: "C:\\msys64\\mingw64\\bin" - name: windows / stable-i686-gnu os: windows-latest rust: stable-i686-pc-windows-gnu target: i686-pc-windows-gnu - features: "--features blocking,gzip,brotli,deflate,json,multipart,stream" + features: "--features blocking,gzip,brotli,zstd,deflate,json,multipart,stream" package_name: mingw-w64-i686-gcc mingw64_path: "C:\\msys64\\mingw32\\bin" @@ -145,6 +145,8 @@ jobs: features: "--features gzip,stream" - name: "feat.: brotli" features: "--features brotli,stream" + - name: "feat.: zstd" + features: "--features zstd,stream" - name: "feat.: deflate" features: "--features deflate,stream" - name: "feat.: json" diff --git a/Cargo.toml b/Cargo.toml index 859a3c92c..23c0476bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,8 @@ gzip = ["dep:async-compression", "async-compression?/gzip", "dep:tokio-util"] brotli = ["dep:async-compression", "async-compression?/brotli", "dep:tokio-util"] +zstd = ["dep:async-compression", "async-compression?/zstd", "dep:tokio-util"] + deflate = ["dep:async-compression", "async-compression?/zlib", "dep:tokio-util"] json = ["dep:serde_json"] @@ -167,6 +169,7 @@ hyper-util = { version = "0.1", features = ["http1", "http2", "client", "client- serde = { version = "1.0", features = ["derive"] } libflate = "1.0" brotli_crate = { package = "brotli", version = "3.3.0" } +zstd_crate = { package = "zstd", version = "0.13" } doc-comment = "0.3" tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] } futures-util = { version = "0.3.0", default-features = false, features = ["std", "alloc"] } @@ -258,6 +261,11 @@ name = "brotli" path = "tests/brotli.rs" required-features = ["brotli", "stream"] +[[test]] +name = "zstd" +path = "tests/zstd.rs" +required-features = ["zstd", "stream"] + [[test]] name = "deflate" path = "tests/deflate.rs" diff --git a/src/async_impl/client.rs b/src/async_impl/client.rs index a64869582..7a3f43937 100644 --- a/src/async_impl/client.rs +++ b/src/async_impl/client.rs @@ -915,6 +915,29 @@ impl ClientBuilder { self } + /// Enable auto zstd decompression by checking the `Content-Encoding` response header. + /// + /// If auto zstd decompression is turned on: + /// + /// - When sending a request and if the request's headers do not already contain + /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `zstd`. + /// The request body is **not** automatically compressed. + /// - When receiving a response, if its headers contain a `Content-Encoding` value of + /// `zstd`, both `Content-Encoding` and `Content-Length` are removed from the + /// headers' set. The response body is automatically decompressed. + /// + /// If the `zstd` feature is turned on, the default option is enabled. + /// + /// # Optional + /// + /// This requires the optional `zstd` feature to be enabled + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + pub fn zstd(mut self, enable: bool) -> ClientBuilder { + self.config.accepts.zstd = enable; + self + } + /// Enable auto deflate decompression by checking the `Content-Encoding` response header. /// /// If auto deflate decompression is turned on: @@ -972,6 +995,23 @@ impl ClientBuilder { } } + /// Disable auto response body zstd decompression. + /// + /// This method exists even if the optional `zstd` feature is not enabled. + /// This can be used to ensure a `Client` doesn't use zstd decompression + /// even if another dependency were to enable the optional `zstd` feature. + pub fn no_zstd(self) -> ClientBuilder { + #[cfg(feature = "zstd")] + { + self.zstd(false) + } + + #[cfg(not(feature = "zstd"))] + { + self + } + } + /// Disable auto response body deflate decompression. /// /// This method exists even if the optional `deflate` feature is not enabled. diff --git a/src/async_impl/decoder.rs b/src/async_impl/decoder.rs index 128f77ecb..4e05428ae 100644 --- a/src/async_impl/decoder.rs +++ b/src/async_impl/decoder.rs @@ -9,6 +9,9 @@ use async_compression::tokio::bufread::GzipDecoder; #[cfg(feature = "brotli")] use async_compression::tokio::bufread::BrotliDecoder; +#[cfg(feature = "zstd")] +use async_compression::tokio::bufread::ZstdDecoder; + #[cfg(feature = "deflate")] use async_compression::tokio::bufread::ZlibDecoder; @@ -19,9 +22,19 @@ use http::HeaderMap; use hyper::body::Body as HttpBody; use hyper::body::Frame; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "brotli", + feature = "zstd", + feature = "deflate" +))] use tokio_util::codec::{BytesCodec, FramedRead}; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "brotli", + feature = "zstd", + feature = "deflate" +))] use tokio_util::io::StreamReader; use super::body::ResponseBody; @@ -33,6 +46,8 @@ pub(super) struct Accepts { pub(super) gzip: bool, #[cfg(feature = "brotli")] pub(super) brotli: bool, + #[cfg(feature = "zstd")] + pub(super) zstd: bool, #[cfg(feature = "deflate")] pub(super) deflate: bool, } @@ -44,6 +59,8 @@ impl Accepts { gzip: false, #[cfg(feature = "brotli")] brotli: false, + #[cfg(feature = "zstd")] + zstd: false, #[cfg(feature = "deflate")] deflate: false, } @@ -59,7 +76,12 @@ pub(crate) struct Decoder { type PeekableIoStream = Peekable; -#[cfg(any(feature = "gzip", feature = "brotli", feature = "deflate"))] +#[cfg(any( + feature = "gzip", + feature = "zstd", + feature = "brotli", + feature = "deflate" +))] type PeekableIoStreamReader = StreamReader; enum Inner { @@ -74,12 +96,21 @@ enum Inner { #[cfg(feature = "brotli")] Brotli(Pin, BytesCodec>>>), + /// A `Zstd` decoder will uncompress the zstd compressed response content before returning it. + #[cfg(feature = "zstd")] + Zstd(Pin, BytesCodec>>>), + /// A `Deflate` decoder will uncompress the deflated response content before returning it. #[cfg(feature = "deflate")] Deflate(Pin, BytesCodec>>>), /// A decoder that doesn't have a value yet. - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] Pending(Pin>), } @@ -93,6 +124,8 @@ enum DecoderType { Gzip, #[cfg(feature = "brotli")] Brotli, + #[cfg(feature = "zstd")] + Zstd, #[cfg(feature = "deflate")] Deflate, } @@ -155,6 +188,21 @@ impl Decoder { } } + /// A zstd decoder. + /// + /// This decoder will buffer and decompress chunks that are zstd compressed. + #[cfg(feature = "zstd")] + fn zstd(body: ResponseBody) -> Decoder { + use futures_util::StreamExt; + + Decoder { + inner: Inner::Pending(Box::pin(Pending( + IoStream(body).peekable(), + DecoderType::Zstd, + ))), + } + } + /// A deflate decoder. /// /// This decoder will buffer and decompress chunks that are deflated. @@ -170,7 +218,12 @@ impl Decoder { } } - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] fn detect_encoding(headers: &mut HeaderMap, encoding_str: &str) -> bool { use http::header::{CONTENT_ENCODING, CONTENT_LENGTH, TRANSFER_ENCODING}; use log::warn; @@ -225,6 +278,13 @@ impl Decoder { } } + #[cfg(feature = "zstd")] + { + if _accepts.zstd && Decoder::detect_encoding(_headers, "zstd") { + return Decoder::zstd(body); + } + } + #[cfg(feature = "deflate")] { if _accepts.deflate && Decoder::detect_encoding(_headers, "deflate") { @@ -245,7 +305,12 @@ impl HttpBody for Decoder { cx: &mut Context, ) -> Poll, Self::Error>>> { match self.inner { - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] Inner::Pending(ref mut future) => match Pin::new(future).poll(cx) { Poll::Ready(Ok(inner)) => { self.inner = inner; @@ -277,6 +342,14 @@ impl HttpBody for Decoder { None => Poll::Ready(None), } } + #[cfg(feature = "zstd")] + Inner::Zstd(ref mut decoder) => { + match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { + Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))), + Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))), + None => Poll::Ready(None), + } + } #[cfg(feature = "deflate")] Inner::Deflate(ref mut decoder) => { match futures_core::ready!(Pin::new(decoder).poll_next(cx)) { @@ -292,7 +365,12 @@ impl HttpBody for Decoder { match self.inner { Inner::PlainText(ref body) => HttpBody::size_hint(body), // the rest are "unknown", so default - #[cfg(any(feature = "brotli", feature = "gzip", feature = "deflate"))] + #[cfg(any( + feature = "brotli", + feature = "zstd", + feature = "gzip", + feature = "deflate" + ))] _ => http_body::SizeHint::default(), } } @@ -332,6 +410,11 @@ impl Future for Pending { BrotliDecoder::new(StreamReader::new(_body)), BytesCodec::new(), ))))), + #[cfg(feature = "zstd")] + DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new( + ZstdDecoder::new(StreamReader::new(_body)), + BytesCodec::new(), + ))))), #[cfg(feature = "gzip")] DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new( GzipDecoder::new(StreamReader::new(_body)), @@ -381,6 +464,8 @@ impl Accepts { gzip: false, #[cfg(feature = "brotli")] brotli: false, + #[cfg(feature = "zstd")] + zstd: false, #[cfg(feature = "deflate")] deflate: false, } @@ -388,15 +473,28 @@ impl Accepts { */ pub(super) fn as_str(&self) -> Option<&'static str> { - match (self.is_gzip(), self.is_brotli(), self.is_deflate()) { - (true, true, true) => Some("gzip, br, deflate"), - (true, true, false) => Some("gzip, br"), - (true, false, true) => Some("gzip, deflate"), - (false, true, true) => Some("br, deflate"), - (true, false, false) => Some("gzip"), - (false, true, false) => Some("br"), - (false, false, true) => Some("deflate"), - (false, false, false) => None, + match ( + self.is_gzip(), + self.is_brotli(), + self.is_zstd(), + self.is_deflate(), + ) { + (true, true, true, true) => Some("gzip, br, zstd, deflate"), + (true, true, false, true) => Some("gzip, br, deflate"), + (true, true, true, false) => Some("gzip, br, zstd"), + (true, true, false, false) => Some("gzip, br"), + (true, false, true, true) => Some("gzip, zstd, deflate"), + (true, false, false, true) => Some("gzip, zstd, deflate"), + (false, true, true, true) => Some("br, zstd, deflate"), + (false, true, false, true) => Some("br, zstd, deflate"), + (true, false, true, false) => Some("gzip, zstd"), + (true, false, false, false) => Some("gzip"), + (false, true, true, false) => Some("br, zstd"), + (false, true, false, false) => Some("br"), + (false, false, true, true) => Some("zstd, deflate"), + (false, false, true, false) => Some("zstd"), + (false, false, false, true) => Some("deflate"), + (false, false, false, false) => None, } } @@ -424,6 +522,18 @@ impl Accepts { } } + fn is_zstd(&self) -> bool { + #[cfg(feature = "zstd")] + { + self.zstd + } + + #[cfg(not(feature = "zstd"))] + { + false + } + } + fn is_deflate(&self) -> bool { #[cfg(feature = "deflate")] { @@ -444,6 +554,8 @@ impl Default for Accepts { gzip: true, #[cfg(feature = "brotli")] brotli: true, + #[cfg(feature = "zstd")] + zstd: true, #[cfg(feature = "deflate")] deflate: true, } diff --git a/src/blocking/client.rs b/src/blocking/client.rs index 5b861cb3e..6dc19e6b9 100644 --- a/src/blocking/client.rs +++ b/src/blocking/client.rs @@ -260,6 +260,28 @@ impl ClientBuilder { self.with_inner(|inner| inner.brotli(enable)) } + /// Enable auto zstd decompression by checking the `Content-Encoding` response header. + /// + /// If auto zstd decompression is turned on: + /// + /// - When sending a request and if the request's headers do not already contain + /// an `Accept-Encoding` **and** `Range` values, the `Accept-Encoding` header is set to `zstd`. + /// The request body is **not** automatically compressed. + /// - When receiving a response, if its headers contain a `Content-Encoding` value of + /// `zstd`, both `Content-Encoding` and `Content-Length` are removed from the + /// headers' set. The response body is automatically decompressed. + /// + /// If the `zstd` feature is turned on, the default option is enabled. + /// + /// # Optional + /// + /// This requires the optional `zstd` feature to be enabled + #[cfg(feature = "zstd")] + #[cfg_attr(docsrs, doc(cfg(feature = "zstd")))] + pub fn zstd(self, enable: bool) -> ClientBuilder { + self.with_inner(|inner| inner.zstd(enable)) + } + /// Enable auto deflate decompression by checking the `Content-Encoding` response header. /// /// If auto deflate decompresson is turned on: @@ -300,6 +322,15 @@ impl ClientBuilder { self.with_inner(|inner| inner.no_brotli()) } + /// Disable auto response body zstd decompression. + /// + /// This method exists even if the optional `zstd` feature is not enabled. + /// This can be used to ensure a `Client` doesn't use zstd decompression + /// even if another dependency were to enable the optional `zstd` feature. + pub fn no_zstd(self) -> ClientBuilder { + self.with_inner(|inner| inner.no_zstd()) + } + /// Disable auto response body deflate decompression. /// /// This method exists even if the optional `deflate` feature is not enabled. diff --git a/src/lib.rs b/src/lib.rs index ce4549dd9..d62cb8210 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -197,6 +197,7 @@ //! - **cookies**: Provides cookie session support. //! - **gzip**: Provides response body gzip decompression. //! - **brotli**: Provides response body brotli decompression. +//! - **zstd**: Provides response body zstd decompression. //! - **deflate**: Provides response body deflate decompression. //! - **json**: Provides serialization and deserialization for JSON bodies. //! - **multipart**: Provides functionality for multipart forms. diff --git a/tests/client.rs b/tests/client.rs index 02f32fc85..ab13fc9a5 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -29,6 +29,12 @@ async fn auto_headers() { .unwrap() .contains("br")); } + if cfg!(feature = "zstd") { + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + } if cfg!(feature = "deflate") { assert!(req.headers()["accept-encoding"] .to_str() diff --git a/tests/zstd.rs b/tests/zstd.rs new file mode 100644 index 000000000..d1886ee49 --- /dev/null +++ b/tests/zstd.rs @@ -0,0 +1,144 @@ +mod support; +use support::server; + +#[tokio::test] +async fn zstd_response() { + zstd_case(10_000, 4096).await; +} + +#[tokio::test] +async fn zstd_single_byte_chunks() { + zstd_case(10, 1).await; +} + +#[tokio::test] +async fn test_zstd_empty_body() { + let server = server::http(move |req| async move { + assert_eq!(req.method(), "HEAD"); + + http::Response::builder() + .header("content-encoding", "zstd") + .body(Default::default()) + .unwrap() + }); + + let client = reqwest::Client::new(); + let res = client + .head(&format!("http://{}/zstd", server.addr())) + .send() + .await + .unwrap(); + + let body = res.text().await.unwrap(); + + assert_eq!(body, ""); +} + +#[tokio::test] +async fn test_accept_header_is_not_changed_if_set() { + let server = server::http(move |req| async move { + assert_eq!(req.headers()["accept"], "application/json"); + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + http::Response::default() + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/accept", server.addr())) + .header( + reqwest::header::ACCEPT, + reqwest::header::HeaderValue::from_static("application/json"), + ) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + +#[tokio::test] +async fn test_accept_encoding_header_is_not_changed_if_set() { + let server = server::http(move |req| async move { + assert_eq!(req.headers()["accept"], "*/*"); + assert_eq!(req.headers()["accept-encoding"], "identity"); + http::Response::default() + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/accept-encoding", server.addr())) + .header( + reqwest::header::ACCEPT_ENCODING, + reqwest::header::HeaderValue::from_static("identity"), + ) + .send() + .await + .unwrap(); + + assert_eq!(res.status(), reqwest::StatusCode::OK); +} + +async fn zstd_case(response_size: usize, chunk_size: usize) { + use futures_util::stream::StreamExt; + + let content: String = (0..response_size) + .into_iter() + .map(|i| format!("test {i}")) + .collect(); + + let zstded_content = zstd_crate::encode_all(content.as_bytes(), 3).unwrap(); + + let mut response = format!( + "\ + HTTP/1.1 200 OK\r\n\ + Server: test-accept\r\n\ + Content-Encoding: zstd\r\n\ + Content-Length: {}\r\n\ + \r\n", + &zstded_content.len() + ) + .into_bytes(); + response.extend(&zstded_content); + + let server = server::http(move |req| { + assert!(req.headers()["accept-encoding"] + .to_str() + .unwrap() + .contains("zstd")); + + let zstded = zstded_content.clone(); + async move { + let len = zstded.len(); + let stream = + futures_util::stream::unfold((zstded, 0), move |(zstded, pos)| async move { + let chunk = zstded.chunks(chunk_size).nth(pos)?.to_vec(); + + Some((chunk, (zstded, pos + 1))) + }); + + let body = reqwest::Body::wrap_stream(stream.map(Ok::<_, std::convert::Infallible>)); + + http::Response::builder() + .header("content-encoding", "zstd") + .header("content-length", len) + .body(body) + .unwrap() + } + }); + + let client = reqwest::Client::new(); + + let res = client + .get(&format!("http://{}/zstd", server.addr())) + .send() + .await + .expect("response"); + + let body = res.text().await.expect("text"); + assert_eq!(body, content); +}