diff --git a/shotover-proxy/tests/test-configs/tee/switch_chain.yaml b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml index 9af3ed9e1..f6c435ff0 100644 --- a/shotover-proxy/tests/test-configs/tee/switch_chain.yaml +++ b/shotover-proxy/tests/test-configs/tee/switch_chain.yaml @@ -15,7 +15,7 @@ sources: - DebugReturner: Redis: "a" - Redis: - name: "redis-3" + name: "redis-2" listen_addr: "127.0.0.1:6372" connection_limit: chain: diff --git a/shotover-proxy/tests/transforms/tee.rs b/shotover-proxy/tests/transforms/tee.rs index 78d3b40e3..1216308be 100644 --- a/shotover-proxy/tests/transforms/tee.rs +++ b/shotover-proxy/tests/transforms/tee.rs @@ -1,7 +1,7 @@ use crate::shotover_process; use test_helpers::connection::redis_connection; use test_helpers::docker_compose::docker_compose; -use test_helpers::shotover_process::{EventMatcher, Level}; +use test_helpers::shotover_process::{Count, EventMatcher, Level}; #[tokio::test(flavor = "multi_thread")] async fn test_ignore_matches() { @@ -77,24 +77,15 @@ async fn test_log_with_mismatch() { assert_eq!("42", result); shotover - .shutdown_and_then_consume_events(&[ - EventMatcher::new() - .with_level(Level::Warn) - .with_target("shotover::transforms::tee") - .with_message( - r#"Tee mismatch: -chain response: ["Redis BulkString(b\"42\")", "Redis BulkString(b\"42\")"] -tee response: ["Redis BulkString(b\"41\")", "Redis BulkString(b\"41\")"]"#, - ), - EventMatcher::new() - .with_level(Level::Warn) - .with_target("shotover::transforms::tee") - .with_message( - r#"Tee mismatch: -chain response: ["Redis BulkString(b\"42\")"] -tee response: ["Redis BulkString(b\"41\")"]"#, - ), - ]) + .shutdown_and_then_consume_events(&[EventMatcher::new() + .with_level(Level::Warn) + .with_count(Count::Times(3)) + .with_target("shotover::transforms::tee") + .with_message( + r#"Tee mismatch: +result-source response: Redis BulkString(b"42") +other response: Redis BulkString(b"41")"#, + )]) .await; } @@ -211,6 +202,7 @@ async fn test_switch_main_chain() { .await; for i in 1..=3 { + println!("{i}"); let redis_port = 6370 + i; let switch_port = 1230 + i; @@ -254,9 +246,27 @@ async fn test_switch_main_chain() { } shotover - .shutdown_and_then_consume_events(&[EventMatcher::new() - .with_level(Level::Warn) - // 1 warning per loop above + 1 warning from the redis-rs driver connection handshake - .with_count(tokio_bin_process::event_matcher::Count::Times(4))]) + .shutdown_and_then_consume_events(&[ + EventMatcher::new() + .with_level(Level::Warn) + // generated by the final loop above, 2 by the requests + 2 by the redis-rs driver connection handshake + .with_count(Count::Times(4)) + .with_target("shotover::transforms::tee") + .with_message( + r#"Tee mismatch: +result-source response: Redis BulkString(b"a") +other response: Redis BulkString(b"b")"#, + ), + EventMatcher::new() + .with_level(Level::Warn) + // generated by the final loop above, by the request made while the result-source is flipped. + .with_count(Count::Times(1)) + .with_target("shotover::transforms::tee") + .with_message( + r#"Tee mismatch: +result-source response: Redis BulkString(b"b") +other response: Redis BulkString(b"a")"#, + ), + ]) .await; } diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index 061c2de88..4ebaedaaa 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,6 +1,6 @@ use super::TransformContextConfig; use crate::config::chain::TransformChainConfig; -use crate::message::Messages; +use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; @@ -13,6 +13,7 @@ use hyper::{ }; use metrics::{counter, Counter}; use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; use std::fmt; use std::sync::atomic::Ordering; use std::{convert::Infallible, net::SocketAddr, str, sync::Arc}; @@ -25,6 +26,7 @@ pub struct TeeBuilder { pub timeout_micros: Option, dropped_messages: Counter, result_source: Arc, + protocol_is_inorder: bool, } pub enum ConsistencyBehaviorBuilder { @@ -41,6 +43,7 @@ impl TeeBuilder { behavior: ConsistencyBehaviorBuilder, timeout_micros: Option, switch_port: Option, + protocol_is_inorder: bool, ) -> Self { let result_source = Arc::new(AtomicResultSource::new(ResultSource::RegularChain)); @@ -59,6 +62,7 @@ impl TeeBuilder { timeout_micros, dropped_messages, result_source, + protocol_is_inorder, } } } @@ -74,13 +78,27 @@ impl TransformBuilder for TeeBuilder { } ConsistencyBehaviorBuilder::FailOnMismatch => ConsistencyBehavior::FailOnMismatch, ConsistencyBehaviorBuilder::SubchainOnMismatch(chain) => { - ConsistencyBehavior::SubchainOnMismatch(chain.build_buffered(self.buffer_size)) + ConsistencyBehavior::SubchainOnMismatch( + chain.build_buffered(self.buffer_size), + Default::default(), + ) } }, buffer_size: self.buffer_size, timeout_micros: self.timeout_micros, dropped_messages: self.dropped_messages.clone(), result_source: self.result_source.clone(), + incoming_responses: if self.protocol_is_inorder { + IncomingResponses::InOrder { + tee: VecDeque::new(), + chain: VecDeque::new(), + } + } else { + IncomingResponses::OutOfOrder { + tee_by_request_id: Default::default(), + chain_by_request_id: Default::default(), + } + }, }) } @@ -120,6 +138,7 @@ pub struct Tee { pub timeout_micros: Option, dropped_messages: Counter, result_source: Arc, + incoming_responses: IncomingResponses, } #[atomic_enum] @@ -141,7 +160,7 @@ pub enum ConsistencyBehavior { Ignore, LogWarningOnMismatch, FailOnMismatch, - SubchainOnMismatch(BufferedChain), + SubchainOnMismatch(BufferedChain, MessageIdMap), } #[derive(Serialize, Deserialize, Debug)] @@ -206,6 +225,7 @@ impl TransformConfig for TeeConfig { behavior, self.timeout_micros, self.switch_port, + transform_context.protocol.is_inorder(), ))) } } @@ -225,46 +245,54 @@ impl Transform for Tee { .process_request(requests_wrapper.clone(), self.timeout_micros), requests_wrapper.call_next_transform() ); - let mut tee_response = tee_result?; - let mut chain_response = chain_result?; - - if !chain_response.eq(&tee_response) { - debug!( - "Tee mismatch:\nchain response: {:?}\ntee response: {:?}", - chain_response - .iter_mut() - .map(|m| m.to_high_level_string()) - .collect::>(), - tee_response - .iter_mut() - .map(|m| m.to_high_level_string()) - .collect::>() - ); - - for message in &mut chain_response { - *message = message.to_error_response( - "ERR The responses from the Tee subchain and down-chain did not match and behavior is set to fail on mismatch".into())?; - } - } - Ok(self.return_response(tee_response, chain_response).await) + let keep: ResultSource = self.result_source.load(Ordering::Relaxed); + let responses = self.incoming_responses.new_responses( + tee_result?, + chain_result?, + keep, + |keep_message, mut other_message| { + debug!( + "Tee mismatch:\nresult-source response: {}\nother response: {}", + keep_message.to_high_level_string(), + other_message.to_high_level_string() + ); + *keep_message = keep_message.to_error_response( + "ERR The responses from the Tee subchain and down-chain did not match and behavior is set to fail on mismatch".into() + ).unwrap(); + }, + ); + + Ok(responses) } - ConsistencyBehavior::SubchainOnMismatch(mismatch_chain) => { - let failed_message = requests_wrapper.clone(); + ConsistencyBehavior::SubchainOnMismatch(mismatch_chain, requests) => { + let address = requests_wrapper.local_addr; + for request in &requests_wrapper.requests { + requests.insert(request.id(), request.clone()); + } let (tee_result, chain_result) = tokio::join!( self.tx .process_request(requests_wrapper.clone(), self.timeout_micros), requests_wrapper.call_next_transform() ); - let tee_response = tee_result?; - let chain_response = chain_result?; - - if !chain_response.eq(&tee_response) { - mismatch_chain.process_request(failed_message, None).await?; - } + let mut mismatched_requests = vec![]; + let keep: ResultSource = self.result_source.load(Ordering::Relaxed); + let responses = self.incoming_responses.new_responses( + tee_result?, + chain_result?, + keep, + |keep_message, _| { + if let Some(id) = keep_message.request_id() { + mismatched_requests.push(requests.remove(&id).unwrap()); + } + }, + ); + mismatch_chain + .process_request(Wrapper::new_with_addr(mismatched_requests, address), None) + .await?; - Ok(self.return_response(tee_response, chain_response).await) + Ok(responses) } ConsistencyBehavior::LogWarningOnMismatch => { let (tee_result, chain_result) = tokio::join!( @@ -273,37 +301,178 @@ impl Transform for Tee { requests_wrapper.call_next_transform() ); - let mut tee_response = tee_result?; - let mut chain_response = chain_result?; - - if !chain_response.eq(&tee_response) { - warn!( - "Tee mismatch:\nchain response: {:?}\ntee response: {:?}", - chain_response - .iter_mut() - .map(|m| m.to_high_level_string()) - .collect::>(), - tee_response - .iter_mut() - .map(|m| m.to_high_level_string()) - .collect::>() - ); - } - Ok(self.return_response(tee_response, chain_response).await) + let keep: ResultSource = self.result_source.load(Ordering::Relaxed); + let responses = self.incoming_responses.new_responses( + tee_result?, + chain_result?, + keep, + |keep_message, mut other_message| { + warn!( + "Tee mismatch:\nresult-source response: {}\nother response: {}", + keep_message.to_high_level_string(), + other_message.to_high_level_string() + ); + }, + ); + + Ok(responses) } } } } -impl Tee { - async fn return_response(&mut self, tee_result: Messages, chain_result: Messages) -> Messages { - let result_source: ResultSource = self.result_source.load(Ordering::Relaxed); - match result_source { - ResultSource::RegularChain => chain_result, - ResultSource::TeeChain => tee_result, +enum IncomingResponses { + /// We must handle in order protocols seperately because we must maintain their order + InOrder { + tee: VecDeque, + chain: VecDeque, + }, + /// We must handle out of order protocols seperately because they could arrive in any order. + OutOfOrder { + tee_by_request_id: MessageIdMap, + chain_by_request_id: MessageIdMap, + }, +} + +impl IncomingResponses { + /// Processes incoming responses. + /// If we have a complete response pair then immediately process and return them. + /// Otherwise store the individual response so that we may eventually match it up with its pair. + /// Responses with no corresponding request are immediately returned or dropped as they will not have any pair. + fn new_responses( + &mut self, + tee_responses: Vec, + chain_responses: Vec, + keep: ResultSource, + mut on_mismatch: F, + ) -> Vec + where + F: FnMut(&mut Message, Message), + { + let mut result = vec![]; + match self { + IncomingResponses::InOrder { tee, chain } => { + tee.extend(tee_responses); + chain.extend(chain_responses); + + // process all responses where we have received from tee and chain + while !tee.is_empty() && !chain.is_empty() { + // handle responses with no request + if tee.front().unwrap().request_id().is_none() { + result.push(tee.pop_front().unwrap()); + // need to start the loop again otherwise we might find there are no responses to pop! + continue; + } + if chain.front().unwrap().request_id().is_none() { + result.push(chain.pop_front().unwrap()); + continue; + } + + let mut tee_response = tee.pop_front().unwrap(); + let mut chain_response = chain.pop_front().unwrap(); + match keep { + ResultSource::RegularChain => { + if tee_response != chain_response { + on_mismatch(&mut chain_response, tee_response); + } + result.push(chain_response); + } + ResultSource::TeeChain => { + if tee_response != chain_response { + on_mismatch(&mut tee_response, chain_response); + } + result.push(tee_response); + } + } + } + + // once again, handle responses with no request + // we need to recheck to ensure we havent left any requestless responses lingering + if tee + .front() + .map(|x| x.request_id().is_none()) + .unwrap_or(false) + { + result.push(tee.pop_front().unwrap()); + } + if chain + .front() + .map(|x| x.request_id().is_none()) + .unwrap_or(false) + { + result.push(chain.pop_front().unwrap()); + } + } + IncomingResponses::OutOfOrder { + tee_by_request_id, + chain_by_request_id, + } => { + // Handle all incoming tee responses that have a matching stored chain response + for mut tee_response in tee_responses { + if let Some(request_id) = tee_response.request_id() { + // a requested response, compare against the other chain before sending it on. + if let Some(mut chain_response) = chain_by_request_id.remove(&request_id) { + match keep { + ResultSource::TeeChain => { + if tee_response != chain_response { + on_mismatch(&mut tee_response, chain_response); + } + result.push(tee_response); + } + ResultSource::RegularChain => { + if tee_response != chain_response { + on_mismatch(&mut chain_response, tee_response); + } + result.push(chain_response); + } + } + } else { + tee_by_request_id.insert(request_id, tee_response); + } + } else { + // unrequested response, so just send it on if its from the keep chain. + if let ResultSource::TeeChain = keep { + result.push(tee_response); + } + } + } + + // Handle all incoming chain responses that have a matching tee response which was just added in the previous block + for mut chain_response in chain_responses { + if let Some(request_id) = chain_response.request_id() { + // a requested response, compare against the other chain before sending it on. + if let Some(mut tee_response) = tee_by_request_id.remove(&request_id) { + match keep { + ResultSource::RegularChain => { + if tee_response != chain_response { + on_mismatch(&mut chain_response, tee_response); + } + result.push(chain_response); + } + ResultSource::TeeChain => { + if tee_response != chain_response { + on_mismatch(&mut tee_response, chain_response); + } + result.push(tee_response); + } + } + } else { + chain_by_request_id.insert(request_id, chain_response); + } + } else { + // unrequested response, so just send it on if its from the keep chain. + if let ResultSource::RegularChain = keep { + result.push(chain_response); + } + } + } + } } + result } +} +impl Tee { async fn ignore_behaviour<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { let result_source: ResultSource = self.result_source.load(Ordering::Relaxed); match result_source {