Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow running http1-only or http2-only servers #156

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog], and this project adheres to

- **changed**: Updated `tower` from `0.4` to `0.5`.
- **added**: Support reading PKCS\#1 and SEC1 private keys with Rustls.
- **added**: Support for http1-only and http2-only servers.

# 0.7.1 (31. July 2024)

Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ tls-openssl = ["arc-swap", "openssl", "tokio-openssl", "dep:pin-project-lite"]

[dependencies]
bytes = "1"
either = "1.13"
http = "1.1"
http-body = "1.0"
hyper = { version = "1.4", features = ["http1", "http2", "server"] }
Expand Down
187 changes: 161 additions & 26 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
handle::Handle,
service::{MakeService, SendService},
};
use either::Either;
use http::Request;
use hyper::body::Incoming;
use hyper_util::{
Expand All @@ -28,6 +29,7 @@ pub struct Server<A = DefaultAcceptor> {
builder: Builder<TokioExecutor>,
listener: Listener,
handle: Handle,
http_version: Option<HttpVersion>,
}

// Builder doesn't implement Debug or Clone right now
Expand Down Expand Up @@ -72,6 +74,7 @@ impl Server {
builder,
listener: Listener::Bind(addr),
handle,
http_version: None,
}
}

Expand All @@ -86,10 +89,17 @@ impl Server {
builder,
listener: Listener::Std(listener),
handle,
http_version: None,
}
}
}

#[derive(Clone, Copy, Eq, PartialEq)]
enum HttpVersion {
Http1,
Http2,
}

impl<A> Server<A> {
/// Overwrite acceptor.
pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> Server<Acceptor> {
Expand All @@ -98,6 +108,7 @@ impl<A> Server<A> {
builder: self.builder,
listener: self.listener,
handle: self.handle,
http_version: None,
}
}

Expand All @@ -111,6 +122,7 @@ impl<A> Server<A> {
builder: self.builder,
listener: self.listener,
handle: self.handle,
http_version: None,
}
}

Expand All @@ -129,6 +141,20 @@ impl<A> Server<A> {
&mut self.builder
}

/// Only accepts HTTP/1
pub fn http1_only(mut self) -> Self {
self.http_version = Some(HttpVersion::Http1);
self.builder = self.builder.http1_only();
self
}

/// Only accepts HTTP/2
pub fn http2_only(mut self) -> Self {
self.http_version = Some(HttpVersion::Http2);
self.builder = self.builder.http2_only();
self
}

/// Provide a handle for additional utilities.
pub fn handle(mut self, handle: Handle) -> Self {
self.handle = handle;
Expand Down Expand Up @@ -192,20 +218,26 @@ impl<A> Server<A> {
let acceptor = acceptor.clone();
let watcher = handle.watcher();
let builder = builder.clone();
let http_version = self.http_version;

tokio::spawn(async move {
if let Ok((stream, send_service)) = acceptor.accept(tcp_stream, service).await {
let io = TokioIo::new(stream);
let service = send_service.into_service();
let service = TowerToHyperService::new(service);

let serve_future = builder.serve_connection_with_upgrades(io, service);
let serve_future = match http_version {
Some(_) => Either::Left(builder.serve_connection(io, service)),
_ => Either::Right(builder.serve_connection_with_upgrades(io, service)),
};
tokio::pin!(serve_future);

let mut serve_future = serve_future.as_pin_mut();
tokio::select! {
biased;
_ = watcher.wait_graceful_shutdown() => {
serve_future.as_mut().graceful_shutdown();
match &mut serve_future {
Either::Left(serve_future) => serve_future.as_mut().graceful_shutdown(),
Either::Right(serve_future) => serve_future.as_mut().graceful_shutdown(),
}
tokio::select! {
biased;
_ = watcher.wait_shutdown() => (),
Expand Down Expand Up @@ -270,7 +302,10 @@ pub(crate) fn io_other<E: Into<BoxError>>(error: E) -> io::Error {

#[cfg(test)]
mod tests {
use crate::{handle::Handle, server::Server};
use crate::{
handle::Handle,
server::{HttpVersion, Server},
};
use axum::body::Body;
use axum::response::Response;
use axum::routing::post;
Expand All @@ -280,9 +315,10 @@ mod tests {
use http::{Method, Request, Uri};
use http_body::Frame;
use http_body_util::{BodyExt, StreamBody};
use hyper::client::conn::http1::handshake;
use hyper::client::conn::http1::SendRequest;
use hyper_util::rt::TokioIo;
use hyper::client;
use hyper::client::conn::http1;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::{io, net::SocketAddr, time::Duration};
use tokio::sync::oneshot;
use tokio::{net::TcpStream, task::JoinHandle, time::timeout};
Expand All @@ -295,7 +331,7 @@ mod tests {

// Client can send requests

do_empty_request(&mut client).await.unwrap();
do_empty_request_h1(&mut client).await.unwrap();

do_slow_request(&mut client, Duration::from_millis(50))
.await
Expand All @@ -309,12 +345,12 @@ mod tests {
let (mut client, conn) = connect(addr).await;

// Client can send request before shutdown.
do_empty_request(&mut client).await.unwrap();
do_empty_request_h1(&mut client).await.unwrap();

handle.shutdown();

// After shutdown, all client requests should fail.
do_empty_request(&mut client).await.unwrap_err();
do_empty_request_h1(&mut client).await.unwrap_err();

// Connection should finish soon.
let _ = timeout(Duration::from_secs(1), conn).await.unwrap();
Expand All @@ -329,8 +365,8 @@ mod tests {
let (mut client2, _conn2) = connect(addr).await;

// Clients can send request before graceful shutdown.
do_empty_request(&mut client1).await.unwrap();
do_empty_request(&mut client2).await.unwrap();
do_empty_request_h1(&mut client1).await.unwrap();
do_empty_request_h1(&mut client2).await.unwrap();

let start = tokio::time::Instant::now();

Expand All @@ -355,9 +391,9 @@ mod tests {
handle.graceful_shutdown(None);

// Any new requests after graceful shutdown begins will fail
do_empty_request(&mut client2).await.unwrap_err();
do_empty_request(&mut client2).await.unwrap_err();
do_empty_request(&mut client2).await.unwrap_err();
do_empty_request_h1(&mut client2).await.unwrap_err();
do_empty_request_h1(&mut client2).await.unwrap_err();
do_empty_request_h1(&mut client2).await.unwrap_err();
};

tokio::join!(fut1, fut2);
Expand All @@ -384,8 +420,8 @@ mod tests {
let (mut client2, _conn2) = connect(addr).await;

// Clients can send request before graceful shutdown.
do_empty_request(&mut client1).await.unwrap();
do_empty_request(&mut client2).await.unwrap();
do_empty_request_h1(&mut client1).await.unwrap();
do_empty_request_h1(&mut client2).await.unwrap();

let start = tokio::time::Instant::now();

Expand Down Expand Up @@ -430,7 +466,43 @@ mod tests {
tokio::join!(task1, task2, task3);
}

async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
#[tokio::test]
async fn test_http1_only() {
let (_handle, _server_task, addr) =
start_server_with_http_version(Some(HttpVersion::Http1)).await;

let (mut client, _conn) = connect_h1(addr).await;

do_empty_request_h1(&mut client).await.unwrap();

do_slow_request(&mut client, Duration::from_millis(50))
.await
.unwrap();

let (mut client, _conn) = connect_h2(addr).await;
do_empty_request_h2(&mut client).await.unwrap_err();
}

#[tokio::test]
async fn test_http2_only() {
let (_handle, _server_task, addr) =
start_server_with_http_version(Some(HttpVersion::Http2)).await;

let (mut client, _conn) = connect_h2(addr).await;

do_empty_request_h2(&mut client).await.unwrap();

do_slow_request_h2(&mut client, Duration::from_millis(50))
.await
.unwrap();

let (mut client, _conn) = connect_h1(addr).await;
do_empty_request_h1(&mut client).await.unwrap_err();
}

async fn start_server_with_http_version(
http_version: Option<HttpVersion>,
) -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
let handle = Handle::new();

let server_handle = handle.clone();
Expand All @@ -446,8 +518,14 @@ mod tests {
);

let addr = SocketAddr::from(([127, 0, 0, 1], 0));

Server::bind(addr)
let server = Server::bind(addr);
let server = match http_version {
Some(HttpVersion::Http1) => server.http1_only(),
Some(HttpVersion::Http2) => server.http2_only(),
None => server,
};

server
.handle(server_handle)
.serve(app.into_make_service())
.await
Expand All @@ -458,9 +536,17 @@ mod tests {
(handle, server_task, addr)
}

async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>) {
async fn start_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
start_server_with_http_version(None).await
}

async fn connect(addr: SocketAddr) -> (http1::SendRequest<Body>, JoinHandle<()>) {
connect_h1(addr).await
}

async fn connect_h1(addr: SocketAddr) -> (http1::SendRequest<Body>, JoinHandle<()>) {
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (send_request, connection) = handshake(stream).await.unwrap();
let (send_request, connection) = client::conn::http1::handshake(stream).await.unwrap();

let task = tokio::spawn(async move {
let _ = connection.await;
Expand All @@ -469,8 +555,36 @@ mod tests {
(send_request, task)
}

async fn connect_h2(addr: SocketAddr) -> (http2::SendRequest<Body>, JoinHandle<()>) {
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (send_request, connection) =
client::conn::http2::handshake(TokioExecutor::new(), stream)
.await
.unwrap();

let task = tokio::spawn(async move {
let _ = connection.await;
});

(send_request, task)
}

// Send a basic `GET /` request.
async fn do_empty_request_h1(client: &mut http1::SendRequest<Body>) -> hyper::Result<()> {
client.ready().await?;

let body = client
.send_request(Request::new(Body::empty()))
.await?
.into_body();

let body = body.collect().await?.to_bytes();
assert_eq!(body.as_ref(), b"Hello, world!");
Ok(())
}

// Send a basic `GET /` request.
async fn do_empty_request(client: &mut SendRequest<Body>) -> hyper::Result<()> {
async fn do_empty_request_h2(client: &mut http2::SendRequest<Body>) -> hyper::Result<()> {
client.ready().await?;

let body = client
Expand All @@ -486,15 +600,36 @@ mod tests {
// Send a request with a body streamed byte-by-byte, over a given duration,
// then wait for the full response.
async fn do_slow_request(
client: &mut SendRequest<Body>,
client: &mut http1::SendRequest<Body>,
duration: Duration,
) -> hyper::Result<()> {
let response = send_slow_request(client, duration).await?;
recv_slow_response_body(response).await
}

async fn do_slow_request_h2(
client: &mut http2::SendRequest<Body>,
duration: Duration,
) -> hyper::Result<()> {
let response = send_slow_request_h2(client, duration).await?;
recv_slow_response_body(response).await
}

async fn send_slow_request(
client: &mut SendRequest<Body>,
client: &mut http1::SendRequest<Body>,
duration: Duration,
) -> hyper::Result<http::Response<hyper::body::Incoming>> {
let req_body_len: usize = 10;
let mut req = Request::new(slow_body(req_body_len, duration));
*req.method_mut() = Method::POST;
*req.uri_mut() = Uri::from_static("/echo_slowly");

client.ready().await?;
client.send_request(req).await
}

async fn send_slow_request_h2(
client: &mut http2::SendRequest<Body>,
duration: Duration,
) -> hyper::Result<http::Response<hyper::body::Incoming>> {
let req_body_len: usize = 10;
Expand Down