From 01d13f1c295808279a695061a4d17534a7271412 Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Mon, 13 May 2024 19:28:18 +0200 Subject: [PATCH] issuer proxy, a lot of other changes --- Cargo.toml | 1 + src/cli.rs | 61 ++++++- src/core.rs | 50 ++++-- src/http/client.rs | 35 ++-- src/http/dns.rs | 30 ++-- src/http/mod.rs | 8 + src/http/server.rs | 27 +-- src/metrics/body.rs | 4 - src/metrics/mod.rs | 2 +- src/routing/body.rs | 42 +++++ src/routing/error_cause.rs | 189 +++++++++++++++++++++ src/routing/middleware/policy.rs | 38 +++-- src/routing/mod.rs | 238 +++++++++++---------------- src/tls/acme.rs | 2 +- src/tls/cert/providers/syncer/mod.rs | 3 +- src/tls/cert/storage.rs | 5 +- src/tls/mod.rs | 13 +- src/tls/resolver.rs | 12 +- 18 files changed, 506 insertions(+), 254 deletions(-) create mode 100644 src/routing/body.rs create mode 100644 src/routing/error_cause.rs diff --git a/Cargo.toml b/Cargo.toml index 37aa55f..1d00030 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ serde = "1.0" serde_json = "1.0" strum = "0.26" strum_macros = "0.26" +sync_wrapper = "1.0" thiserror = "1.0" tempfile = "3.10" tokio = { version = "1.36", features = ["full"] } diff --git a/src/cli.rs b/src/cli.rs index 3140afe..3f64c0a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -14,7 +14,7 @@ use crate::{ core::{AUTHOR_NAME, SERVICE_NAME}, http::dns, routing::canister::CanisterAlias, - tls::acme, + tls::{self, acme}, }; #[derive(Parser)] @@ -146,12 +146,18 @@ pub struct Cert { #[derive(Args)] pub struct Domain { - /// List of domains that we serve system subnets from - #[clap(long = "domain-system")] + /// Specify domains that will be served. This affects the routing, canister extraction, ACME certificate issuing etc. + #[clap(long = "domain")] + pub domains: Vec, + + /// List of domains that we serve system subnets from. This enables domain-canister matching for these domains & adds them to the list of served domains above, do not list them there separately. + /// Requires --domain-app. + #[clap(long = "domain-system", requires = "domains_app")] pub domains_system: Vec, - /// List of domains that we serve app subnets from - #[clap(long = "domain-app")] + /// List of domains that we serve app subnets from. See --domain-system above for details. + /// Requires --domain-system. + #[clap(long = "domain-app", requires = "domains_system")] pub domains_app: Vec, /// List of canister aliases in format ':' @@ -184,7 +190,9 @@ pub struct Policy { #[derive(Args)] pub struct Acme { - /// Type of ACME challenge to use. Currently supported: alpn + /// Type of ACME challenge to use. Currently supported: alpn. + /// If specified it will try to obtain the certificate that is valid for all specified domains + /// (--domain-app & --domain-system). For this to succeed they all should resolve to the hostname where this service is running. #[clap(long = "acme-challenge", requires = "acme_cache_path")] pub acme_challenge: Option, @@ -193,7 +201,7 @@ pub struct Acme { #[clap(long = "acme-cache-path")] pub acme_cache_path: Option, - /// Whether to use LetsEncrypt staging API to avoid hitting the limits + /// Whether to use LetsEncrypt staging API for testing to avoid hitting the limits #[clap(long = "acme-staging")] pub acme_staging: bool, } @@ -211,3 +219,42 @@ pub struct Misc { #[clap(long = "geoip-db")] pub geoip_db: Option, } + +// Some conversions +impl From<&HttpServer> for crate::http::server::Options { + fn from(c: &HttpServer) -> Self { + Self { + backlog: c.backlog, + http2_keepalive_interval: c.http2_keepalive_interval, + http2_keepalive_timeout: c.http2_keepalive_timeout, + http2_max_streams: c.http2_max_streams, + grace_period: c.grace_period, + } + } +} + +impl From<&Dns> for crate::http::dns::Options { + fn from(c: &Dns) -> Self { + Self { + protocol: c.protocol, + servers: c.servers.clone(), + tls_name: c.tls_name.clone(), + cache_size: c.cache_size, + } + } +} + +impl From<&Cli> for crate::http::client::Options { + fn from(c: &Cli) -> Self { + Self { + dns_options: (&c.dns).into(), + timeout_connect: c.http_client.timeout_connect, + timeout: c.http_client.timeout, + tcp_keepalive: Some(c.http_client.tcp_keepalive), + http2_keepalive: Some(c.http_client.http2_keepalive), + http2_keepalive_timeout: c.http_client.http2_keepalive_timeout, + user_agent: crate::core::SERVICE_NAME.into(), + tls_config: tls::prepare_client_config(), + } + } +} diff --git a/src/core.rs b/src/core.rs index 0e8d63c..075953a 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,15 +1,15 @@ +use std::{error::Error as StdError, sync::Arc}; + use anyhow::{anyhow, Error}; use async_trait::async_trait; use prometheus::Registry; use rustls::sign::CertifiedKey; -use std::sync::Arc; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{error, warn}; use crate::{ cli::Cli, - http::{server, ReqwestClient, Server}, - metrics, + http, metrics, routing::{ self, canister::{CanisterResolver, ResolvesCanister}, @@ -31,6 +31,26 @@ pub trait Run: Send + Sync { } pub struct Runner(pub String, pub Arc); +#[async_trait] +impl Run for http::Server { + async fn run(&self, token: CancellationToken) -> Result<(), Error> { + self.serve(token).await + } +} + +pub fn error_source(error: &impl StdError) -> Option<&E> { + let mut source = error.source(); + while let Some(err) = source { + if let Some(v) = err.downcast_ref() { + return Some(v); + } + + source = err.source(); + } + + None +} + pub async fn main(cli: &Cli) -> Result<(), Error> { // Install crypto-provider rustls::crypto::aws_lc_rs::default_provider() @@ -41,7 +61,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { let token = CancellationToken::new(); let tracker = TaskTracker::new(); let registry = Registry::new(); - let http_client = Arc::new(ReqwestClient::new(cli)?); + let http_client = Arc::new(http::ReqwestClient::new(cli.into())?); // List of cancellable tasks to execute & track let mut runners: Vec = vec![]; @@ -52,12 +72,13 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { ctrlc::set_handler(move || handler_token.cancel())?; // Make a list of all supported domains - let mut domains = cli.domain.domains_system.clone(); + let mut domains = cli.domain.domains.clone(); + domains.extend_from_slice(&cli.domain.domains_system); domains.extend_from_slice(&cli.domain.domains_app); if domains.is_empty() { return Err(anyhow!( - "No domains specified (use --domain-system and/or --domain-app)" + "No domains to serve specified (use --domain/--domain-system/--domain-app)" )); } @@ -82,13 +103,11 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { runners.push(Runner("denylist_updater".into(), v)); } - let server_options = server::Options::from(&cli.http_server); - // Set up HTTP - let http_server = Arc::new(Server::new( + let http_server = Arc::new(http::Server::new( cli.http_server.http, router.clone(), - server_options, + (&cli.http_server).into(), None, )) as Arc; runners.push(Runner("http_server".into(), http_server)); @@ -103,10 +122,10 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { )?; runners.extend(tls_runners); - let https_server = Arc::new(Server::new( + let https_server = Arc::new(http::Server::new( cli.http_server.https, router, - server_options, + (&cli.http_server).into(), Some(rustls_cfg), )) as Arc; runners.push(Runner("https_server".into(), https_server)); @@ -116,7 +135,12 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { let (router, runner) = metrics::setup(®istry); runners.push(Runner("metrics_runner".into(), runner)); - let srv = Arc::new(Server::new(addr, router, server_options, None)); + let srv = Arc::new(http::Server::new( + addr, + router, + (&cli.http_server).into(), + None, + )); runners.push(Runner("metrics_server".into(), srv as Arc)); } diff --git a/src/http/client.rs b/src/http/client.rs index 90f1387..80dbf6d 100644 --- a/src/http/client.rs +++ b/src/http/client.rs @@ -1,9 +1,9 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use mockall::automock; -use crate::{cli, core::SERVICE_NAME, http::dns::Resolver, tls::prepare_client_config}; +use super::dns; #[automock] #[async_trait] @@ -11,25 +11,34 @@ pub trait Client: Send + Sync { async fn execute(&self, req: reqwest::Request) -> Result; } +pub struct Options { + pub dns_options: dns::Options, + pub timeout_connect: Duration, + pub timeout: Duration, + pub tcp_keepalive: Option, + pub http2_keepalive: Option, + pub http2_keepalive_timeout: Duration, + pub user_agent: String, + pub tls_config: rustls::ClientConfig, +} + #[derive(Clone)] pub struct ReqwestClient(reqwest::Client); impl ReqwestClient { - pub fn new(cli: &cli::Cli) -> Result { - let http = &cli.http_client; - + pub fn new(opts: Options) -> Result { let client = reqwest::Client::builder() - .use_preconfigured_tls(prepare_client_config()) - .dns_resolver(Arc::new(Resolver::new(&cli.dns))) - .connect_timeout(http.timeout_connect) - .timeout(http.timeout) + .use_preconfigured_tls(opts.tls_config) + .dns_resolver(Arc::new(dns::Resolver::new(opts.dns_options))) + .connect_timeout(opts.timeout_connect) + .timeout(opts.timeout) .tcp_nodelay(true) - .tcp_keepalive(Some(http.tcp_keepalive)) - .http2_keep_alive_interval(Some(http.http2_keepalive)) - .http2_keep_alive_timeout(http.http2_keepalive_timeout) + .tcp_keepalive(opts.tcp_keepalive) + .http2_keep_alive_interval(opts.http2_keepalive) + .http2_keep_alive_timeout(opts.http2_keepalive_timeout) .http2_keep_alive_while_idle(true) .http2_adaptive_window(true) - .user_agent(SERVICE_NAME) + .user_agent(opts.user_agent) .redirect(reqwest::redirect::Policy::none()) .no_proxy() .build()?; diff --git a/src/http/dns.rs b/src/http/dns.rs index 82429e6..55dc6d7 100644 --- a/src/http/dns.rs +++ b/src/http/dns.rs @@ -1,4 +1,7 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; use hickory_resolver::{ config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, @@ -8,9 +11,7 @@ use hickory_resolver::{ use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use strum_macros::EnumString; -use crate::cli::Dns; - -#[derive(Clone, Debug, EnumString)] +#[derive(Clone, Copy, Debug, EnumString)] #[strum(serialize_all = "snake_case")] pub enum Protocol { Clear, @@ -21,16 +22,21 @@ pub enum Protocol { #[derive(Debug, Clone)] pub struct Resolver(Arc); +pub struct Options { + pub protocol: Protocol, + pub servers: Vec, + pub tls_name: String, + pub cache_size: usize, +} + // new() must be called in Tokio context impl Resolver { - pub fn new(cli: &Dns) -> Self { - let name_servers = match cli.protocol { - Protocol::Clear => NameServerConfigGroup::from_ips_clear(&cli.servers, 53, true), - Protocol::Tls => { - NameServerConfigGroup::from_ips_tls(&cli.servers, 853, cli.tls_name.clone(), true) - } + pub fn new(o: Options) -> Self { + let name_servers = match o.protocol { + Protocol::Clear => NameServerConfigGroup::from_ips_clear(&o.servers, 53, true), + Protocol::Tls => NameServerConfigGroup::from_ips_tls(&o.servers, 853, o.tls_name, true), Protocol::Https => { - NameServerConfigGroup::from_ips_https(&cli.servers, 443, cli.tls_name.clone(), true) + NameServerConfigGroup::from_ips_https(&o.servers, 443, o.tls_name, true) } }; @@ -38,7 +44,7 @@ impl Resolver { let mut opts = ResolverOpts::default(); opts.rotate = true; - opts.cache_size = cli.cache_size; + opts.cache_size = o.cache_size; opts.use_hosts_file = false; opts.preserve_intermediates = false; opts.try_tcp_on_error = true; diff --git a/src/http/mod.rs b/src/http/mod.rs index 4b2841d..797f3bb 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -7,6 +7,14 @@ use http::{HeaderMap, Version}; pub use client::{Client, ReqwestClient}; pub use server::{ConnInfo, Server}; +pub const ALPN_H1: &[u8] = b"http/1.1"; +pub const ALPN_H2: &[u8] = b"h2"; +pub const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2]; + +pub fn is_http_alpn(alpn: &[u8]) -> bool { + ALPN_HTTP.contains(&alpn) +} + // Calculate very approximate HTTP request/response headers size in bytes. // More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed. // But it seems there's no better way. diff --git a/src/http/server.rs b/src/http/server.rs index a110b01..6d3b78c 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -5,9 +5,7 @@ use std::{ time::{Duration, Instant}, }; -use crate::core::Run; use anyhow::{anyhow, Context, Error}; -use async_trait::async_trait; use axum::{extract::Request, Router}; use fqdn::FQDN; use hyper::body::Incoming; @@ -26,7 +24,7 @@ use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower_service::Service; use tracing::{debug, warn}; -use crate::{cli, tls::is_http_alpn}; +use super::is_http_alpn; // Blanket async read+write trait to box streams trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Sync + Unpin {} @@ -41,18 +39,6 @@ pub struct Options { pub grace_period: Duration, } -impl From<&cli::HttpServer> for Options { - fn from(c: &cli::HttpServer) -> Self { - Self { - backlog: c.backlog, - http2_keepalive_interval: c.http2_keepalive_interval, - http2_keepalive_timeout: c.http2_keepalive_timeout, - http2_max_streams: c.http2_max_streams, - grace_period: c.grace_period, - } - } -} - // TLS information about the connection #[derive(Clone, Debug)] pub struct TlsInfo { @@ -107,7 +93,7 @@ struct Conn { } impl Conn { - pub async fn tls_handshake( + async fn tls_handshake( &self, stream: TcpStream, ) -> Result<(TlsStream, TlsInfo), Error> { @@ -137,7 +123,7 @@ impl Conn { Ok((stream, tls_info)) } - pub async fn handle(&self, stream: TcpStream) -> Result<(), Error> { + async fn handle(&self, stream: TcpStream) -> Result<(), Error> { let accepted_at = Instant::now(); debug!( @@ -242,11 +228,8 @@ impl Server { tls_acceptor: rustls_cfg.map(|x| TlsAcceptor::from(Arc::new(x))), } } -} -#[async_trait] -impl Run for Server { - async fn run(&self, token: CancellationToken) -> Result<(), Error> { + pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> { let listener = listen_tcp_backlog(self.addr, self.options.backlog)?; // Prepare Hyper connection builder @@ -290,7 +273,7 @@ impl Run for Server { Ok(v) => v, Err(e) => { warn!("Unable to accept connection: {e}"); - // Wait few ms just in case that there's an overflowed backlog + // Wait few ms just in case that there's an overflown backlog // so that we don't run into infinite error loop tokio::time::sleep(Duration::from_millis(10)).await; continue; diff --git a/src/metrics/body.rs b/src/metrics/body.rs index f48d099..8d7edd9 100644 --- a/src/metrics/body.rs +++ b/src/metrics/body.rs @@ -130,11 +130,9 @@ mod test { let body = axum::body::Body::from_stream(stream); let (tx, rx) = std::sync::mpsc::channel(); - let callback = move |response_size: u64, _body_result: Result<(), String>| { let _ = tx.send(response_size); }; - let body = CountingBody::new(body, callback); // Check that the body streams the same data back @@ -153,11 +151,9 @@ mod test { let body = http_body_util::Full::new(buf); let (tx, rx) = std::sync::mpsc::channel(); - let callback = move |response_size: u64, _body_result: Result<(), String>| { let _ = tx.send(response_size); }; - let body = CountingBody::new(body, callback); // Check that the body streams the same data back diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 679e847..cf485ea 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -30,7 +30,7 @@ use tracing::{debug, info, warn}; use crate::{ core::Run, http::{calc_headers_size, http_version, server::ConnInfo}, - routing::{middleware::request_id::RequestId, ErrorCause, RequestCtx}, + routing::{error_cause::ErrorCause, middleware::request_id::RequestId, RequestCtx}, }; use body::CountingBody; diff --git a/src/routing/body.rs b/src/routing/body.rs new file mode 100644 index 0000000..c4b9b5a --- /dev/null +++ b/src/routing/body.rs @@ -0,0 +1,42 @@ +use std::{ + pin::{pin, Pin}, + task::{Context, Poll}, +}; + +use axum::{body::Body, Error}; +use bytes::Bytes; +use futures::Stream; +use http_body::Body as _; +use sync_wrapper::SyncWrapper; + +// Wrapper for Axum body that makes it Sync to be usable with Request +// TODO find a better way? +pub struct SyncBodyDataStream { + inner: SyncWrapper, +} + +impl SyncBodyDataStream { + pub const fn new(body: Body) -> Self { + Self { + inner: SyncWrapper::new(body), + } + } +} + +impl Stream for SyncBodyDataStream { + type Item = Result; + + #[inline] + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let mut pinned = pin!(self.inner.get_mut()); + match futures_util::ready!(pinned.as_mut().poll_frame(cx)?) { + Some(frame) => match frame.into_data() { + Ok(data) => return Poll::Ready(Some(Ok(data))), + Err(_frame) => {} + }, + None => return Poll::Ready(None), + } + } + } +} diff --git a/src/routing/error_cause.rs b/src/routing/error_cause.rs new file mode 100644 index 0000000..9482797 --- /dev/null +++ b/src/routing/error_cause.rs @@ -0,0 +1,189 @@ +use std::fmt; + +use axum::response::{IntoResponse, Response}; +use hickory_resolver::error::ResolveError; +use http::StatusCode; +use strum_macros::Display; + +// Process error chain trying to find given error type +pub fn error_infer(error: &anyhow::Error) -> Option<&E> { + for cause in error.chain() { + if let Some(e) = cause.downcast_ref() { + return Some(e); + } + } + None +} + +#[derive(Debug, Clone, Display)] +#[strum(serialize_all = "snake_case")] +pub enum RateLimitCause { + Normal, +} + +// Categorized possible causes for request processing failures +// Not using Error as inner type since it's not cloneable +#[derive(Debug, Clone)] +pub enum ErrorCause { + UnableToReadBody(String), + PayloadTooLarge(usize), + UnableToParseCBOR(String), + UnableToParseHTTPArg(String), + LoadShed, + MalformedRequest(String), + MalformedResponse(String), + NoAuthority, + CanisterIdNotFound, + SNIMismatch, + DomainCanisterMismatch, + Denylisted, + NoRoutingTable, + SubnetNotFound, + NoHealthyNodes, + BackendErrorDNS(String), + BackendErrorConnect, + BackendTimeout, + BackendTLSErrorOther(String), + BackendTLSErrorCert(String), + BackendErrorOther(String), + RateLimited(RateLimitCause), + Other(String), +} + +impl ErrorCause { + pub const fn status_code(&self) -> StatusCode { + match self { + Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, + Self::UnableToReadBody(_) => StatusCode::REQUEST_TIMEOUT, + Self::UnableToParseCBOR(_) => StatusCode::BAD_REQUEST, + Self::UnableToParseHTTPArg(_) => StatusCode::BAD_REQUEST, + Self::LoadShed => StatusCode::TOO_MANY_REQUESTS, + Self::MalformedRequest(_) => StatusCode::BAD_REQUEST, + Self::MalformedResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::NoAuthority => StatusCode::BAD_REQUEST, + Self::CanisterIdNotFound => StatusCode::BAD_REQUEST, + Self::SNIMismatch => StatusCode::BAD_REQUEST, + Self::DomainCanisterMismatch => StatusCode::FORBIDDEN, + Self::Denylisted => StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, + Self::NoRoutingTable => StatusCode::SERVICE_UNAVAILABLE, + Self::SubnetNotFound => StatusCode::BAD_REQUEST, // TODO change to 404? + Self::NoHealthyNodes => StatusCode::SERVICE_UNAVAILABLE, + Self::BackendErrorDNS(_) => StatusCode::SERVICE_UNAVAILABLE, + Self::BackendErrorConnect => StatusCode::SERVICE_UNAVAILABLE, + Self::BackendTimeout => StatusCode::INTERNAL_SERVER_ERROR, + Self::BackendTLSErrorOther(_) => StatusCode::SERVICE_UNAVAILABLE, + Self::BackendTLSErrorCert(_) => StatusCode::SERVICE_UNAVAILABLE, + Self::BackendErrorOther(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::RateLimited(_) => StatusCode::TOO_MANY_REQUESTS, + } + } + + pub fn details(&self) -> Option { + match self { + Self::Other(x) => Some(x.clone()), + Self::PayloadTooLarge(x) => Some(format!("maximum body size is {x} bytes")), + Self::UnableToReadBody(x) => Some(x.clone()), + Self::UnableToParseCBOR(x) => Some(x.clone()), + Self::UnableToParseHTTPArg(x) => Some(x.clone()), + Self::LoadShed => Some("Overloaded".into()), + Self::MalformedRequest(x) => Some(x.clone()), + Self::MalformedResponse(x) => Some(x.clone()), + Self::BackendErrorDNS(x) => Some(x.clone()), + Self::BackendTLSErrorOther(x) => Some(x.clone()), + Self::BackendTLSErrorCert(x) => Some(x.clone()), + Self::BackendErrorOther(x) => Some(x.clone()), + _ => None, + } + } + + pub const fn retriable(&self) -> bool { + matches!( + self, + Self::BackendErrorDNS(_) + | Self::BackendErrorConnect + | Self::BackendTLSErrorOther(_) + | Self::BackendTLSErrorCert(_) + ) + } +} + +impl fmt::Display for ErrorCause { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Other(_) => write!(f, "general_error"), + Self::UnableToReadBody(_) => write!(f, "unable_to_read_body"), + Self::PayloadTooLarge(_) => write!(f, "payload_too_large"), + Self::UnableToParseCBOR(_) => write!(f, "unable_to_parse_cbor"), + Self::UnableToParseHTTPArg(_) => write!(f, "unable_to_parse_http_arg"), + Self::LoadShed => write!(f, "load_shed"), + Self::MalformedRequest(_) => write!(f, "malformed_request"), + Self::MalformedResponse(_) => write!(f, "malformed_response"), + Self::CanisterIdNotFound => write!(f, "canister_id_not_found"), + Self::SNIMismatch => write!(f, "sni_mismatch"), + Self::DomainCanisterMismatch => write!(f, "domain_canister_mismatch"), + Self::Denylisted => write!(f, "denylisted"), + Self::NoAuthority => write!(f, "no_authority"), + Self::NoRoutingTable => write!(f, "no_routing_table"), + Self::SubnetNotFound => write!(f, "subnet_not_found"), + Self::NoHealthyNodes => write!(f, "no_healthy_nodes"), + Self::BackendErrorDNS(_) => write!(f, "backend_error_dns"), + Self::BackendErrorConnect => write!(f, "backend_error_connect"), + Self::BackendTimeout => write!(f, "backend_timeout"), + Self::BackendTLSErrorOther(_) => write!(f, "backend_tls_error"), + Self::BackendTLSErrorCert(_) => write!(f, "backend_tls_error_cert"), + Self::BackendErrorOther(_) => write!(f, "backend_error_other"), + Self::RateLimited(x) => write!(f, "rate_limited_{x}"), + } + } +} + +// Creates the response from ErrorCause and injects itself into extensions to be visible by middleware +impl IntoResponse for ErrorCause { + fn into_response(self) -> Response { + let mut body = self.to_string(); + + if let Some(v) = self.details() { + body = format!("{body}: {v}"); + } + + let mut resp = (self.status_code(), format!("{body}\n")).into_response(); + resp.extensions_mut().insert(self); + resp + } +} + +impl From for ErrorCause { + fn from(e: anyhow::Error) -> Self { + // Check if it's a known Reqwest error + if let Some(e) = error_infer::(&e) { + if e.is_connect() { + return Self::BackendErrorConnect; + } + + if e.is_timeout() { + return Self::BackendTimeout; + } + } + + // Check if it's a DNS error + if let Some(e) = error_infer::(&e) { + return Self::BackendErrorDNS(e.to_string()); + } + + // Check if it's a Rustls error + if let Some(e) = error_infer::(&e) { + return match e { + rustls::Error::InvalidCertificate(v) => { + Self::BackendTLSErrorCert(format!("{:?}", v)) + } + rustls::Error::NoCertificatesPresented => { + Self::BackendTLSErrorCert("no certificate presented".into()) + } + _ => Self::BackendTLSErrorOther(e.to_string()), + }; + } + + Self::BackendErrorOther(e.to_string()) + } +} diff --git a/src/routing/middleware/policy.rs b/src/routing/middleware/policy.rs index 4fc0fa3..e956fdd 100644 --- a/src/routing/middleware/policy.rs +++ b/src/routing/middleware/policy.rs @@ -17,27 +17,33 @@ use crate::{ }; pub struct PolicyState { - domain_canister_matcher: DomainCanisterMatcher, + domain_canister_matcher: Option, denylist: Option>, } +#[allow(clippy::type_complexity)] impl PolicyState { pub fn new( cli: &Cli, http_client: Arc, registry: &Registry, - ) -> Result<(Self, Option>), Error> { + ) -> Result<(Option, Option>), Error> { let pre_isolation_canisters = if let Some(v) = cli.policy.pre_isolation_canisters.as_ref() { load_canister_list(v).context("unable to load pre-isolation canisters")? } else { HashSet::new() }; - let domain_canister_matcher = DomainCanisterMatcher::new( - pre_isolation_canisters, - cli.domain.domains_app.clone(), - cli.domain.domains_system.clone(), - ); + // Enable matcher only if both system and app domains are specified. CLI makes sure that if one is set then another is too. + let domain_canister_matcher = if !cli.domain.domains_app.is_empty() { + Some(DomainCanisterMatcher::new( + pre_isolation_canisters, + cli.domain.domains_app.clone(), + cli.domain.domains_system.clone(), + )) + } else { + None + }; let denylist = if cli.policy.denylist_seed.is_some() || cli.policy.denylist_url.is_some() { Some(Arc::new( @@ -55,11 +61,12 @@ impl PolicyState { None }; + // Return the policy only if at least one of the filters is enabled Ok(( - Self { + (denylist.is_some() && domain_canister_matcher.is_some()).then_some(Self { domain_canister_matcher, denylist: denylist.clone(), - }, + }), denylist.map(|x| x as Arc), )) } @@ -72,19 +79,18 @@ pub async fn middleware( request: Request, next: Next, ) -> Result { - // Check denylisting + // Check denylisting if configured if let Some(v) = state.denylist.as_ref() { if v.is_blocked(ctx.canister.id, country_code.map(|x| x.0)) { return Err(ErrorCause::Denylisted); } } - // Check domain-canister matching - if !state - .domain_canister_matcher - .check(ctx.canister.id, &ctx.authority) - { - return Err(ErrorCause::DomainCanisterMismatch); + // Check domain-canister matching if configured + if let Some(v) = &state.domain_canister_matcher.as_ref() { + if !v.check(ctx.canister.id, &ctx.authority) { + return Err(ErrorCause::DomainCanisterMismatch); + } } Ok(next.run(request).await) diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 6158851..694684c 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -1,21 +1,29 @@ +pub mod body; pub mod canister; +pub mod error_cause; pub mod middleware; -use std::{fmt, sync::Arc}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; use anyhow::Error; use axum::{ + body::Body, + extract::{Request, State}, middleware::{from_fn, from_fn_with_state, FromFnLayer}, response::{IntoResponse, Response}, + routing::{get, post}, Router, }; use axum_extra::middleware::option_layer; +use derive_new::new; use fqdn::FQDN; -use http::StatusCode; use prometheus::Registry; -use strum_macros::Display; use tower::ServiceBuilder; use tracing::warn; +use url::Url; use crate::{ cli::Cli, @@ -25,7 +33,10 @@ use crate::{ routing::middleware::{geoip, headers, policy, request_id, validate}, }; -use self::canister::{Canister, ResolvesCanister}; +use { + canister::{Canister, ResolvesCanister}, + error_cause::ErrorCause, +}; pub struct RequestCtx { // HTTP2 authority or HTTP1 Host header @@ -33,146 +44,65 @@ pub struct RequestCtx { pub canister: Canister, } -#[derive(Debug, Clone, Display)] -#[strum(serialize_all = "snake_case")] -pub enum RateLimitCause { - Normal, - LedgerTransfer, -} - -// Categorized possible causes for request processing failures -// Not using Error as inner type since it's not cloneable -#[derive(Debug, Clone)] -pub enum ErrorCause { - UnableToReadBody(String), - PayloadTooLarge(usize), - UnableToParseCBOR(String), - UnableToParseHTTPArg(String), - LoadShed, - MalformedRequest(String), - MalformedResponse(String), - NoAuthority, - CanisterIdNotFound, - SNIMismatch, - DomainCanisterMismatch, - Denylisted, - NoRoutingTable, - SubnetNotFound, - NoHealthyNodes, - ReplicaErrorDNS(String), - ReplicaErrorConnect, - ReplicaTimeout, - ReplicaTLSErrorOther(String), - ReplicaTLSErrorCert(String), - ReplicaErrorOther(String), - RateLimited(RateLimitCause), - Other(String), -} - -impl ErrorCause { - pub const fn status_code(&self) -> StatusCode { - match self { - Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR, - Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, - Self::UnableToReadBody(_) => StatusCode::REQUEST_TIMEOUT, - Self::UnableToParseCBOR(_) => StatusCode::BAD_REQUEST, - Self::UnableToParseHTTPArg(_) => StatusCode::BAD_REQUEST, - Self::LoadShed => StatusCode::TOO_MANY_REQUESTS, - Self::MalformedRequest(_) => StatusCode::BAD_REQUEST, - Self::MalformedResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, - Self::NoAuthority => StatusCode::BAD_REQUEST, - Self::CanisterIdNotFound => StatusCode::BAD_REQUEST, - Self::SNIMismatch => StatusCode::BAD_REQUEST, - Self::DomainCanisterMismatch => StatusCode::FORBIDDEN, - Self::Denylisted => StatusCode::UNAVAILABLE_FOR_LEGAL_REASONS, - Self::NoRoutingTable => StatusCode::SERVICE_UNAVAILABLE, - Self::SubnetNotFound => StatusCode::BAD_REQUEST, // TODO change to 404? - Self::NoHealthyNodes => StatusCode::SERVICE_UNAVAILABLE, - Self::ReplicaErrorDNS(_) => StatusCode::SERVICE_UNAVAILABLE, - Self::ReplicaErrorConnect => StatusCode::SERVICE_UNAVAILABLE, - Self::ReplicaTimeout => StatusCode::INTERNAL_SERVER_ERROR, - Self::ReplicaTLSErrorOther(_) => StatusCode::SERVICE_UNAVAILABLE, - Self::ReplicaTLSErrorCert(_) => StatusCode::SERVICE_UNAVAILABLE, - Self::ReplicaErrorOther(_) => StatusCode::INTERNAL_SERVER_ERROR, - Self::RateLimited(_) => StatusCode::TOO_MANY_REQUESTS, - } - } - - pub fn details(&self) -> Option { - match self { - Self::Other(x) => Some(x.clone()), - Self::PayloadTooLarge(x) => Some(format!("maximum body size is {x} bytes")), - Self::UnableToReadBody(x) => Some(x.clone()), - Self::UnableToParseCBOR(x) => Some(x.clone()), - Self::UnableToParseHTTPArg(x) => Some(x.clone()), - Self::LoadShed => Some("Overloaded".into()), - Self::MalformedRequest(x) => Some(x.clone()), - Self::MalformedResponse(x) => Some(x.clone()), - Self::ReplicaErrorDNS(x) => Some(x.clone()), - Self::ReplicaTLSErrorOther(x) => Some(x.clone()), - Self::ReplicaTLSErrorCert(x) => Some(x.clone()), - Self::ReplicaErrorOther(x) => Some(x.clone()), - _ => None, - } - } - - pub const fn retriable(&self) -> bool { - matches!( - self, - Self::ReplicaErrorDNS(_) - | Self::ReplicaErrorConnect - | Self::ReplicaTLSErrorOther(_) - | Self::ReplicaTLSErrorCert(_) - ) - } +// Proxies provided Axum request to a given URL using Reqwest Client trait object and returns Axum response +async fn proxy( + url: Url, + request: Request, + http_client: &Arc, +) -> Result { + // Convert Axum request into Reqwest one + let (parts, body) = request.into_parts(); + let mut request = reqwest::Request::new(parts.method.clone(), url); + *request.headers_mut() = parts.headers; + // Use SyncBodyDataStream wrapper that is Sync (Axum body is !Sync) + *request.body_mut() = Some(reqwest::Body::wrap_stream(body::SyncBodyDataStream::new( + body, + ))); + + // Execute the request + let response = http_client.execute(request).await?; + let headers = response.headers().clone(); + + // Convert the Reqwest response back to the Axum one + let mut response = Response::builder() + .status(response.status()) + .body(Body::from_stream(response.bytes_stream()))?; + *response.headers_mut() = headers; + + Ok(response) } -impl fmt::Display for ErrorCause { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Other(_) => write!(f, "general_error"), - Self::UnableToReadBody(_) => write!(f, "unable_to_read_body"), - Self::PayloadTooLarge(_) => write!(f, "payload_too_large"), - Self::UnableToParseCBOR(_) => write!(f, "unable_to_parse_cbor"), - Self::UnableToParseHTTPArg(_) => write!(f, "unable_to_parse_http_arg"), - Self::LoadShed => write!(f, "load_shed"), - Self::MalformedRequest(_) => write!(f, "malformed_request"), - Self::MalformedResponse(_) => write!(f, "malformed_response"), - Self::CanisterIdNotFound => write!(f, "canister_id_not_found"), - Self::SNIMismatch => write!(f, "sni_mismatch"), - Self::DomainCanisterMismatch => write!(f, "domain_canister_mismatch"), - Self::Denylisted => write!(f, "denylisted"), - Self::NoAuthority => write!(f, "no_authority"), - Self::NoRoutingTable => write!(f, "no_routing_table"), - Self::SubnetNotFound => write!(f, "subnet_not_found"), - Self::NoHealthyNodes => write!(f, "no_healthy_nodes"), - Self::ReplicaErrorDNS(_) => write!(f, "replica_error_dns"), - Self::ReplicaErrorConnect => write!(f, "replica_error_connect"), - Self::ReplicaTimeout => write!(f, "replica_timeout"), - Self::ReplicaTLSErrorOther(_) => write!(f, "replica_tls_error"), - Self::ReplicaTLSErrorCert(_) => write!(f, "replica_tls_error_cert"), - Self::ReplicaErrorOther(_) => write!(f, "replica_error_other"), - Self::RateLimited(x) => write!(f, "rate_limited_{x}"), - } - } +#[derive(new)] +struct IssuerProxyState { + http_client: Arc, + issuers: Vec, + #[new(default)] + next: AtomicUsize, } -// Creates the response from ErrorCause and injects itself into extensions to be visible by middleware -impl IntoResponse for ErrorCause { - fn into_response(self) -> Response { - let mut body = self.to_string(); - - if let Some(v) = self.details() { - body = format!("{body}: {v}"); - } - - let mut resp = (self.status_code(), format!("{body}\n")).into_response(); - resp.extensions_mut().insert(self); - resp - } +// Proxies /registrations endpoint to the certificate issuers if they're defined +async fn issuer_proxy( + State(state): State>, + request: Request, +) -> Result { + // Extract path part from the request + let path = request.uri().path(); + + // Pick next issuer using round-robin & generate request URL for it + let next = state.next.fetch_add(1, Ordering::SeqCst) % state.issuers.len(); + let url = state.issuers[next] + .clone() + .join(path) + .map_err(|_| ErrorCause::MalformedRequest("unable to parse path as URL part".into()))?; + + let response = proxy(url, request, &state.http_client) + .await + .map_err(ErrorCause::from)?; + + Ok(response) } -async fn handler(request: axum::extract::Request) -> impl IntoResponse { +async fn handler(request: Request) -> impl IntoResponse { warn!("{:?}", request.extensions().get::>()); warn!("{:?}", request.extensions().get::()); } @@ -195,7 +125,9 @@ pub fn setup_router( .transpose()?; // Policy - let (policy_state, denylist_runner) = policy::PolicyState::new(cli, http_client, registry)?; + let (policy_state, denylist_runner) = + policy::PolicyState::new(cli, http_client.clone(), registry)?; + let policy_mw = policy_state.map(|x| from_fn_with_state(Arc::new(x), policy::middleware)); // Metrics let metrics_mw = from_fn_with_state( @@ -210,14 +142,32 @@ pub fn setup_router( .layer(metrics_mw) .layer(option_layer(geoip_mw)) .layer(from_fn_with_state(canister_resolver, validate::middleware)) - .layer(from_fn_with_state( - Arc::new(policy_state), - policy::middleware, - )); + .layer(option_layer(policy_mw)); let router = axum::Router::new() - .route("/", axum::routing::get(handler)) + .route("/", get(handler)) .layer(common_layers); + // Setup issuer proxy endpoint if we have them configured + let router = if !cli.cert.issuer_urls.is_empty() { + // Strip possible path from URLs + let mut urls = cli.cert.issuer_urls.clone(); + urls.iter_mut().for_each(|x| x.set_path("")); + + let state = Arc::new(IssuerProxyState::new(http_client, urls)); + + router + .route( + "/registrations/:id", + get(issuer_proxy) + .put(issuer_proxy) + .delete(issuer_proxy) + .with_state(state.clone()), + ) + .route("/registrations", post(issuer_proxy).with_state(state)) + } else { + router + }; + Ok((router, denylist_runner)) } diff --git a/src/tls/acme.rs b/src/tls/acme.rs index 78d8865..80fa5a6 100644 --- a/src/tls/acme.rs +++ b/src/tls/acme.rs @@ -64,7 +64,7 @@ impl Run for AcmeTlsAlpn { // Kick the ACME process forward res = state.next() => { match res.unwrap() { - Ok(v) => warn!("AcmeTlsAlpn: success: event {v:?}"), + Ok(v) => warn!("AcmeTlsAlpn: success: {v:?}"), Err(e) => warn!("AcmeTlsAlpn: error: {e}"), } } diff --git a/src/tls/cert/providers/syncer/mod.rs b/src/tls/cert/providers/syncer/mod.rs index dbeb80a..4b9cca3 100644 --- a/src/tls/cert/providers/syncer/mod.rs +++ b/src/tls/cert/providers/syncer/mod.rs @@ -52,7 +52,8 @@ pub struct CertificatesImporter { } impl CertificatesImporter { - pub fn new(http_client: Arc, exporter_url: Url) -> Self { + pub fn new(http_client: Arc, mut exporter_url: Url) -> Self { + exporter_url.set_path(""); let exporter_url = exporter_url.join("/certificates").unwrap(); Self { diff --git a/src/tls/cert/storage.rs b/src/tls/cert/storage.rs index e8765c6..5419823 100644 --- a/src/tls/cert/storage.rs +++ b/src/tls/cert/storage.rs @@ -88,15 +88,14 @@ impl StoresCertificates for Storage { // Implement certificate resolving for Rustls impl resolver::ResolvesServerCert for StorageKey { fn resolve(&self, ch: &ClientHello) -> Option> { - // See if client provided us with an SNI - let sni = ch.server_name()?; - // Make sure we've got an ALPN list and they're all HTTP, otherwise refuse resolving. // This is to make sure we don't answer to e.g. ACME challenges here if !ch.alpn()?.all(tls::is_http_alpn) { return None; } + // See if client provided us with an SNI + let sni = ch.server_name()?; self.lookup_cert(sni) } } diff --git a/src/tls/mod.rs b/src/tls/mod.rs index f2282a0..6f808d8 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -19,7 +19,7 @@ use rustls_acme::acme::ACME_TLS_ALPN_NAME; use crate::{ cli::Cli, core::Runner, - http, + http::{is_http_alpn, Client, ALPN_H1, ALPN_H2}, tls::{ cert::{providers, Aggregator}, resolver::{AggregatingResolver, ResolvesServerCert}, @@ -28,14 +28,6 @@ use crate::{ use cert::{providers::ProvidesCertificates, storage::StoresCertificates}; -const ALPN_H1: &[u8] = b"http/1.1"; -const ALPN_H2: &[u8] = b"h2"; -const ALPN_HTTP: &[&[u8]] = &[ALPN_H1, ALPN_H2]; - -pub fn is_http_alpn(alpn: &[u8]) -> bool { - ALPN_HTTP.contains(&alpn) -} - pub fn prepare_server_config( resolver: Arc, ) -> ServerConfig { @@ -74,11 +66,10 @@ pub fn prepare_client_config() -> ClientConfig { } // Prepares the stuff needed for serving TLS -#[allow(clippy::type_complexity)] pub fn setup( cli: &Cli, domains: Vec, - http_client: Arc, + http_client: Arc, storage: Arc>>, cert_resolver: Arc, ) -> Result<(Vec, ServerConfig), Error> { diff --git a/src/tls/resolver.rs b/src/tls/resolver.rs index fbd0f54..744f8aa 100644 --- a/src/tls/resolver.rs +++ b/src/tls/resolver.rs @@ -5,15 +5,15 @@ use rustls::{ sign::CertifiedKey, }; -// Custom ResolvesServerCert trait that takes ClientHello by reference. +// Custom ResolvesServerCert trait that borrows ClientHello. // It's needed because Rustls' ResolvesServerCert consumes ClientHello // https://github.com/rustls/rustls/issues/1908 pub trait ResolvesServerCert: Debug + Send + Sync { fn resolve(&self, client_hello: &ClientHello) -> Option>; } -// Combines several certificate resolvers into one -// Only one Rustls-compatible resolver can be used (acme) since it takes ClientHello by value +// Combines several certificate resolvers into one. +// Only one Rustls-compatible resolver can be used since it consumes ClientHello. #[derive(Debug, derive_new::new)] pub struct AggregatingResolver { rustls: Option>, @@ -25,11 +25,11 @@ impl ResolvesServerCertRustls for AggregatingResolver { fn resolve(&self, ch: ClientHello) -> Option> { // Iterate over our resolvers to find matching cert if any let cert = self.resolvers.iter().find_map(|x| x.resolve(&ch)); - if let Some(v) = cert { - return Some(v); + if cert.is_some() { + return cert; } - // Otherwise try the ACME resolver with Rustls trait that consumes ClientHello + // Otherwise try the Rustls-compatible resolver that consumes ClientHello self.rustls.as_ref().and_then(|x| x.resolve(ch)) } }