Skip to content

Commit

Permalink
more TLS work
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Apr 18, 2024
1 parent 363666d commit 3505182
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 92 deletions.
6 changes: 3 additions & 3 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ pub struct HttpServer {
pub http2_max_streams: u32,

/// How long to wait for the existing connections to finish before shutting down
#[clap(long = "http-server-grace-period", default_value = "30s", value_parser = parse_duration)]
#[clap(long = "http-server-grace-period", default_value = "10s", value_parser = parse_duration)]
pub grace_period: Duration,
}

#[derive(Args)]
pub struct Cert {
/// Read certificates from given directories, each certificate should be a pair .pem + .key files with the same base name
#[clap(long = "cert-dir")]
#[clap(long = "cert-provider-dir")]
pub dir: Vec<PathBuf>,

/// Request certificates from the 'certificate-issuer' instances reachable over given URLs
#[clap(long = "cert-syncer-url")]
#[clap(long = "cert-provider-syncer-url")]
pub syncer_urls: Vec<Url>,
}
10 changes: 8 additions & 2 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ pub async fn main(cli: Cli) -> Result<(), Error> {
let handler_token = token.clone();
ctrlc::set_handler(move || handler_token.cancel())?;

let router = axum::Router::new().route("/", axum::routing::get(|| async { "Hello, World!" }));
let router = axum::Router::new().route(
"/",
axum::routing::get(|| async {
tokio::time::sleep(std::time::Duration::from_secs(15)).await;
"Hello, World!"
}),
);

let mut runners: Vec<(String, Arc<dyn Run>)> = vec![];

Expand All @@ -45,7 +51,7 @@ pub async fn main(cli: Cli) -> Result<(), Error> {
runners.push(("http_server".into(), http_server));

// Set up HTTPS
let (aggregator, rustls_cfg) = tls::setup(&cli, http_client.clone());
let (aggregator, rustls_cfg) = tls::setup(&cli, http_client.clone())?;
runners.push(("aggregator".into(), aggregator));

let https_server = Arc::new(Server::new(
Expand Down
4 changes: 2 additions & 2 deletions src/http/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use mockall::automock;

use crate::{cli, core::SERVICE_NAME, http::dns::Resolver, tls::prepare_rustls_client_config};
use crate::{cli, core::SERVICE_NAME, http::dns::Resolver, tls::prepare_client_config};

#[automock]
#[async_trait]
Expand All @@ -19,7 +19,7 @@ impl ReqwestClient {
let http = &cli.http_client;

let client = reqwest::Client::builder()
.use_preconfigured_tls(prepare_rustls_client_config())
.use_preconfigured_tls(prepare_client_config())
.dns_resolver(Arc::new(Resolver::new(&cli.dns)))
.connect_timeout(http.timeout_connect)
.timeout(http.timeout)
Expand Down
156 changes: 118 additions & 38 deletions src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
use std::{net::SocketAddr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};

use crate::core::Run;
use anyhow::{anyhow, Error};
use async_trait::async_trait;
use axum::{extract::Request, Router};
use futures_util::pin_mut;
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
rt::{TokioExecutor, TokioIo, TokioTimer},
server::conn::auto::Builder,
};
use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpSocket, TcpStream},
Expand All @@ -19,17 +24,43 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tower_service::Service;
use tracing::{debug, warn};

use crate::core::Run;

// Blanket async read+write trait to box streams
trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Sync + Unpin> AsyncReadWrite for T {}

// TLS information about the connection
// To be injected into the request as an extension
#[derive(Clone)]
struct TlsInfo {
sni: String,
alpn: String,
protocol: ProtocolVersion,
cipher: CipherSuite,
}

impl From<&ServerConnection> for TlsInfo {
fn from(c: &ServerConnection) -> Self {
Self {
sni: c.server_name().unwrap_or("unknown").into(),
alpn: c
.alpn_protocol()
.map_or("unknown".into(), |x| String::from_utf8_lossy(x).to_string()),
protocol: c.protocol_version().unwrap_or(ProtocolVersion::Unknown(0)),
cipher: c
.negotiated_cipher_suite()
// Some default cipher, it should never be None in fact, but just in case we don't use unwrap()
.map_or(rustls::CipherSuite::TLS13_AES_128_CCM_SHA256, |x| x.suite()),
}
}
}

struct Conn {
addr: SocketAddr,
remote_addr: SocketAddr,
router: Router,
builder: Builder<TokioExecutor>,
token: CancellationToken,
grace_period: Duration,
tls_acceptor: Option<TlsAcceptor>,
}

Expand All @@ -44,50 +75,77 @@ impl Conn {
stream.set_nodelay(true)?;

// Perform TLS handshake if we're in TLS mode
let stream: Box<dyn AsyncReadWrite> = if let Some(v) = &self.tls_acceptor {
debug!("{}: performing TLS handshake", self.remote_addr);
Box::new(v.accept(stream).await?)
let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if let Some(v) = &self.tls_acceptor {
debug!(
"Server {}: {}: performing TLS handshake",
self.addr, self.remote_addr
);

let start = Instant::now();
let stream = v.accept(stream).await?;
let latency = start.elapsed();

let conn = stream.get_ref().1;
let tls_info = TlsInfo::from(conn);

debug!(
"Server {}: {}: handshake finished in {}ms (server: {}, proto: {:?}, cipher: {:?}, ALPN: {})",
self.addr,
self.remote_addr,
latency.as_millis(),
tls_info.sni,
tls_info.protocol,
tls_info.cipher,
tls_info.alpn,
);

(Box::new(stream), Some(tls_info))
} else {
Box::new(stream)
(Box::new(stream), None)
};

// Convert stream from Tokio to Hyper
let stream = TokioIo::new(stream);

// Convert router to Hyper service
let service = hyper::service::service_fn(move |request: Request<Incoming>| {
let service = hyper::service::service_fn(move |mut request: Request<Incoming>| {
// Inject TLS information if it's a TLS session
// TODO avoid cloning due to Fn() somehow?
if let Some(v) = tls_info.clone() {
request.extensions_mut().insert(v);
}

self.router.clone().call(request)
});

// Call the service
let mut builder = Builder::new(TokioExecutor::new());
let conn = self.builder.serve_connection(stream, service);
// Using mutable future reference requires pinning, otherwise .await consumes it
tokio::pin!(conn);

// Some sensible defaults
// TODO make configurable?
builder
.http2()
.adaptive_window(true)
.max_concurrent_streams(Some(100))
.keep_alive_interval(Some(Duration::from_secs(20)))
.keep_alive_timeout(Duration::from_secs(10));
select! {
biased; // Poll top-down

let conn = builder.serve_connection(stream, service);
pin_mut!(conn);

loop {
select! {
v = conn.as_mut() => {
if let Err(e) = v {
return Err(anyhow!("Unable to serve connection: {e}"));
}
() = self.token.cancelled() => {
// Start graceful shutdown of the connection
// For H2: sends GOAWAY frames to the client
// For H1: disables keepalives
conn.as_mut().graceful_shutdown();

break;
},

() = self.token.cancelled() => {
conn.as_mut().graceful_shutdown();
// Wait for the grace period to finish or connection to complete.
// Connection must still be polled for shutdown to proceed.
select! {
biased;
() = tokio::time::sleep(self.grace_period) => {},
_ = conn.as_mut() => {},
}
}

v = conn.as_mut() => {
if let Err(e) = v {
return Err(anyhow!("Unable to serve connection: {e}"));
}
},
}

debug!(
Expand Down Expand Up @@ -132,7 +190,19 @@ impl Server {
impl Run for Server {
async fn run(&self, token: CancellationToken) -> Result<(), Error> {
let listener = listen_tcp_backlog(self.addr, self.backlog)?;
pin_mut!(listener);

// Setup Hyper connection builder with some defaults
// TODO make configurable?
let mut builder = Builder::new(TokioExecutor::new());
builder
.http1()
.keep_alive(true)
.http2()
.adaptive_window(true)
.max_concurrent_streams(Some(100))
.timer(TokioTimer::new()) // Needed for the keepalives below
.keep_alive_interval(Some(Duration::from_secs(20)))
.keep_alive_timeout(Duration::from_secs(10));

warn!(
"Server {}: running (TLS: {})",
Expand All @@ -142,13 +212,20 @@ impl Run for Server {

loop {
select! {
biased; // Poll top-down

() = token.cancelled() => {
// Stop accepting new connections
drop(listener);

warn!("Server {}: shutting down, waiting for the active connections to close for {}s", self.addr, self.grace_period.as_secs());
self.tracker.close();
select! {
() = self.tracker.wait() => {},
() = tokio::time::sleep(self.grace_period) => {},
}
self.tracker.wait().await;

// select! {
// () = self.tracker.wait() => {},
// () = tokio::time::sleep(self.grace_period) => {},
// }
warn!("Server {}: shut down", self.addr);
return Ok(());
},
Expand All @@ -165,11 +242,14 @@ impl Run for Server {

// Create a new connection
// Router & TlsAcceptor are both Arc<> inside so it's cheap to clone
// Builder is a bit more complex, but cloning is better than to create it again
let conn = Conn {
addr: self.addr,
remote_addr,
router: self.router.clone(),
builder: builder.clone(),
token: token.child_token(),
grace_period: self.grace_period,
tls_acceptor: self.tls_acceptor.clone(),
};

Expand Down
23 changes: 16 additions & 7 deletions src/tls/cert/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,38 @@ use std::{

use anyhow::{anyhow, Context, Error};
use async_trait::async_trait;
use candid::Principal;
use futures::future::join_all;
use rustls::{crypto::aws_lc_rs, sign::CertifiedKey};
use tokio::select;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use x509_parser::prelude::*;

use crate::{core::Run, tls::cert::storage::StorageKey};
use crate::core::Run;
use providers::ProvidesCertificates;
use storage::StorageKey;

#[derive(Clone, Debug)]
pub struct CustomDomain {
name: String,
canister_id: Principal,
}

// Generic certificate and a list of its SANs
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Cert<T: Clone> {
san: Vec<String>,
cert: T,
pub custom: Option<CustomDomain>,
}

// Commonly used concrete type of the above for Rustls
pub type CertKey = Cert<Arc<CertifiedKey>>;

// Trait that the certificate providers should implement
// It should return a vector of Rustls-compatible keys
#[async_trait]
pub trait ProvidesCertificates: Sync + Send {
async fn get_certificates(&self) -> Result<Vec<CertKey>, Error>;
// Looks up custom domain canister id by hostname
pub trait LookupCanister: Sync + Send {
fn lookup_canister(&self, hostname: &str) -> Option<Principal>;
}

// Extracts a list of SubjectAlternativeName from a single certificate, formatted as strings.
Expand Down Expand Up @@ -104,6 +112,7 @@ pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result<CertKey, Error>
Ok(Cert {
san,
cert: Arc::new(CertifiedKey::new(certs, key)),
custom: None,
})
}

Expand Down
2 changes: 1 addition & 1 deletion src/tls/cert/providers/dir.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;

use crate::tls::cert::{pem_convert_to_rustls, CertKey, ProvidesCertificates};
use crate::tls::cert::{pem_convert_to_rustls, providers::ProvidesCertificates, CertKey};
use anyhow::{Context, Error};
use async_trait::async_trait;
use tokio::fs::read_dir;
Expand Down
11 changes: 11 additions & 0 deletions src/tls/cert/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,14 @@ pub mod syncer;

pub use dir::Provider as Dir;
pub use syncer::CertificatesImporter as Syncer;

use async_trait::async_trait;

use super::CertKey;

// Trait that the certificate providers should implement
// It should return a vector of Rustls-compatible keys
#[async_trait]
pub trait ProvidesCertificates: Sync + Send {
async fn get_certificates(&self) -> Result<Vec<CertKey>, anyhow::Error>;
}
Loading

0 comments on commit 3505182

Please sign in to comment.