From f20542c9154fe28b77a35b85e7cf882da4700a1a Mon Sep 17 00:00:00 2001 From: Igor Novgorodov Date: Tue, 30 Apr 2024 16:42:23 +0200 Subject: [PATCH] metrics module --- src/cli.rs | 10 + src/core.rs | 23 ++- src/http/mod.rs | 20 ++ src/http/server.rs | 4 + src/main.rs | 2 +- src/metrics/body.rs | 27 ++- src/metrics/mod.rs | 275 +++++++++++++++++++++++++-- src/routing/canister.rs | 74 +++++-- src/routing/middleware/request_id.rs | 10 +- src/routing/middleware/validate.rs | 10 +- src/routing/mod.rs | 8 +- 11 files changed, 407 insertions(+), 56 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 6cc5a9f..9bb56d1 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -38,6 +38,9 @@ pub struct Cli { #[command(flatten, next_help_heading = "Policy")] pub policy: Policy, + #[command(flatten, next_help_heading = "Metrics")] + pub metrics: Metrics, + #[command(flatten, next_help_heading = "Misc")] pub misc: Misc, } @@ -175,6 +178,13 @@ pub struct Policy { pub denylist_poll_interval: Duration, } +#[derive(Args)] +pub struct Metrics { + /// Where to listen for Prometheus metrics scraping + #[clap(long = "metrics-listen")] + pub listen: Option, +} + #[derive(Args)] pub struct Misc { /// Path to a GeoIP database diff --git a/src/core.rs b/src/core.rs index e7f22e4..0fe897a 100644 --- a/src/core.rs +++ b/src/core.rs @@ -2,13 +2,14 @@ use anyhow::Error; use async_trait::async_trait; use prometheus::Registry; use rustls::sign::CertifiedKey; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tracing::{error, warn}; use crate::{ cli::Cli, http::{server, ReqwestClient, Server}, + metrics, routing::{ self, canister::{CanisterResolver, ResolvesCanister}, @@ -30,11 +31,10 @@ pub trait Run: Send + Sync { } pub async fn main(cli: &Cli) -> Result<(), Error> { + // Prepare some general stuff let token = CancellationToken::new(); let tracker = TaskTracker::new(); - let registry = Registry::new(); - let http_client = Arc::new(ReqwestClient::new(cli)?); // Handle SIGTERM/SIGHUP and Ctrl+C @@ -46,14 +46,17 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { let mut domains = cli.domain.domains_system.clone(); domains.extend(cli.domain.domains_app.clone()); + // Prepare certificate storage let storage = Arc::new(Storage::new()); + + // Prepare canister resolver to infer canister_id from requests let canister_resolver = CanisterResolver::new( domains, cli.domain.canister_aliases.clone(), storage.clone() as Arc, )?; - // List of cancellable tasks to execute & watch + // List of cancellable tasks to execute & track let mut runners: Vec<(String, Arc)> = vec![]; // Create a router @@ -68,6 +71,7 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { } let server_options = server::Options::from(&cli.http_server); + // Set up HTTP let http_server = Arc::new(Server::new( cli.http_server.http, @@ -94,7 +98,16 @@ pub async fn main(cli: &Cli) -> Result<(), Error> { )) as Arc; runners.push(("https_server".into(), https_server)); - // Spawn runners + // Setup metrics + if let Some(addr) = cli.metrics.listen { + let (router, runner) = metrics::setup(®istry); + runners.push(("metrics_runner".into(), runner)); + + let srv = Arc::new(Server::new(addr, router, server_options, None)); + runners.push(("metrics_server".into(), srv as Arc)); + } + + // Spawn & track runners for (name, obj) in runners { let token = token.child_token(); tracker.spawn(async move { diff --git a/src/http/mod.rs b/src/http/mod.rs index 6d1a77e..4b2841d 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -2,5 +2,25 @@ pub mod client; pub mod dns; pub mod server; +use http::{HeaderMap, Version}; + pub use client::{Client, ReqwestClient}; pub use server::{ConnInfo, Server}; + +// Calculate very approximate HTTP request/response headers size in bytes. +// More or less accurate only for http/1.1 since in h2 headers are in HPACK-compressed. +// But it seems there's no better way. +pub fn calc_headers_size(h: &HeaderMap) -> usize { + h.iter().map(|(k, v)| k.as_str().len() + v.len() + 2).sum() +} + +pub const fn http_version(v: Version) -> &'static str { + match v { + Version::HTTP_09 => "0.9", + Version::HTTP_10 => "1.0", + Version::HTTP_11 => "1.1", + Version::HTTP_2 => "2.0", + Version::HTTP_3 => "3.0", + _ => "-", + } +} diff --git a/src/http/server.rs b/src/http/server.rs index 0465b95..a110b01 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -90,6 +90,7 @@ impl TryFrom<&ServerConnection> for TlsInfo { #[derive(Clone, Debug)] pub struct ConnInfo { + pub accepted_at: Instant, pub local_addr: SocketAddr, pub remote_addr: SocketAddr, pub tls: Option, @@ -137,6 +138,8 @@ impl Conn { } pub async fn handle(&self, stream: TcpStream) -> Result<(), Error> { + let accepted_at = Instant::now(); + debug!( "Server {}: {}: got a new connection", self.addr, self.remote_addr @@ -164,6 +167,7 @@ impl Conn { // Since it will be cloned for each request served over this connection // it's probably better to wrap it into Arc let conn_info = ConnInfo { + accepted_at, local_addr: self.addr, remote_addr: self.remote_addr, tls: tls_info, diff --git a/src/main.rs b/src/main.rs index 712f7fb..1c7ee11 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), Error> { let cli = Cli::parse(); let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_max_level(tracing::Level::INFO) + .with_max_level(tracing::Level::DEBUG) .finish(); tracing::subscriber::set_global_default(subscriber)?; diff --git a/src/metrics/body.rs b/src/metrics/body.rs index 6387dfe..f48d099 100644 --- a/src/metrics/body.rs +++ b/src/metrics/body.rs @@ -6,6 +6,8 @@ use std::{ task::{Context, Poll}, }; +use crate::http::calc_headers_size; + // Body that counts the bytes streamed pub struct CountingBody { inner: Pin + Send + 'static>>, @@ -74,21 +76,22 @@ where // There is still some data available Poll::Ready(Some(v)) => match v { Ok(buf) => { - // Ignore if it's not a data frame for now. - // It can also be trailers that are uncommon + // Normal data frame if buf.is_data() { self.bytes_sent += buf.data_ref().unwrap().remaining() as u64; + } else if buf.is_trailers() { + // Trailers are very uncommon, for the sake of completeness + self.bytes_sent += calc_headers_size(buf.trailers_ref().unwrap()) as u64; + } - // Check if we already got what was expected - if Some(self.bytes_sent) >= self.expected_size { - self.do_callback(Ok(())); - } + // Check if we already got what was expected + if Some(self.bytes_sent) >= self.expected_size { + self.do_callback(Ok(())); } } // Error occured, execute callback Err(e) => { - // Error is not Copy/Clone so use string instead self.do_callback(Err(e.to_string())); } }, @@ -117,8 +120,13 @@ mod test { #[tokio::test] async fn test_body_stream() { - let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblah"; - let mut stream = tokio_util::io::ReaderStream::new(&data[..]); + let data = b"foobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbl\ + ahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahbla\ + hfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoob\ + arblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarblahblahfoobarbla\ + blahfoobarblahblah"; + + let stream = tokio_util::io::ReaderStream::new(&data[..]); let body = axum::body::Body::from_stream(stream); let (tx, rx) = std::sync::mpsc::channel(); @@ -141,7 +149,6 @@ mod test { #[tokio::test] async fn test_body_full() { let data = vec![0; 512]; - let buf = bytes::Bytes::from_iter(data.clone()); let body = http_body_util::Full::new(buf); diff --git a/src/metrics/mod.rs b/src/metrics/mod.rs index 0529d91..2a38248 100644 --- a/src/metrics/mod.rs +++ b/src/metrics/mod.rs @@ -4,26 +4,42 @@ use std::{ net::SocketAddr, pin::Pin, sync::{atomic::AtomicBool, Arc}, - time::Instant, + time::{Duration, Instant}, }; +use anyhow::Error; use axum::{ + async_trait, + body::Body, extract::{Extension, Request, State}, middleware::Next, response::{IntoResponse, Response}, + routing::get, + Router, }; +use http::header::CONTENT_TYPE; +use hyper::client::conn; +use jemalloc_ctl::{epoch, stats}; use prometheus::{ proto::MetricFamily, register_histogram_vec_with_registry, register_int_counter_vec_with_registry, register_int_gauge_vec_with_registry, register_int_gauge_with_registry, Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGauge, IntGaugeVec, Registry, TextEncoder, }; -use tracing::warn; +use tokio::{select, sync::RwLock}; +use tokio_util::sync::CancellationToken; +use tower_http::compression::CompressionLayer; +use tracing::{debug, info, warn}; -use crate::routing::middleware::request_id::RequestId; +use crate::{ + core::Run, + http::{calc_headers_size, http_version, server::ConnInfo}, + routing::{middleware::request_id::RequestId, ErrorCause, RequestCtx}, +}; use body::CountingBody; const KB: f64 = 1024.0; +const METRICS_CACHE_CAPACITY: usize = 15 * 1024 * 1024; pub const HTTP_DURATION_BUCKETS: &[f64] = &[0.05, 0.2, 1.0, 2.0]; pub const HTTP_REQUEST_SIZE_BUCKETS: &[f64] = &[128.0, KB, 2.0 * KB, 4.0 * KB, 8.0 * KB]; @@ -32,20 +48,148 @@ pub const HTTP_RESPONSE_SIZE_BUCKETS: &[f64] = &[1.0 * KB, 8.0 * KB, 64.0 * KB, // https://prometheus.io/docs/instrumenting/exposition_formats/#basic-info const PROMETHEUS_CONTENT_TYPE: &str = "text/plain; version=0.0.4"; +pub struct MetricsCache { + buffer: Vec, +} + +impl MetricsCache { + pub fn new(capacity: usize) -> Self { + Self { + // Preallocate a large enough vector, it'll be expanded if needed + buffer: Vec::with_capacity(capacity), + } + } +} + +pub struct MetricsRunner { + metrics_cache: Arc>, + registry: Registry, + encoder: TextEncoder, + mem_allocated: IntGauge, + mem_resident: IntGauge, +} + +// Snapshots & encodes the metrics for the handler to export +impl MetricsRunner { + pub fn new(metrics_cache: Arc>, registry: &Registry) -> Self { + let mem_allocated = register_int_gauge_with_registry!( + format!("memory_allocated"), + format!("Allocated memory in bytes"), + registry + ) + .unwrap(); + + let mem_resident = register_int_gauge_with_registry!( + format!("memory_resident"), + format!("Resident memory in bytes"), + registry + ) + .unwrap(); + + Self { + metrics_cache, + registry: registry.clone(), + encoder: TextEncoder::new(), + mem_allocated, + mem_resident, + } + } +} + +impl MetricsRunner { + async fn update(&self) -> Result<(), Error> { + // Record jemalloc memory usage + epoch::advance().unwrap(); + self.mem_allocated + .set(stats::allocated::read().unwrap() as i64); + self.mem_resident + .set(stats::resident::read().unwrap() as i64); + + // Get a snapshot of metrics + let metric_families = self.registry.gather(); + + // Take a write lock, truncate the vector and encode the metrics into it + let mut metrics_cache = self.metrics_cache.write().await; + metrics_cache.buffer.clear(); + self.encoder + .encode(&metric_families, &mut metrics_cache.buffer)?; + drop(metrics_cache); // clippy + + Ok(()) + } +} + +#[async_trait] +impl Run for MetricsRunner { + async fn run(&self, token: CancellationToken) -> Result<(), Error> { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + + warn!("MetricsRunner: started"); + loop { + select! { + biased; + + () = token.cancelled() => { + warn!("MetricsRunner: exited"); + return Ok(()); + } + + _ = interval.tick() => { + let start = Instant::now(); + if let Err(e) = self.update().await { + warn!("Unable to update metrics: {e}"); + } else { + debug!("Metrics updated in {}ms", start.elapsed().as_millis()); + } + } + } + } + } +} + +pub async fn handler(State(state): State>>) -> impl IntoResponse { + // Get a read lock and clone the buffer contents + ( + [(CONTENT_TYPE, PROMETHEUS_CONTENT_TYPE)], + state.read().await.buffer.clone(), + ) +} + +pub fn setup(registry: &Registry) -> (Router, Arc) { + let cache = Arc::new(RwLock::new(MetricsCache::new(METRICS_CACHE_CAPACITY))); + let runner = Arc::new(MetricsRunner::new(cache.clone(), registry)); + + let router = Router::new() + .route("/metrics", get(handler)) + .layer( + CompressionLayer::new() + .gzip(true) + .br(true) + .zstd(true) + .deflate(true), + ) + .with_state(cache); + + (router, runner as Arc) +} + #[derive(Clone)] pub struct HttpMetricParams { - pub counter: IntCounterVec, - pub durationer: HistogramVec, - pub request_sizer: HistogramVec, - pub response_sizer: HistogramVec, + pub requests: IntCounterVec, + pub duration: HistogramVec, + pub duration_full: HistogramVec, + pub request_size: HistogramVec, + pub response_size: HistogramVec, } impl HttpMetricParams { pub fn new(registry: &Registry) -> Self { - const LABELS_HTTP: &[&str] = &["domain", "status_code", "error_cause", "cache_status"]; + const LABELS_HTTP: &[&str] = &[ + "tls", "method", "http", "domain", "status", "error", "cache", + ]; Self { - counter: register_int_counter_vec_with_registry!( + requests: register_int_counter_vec_with_registry!( format!("http_total"), format!("Counts occurrences of requests"), LABELS_HTTP, @@ -53,7 +197,7 @@ impl HttpMetricParams { ) .unwrap(), - durationer: register_histogram_vec_with_registry!( + duration: register_histogram_vec_with_registry!( format!("http_duration_sec"), format!("Records the duration of request processing in seconds"), LABELS_HTTP, @@ -62,7 +206,16 @@ impl HttpMetricParams { ) .unwrap(), - request_sizer: register_histogram_vec_with_registry!( + duration_full: register_histogram_vec_with_registry!( + format!("http_duration_full_sec"), + format!("Records the full duration of request processing including response streaming in seconds"), + LABELS_HTTP, + HTTP_DURATION_BUCKETS.to_vec(), + registry + ) + .unwrap(), + + request_size: register_histogram_vec_with_registry!( format!("http_request_size"), format!("Records the size of requests"), LABELS_HTTP, @@ -71,7 +224,7 @@ impl HttpMetricParams { ) .unwrap(), - response_sizer: register_histogram_vec_with_registry!( + response_size: register_histogram_vec_with_registry!( format!("http_response_size"), format!("Records the size of responses"), LABELS_HTTP, @@ -84,17 +237,107 @@ impl HttpMetricParams { } pub async fn middleware( - State(state): State, + State(state): State>, + Extension(conn_info): Extension>, Extension(request_id): Extension, request: Request, next: Next, ) -> impl IntoResponse { + // Prepare to execute the request and count its body size + let (parts, body) = request.into_parts(); + let (tx, rx) = std::sync::mpsc::sync_channel(1); + let request_callback = move |size: u64, _: Result<(), String>| { + let _ = tx.send(size); + }; + let body = Body::new(CountingBody::new(body, request_callback)); + let request = Request::from_parts(parts, body); + + // Gather needed stuff from request before it's consumed + let method = request.method().clone(); + let http_version = http_version(request.version()); + let request_size_headers = calc_headers_size(request.headers()) as u64; + let uri = request.uri().clone(); + + // Execute the request let response = next.run(request).await; + let duration = conn_info.accepted_at.elapsed(); + + let ctx = response.extensions().get::>().cloned(); + let error_cause = response.extensions().get::().cloned(); + let status = response.status().as_u16().to_string(); + + // By this time the channel should already have the data + // since the response headers are already received -> request body was for sure read + let request_size = rx.recv().unwrap_or(0) + request_size_headers; + let (parts, body) = response.into_parts(); - let record_metrics = - move |response_size: u64, _body_result: Result<(), String>| warn!("{}", response_size); + let response_callback = move |response_size: u64, _: Result<(), String>| { + let duration_full = conn_info.accepted_at.elapsed(); + + let (tls_version, tls_cipher) = conn_info + .tls + .as_ref() + .map(|x| (x.protocol.as_str().unwrap(), x.cipher.as_str().unwrap())) + .unwrap_or(("unknown", "unknown")); + let domain = ctx + .as_ref() + .map(|x| x.canister.domain.to_string()) + .unwrap_or_else(|| "unknown".into()); + let error_cause = error_cause + .clone() + .map(|x| x.to_string()) + .unwrap_or_else(|| "no".into()); + + let labels = &[ + tls_version, + method.as_str(), + http_version, + &domain, + &status, + &error_cause, + "BYPASS", // TODO fill when cache is implemented + ]; + + // Update metrics + state.requests.with_label_values(labels).inc(); + state + .duration + .with_label_values(labels) + .observe(duration.as_secs_f64()); + state + .duration_full + .with_label_values(labels) + .observe(duration_full.as_secs_f64()); + state + .request_size + .with_label_values(labels) + .observe(request_size as f64); + state + .response_size + .with_label_values(labels) + .observe(response_size as f64); + + // Log the request + info!( + request_id = request_id.to_string(), + method = method.as_str(), + http = http_version, + status, + tls_version, + tls_cipher, + domain, + host = ctx.as_ref().map(|x| x.authority.to_string()), + path = uri.path(), + canister_id = ctx.as_ref().map(|x| x.canister.id.to_string()), + error = error_cause, + req_size = request_size, + resp_size = response_size, + dur = duration.as_millis(), + dur_full = duration_full.as_millis(), + ); + }; - let body = CountingBody::new(body, record_metrics); + let body = CountingBody::new(body, response_callback); Response::from_parts(parts, body) } diff --git a/src/routing/canister.rs b/src/routing/canister.rs index a61a439..47596d6 100644 --- a/src/routing/canister.rs +++ b/src/routing/canister.rs @@ -6,8 +6,6 @@ use fqdn::{Fqdn, FQDN}; use crate::tls::cert::LooksupCustomDomain; -const INVALID_ALIAS_FORMAT: &str = "Invalid alias format, must be 'alias:canister_id'"; - // Alias for a canister under all served domains. // E.g. an alias 'nns' would resolve under both 'nns.ic0.app' and 'nns.icp0.io' #[derive(Clone)] @@ -17,6 +15,8 @@ impl FromStr for CanisterAlias { type Err = Error; fn from_str(value: &str) -> Result { + const INVALID_ALIAS_FORMAT: &str = "Invalid alias format, must be ':'"; + match value.split_once(':') { Some((alias, principal)) => { if alias.is_empty() { @@ -36,9 +36,10 @@ impl FromStr for CanisterAlias { } // Combination of canister id and whether we need to verify the response -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Debug, Eq, PartialEq)] pub struct Canister { pub id: Principal, + pub domain: FQDN, pub verify: bool, } @@ -67,6 +68,7 @@ impl CanisterResolver { FQDN::from_str(&format!("{}.{d}", a.0))?, Canister { id: a.1, + domain: d.clone(), verify: true, }, )); @@ -86,7 +88,7 @@ impl CanisterResolver { self.aliases .iter() .find(|x| host.is_subdomain_of(&x.0)) - .map(|x| x.1) + .map(|x| x.1.clone()) } // Tries to resolve canister id from . or .raw. formatted hostname @@ -117,7 +119,7 @@ impl CanisterResolver { return None; } - Some(Canister { id, verify }) + Some(Canister { id, domain, verify }) } } @@ -128,7 +130,11 @@ impl ResolvesCanister for CanisterResolver { .or_else(|| self.resolve_domain(host)) .or_else(|| { let id = self.custom_domains.lookup_custom_domain(host)?; - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: host.to_owned(), + verify: true, + }) }) } } @@ -153,6 +159,9 @@ mod test { let a = CanisterAlias::from_str(":aaaaa-aa"); assert!(a.is_err()); + let a = CanisterAlias::from_str("|||:aaaaa-aa"); + assert!(a.is_err()); + // All is empty let a = CanisterAlias::from_str(":"); assert!(a.is_err()); @@ -198,6 +207,7 @@ mod test { canister, Some(Canister { id: a.1, + domain: d.clone(), verify: true }) ); @@ -225,30 +235,54 @@ mod test { // Normal & raw assert_eq!( resolver.resolve_domain(&fqdn!("aaaaa-aa.ic0.app")), - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: true + }) ); assert_eq!( resolver.resolve_domain(&fqdn!("aaaaa-aa.icp0.io")), - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: fqdn!("icp0.io"), + verify: true + }) ); assert_eq!( resolver.resolve_domain(&fqdn!("aaaaa-aa.raw.ic0.app")), - Some(Canister { id, verify: false }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: false + }) ); assert_eq!( resolver.resolve_domain(&fqdn!("aaaaa-aa.raw.icp0.io")), - Some(Canister { id, verify: false }) + Some(Canister { + id, + domain: fqdn!("icp0.io"), + verify: false + }) ); // foo-- assert_eq!( resolver.resolve_domain(&fqdn!("foo--aaaaa-aa.ic0.app")), - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: true + }) ); assert_eq!( resolver.resolve_domain(&fqdn!("asndjasldfajlsd--aaaaa-aa.ic0.app")), - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: true + }) ); // Nested subdomain should not match @@ -267,6 +301,7 @@ mod test { resolver.resolve_canister(&fqdn!("nns.ic0.app")), Some(Canister { id: Principal::from_text("qoctq-giaaa-aaaaa-aaaea-cai").unwrap(), + domain: fqdn!("ic0.app"), verify: true }) ); @@ -274,12 +309,20 @@ mod test { // Resolve from hostname assert_eq!( resolver.resolve_canister(&fqdn!("aaaaa-aa.ic0.app")), - Some(Canister { id, verify: true }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: true + }) ); assert_eq!( resolver.resolve_canister(&fqdn!("aaaaa-aa.raw.ic0.app")), - Some(Canister { id, verify: false }) + Some(Canister { + id, + domain: fqdn!("ic0.app"), + verify: false + }) ); // Resolve custom domain @@ -287,7 +330,8 @@ mod test { resolver.resolve_canister(&fqdn!("foo.baz")), Some(Canister { id: Principal::from_text(TEST_CANISTER_ID).unwrap(), - verify: true + domain: fqdn!("foo.baz"), + verify: true, }) ); diff --git a/src/routing/middleware/request_id.rs b/src/routing/middleware/request_id.rs index eba74e7..e7b12d1 100644 --- a/src/routing/middleware/request_id.rs +++ b/src/routing/middleware/request_id.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use axum::{extract::Request, middleware::Next, response::Response}; use bytes::Bytes; use http::header::{HeaderName, HeaderValue}; @@ -9,7 +11,7 @@ const HEADER: HeaderName = HeaderName::from_static("x-request-id"); #[derive(Clone, Copy)] pub struct RequestId(Uuid); -// Generate & insert request uuid into extensions and headers +// Generate & insert request UUID into extensions and headers pub async fn middleware(mut request: Request, next: Next) -> Response { let request_id = RequestId(Uuid::now_v7()); let hdr = request_id.0.as_hyphenated().to_string(); @@ -23,3 +25,9 @@ pub async fn middleware(mut request: Request, next: Next) -> Response { response.headers_mut().insert(HEADER, hdr); response } + +impl Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0.hyphenated()) + } +} diff --git a/src/routing/middleware/validate.rs b/src/routing/middleware/validate.rs index a5d26f7..769e1a0 100644 --- a/src/routing/middleware/validate.rs +++ b/src/routing/middleware/validate.rs @@ -20,6 +20,7 @@ pub async fn middleware( mut request: Request, next: Next, ) -> Result { + // Extract the authority let authority = match extract_authority(&request) { Some(v) => v, None => return Err(ErrorCause::NoAuthority), @@ -37,13 +38,14 @@ pub async fn middleware( .resolve_canister(&authority) .ok_or(ErrorCause::CanisterIdNotFound)?; - println!("{:?}", canister.id.to_string()); - let ctx = Arc::new(RequestCtx { authority, canister, }); - request.extensions_mut().insert(ctx); + request.extensions_mut().insert(ctx.clone()); + + let mut response = next.run(request).await; + response.extensions_mut().insert(ctx); - Ok(next.run(request).await) + Ok(response) } diff --git a/src/routing/mod.rs b/src/routing/mod.rs index 2c861a7..b63da96 100644 --- a/src/routing/mod.rs +++ b/src/routing/mod.rs @@ -29,8 +29,8 @@ use self::canister::{Canister, ResolvesCanister}; pub struct RequestCtx { // HTTP2 authority or HTTP1 Host header - authority: FQDN, - canister: Canister, + pub authority: FQDN, + pub canister: Canister, } #[derive(Debug, Clone, Display)] @@ -70,7 +70,7 @@ pub enum ErrorCause { } impl ErrorCause { - pub fn status_code(&self) -> StatusCode { + pub const fn status_code(&self) -> StatusCode { match self { Self::Other(_) => StatusCode::INTERNAL_SERVER_ERROR, Self::PayloadTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE, @@ -200,7 +200,7 @@ pub fn setup_router( // Metrics let metrics_mw = from_fn_with_state( - metrics::HttpMetricParams::new(registry), + Arc::new(metrics::HttpMetricParams::new(registry)), metrics::middleware, );