diff --git a/Cargo.lock b/Cargo.lock index be718f93e..09c09a684 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -247,7 +247,7 @@ dependencies = [ "fastrand", "hex", "http 0.2.11", - "hyper", + "hyper 0.14.28", "ring", "time", "tokio", @@ -283,7 +283,7 @@ dependencies = [ "bytes", "fastrand", "http 0.2.11", - "http-body", + "http-body 0.4.6", "percent-encoding", "pin-project-lite", "tracing", @@ -425,7 +425,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.11", - "http-body", + "http-body 0.4.6", "once_cell", "percent-encoding", "pin-project-lite", @@ -464,10 +464,10 @@ dependencies = [ "aws-smithy-types", "bytes", "fastrand", - "h2", + "h2 0.3.24", "http 0.2.11", - "http-body", - "hyper", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-rustls", "once_cell", "pin-project-lite", @@ -504,7 +504,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.11", - "http-body", + "http-body 0.4.6", "itoa", "num-integer", "pin-project-lite", @@ -560,6 +560,58 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1236b4b292f6c4d6dc34604bb5120d85c3fe1d1aa596bd5cc52ca054d13e7b9e" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", + "hyper 1.2.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.69" @@ -2070,6 +2122,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", + "indexmap 2.2.2", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.3.1" @@ -2204,6 +2275,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.8.0" @@ -2226,9 +2320,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.24", "http 0.2.11", - "http-body", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -2240,6 +2334,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.2", + "http 1.0.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -2248,7 +2362,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.11", - "hyper", + "hyper 0.14.28", "log", "rustls 0.21.10", "rustls-native-certs 0.6.3", @@ -2263,12 +2377,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper", + "hyper 0.14.28", "native-tls", "tokio", "tokio-native-tls", ] +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.2.0", + "pin-project-lite", + "socket2 0.5.5", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -2563,6 +2693,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "md5" version = "0.7.0" @@ -3134,6 +3270,26 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "pin-project" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0302c4a0442c456bd56f841aee5c3bfd17967563f6fadc9ceb9f9c23cf3807e0" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + [[package]] name = "pin-project-lite" version = "0.2.13" @@ -3642,10 +3798,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.24", "http 0.2.11", - "http-body", - "hyper", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-tls", "ipnet", "js-sys", @@ -4296,6 +4452,7 @@ dependencies = [ "atomic_enum", "aws-config", "aws-sdk-kms", + "axum", "backtrace", "backtrace-ext", "base64", @@ -4322,7 +4479,6 @@ dependencies = [ "hex-literal", "http 1.0.0", "httparse", - "hyper", "itertools 0.12.1", "kafka-protocol", "lz4_flex", @@ -4377,7 +4533,6 @@ dependencies = [ "futures", "hex", "hex-literal", - "hyper", "itertools 0.12.1", "nix", "opensearch", @@ -4970,6 +5125,27 @@ dependencies = [ "winnow", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" + [[package]] name = "tower-service" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index e9c990085..8a85558c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,5 +58,4 @@ typetag = "0.2.5" aws-throwaway = { version = "0.6.0", default-features = false } tokio-bin-process = "0.4.0" ordered-float = { version = "4.0.0", features = ["serde"] } -hyper = { version = "0.14.14", features = ["server", "tcp", "http1"] } shell-quote = { default-features = false, version = "0.5.0" } diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 2cfbeda61..7aed9955f 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -54,7 +54,6 @@ cql-ws = { git = "https://github.com/shotover/cql-ws" } opensearch = "2.1.0" serde_json = "1.0.103" time = { version = "0.3.25" } -hyper.workspace = true shell-quote.workspace = true [features] diff --git a/shotover-proxy/tests/transforms/tee.rs b/shotover-proxy/tests/transforms/tee.rs index 5f3bcfa33..78d3b40e3 100644 --- a/shotover-proxy/tests/transforms/tee.rs +++ b/shotover-proxy/tests/transforms/tee.rs @@ -1,5 +1,4 @@ use crate::shotover_process; -use hyper::{body, Body, Client, Method, Request, Response}; use test_helpers::connection::redis_connection; use test_helpers::docker_compose::docker_compose; use test_helpers::shotover_process::{EventMatcher, Level}; @@ -205,23 +204,6 @@ async fn test_subchain_with_mismatch() { shotover.shutdown_and_then_consume_events(&[]).await; } -async fn read_response_body(res: Response) -> Result { - let bytes = body::to_bytes(res.into_body()).await?; - Ok(String::from_utf8(bytes.to_vec()).expect("response was not valid utf-8")) -} - -async fn hyper_request(uri: String, method: Method, body: Body) -> Response { - let client = Client::new(); - - let req = Request::builder() - .method(method) - .uri(uri) - .body(body) - .expect("request builder"); - - client.request(req).await.unwrap() -} - #[tokio::test(flavor = "multi_thread")] async fn test_switch_main_chain() { let shotover = shotover_process("tests/test-configs/tee/switch_chain.yaml") @@ -243,26 +225,11 @@ async fn test_switch_main_chain() { assert_eq!("a", result); - let _ = hyper_request( - format!( - "http://localhost:{}/transform/tee/result-source", - switch_port - ), - Method::PUT, - Body::from("tee-chain"), - ) - .await; + let url = format!("http://localhost:{switch_port}/transform/tee/result-source"); + let client = reqwest::Client::new(); + client.put(&url).body("tee-chain").send().await.unwrap(); - let res = hyper_request( - format!( - "http://localhost:{}/transform/tee/result-source", - switch_port - ), - Method::GET, - Body::empty(), - ) - .await; - let body = read_response_body(res).await.unwrap(); + let body = client.get(&url).send().await.unwrap().text().await.unwrap(); assert_eq!("tee-chain", body); let result = redis::cmd("SET") @@ -274,15 +241,7 @@ async fn test_switch_main_chain() { assert_eq!("b", result); - let _ = hyper_request( - format!( - "http://localhost:{}/transform/tee/result-source", - switch_port - ), - Method::PUT, - Body::from("regular-chain"), - ) - .await; + client.put(&url).body("regular-chain").send().await.unwrap(); let result = redis::cmd("SET") .arg("key") diff --git a/shotover/Cargo.toml b/shotover/Cargo.toml index b0e714bdb..b055f0cff 100644 --- a/shotover/Cargo.toml +++ b/shotover/Cargo.toml @@ -47,6 +47,7 @@ opensearch = [ default = ["cassandra", "redis", "kafka", "opensearch"] [dependencies] +axum = { version = "0.7", default-features = false, features = ["tokio", "tracing", "http1"] } atomic_enum = "0.2.0" pretty-hex = "0.4.0" tokio-stream = "0.1.2" @@ -96,7 +97,6 @@ metrics-exporter-prometheus = { version = "0.13.0", default-features = false } tracing.workspace = true tracing-subscriber.workspace = true tracing-appender.workspace = true -hyper.workspace = true halfbrown = { version = "0.2.1", optional = true } # Transform dependencies diff --git a/shotover/src/http.rs b/shotover/src/http.rs new file mode 100644 index 000000000..f637a8b8b --- /dev/null +++ b/shotover/src/http.rs @@ -0,0 +1,27 @@ +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; + +// Make our own error that wraps `anyhow::Error`. +pub(crate) struct HttpServerError(pub anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for HttpServerError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("HTTP 500 error: {}", self.0), + ) + .into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for HttpServerError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} diff --git a/shotover/src/lib.rs b/shotover/src/lib.rs index 44180d51c..a80b70749 100644 --- a/shotover/src/lib.rs +++ b/shotover/src/lib.rs @@ -55,6 +55,7 @@ pub mod codec; pub mod config; mod connection_span; pub mod frame; +mod http; pub mod message; mod observability; pub mod runner; diff --git a/shotover/src/observability/mod.rs b/shotover/src/observability/mod.rs index dd96acfe8..4728089e8 100644 --- a/shotover/src/observability/mod.rs +++ b/shotover/src/observability/mod.rs @@ -1,12 +1,8 @@ +use crate::http::HttpServerError; use crate::runner::ReloadHandle; use anyhow::{anyhow, Context, Result}; -use bytes::Bytes; -use hyper::{ - service::{make_service_fn, service_fn}, - Method, Request, StatusCode, {Body, Response, Server}, -}; +use axum::{extract::State, response::Html, Router}; use metrics_exporter_prometheus::PrometheusHandle; -use std::convert::Infallible; use std::str; use std::{net::SocketAddr, sync::Arc}; use tracing::{error, trace}; @@ -18,21 +14,6 @@ pub(crate) struct LogFilterHttpExporter { tracing_handle: ReloadHandle, } -/// Sets the `tracing_suscriber` filter level to the value of `bytes` on `handle` -fn set_filter(bytes: Bytes, handle: &ReloadHandle) -> Result<()> { - let body = str::from_utf8(bytes.as_ref())?; - trace!(request.body = ?body); - let new_filter = body.parse::()?; - handle.reload(new_filter) -} - -fn rsp(status: StatusCode, body: impl Into) -> Response { - Response::builder() - .status(status) - .body(body.into()) - .expect("builder with known status code must not fail") -} - impl LogFilterHttpExporter { /// Creates a new [`LogFilterHttpExporter`] that listens on the given `address`. /// @@ -58,55 +39,46 @@ impl LogFilterHttpExporter { } async fn async_run_inner(self) -> Result<()> { - let recorder_handle = Arc::new(self.recorder_handle); - let tracing_handle = Arc::new(self.tracing_handle); + let state = AppState { + recorder_handle: Arc::new(self.recorder_handle), + tracing_handle: Arc::new(self.tracing_handle), + }; - let make_svc = make_service_fn(move |_| { - let recorder_handle = recorder_handle.clone(); - let tracing_handle = tracing_handle.clone(); - - async move { - Ok::<_, Infallible>(service_fn(move |req: Request| { - let recorder_handle = recorder_handle.clone(); - let tracing_handle = tracing_handle.clone(); - - async move { - let response = match (req.method(), req.uri().path()) { - (&Method::GET, "/metrics") => { - Response::new(Body::from(recorder_handle.as_ref().render())) - } - (&Method::PUT, "/filter") => { - trace!("setting filter"); - match hyper::body::to_bytes(req).await { - Ok(body) => match set_filter(body, &tracing_handle) { - Err(error) => { - error!(?error, "setting filter failed!"); - rsp( - StatusCode::INTERNAL_SERVER_ERROR, - format!("{:?}", error), - ) - } - Ok(()) => rsp(StatusCode::NO_CONTENT, Body::empty()), - }, - Err(error) => { - error!(%error, "setting filter failed - Couldn't read bytes"); - rsp(StatusCode::INTERNAL_SERVER_ERROR, format!("{error:?}")) - } - } - } - _ => rsp(StatusCode::NOT_FOUND, "try '/filter' or `/metrics`"), - }; - Ok::<_, Infallible>(response) - } - })) - } - }); + let app = Router::new() + .route("/", axum::routing::get(root)) + .route("/metrics", axum::routing::get(serve_metrics)) + .route("/filter", axum::routing::put(put_filter)) + .with_state(state); let address = self.address; - Server::try_bind(&address) - .with_context(|| format!("Failed to bind to {}", address))? - .serve(make_svc) + let listener = tokio::net::TcpListener::bind(address) .await - .map_err(|e| anyhow!(e)) + .with_context(|| format!("Failed to bind to {}", address))?; + axum::serve(listener, app).await.map_err(|e| anyhow!(e)) } } + +async fn root() -> Html<&'static str> { + Html("try /filter or /metrics") +} + +async fn serve_metrics(State(state): State) -> Html { + Html(state.recorder_handle.as_ref().render()) +} + +async fn put_filter( + State(state): State, + new_filter_string: String, +) -> Result, HttpServerError> { + trace!("setting filter to: {new_filter_string}"); + let new_filter = new_filter_string.parse::()?; + state.tracing_handle.reload(new_filter)?; + tracing::info!("filter set to: {new_filter_string}"); + Ok(Html("Filter set")) +} + +#[derive(Clone)] +struct AppState { + tracing_handle: Arc, + recorder_handle: Arc, +} diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index 061c2de88..8987bdade 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,21 +1,20 @@ use super::TransformContextConfig; use crate::config::chain::TransformChainConfig; +use crate::http::HttpServerError; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use atomic_enum::atomic_enum; -use bytes::Bytes; -use hyper::{ - service::{make_service_fn, service_fn}, - Method, Request, StatusCode, {Body, Response, Server}, -}; +use axum::extract::State; +use axum::response::Html; +use axum::Router; use metrics::{counter, Counter}; use serde::{Deserialize, Serialize}; use std::fmt; use std::sync::atomic::Ordering; -use std::{convert::Infallible, net::SocketAddr, str, sync::Arc}; +use std::{net::SocketAddr, str, sync::Arc}; use tracing::{debug, error, trace, warn}; pub struct TeeBuilder { @@ -344,36 +343,6 @@ impl ChainSwitchListener { Self { address } } - fn rsp(status: StatusCode, body: impl Into) -> Response { - Response::builder() - .status(status) - .body(body.into()) - .expect("builder with known status code must not fail") - } - - async fn set_result_source_chain( - body: Bytes, - result_source: Arc, - ) -> Result<()> { - let new_result_source = str::from_utf8(body.as_ref())?; - - let new_value = match new_result_source { - "tee-chain" => ResultSource::TeeChain, - "regular-chain" => ResultSource::RegularChain, - _ => { - return Err(anyhow!( - r"Invalid value for result source: {}, should be 'tee-chain' or 'regular-chain'", - new_result_source - )) - } - }; - - debug!("Setting result source to {}", new_value); - - result_source.store(new_value, Ordering::Relaxed); - Ok(()) - } - async fn async_run(self, result_source: Arc) { if let Err(err) = self.async_run_inner(result_source).await { error!("Error in ChainSwitchListener: {}", err); @@ -381,67 +350,61 @@ impl ChainSwitchListener { } async fn async_run_inner(self, result_source: Arc) -> Result<()> { - let make_svc = make_service_fn(move |_| { - let result_source = result_source.clone(); - async move { - Ok::<_, Infallible>(service_fn(move |req: Request| { - let result_source = result_source.clone(); - async move { - let response = match (req.method(), req.uri().path()) { - (&Method::GET, "/transform/tee/result-source") => { - let result_source: ResultSource = - result_source.load(Ordering::Relaxed); - Self::rsp(StatusCode::OK, result_source.to_string()) - } - (&Method::PUT, "/transform/tee/result-source") => { - match hyper::body::to_bytes(req.into_body()).await { - Ok(body) => { - match Self::set_result_source_chain( - body, - result_source.clone(), - ) - .await - { - Err(error) => { - error!(?error, "setting result source failed"); - Self::rsp( - StatusCode::BAD_REQUEST, - format!( - "setting result source failed: {error}" - ), - ) - } - Ok(()) => Self::rsp(StatusCode::OK, Body::empty()), - } - } - Err(error) => { - error!(%error, "setting result source failed - Couldn't read bytes"); - Self::rsp( - StatusCode::INTERNAL_SERVER_ERROR, - format!("{error:?}"), - ) - } - } - } - _ => { - Self::rsp(StatusCode::NOT_FOUND, "try /tranform/tee/result-source") - } - }; - Ok::<_, Infallible>(response) - } - })) - } - }); + let app = Router::new() + .route("/", axum::routing::get(root)) + .route( + "/transform/tee/result-source", + axum::routing::get(get_result_source), + ) + .route( + "/transform/tee/result-source", + axum::routing::put(put_result_source), + ) + .with_state(AppState { result_source }); let address = self.address; - Server::try_bind(&address) - .with_context(|| format!("Failed to bind to {}", address))? - .serve(make_svc) + let listener = tokio::net::TcpListener::bind(address) .await - .map_err(|e| anyhow!(e)) + .with_context(|| format!("Failed to bind to {}", address))?; + axum::serve(listener, app).await.map_err(|e| anyhow!(e)) } } +async fn root() -> Html<&'static str> { + Html("try /transform/tee/result-source") +} + +async fn get_result_source(State(state): State) -> Html { + let result_source: ResultSource = state.result_source.load(Ordering::Relaxed); + Html(result_source.to_string()) +} + +async fn put_result_source( + State(state): State, + new_result_source: String, +) -> Result<(), HttpServerError> { + let new_value = match new_result_source.as_str() { + "tee-chain" => ResultSource::TeeChain, + "regular-chain" => ResultSource::RegularChain, + _ => { + return Err(HttpServerError(anyhow!( + r"Invalid value for result source: {:?}, should be 'tee-chain' or 'regular-chain'", + new_result_source + ))); + } + }; + + state.result_source.store(new_value, Ordering::Relaxed); + tracing::info!("result source set to {new_value}"); + + Ok(()) +} + +#[derive(Clone)] +struct AppState { + result_source: Arc, +} + #[cfg(all(test, feature = "redis"))] mod tests { use super::*;