From da17b1e447cb548a14c00d6e3476deba2516068b Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Tue, 16 Apr 2024 16:30:31 +0200 Subject: [PATCH] work continues --- Cargo.toml | 13 ++- src/cli.rs | 44 +++++++--- src/core.rs | 39 ++++++++- src/http/client.rs | 45 +++++++++++ src/http/dns.rs | 89 ++++++++------------ src/http/mod.rs | 49 +---------- src/http/server.rs | 158 ++++++++++++++++++++++++++++++++++++ src/main.rs | 6 ++ src/tls/cert/dir.rs | 85 +++++++++++++++++++ src/tls/cert/mod.rs | 162 ++++++++++++++++++++++++++++++++++--- src/tls/cert/syncer/mod.rs | 15 ++-- src/tls/cert/test.rs | 0 src/tls/mod.rs | 7 +- src/tls/resolver.rs | 23 +++--- 14 files changed, 584 insertions(+), 151 deletions(-) create mode 100644 src/http/client.rs create mode 100644 src/http/server.rs create mode 100644 src/tls/cert/dir.rs delete mode 100644 src/tls/cert/test.rs diff --git a/Cargo.toml b/Cargo.toml index 8de9e43..56aca17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,15 +10,19 @@ arc-swap = "1" async-scoped = { version = "0.8", features = ["use-tokio"] } async-trait = "0.1" axum = "0.7" +axum-server = { version = "0.6", features = ["tls-rustls"] } candid = "0.10" clap = { version = "4.5", features = ["derive", "string"] } clap_derive = "4.5" +ctrlc = { version = "3.4", features = [ "termination" ] } fqdn = "0.3" futures = "0.3" -hickory-resolver = {version = "0.24", features = ["dns-over-rustls", "dns-over-https-rustls", "dnssec-ring"] } +futures-util = "0.3" +hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots", "dnssec-ring"] } http = "1.1" humantime = "2.1" hyper = "1.2" +hyper-util = "0.1" jemallocator = "0.5" jemalloc-ctl = "0.5" lazy_static = "1.4" @@ -48,9 +52,10 @@ webpki-roots = "0.26" serde = "1.0" serde_json = "1.0" strum = "0.26" +strum_macros = "0.26" thiserror = "1.0" -tokio = "1.36" -tokio-util = "0.7" +tokio = { version = "1.36", features = [ "full" ] } +tokio-util = {version = "0.7", features = [ "full" ] } tokio-rustls = "0.26" tower = "0.4" tower_governor = "0.3" @@ -59,6 +64,7 @@ tower-http = { version = "0.5", features = [ "request-id", "compression-full", ] } +tower-service = "0.3" tracing = "0.1" tracing-core = "0.1" tracing-serde = "0.1" @@ -68,3 +74,4 @@ x509-parser = "0.16" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } +tempfile = "3.10" diff --git a/src/cli.rs b/src/cli.rs index c3a8f77..e9c8305 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,8 +1,10 @@ -use std::{net::IpAddr, time::Duration}; +use std::{ + net::{IpAddr, SocketAddr}, + time::Duration, +}; -use anyhow::{anyhow, Error}; use clap::{Args, Parser}; -use hickory_resolver::config::{Protocol, CLOUDFLARE_IPS}; +use hickory_resolver::config::CLOUDFLARE_IPS; use humantime::parse_duration; use crate::{ @@ -10,9 +12,6 @@ use crate::{ http::dns, }; -// Clap does not support prefixes due to macro limitations -// so we have to add prefixes manually (long = "...") - #[derive(Parser)] #[clap(name = SERVICE_NAME)] #[clap(author = AUTHOR_NAME)] @@ -20,10 +19,16 @@ pub struct Cli { #[command(flatten, next_help_heading = "HTTP Client")] pub http_client: HttpClient, - #[command(flatten, next_help_heading = "DNS")] + #[command(flatten, next_help_heading = "DNS Resolver")] pub dns: Dns, + + #[command(flatten, next_help_heading = "HTTP Server")] + pub http_server: HttpServer, } +// Clap does not support prefixes due to macro limitations +// so we have to add them manually (long = "...") + #[derive(Args)] pub struct HttpClient { /// Timeout for HTTP connection phase @@ -47,17 +52,36 @@ pub struct HttpClient { pub http2_keepalive_timeout: Duration, } -#[derive(Args, Clone, Debug)] +#[derive(Args)] pub struct Dns { /// List of DNS servers to use #[clap(long = "dns-servers", default_values_t = CLOUDFLARE_IPS)] pub servers: Vec, - /// DNS protocol to use (udp/tcp/tls/https) + /// DNS protocol to use (clear/tls/https) #[clap(long = "dns-protocol", default_value = "tls")] pub protocol: dns::Protocol, - /// TLS name for DNS-over-TLS and DNS-over-HTTPS protocols + /// TLS name to expect for TLS and HTTPS protocols (e.g. "dns.google" or "cloudflare-dns.com") #[clap(long = "dns-tls-name", default_value = "cloudflare-dns.com")] pub tls_name: String, + + /// Cache size for the resolver (in number of DNS records) + #[clap(long = "dns-cache-size", default_value = "2048")] + pub cache_size: usize, +} + +#[derive(Args)] +pub struct HttpServer { + /// Where to listen for HTTP + #[clap(long = "listen-http", default_value = "[::1]:8080")] + pub http: SocketAddr, + + /// Where to listen for HTTPS + #[clap(long = "listen-https", default_value = "[::1]:8443")] + pub https: SocketAddr, + + /// Backlog of incoming connections to set on the listening socket. + #[clap(long, default_value = "8192")] + pub backlog: u32, } diff --git a/src/core.rs b/src/core.rs index 39442fd..5e52a58 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,10 +1,45 @@ use anyhow::Error; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; +use tracing::{error, warn}; -use crate::cli::Cli; +use crate::{cli::Cli, http::server::Server}; pub const SERVICE_NAME: &str = "ic_gateway"; pub const AUTHOR_NAME: &str = "Boundary Node Team "; -pub async fn main(_cli: Cli) -> Result<(), Error> { +pub async fn main(cli: Cli) -> Result<(), Error> { + let token = CancellationToken::new(); + let tracker = TaskTracker::new(); + + // Handle SIGTERM/SIGHUP and Ctrl+C + // Cancelling a token cancels all of its clones too, except the ones from .child_token() + let handler_token = token.clone(); + ctrlc::set_handler(move || handler_token.cancel())?; + + let router = axum::Router::new().route("/", axum::routing::get(|| async { "Hello, World!" })); + + let http_server = Server::new( + cli.http_server.http, + cli.http_server.backlog, + router, + token.child_token(), + None, + ); + + tracker.spawn(async move { + match http_server.start().await { + Ok(()) => {} + Err(e) => { + error!("Unable to start server: {e}"); + } + }; + }); + + warn!("Service is running, waiting for the shutdown signal"); + token.cancelled().await; + warn!("Shutdown signal received, cleaning up"); + tracker.close(); + tracker.wait().await; + Ok(()) } diff --git a/src/http/client.rs b/src/http/client.rs new file mode 100644 index 0000000..f0cc359 --- /dev/null +++ b/src/http/client.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use mockall::automock; + +use crate::{cli, core::SERVICE_NAME, http::dns::Resolver, tls::prepare_rustls_client_config}; + +#[automock] +#[async_trait] +pub trait Client: Send + Sync { + async fn execute(&self, req: reqwest::Request) -> Result; +} + +pub struct ReqwestClient(reqwest::Client); + +impl ReqwestClient { + pub fn new(cli: &cli::Cli) -> Result { + let http = &cli.http_client; + + let client = reqwest::Client::builder() + .use_preconfigured_tls(prepare_rustls_client_config()) + .dns_resolver(Arc::new(Resolver::new(&cli.dns))) + .connect_timeout(http.timeout_connect) + .timeout(http.timeout) + .tcp_nodelay(true) + .tcp_keepalive(Some(http.tcp_keepalive)) + .http2_keep_alive_interval(Some(http.http2_keepalive)) + .http2_keep_alive_timeout(http.http2_keepalive_timeout) + .http2_keep_alive_while_idle(true) + .http2_adaptive_window(true) + .user_agent(SERVICE_NAME) + .redirect(reqwest::redirect::Policy::none()) + .no_proxy() + .build()?; + + Ok(Self(client)) + } +} + +#[async_trait] +impl Client for ReqwestClient { + async fn execute(&self, req: reqwest::Request) -> Result { + self.0.execute(req).await + } +} diff --git a/src/http/dns.rs b/src/http/dns.rs index 734741c..82429e6 100644 --- a/src/http/dns.rs +++ b/src/http/dns.rs @@ -1,57 +1,53 @@ -use std::{net::SocketAddr, str::FromStr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc}; -use anyhow::{anyhow, Error}; use hickory_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, lookup_ip::LookupIpIntoIter, TokioAsyncResolver, }; -use once_cell::sync::OnceCell; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use strum_macros::EnumString; use crate::cli::Dns; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, EnumString)] +#[strum(serialize_all = "snake_case")] pub enum Protocol { Clear, Tls, Https, } -impl FromStr for Protocol { - type Err = Error; +#[derive(Debug, Clone)] +pub struct Resolver(Arc); - fn from_str(s: &str) -> Result { - Ok(match s { - "clear" => Protocol::Clear, - "tls" => Protocol::Tls, - "https" => Protocol::Https, - _ => return Err(anyhow!("Unknown DNS protocol")), - }) +// new() must be called in Tokio context +impl Resolver { + pub fn new(cli: &Dns) -> Self { + let name_servers = match cli.protocol { + Protocol::Clear => NameServerConfigGroup::from_ips_clear(&cli.servers, 53, true), + Protocol::Tls => { + NameServerConfigGroup::from_ips_tls(&cli.servers, 853, cli.tls_name.clone(), true) + } + Protocol::Https => { + NameServerConfigGroup::from_ips_https(&cli.servers, 443, cli.tls_name.clone(), true) + } + }; + + let cfg = ResolverConfig::from_parts(None, vec![], name_servers); + + let mut opts = ResolverOpts::default(); + opts.rotate = true; + opts.cache_size = cli.cache_size; + opts.use_hosts_file = false; + opts.preserve_intermediates = false; + opts.try_tcp_on_error = true; + + let resolver = TokioAsyncResolver::tokio(cfg, opts); + Self(Arc::new(resolver)) } } -pub fn prepare_dns_resolver(cli: Dns) -> TokioAsyncResolver { - let name_servers = match cli.protocol { - Protocol::Clear => NameServerConfigGroup::from_ips_clear(&cli.servers, 53, true), - Protocol::Tls => NameServerConfigGroup::from_ips_tls(&cli.servers, 853, cli.tls_name, true), - Protocol::Https => { - NameServerConfigGroup::from_ips_https(&cli.servers, 443, cli.tls_name, true) - } - }; - - let cfg = ResolverConfig::from_parts(None, vec![], name_servers); - - let mut opts = ResolverOpts::default(); - opts.rotate = true; - opts.cache_size = 2048; - opts.use_hosts_file = false; - opts.preserve_intermediates = false; - opts.try_tcp_on_error = true; - - TokioAsyncResolver::tokio(cfg, opts) -} - struct SocketAddrs { iter: LookupIpIntoIter, } @@ -64,34 +60,13 @@ impl Iterator for SocketAddrs { } } -#[derive(Debug, Clone)] -pub struct DnsResolver { - // Constructor is called most probably not in the Tokio context - // so we delay creation of the resolver using once_cell - state: Arc>, - cli: Dns, -} - -impl DnsResolver { - pub fn new(cli: &Dns) -> Self { - Self { - state: Arc::new(OnceCell::new()), - cli: cli.clone(), - } - } -} - // Implement resolving for Reqwest using Hickory -impl Resolve for DnsResolver { +impl Resolve for Resolver { fn resolve(&self, name: Name) -> Resolving { let resolver = self.clone(); Box::pin(async move { - let resolver = resolver - .state - .get_or_init(|| prepare_dns_resolver(resolver.cli)); - - let lookup = resolver.lookup_ip(name.as_str()).await?; + let lookup = resolver.0.lookup_ip(name.as_str()).await?; let addrs: Addrs = Box::new(SocketAddrs { iter: lookup.into_iter(), }); diff --git a/src/http/mod.rs b/src/http/mod.rs index a85f90c..2aab2d3 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,48 +1,3 @@ +pub mod client; pub mod dns; - -use std::sync::Arc; - -use async_trait::async_trait; -use mockall::automock; -use reqwest::{Client, Error, Request, Response}; - -use crate::{cli, core::SERVICE_NAME, http::dns::DnsResolver, tls::prepare_rustls_client_config}; - -#[automock] -#[async_trait] -pub trait HttpClient: Send + Sync { - async fn execute(&self, req: Request) -> Result; -} - -pub struct ReqwestClient(Client); - -impl ReqwestClient { - pub fn new(cli: &cli::Cli) -> Result { - let http = &cli.http_client; - - let client = Client::builder() - .use_preconfigured_tls(prepare_rustls_client_config()) - .dns_resolver(Arc::new(DnsResolver::new(&cli.dns))) - .connect_timeout(http.timeout_connect) - .timeout(http.timeout) - .tcp_nodelay(true) - .tcp_keepalive(Some(http.tcp_keepalive)) - .http2_keep_alive_interval(Some(http.http2_keepalive)) - .http2_keep_alive_timeout(http.http2_keepalive_timeout) - .http2_keep_alive_while_idle(true) - .http2_adaptive_window(true) - .user_agent(SERVICE_NAME) - .redirect(reqwest::redirect::Policy::none()) - .no_proxy() - .build()?; - - Ok(Self(client)) - } -} - -#[async_trait] -impl HttpClient for ReqwestClient { - async fn execute(&self, req: Request) -> Result { - self.0.execute(req).await - } -} +pub mod server; diff --git a/src/http/server.rs b/src/http/server.rs new file mode 100644 index 0000000..e446d80 --- /dev/null +++ b/src/http/server.rs @@ -0,0 +1,158 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use anyhow::{anyhow, Error}; +use axum::{extract::Request, Router}; +use futures_util::pin_mut; +use hyper::body::Incoming; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::Builder, +}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::select; +use tokio_rustls::TlsAcceptor; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; +use tower_service::Service; +use tracing::{debug, warn}; + +// Blanket async read+write trait to box streams +trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} +impl AsyncReadWrite for T {} + +pub struct Conn { + addr: SocketAddr, + remote_addr: SocketAddr, + router: Router, + tls_acceptor: Option, +} + +impl Conn { + pub async fn handle(&self, stream: TcpStream) -> Result<(), Error> { + debug!( + "Server {}: {}: got a new connection", + self.addr, self.remote_addr + ); + + // Disable Nagle's algo + stream.set_nodelay(true)?; + + // Perform TLS handshake if we're in TLS mode + let stream: Box = if let Some(v) = &self.tls_acceptor { + debug!("{}: performing TLS handshake", self.remote_addr); + Box::new(v.accept(stream).await?) + } else { + Box::new(stream) + }; + + // Convert stream from Tokio to Hyper + let stream = TokioIo::new(stream); + + // Convert router to Hyper service + let service = hyper::service::service_fn(move |request: Request| { + self.router.clone().call(request) + }); + + // Call the service + Builder::new(TokioExecutor::new()) + .serve_connection(stream, service) + .await + // It shouldn't really fail since Axum routers are infallible + .map_err(|e| anyhow!("unable to call service: {e}"))?; + + debug!( + "Server {}: {}: connection finished", + self.addr, self.remote_addr + ); + + Ok(()) + } +} + +// Listens for new connections on addr with an optional TLS and serves provided Router +pub struct Server { + addr: SocketAddr, + backlog: u32, + router: Router, + token: CancellationToken, + tracker: TaskTracker, + tls_acceptor: Option, +} + +impl Server { + pub fn new( + addr: SocketAddr, + backlog: u32, + router: Router, + token: CancellationToken, + rustls_cfg: Option, + ) -> Self { + Self { + addr, + backlog, + router, + token, + tracker: TaskTracker::new(), + tls_acceptor: rustls_cfg.map(|x| TlsAcceptor::from(Arc::new(x))), + } + } + + pub async fn start(&self) -> Result<(), Error> { + let listener = listen_tcp_backlog(self.addr, self.backlog)?; + pin_mut!(listener); + + loop { + select! { + () = self.token.cancelled() => { + warn!("Server {}: shutting down, waiting for the active connections to close for 30s", self.addr); + self.tracker.close(); + select! { + () = self.tracker.wait() => {}, + () = tokio::time::sleep(Duration::from_secs(30)) => {}, + } + warn!("Server {}: shut down", self.addr); + return Ok(()); + }, + + // Try to accept the connection + v = listener.accept() => { + let (stream, remote_addr) = match v { + Ok((a, b)) => (a, b), + Err(e) => { + warn!("Unable to accept connection: {e}"); + continue; + } + }; + + // Create a new connection + // Router & TlsAcceptor are both Arc<> inside so it's cheap to clone + let conn = Conn { + addr: self.addr, + remote_addr, + router: self.router.clone(), + tls_acceptor: self.tls_acceptor.clone(), + }; + + // Spawn a task to handle connection + self.tracker.spawn(async move { + match conn.handle(stream).await { + Ok(()) => {}, + Err(e) => warn!("Server {}: {}: failed to handle connection: {e}", conn.addr, remote_addr), + } + }); + } + } + } + } +} + +// Creates a listener with a backlog set +pub fn listen_tcp_backlog(addr: SocketAddr, backlog: u32) -> Result { + let socket = match addr { + SocketAddr::V4(_) => TcpSocket::new_v4()?, + SocketAddr::V6(_) => TcpSocket::new_v6()?, + }; + + socket.bind(addr)?; + Ok(socket.listen(backlog)?) +} diff --git a/src/main.rs b/src/main.rs index b94d01a..704d5fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,5 +21,11 @@ static GLOBAL: Jemalloc = Jemalloc; #[tokio::main] async fn main() -> Result<(), Error> { let cli = Cli::parse(); + + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::DEBUG) + .finish(); + tracing::subscriber::set_global_default(subscriber)?; + core::main(cli).await } diff --git a/src/tls/cert/dir.rs b/src/tls/cert/dir.rs new file mode 100644 index 0000000..09c8840 --- /dev/null +++ b/src/tls/cert/dir.rs @@ -0,0 +1,85 @@ +use std::path::PathBuf; + +use anyhow::{anyhow, Context, Error}; +use async_trait::async_trait; +use tokio::fs::read_dir; + +use super::{pem_convert_to_rustls, Cert, ProvidesCertificates}; + +pub struct Provider { + path: PathBuf, +} + +// It searches for .pem files in the given folder and tries to find the +// corresponding .key files with the same base name. +// After that it loads & parses each pair. +impl Provider { + pub const fn new(path: PathBuf) -> Self { + Self { path } + } +} + +#[async_trait] +impl ProvidesCertificates for Provider { + async fn get_certificates(&self) -> Result, Error> { + let mut files = read_dir(&self.path).await?; + + let mut certs = vec![]; + while let Some(v) = files.next_entry().await? { + if !v.file_type().await?.is_file() { + continue; + } + + if !v + .path() + .extension() + .map_or(false, |x| x.eq_ignore_ascii_case("pem")) + { + continue; + } + + let path = v.path(); + let base = path.file_stem().unwrap().to_string_lossy(); + let keyfile = self.path.join(format!("{base}.key")); + + let chain = tokio::fs::read(v.path()).await?; + let key = tokio::fs::read(keyfile).await.context(format!( + "Corresponding .key file for {} not found", + v.path().to_string_lossy() + ))?; + + let cert = pem_convert_to_rustls(&key, &chain)?; + certs.push(cert); + } + + Ok(certs) + } +} + +#[cfg(test)] +mod test { + use super::super::test::{CERT, KEY}; + use super::*; + + #[tokio::test] + async fn test() -> Result<(), Error> { + let dir = tempfile::tempdir()?; + + let keyfile = dir.path().join("foobar.key"); + std::fs::write(keyfile, KEY)?; + + let certfile = dir.path().join("foobar.pem"); + std::fs::write(certfile, CERT)?; + + // Some junk to be ignored + std::fs::write(dir.path().join("foobar.baz"), b"foobar")?; + + let prov = Provider::new(dir.path().to_path_buf()); + let certs = prov.get_certificates().await?; + + assert_eq!(certs.len(), 1); + assert_eq!(certs[0].san, vec!["novg"]); + + Ok(()) + } +} diff --git a/src/tls/cert/mod.rs b/src/tls/cert/mod.rs index 1c1e021..1223404 100644 --- a/src/tls/cert/mod.rs +++ b/src/tls/cert/mod.rs @@ -1,15 +1,90 @@ +mod dir; mod syncer; -mod test; -use std::sync::Arc; +use std::{ + any, + collections::HashMap, + convert::TryFrom, + net::{Ipv4Addr, Ipv6Addr}, + sync::Arc, +}; -use anyhow::{anyhow, Error}; +use anyhow::{anyhow, Context, Error}; +use arc_swap::ArcSwapOption; use async_trait::async_trait; use futures::future::join_all; use rustls::{crypto::aws_lc_rs, sign::CertifiedKey}; +use x509_parser::prelude::*; + +// Trait that the certificate sources should implement +// It should return a vector of Rustls-compatible CertifiedKeys +#[async_trait] +pub trait ProvidesCertificates: Sync + Send { + async fn get_certificates(&self) -> Result, Error>; +} + +// Certificate and its SANs +pub struct Cert { + // List of SubjectAlternativeNames + san: Vec, + key: Arc, +} + +// Shared certificate storage +// It's intended to hold CertifiedKey but since we can't create it easily - we use generics +// to be able to do tests using other types +pub type Storage = Arc>>; + +// Extracts a list of SubjectAlternativeName from a single certificate, formatted as strings. +// Fails for everything except DNSName and IPAddress +fn extract_san_from_der(cert: &[u8]) -> Result, Error> { + let cert = X509Certificate::from_der(cert) + .context("Unable to parse DER-encoded certificate")? + .1; + + for ext in cert.extensions() { + if let ParsedExtension::SubjectAlternativeName(san) = ext.parsed_extension() { + let mut names = vec![]; + for name in &san.general_names { + let name = match name { + GeneralName::DNSName(v) => (*v).to_string(), + GeneralName::IPAddress(v) => match v.len() { + 4 => { + let b: [u8; 4] = (*v).try_into().unwrap(); // We already checked that it's 4 + let ip = Ipv4Addr::from(b); + ip.to_string() + } + + 16 => { + let b: [u8; 16] = (*v).try_into().unwrap(); // We already checked that it's 16 + let ip = Ipv6Addr::from(b); + ip.to_string() + } + + _ => return Err(anyhow!("Invalid IP address length")), + }, + + _ => return Err(anyhow!("Unsupported SubjectAlternativeName type")), + }; + + names.push(name); + } + + if names.is_empty() { + return Err(anyhow!( + "No supported names found in SubjectAlternativeName extension" + )); + } + + return Ok(names); + } + } + + Err(anyhow!("SubjectAlternativeName extension not found")) +} // Converts raw PEM certificate chain & private key to a CertifiedKey ready to be consumed by Rustls -pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result, Error> { +pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result { let (key, certs) = (key.to_vec(), certs.to_vec()); let key = rustls_pemfile::private_key(&mut key.as_ref())? @@ -20,15 +95,16 @@ pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result Result>, Error>; + Ok(Cert { + san, + key: Arc::new(CertifiedKey::new(certs, key)), + }) } // Provider that aggregates other providers' output @@ -44,7 +120,7 @@ impl AggregatingProvider { #[async_trait] impl ProvidesCertificates for AggregatingProvider { - async fn get_certificates(&self) -> Result>, Error> { + async fn get_certificates(&self) -> Result, Error> { let certs = join_all( self.providers .iter() @@ -62,3 +138,65 @@ impl ProvidesCertificates for AggregatingProvider { Ok(certs) } } + +#[cfg(test)] +pub mod test { + use super::*; + + pub const CERT: &[u8] = b"-----BEGIN CERTIFICATE-----\n\ + MIIC6TCCAdGgAwIBAgIUK60AjMl8YTJ5nWViMweY043y6/EwDQYJKoZIhvcNAQEL\n\ + BQAwDzENMAsGA1UEAwwEbm92ZzAeFw0yMzAxMDkyMTM5NTZaFw0zMzAxMDYyMTM5\n\ + NTZaMA8xDTALBgNVBAMMBG5vdmcwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEK\n\ + AoIBAQCd/7NXWeENaITmYU+eWMJEJMZa6v74g70RpZlprQzx148U0QOKEw/r6mmd\n\ + SlbN4wsbb9lUu3zmXXpvYDAHYuOTYsDWcuNJXP/gCnPrD2wU8lJt3C5blmeU/9+0\n\ + U6/ppRmu6kf/jmm7CMBnowI0+kdvTF7sbpiUBXTDujXNsqtX0FaksILc9ZAqpUCC\n\ + 2gqRcOXahzT2vnvJ2N+2bhveG+eB0/5oZcKgx0D4QgjR9k1+thWOQZUCJMg32OYS\n\ + k4e57WhOQxu9Kh5N2MU1Ff3fhCYXzg7/GhJtWyDmjt1vNBwGW9Zn0BicySdcVFPC\n\ + mRW3/rZrSpnwvsEnpIuyKGq+NMSXAgMBAAGjPTA7MAkGA1UdEwQCMAAwDwYDVR0R\n\ + BAgwBoIEbm92ZzAdBgNVHQ4EFgQUYHN6l0ihbfbLQXqnKPltmv9DWDkwDQYJKoZI\n\ + hvcNAQELBQADggEBAFBvyns/lJZ+zB4/Tmx3YUryji20XUNwhtlBC6V7rdWCXneY\n\ + kqKVgbyDZ+XAYX2eL3o1gcv+XJxQgHfL+OqHJCVbK2kkYVSCW38WNVZb+oeTp/w3\n\ + pgtmg91JcCjFEw2doqImLZLQDX6KK1gDGdTQ2dtisFcxGEkMUyjzqmZmZNzl+u7d\n\ + JeDygLfGrMleO7ij2hP2vEfgkGbbvM+JCTav0B91Rj8/CbJHBwr8/CW4BJTjsqZC\n\ + mglNb9+hY8N6XAxntoqZsFzuDyDx7ZSxeAW0yVRemrIPSgcPwpLDBFm4dCSwUHJN\n\ + ujBjp7DRCQgg8uUq+0FMQ63ioZoR5mXQ5hzmTqk=\n\ + -----END CERTIFICATE-----\n\ + "; + + pub const KEY: &[u8] = b"-----BEGIN PRIVATE KEY-----\n\ + MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCd/7NXWeENaITm\n\ + YU+eWMJEJMZa6v74g70RpZlprQzx148U0QOKEw/r6mmdSlbN4wsbb9lUu3zmXXpv\n\ + YDAHYuOTYsDWcuNJXP/gCnPrD2wU8lJt3C5blmeU/9+0U6/ppRmu6kf/jmm7CMBn\n\ + owI0+kdvTF7sbpiUBXTDujXNsqtX0FaksILc9ZAqpUCC2gqRcOXahzT2vnvJ2N+2\n\ + bhveG+eB0/5oZcKgx0D4QgjR9k1+thWOQZUCJMg32OYSk4e57WhOQxu9Kh5N2MU1\n\ + Ff3fhCYXzg7/GhJtWyDmjt1vNBwGW9Zn0BicySdcVFPCmRW3/rZrSpnwvsEnpIuy\n\ + KGq+NMSXAgMBAAECggEAKYtxTFAxWZW4kF1ZEqFzH3juAT0WYyE8x1WcY8mhhDvy\n\ + fv5AqH8/qgBe2gGQlp2TL5k2881C184PohaQOnj5rykB3MGj2wgNrgsBlPberBlV\n\ + rFZ/iAyh2u93EpMIx+5mNPScjumTCp+P/BBERcrjmrPhp9ii3RUcMVUWzaoj3Lhc\n\ + wa5trC1r7UqbUZeO7NaVA7cGETZLVm8U7NaL8ccb1dKASUzrC9QCy9VVekJbb2S7\n\ + h38MELR9wvTGS7s4hXQGejb8vEDuXcZzWIFg3YMkJPIyGLAEaRynfeAHm/ji48U0\n\ + zh1ba3CWE/6z6nayDPqWqrwic4Hff6Mz+SIWAz2LyQKBgQDcdeWweNRVXhVkcFUP\n\ + JNpUiLOF5j3f4nqZwk7j5hQBxcXilYO/lmrcimvhvJ3ox97GfqCkvEQM8thTnPmi\n\ + JBagynOfIaUK2qdVwS1BbZ2JpYe3k/rO+iSKtRO4mF94cHgFIafPb5qt0fFz9bDS\n\ + 7D2lnWSbveMvb+mZsp/+FZx2DwKBgQC3eBhAbOSrSGuh7KOuWsav8pROMdcsESpz\n\ + j8el1iEklRsklYiNrVsztlZtNUXE2zSHeNPsGENDGlvKG8qD/vbcdTFsYa1H8Hk5\n\ + NydTLAb0/Bm256Xee1Dm5Wt2yG2aLfc9eG0trJz8VgBDhDlulnjo2kavhWIpTBNm\n\ + 0WmkMQsQ+QKBgQDYXd1PlUbPgcb9DEJu2nxs+r02bQHM+TnaLhm/EdAQ7UmJV7Q2\n\ + FCpMyI2YvsU78O1zYlPHWf5vtucZKLbXqxOKOye+xgZ04KPaRf1keXBj51GLmnBN\n\ + MrMqbw0r3l/UlI02fBF2RNJKRgHzDO6+E51tLUvQjkyqAewCLI1ZkVw9gQKBgD0F\n\ + J2O+E+vX4VxwnRvvOyfn0WWUdBFHAEyBJJDGgC1vniBzz3/3iV7QpTwbPMI1eeoY\n\ + yLs8cpqN2LuGtLtkAGzgWXjHn99OXrMl4eFqwkGW22KW9vbhIs44vZ47GSDvasy6\n\ + Ee3f/DJ81AegoY1jZIFln57fCP/dOpK20aD3YsvZAoGBAKgaWVYbROCRJ6C8CQGd\n\ + yetoZ8n25E7O5JtyKSNGwiQyD0IURgLuotiBpQvCCz9HGS53E6HLzBCc4jZc3GDq\n\ + qVDS5cIgcfWAOBalBQ+JxoHsnLRGXeBBKwvaJB+EzlrV8st1dCmM4gukElBJm/PZ\n\ + TvEPeiHG81OgB1RPgUt3DVIf\n\ + -----END PRIVATE KEY-----\n\ + "; + + #[test] + fn test_pem_convert_to_rustls() -> Result<(), Error> { + let cert = pem_convert_to_rustls(KEY, CERT)?; + assert_eq!(cert.san, vec!["novg"]); + Ok(()) + } +} diff --git a/src/tls/cert/syncer/mod.rs b/src/tls/cert/syncer/mod.rs index 0473718..b6a0587 100644 --- a/src/tls/cert/syncer/mod.rs +++ b/src/tls/cert/syncer/mod.rs @@ -7,15 +7,14 @@ use async_trait::async_trait; use candid::Principal; use mockall::automock; use reqwest::{Method, Request, StatusCode, Url}; -use rustls::sign::CertifiedKey; use serde::Deserialize; use crate::{ - http::HttpClient, + http::client, tls::cert::{ pem_convert_to_rustls, syncer::verify::{Verify, VerifyError, WithVerify}, - ProvidesCertificates, + Cert, ProvidesCertificates, }, }; @@ -48,12 +47,12 @@ pub trait Import: Sync + Send { } pub struct CertificatesImporter { - http_client: Arc, + http_client: Arc, exporter_url: Url, } impl CertificatesImporter { - pub fn new(http_client: Arc, exporter_url: Url) -> Self { + pub fn new(http_client: Arc, exporter_url: Url) -> Self { Self { http_client, exporter_url, @@ -63,7 +62,7 @@ impl CertificatesImporter { #[async_trait] impl ProvidesCertificates for CertificatesImporter { - async fn get_certificates(&self) -> Result>, anyhow::Error> { + async fn get_certificates(&self) -> Result, anyhow::Error> { let certs = self .import() .await? @@ -130,11 +129,11 @@ mod tests { use reqwest::Body; use std::{str::FromStr, sync::Arc}; - use crate::{http::MockHttpClient, tls::cert::syncer::verify::MockVerify}; + use crate::{http::client::MockClient, tls::cert::syncer::verify::MockVerify}; #[tokio::test] async fn import_ok() -> Result<(), AnyhowError> { - let mut http_client = MockHttpClient::new(); + let mut http_client = MockClient::new(); http_client .expect_execute() .times(1) diff --git a/src/tls/cert/test.rs b/src/tls/cert/test.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/tls/mod.rs b/src/tls/mod.rs index d1ce713..ae7da67 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -11,6 +11,9 @@ use rustls::{ RootCertStore, }; +const ALPN_H1: &[u8] = b"http/1.1"; +const ALPN_H2: &[u8] = b"h2"; + pub fn prepare_rustls_server_config(resolver: Arc) -> ServerConfig { let mut cfg = ServerConfig::builder_with_protocol_versions(&[&TLS13, &TLS12]) .with_no_client_auth() @@ -18,7 +21,7 @@ pub fn prepare_rustls_server_config(resolver: Arc) -> Se // Create custom session storage with higher limit to allow effective TLS session resumption cfg.session_storage = ServerSessionMemoryCache::new(131_072); - cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()]; cfg } @@ -36,7 +39,7 @@ pub fn prepare_rustls_client_config() -> ClientConfig { // Session resumption let store = ClientSessionMemoryCache::new(2048); cfg.resumption = Resumption::store(Arc::new(store)); - cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()]; cfg } diff --git a/src/tls/resolver.rs b/src/tls/resolver.rs index c882020..ceb4f14 100644 --- a/src/tls/resolver.rs +++ b/src/tls/resolver.rs @@ -1,23 +1,22 @@ -use std::{collections::HashMap, str::FromStr, sync::Arc}; +use std::{str::FromStr, sync::Arc}; -use arc_swap::ArcSwapOption; use fqdn::FQDN; use rustls::{ server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; -pub type CertStorage = Arc>>; +use crate::tls::cert::Storage; // Generic certificate resolver that supports wildcards. // It provides Rustls with a certificate corresponding to the SNI hostname, if there's one. #[derive(Debug)] -pub struct CertResolver { - storage: CertStorage, +pub struct Resolver { + storage: Storage, } -impl CertResolver { - pub fn new(storage: CertStorage) -> Self { +impl Resolver { + pub fn new(storage: Storage) -> Self { Self { storage } } @@ -39,7 +38,7 @@ impl CertResolver { } } -impl ResolvesServerCert for CertResolver> { +impl ResolvesServerCert for Resolver> { fn resolve(&self, ch: ClientHello) -> Option> { // See if client provided us with an SNI let sni = ch.server_name()?; @@ -50,6 +49,10 @@ impl ResolvesServerCert for CertResolver> { #[cfg(test)] mod test { use super::*; + + use arc_swap::ArcSwapOption; + use std::collections::HashMap; + use anyhow::Error; #[test] @@ -60,8 +63,8 @@ mod test { hm.insert("foo.baz".to_string(), "foo.baz".to_string()); hm.insert("bad:hostname".to_string(), "bad".to_string()); - let storage: CertStorage = Arc::new(ArcSwapOption::new(Some(Arc::new(hm)))); - let resolver = CertResolver::new(storage); + let storage: Storage = Arc::new(ArcSwapOption::new(Some(Arc::new(hm)))); + let resolver = Resolver::new(storage); assert_eq!(resolver.find_cert("foo.bar"), Some("foo.bar".into())); assert_eq!(resolver.find_cert("blah.foo.bar"), Some("*.foo.bar".into()));