Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Apr 19, 2024
1 parent 6e8850b commit 5103588
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 38 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ candid = "0.10"
clap = { version = "4.5", features = ["derive", "string"] }
clap_derive = "4.5"
ctrlc = { version = "3.4", features = ["termination"] }
derive-new = "0.6"
fqdn = "0.3"
futures = "0.3"
futures-util = "0.3"
Expand Down Expand Up @@ -53,6 +54,7 @@ reqwest = { git = "https://github.com/blind-oracle/reqwest.git", default_feature
"stream",
] }
rustls = "0.23"
rustls-acme = "0.9"
rustls-pemfile = "2"
webpki-roots = "0.26"
serde = "1.0"
Expand Down
1 change: 1 addition & 0 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub trait Run: Send + Sync {
}

async fn handle(request: axum::extract::Request) -> impl IntoResponse {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
warn!("{:?}", request.extensions().get::<Arc<ConnInfo>>());
"Hello"
}
Expand Down
26 changes: 17 additions & 9 deletions src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
};

use crate::core::Run;
use anyhow::{anyhow, Error};
use anyhow::{anyhow, Context, Error};
use async_trait::async_trait;
use axum::{extract::Request, Router};
use hyper::body::Incoming;
Expand All @@ -15,7 +15,7 @@ use hyper_util::{
};
use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion};
use tokio::{
io::{AsyncRead, AsyncWrite},
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpSocket, TcpStream},
select,
};
Expand All @@ -24,7 +24,7 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tower_service::Service;
use tracing::{debug, warn};

use crate::{cli, tls};
use crate::{cli, tls::is_http_alpn};

// Blanket async read+write trait to box streams
trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
Expand Down Expand Up @@ -134,7 +134,15 @@ impl Conn {

// Perform TLS handshake if we're in TLS mode
let (stream, tls_info): (Box<dyn AsyncReadWrite>, _) = if self.tls_acceptor.is_some() {
let (stream, tls_info) = self.tls_handshake(stream).await?;
let (mut stream, tls_info) = self.tls_handshake(stream).await?;

// Close the connection if agreed ALPN is not HTTP - probably it was an ACME challenge
if !is_http_alpn(tls_info.alpn.as_bytes()) {
debug!("Not HTTP ALPN ('{}') - closing connection", tls_info.alpn);
stream.shutdown().await.context("error in shutdown()")?;
return Ok(());
}

(Box::new(stream), Some(tls_info))
} else {
(Box::new(stream), None)
Expand Down Expand Up @@ -189,11 +197,6 @@ impl Conn {
},
}

debug!(
"Server {}: {}: connection finished",
self.addr, self.remote_addr
);

Ok(())
}
}
Expand Down Expand Up @@ -295,6 +298,11 @@ impl Run for Server {
if let Err(e) = conn.handle(stream).await {
warn!("Server {}: {}: failed to handle connection: {e}", conn.addr, remote_addr);
}

debug!(
"Server {}: {}: connection finished",
conn.addr, remote_addr
);
});
}
}
Expand Down
5 changes: 1 addition & 4 deletions src/tls/cert/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,13 @@ pub fn pem_convert_to_rustls(key: &[u8], certs: &[u8]) -> Result<CertKey, Error>
}

// Collects certificates from providers and stores them in a given storage
#[derive(derive_new::new)]
pub struct Aggregator {
providers: Vec<Arc<dyn ProvidesCertificates>>,
storage: Arc<StorageKey>,
}

impl Aggregator {
pub fn new(providers: Vec<Arc<dyn ProvidesCertificates>>, storage: Arc<StorageKey>) -> Self {
Self { providers, storage }
}

// Fetches certificates concurrently from all providers
async fn fetch(&self) -> Result<Vec<CertKey>, Error> {
let certs = join_all(
Expand Down
11 changes: 3 additions & 8 deletions src/tls/cert/providers/dir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@ use async_trait::async_trait;
use tokio::fs::read_dir;
use tracing::info;

pub struct Provider {
path: PathBuf,
}

// It searches for .pem files in the given directory and tries to find the
// corresponding .key files with the same base name.
// After that it loads & parses each pair.
impl Provider {
pub const fn new(path: PathBuf) -> Self {
Self { path }
}
#[derive(derive_new::new)]
pub struct Provider {
path: PathBuf,
}

#[async_trait]
Expand Down
16 changes: 7 additions & 9 deletions src/tls/cert/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@ use anyhow::{anyhow, Error};
use arc_swap::ArcSwapOption;
use candid::Principal;
use fqdn::FQDN;
use rustls::{
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use rustls::{server::ClientHello, sign::CertifiedKey};

use super::{Cert, LookupCanister};
use crate::tls;
use crate::tls::{self, resolver};

#[derive(Debug)]
struct StorageInner<T: Clone> {
Expand Down Expand Up @@ -58,6 +55,7 @@ impl<T: Clone> Storage<T> {
let mut canisters = HashMap::new();

for c in cert_list {
// Take note of the canister ID
if let Some(v) = c.custom {
if canisters.insert(v.name.clone(), v.canister_id).is_some() {
return Err(anyhow!("Duplicate name detected: {}", v.name));
Expand All @@ -79,11 +77,11 @@ impl<T: Clone> Storage<T> {
}

// Implement certificate resolving for Rustls
impl ResolvesServerCert for StorageKey {
fn resolve(&self, ch: ClientHello) -> Option<Arc<CertifiedKey>> {
// Make sure we've got an ALPN list and they're all HTTP, otherwise refuse resolving
impl resolver::ResolvesServerCert for StorageKey {
fn resolve(&self, ch: &ClientHello) -> Option<Arc<CertifiedKey>> {
// Make sure we've got an ALPN list and they're all HTTP, otherwise refuse resolving.
// This is to make sure we don't answer to e.g. ACME challenges here
if !ch.alpn()?.all(|x| tls::ALPN_HTTP.contains(&x)) {
if !ch.alpn()?.all(tls::is_http_alpn) {
return None;
}

Expand Down
31 changes: 23 additions & 8 deletions src/tls/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod cert;
pub mod resolver;
mod test;

use std::sync::Arc;
Expand All @@ -10,19 +11,27 @@ use rustls::{
version::{TLS12, TLS13},
RootCertStore,
};
use rustls_acme::acme::ACME_TLS_ALPN_NAME;

use crate::{
cli::Cli,
core::Run,
http,
tls::cert::{providers, storage::Storage, Aggregator},
tls::{
cert::{providers, storage::Storage, Aggregator},
resolver::AggregatingResolver,
},
};

use cert::providers::ProvidesCertificates;

pub const ALPN_H1: &[u8] = b"http/1.1";
pub const ALPN_H2: &[u8] = b"h2";
pub const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2];
const ALPN_H1: &[u8] = b"http/1.1";
const ALPN_H2: &[u8] = b"h2";
const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2];

pub fn is_http_alpn(alpn: &[u8]) -> bool {
ALPN_HTTP.contains(&alpn)
}

pub fn prepare_server_config(resolver: Arc<dyn ResolvesServerCert>) -> ServerConfig {
let mut cfg = ServerConfig::builder_with_protocol_versions(&[&TLS13, &TLS12])
Expand All @@ -31,7 +40,12 @@ pub fn prepare_server_config(resolver: Arc<dyn ResolvesServerCert>) -> ServerCon

// Create custom session storage with higher limit to allow effective TLS session resumption
cfg.session_storage = ServerSessionMemoryCache::new(131_072);
cfg.alpn_protocols = vec![ALPN_H2.to_vec(), ALPN_H1.to_vec()];
cfg.alpn_protocols = vec![
ALPN_H2.to_vec(),
ALPN_H1.to_vec(),
// Support ACME challenge ALPN too
ACME_TLS_ALPN_NAME.to_vec(),
];

cfg
}
Expand Down Expand Up @@ -79,8 +93,9 @@ pub fn setup(
}

let storage = Arc::new(Storage::new());
let aggregator = Arc::new(Aggregator::new(providers, storage.clone()));
let config = prepare_server_config(storage);
let cert_aggregator = Arc::new(Aggregator::new(providers, storage.clone()));
let resolve_aggregator = Arc::new(AggregatingResolver::new(None, vec![storage]));
let config = prepare_server_config(resolve_aggregator);

Ok((aggregator, config))
Ok((cert_aggregator, config))
}
34 changes: 34 additions & 0 deletions src/tls/resolver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::{fmt::Debug, sync::Arc};

use rustls::{
server::{ClientHello, ResolvesServerCert as ResolvesServerCertRustls},
sign::CertifiedKey,
};

// Custom ResolvesServerCert trait that takes ClientHello by reference.
// It's needed because Rustls' ResolvesServerCert consumes ClientHello
pub trait ResolvesServerCert: Debug + Send + Sync {
fn resolve(&self, client_hello: &ClientHello) -> Option<Arc<CertifiedKey>>;
}

// Combines several certificate resolvers into one
// Only one Rustls-compatible resolver can be used (acme) since it takes ClientHello by value
#[derive(Debug, derive_new::new)]
pub struct AggregatingResolver {
acme: Option<Arc<dyn ResolvesServerCertRustls>>,
resolvers: Vec<Arc<dyn ResolvesServerCert>>,
}

// Implement certificate resolving for Rustls
impl ResolvesServerCertRustls for AggregatingResolver {
fn resolve(&self, ch: ClientHello) -> Option<Arc<CertifiedKey>> {
// Iterate over our resolvers to find matching cert if any
let cert = self.resolvers.iter().find_map(|x| x.resolve(&ch));
if let Some(v) = cert {
return Some(v);
}

// Otherwise try the ACME resolver with Rustls trait that consumes ClientHello
self.acme.as_ref().and_then(|x| x.resolve(ch))
}
}

0 comments on commit 5103588

Please sign in to comment.