diff --git a/Cargo.toml b/Cargo.toml index 2bfe7b3..74cb743 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,24 +10,24 @@ readme = "README.md" exclude = ["test/*"] [features] -default = ["runtime"] +# FIXME should not be turning on http1 by default: https://github.com/hyperium/hyper/issues/2376 +default = ["hyper/http1", "tcp"] -runtime = ["hyper/runtime"] +tcp = ["hyper/tcp"] [dependencies] -antidote = "1.0.0" -bytes = "0.5" http = "0.2" -hyper = { version = "0.13", default-features = false } +hyper = { version = "0.14", default-features = false, features = ["client"] } linked_hash_set = "0.1" once_cell = "1.0" -openssl = "0.10.19" +openssl = "0.10.32" openssl-sys = "0.9.26" -tokio = "0.2" -tokio-openssl = "0.4" +parking_lot = "0.11" +tokio = "1.0" +tokio-openssl = "0.6" tower-layer = "0.3" [dev-dependencies] -hyper = "0.13" -tokio = { version = "0.2", features = ["full"] } +hyper = { version = "0.14", features = ["full"] } +tokio = { version = "1.0", features = ["full"] } futures = "0.3" diff --git a/src/lib.rs b/src/lib.rs index d8c708c..f2e026c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,11 +3,9 @@ #![doc(html_root_url = "https://docs.rs/hyper-openssl/0.8")] use crate::cache::{SessionCache, SessionKey}; -use antidote::Mutex; -use bytes::{Buf, BufMut}; use http::uri::Scheme; use hyper::client::connect::{Connected, Connection}; -#[cfg(feature = "runtime")] +#[cfg(feature = "tcp")] use hyper::client::HttpConnector; use hyper::service::Service; use hyper::Uri; @@ -15,17 +13,19 @@ use once_cell::sync::OnceCell; use openssl::error::ErrorStack; use openssl::ex_data::Index; use openssl::ssl::{ - ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, SslSessionCacheMode, + self, ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, + SslSessionCacheMode, }; +use openssl::x509::X509VerifyResult; +use parking_lot::Mutex; use std::error::Error; -use std::fmt::Debug; +use std::fmt; use std::future::Future; use std::io; -use std::mem::MaybeUninit; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_openssl::SslStream; use tower_layer::Layer; @@ -151,14 +151,14 @@ pub struct HttpsConnector { inner: Inner, } -#[cfg(feature = "runtime")] +#[cfg(feature = "tcp")] impl HttpsConnector { /// Creates a a new `HttpsConnector` using default settings. /// /// The Hyper `HttpConnector` is used to perform the TCP socket connection. ALPN is configured to support both /// HTTP/2 and HTTP/1.1. /// - /// Requires the `runtime` Cargo feature. + /// Requires the `tcp` Cargo feature. pub fn new() -> Result, ErrorStack> { let mut http = HttpConnector::new(); http.enforce_http(false); @@ -171,8 +171,8 @@ impl HttpsConnector where S: Service + Send, S::Error: Into>, - S::Future: Unpin + Send + 'static, - T: AsyncRead + AsyncWrite + Connection + Unpin + Debug + Sync + Send + 'static, + S::Future: Send + 'static, + T: AsyncRead + AsyncWrite + Connection + Unpin, { /// Creates a new `HttpsConnector`. /// @@ -197,8 +197,8 @@ impl Service for HttpsConnector where S: Service + Send, S::Error: Into>, - S::Future: Unpin + Send + 'static, - S::Response: AsyncRead + AsyncWrite + Connection + Unpin + Debug + Sync + Send + 'static, + S::Future: Send + 'static, + S::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin, { type Response = MaybeHttpsStream; type Error = Box; @@ -228,15 +228,48 @@ where let host = uri.host().ok_or_else(|| "URI missing host")?; let config = inner.setup_ssl(&uri, host)?; - let stream = tokio_openssl::connect(config, host, conn).await?; + let ssl = config.into_ssl(host)?; - Ok(MaybeHttpsStream::Https(stream)) + let mut stream = SslStream::new(ssl, conn)?; + + match Pin::new(&mut stream).connect().await { + Ok(()) => Ok(MaybeHttpsStream::Https(stream)), + Err(error) => Err(Box::new(ConnectError { + error, + verify_result: stream.ssl().verify_result(), + }) as _), + } }; Box::pin(f) } } +#[derive(Debug)] +struct ConnectError { + error: ssl::Error, + verify_result: X509VerifyResult, +} + +impl fmt::Display for ConnectError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.error, fmt)?; + + if self.verify_result != X509VerifyResult::OK { + fmt.write_str(": ")?; + fmt::Display::fmt(&self.verify_result, fmt)?; + } + + Ok(()) + } +} + +impl Error for ConnectError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.error) + } +} + /// A stream which may be wrapped with TLS. pub enum MaybeHttpsStream { /// A raw HTTP stream. @@ -249,35 +282,14 @@ impl AsyncRead for MaybeHttpsStream where T: AsyncRead + AsyncWrite + Unpin, { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit]) -> bool { - match &*self { - MaybeHttpsStream::Http(s) => s.prepare_uninitialized_buffer(buf), - MaybeHttpsStream::Https(s) => s.prepare_uninitialized_buffer(buf), - } - } - fn poll_read( mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut [u8], - ) -> Poll> { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(ctx, buf), - } - } - - fn poll_read_buf( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> - where - B: BufMut, - { + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_read_buf(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_read_buf(ctx, buf), + MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), + MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf), } } } @@ -310,20 +322,6 @@ where MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx), } } - - fn poll_write_buf( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, - buf: &mut B, - ) -> Poll> - where - B: Buf, - { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_write_buf(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_write_buf(ctx, buf), - } - } } impl Connection for MaybeHttpsStream diff --git a/src/test.rs b/src/test.rs index 99c47bf..349f686 100644 --- a/src/test.rs +++ b/src/test.rs @@ -8,7 +8,7 @@ use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; use tokio::net::TcpListener; #[tokio::test] -#[cfg(feature = "runtime")] +#[cfg(feature = "tcp")] async fn google() { let ssl = HttpsConnector::new().unwrap(); let client = Client::builder() @@ -28,7 +28,7 @@ async fn google() { #[tokio::test] async fn localhost() { - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let server = async move { @@ -44,7 +44,10 @@ async fn localhost() { for _ in 0..3 { let stream = listener.accept().await.unwrap().0; - let stream = tokio_openssl::accept(&acceptor, stream).await.unwrap(); + let ssl = Ssl::new(acceptor.context()).unwrap(); + let mut stream = SslStream::new(ssl, stream).unwrap(); + + Pin::new(&mut stream).accept().await.unwrap(); let service = service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(Body::empty())) }); @@ -93,7 +96,7 @@ async fn localhost() { async fn alpn_h2() { use openssl::ssl::{self, AlpnError}; - let mut listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); let server = async move { @@ -110,7 +113,10 @@ async fn alpn_h2() { let acceptor = acceptor.build(); let stream = listener.accept().await.unwrap().0; - let stream = tokio_openssl::accept(&acceptor, stream).await.unwrap(); + let ssl = Ssl::new(acceptor.context()).unwrap(); + let mut stream = SslStream::new(ssl, stream).unwrap(); + + Pin::new(&mut stream).accept().await.unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service =