From 43d46a893d776522e2bba2adf5913ce835f7bbc0 Mon Sep 17 00:00:00 2001 From: Dirkjan Ochtman Date: Fri, 20 Dec 2024 15:25:43 +0100 Subject: [PATCH] download: clean up TLS feature guards --- download/src/lib.rs | 56 +++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/download/src/lib.rs b/download/src/lib.rs index aa0899d5f6..73faeaa6d9 100644 --- a/download/src/lib.rs +++ b/download/src/lib.rs @@ -160,17 +160,39 @@ impl Backend { #[cfg(feature = "curl-backend")] Self::Curl => curl::download(url, resume_from, callback), #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] - Self::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await, + Self::Reqwest(tls) => tls.download(url, resume_from, callback).await, } } } +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] #[derive(Debug, Copy, Clone)] pub enum TlsBackend { + #[cfg(feature = "reqwest-rustls-tls")] Rustls, + #[cfg(feature = "reqwest-native-tls")] NativeTls, } +#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] +impl TlsBackend { + async fn download( + self, + url: &Url, + resume_from: u64, + callback: DownloadCallback<'_>, + ) -> Result<()> { + let client = match self { + #[cfg(feature = "reqwest-rustls-tls")] + Self::Rustls => &reqwest_be::CLIENT_RUSTLS_TLS, + #[cfg(feature = "reqwest-native-tls")] + Self::NativeTls => &reqwest_be::CLIENT_NATIVE_TLS, + }; + + reqwest_be::download(url, resume_from, callback, client).await + } +} + #[derive(Debug, Copy, Clone)] pub enum Event<'a> { ResumingPartialDownload, @@ -298,12 +320,6 @@ pub mod curl { #[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))] pub mod reqwest_be { - #[cfg(all( - not(feature = "reqwest-rustls-tls"), - not(feature = "reqwest-native-tls") - ))] - compile_error!("Must select a reqwest TLS backend"); - use std::io; #[cfg(feature = "reqwest-rustls-tls")] use std::sync::Arc; @@ -320,20 +336,20 @@ pub mod reqwest_be { use tokio_stream::StreamExt; use url::Url; - use super::{DownloadError, Event, TlsBackend}; + use super::{DownloadError, Event}; pub async fn download( url: &Url, resume_from: u64, callback: &dyn Fn(Event<'_>) -> Result<()>, - tls: TlsBackend, + client: &Client, ) -> Result<()> { // Short-circuit reqwest for the "file:" URL scheme if download_from_file_url(url, resume_from, callback)? { return Ok(()); } - let res = request(url, resume_from, tls) + let res = request(url, resume_from, client) .await .context("failed to make network request")?; @@ -367,7 +383,7 @@ pub mod reqwest_be { } #[cfg(feature = "reqwest-rustls-tls")] - static CLIENT_RUSTLS_TLS: LazyLock = LazyLock::new(|| { + pub(super) static CLIENT_RUSTLS_TLS: LazyLock = LazyLock::new(|| { let catcher = || { client_generic() .use_preconfigured_tls( @@ -393,7 +409,7 @@ pub mod reqwest_be { }); #[cfg(feature = "reqwest-native-tls")] - static CLIENT_DEFAULT_TLS: LazyLock = LazyLock::new(|| { + pub(super) static CLIENT_NATIVE_TLS: LazyLock = LazyLock::new(|| { let catcher = || { client_generic() .user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT) @@ -416,22 +432,8 @@ pub mod reqwest_be { async fn request( url: &Url, resume_from: u64, - backend: TlsBackend, + client: &Client, ) -> Result { - let client: &Client = match backend { - #[cfg(feature = "reqwest-rustls-tls")] - TlsBackend::Rustls => &CLIENT_RUSTLS_TLS, - #[cfg(not(feature = "reqwest-rustls-tls"))] - TlsBackend::Rustls => { - return Err(DownloadError::BackendUnavailable("reqwest rustls")); - } - #[cfg(feature = "reqwest-native-tls")] - TlsBackend::NativeTls => &CLIENT_DEFAULT_TLS, - #[cfg(not(feature = "reqwest-native-tls"))] - TlsBackend::NativeTls => { - return Err(DownloadError::BackendUnavailable("reqwest default TLS")); - } - }; let mut req = client.get(url.as_str()); if resume_from != 0 {