diff --git a/Cargo.lock b/Cargo.lock index 513c729aa..a04354f45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4456,6 +4456,7 @@ dependencies = [ "futures", "hex", "hex-literal", + "hyper", "inferno", "itertools 0.11.0", "nix 0.27.1", diff --git a/Cargo.toml b/Cargo.toml index c9fb1a0f7..bc2679e8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,3 +64,4 @@ typetag = "0.2.5" aws-throwaway = "0.3.0" tokio-bin-process = "0.4.0" ordered-float = { version = "4.0.0", features = ["serde"] } +hyper = { version = "0.14.14", features = ["server"] } diff --git a/docs/src/transforms.md b/docs/src/transforms.md index 5ae551fa7..1939d595a 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -506,6 +506,12 @@ This is mainly used in conjunction with the `TuneableConsistencyScatter` transfo This transform sends messages to both the defined sub chain and the remaining down-chain transforms. The response from the down-chain transform is returned back up-chain but various behaviours can be defined by the `behaviour` field to handle the case when the responses from the sub chain and down-chain do not match. +Tee also exposes an optional HTTP API to switch which chain to use as the "result source", that is the chain to return responses from. + +`GET` `/transform/tee/result-source` will return `regular-chain` or `tee-chain` indicating which chain is being used for the result source. + +`PUT` `/transform/tee/result-source` with the body content as either `regular-chain` or `tee-chain` to set the result source. + ```yaml - Tee: # Ignore responses returned by the sub chain @@ -528,6 +534,10 @@ The response from the down-chain transform is returned back up-chain but various # filter: Read # - NullSink + # The port that the HTTP API will listen on. + # When this field is not provided the HTTP API will not be run. + # http_api_port: 1234 + # # Timeout for sending to the sub chain in microseconds timeout_micros: 1000 # The number of message batches that the tee can hold onto in its buffer of messages to send. diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml index 7b8de0bff..2e1cb09dd 100644 --- a/shotover-proxy/Cargo.toml +++ b/shotover-proxy/Cargo.toml @@ -54,6 +54,7 @@ serde_json = "1.0.103" time = { version = "0.3.25" } inferno = { version = "0.11.15", default-features = false, features = ["multithreaded", "nameattr"] } shell-quote = "0.3.0" +hyper.workspace = true [features] # Include WIP alpha transforms in the public API diff --git a/shotover-proxy/tests/test-configs/tee/switch_chain.yaml b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml new file mode 100644 index 000000000..9af3ed9e1 --- /dev/null +++ b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml @@ -0,0 +1,46 @@ +--- +sources: + - Redis: + name: "redis-1" + listen_addr: "127.0.0.1:6371" + connection_limit: + chain: + - Tee: + behavior: Ignore + buffer_size: 10000 + switch_port: 1231 + chain: + - DebugReturner: + Redis: "b" + - DebugReturner: + Redis: "a" + - Redis: + name: "redis-3" + listen_addr: "127.0.0.1:6372" + connection_limit: + chain: + - Tee: + behavior: + SubchainOnMismatch: + - NullSink + buffer_size: 10000 + switch_port: 1232 + chain: + - DebugReturner: + Redis: "b" + - DebugReturner: + Redis: "a" + - Redis: + name: "redis-3" + listen_addr: "127.0.0.1:6373" + connection_limit: + chain: + - Tee: + behavior: LogWarningOnMismatch + buffer_size: 10000 + switch_port: 1233 + chain: + - DebugReturner: + Redis: "b" + - DebugReturner: + Redis: "a" diff --git a/shotover-proxy/tests/transforms/tee.rs b/shotover-proxy/tests/transforms/tee.rs index b6bc41758..1f944d2ee 100644 --- a/shotover-proxy/tests/transforms/tee.rs +++ b/shotover-proxy/tests/transforms/tee.rs @@ -1,4 +1,5 @@ use crate::shotover_process; +use hyper::{body, Body, Client, Method, Request, Response}; use test_helpers::connection::redis_connection; use test_helpers::docker_compose::docker_compose; use test_helpers::shotover_process::{EventMatcher, Level}; @@ -193,3 +194,99 @@ async fn test_subchain_with_mismatch() { assert_eq!("myvalue", result); shotover.shutdown_and_then_consume_events(&[]).await; } + +async fn read_response_body(res: Response) -> Result { + let bytes = body::to_bytes(res.into_body()).await?; + Ok(String::from_utf8(bytes.to_vec()).expect("response was not valid utf-8")) +} + +async fn hyper_request(uri: String, method: Method, body: Body) -> Response { + let client = Client::new(); + + let req = Request::builder() + .method(method) + .uri(uri) + .body(body) + .expect("request builder"); + + client.request(req).await.unwrap() +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_switch_main_chain() { + let shotover = shotover_process("tests/test-configs/tee/switch_chain.yaml") + .start() + .await; + + for i in 1..=3 { + let redis_port = 6370 + i; + let switch_port = 1230 + i; + + let mut connection = redis_connection::new_async("127.0.0.1", redis_port).await; + + let result = redis::cmd("SET") + .arg("key") + .arg("myvalue") + .query_async::<_, String>(&mut connection) + .await + .unwrap(); + + assert_eq!("a", result); + + let _ = hyper_request( + format!( + "http://localhost:{}/transform/tee/result-source", + switch_port + ), + Method::PUT, + Body::from("tee-chain"), + ) + .await; + + let res = hyper_request( + format!( + "http://localhost:{}/transform/tee/result-source", + switch_port + ), + Method::GET, + Body::empty(), + ) + .await; + let body = read_response_body(res).await.unwrap(); + assert_eq!("tee-chain", body); + + let result = redis::cmd("SET") + .arg("key") + .arg("myvalue") + .query_async::<_, String>(&mut connection) + .await + .unwrap(); + + assert_eq!("b", result); + + let _ = hyper_request( + format!( + "http://localhost:{}/transform/tee/result-source", + switch_port + ), + Method::PUT, + Body::from("regular-chain"), + ) + .await; + + let result = redis::cmd("SET") + .arg("key") + .arg("myvalue") + .query_async::<_, String>(&mut connection) + .await + .unwrap(); + + assert_eq!("a", result); + } + + shotover + .shutdown_and_then_consume_events(&[EventMatcher::new() + .with_level(Level::Warn) + .with_count(tokio_bin_process::event_matcher::Count::Times(3))]) + .await; +} diff --git a/shotover/Cargo.toml b/shotover/Cargo.toml index 695409b61..d86e66396 100644 --- a/shotover/Cargo.toml +++ b/shotover/Cargo.toml @@ -65,7 +65,7 @@ metrics-exporter-prometheus = "0.12.0" tracing.workspace = true tracing-subscriber.workspace = true tracing-appender.workspace = true -hyper = { version = "0.14.14", features = ["server"] } +hyper.workspace = true halfbrown = "0.2.1" # Transform dependencies diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index a239b78c1..aaca35484 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -2,11 +2,20 @@ use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; -use anyhow::Result; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; +use atomic_enum::atomic_enum; +use bytes::Bytes; +use hyper::{ + service::{make_service_fn, service_fn}, + Method, Request, StatusCode, {Body, Response, Server}, +}; use metrics::{register_counter, Counter}; use serde::{Deserialize, Serialize}; -use tracing::{debug, trace, warn}; +use std::fmt; +use std::sync::atomic::Ordering; +use std::{convert::Infallible, net::SocketAddr, str, sync::Arc}; +use tracing::{debug, error, trace, warn}; pub struct TeeBuilder { pub tx: TransformChainBuilder, @@ -14,6 +23,7 @@ pub struct TeeBuilder { pub behavior: ConsistencyBehaviorBuilder, pub timeout_micros: Option, dropped_messages: Counter, + result_source: Arc, } pub enum ConsistencyBehaviorBuilder { @@ -29,7 +39,16 @@ impl TeeBuilder { buffer_size: usize, behavior: ConsistencyBehaviorBuilder, timeout_micros: Option, + switch_port: Option, ) -> Self { + let result_source = Arc::new(AtomicResultSource::new(ResultSource::RegularChain)); + + if let Some(switch_port) = switch_port { + let chain_switch_listener = + ChainSwitchListener::new(SocketAddr::from(([127, 0, 0, 1], switch_port))); + tokio::spawn(chain_switch_listener.async_run(result_source.clone())); + } + let dropped_messages = register_counter!("tee_dropped_messages", "chain" => "Tee"); TeeBuilder { @@ -38,6 +57,7 @@ impl TeeBuilder { behavior, timeout_micros, dropped_messages, + result_source, } } } @@ -59,6 +79,7 @@ impl TransformBuilder for TeeBuilder { buffer_size: self.buffer_size, timeout_micros: self.timeout_micros, dropped_messages: self.dropped_messages.clone(), + result_source: self.result_source.clone(), }) } @@ -97,6 +118,22 @@ pub struct Tee { pub behavior: ConsistencyBehavior, pub timeout_micros: Option, dropped_messages: Counter, + result_source: Arc, +} + +#[atomic_enum] +pub enum ResultSource { + RegularChain, + TeeChain, +} + +impl fmt::Display for ResultSource { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ResultSource::RegularChain => write!(f, "regular-chain"), + ResultSource::TeeChain => write!(f, "tee-chain"), + } + } } pub enum ConsistencyBehavior { @@ -113,6 +150,7 @@ pub struct TeeConfig { pub timeout_micros: Option, pub chain: TransformChainConfig, pub buffer_size: Option, + pub switch_port: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -153,6 +191,7 @@ impl TransformConfig for TeeConfig { buffer_size, behavior, self.timeout_micros, + self.switch_port, ))) } } @@ -161,18 +200,7 @@ impl TransformConfig for TeeConfig { impl Transform for Tee { async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { match &mut self.behavior { - ConsistencyBehavior::Ignore => { - let (tee_result, chain_result) = tokio::join!( - self.tx - .process_request_no_return(requests_wrapper.clone(), self.timeout_micros), - requests_wrapper.call_next_transform() - ); - if let Err(e) = tee_result { - self.dropped_messages.increment(1); - trace!("Tee Ignored error {e}"); - } - chain_result - } + ConsistencyBehavior::Ignore => self.ignore_behaviour(requests_wrapper).await, ConsistencyBehavior::FailOnMismatch => { let (tee_result, chain_result) = tokio::join!( self.tx @@ -200,7 +228,8 @@ impl Transform for Tee { "ERR The responses from the Tee subchain and down-chain did not match and behavior is set to fail on mismatch".into())?; } } - Ok(chain_response) + + Ok(self.return_response(tee_response, chain_response).await) } ConsistencyBehavior::SubchainOnMismatch(mismatch_chain) => { let failed_message = requests_wrapper.clone(); @@ -217,7 +246,7 @@ impl Transform for Tee { mismatch_chain.process_request(failed_message, None).await?; } - Ok(chain_response) + Ok(self.return_response(tee_response, chain_response).await) } ConsistencyBehavior::LogWarningOnMismatch => { let (tee_result, chain_result) = tokio::join!( @@ -242,10 +271,157 @@ impl Transform for Tee { .collect::>() ); } - Ok(chain_response) + Ok(self.return_response(tee_response, chain_response).await) + } + } + } +} + +impl Tee { + async fn return_response(&mut self, tee_result: Messages, chain_result: Messages) -> Messages { + let result_source: ResultSource = self.result_source.load(Ordering::Relaxed); + match result_source { + ResultSource::RegularChain => chain_result, + ResultSource::TeeChain => tee_result, + } + } + + async fn ignore_behaviour<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { + let result_source: ResultSource = self.result_source.load(Ordering::Relaxed); + match result_source { + ResultSource::RegularChain => { + let (tee_result, chain_result) = tokio::join!( + self.tx + .process_request_no_return(requests_wrapper.clone(), self.timeout_micros), + requests_wrapper.call_next_transform() + ); + if let Err(e) = tee_result { + self.dropped_messages.increment(1); + trace!("Tee Ignored error {e}"); + } + chain_result + } + ResultSource::TeeChain => { + let (tee_result, chain_result) = tokio::join!( + self.tx + .process_request(requests_wrapper.clone(), self.timeout_micros), + requests_wrapper.call_next_transform() + ); + if let Err(e) = chain_result { + self.dropped_messages.increment(1); + trace!("Tee Ignored error {e}"); + } + tee_result + } + } + } +} + +struct ChainSwitchListener { + address: SocketAddr, +} + +impl ChainSwitchListener { + fn new(address: SocketAddr) -> Self { + Self { address } + } + + fn rsp(status: StatusCode, body: impl Into) -> Response { + Response::builder() + .status(status) + .body(body.into()) + .expect("builder with known status code must not fail") + } + + async fn set_result_source_chain( + body: Bytes, + result_source: Arc, + ) -> Result<()> { + let new_result_source = str::from_utf8(body.as_ref())?; + + let new_value = match new_result_source { + "tee-chain" => ResultSource::TeeChain, + "regular-chain" => ResultSource::RegularChain, + _ => { + return Err(anyhow!( + r"Invalid value for result source: {}, should be 'tee-chain' or 'regular-chain'", + new_result_source + )) } + }; + + debug!("Setting result source to {}", new_value); + + result_source.store(new_value, Ordering::Relaxed); + Ok(()) + } + + async fn async_run(self, result_source: Arc) { + if let Err(err) = self.async_run_inner(result_source).await { + error!("Error in ChainSwitchListener: {}", err); } } + + async fn async_run_inner(self, result_source: Arc) -> Result<()> { + let make_svc = make_service_fn(move |_| { + let result_source = result_source.clone(); + async move { + Ok::<_, Infallible>(service_fn(move |req: Request| { + let result_source = result_source.clone(); + async move { + let response = match (req.method(), req.uri().path()) { + (&Method::GET, "/transform/tee/result-source") => { + let result_source: ResultSource = + result_source.load(Ordering::Relaxed); + Self::rsp(StatusCode::OK, result_source.to_string()) + } + (&Method::PUT, "/transform/tee/result-source") => { + match hyper::body::to_bytes(req.into_body()).await { + Ok(body) => { + match Self::set_result_source_chain( + body, + result_source.clone(), + ) + .await + { + Err(error) => { + error!(?error, "setting result source failed"); + Self::rsp( + StatusCode::BAD_REQUEST, + format!( + "setting result source failed: {error}" + ), + ) + } + Ok(()) => Self::rsp(StatusCode::OK, Body::empty()), + } + } + Err(error) => { + error!(%error, "setting result source failed - Couldn't read bytes"); + Self::rsp( + StatusCode::INTERNAL_SERVER_ERROR, + format!("{error:?}"), + ) + } + } + } + _ => { + Self::rsp(StatusCode::NOT_FOUND, "try /tranform/tee/result-source") + } + }; + Ok::<_, Infallible>(response) + } + })) + } + }); + + let address = self.address; + Server::try_bind(&address) + .with_context(|| format!("Failed to bind to {}", address))? + .serve(make_svc) + .await + .map_err(|e| anyhow!(e)) + } } #[cfg(test)] @@ -260,6 +436,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap(); @@ -274,6 +451,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig), Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap(); @@ -291,6 +469,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); @@ -304,6 +483,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap(); let result = transform.validate(); @@ -319,6 +499,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap(); @@ -338,6 +519,7 @@ mod tests { timeout_micros: None, chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]), buffer_size: None, + switch_port: None, }; let transform = config.get_builder("".to_owned()).await.unwrap();