diff --git a/CHANGELOG.md b/CHANGELOG.md index eb59db2f..d046f3f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog], and this project adheres to - **changed**: Updated `tower` from `0.4` to `0.5`. - **added**: Support reading PKCS\#1 and SEC1 private keys with Rustls. +- **added**: Support for http1-only and http2-only servers. # 0.7.1 (31. July 2024) diff --git a/Cargo.toml b/Cargo.toml index 68c312c0..22fe91cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ tls-openssl = ["arc-swap", "openssl", "tokio-openssl", "dep:pin-project-lite"] [dependencies] bytes = "1" +either = "1.13" http = "1.1" http-body = "1.0" hyper = { version = "1.4", features = ["http1", "http2", "server"] } diff --git a/src/server.rs b/src/server.rs index 32c75b54..67336637 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use crate::{ handle::Handle, service::{MakeService, SendService}, }; +use either::Either; use http::Request; use hyper::body::Incoming; use hyper_util::{ @@ -28,6 +29,7 @@ pub struct Server { builder: Builder, listener: Listener, handle: Handle, + http_version: Option, } // Builder doesn't implement Debug or Clone right now @@ -72,6 +74,7 @@ impl Server { builder, listener: Listener::Bind(addr), handle, + http_version: None, } } @@ -86,10 +89,17 @@ impl Server { builder, listener: Listener::Std(listener), handle, + http_version: None, } } } +#[derive(Clone, Copy, Eq, PartialEq)] +enum HttpVersion { + Http1, + Http2, +} + impl Server { /// Overwrite acceptor. pub fn acceptor(self, acceptor: Acceptor) -> Server { @@ -98,6 +108,7 @@ impl Server { builder: self.builder, listener: self.listener, handle: self.handle, + http_version: None, } } @@ -111,6 +122,7 @@ impl Server { builder: self.builder, listener: self.listener, handle: self.handle, + http_version: None, } } @@ -129,6 +141,20 @@ impl Server { &mut self.builder } + /// Only accepts HTTP/1 + pub fn http1_only(mut self) -> Self { + self.http_version = Some(HttpVersion::Http1); + self.builder = self.builder.http1_only(); + self + } + + /// Only accepts HTTP/2 + pub fn http2_only(mut self) -> Self { + self.http_version = Some(HttpVersion::Http2); + self.builder = self.builder.http2_only(); + self + } + /// Provide a handle for additional utilities. pub fn handle(mut self, handle: Handle) -> Self { self.handle = handle; @@ -192,20 +218,26 @@ impl Server { let acceptor = acceptor.clone(); let watcher = handle.watcher(); let builder = builder.clone(); + let http_version = self.http_version; tokio::spawn(async move { if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await { let io = TokioIo::new(stream); let service = send_service.into_service(); let service = TowerToHyperService::new(service); - - let serve_future = builder.serve_connection_with_upgrades(io, service); + let serve_future = match http_version { + Some(_) => Either::Left(builder.serve_connection(io, service)), + _ => Either::Right(builder.serve_connection_with_upgrades(io, service)), + }; tokio::pin!(serve_future); - + let mut serve_future = serve_future.as_pin_mut(); tokio::select! { biased; _ = watcher.wait_graceful_shutdown() => { - serve_future.as_mut().graceful_shutdown(); + match &mut serve_future { + Either::Left(serve_future) => serve_future.as_mut().graceful_shutdown(), + Either::Right(serve_future) => serve_future.as_mut().graceful_shutdown(), + } tokio::select! { biased; _ = watcher.wait_shutdown() => (), @@ -270,7 +302,10 @@ pub(crate) fn io_other>(error: E) -> io::Error { #[cfg(test)] mod tests { - use crate::{handle::Handle, server::Server}; + use crate::{ + handle::Handle, + server::{HttpVersion, Server}, + }; use axum::body::Body; use axum::response::Response; use axum::routing::post; @@ -280,9 +315,10 @@ mod tests { use http::{Method, Request, Uri}; use http_body::Frame; use http_body_util::{BodyExt, StreamBody}; - use hyper::client::conn::http1::handshake; - use hyper::client::conn::http1::SendRequest; - use hyper_util::rt::TokioIo; + use hyper::client; + use hyper::client::conn::http1; + use hyper::client::conn::http2; + use hyper_util::rt::{TokioExecutor, TokioIo}; use std::{io, net::SocketAddr, time::Duration}; use tokio::sync::oneshot; use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; @@ -295,7 +331,7 @@ mod tests { // Client can send requests - do_empty_request(&mut client).await.unwrap(); + do_empty_request_h1(&mut client).await.unwrap(); do_slow_request(&mut client, Duration::from_millis(50)) .await @@ -309,12 +345,12 @@ mod tests { let (mut client, conn) = connect(addr).await; // Client can send request before shutdown. - do_empty_request(&mut client).await.unwrap(); + do_empty_request_h1(&mut client).await.unwrap(); handle.shutdown(); // After shutdown, all client requests should fail. - do_empty_request(&mut client).await.unwrap_err(); + do_empty_request_h1(&mut client).await.unwrap_err(); // Connection should finish soon. let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); @@ -329,8 +365,8 @@ mod tests { let (mut client2, _conn2) = connect(addr).await; // Clients can send request before graceful shutdown. - do_empty_request(&mut client1).await.unwrap(); - do_empty_request(&mut client2).await.unwrap(); + do_empty_request_h1(&mut client1).await.unwrap(); + do_empty_request_h1(&mut client2).await.unwrap(); let start = tokio::time::Instant::now(); @@ -355,9 +391,9 @@ mod tests { handle.graceful_shutdown(None); // Any new requests after graceful shutdown begins will fail - do_empty_request(&mut client2).await.unwrap_err(); - do_empty_request(&mut client2).await.unwrap_err(); - do_empty_request(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); }; tokio::join!(fut1, fut2); @@ -384,8 +420,8 @@ mod tests { let (mut client2, _conn2) = connect(addr).await; // Clients can send request before graceful shutdown. - do_empty_request(&mut client1).await.unwrap(); - do_empty_request(&mut client2).await.unwrap(); + do_empty_request_h1(&mut client1).await.unwrap(); + do_empty_request_h1(&mut client2).await.unwrap(); let start = tokio::time::Instant::now(); @@ -430,7 +466,43 @@ mod tests { tokio::join!(task1, task2, task3); } - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { + #[tokio::test] + async fn test_http1_only() { + let (_handle, _server_task, addr) = + start_server_with_http_version(Some(HttpVersion::Http1)).await; + + let (mut client, _conn) = connect_h1(addr).await; + + do_empty_request_h1(&mut client).await.unwrap(); + + do_slow_request(&mut client, Duration::from_millis(50)) + .await + .unwrap(); + + let (mut client, _conn) = connect_h2(addr).await; + do_empty_request_h2(&mut client).await.unwrap_err(); + } + + #[tokio::test] + async fn test_http2_only() { + let (_handle, _server_task, addr) = + start_server_with_http_version(Some(HttpVersion::Http2)).await; + + let (mut client, _conn) = connect_h2(addr).await; + + do_empty_request_h2(&mut client).await.unwrap(); + + do_slow_request_h2(&mut client, Duration::from_millis(50)) + .await + .unwrap(); + + let (mut client, _conn) = connect_h1(addr).await; + do_empty_request_h1(&mut client).await.unwrap_err(); + } + + async fn start_server_with_http_version( + http_version: Option, + ) -> (Handle, JoinHandle>, SocketAddr) { let handle = Handle::new(); let server_handle = handle.clone(); @@ -446,8 +518,14 @@ mod tests { ); let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - Server::bind(addr) + let server = Server::bind(addr); + let server = match http_version { + Some(HttpVersion::Http1) => server.http1_only(), + Some(HttpVersion::Http2) => server.http2_only(), + None => server, + }; + + server .handle(server_handle) .serve(app.into_make_service()) .await @@ -458,9 +536,17 @@ mod tests { (handle, server_task, addr) } - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { + async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { + start_server_with_http_version(None).await + } + + async fn connect(addr: SocketAddr) -> (http1::SendRequest, JoinHandle<()>) { + connect_h1(addr).await + } + + async fn connect_h1(addr: SocketAddr) -> (http1::SendRequest, JoinHandle<()>) { let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); - let (send_request, connection) = handshake(stream).await.unwrap(); + let (send_request, connection) = client::conn::http1::handshake(stream).await.unwrap(); let task = tokio::spawn(async move { let _ = connection.await; @@ -469,8 +555,36 @@ mod tests { (send_request, task) } + async fn connect_h2(addr: SocketAddr) -> (http2::SendRequest, JoinHandle<()>) { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (send_request, connection) = + client::conn::http2::handshake(TokioExecutor::new(), stream) + .await + .unwrap(); + + let task = tokio::spawn(async move { + let _ = connection.await; + }); + + (send_request, task) + } + + // Send a basic `GET /` request. + async fn do_empty_request_h1(client: &mut http1::SendRequest) -> hyper::Result<()> { + client.ready().await?; + + let body = client + .send_request(Request::new(Body::empty())) + .await? + .into_body(); + + let body = body.collect().await?.to_bytes(); + assert_eq!(body.as_ref(), b"Hello, world!"); + Ok(()) + } + // Send a basic `GET /` request. - async fn do_empty_request(client: &mut SendRequest) -> hyper::Result<()> { + async fn do_empty_request_h2(client: &mut http2::SendRequest) -> hyper::Result<()> { client.ready().await?; let body = client @@ -486,15 +600,36 @@ mod tests { // Send a request with a body streamed byte-by-byte, over a given duration, // then wait for the full response. async fn do_slow_request( - client: &mut SendRequest, + client: &mut http1::SendRequest, duration: Duration, ) -> hyper::Result<()> { let response = send_slow_request(client, duration).await?; recv_slow_response_body(response).await } + async fn do_slow_request_h2( + client: &mut http2::SendRequest, + duration: Duration, + ) -> hyper::Result<()> { + let response = send_slow_request_h2(client, duration).await?; + recv_slow_response_body(response).await + } + async fn send_slow_request( - client: &mut SendRequest, + client: &mut http1::SendRequest, + duration: Duration, + ) -> hyper::Result> { + let req_body_len: usize = 10; + let mut req = Request::new(slow_body(req_body_len, duration)); + *req.method_mut() = Method::POST; + *req.uri_mut() = Uri::from_static("/echo_slowly"); + + client.ready().await?; + client.send_request(req).await + } + + async fn send_slow_request_h2( + client: &mut http2::SendRequest, duration: Duration, ) -> hyper::Result> { let req_body_len: usize = 10;