From 51035883d2db70fed9a7af964c871e9cb0f5390a Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Fri, 19 Apr 2024 21:44:47 +0200 Subject: [PATCH] more work --- Cargo.toml | 2 ++ src/core.rs | 1 + src/http/server.rs | 26 +++++++++++++++++--------- src/tls/cert/mod.rs | 5 +---- src/tls/cert/providers/dir.rs | 11 +++-------- src/tls/cert/storage.rs | 16 +++++++--------- src/tls/mod.rs | 31 +++++++++++++++++++++++-------- src/tls/resolver.rs | 34 ++++++++++++++++++++++++++++++++++ 8 files changed, 88 insertions(+), 38 deletions(-) create mode 100644 src/tls/resolver.rs diff --git a/Cargo.toml b/Cargo.toml index 51bacf7..7e56995 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ candid = "0.10" clap = { version = "4.5", features = ["derive", "string"] } clap_derive = "4.5" ctrlc = { version = "3.4", features = ["termination"] } +derive-new = "0.6" fqdn = "0.3" futures = "0.3" futures-util = "0.3" @@ -53,6 +54,7 @@ reqwest = { git = "https://github.com/blind-oracle/reqwest.git", default_feature "stream", ] } rustls = "0.23" +rustls-acme = "0.9" rustls-pemfile = "2" webpki-roots = "0.26" serde = "1.0" diff --git a/src/core.rs b/src/core.rs index e378582..71b98fb 100644 --- a/src/core.rs +++ b/src/core.rs @@ -21,6 +21,7 @@ pub trait Run: Send + Sync { } async fn handle(request: axum::extract::Request) -> impl IntoResponse { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; warn!("{:?}", request.extensions().get::>()); "Hello" } diff --git a/src/http/server.rs b/src/http/server.rs index 9d1c7c5..2e20425 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -5,7 +5,7 @@ use std::{ }; use crate::core::Run; -use anyhow::{anyhow, Error}; +use anyhow::{anyhow, Context, Error}; use async_trait::async_trait; use axum::{extract::Request, Router}; use hyper::body::Incoming; @@ -15,7 +15,7 @@ use hyper_util::{ }; use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion}; use tokio::{ - io::{AsyncRead, AsyncWrite}, + io::{AsyncRead, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpSocket, TcpStream}, select, }; @@ -24,7 +24,7 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower_service::Service; use tracing::{debug, warn}; -use crate::{cli, tls}; +use crate::{cli, tls::is_http_alpn}; // Blanket async read+write trait to box streams trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} @@ -134,7 +134,15 @@ impl Conn { // Perform TLS handshake if we're in TLS mode let (stream, tls_info): (Box, _) = if self.tls_acceptor.is_some() { - let (stream, tls_info) = self.tls_handshake(stream).await?; + let (mut stream, tls_info) = self.tls_handshake(stream).await?; + + // Close the connection if agreed ALPN is not HTTP - probably it was an ACME challenge + if !is_http_alpn(tls_info.alpn.as_bytes()) { + debug!("Not HTTP ALPN ('{}') - closing connection", tls_info.alpn); + stream.shutdown().await.context("error in shutdown()")?; + return Ok(()); + } + (Box::new(stream), Some(tls_info)) } else { (Box::new(stream), None) @@ -189,11 +197,6 @@ impl Conn { }, } - debug!( - "Server {}: {}: connection finished", - self.addr, self.remote_addr - ); - Ok(()) } } @@ -295,6 +298,11 @@ impl Run for Server { if let Err(e) = conn.handle(stream).await { warn!("Server {}: {}: failed to handle connection: {e}", conn.addr, remote_addr); } + + debug!( + "Server {}: {}: connection finished", + conn.addr, remote_addr + ); }); } } diff --git a/src/tls/cert/mod.rs b/src/tls/cert/mod.rs index b76413f..0e888e3 100644 --- a/src/tls/cert/mod.rs +++ b/src/tls/cert/mod.rs @@ -117,16 +117,13 @@ pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result } // Collects certificates from providers and stores them in a given storage +#[derive(derive_new::new)] pub struct Aggregator { providers: Vec>, storage: Arc, } impl Aggregator { - pub fn new(providers: Vec>, storage: Arc) -> Self { - Self { providers, storage } - } - // Fetches certificates concurrently from all providers async fn fetch(&self) -> Result, Error> { let certs = join_all( diff --git a/src/tls/cert/providers/dir.rs b/src/tls/cert/providers/dir.rs index 7ccae00..bbc2c7f 100644 --- a/src/tls/cert/providers/dir.rs +++ b/src/tls/cert/providers/dir.rs @@ -6,17 +6,12 @@ use async_trait::async_trait; use tokio::fs::read_dir; use tracing::info; -pub struct Provider { - path: PathBuf, -} - // It searches for .pem files in the given directory and tries to find the // corresponding .key files with the same base name. // After that it loads & parses each pair. -impl Provider { - pub const fn new(path: PathBuf) -> Self { - Self { path } - } +#[derive(derive_new::new)] +pub struct Provider { + path: PathBuf, } #[async_trait] diff --git a/src/tls/cert/storage.rs b/src/tls/cert/storage.rs index 0848ea3..dd5e94f 100644 --- a/src/tls/cert/storage.rs +++ b/src/tls/cert/storage.rs @@ -4,13 +4,10 @@ use anyhow::{anyhow, Error}; use arc_swap::ArcSwapOption; use candid::Principal; use fqdn::FQDN; -use rustls::{ - server::{ClientHello, ResolvesServerCert}, - sign::CertifiedKey, -}; +use rustls::{server::ClientHello, sign::CertifiedKey}; use super::{Cert, LookupCanister}; -use crate::tls; +use crate::tls::{self, resolver}; #[derive(Debug)] struct StorageInner { @@ -58,6 +55,7 @@ impl Storage { let mut canisters = HashMap::new(); for c in cert_list { + // Take note of the canister ID if let Some(v) = c.custom { if canisters.insert(v.name.clone(), v.canister_id).is_some() { return Err(anyhow!("Duplicate name detected: {}", v.name)); @@ -79,11 +77,11 @@ impl Storage { } // Implement certificate resolving for Rustls -impl ResolvesServerCert for StorageKey { - fn resolve(&self, ch: ClientHello) -> Option> { - // Make sure we've got an ALPN list and they're all HTTP, otherwise refuse resolving +impl resolver::ResolvesServerCert for StorageKey { + fn resolve(&self, ch: &ClientHello) -> Option> { + // Make sure we've got an ALPN list and they're all HTTP, otherwise refuse resolving. // This is to make sure we don't answer to e.g. ACME challenges here - if !ch.alpn()?.all(|x| tls::ALPN_HTTP.contains(&x)) { + if !ch.alpn()?.all(tls::is_http_alpn) { return None; } diff --git a/src/tls/mod.rs b/src/tls/mod.rs index 1a08e61..cfdd505 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -1,4 +1,5 @@ mod cert; +pub mod resolver; mod test; use std::sync::Arc; @@ -10,19 +11,27 @@ use rustls::{ version::{TLS12, TLS13}, RootCertStore, }; +use rustls_acme::acme::ACME_TLS_ALPN_NAME; use crate::{ cli::Cli, core::Run, http, - tls::cert::{providers, storage::Storage, Aggregator}, + tls::{ + cert::{providers, storage::Storage, Aggregator}, + resolver::AggregatingResolver, + }, }; use cert::providers::ProvidesCertificates; -pub const ALPN_H1: &[u8] = b"http/1.1"; -pub const ALPN_H2: &[u8] = b"h2"; -pub const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2]; +const ALPN_H1: &[u8] = b"http/1.1"; +const ALPN_H2: &[u8] = b"h2"; +const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2]; + +pub fn is_http_alpn(alpn: &[u8]) -> bool { + ALPN_HTTP.contains(&alpn) +} pub fn prepare_server_config(resolver: Arc) -> ServerConfig { let mut cfg = ServerConfig::builder_with_protocol_versions(&[&TLS13, &TLS12]) @@ -31,7 +40,12 @@ pub fn prepare_server_config(resolver: Arc) -> ServerCon // Create custom session storage with higher limit to allow effective TLS session resumption cfg.session_storage = ServerSessionMemoryCache::new(131_072); - cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()]; + cfg.alpn_protocols = vec![ + ALPN_H2.to_vec(), + ALPN_H1.to_vec(), + // Support ACME challenge ALPN too + ACME_TLS_ALPN_NAME.to_vec(), + ]; cfg } @@ -79,8 +93,9 @@ pub fn setup( } let storage = Arc::new(Storage::new()); - let aggregator = Arc::new(Aggregator::new(providers, storage.clone())); - let config = prepare_server_config(storage); + let cert_aggregator = Arc::new(Aggregator::new(providers, storage.clone())); + let resolve_aggregator = Arc::new(AggregatingResolver::new(None, vec![storage])); + let config = prepare_server_config(resolve_aggregator); - Ok((aggregator, config)) + Ok((cert_aggregator, config)) } diff --git a/src/tls/resolver.rs b/src/tls/resolver.rs new file mode 100644 index 0000000..4aa9c36 --- /dev/null +++ b/src/tls/resolver.rs @@ -0,0 +1,34 @@ +use std::{fmt::Debug, sync::Arc}; + +use rustls::{ + server::{ClientHello, ResolvesServerCert as ResolvesServerCertRustls}, + sign::CertifiedKey, +}; + +// Custom ResolvesServerCert trait that takes ClientHello by reference. +// It's needed because Rustls' ResolvesServerCert consumes ClientHello +pub trait ResolvesServerCert: Debug + Send + Sync { + fn resolve(&self, client_hello: &ClientHello) -> Option>; +} + +// Combines several certificate resolvers into one +// Only one Rustls-compatible resolver can be used (acme) since it takes ClientHello by value +#[derive(Debug, derive_new::new)] +pub struct AggregatingResolver { + acme: Option>, + resolvers: Vec>, +} + +// Implement certificate resolving for Rustls +impl ResolvesServerCertRustls for AggregatingResolver { + fn resolve(&self, ch: ClientHello) -> Option> { + // Iterate over our resolvers to find matching cert if any + let cert = self.resolvers.iter().find_map(|x| x.resolve(&ch)); + if let Some(v) = cert { + return Some(v); + } + + // Otherwise try the ACME resolver with Rustls trait that consumes ClientHello + self.acme.as_ref().and_then(|x| x.resolve(ch)) + } +}