diff --git a/Cargo.lock b/Cargo.lock
index 513c729aa..a04354f45 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4456,6 +4456,7 @@ dependencies = [
"futures",
"hex",
"hex-literal",
+ "hyper",
"inferno",
"itertools 0.11.0",
"nix 0.27.1",
diff --git a/Cargo.toml b/Cargo.toml
index c9fb1a0f7..bc2679e8c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -64,3 +64,4 @@ typetag = "0.2.5"
aws-throwaway = "0.3.0"
tokio-bin-process = "0.4.0"
ordered-float = { version = "4.0.0", features = ["serde"] }
+hyper = { version = "0.14.14", features = ["server"] }
diff --git a/docs/src/transforms.md b/docs/src/transforms.md
index 5ae551fa7..1939d595a 100644
--- a/docs/src/transforms.md
+++ b/docs/src/transforms.md
@@ -506,6 +506,12 @@ This is mainly used in conjunction with the `TuneableConsistencyScatter` transfo
This transform sends messages to both the defined sub chain and the remaining down-chain transforms.
The response from the down-chain transform is returned back up-chain but various behaviours can be defined by the `behaviour` field to handle the case when the responses from the sub chain and down-chain do not match.
+Tee also exposes an optional HTTP API to switch which chain to use as the "result source", that is the chain to return responses from.
+
+`GET` `/transform/tee/result-source` will return `regular-chain` or `tee-chain` indicating which chain is being used for the result source.
+
+`PUT` `/transform/tee/result-source` with the body content as either `regular-chain` or `tee-chain` to set the result source.
+
```yaml
- Tee:
# Ignore responses returned by the sub chain
@@ -528,6 +534,10 @@ The response from the down-chain transform is returned back up-chain but various
# filter: Read
# - NullSink
+ # The port that the HTTP API will listen on.
+ # When this field is not provided the HTTP API will not be run.
+ # http_api_port: 1234
+ #
# Timeout for sending to the sub chain in microseconds
timeout_micros: 1000
# The number of message batches that the tee can hold onto in its buffer of messages to send.
diff --git a/shotover-proxy/Cargo.toml b/shotover-proxy/Cargo.toml
index 7b8de0bff..2e1cb09dd 100644
--- a/shotover-proxy/Cargo.toml
+++ b/shotover-proxy/Cargo.toml
@@ -54,6 +54,7 @@ serde_json = "1.0.103"
time = { version = "0.3.25" }
inferno = { version = "0.11.15", default-features = false, features = ["multithreaded", "nameattr"] }
shell-quote = "0.3.0"
+hyper.workspace = true
[features]
# Include WIP alpha transforms in the public API
diff --git a/shotover-proxy/tests/test-configs/tee/switch_chain.yaml b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml
new file mode 100644
index 000000000..9af3ed9e1
--- /dev/null
+++ b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml
@@ -0,0 +1,46 @@
+---
+sources:
+ - Redis:
+ name: "redis-1"
+ listen_addr: "127.0.0.1:6371"
+ connection_limit:
+ chain:
+ - Tee:
+ behavior: Ignore
+ buffer_size: 10000
+ switch_port: 1231
+ chain:
+ - DebugReturner:
+ Redis: "b"
+ - DebugReturner:
+ Redis: "a"
+ - Redis:
+ name: "redis-3"
+ listen_addr: "127.0.0.1:6372"
+ connection_limit:
+ chain:
+ - Tee:
+ behavior:
+ SubchainOnMismatch:
+ - NullSink
+ buffer_size: 10000
+ switch_port: 1232
+ chain:
+ - DebugReturner:
+ Redis: "b"
+ - DebugReturner:
+ Redis: "a"
+ - Redis:
+ name: "redis-3"
+ listen_addr: "127.0.0.1:6373"
+ connection_limit:
+ chain:
+ - Tee:
+ behavior: LogWarningOnMismatch
+ buffer_size: 10000
+ switch_port: 1233
+ chain:
+ - DebugReturner:
+ Redis: "b"
+ - DebugReturner:
+ Redis: "a"
diff --git a/shotover-proxy/tests/transforms/tee.rs b/shotover-proxy/tests/transforms/tee.rs
index b6bc41758..1f944d2ee 100644
--- a/shotover-proxy/tests/transforms/tee.rs
+++ b/shotover-proxy/tests/transforms/tee.rs
@@ -1,4 +1,5 @@
use crate::shotover_process;
+use hyper::{body, Body, Client, Method, Request, Response};
use test_helpers::connection::redis_connection;
use test_helpers::docker_compose::docker_compose;
use test_helpers::shotover_process::{EventMatcher, Level};
@@ -193,3 +194,99 @@ async fn test_subchain_with_mismatch() {
assert_eq!("myvalue", result);
shotover.shutdown_and_then_consume_events(&[]).await;
}
+
+async fn read_response_body(res: Response
) -> Result {
+ let bytes = body::to_bytes(res.into_body()).await?;
+ Ok(String::from_utf8(bytes.to_vec()).expect("response was not valid utf-8"))
+}
+
+async fn hyper_request(uri: String, method: Method, body: Body) -> Response {
+ let client = Client::new();
+
+ let req = Request::builder()
+ .method(method)
+ .uri(uri)
+ .body(body)
+ .expect("request builder");
+
+ client.request(req).await.unwrap()
+}
+
+#[tokio::test(flavor = "multi_thread")]
+async fn test_switch_main_chain() {
+ let shotover = shotover_process("tests/test-configs/tee/switch_chain.yaml")
+ .start()
+ .await;
+
+ for i in 1..=3 {
+ let redis_port = 6370 + i;
+ let switch_port = 1230 + i;
+
+ let mut connection = redis_connection::new_async("127.0.0.1", redis_port).await;
+
+ let result = redis::cmd("SET")
+ .arg("key")
+ .arg("myvalue")
+ .query_async::<_, String>(&mut connection)
+ .await
+ .unwrap();
+
+ assert_eq!("a", result);
+
+ let _ = hyper_request(
+ format!(
+ "http://localhost:{}/transform/tee/result-source",
+ switch_port
+ ),
+ Method::PUT,
+ Body::from("tee-chain"),
+ )
+ .await;
+
+ let res = hyper_request(
+ format!(
+ "http://localhost:{}/transform/tee/result-source",
+ switch_port
+ ),
+ Method::GET,
+ Body::empty(),
+ )
+ .await;
+ let body = read_response_body(res).await.unwrap();
+ assert_eq!("tee-chain", body);
+
+ let result = redis::cmd("SET")
+ .arg("key")
+ .arg("myvalue")
+ .query_async::<_, String>(&mut connection)
+ .await
+ .unwrap();
+
+ assert_eq!("b", result);
+
+ let _ = hyper_request(
+ format!(
+ "http://localhost:{}/transform/tee/result-source",
+ switch_port
+ ),
+ Method::PUT,
+ Body::from("regular-chain"),
+ )
+ .await;
+
+ let result = redis::cmd("SET")
+ .arg("key")
+ .arg("myvalue")
+ .query_async::<_, String>(&mut connection)
+ .await
+ .unwrap();
+
+ assert_eq!("a", result);
+ }
+
+ shotover
+ .shutdown_and_then_consume_events(&[EventMatcher::new()
+ .with_level(Level::Warn)
+ .with_count(tokio_bin_process::event_matcher::Count::Times(3))])
+ .await;
+}
diff --git a/shotover/Cargo.toml b/shotover/Cargo.toml
index 695409b61..d86e66396 100644
--- a/shotover/Cargo.toml
+++ b/shotover/Cargo.toml
@@ -65,7 +65,7 @@ metrics-exporter-prometheus = "0.12.0"
tracing.workspace = true
tracing-subscriber.workspace = true
tracing-appender.workspace = true
-hyper = { version = "0.14.14", features = ["server"] }
+hyper.workspace = true
halfbrown = "0.2.1"
# Transform dependencies
diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs
index a239b78c1..aaca35484 100644
--- a/shotover/src/transforms/tee.rs
+++ b/shotover/src/transforms/tee.rs
@@ -2,11 +2,20 @@ use crate::config::chain::TransformChainConfig;
use crate::message::Messages;
use crate::transforms::chain::{BufferedChain, TransformChainBuilder};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper};
-use anyhow::Result;
+use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
+use atomic_enum::atomic_enum;
+use bytes::Bytes;
+use hyper::{
+ service::{make_service_fn, service_fn},
+ Method, Request, StatusCode, {Body, Response, Server},
+};
use metrics::{register_counter, Counter};
use serde::{Deserialize, Serialize};
-use tracing::{debug, trace, warn};
+use std::fmt;
+use std::sync::atomic::Ordering;
+use std::{convert::Infallible, net::SocketAddr, str, sync::Arc};
+use tracing::{debug, error, trace, warn};
pub struct TeeBuilder {
pub tx: TransformChainBuilder,
@@ -14,6 +23,7 @@ pub struct TeeBuilder {
pub behavior: ConsistencyBehaviorBuilder,
pub timeout_micros: Option,
dropped_messages: Counter,
+ result_source: Arc,
}
pub enum ConsistencyBehaviorBuilder {
@@ -29,7 +39,16 @@ impl TeeBuilder {
buffer_size: usize,
behavior: ConsistencyBehaviorBuilder,
timeout_micros: Option,
+ switch_port: Option,
) -> Self {
+ let result_source = Arc::new(AtomicResultSource::new(ResultSource::RegularChain));
+
+ if let Some(switch_port) = switch_port {
+ let chain_switch_listener =
+ ChainSwitchListener::new(SocketAddr::from(([127, 0, 0, 1], switch_port)));
+ tokio::spawn(chain_switch_listener.async_run(result_source.clone()));
+ }
+
let dropped_messages = register_counter!("tee_dropped_messages", "chain" => "Tee");
TeeBuilder {
@@ -38,6 +57,7 @@ impl TeeBuilder {
behavior,
timeout_micros,
dropped_messages,
+ result_source,
}
}
}
@@ -59,6 +79,7 @@ impl TransformBuilder for TeeBuilder {
buffer_size: self.buffer_size,
timeout_micros: self.timeout_micros,
dropped_messages: self.dropped_messages.clone(),
+ result_source: self.result_source.clone(),
})
}
@@ -97,6 +118,22 @@ pub struct Tee {
pub behavior: ConsistencyBehavior,
pub timeout_micros: Option,
dropped_messages: Counter,
+ result_source: Arc,
+}
+
+#[atomic_enum]
+pub enum ResultSource {
+ RegularChain,
+ TeeChain,
+}
+
+impl fmt::Display for ResultSource {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ ResultSource::RegularChain => write!(f, "regular-chain"),
+ ResultSource::TeeChain => write!(f, "tee-chain"),
+ }
+ }
}
pub enum ConsistencyBehavior {
@@ -113,6 +150,7 @@ pub struct TeeConfig {
pub timeout_micros: Option,
pub chain: TransformChainConfig,
pub buffer_size: Option,
+ pub switch_port: Option,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -153,6 +191,7 @@ impl TransformConfig for TeeConfig {
buffer_size,
behavior,
self.timeout_micros,
+ self.switch_port,
)))
}
}
@@ -161,18 +200,7 @@ impl TransformConfig for TeeConfig {
impl Transform for Tee {
async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result {
match &mut self.behavior {
- ConsistencyBehavior::Ignore => {
- let (tee_result, chain_result) = tokio::join!(
- self.tx
- .process_request_no_return(requests_wrapper.clone(), self.timeout_micros),
- requests_wrapper.call_next_transform()
- );
- if let Err(e) = tee_result {
- self.dropped_messages.increment(1);
- trace!("Tee Ignored error {e}");
- }
- chain_result
- }
+ ConsistencyBehavior::Ignore => self.ignore_behaviour(requests_wrapper).await,
ConsistencyBehavior::FailOnMismatch => {
let (tee_result, chain_result) = tokio::join!(
self.tx
@@ -200,7 +228,8 @@ impl Transform for Tee {
"ERR The responses from the Tee subchain and down-chain did not match and behavior is set to fail on mismatch".into())?;
}
}
- Ok(chain_response)
+
+ Ok(self.return_response(tee_response, chain_response).await)
}
ConsistencyBehavior::SubchainOnMismatch(mismatch_chain) => {
let failed_message = requests_wrapper.clone();
@@ -217,7 +246,7 @@ impl Transform for Tee {
mismatch_chain.process_request(failed_message, None).await?;
}
- Ok(chain_response)
+ Ok(self.return_response(tee_response, chain_response).await)
}
ConsistencyBehavior::LogWarningOnMismatch => {
let (tee_result, chain_result) = tokio::join!(
@@ -242,10 +271,157 @@ impl Transform for Tee {
.collect::>()
);
}
- Ok(chain_response)
+ Ok(self.return_response(tee_response, chain_response).await)
+ }
+ }
+ }
+}
+
+impl Tee {
+ async fn return_response(&mut self, tee_result: Messages, chain_result: Messages) -> Messages {
+ let result_source: ResultSource = self.result_source.load(Ordering::Relaxed);
+ match result_source {
+ ResultSource::RegularChain => chain_result,
+ ResultSource::TeeChain => tee_result,
+ }
+ }
+
+ async fn ignore_behaviour<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result {
+ let result_source: ResultSource = self.result_source.load(Ordering::Relaxed);
+ match result_source {
+ ResultSource::RegularChain => {
+ let (tee_result, chain_result) = tokio::join!(
+ self.tx
+ .process_request_no_return(requests_wrapper.clone(), self.timeout_micros),
+ requests_wrapper.call_next_transform()
+ );
+ if let Err(e) = tee_result {
+ self.dropped_messages.increment(1);
+ trace!("Tee Ignored error {e}");
+ }
+ chain_result
+ }
+ ResultSource::TeeChain => {
+ let (tee_result, chain_result) = tokio::join!(
+ self.tx
+ .process_request(requests_wrapper.clone(), self.timeout_micros),
+ requests_wrapper.call_next_transform()
+ );
+ if let Err(e) = chain_result {
+ self.dropped_messages.increment(1);
+ trace!("Tee Ignored error {e}");
+ }
+ tee_result
+ }
+ }
+ }
+}
+
+struct ChainSwitchListener {
+ address: SocketAddr,
+}
+
+impl ChainSwitchListener {
+ fn new(address: SocketAddr) -> Self {
+ Self { address }
+ }
+
+ fn rsp(status: StatusCode, body: impl Into) -> Response {
+ Response::builder()
+ .status(status)
+ .body(body.into())
+ .expect("builder with known status code must not fail")
+ }
+
+ async fn set_result_source_chain(
+ body: Bytes,
+ result_source: Arc,
+ ) -> Result<()> {
+ let new_result_source = str::from_utf8(body.as_ref())?;
+
+ let new_value = match new_result_source {
+ "tee-chain" => ResultSource::TeeChain,
+ "regular-chain" => ResultSource::RegularChain,
+ _ => {
+ return Err(anyhow!(
+ r"Invalid value for result source: {}, should be 'tee-chain' or 'regular-chain'",
+ new_result_source
+ ))
}
+ };
+
+ debug!("Setting result source to {}", new_value);
+
+ result_source.store(new_value, Ordering::Relaxed);
+ Ok(())
+ }
+
+ async fn async_run(self, result_source: Arc) {
+ if let Err(err) = self.async_run_inner(result_source).await {
+ error!("Error in ChainSwitchListener: {}", err);
}
}
+
+ async fn async_run_inner(self, result_source: Arc) -> Result<()> {
+ let make_svc = make_service_fn(move |_| {
+ let result_source = result_source.clone();
+ async move {
+ Ok::<_, Infallible>(service_fn(move |req: Request| {
+ let result_source = result_source.clone();
+ async move {
+ let response = match (req.method(), req.uri().path()) {
+ (&Method::GET, "/transform/tee/result-source") => {
+ let result_source: ResultSource =
+ result_source.load(Ordering::Relaxed);
+ Self::rsp(StatusCode::OK, result_source.to_string())
+ }
+ (&Method::PUT, "/transform/tee/result-source") => {
+ match hyper::body::to_bytes(req.into_body()).await {
+ Ok(body) => {
+ match Self::set_result_source_chain(
+ body,
+ result_source.clone(),
+ )
+ .await
+ {
+ Err(error) => {
+ error!(?error, "setting result source failed");
+ Self::rsp(
+ StatusCode::BAD_REQUEST,
+ format!(
+ "setting result source failed: {error}"
+ ),
+ )
+ }
+ Ok(()) => Self::rsp(StatusCode::OK, Body::empty()),
+ }
+ }
+ Err(error) => {
+ error!(%error, "setting result source failed - Couldn't read bytes");
+ Self::rsp(
+ StatusCode::INTERNAL_SERVER_ERROR,
+ format!("{error:?}"),
+ )
+ }
+ }
+ }
+ _ => {
+ Self::rsp(StatusCode::NOT_FOUND, "try /tranform/tee/result-source")
+ }
+ };
+ Ok::<_, Infallible>(response)
+ }
+ }))
+ }
+ });
+
+ let address = self.address;
+ Server::try_bind(&address)
+ .with_context(|| format!("Failed to bind to {}", address))?
+ .serve(make_svc)
+ .await
+ .map_err(|e| anyhow!(e))
+ }
}
#[cfg(test)]
@@ -260,6 +436,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();
@@ -274,6 +451,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig), Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();
@@ -291,6 +469,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();
let result = transform.validate();
@@ -304,6 +483,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();
let result = transform.validate();
@@ -319,6 +499,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();
@@ -338,6 +519,7 @@ mod tests {
timeout_micros: None,
chain: TransformChainConfig(vec![Box::new(NullSinkConfig)]),
buffer_size: None,
+ switch_port: None,
};
let transform = config.get_builder("".to_owned()).await.unwrap();