Skip to content

Commit

Permalink
feat(foundations): add keep-alive timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
TroyKomodo committed Aug 28, 2024
1 parent 5cbe01c commit 784a69a
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 431 deletions.
11 changes: 11 additions & 0 deletions foundations/src/http/server/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use hyper_util::rt::TokioExecutor;

Expand All @@ -25,6 +26,7 @@ pub struct ServerBuilder {
http1_2: hyper_util::server::conn::auto::Builder<TokioExecutor>,
#[cfg(feature = "http3")]
quic: Option<super::Quic>,
keep_alive_timeout: Option<std::time::Duration>,
worker_count: usize,
}

Expand All @@ -39,6 +41,7 @@ impl Default for ServerBuilder {
http1_2: hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()),
#[cfg(feature = "http3")]
quic: None,
keep_alive_timeout: Some(Duration::from_secs(30)),
worker_count: 1,
}
}
Expand Down Expand Up @@ -91,6 +94,13 @@ impl ServerBuilder {
self
}

/// Set the keep alive timeout for the server.
/// Defaults to 5 seconds.
pub fn with_keep_alive_timeout(mut self, timeout: impl Into<Option<std::time::Duration>>) -> Self {
self.keep_alive_timeout = timeout.into();
self
}

/// Build the server.
pub fn build<M>(self, make_service: M) -> Result<Server<M>, Error>
where
Expand Down Expand Up @@ -124,6 +134,7 @@ impl ServerBuilder {
backends: Vec::new(),
handler: None,
worker_count: self.worker_count,
keep_alive_timeout: self.keep_alive_timeout,
})
}
}
10 changes: 7 additions & 3 deletions foundations/src/http/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct Server<M> {
backends: Vec<AbortOnDrop>,
handler: Option<crate::context::Handler>,
worker_count: usize,
keep_alive_timeout: Option<std::time::Duration>,
}

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -152,7 +153,8 @@ impl<M: MakeService> Server<M> {
for i in 0..self.worker_count {
let tcp_listener = make_tcp_listener(self.bind)?;
let make_service = self.make_service.clone();
let backend = TlsBackend::new(tcp_listener, acceptor.clone(), self.http1_2.clone(), &ctx);
let backend = TlsBackend::new(tcp_listener, acceptor.clone(), self.http1_2.clone(), &ctx)
.with_keep_alive_timeout(self.keep_alive_timeout);
let span = tracing::info_span!("tls", addr = %self.bind, worker = i);
self.backends
.push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span))));
Expand All @@ -170,7 +172,8 @@ impl<M: MakeService> Server<M> {
for i in 0..self.worker_count {
let tcp_listener = make_tcp_listener(addr)?;
let make_service = self.make_service.clone();
let backend = TcpBackend::new(tcp_listener, self.http1_2.clone(), &ctx);
let backend = TcpBackend::new(tcp_listener, self.http1_2.clone(), &ctx)
.with_keep_alive_timeout(self.keep_alive_timeout);
let span = tracing::info_span!("tcp", addr = %addr, worker = i);
self.backends
.push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span))));
Expand All @@ -188,7 +191,8 @@ impl<M: MakeService> Server<M> {
quinn::default_runtime().unwrap(),
)?;
let make_service = self.make_service.clone();
let backend = QuicBackend::new(endpoint, quic.h3.clone(), &ctx);
let backend =
QuicBackend::new(endpoint, quic.h3.clone(), &ctx).with_keep_alive_timeout(self.keep_alive_timeout);
let span = tracing::info_span!("quic", addr = %self.bind, worker = i);
self.backends
.push(AbortOnDrop::new(spawn(backend.serve(make_service).instrument(span))));
Expand Down
23 changes: 23 additions & 0 deletions foundations/src/http/server/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ pub mod tcp;
pub mod tls;

use std::convert::Infallible;
use std::sync::Arc;

pub use axum::body::Body;
pub use axum::extract::Request;
pub use axum::response::{IntoResponse, Response};
use rand::Rng;

use super::Error;

Expand Down Expand Up @@ -152,3 +154,24 @@ pub enum SocketKind {
TlsTcp,
Quic,
}

fn jitter(duration: std::time::Duration) -> std::time::Duration {
let mut rng = rand::thread_rng();
let jitter = rng.gen_range(0..duration.as_millis() / 10);
duration + std::time::Duration::from_millis(jitter as u64)
}

struct ActiveRequestsGuard(Arc<std::sync::atomic::AtomicUsize>);

impl Drop for ActiveRequestsGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
}
}

impl ActiveRequestsGuard {
fn new(active_requests: Arc<std::sync::atomic::AtomicUsize>) -> Self {
active_requests.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self(active_requests)
}
}
49 changes: 46 additions & 3 deletions foundations/src/http/server/stream/quic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use tracing::Instrument;

use super::{Backend, IncomingConnection, MakeService, ServiceHandler, SocketKind};
use crate::context::ContextFutExt;
use crate::http::server::stream::{jitter, ActiveRequestsGuard};
use crate::http::server::Error;
#[cfg(feature = "runtime")]
use crate::runtime::spawn;
Expand All @@ -30,6 +31,7 @@ pub struct QuicBackend {
endpoint: quinn::Endpoint,
builder: Arc<Builder>,
handler: crate::context::Handler,
keep_alive_timeout: Option<std::time::Duration>,
}

impl QuicBackend {
Expand All @@ -38,8 +40,14 @@ impl QuicBackend {
endpoint,
builder,
handler: ctx.new_child().1,
keep_alive_timeout: None,
}
}

pub fn with_keep_alive_timeout(mut self, timeout: impl Into<Option<std::time::Duration>>) -> Self {
self.keep_alive_timeout = timeout.into();
self
}
}

struct IncomingQuicConnection<'a> {
Expand Down Expand Up @@ -82,7 +90,13 @@ impl Backend for QuicBackend {
break;
};

let connection = connection.accept()?;
let connection = match connection.accept() {
Ok(connection) => connection,
Err(e) => {
tracing::debug!(error = %e, "failed to accept quic connection");
continue;
}
};

let span = tracing::trace_span!("connection", remote_addr = %connection.remote_address());
let _guard = span.enter();
Expand All @@ -106,6 +120,7 @@ impl Backend for QuicBackend {
connection,
builder: self.builder.clone(),
service,
keep_alive_timeout: self.keep_alive_timeout,
parent_ctx: ctx,
}
.serve()
Expand All @@ -125,6 +140,7 @@ struct Connection<S: ServiceHandler> {
connection: Connecting,
builder: Arc<Builder>,
service: S,
keep_alive_timeout: Option<std::time::Duration>,
parent_ctx: crate::context::Context,
}

Expand All @@ -138,7 +154,10 @@ impl<S: ServiceHandler> Connection<S> {
self.service.on_close().await;
return;
}
None => return,
None => {
self.service.on_close().await;
return;
}
};

let mut connection = match self
Expand All @@ -153,7 +172,10 @@ impl<S: ServiceHandler> Connection<S> {
self.service.on_close().await;
return;
}
None => return,
None => {
self.service.on_close().await;
return;
}
};

let (hijack_conn_tx, mut hijack_conn_rx) = tokio::sync::mpsc::channel::<SendQuicConnection>(1);
Expand All @@ -170,6 +192,8 @@ impl<S: ServiceHandler> Connection<S> {
// When the above is cancelled, the connection is allowed to finish.
let connection_handle = crate::context::Handler::new();

let active_requests = Arc::new(std::sync::atomic::AtomicUsize::new(0));

loop {
let (request, stream) = tokio::select! {
request = connection.accept() => {
Expand All @@ -189,6 +213,23 @@ impl<S: ServiceHandler> Connection<S> {
break;
}
}
},
Some(_) = async {
if let Some(keep_alive_timeout) = self.keep_alive_timeout {
loop {
tokio::time::sleep(jitter(keep_alive_timeout)).await;
if active_requests.load(std::sync::atomic::Ordering::Relaxed) != 0 {
continue;
}

break Some(());
}
} else {
None
}
} => {
tracing::debug!("keep alive timeout");
break;
}
// This happens when the connection has been upgraded to a WebTransport connection.
Some(send_hijack_conn) = hijack_conn_rx.recv() => {
Expand All @@ -201,6 +242,7 @@ impl<S: ServiceHandler> Connection<S> {
};

tracing::trace!("new request");
let active_requests = ActiveRequestsGuard::new(active_requests.clone());

let service = self.service.clone();
let stream = QuinnStream::new(stream);
Expand All @@ -226,6 +268,7 @@ impl<S: ServiceHandler> Connection<S> {
service.on_error(err).await;
}

drop(active_requests);
drop(ctx);
}
.with_context(connection_context)
Expand Down
Loading

0 comments on commit 784a69a

Please sign in to comment.