diff --git a/README.md b/README.md index 6562952..ede1a0f 100644 --- a/README.md +++ b/README.md @@ -87,13 +87,7 @@ async fn main() { Some(Cache::new(128)), ); - let client = DefaultClient::new( - tokio_native_tls::native_tls::TlsConnector::builder() - // You must set ALPN if you want to support HTTP/2 - .request_alpns(&["h2", "http/1.1"]) - .build() - .unwrap(), - ); + let client = DefaultClient::new().unwrap(); let server = proxy .bind(("127.0.0.1", 3003), move |_client_addr, req| { let client = client.clone(); @@ -111,8 +105,6 @@ async fn main() { // Modifying upgraded traffic is not supported yet. // You can try https://echo.websocket.org/.ws to test websocket. - // But you need to disable alpn of DefaultClient to disable HTTP2 because echo.websocket.org does not support HTTP/2 for Websocket. - // It should be match incoming and outgoing HTTP version on DefaultClient, I'll fix this later. #54 println!("Upgrade connection"); let Upgrade { mut client_to_server, diff --git a/examples/dev_proxy.rs b/examples/dev_proxy.rs index 15cb149..c0e3494 100644 --- a/examples/dev_proxy.rs +++ b/examples/dev_proxy.rs @@ -80,13 +80,7 @@ async fn main() { Some(Cache::new(128)), ); - let client = DefaultClient::new( - tokio_native_tls::native_tls::TlsConnector::builder() - // You must set ALPN if you want to support HTTP/2 - .request_alpns(&["h2", "http/1.1"]) - .build() - .unwrap(), - ); + let client = DefaultClient::new().unwrap(); let proxy = proxy .bind(("127.0.0.1", 3003), move |_client_addr, mut req| { let client = client.clone(); diff --git a/examples/proxy.rs b/examples/proxy.rs index 35752bb..ca9fc01 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -74,13 +74,7 @@ async fn main() { Some(Cache::new(128)), ); - let client = DefaultClient::new( - tokio_native_tls::native_tls::TlsConnector::builder() - // You must set ALPN if you want to support HTTP/2 - .request_alpns(&["h2", "http/1.1"]) - .build() - .unwrap(), - ); + let client = DefaultClient::new().unwrap(); let server = proxy .bind(("127.0.0.1", 3003), move |_client_addr, req| { let client = client.clone(); @@ -98,8 +92,6 @@ async fn main() { // Modifying upgraded traffic is not supported yet. // You can try https://echo.websocket.org/.ws to test websocket. - // But you need to disable alpn of DefaultClient to disable HTTP2 because echo.websocket.org does not support HTTP/2 for Websocket. - // It should be match incoming and outgoing HTTP version on DefaultClient, I'll fix this later. #54 println!("Upgrade connection"); let Upgrade { mut client_to_server, diff --git a/src/default_client.rs b/src/default_client.rs index 877c31f..8f90878 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -3,7 +3,7 @@ use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; use http_body_util::Empty; use hyper::{ body::{Body, Incoming}, - client, header, Request, Response, StatusCode, Uri, + client, header, Request, Response, StatusCode, Uri, Version, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use std::task::{Context, Poll}; @@ -37,10 +37,28 @@ pub struct Upgrade { } #[derive(Clone)] /// Default HTTP client for this crate -pub struct DefaultClient(tokio_native_tls::TlsConnector); +pub struct DefaultClient { + tls_connector_no_alpn: tokio_native_tls::TlsConnector, + tls_connector_alpn_h2: tokio_native_tls::TlsConnector, +} impl DefaultClient { - pub fn new(tls_connector: native_tls::TlsConnector) -> Self { - Self(tls_connector.into()) + pub fn new() -> native_tls::Result { + let tls_connector_no_alpn = native_tls::TlsConnector::builder().build()?; + let tls_connector_alpn_h2 = native_tls::TlsConnector::builder() + .request_alpns(&["h2", "http/1.1"]) + .build()?; + + Ok(Self { + tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn), + tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2), + }) + } + + fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector { + match http_version { + Version::HTTP_2 => &self.tls_connector_alpn_h2, + _ => &self.tls_connector_no_alpn, + } } /// Send a request and return a response. @@ -55,7 +73,7 @@ impl DefaultClient { B::Data: Send, B::Error: Into>, { - let mut send_request = self.connect(req.uri()).await?; + let mut send_request = self.connect(req.uri(), req.version()).await?; let (req_parts, req_body) = req.into_parts(); @@ -99,7 +117,7 @@ impl DefaultClient { Ok((res, None)) } - async fn connect(&self, uri: &Uri) -> Result, Error> + async fn connect(&self, uri: &Uri, http_version: Version) -> Result, Error> where B: Body + Unpin + Send + 'static, B::Data: Send, @@ -120,7 +138,7 @@ impl DefaultClient { if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { let tls = self - .0 + .tls_connector(http_version) .connect(host, tcp) .await .map_err(|err| Error::TlsConnectError(uri.clone(), err))?; diff --git a/tests/test.rs b/tests/test.rs index 41424a8..ab232ce 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -73,12 +73,7 @@ fn client_tls(proxy_port: u16, cert: &rcgen::CertifiedKey) -> reqwest::Client { } fn proxy_client() -> DefaultClient { - DefaultClient::new( - tokio_native_tls::native_tls::TlsConnector::builder() - .request_alpns(&["h2", "http/1.1"]) - .build() - .unwrap(), - ) + DefaultClient::new().unwrap() } async fn setup(app: Router, service: S) -> (u16, u16)