diff --git a/shotover-proxy/src/runner.rs b/shotover-proxy/src/runner.rs index 1058909fb..ef6079042 100644 --- a/shotover-proxy/src/runner.rs +++ b/shotover-proxy/src/runner.rs @@ -1,3 +1,4 @@ +use std::env; use std::net::SocketAddr; use anyhow::{anyhow, Result}; @@ -9,6 +10,7 @@ use tokio::sync::broadcast; use tokio::task::JoinHandle; use tracing::{debug, error, info}; use tracing_appender::non_blocking::{NonBlocking, WorkerGuard}; +use tracing_subscriber::filter::Directive; use tracing_subscriber::fmt::format::{DefaultFields, Format}; use tracing_subscriber::fmt::Layer; use tracing_subscriber::layer::Layered; @@ -68,7 +70,7 @@ impl Runner { .build() .unwrap(); - let tracing = TracingState::new(config.main_log_level.as_str()); + let tracing = TracingState::new(config.main_log_level.as_str())?; Ok(Runner { runtime, @@ -127,13 +129,39 @@ struct TracingState { Handle, Registry>>, } +/// Returns a new `EnvFilter` by parsing each directive string, or an error if any directive is invalid. +/// The parsing is robust to formatting, but will reject the first invalid directive (e.g. bad log level). +fn try_parse_log_directives(directives: &[Option<&str>]) -> Result { + let directives: Vec = directives + .iter() + .flat_map(Option::as_deref) + .flat_map(|s| s.split(',')) + .map(str::trim) + .filter(|s| !s.is_empty()) + .map(|s| s.parse().map_err(|e| anyhow!("{}: {}", e, s))) + .collect::>()?; + + let filter = directives + .into_iter() + .fold(EnvFilter::default(), |filter, directive| { + filter.add_directive(directive) + }); + + Ok(filter) +} + impl TracingState { - fn new(log_level: &str) -> Self { + fn new(log_level: &str) -> Result { let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); let builder = tracing_subscriber::fmt() .with_writer(non_blocking) - .with_env_filter(log_level) + .with_env_filter({ + // Load log directives from shotover config and then from the RUST_LOG env var, with the latter taking priority. + // In the future we might be able to simplify the implementation if work is done on tokio-rs/tracing#1466. + let overrides = env::var(EnvFilter::DEFAULT_ENV).ok(); + try_parse_log_directives(&[Some(log_level), overrides.as_deref()])? + }) .with_filter_reloading(); let handle = builder.reload_handle(); @@ -141,7 +169,7 @@ impl TracingState { // Currently the implementation of try_init will only fail when it is called multiple times. builder.try_init().ok(); - TracingState { guard, handle } + Ok(TracingState { guard, handle }) } } @@ -185,3 +213,23 @@ pub async fn run( } } } + +#[test] +fn test_try_parse_log_directives() { + assert_eq!( + try_parse_log_directives(&[ + Some("info,short=warn,error"), + None, + Some("debug"), + Some("alongname=trace") + ]) + .unwrap() + .to_string(), + // Ordered by descending specificity. + "alongname=trace,short=warn,debug" + ); + match try_parse_log_directives(&[Some("good=info,bad=blah,warn")]) { + Ok(_) => panic!(), + Err(e) => assert_eq!(e.to_string(), "invalid filter directive: bad=blah"), + } +}