diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ac2a460..eaf4d5a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,7 +54,7 @@ jobs: - uses: actions/checkout@v2 - uses: sfackler/actions/rustup@master with: - version: 1.58.1 + version: 1.68.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 @@ -74,3 +74,4 @@ jobs: path: target key: test-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y - run: cargo test --all + - run: cargo test --all --all-features diff --git a/Cargo.toml b/Cargo.toml index d76e175..1e73a27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,25 +7,38 @@ description = "Hyper TLS support via OpenSSL" license = "MIT/Apache-2.0" repository = "https://github.com/sfackler/hyper-openssl" readme = "README.md" -exclude = ["test/*"] +rust-version = "1.68" [features] -default = ["tcp"] -tcp = ["hyper/tcp"] +default = [] + +client-legacy = [ + "dep:http", + "dep:hyper-util", + "dep:linked_hash_set", + "dep:once_cell", + "dep:parking_lot", + "dep:pin-project", + "dep:tower-layer", + "dep:tower-service", + "hyper-util?/client-legacy", +] + [dependencies] -http = "0.2" -hyper = { version = "0.14.2", default-features = false, features = ["client"] } -linked_hash_set = "0.1" -once_cell = "1.0" +http = { version = "1.0.0", optional = true } +hyper = "1.0.1" +hyper-util = { version = "0.1", optional = true } +linked_hash_set = { version = "0.1", optional = true } +once_cell = { version = "1", optional = true } openssl = "0.10.32" openssl-sys = "0.9.26" -parking_lot = "0.12.0" -tokio = "1.0" -tokio-openssl = "0.6" -tower-layer = "0.3" +parking_lot = { version = "0.12", optional = true } +pin-project = { version = "1.1.3", optional = true } +tower-layer = { version = "0.3", optional = true } +tower-service = { version = "0.3", optional = true } [dev-dependencies] -hyper = { version = "0.14", features = ["full"] } -tokio = { version = "1.0", features = ["full"] } -futures = "0.3" +hyper = { version = "1", features = ["full"] } +hyper-util = { version = "0.1", features = ["full"] } +tokio = { version = "1", features = ["full"] } diff --git a/src/cache.rs b/src/client/cache.rs similarity index 89% rename from src/cache.rs rename to src/client/cache.rs index 6e2d348..fd0ceed 100644 --- a/src/cache.rs +++ b/src/client/cache.rs @@ -3,7 +3,8 @@ use linked_hash_set::LinkedHashSet; use openssl::ssl::SslVersion; use openssl::ssl::{SslSession, SslSessionRef}; use std::borrow::Borrow; -use std::collections::hash_map::{Entry, HashMap}; +use std::collections::hash_map::Entry; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; #[derive(Hash, PartialEq, Eq, Clone)] @@ -28,7 +29,7 @@ impl Hash for HashSession { where H: Hasher, { - self.0.id().hash(state); + self.0.id().hash(state) } } @@ -56,16 +57,14 @@ impl SessionCache { self.sessions .entry(key.clone()) - .or_insert_with(LinkedHashSet::new) + .or_default() .insert(session.clone()); self.reverse.insert(session, key); } pub fn get(&mut self, key: &SessionKey) -> Option { - let session = { - let sessions = self.sessions.get_mut(key)?; - sessions.front().cloned()?.0 - }; + let sessions = self.sessions.get_mut(key)?; + let session = sessions.front().cloned()?.0; #[cfg(ossl111)] { diff --git a/src/client/legacy.rs b/src/client/legacy.rs new file mode 100644 index 0000000..96a114d --- /dev/null +++ b/src/client/legacy.rs @@ -0,0 +1,315 @@ +//! hyper-util legacy client support. +use crate::client::cache::{SessionCache, SessionKey}; +use crate::SslStream; +use http::uri::Scheme; +use hyper::rt::{Read, ReadBufCursor, Write}; +use hyper::Uri; +use hyper_util::client::legacy::connect::{Connected, Connection, HttpConnector}; +use once_cell::sync::OnceCell; +use openssl::error::ErrorStack; +use openssl::ex_data::Index; +use openssl::ssl::{ + self, ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, + SslSessionCacheMode, +}; +use openssl::x509::X509VerifyResult; +use parking_lot::Mutex; +use pin_project::pin_project; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::{fmt, io}; +use tower_layer::Layer; +use tower_service::Service; + +type ConfigureCallback = + dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send; + +fn key_index() -> Result, ErrorStack> { + static IDX: OnceCell> = OnceCell::new(); + IDX.get_or_try_init(Ssl::new_ex_index).copied() +} + +#[derive(Clone)] +struct Inner { + ssl: SslConnector, + cache: Arc>, + callback: Option>, +} + +/// A [`Layer`] which wraps services in an `HttpsConnector`. +pub struct HttpsLayer { + inner: Inner, +} + +impl HttpsLayer { + /// Creates a new `HttpsLayer` with default settings. + /// + /// ALPN is configured to support both HTTP/1.1 and HTTP/2. + pub fn new() -> Result { + let mut ssl = SslConnector::builder(SslMethod::tls())?; + + #[cfg(ossl102)] + ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?; + + Self::with_connector(ssl) + } + + /// Creates a new `HttpsLayer`. + /// + /// The session cache configuration of `ssl` will be overwritten. + pub fn with_connector(mut ssl: SslConnectorBuilder) -> Result { + let cache = Arc::new(Mutex::new(SessionCache::new())); + + ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT); + + ssl.set_new_session_callback({ + let cache = cache.clone(); + move |ssl, session| { + if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) { + cache.lock().insert(key.clone(), session); + } + } + }); + + ssl.set_remove_session_callback({ + let cache = cache.clone(); + move |_, session| cache.lock().remove(session) + }); + + Ok(HttpsLayer { + inner: Inner { + ssl: ssl.build(), + cache, + callback: None, + }, + }) + } + + /// Registers a callback which can customize the configuration of each connection. + pub fn set_callback(&mut self, callback: F) + where + F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, + { + self.inner.callback = Some(Arc::new(callback)); + } +} + +impl Layer for HttpsLayer { + type Service = HttpsConnector; + + fn layer(&self, inner: S) -> Self::Service { + HttpsConnector { + http: inner, + inner: self.inner.clone(), + } + } +} + +/// A Connector using OpenSSL supporting `http` and `https` schemes. +#[derive(Clone)] +pub struct HttpsConnector { + http: T, + inner: Inner, +} + +impl HttpsConnector { + /// Creates a new `HttpsConnector` using default settings. + /// + /// The Hyper [`HttpConnector`] is used to perform the TCP socket connection. ALPN is configured to support both + /// HTTP/1.1 and HTTP/2. + pub fn new() -> Result { + let mut http = HttpConnector::new(); + http.enforce_http(false); + + HttpsLayer::new().map(|l| l.layer(http)) + } +} + +impl HttpsConnector { + /// Creates a new `HttpsConnector`. + /// + /// The session cache configuration of `ssl` will be overwritten. + pub fn with_connector(http: S, ssl: SslConnectorBuilder) -> Result { + HttpsLayer::with_connector(ssl).map(|l| l.layer(http)) + } + + /// Registers a callback which can customize the configuration of each connection. + pub fn set_callback(&mut self, callback: F) + where + F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, + { + self.inner.callback = Some(Arc::new(callback)); + } +} + +impl Service for HttpsConnector +where + S: Service, + S::Future: 'static + Send, + S::Error: Into>, + S::Response: Read + Write + Unpin + Connection + Send, +{ + type Response = MaybeHttpsStream; + type Error = Box; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.http.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let tls_setup = if req.scheme() == Some(&Scheme::HTTPS) { + Some((self.inner.clone(), req.clone())) + } else { + None + }; + + let connect = self.http.call(req); + + Box::pin(async move { + let conn = connect.await.map_err(Into::into)?; + + let Some((inner, uri)) = tls_setup else { + return Ok(MaybeHttpsStream::Http(conn)); + }; + + let Some(host) = uri.host() else { + return Err("URI missing host".into()); + }; + + let mut config = inner.ssl.configure()?; + + if let Some(callback) = &inner.callback { + callback(&mut config, &uri)?; + } + + let key = SessionKey { + host: host.to_string(), + port: uri.port_u16().unwrap_or(443), + }; + + if let Some(session) = inner.cache.lock().get(&key) { + unsafe { + config.set_session(&session)?; + } + } + + let idx = key_index()?; + config.set_ex_data(idx, key); + + let ssl = config.into_ssl(host)?; + + let mut stream = SslStream::new(ssl, conn)?; + + match Pin::new(&mut stream).connect().await { + Ok(()) => Ok(MaybeHttpsStream::Https(stream)), + Err(error) => Err(Box::new(ConnectError { + error, + verify_result: stream.ssl().verify_result(), + }) as _), + } + }) + } +} + +#[derive(Debug)] +struct ConnectError { + error: ssl::Error, + verify_result: X509VerifyResult, +} + +impl fmt::Display for ConnectError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.error, fmt)?; + + if self.verify_result != X509VerifyResult::OK { + fmt.write_str(": ")?; + fmt::Display::fmt(&self.verify_result, fmt)?; + } + + Ok(()) + } +} + +impl Error for ConnectError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.error) + } +} + +/// A stream which may be wrapped with TLS. +#[pin_project(project = MaybeHttpsStreamProj)] +pub enum MaybeHttpsStream { + /// A raw HTTP stream. + Http(#[pin] T), + /// A TLS-wrapped HTTP stream. + Https(#[pin] SslStream), +} + +impl Read for MaybeHttpsStream +where + T: Read + Write, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: ReadBufCursor<'_>, + ) -> Poll> { + match self.project() { + MaybeHttpsStreamProj::Http(s) => s.poll_read(cx, buf), + MaybeHttpsStreamProj::Https(s) => s.poll_read(cx, buf), + } + } +} + +impl Write for MaybeHttpsStream +where + T: Read + Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + MaybeHttpsStreamProj::Http(s) => s.poll_write(cx, buf), + MaybeHttpsStreamProj::Https(s) => s.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + MaybeHttpsStreamProj::Http(s) => s.poll_flush(cx), + MaybeHttpsStreamProj::Https(s) => s.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + MaybeHttpsStreamProj::Http(s) => s.poll_shutdown(cx), + MaybeHttpsStreamProj::Https(s) => s.poll_shutdown(cx), + } + } +} + +impl Connection for MaybeHttpsStream +where + T: Connection, +{ + fn connected(&self) -> Connected { + match self { + MaybeHttpsStream::Http(s) => s.connected(), + MaybeHttpsStream::Https(s) => { + let mut connected = s.get_ref().connected(); + #[cfg(ossl102)] + if s.ssl().selected_alpn_protocol() == Some(b"h2") { + connected = connected.negotiated_h2(); + } + connected + } + } + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..05a74f4 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,4 @@ +//! hyper-util client support. + +mod cache; +pub mod legacy; diff --git a/src/lib.rs b/src/lib.rs index 7d05c59..b116466 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,346 +1,262 @@ -//! Hyper SSL support via OpenSSL. +//! Hyper TLS support via OpenSSL. #![warn(missing_docs)] -#![doc(html_root_url = "https://docs.rs/hyper-openssl/0.8")] -use crate::cache::{SessionCache, SessionKey}; -use http::uri::Scheme; -use hyper::client::connect::{Connected, Connection}; -#[cfg(feature = "tcp")] -use hyper::client::HttpConnector; -use hyper::service::Service; -use hyper::Uri; -use once_cell::sync::OnceCell; +use hyper::rt::{Read, ReadBuf, ReadBufCursor, Write}; use openssl::error::ErrorStack; -use openssl::ex_data::Index; -use openssl::ssl::{ - self, ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod, - SslSessionCacheMode, -}; -use openssl::x509::X509VerifyResult; -use parking_lot::Mutex; -use std::error::Error; +use openssl::ssl::{self, ErrorCode, Ssl, SslRef}; use std::fmt; -use std::future::Future; -use std::io; +use std::future; +use std::io::{self, Read as _, Write as _}; use std::pin::Pin; -use std::sync::Arc; +use std::ptr; +use std::slice; use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_openssl::SslStream; -use tower_layer::Layer; -mod cache; +#[cfg(feature = "client-legacy")] +pub mod client; #[cfg(test)] mod test; -fn key_index() -> Result, ErrorStack> { - static IDX: OnceCell> = OnceCell::new(); - IDX.get_or_try_init(Ssl::new_ex_index).map(|v| *v) +struct StreamWrapper { + stream: S, + context: *mut Context<'static>, } -#[derive(Clone)] -struct Inner { - ssl: SslConnector, - cache: Arc>, - #[allow(clippy::type_complexity)] - callback: Option< - Arc Result<(), ErrorStack> + Sync + Send>, - >, -} - -impl Inner { - fn setup_ssl(&self, uri: &Uri, host: &str) -> Result { - let mut conf = self.ssl.configure()?; - - if let Some(ref callback) = self.callback { - callback(&mut conf, uri)?; - } +unsafe impl Sync for StreamWrapper where S: Sync {} +unsafe impl Send for StreamWrapper where S: Send {} - let key = SessionKey { - host: host.to_string(), - port: uri.port_u16().unwrap_or(443), - }; - - if let Some(session) = self.cache.lock().get(&key) { - unsafe { - conf.set_session(&session)?; - } - } - - let idx = key_index()?; - conf.set_ex_data(idx, key); - - Ok(conf) +impl fmt::Debug for StreamWrapper +where + S: fmt::Debug, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.stream, fmt) } } -/// A layer which wraps services in an `HttpsConnector`. -pub struct HttpsLayer { - inner: Inner, -} - -impl HttpsLayer { - /// Creates a new `HttpsLayer` with default settings. +impl StreamWrapper { + /// # Safety /// - /// ALPN is configured to support both HTTP/2 and HTTP/1.1. - pub fn new() -> Result { - let mut ssl = SslConnector::builder(SslMethod::tls())?; - - #[cfg(ossl102)] - ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?; - - Self::with_connector(ssl) - } - - /// Creates a new `HttpsLayer`. - /// - /// The session cache configuration of `ssl` will be overwritten. - pub fn with_connector(mut ssl: SslConnectorBuilder) -> Result { - let cache = Arc::new(Mutex::new(SessionCache::new())); - - ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT); - - ssl.set_new_session_callback({ - let cache = cache.clone(); - move |ssl, session| { - if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) { - cache.lock().insert(key.clone(), session); - } - } - }); - - ssl.set_remove_session_callback({ - let cache = cache.clone(); - move |_, session| cache.lock().remove(session) - }); - - Ok(HttpsLayer { - inner: Inner { - ssl: ssl.build(), - cache, - callback: None, - }, - }) - } - - /// Registers a callback which can customize the configuration of each connection. - pub fn set_callback(&mut self, callback: F) - where - F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, - { - self.inner.callback = Some(Arc::new(callback)); + /// Must be called with `context` set to a valid pointer to a live `Context` object, and the + /// wrapper must be pinned in memory. + unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) { + debug_assert_ne!(self.context, ptr::null_mut()); + let stream = Pin::new_unchecked(&mut self.stream); + let context = &mut *self.context.cast(); + (stream, context) } } -impl Layer for HttpsLayer { - type Service = HttpsConnector; - - fn layer(&self, inner: S) -> HttpsConnector { - HttpsConnector { - http: inner, - inner: self.inner.clone(), +impl io::Read for StreamWrapper +where + S: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let (stream, cx) = unsafe { self.parts() }; + let mut buf = ReadBuf::new(buf); + match stream.poll_read(cx, buf.unfilled())? { + Poll::Ready(()) => Ok(buf.filled().len()), + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } } -/// A Connector using OpenSSL to support `http` and `https` schemes. -#[derive(Clone)] -pub struct HttpsConnector { - http: T, - inner: Inner, -} - -#[cfg(feature = "tcp")] -impl HttpsConnector { - /// Creates a a new `HttpsConnector` using default settings. - /// - /// The Hyper `HttpConnector` is used to perform the TCP socket connection. ALPN is configured to support both - /// HTTP/2 and HTTP/1.1. - /// - /// Requires the `tcp` Cargo feature. - pub fn new() -> Result, ErrorStack> { - let mut http = HttpConnector::new(); - http.enforce_http(false); - - HttpsLayer::new().map(|l| l.layer(http)) - } -} - -impl HttpsConnector +impl io::Write for StreamWrapper where - S: Service + Send, - S::Error: Into>, - S::Future: Send + 'static, - T: AsyncRead + AsyncWrite + Connection + Unpin, + S: Write, { - /// Creates a new `HttpsConnector`. - /// - /// The session cache configuration of `ssl` will be overwritten. - pub fn with_connector( - http: S, - ssl: SslConnectorBuilder, - ) -> Result, ErrorStack> { - HttpsLayer::with_connector(ssl).map(|l| l.layer(http)) + fn write(&mut self, buf: &[u8]) -> io::Result { + let (stream, cx) = unsafe { self.parts() }; + match stream.poll_write(cx, buf) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } } - /// Registers a callback which can customize the configuration of each connection. - pub fn set_callback(&mut self, callback: F) - where - F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send, - { - self.inner.callback = Some(Arc::new(callback)); + fn flush(&mut self) -> io::Result<()> { + let (stream, cx) = unsafe { self.parts() }; + match stream.poll_flush(cx) { + Poll::Ready(r) => r, + Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), + } } } -impl Service for HttpsConnector +/// A Hyper stream type using OpenSSL. +#[derive(Debug)] +pub struct SslStream(ssl::SslStream>); + +impl SslStream where - S: Service + Send, - S::Error: Into>, - S::Future: Send + 'static, - S::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin, + S: Read + Write, { - type Response = MaybeHttpsStream; - type Error = Box; - #[allow(clippy::type_complexity)] - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.http.poll_ready(cx).map_err(Into::into) + /// Like [`SslStream::new`](ssl::SslStream::new). + pub fn new(ssl: Ssl, stream: S) -> Result { + ssl::SslStream::new( + ssl, + StreamWrapper { + stream, + context: ptr::null_mut(), + }, + ) + .map(SslStream) } - fn call(&mut self, uri: Uri) -> Self::Future { - let tls_setup = if uri.scheme() == Some(&Scheme::HTTPS) { - Some((self.inner.clone(), uri.clone())) - } else { - None - }; - - let connect = self.http.call(uri); - - let f = async { - let conn = connect.await.map_err(Into::into)?; - - let (inner, uri) = match tls_setup { - Some((inner, uri)) => (inner, uri), - None => return Ok(MaybeHttpsStream::Http(conn)), - }; - - let host = uri.host().ok_or("URI missing host")?; + /// Like [`SslStream::connect`](ssl::SslStream::connect). + pub fn poll_connect( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.with_context(cx, |s| cvt_ossl(s.connect())) + } - let config = inner.setup_ssl(&uri, host)?; - let ssl = config.into_ssl(host)?; + /// A convenience method wrapping [`poll_connect`](Self::poll_connect). + pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { + future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await + } - let mut stream = SslStream::new(ssl, conn)?; + /// Like [`SslStream::accept`](ssl::SslStream::accept). + pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_context(cx, |s| cvt_ossl(s.accept())) + } - match Pin::new(&mut stream).connect().await { - Ok(()) => Ok(MaybeHttpsStream::Https(stream)), - Err(error) => Err(Box::new(ConnectError { - error, - verify_result: stream.ssl().verify_result(), - }) as _), - } - }; + /// A convenience method wrapping [`poll_accept`](Self::poll_accept). + pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { + future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await + } - Box::pin(f) + /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake). + pub fn poll_do_handshake( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.with_context(cx, |s| cvt_ossl(s.do_handshake())) } -} -#[derive(Debug)] -struct ConnectError { - error: ssl::Error, - verify_result: X509VerifyResult, + /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake). + pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> { + future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await + } } -impl fmt::Display for ConnectError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.error, fmt)?; +impl SslStream { + /// Returns a shared reference to the `Ssl` object associated with this stream. + pub fn ssl(&self) -> &SslRef { + self.0.ssl() + } - if self.verify_result != X509VerifyResult::OK { - fmt.write_str(": ")?; - fmt::Display::fmt(&self.verify_result, fmt)?; - } + /// Returns a shared reference to the underlying stream. + pub fn get_ref(&self) -> &S { + &self.0.get_ref().stream + } - Ok(()) + /// Returns a mutable reference to the underlying stream. + pub fn get_mut(&mut self) -> &mut S { + &mut self.0.get_mut().stream } -} -impl Error for ConnectError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - Some(&self.error) + /// Returns a pinned mutable reference to the underlying stream. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> { + unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) } } -} -/// A stream which may be wrapped with TLS. -pub enum MaybeHttpsStream { - /// A raw HTTP stream. - Http(T), - /// An SSL-wrapped HTTP stream. - Https(SslStream), + fn with_context(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R + where + F: FnOnce(&mut ssl::SslStream>) -> R, + { + unsafe { + let this = self.get_unchecked_mut(); + this.0.get_mut().context = (ctx as *mut Context<'_>).cast(); + let r = f(&mut this.0); + this.0.get_mut().context = ptr::null_mut(); + r + } + } } -impl AsyncRead for MaybeHttpsStream +impl Read for SslStream where - T: AsyncRead + AsyncWrite + Unpin, + S: Read + Write, { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + mut buf: ReadBufCursor<'_>, ) -> Poll> { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf), - } + self.with_context(cx, |s| { + // This isn't really "proper", but rust-openssl doesn't currently expose a suitable + // interface even though OpenSSL itself doesn't require the buffer to be initialized. + // So this is good enough for now. + let slice = unsafe { + let buf = buf.as_mut(); + slice::from_raw_parts_mut(buf.as_mut_ptr().cast::(), buf.len()) + }; + + match cvt(s.read(slice))? { + Poll::Ready(nread) => unsafe { + buf.advance(nread); + Poll::Ready(Ok(())) + }, + Poll::Pending => Poll::Pending, + } + }) } } -impl AsyncWrite for MaybeHttpsStream +impl Write for SslStream where - T: AsyncRead + AsyncWrite + Unpin, + S: Read + Write, { fn poll_write( - mut self: Pin<&mut Self>, - ctx: &mut Context<'_>, + self: Pin<&mut Self>, + cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(ctx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(ctx, buf), - } + self.with_context(cx, |s| cvt(s.write(buf))) } - fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(ctx), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(ctx), - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_context(cx, |s| cvt(s.flush())) } - fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll> { - match &mut *self { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(ctx), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(ctx), - } + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.with_context(cx, |s| { + match s.shutdown() { + Ok(_) => {} + Err(e) => match e.code() { + ErrorCode::ZERO_RETURN => {} + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => return Poll::Pending, + _ => { + return Poll::Ready(Err(e + .into_io_error() + .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e)))) + } + }, + } + + let (stream, cx) = unsafe { s.get_mut().parts() }; + stream.poll_shutdown(cx) + }) } } -impl Connection for MaybeHttpsStream -where - T: Connection, -{ - fn connected(&self) -> Connected { - match self { - MaybeHttpsStream::Http(s) => s.connected(), - MaybeHttpsStream::Https(s) => { - let mut connected = s.get_ref().connected(); - #[cfg(ossl102)] - { - if s.ssl().selected_alpn_protocol() == Some(b"h2") { - connected = connected.negotiated_h2(); - } - } - connected - } - } +fn cvt(r: io::Result) -> Poll> { + match r { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(e) => Poll::Ready(Err(e)), + } +} + +fn cvt_ossl(r: Result) -> Poll> { + match r { + Ok(v) => Poll::Ready(Ok(v)), + Err(e) => match e.code() { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending, + _ => Poll::Ready(Err(e)), + }, } } diff --git a/src/test.rs b/src/test.rs index f202daf..023bc6f 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,39 +1,83 @@ -use super::*; -use futures::StreamExt; -use hyper::client::HttpConnector; -use hyper::server::conn::Http; -use hyper::{service, Response}; -use hyper::{Body, Client}; -use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod}; -use tokio::net::TcpListener; +use crate::SslStream; +use hyper::client::conn::http1; +use hyper::{server, service, Request, Response}; +use hyper_util::rt::TokioIo; +use openssl::ssl::{Ssl, SslAcceptor, SslConnector, SslFiletype, SslMethod}; +use std::io; +use std::pin::Pin; +use tokio::net::{TcpListener, TcpStream}; #[tokio::test] -#[cfg(feature = "tcp")] -async fn google() { - let ssl = HttpsConnector::new().unwrap(); - let client = Client::builder() - .pool_max_idle_per_host(0) - .build::<_, Body>(ssl); +async fn raw_client_server() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); - for _ in 0..3 { - let resp = client - .get("https://www.google.com".parse().unwrap()) + tokio::spawn(async move { + let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap(); + acceptor + .set_private_key_file("test/key.pem", SslFiletype::PEM) + .unwrap(); + acceptor + .set_certificate_chain_file("test/cert.pem") + .unwrap(); + let acceptor = acceptor.build(); + + let ssl = Ssl::new(acceptor.context()).unwrap(); + let stream = listener.accept().await.unwrap().0; + let mut stream = SslStream::new(ssl, TokioIo::new(stream)).unwrap(); + + Pin::new(&mut stream).accept().await.unwrap(); + + let service = + service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(String::new())) }); + + server::conn::http1::Builder::new() + .serve_connection(stream, service) .await .unwrap(); - assert!(resp.status().is_success(), "{}", resp.status()); - let mut body = resp.into_body(); - while body.next().await.transpose().unwrap().is_some() {} - } + }); + + let stream = TcpStream::connect(addr).await.unwrap(); + let stream = TokioIo::new(stream); + + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("test/cert.pem").unwrap(); + let ssl = builder + .build() + .configure() + .unwrap() + .into_ssl("localhost") + .unwrap(); + let mut stream = SslStream::new(ssl, stream).unwrap(); + Pin::new(&mut stream).connect().await.unwrap(); + + let (mut tx, conn) = http1::handshake(stream).await.unwrap(); + tokio::spawn(async move { + if let Err(err) = conn.await { + panic!("Connection failed: {:?}", err); + } + }); + + let req = Request::builder().body(String::new()).unwrap(); + tx.send_request(req).await.unwrap(); } #[tokio::test] -async fn localhost() { +#[cfg(feature = "client-legacy")] +async fn legacy_client_server() { + use crate::client::legacy::HttpsConnector; + use hyper::body::Body; + use hyper_util::client::legacy::connect::HttpConnector; + use hyper_util::client::legacy::Client; + use hyper_util::rt::TokioExecutor; + use std::future; + use std::pin::pin; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - acceptor.set_session_id_context(b"test").unwrap(); + tokio::spawn(async move { + let mut acceptor = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls()).unwrap(); acceptor .set_private_key_file("test/key.pem", SslFiletype::PEM) .unwrap(); @@ -43,23 +87,22 @@ async fn localhost() { let acceptor = acceptor.build(); for _ in 0..3 { - let stream = listener.accept().await.unwrap().0; let ssl = Ssl::new(acceptor.context()).unwrap(); - let mut stream = SslStream::new(ssl, stream).unwrap(); + let stream = listener.accept().await.unwrap().0; + let mut stream = SslStream::new(ssl, TokioIo::new(stream)).unwrap(); Pin::new(&mut stream).accept().await.unwrap(); let service = - service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(Body::empty())) }); + service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(String::new())) }); - Http::new() - .http1_keep_alive(false) + server::conn::http1::Builder::new() + .keep_alive(false) .serve_connection(stream, service) .await .unwrap(); } - }; - tokio::spawn(server); + }); let mut connector = HttpConnector::new(); connector.enforce_http(false); @@ -78,23 +121,35 @@ async fn localhost() { } let ssl = HttpsConnector::with_connector(connector, ssl).unwrap(); - let client = Client::builder().build::<_, Body>(ssl); + let client = Client::builder(TokioExecutor::new()).build::<_, String>(ssl); for _ in 0..3 { let resp = client - .get(format!("https://localhost:{}", port).parse().unwrap()) + .get(format!("https://localhost:{port}").parse().unwrap()) .await .unwrap(); - assert!(resp.status().is_success(), "{}", resp.status()); - let mut body = resp.into_body(); - while body.next().await.transpose().unwrap().is_some() {} + assert!(resp.status().is_success()); + let mut body = pin!(resp.into_body()); + while future::poll_fn(|cx| body.as_mut().poll_frame(cx)) + .await + .transpose() + .unwrap() + .is_some() + {} } } #[tokio::test] -#[cfg(ossl102)] -async fn alpn_h2() { +#[cfg(all(feature = "client-legacy", ossl102))] +async fn legacy_alpn_h2() { + use crate::client::legacy::HttpsConnector; + use hyper::body::Body; + use hyper_util::client::legacy::connect::HttpConnector; + use hyper_util::client::legacy::Client; + use hyper_util::rt::TokioExecutor; use openssl::ssl::{self, AlpnError}; + use std::future; + use std::pin::pin; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let port = listener.local_addr().unwrap().port(); @@ -114,16 +169,15 @@ async fn alpn_h2() { let stream = listener.accept().await.unwrap().0; let ssl = Ssl::new(acceptor.context()).unwrap(); - let mut stream = SslStream::new(ssl, stream).unwrap(); + let mut stream = SslStream::new(ssl, TokioIo::new(stream)).unwrap(); Pin::new(&mut stream).accept().await.unwrap(); assert_eq!(stream.ssl().selected_alpn_protocol().unwrap(), b"h2"); let service = - service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(Body::empty())) }); + service::service_fn(|_| async { Ok::<_, io::Error>(Response::new(String::new())) }); - Http::new() - .http2_only(true) + server::conn::http2::Builder::new(TokioExecutor::new()) .serve_connection(stream, service) .await .unwrap(); @@ -137,13 +191,18 @@ async fn alpn_h2() { ssl.set_alpn_protos(b"\x02h2\x08http/1.1").unwrap(); let ssl = HttpsConnector::with_connector(connector, ssl).unwrap(); - let client = Client::builder().build::<_, Body>(ssl); + let client = Client::builder(TokioExecutor::new()).build::<_, String>(ssl); let resp = client .get(format!("https://localhost:{}", port).parse().unwrap()) .await .unwrap(); assert!(resp.status().is_success(), "{}", resp.status()); - let mut body = resp.into_body(); - while body.next().await.transpose().unwrap().is_some() {} + let mut body = pin!(resp.into_body()); + while future::poll_fn(|cx| body.as_mut().poll_frame(cx)) + .await + .transpose() + .unwrap() + .is_some() + {} }