From 8f7d300ca1ac6b93460f62505c8ffc31122efe06 Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Fri, 25 Oct 2024 14:25:04 -0700 Subject: [PATCH 1/6] expose http1_only and http2_only --- src/server.rs | 90 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/src/server.rs b/src/server.rs index 32c75b54..c0bbf677 100644 --- a/src/server.rs +++ b/src/server.rs @@ -129,6 +129,18 @@ impl Server { &mut self.builder } + /// Only accepts HTTP/1 + pub fn http1_only(mut self) -> Self { + self.builder = self.builder.http1_only(); + self + } + + /// Only accepts HTTP/2 + pub fn http2_only(mut self) -> Self { + self.builder = self.builder.http2_only(); + self + } + /// Provide a handle for additional utilities. pub fn handle(mut self, handle: Handle) -> Self { self.handle = handle; @@ -280,9 +292,10 @@ mod tests { use http::{Method, Request, Uri}; use http_body::Frame; use http_body_util::{BodyExt, StreamBody}; - use hyper::client::conn::http1::handshake; + use hyper::client; use hyper::client::conn::http1::SendRequest; - use hyper_util::rt::TokioIo; + use hyper::client::conn::http2; + use hyper_util::rt::{TokioExecutor, TokioIo}; use std::{io, net::SocketAddr, time::Duration}; use tokio::sync::oneshot; use tokio::{net::TcpStream, task::JoinHandle, time::timeout}; @@ -430,7 +443,32 @@ mod tests { tokio::join!(task1, task2, task3); } - async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { + #[tokio::test] + async fn test_http1_only() { + let (_handle, _server_task, addr) = + start_server_with_http_version(Some(HttpVersion::Http1)).await; + + let (mut client, _conn) = connect_h1(addr).await; + + do_empty_request(&mut client).await.unwrap(); + + do_slow_request(&mut client, Duration::from_millis(50)) + .await + .unwrap(); + + let (mut client, _conn) = connect_h2(addr).await; + do_empty_request_h2(&mut client).await.unwrap_err(); + } + + #[derive(PartialEq, Eq)] + enum HttpVersion { + Http1, + Http2, + } + + async fn start_server_with_http_version( + http_version: Option, + ) -> (Handle, JoinHandle>, SocketAddr) { let handle = Handle::new(); let server_handle = handle.clone(); @@ -446,8 +484,15 @@ mod tests { ); let addr = SocketAddr::from(([127, 0, 0, 1], 0)); - - Server::bind(addr) + let server = Server::bind(addr); + let server = match http_version { + Some(version) if version == HttpVersion::Http1 => server.http1_only(), + Some(version) if version == HttpVersion::Http2 => server.http2_only(), + Some(_) => panic!("Invalid HTTP version"), + None => server, + }; + + server .handle(server_handle) .serve(app.into_make_service()) .await @@ -458,9 +503,28 @@ mod tests { (handle, server_task, addr) } + async fn start_server() -> (Handle, JoinHandle>, SocketAddr) { + start_server_with_http_version(None).await + } + async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { + connect_h1(addr).await + } + + async fn connect_h1(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { + let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (send_request, connection) = client::conn::http1::handshake(stream).await.unwrap(); + + let task = tokio::spawn(async move { + let _ = connection.await; + }); + + (send_request, task) + } + + async fn connect_h2(addr: SocketAddr) -> (http2::SendRequest, JoinHandle<()>) { let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); - let (send_request, connection) = handshake(stream).await.unwrap(); + let (send_request, connection) = client::conn::http2::handshake(TokioExecutor::new(), stream).await.unwrap(); let task = tokio::spawn(async move { let _ = connection.await; @@ -483,6 +547,20 @@ mod tests { Ok(()) } + // Send a basic `GET /` request. + async fn do_empty_request_h2(client: &mut http2::SendRequest) -> hyper::Result<()> { + client.ready().await?; + + let body = client + .send_request(Request::new(Body::empty())) + .await? + .into_body(); + + let body = body.collect().await?.to_bytes(); + assert_eq!(body.as_ref(), b"Hello, world!"); + Ok(()) + } + // Send a request with a body streamed byte-by-byte, over a given duration, // then wait for the full response. async fn do_slow_request( From 443a14ec04db4adc3bd8159b5030304898cf856d Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Fri, 25 Oct 2024 18:13:48 -0700 Subject: [PATCH 2/6] test passing --- src/server.rs | 57 +++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/src/server.rs b/src/server.rs index c0bbf677..f9daec3f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -28,6 +28,7 @@ pub struct Server { builder: Builder, listener: Listener, handle: Handle, + http_version: Option, } // Builder doesn't implement Debug or Clone right now @@ -72,6 +73,7 @@ impl Server { builder, listener: Listener::Bind(addr), handle, + http_version: None, } } @@ -86,10 +88,17 @@ impl Server { builder, listener: Listener::Std(listener), handle, + http_version: None, } } } +#[derive(Clone, Copy)] +enum HttpVersion { + Http1, + Http2, +} + impl Server { /// Overwrite acceptor. pub fn acceptor(self, acceptor: Acceptor) -> Server { @@ -98,6 +107,7 @@ impl Server { builder: self.builder, listener: self.listener, handle: self.handle, + http_version: None, } } @@ -111,6 +121,7 @@ impl Server { builder: self.builder, listener: self.listener, handle: self.handle, + http_version: None, } } @@ -131,12 +142,14 @@ impl Server { /// Only accepts HTTP/1 pub fn http1_only(mut self) -> Self { + self.http_version = Some(HttpVersion::Http1); self.builder = self.builder.http1_only(); self } /// Only accepts HTTP/2 pub fn http2_only(mut self) -> Self { + self.http_version = Some(HttpVersion::Http2); self.builder = self.builder.http2_only(); self } @@ -204,28 +217,51 @@ impl Server { let acceptor = acceptor.clone(); let watcher = handle.watcher(); let builder = builder.clone(); + let http_version = self.http_version.clone(); tokio::spawn(async move { if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await { let io = TokioIo::new(stream); let service = send_service.into_service(); let service = TowerToHyperService::new(service); + match http_version { + Some(_) => { + let serve_future = builder.serve_connection(io, service); + tokio::pin!(serve_future); - let serve_future = builder.serve_connection_with_upgrades(io, service); - tokio::pin!(serve_future); + tokio::select! { + biased; + _ = watcher.wait_graceful_shutdown() => { + serve_future.as_mut().graceful_shutdown(); + tokio::select! { + biased; + _ = watcher.wait_shutdown() => (), + _ = &mut serve_future => (), + } + } + _ = watcher.wait_shutdown() => (), + _ = &mut serve_future => (), + } + } + None => { + let serve_future = + builder.serve_connection_with_upgrades(io, service); + tokio::pin!(serve_future); - tokio::select! { - biased; - _ = watcher.wait_graceful_shutdown() => { - serve_future.as_mut().graceful_shutdown(); tokio::select! { biased; + _ = watcher.wait_graceful_shutdown() => { + serve_future.as_mut().graceful_shutdown(); + tokio::select! { + biased; + _ = watcher.wait_shutdown() => (), + _ = &mut serve_future => (), + } + } _ = watcher.wait_shutdown() => (), _ = &mut serve_future => (), } } - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), } } }); @@ -524,7 +560,10 @@ mod tests { async fn connect_h2(addr: SocketAddr) -> (http2::SendRequest, JoinHandle<()>) { let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); - let (send_request, connection) = client::conn::http2::handshake(TokioExecutor::new(), stream).await.unwrap(); + let (send_request, connection) = + client::conn::http2::handshake(TokioExecutor::new(), stream) + .await + .unwrap(); let task = tokio::spawn(async move { let _ = connection.await; From 7d489327ca34bcf53dc397f65950226d523c6e3d Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Mon, 28 Oct 2024 09:39:22 -0700 Subject: [PATCH 3/6] switch to either --- Cargo.toml | 1 + src/server.rs | 128 +++++++++++++++++++++++++++++--------------------- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 68c312c0..22fe91cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ tls-openssl = ["arc-swap", "openssl", "tokio-openssl", "dep:pin-project-lite"] [dependencies] bytes = "1" +either = "1.13" http = "1.1" http-body = "1.0" hyper = { version = "1.4", features = ["http1", "http2", "server"] } diff --git a/src/server.rs b/src/server.rs index f9daec3f..b765647d 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use crate::{ handle::Handle, service::{MakeService, SendService}, }; +use either::Either; use http::Request; use hyper::body::Incoming; use hyper_util::{ @@ -93,7 +94,7 @@ impl Server { } } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Eq, PartialEq)] enum HttpVersion { Http1, Http2, @@ -224,44 +225,29 @@ impl Server { let io = TokioIo::new(stream); let service = send_service.into_service(); let service = TowerToHyperService::new(service); - match http_version { + let serve_future = match http_version { Some(_) => { - let serve_future = builder.serve_connection(io, service); - tokio::pin!(serve_future); - - tokio::select! { - biased; - _ = watcher.wait_graceful_shutdown() => { - serve_future.as_mut().graceful_shutdown(); - tokio::select! { - biased; - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } + Either::Left(builder.serve_connection(io, service)) } - None => { - let serve_future = - builder.serve_connection_with_upgrades(io, service); - tokio::pin!(serve_future); - + _ => Either::Right(builder.serve_connection_with_upgrades(io, service)), + }; + tokio::pin!(serve_future); + let mut serve_future = serve_future.as_pin_mut(); + tokio::select! { + biased; + _ = watcher.wait_graceful_shutdown() => { + match &mut serve_future { + Either::Left(serve_future) => serve_future.as_mut().graceful_shutdown(), + Either::Right(serve_future) => serve_future.as_mut().graceful_shutdown(), + } tokio::select! { biased; - _ = watcher.wait_graceful_shutdown() => { - serve_future.as_mut().graceful_shutdown(); - tokio::select! { - biased; - _ = watcher.wait_shutdown() => (), - _ = &mut serve_future => (), - } - } _ = watcher.wait_shutdown() => (), _ = &mut serve_future => (), } } + _ = watcher.wait_shutdown() => (), + _ = &mut serve_future => (), } } }); @@ -318,7 +304,10 @@ pub(crate) fn io_other>(error: E) -> io::Error { #[cfg(test)] mod tests { - use crate::{handle::Handle, server::Server}; + use crate::{ + handle::Handle, + server::{HttpVersion, Server}, + }; use axum::body::Body; use axum::response::Response; use axum::routing::post; @@ -329,7 +318,7 @@ mod tests { use http_body::Frame; use http_body_util::{BodyExt, StreamBody}; use hyper::client; - use hyper::client::conn::http1::SendRequest; + use hyper::client::conn::http1; use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use std::{io, net::SocketAddr, time::Duration}; @@ -344,7 +333,7 @@ mod tests { // Client can send requests - do_empty_request(&mut client).await.unwrap(); + do_empty_request_h1(&mut client).await.unwrap(); do_slow_request(&mut client, Duration::from_millis(50)) .await @@ -358,12 +347,12 @@ mod tests { let (mut client, conn) = connect(addr).await; // Client can send request before shutdown. - do_empty_request(&mut client).await.unwrap(); + do_empty_request_h1(&mut client).await.unwrap(); handle.shutdown(); // After shutdown, all client requests should fail. - do_empty_request(&mut client).await.unwrap_err(); + do_empty_request_h1(&mut client).await.unwrap_err(); // Connection should finish soon. let _ = timeout(Duration::from_secs(1), conn).await.unwrap(); @@ -378,8 +367,8 @@ mod tests { let (mut client2, _conn2) = connect(addr).await; // Clients can send request before graceful shutdown. - do_empty_request(&mut client1).await.unwrap(); - do_empty_request(&mut client2).await.unwrap(); + do_empty_request_h1(&mut client1).await.unwrap(); + do_empty_request_h1(&mut client2).await.unwrap(); let start = tokio::time::Instant::now(); @@ -404,9 +393,9 @@ mod tests { handle.graceful_shutdown(None); // Any new requests after graceful shutdown begins will fail - do_empty_request(&mut client2).await.unwrap_err(); - do_empty_request(&mut client2).await.unwrap_err(); - do_empty_request(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); + do_empty_request_h1(&mut client2).await.unwrap_err(); }; tokio::join!(fut1, fut2); @@ -433,8 +422,8 @@ mod tests { let (mut client2, _conn2) = connect(addr).await; // Clients can send request before graceful shutdown. - do_empty_request(&mut client1).await.unwrap(); - do_empty_request(&mut client2).await.unwrap(); + do_empty_request_h1(&mut client1).await.unwrap(); + do_empty_request_h1(&mut client2).await.unwrap(); let start = tokio::time::Instant::now(); @@ -486,7 +475,7 @@ mod tests { let (mut client, _conn) = connect_h1(addr).await; - do_empty_request(&mut client).await.unwrap(); + do_empty_request_h1(&mut client).await.unwrap(); do_slow_request(&mut client, Duration::from_millis(50)) .await @@ -496,10 +485,21 @@ mod tests { do_empty_request_h2(&mut client).await.unwrap_err(); } - #[derive(PartialEq, Eq)] - enum HttpVersion { - Http1, - Http2, + #[tokio::test] + async fn test_http2_only() { + let (_handle, _server_task, addr) = + start_server_with_http_version(Some(HttpVersion::Http2)).await; + + let (mut client, _conn) = connect_h2(addr).await; + + do_empty_request_h2(&mut client).await.unwrap(); + + do_slow_request_h2(&mut client, Duration::from_millis(50)) + .await + .unwrap(); + + let (mut client, _conn) = connect_h1(addr).await; + do_empty_request_h1(&mut client).await.unwrap_err(); } async fn start_server_with_http_version( @@ -543,11 +543,11 @@ mod tests { start_server_with_http_version(None).await } - async fn connect(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { + async fn connect(addr: SocketAddr) -> (http1::SendRequest, JoinHandle<()>) { connect_h1(addr).await } - async fn connect_h1(addr: SocketAddr) -> (SendRequest, JoinHandle<()>) { + async fn connect_h1(addr: SocketAddr) -> (http1::SendRequest, JoinHandle<()>) { let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap()); let (send_request, connection) = client::conn::http1::handshake(stream).await.unwrap(); @@ -573,7 +573,7 @@ mod tests { } // Send a basic `GET /` request. - async fn do_empty_request(client: &mut SendRequest) -> hyper::Result<()> { + async fn do_empty_request_h1(client: &mut http1::SendRequest) -> hyper::Result<()> { client.ready().await?; let body = client @@ -603,15 +603,23 @@ mod tests { // Send a request with a body streamed byte-by-byte, over a given duration, // then wait for the full response. async fn do_slow_request( - client: &mut SendRequest, + client: &mut http1::SendRequest, duration: Duration, ) -> hyper::Result<()> { let response = send_slow_request(client, duration).await?; recv_slow_response_body(response).await } + async fn do_slow_request_h2( + client: &mut http2::SendRequest, + duration: Duration, + ) -> hyper::Result<()> { + let response = send_slow_request_h2(client, duration).await?; + recv_slow_response_body(response).await + } + async fn send_slow_request( - client: &mut SendRequest, + client: &mut http1::SendRequest, duration: Duration, ) -> hyper::Result> { let req_body_len: usize = 10; @@ -623,6 +631,20 @@ mod tests { client.send_request(req).await } + async fn send_slow_request_h2( + client: &mut http2::SendRequest, + duration: Duration, + ) -> hyper::Result> { + let req_body_len: usize = 10; + let mut req = Request::new(slow_body(req_body_len, duration)); + *req.method_mut() = Method::POST; + *req.uri_mut() = Uri::from_static("/echo_slowly"); + + client.ready().await?; + client.send_request(req).await + } + + async fn recv_slow_response_body( response: http::Response, ) -> hyper::Result<()> { From bf6a7f208d94317097e80dba71aedf64f65ec1e0 Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Mon, 28 Oct 2024 10:13:07 -0700 Subject: [PATCH 4/6] fmt --- src/server.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index b765647d..cde14248 100644 --- a/src/server.rs +++ b/src/server.rs @@ -226,9 +226,7 @@ impl Server { let service = send_service.into_service(); let service = TowerToHyperService::new(service); let serve_future = match http_version { - Some(_) => { - Either::Left(builder.serve_connection(io, service)) - } + Some(_) => Either::Left(builder.serve_connection(io, service)), _ => Either::Right(builder.serve_connection_with_upgrades(io, service)), }; tokio::pin!(serve_future); @@ -644,7 +642,6 @@ mod tests { client.send_request(req).await } - async fn recv_slow_response_body( response: http::Response, ) -> hyper::Result<()> { From 98d27e6e29fc7ece48710c252e998246ff917af2 Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Mon, 28 Oct 2024 10:15:51 -0700 Subject: [PATCH 5/6] clippy --- src/server.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/server.rs b/src/server.rs index cde14248..67336637 100644 --- a/src/server.rs +++ b/src/server.rs @@ -218,7 +218,7 @@ impl Server { let acceptor = acceptor.clone(); let watcher = handle.watcher(); let builder = builder.clone(); - let http_version = self.http_version.clone(); + let http_version = self.http_version; tokio::spawn(async move { if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await { @@ -520,9 +520,8 @@ mod tests { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let server = Server::bind(addr); let server = match http_version { - Some(version) if version == HttpVersion::Http1 => server.http1_only(), - Some(version) if version == HttpVersion::Http2 => server.http2_only(), - Some(_) => panic!("Invalid HTTP version"), + Some(HttpVersion::Http1) => server.http1_only(), + Some(HttpVersion::Http2) => server.http2_only(), None => server, }; From 8f64315afe9873981c79236ec46043db98a2252e Mon Sep 17 00:00:00 2001 From: Gilad Wolff Date: Mon, 28 Oct 2024 10:28:10 -0700 Subject: [PATCH 6/6] update changlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb59db2f..d046f3f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog], and this project adheres to - **changed**: Updated `tower` from `0.4` to `0.5`. - **added**: Support reading PKCS\#1 and SEC1 private keys with Rustls. +- **added**: Support for http1-only and http2-only servers. # 0.7.1 (31. July 2024)