Skip to content

Commit

Permalink
download: clean up TLS feature guards
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Dec 20, 2024
1 parent d3d2e96 commit 43d46a8
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,39 @@ impl Backend {
#[cfg(feature = "curl-backend")]
Self::Curl => curl::download(url, resume_from, callback),
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
Self::Reqwest(tls) => reqwest_be::download(url, resume_from, callback, tls).await,
Self::Reqwest(tls) => tls.download(url, resume_from, callback).await,
}
}
}

#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
#[derive(Debug, Copy, Clone)]
pub enum TlsBackend {
#[cfg(feature = "reqwest-rustls-tls")]
Rustls,
#[cfg(feature = "reqwest-native-tls")]
NativeTls,
}

#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
impl TlsBackend {
async fn download(
self,
url: &Url,
resume_from: u64,
callback: DownloadCallback<'_>,
) -> Result<()> {
let client = match self {
#[cfg(feature = "reqwest-rustls-tls")]
Self::Rustls => &reqwest_be::CLIENT_RUSTLS_TLS,
#[cfg(feature = "reqwest-native-tls")]
Self::NativeTls => &reqwest_be::CLIENT_NATIVE_TLS,
};

reqwest_be::download(url, resume_from, callback, client).await
}
}

#[derive(Debug, Copy, Clone)]
pub enum Event<'a> {
ResumingPartialDownload,
Expand Down Expand Up @@ -298,12 +320,6 @@ pub mod curl {

#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
pub mod reqwest_be {
#[cfg(all(
not(feature = "reqwest-rustls-tls"),
not(feature = "reqwest-native-tls")
))]
compile_error!("Must select a reqwest TLS backend");

use std::io;
#[cfg(feature = "reqwest-rustls-tls")]
use std::sync::Arc;
Expand All @@ -320,20 +336,20 @@ pub mod reqwest_be {
use tokio_stream::StreamExt;
use url::Url;

use super::{DownloadError, Event, TlsBackend};
use super::{DownloadError, Event};

pub async fn download(
url: &Url,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> Result<()>,
tls: TlsBackend,
client: &Client,
) -> Result<()> {
// Short-circuit reqwest for the "file:" URL scheme
if download_from_file_url(url, resume_from, callback)? {
return Ok(());
}

let res = request(url, resume_from, tls)
let res = request(url, resume_from, client)
.await
.context("failed to make network request")?;

Expand Down Expand Up @@ -367,7 +383,7 @@ pub mod reqwest_be {
}

#[cfg(feature = "reqwest-rustls-tls")]
static CLIENT_RUSTLS_TLS: LazyLock<Client> = LazyLock::new(|| {
pub(super) static CLIENT_RUSTLS_TLS: LazyLock<Client> = LazyLock::new(|| {
let catcher = || {
client_generic()
.use_preconfigured_tls(
Expand All @@ -393,7 +409,7 @@ pub mod reqwest_be {
});

#[cfg(feature = "reqwest-native-tls")]
static CLIENT_DEFAULT_TLS: LazyLock<Client> = LazyLock::new(|| {
pub(super) static CLIENT_NATIVE_TLS: LazyLock<Client> = LazyLock::new(|| {
let catcher = || {
client_generic()
.user_agent(super::REQWEST_DEFAULT_TLS_USER_AGENT)
Expand All @@ -416,22 +432,8 @@ pub mod reqwest_be {
async fn request(
url: &Url,
resume_from: u64,
backend: TlsBackend,
client: &Client,
) -> Result<Response, DownloadError> {
let client: &Client = match backend {
#[cfg(feature = "reqwest-rustls-tls")]
TlsBackend::Rustls => &CLIENT_RUSTLS_TLS,
#[cfg(not(feature = "reqwest-rustls-tls"))]
TlsBackend::Rustls => {
return Err(DownloadError::BackendUnavailable("reqwest rustls"));
}
#[cfg(feature = "reqwest-native-tls")]
TlsBackend::NativeTls => &CLIENT_DEFAULT_TLS,
#[cfg(not(feature = "reqwest-native-tls"))]
TlsBackend::NativeTls => {
return Err(DownloadError::BackendUnavailable("reqwest default TLS"));
}
};
let mut req = client.get(url.as_str());

if resume_from != 0 {
Expand Down

0 comments on commit 43d46a8

Please sign in to comment.