From ba0fffb1fa385ff4cab8896605051386a95645bf Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Mon, 19 Feb 2024 10:57:45 +1100 Subject: [PATCH] Add dummy requests to better support certain transforms --- shotover/benches/benches/chain.rs | 11 ++- shotover/src/message/mod.rs | 18 ++++ shotover/src/transforms/debug/returner.rs | 46 +++++----- .../tuneable_consistency_scatter.rs | 4 +- shotover/src/transforms/filter.rs | 86 +++++++------------ shotover/src/transforms/loopback.rs | 7 +- shotover/src/transforms/throttling.rs | 54 ++++++------ .../util/cluster_connection_pool.rs | 61 ++++++++----- 8 files changed, 156 insertions(+), 131 deletions(-) diff --git a/shotover/benches/benches/chain.rs b/shotover/benches/benches/chain.rs index 3639f09e0..1226b172b 100644 --- a/shotover/benches/benches/chain.rs +++ b/shotover/benches/benches/chain.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use bytes::Bytes; use cassandra_protocol::compression::Compression; use cassandra_protocol::{consistency::Consistency, frame::Version, query::QueryParams}; @@ -70,6 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { vec![ Box::new(QueryTypeFilter { filter: Filter::DenyList(vec![QueryType::Read]), + filtered_requests: HashMap::new(), }), Box::new(DebugReturner::new(Response::Redis("a".into()))), ], @@ -106,12 +109,12 @@ fn criterion_benchmark(c: &mut Criterion) { let chain = TransformChainBuilder::new( vec![ Box::new(RedisTimestampTagger::new()), - Box::new(DebugReturner::new(Response::Message(vec![ - Message::from_frame(Frame::Redis(RedisFrame::Array(vec![ + Box::new(DebugReturner::new(Response::Message(Message::from_frame( + Frame::Redis(RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"1")), // real frame RedisFrame::BulkString(Bytes::from_static(b"1")), // timestamp - ]))), - ]))), + ])), + )))), ], "bench", ); diff --git a/shotover/src/message/mod.rs b/shotover/src/message/mod.rs index aacc87b6b..321043a6c 100644 --- a/shotover/src/message/mod.rs +++ b/shotover/src/message/mod.rs @@ -416,6 +416,24 @@ impl Message { }, } } + /// Set this `Message` to a dummy frame so that the message will never reach the client or DB. + /// For requests, the dummy frame will be dropped when it reaches the Sink. + /// Additionally a corresponding dummy response will be generated with its request_id set to the requests id. + /// For responses, the dummy frame will be dropped when it reaches the Source. + pub fn replace_with_dummy(&mut self) { + self.inner = Some(MessageInner::Modified { + frame: Frame::Dummy, + }); + } + + pub fn is_dummy(&self) -> bool { + matches!( + self.inner, + Some(MessageInner::Modified { + frame: Frame::Dummy + }) + ) + } /// Set this `Message` to a backpressure response pub fn set_backpressure(&mut self) -> Result<()> { diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index efeb183bc..47a2fef3c 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,4 +1,4 @@ -use crate::message::Messages; +use crate::message::{Message, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -24,7 +24,7 @@ impl TransformConfig for DebugReturnerConfig { #[serde(deny_unknown_fields)] pub enum Response { #[serde(skip)] - Message(Messages), + Message(Message), #[cfg(feature = "redis")] Redis(String), Fail, @@ -61,24 +61,28 @@ impl Transform for DebugReturner { NAME } - async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { - match &self.response { - Response::Message(message) => Ok(message.clone()), - #[cfg(feature = "redis")] - Response::Redis(string) => { - use crate::frame::{Frame, RedisFrame}; - use crate::message::Message; - Ok(requests_wrapper - .requests - .iter() - .map(|_| { - Message::from_frame(Frame::Redis(RedisFrame::BulkString( - string.to_string().into(), - ))) - }) - .collect()) - } - Response::Fail => Err(anyhow!("Intentional Fail")), - } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { + requests_wrapper + .requests + .iter_mut() + .map(|request| match &self.response { + Response::Message(message) => { + let mut message = message.clone(); + message.set_request_id(request.id()); + Ok(message) + } + #[cfg(feature = "redis")] + Response::Redis(string) => { + use crate::frame::{Frame, RedisFrame}; + use crate::message::Message; + let mut message = Message::from_frame(Frame::Redis(RedisFrame::BulkString( + string.to_string().into(), + ))); + message.set_request_id(request.id()); + Ok(message) + } + Response::Fail => Err(anyhow!("Intentional Fail")), + }) + .collect() } } diff --git a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs index 5899153bb..b5816fa34 100644 --- a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -317,9 +317,7 @@ mod scatter_transform_tests { #[tokio::test(flavor = "multi_thread")] async fn test_scatter_success() { - let response = vec![Message::from_frame(Frame::Redis(RedisFrame::BulkString( - "OK".into(), - )))]; + let response = Message::from_frame(Frame::Redis(RedisFrame::BulkString("OK".into()))); let wrapper = Wrapper::new_test(vec![Message::from_frame(Frame::Redis( RedisFrame::BulkString(Bytes::from_static(b"foo")), diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index 1d16d21f2..f6c6b7bb6 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -1,11 +1,9 @@ -use crate::message::{Message, Messages, QueryType}; +use crate::message::{Message, MessageId, Messages, QueryType}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use std::sync::atomic::{AtomicBool, Ordering}; - -static SHOWN_ERROR: AtomicBool = AtomicBool::new(false); +use std::collections::HashMap; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields)] @@ -17,6 +15,7 @@ pub enum Filter { #[derive(Debug, Clone)] pub struct QueryTypeFilter { pub filter: Filter, + pub filtered_requests: HashMap, } #[derive(Serialize, Deserialize, Debug)] @@ -33,6 +32,7 @@ impl TransformConfig for QueryTypeFilterConfig { async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(QueryTypeFilter { filter: self.filter.clone(), + filtered_requests: HashMap::new(), })) } } @@ -54,60 +54,33 @@ impl Transform for QueryTypeFilter { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - let removed_indexes: Result> = requests_wrapper - .requests - .iter_mut() - .enumerate() - .filter_map(|(i, m)| match self.filter { - Filter::AllowList(ref allow_list) => { - if allow_list.contains(&m.get_query_type()) { - None - } else { - Some((i, m)) - } - } - Filter::DenyList(ref deny_list) => { - if deny_list.contains(&m.get_query_type()) { - Some((i, m)) - } else { - None - } - } - }) - .map(|(i, m)| { - Ok(( - i, - m.to_error_response("Message was filtered out by shotover".to_owned()) - .map_err(|e| e.context("Failed to filter message {e:?}"))?, - )) - }) - .collect(); - - let removed_indexes = removed_indexes?; - - for (i, _) in removed_indexes.iter().rev() { - requests_wrapper.requests.remove(*i); + for request in requests_wrapper.requests.iter_mut() { + let filter_out = match &self.filter { + Filter::AllowList(allow_list) => !allow_list.contains(&request.get_query_type()), + Filter::DenyList(deny_list) => deny_list.contains(&request.get_query_type()), + }; + + if filter_out { + self.filtered_requests.insert( + request.id(), + request + .to_error_response("Message was filtered out by shotover".to_owned()) + .map_err(|e| e.context("Failed to filter message"))?, + ); + request.replace_with_dummy(); + } } - let mut shown_error = SHOWN_ERROR.load(Ordering::Relaxed); - - requests_wrapper - .call_next_transform() - .await - .map(|mut messages| { - - for (i, message) in removed_indexes.into_iter() { - if i <= messages.len() { - messages.insert(i, message); - } - else if !shown_error{ - tracing::error!("The current filter transform implementation does not obey the current transform invariants. see https://github.com/shotover/shotover-proxy/issues/499"); - shown_error = true; - SHOWN_ERROR.store(true , Ordering::Relaxed); - } + let mut responses = requests_wrapper.call_next_transform().await?; + for response in responses.iter_mut() { + if let Some(request_id) = response.request_id() { + if let Some(error_response) = self.filtered_requests.remove(&request_id) { + *response = error_response; } - messages - }) + } + } + + Ok(responses) } } @@ -121,11 +94,13 @@ mod test { use crate::transforms::filter::QueryTypeFilter; use crate::transforms::loopback::Loopback; use crate::transforms::{Transform, Wrapper}; + use std::collections::HashMap; #[tokio::test(flavor = "multi_thread")] async fn test_filter_denylist() { let mut filter_transform = QueryTypeFilter { filter: Filter::DenyList(vec![QueryType::Read]), + filtered_requests: HashMap::new(), }; let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; @@ -180,6 +155,7 @@ mod test { async fn test_filter_allowlist() { let mut filter_transform = QueryTypeFilter { filter: Filter::AllowList(vec![QueryType::Write]), + filtered_requests: HashMap::new(), }; let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; diff --git a/shotover/src/transforms/loopback.rs b/shotover/src/transforms/loopback.rs index 0e39ba1ec..7345a2832 100644 --- a/shotover/src/transforms/loopback.rs +++ b/shotover/src/transforms/loopback.rs @@ -28,7 +28,12 @@ impl Transform for Loopback { NAME } - async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { + // This transform ultimately doesnt make a lot of sense semantically + // but make a vague attempt to follow transform invariants anyway. + for request in &mut requests_wrapper.requests { + request.set_request_id(request.id()); + } Ok(requests_wrapper.requests) } } diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 2bae4b1db..9657ba373 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, Messages}; +use crate::message::{MessageId, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; @@ -10,6 +10,7 @@ use governor::{ }; use nonzero_ext::nonzero; use serde::{Deserialize, Serialize}; +use std::collections::HashSet; use std::num::NonZeroU32; use std::sync::Arc; @@ -29,6 +30,7 @@ impl TransformConfig for RequestThrottlingConfig { self.max_requests_per_second, ))), max_requests_per_second: self.max_requests_per_second, + throttled_requests: HashSet::new(), })) } } @@ -37,6 +39,7 @@ impl TransformConfig for RequestThrottlingConfig { pub struct RequestThrottling { limiter: Arc>, max_requests_per_second: NonZeroU32, + throttled_requests: HashSet, } impl TransformBuilder for RequestThrottling { @@ -67,43 +70,38 @@ impl Transform for RequestThrottling { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - // extract throttled messages from the requests_wrapper - let throttled_messages: Vec<(Message, usize)> = (0..requests_wrapper.requests.len()) - .rev() - .filter_map(|i| { - match self - .limiter - .check_n(requests_wrapper.requests[i].cell_count().ok()?) - { - // occurs if all cells can be accommodated and - Ok(Ok(())) => None, + for request in &mut requests_wrapper.requests { + if let Ok(cell_count) = request.cell_count() { + match self.limiter.check_n(cell_count) { + // occurs if all cells can be accommodated + Ok(Ok(())) => {} // occurs if not all cells can be accommodated. Ok(Err(_)) => { - let message = requests_wrapper.requests.remove(i); - Some((message, i)) + self.throttled_requests.insert(request.id()); + request.replace_with_dummy(); } // occurs when the batch can never go through, meaning the rate limiter's quota's burst size is too low for the given number of cells to be ever allowed through Err(_) => { tracing::warn!("A message was received that could never have been successfully delivered since it contains more sub messages than can ever be allowed through via the `RequestThrottling` transforms `max_requests_per_second` configuration."); - let message = requests_wrapper.requests.remove(i); - Some((message, i)) + self.throttled_requests.insert(request.id()); + request.replace_with_dummy(); } } - }) - .collect(); + } + } - // if every message got backpressured we can skip this - let mut responses = if !requests_wrapper.requests.is_empty() { - // send allowed messages to Cassandra - requests_wrapper.call_next_transform().await? - } else { - vec![] - }; + // send allowed messages to Cassandra + let mut responses = requests_wrapper.call_next_transform().await?; // reinsert backpressure error responses back into responses - for (mut message, i) in throttled_messages.into_iter().rev() { - message.set_backpressure()?; - responses.insert(i, message); + for response in responses.iter_mut() { + if response + .request_id() + .map(|id| self.throttled_requests.remove(&id)) + .unwrap_or(false) + { + response.set_backpressure()?; + } } Ok(responses) @@ -124,6 +122,7 @@ mod test { Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(20u32)))), max_requests_per_second: nonzero!(20u32), + throttled_requests: HashSet::new(), }), Box::::default(), ], @@ -146,6 +145,7 @@ mod test { Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(100u32)))), max_requests_per_second: nonzero!(100u32), + throttled_requests: HashSet::new(), }), Box::::default(), ], diff --git a/shotover/src/transforms/util/cluster_connection_pool.rs b/shotover/src/transforms/util/cluster_connection_pool.rs index 949eb33ff..8feb51b42 100644 --- a/shotover/src/transforms/util/cluster_connection_pool.rs +++ b/shotover/src/transforms/util/cluster_connection_pool.rs @@ -1,5 +1,7 @@ use super::Response; use crate::codec::{CodecBuilder, CodecWriteError, DecoderHalf, EncoderHalf}; +use crate::frame::Frame; +use crate::message::{Message, MessageId}; use crate::tcp; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::util::{ConnectionError, Request}; @@ -193,6 +195,7 @@ pub fn spawn_read_write_tasks< stream_rx: R, stream_tx: W, ) -> Connection { + let (dummy_request_tx, dummy_request_rx) = tokio::sync::mpsc::unbounded_channel(); let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel(); let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); let (closed_tx, closed_rx) = tokio::sync::oneshot::channel(); @@ -201,7 +204,7 @@ pub fn spawn_read_write_tasks< tokio::spawn(async move { tokio::select! { - result = tx_process(stream_tx, out_rx, return_tx, encoder) => if let Err(e) = result { + result = tx_process(dummy_request_tx, stream_tx, out_rx, return_tx, encoder) => if let Err(e) = result { trace!("connection write-closed with error: {:?}", e); } else { trace!("connection write-closed gracefully"); @@ -214,7 +217,7 @@ pub fn spawn_read_write_tasks< tokio::spawn( async move { - if let Err(e) = rx_process(stream_rx, return_rx, decoder).await { + if let Err(e) = rx_process(dummy_request_rx, stream_rx, return_rx, decoder).await { trace!("connection read-closed with error: {:?}", e); } else { trace!("connection read-closed gracefully"); @@ -230,6 +233,7 @@ pub fn spawn_read_write_tasks< } async fn tx_process( + dummy_request_tx: UnboundedSender, write: W, out_rx: UnboundedReceiver, return_tx: UnboundedSender, @@ -237,6 +241,9 @@ async fn tx_process( ) -> Result<(), CodecWriteError> { let writer = FramedWrite::new(write, codec); let rx_stream = UnboundedReceiverStream::new(out_rx).map(|x| { + if x.message.is_dummy() { + dummy_request_tx.send(x.message.id()).ok(); + } let ret = Ok(vec![x.message]); return_tx .send(x.return_chan) @@ -249,6 +256,7 @@ async fn tx_process( type ReturnChan = Option>; async fn rx_process( + mut dummy_request_rx: UnboundedReceiver, read: R, mut return_rx: UnboundedReceiver, codec: C, @@ -258,29 +266,42 @@ async fn rx_process( // TODO: This reader.next() may perform reads after tx_process has shutdown the write half. // This may result in unexpected ConnectionReset errors. // refer to the cassandra connection logic. - while let Some(responses) = reader.next().await { - match responses { - Ok(responses) => { - for response_message in responses { - loop { - if let Some(Some(ret)) = return_rx.recv().await { - // If the receiver hangs up, just silently ignore - let _ = ret.send(Response { - response: Ok(response_message), - }); - break; + loop { + tokio::select!( + responses = reader.next() => { + tracing::info!("regular path"); + match responses { + Some(Ok(responses)) => { + for response_message in responses { + if let Some(Some(ret)) = return_rx.recv().await { + // If the receiver hangs up, just silently ignore + ret.send(Response { + response: Ok(response_message), + }).ok(); + } } } + Some(Err(e)) => return Err(anyhow!("Couldn't decode message from upstream host {e:?}")), + None => { + // connection closed + break; + } } } - Err(e) => { - debug!("Couldn't decode message from upstream host {:?}", e); - return Err(anyhow!( - "Couldn't decode message from upstream host {:?}", - e - )); + request_id = dummy_request_rx.recv() => { + tracing::info!("dummy {request_id:?}"); + match request_id { + Some(request_id) => if let Some(Some(ret)) = return_rx.recv().await { + let mut response= Message::from_frame(Frame::Dummy); + response.set_request_id(request_id); + ret.send(Response { response: Ok(response) }).ok(); + } + None => { + break; + } + } } - } + ) } Ok(())