diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 74852c0..0e7f979 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,7 +3,12 @@ name: rustls permissions: contents: read -on: [push, pull_request] +on: + push: + pull_request: + merge_group: + schedule: + - cron: '23 6 * * 5' jobs: build: @@ -78,7 +83,7 @@ jobs: - name: Install rust toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: "1.60" + toolchain: "1.63" - name: Check MSRV run: cargo check --lib --all-features diff --git a/Cargo.toml b/Cargo.toml index 6daff84..d8d9240 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "hyper-rustls" version = "0.24.1" edition = "2021" -rust-version = "1.60" +rust-version = "1.63" license = "Apache-2.0 OR ISC OR MIT" readme = "README.md" description = "Rustls+hyper integration for pure rust HTTPS" @@ -15,7 +15,7 @@ http = "0.2" hyper = { version = "0.14", default-features = false, features = ["client"] } log = { version = "0.4.4", optional = true } rustls-native-certs = { version = "0.6", optional = true } -rustls = { version = "0.21.0", default-features = false } +rustls = { version = "0.21.6", default-features = false } tokio = "1.0" tokio-rustls = { version = "0.24.0", default-features = false } webpki-roots = { version = "0.25", optional = true } diff --git a/src/acceptor.rs b/src/acceptor.rs index e843be2..4cf816d 100644 --- a/src/acceptor.rs +++ b/src/acceptor.rs @@ -9,18 +9,60 @@ use hyper::server::{ accept::Accept, conn::{AddrIncoming, AddrStream}, }; -use rustls::ServerConfig; +use rustls::{ServerConfig, ServerConnection}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; mod builder; pub use builder::AcceptorBuilder; use builder::WantsTlsConfig; -enum State { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), +/// A TLS acceptor that can be used with hyper servers. +pub struct TlsAcceptor { + config: Arc, + incoming: AddrIncoming, +} + +/// An Acceptor for the `https` scheme. +impl TlsAcceptor { + /// Provides a builder for a `TlsAcceptor`. + pub fn builder() -> AcceptorBuilder { + AcceptorBuilder::new() + } + + /// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`. + pub fn new(config: Arc, incoming: AddrIncoming) -> Self { + Self { config, incoming } + } +} + +impl Accept for TlsAcceptor { + type Conn = TlsStream; + type Error = io::Error; + + fn poll_accept( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let pin = self.get_mut(); + Poll::Ready(match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { + Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))), + Some(Err(e)) => Some(Err(e)), + None => None, + }) + } +} + +impl From<(C, I)> for TlsAcceptor +where + C: Into>, + I: Into, +{ + fn from((config, incoming): (C, I)) -> Self { + Self::new(config.into(), incoming.into()) + } } +/// A TLS stream constructed by a [`TlsAcceptor`]. // tokio_rustls::server::TlsStream doesn't expose constructor methods, // so we have to TlsAcceptor::accept and handshake to have access to it // TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first @@ -29,12 +71,32 @@ pub struct TlsStream { } impl TlsStream { - fn new(stream: AddrStream, config: Arc) -> TlsStream { + fn new(stream: AddrStream, config: Arc) -> Self { let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); - TlsStream { + Self { state: State::Handshaking(accept), } } + + /// Returns a reference to the underlying IO stream. + /// + /// This should always return `Some`, except if an error has already been yielded. + pub fn io(&self) -> Option<&AddrStream> { + match &self.state { + State::Handshaking(accept) => accept.get_ref(), + State::Streaming(stream) => Some(stream.get_ref().0), + } + } + + /// Returns a reference to the underlying [`rustls::ServerConnection']. + /// + /// This will start yielding `Some` only after the handshake has completed. + pub fn connection(&self) -> Option<&ServerConnection> { + match &self.state { + State::Handshaking(_) => None, + State::Streaming(stream) => Some(stream.get_ref().1), + } + } } impl AsyncRead for TlsStream { @@ -44,17 +106,19 @@ impl AsyncRead for TlsStream { buf: &mut ReadBuf, ) -> Poll> { let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_read(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), - } + let accept = match &mut pin.state { + State::Handshaking(accept) => accept, + State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf), + }; + + let mut stream = match ready!(Pin::new(accept).poll(cx)) { + Ok(stream) => stream, + Err(err) => return Poll::Ready(Err(err)), + }; + + let result = Pin::new(&mut stream).poll_read(cx, buf); + pin.state = State::Streaming(stream); + result } } @@ -65,75 +129,37 @@ impl AsyncWrite for TlsStream { buf: &[u8], ) -> Poll> { let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_write(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), - } + let accept = match &mut pin.state { + State::Handshaking(accept) => accept, + State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf), + }; + + let mut stream = match ready!(Pin::new(accept).poll(cx)) { + Ok(stream) => stream, + Err(err) => return Poll::Ready(Err(err)), + }; + + let result = Pin::new(&mut stream).poll_write(cx, buf); + pin.state = State::Streaming(stream); + result } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { + match &mut self.state { State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), + State::Streaming(stream) => Pin::new(stream).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { + match &mut self.state { State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), + State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx), } } } -/// A TLS acceptor that can be used with hyper servers. -pub struct TlsAcceptor { - config: Arc, - incoming: AddrIncoming, -} - -/// An Acceptor for the `https` scheme. -impl TlsAcceptor { - /// Provides a builder for a `TlsAcceptor`. - pub fn builder() -> AcceptorBuilder { - AcceptorBuilder::new() - } - /// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`. - pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { - TlsAcceptor { config, incoming } - } -} - -impl From<(C, I)> for TlsAcceptor -where - C: Into>, - I: Into, -{ - fn from((config, incoming): (C, I)) -> TlsAcceptor { - TlsAcceptor::new(config.into(), incoming.into()) - } -} - -impl Accept for TlsAcceptor { - type Conn = TlsStream; - type Error = io::Error; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let pin = self.get_mut(); - match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { - Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), - Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), - } - } +enum State { + Handshaking(tokio_rustls::Accept), + Streaming(tokio_rustls::server::TlsStream), } diff --git a/src/config.rs b/src/config.rs index c4b4624..256856c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -54,7 +54,7 @@ impl ConfigBuilderExt for ConfigBuilder { #[cfg_attr(docsrs, doc(cfg(feature = "webpki-roots")))] fn with_webpki_roots(self) -> ConfigBuilder { let mut roots = rustls::RootCertStore::empty(); - roots.add_server_trust_anchors( + roots.add_trust_anchors( webpki_roots::TLS_SERVER_ROOTS .iter() .map(|ta| { diff --git a/src/connector.rs b/src/connector.rs index 18ad689..1e3cc4e 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -10,7 +10,7 @@ use tokio_rustls::TlsConnector; use crate::stream::MaybeHttpsStream; -pub mod builder; +pub(crate) mod builder; type BoxError = Box; @@ -45,7 +45,7 @@ where C: Into>, { fn from((http, cfg): (H, C)) -> Self { - HttpsConnector { + Self { force_https: false, http, tls_config: cfg.into(), diff --git a/src/lib.rs b/src/lib.rs index b207568..308f71e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,18 +76,19 @@ //! # fn main() {} //! ``` -#![warn(missing_docs)] +#![warn(missing_docs, unreachable_pub, clippy::use_self)] #![cfg_attr(docsrs, feature(doc_cfg))] #[cfg(feature = "acceptor")] -mod acceptor; +/// TLS acceptor implementing hyper's `Accept` trait. +pub mod acceptor; mod config; mod connector; mod stream; #[cfg(feature = "logging")] mod log { - pub use log::{debug, trace}; + pub(crate) use log::{debug, trace}; } #[cfg(not(feature = "logging"))] diff --git a/src/stream.rs b/src/stream.rs index 64ddc48..a0e8c68 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -21,8 +21,8 @@ pub enum MaybeHttpsStream { impl Connection for MaybeHttpsStream { fn connected(&self) -> Connected { match self { - MaybeHttpsStream::Http(s) => s.connected(), - MaybeHttpsStream::Https(s) => { + Self::Http(s) => s.connected(), + Self::Https(s) => { let (tcp, tls) = s.get_ref(); if tls.alpn_protocol() == Some(b"h2") { tcp.connected().negotiated_h2() @@ -37,21 +37,21 @@ impl Connection for MaybeHttpsSt impl fmt::Debug for MaybeHttpsStream { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - MaybeHttpsStream::Http(..) => f.pad("Http(..)"), - MaybeHttpsStream::Https(..) => f.pad("Https(..)"), + Self::Http(..) => f.pad("Http(..)"), + Self::Https(..) => f.pad("Https(..)"), } } } impl From for MaybeHttpsStream { fn from(inner: T) -> Self { - MaybeHttpsStream::Http(inner) + Self::Http(inner) } } impl From> for MaybeHttpsStream { fn from(inner: TlsStream) -> Self { - MaybeHttpsStream::Https(inner) + Self::Https(inner) } } @@ -63,8 +63,8 @@ impl AsyncRead for MaybeHttpsStream { buf: &mut ReadBuf<'_>, ) -> Poll> { match Pin::get_mut(self) { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_read(cx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_read(cx, buf), + Self::Http(s) => Pin::new(s).poll_read(cx, buf), + Self::Https(s) => Pin::new(s).poll_read(cx, buf), } } } @@ -77,24 +77,24 @@ impl AsyncWrite for MaybeHttpsStream { buf: &[u8], ) -> Poll> { match Pin::get_mut(self) { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_write(cx, buf), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_write(cx, buf), + Self::Http(s) => Pin::new(s).poll_write(cx, buf), + Self::Https(s) => Pin::new(s).poll_write(cx, buf), } } #[inline] fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::get_mut(self) { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_flush(cx), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_flush(cx), + Self::Http(s) => Pin::new(s).poll_flush(cx), + Self::Https(s) => Pin::new(s).poll_flush(cx), } } #[inline] fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match Pin::get_mut(self) { - MaybeHttpsStream::Http(s) => Pin::new(s).poll_shutdown(cx), - MaybeHttpsStream::Https(s) => Pin::new(s).poll_shutdown(cx), + Self::Http(s) => Pin::new(s).poll_shutdown(cx), + Self::Https(s) => Pin::new(s).poll_shutdown(cx), } } } diff --git a/tests/tests.rs b/tests/tests.rs index 72dd114..e9cb839 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -24,7 +24,7 @@ fn client_command() -> Command { fn wait_for_server(addr: &str) { for i in 0..10 { - if let Ok(_) = TcpStream::connect(addr) { + if TcpStream::connect(addr).is_ok() { return; } thread::sleep(time::Duration::from_millis(i * 100)); @@ -72,7 +72,7 @@ fn server() { println!("curl stderr:\n{}", String::from_utf8_lossy(&output.stderr)); } - assert_eq!(String::from_utf8_lossy(&*output.stdout), "Try POST /echo\n"); + assert_eq!(String::from_utf8_lossy(&output.stdout), "Try POST /echo\n"); } #[test]