From feece72e300c5ef95cb600dcd35434ee24b51c8f Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Mon, 19 Feb 2024 10:57:45 +1100 Subject: [PATCH] Add dummy requests to better support certain transforms --- .../redis_int_tests/basic_driver_tests.rs | 4 +- shotover-proxy/tests/redis_int_tests/mod.rs | 8 +- shotover/benches/benches/chain.rs | 11 +-- shotover/src/codec/cassandra.rs | 5 ++ shotover/src/message/mod.rs | 26 ++++-- .../src/transforms/cassandra/connection.rs | 78 +++++++++-------- shotover/src/transforms/debug/returner.rs | 46 +++++----- .../tuneable_consistency_scatter.rs | 4 +- shotover/src/transforms/filter.rs | 85 +++++++------------ shotover/src/transforms/loopback.rs | 7 +- shotover/src/transforms/throttling.rs | 58 ++++++------- .../util/cluster_connection_pool.rs | 59 ++++++++----- 12 files changed, 208 insertions(+), 183 deletions(-) diff --git a/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs b/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs index 6748be0b2..17fd67064 100644 --- a/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs +++ b/shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs @@ -1161,7 +1161,7 @@ pub async fn test_cluster_replication( replication_connection: &mut ClusterConnection, ) { // According to the coalesce config the writes are only flushed to the replication cluster after 2000 total writes pass through shotover - for i in 0..1000 { + for i in 0..500 { // 2000 writes havent occured yet so this must be true assert!( replication_connection.get::<&str, i32>("foo").is_err(), @@ -1189,7 +1189,7 @@ pub async fn test_cluster_replication( // although we do need to account for the race condition of shotover returning a response before flushing to the replication cluster let mut value1 = Ok(1); // These dummy values are fine because they get overwritten on the first loop let mut value2 = Ok(b"".to_vec()); - for _ in 0..100 { + for _ in 0..200 { sleep(Duration::from_millis(100)); value1 = replication_connection.get("foo"); value2 = replication_connection.get("bar"); diff --git a/shotover-proxy/tests/redis_int_tests/mod.rs b/shotover-proxy/tests/redis_int_tests/mod.rs index e33799534..4f711e3c4 100644 --- a/shotover-proxy/tests/redis_int_tests/mod.rs +++ b/shotover-proxy/tests/redis_int_tests/mod.rs @@ -334,11 +334,5 @@ async fn cluster_dr() { test_dr_auth().await; run_all_cluster_hiding(&mut connection, &mut flusher).await; - shotover - .shutdown_and_then_consume_events(&[EventMatcher::new() - .with_level(Level::Error) - .with_target("shotover::transforms::filter") - .with_message("The current filter transform implementation does not obey the current transform invariants. see https://github.com/shotover/shotover-proxy/issues/499") - ]) - .await; + shotover.shutdown_and_then_consume_events(&[]).await; } diff --git a/shotover/benches/benches/chain.rs b/shotover/benches/benches/chain.rs index 3639f09e0..b05037845 100644 --- a/shotover/benches/benches/chain.rs +++ b/shotover/benches/benches/chain.rs @@ -6,7 +6,7 @@ use hex_literal::hex; use shotover::frame::cassandra::{parse_statement_single, Tracing}; use shotover::frame::RedisFrame; use shotover::frame::{CassandraFrame, CassandraOperation, Frame}; -use shotover::message::{Message, ProtocolType, QueryType}; +use shotover::message::{Message, MessageIdMap, ProtocolType, QueryType}; use shotover::transforms::cassandra::peers_rewrite::CassandraPeersRewrite; use shotover::transforms::chain::{TransformChain, TransformChainBuilder}; use shotover::transforms::debug::returner::{DebugReturner, Response}; @@ -70,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { vec![ Box::new(QueryTypeFilter { filter: Filter::DenyList(vec![QueryType::Read]), + filtered_requests: MessageIdMap::default(), }), Box::new(DebugReturner::new(Response::Redis("a".into()))), ], @@ -106,12 +107,12 @@ fn criterion_benchmark(c: &mut Criterion) { let chain = TransformChainBuilder::new( vec![ Box::new(RedisTimestampTagger::new()), - Box::new(DebugReturner::new(Response::Message(vec![ - Message::from_frame(Frame::Redis(RedisFrame::Array(vec![ + Box::new(DebugReturner::new(Response::Message(Message::from_frame( + Frame::Redis(RedisFrame::Array(vec![ RedisFrame::BulkString(Bytes::from_static(b"1")), // real frame RedisFrame::BulkString(Bytes::from_static(b"1")), // timestamp - ]))), - ]))), + ])), + )))), ], "bench", ); diff --git a/shotover/src/codec/cassandra.rs b/shotover/src/codec/cassandra.rs index ce4d839dc..f14e5b004 100644 --- a/shotover/src/codec/cassandra.rs +++ b/shotover/src/codec/cassandra.rs @@ -779,6 +779,11 @@ impl CassandraEncoder { compression: Compression, handshake_complete: bool, ) -> Result<()> { + if m.is_dummy() { + // skip dummy messages + return Ok(()); + } + match (version, handshake_complete) { (Version::V5, true) => { match compression { diff --git a/shotover/src/message/mod.rs b/shotover/src/message/mod.rs index 22d3a2070..1fa6d5e2d 100644 --- a/shotover/src/message/mod.rs +++ b/shotover/src/message/mod.rs @@ -421,12 +421,30 @@ impl Message { }, } } + /// Set this `Message` to a dummy frame so that the message will never reach the client or DB. + /// For requests, the dummy frame will be dropped when it reaches the Sink. + /// Additionally a corresponding dummy response will be generated with its request_id set to the requests id. + /// For responses, the dummy frame will be dropped when it reaches the Source. + pub fn replace_with_dummy(&mut self) { + self.inner = Some(MessageInner::Modified { + frame: Frame::Dummy, + }); + } + + pub fn is_dummy(&self) -> bool { + matches!( + self.inner, + Some(MessageInner::Modified { + frame: Frame::Dummy + }) + ) + } /// Set this `Message` to a backpressure response - pub fn set_backpressure(&mut self) -> Result<()> { + pub fn to_backpressure(&mut self) -> Result { let metadata = self.metadata()?; - *self = Message::from_frame_at_instant( + Ok(Message::from_frame_at_instant( match metadata { #[cfg(feature = "cassandra")] Metadata::Cassandra(metadata) => Frame::Cassandra(metadata.backpressure_response()), @@ -440,9 +458,7 @@ impl Message { // reachable with feature = cassandra #[allow(unreachable_code)] self.received_from_source_or_sink_at, - ); - - Ok(()) + )) } // Retrieves the stream_id without parsing the rest of the frame. diff --git a/shotover/src/transforms/cassandra/connection.rs b/shotover/src/transforms/cassandra/connection.rs index dca8316b5..4d84f242f 100644 --- a/shotover/src/transforms/cassandra/connection.rs +++ b/shotover/src/transforms/cassandra/connection.rs @@ -25,7 +25,7 @@ use tracing::Instrument; struct Request { message: Message, return_chan: oneshot::Sender, - stream_id: i16, + stream_id: Option, } pub type Response = Result; @@ -36,16 +36,17 @@ pub struct ResponseError { #[source] pub cause: anyhow::Error, pub destination: SocketAddr, - pub stream_id: i16, + pub stream_id: Option, } impl ResponseError { pub fn to_response(&self, version: Version) -> Message { - Message::from_frame(Frame::Cassandra(CassandraFrame::shotover_error( - self.stream_id, - version, - &format!("{}", self), - ))) + match self.stream_id { + Some(stream_id) => Message::from_frame(Frame::Cassandra( + CassandraFrame::shotover_error(stream_id, version, &format!("{}", self)), + )), + None => Message::from_frame(Frame::Dummy), + } } } @@ -53,7 +54,8 @@ impl ResponseError { struct ReturnChannel { return_chan: oneshot::Sender, request_id: MessageId, - stream_id: i16, + stream_id: Option, + is_dummy: bool, } #[derive(Clone, Derivative)] @@ -148,19 +150,16 @@ impl CassandraConnection { /// But this indicates a bug within CassandraConnection and should be fixed here. pub fn send(&self, message: Message) -> Result> { let (return_chan_tx, return_chan_rx) = oneshot::channel(); - // Convert the message to `Request` and send upstream - if let Some(stream_id) = message.stream_id() { - self.connection - .send(Request { - message, - return_chan: return_chan_tx, - stream_id, - }) - .map(|_| return_chan_rx) - .map_err(|x| x.into()) - } else { - Err(anyhow!("no cassandra frame found")) - } + let stream_id = message.stream_id(); + self.connection + .send(Request { + message, + return_chan: return_chan_tx, + // TODO: delete the stream_id field, we wont need it when we are handling cassandra out of order + stream_id, + }) + .map(|_| return_chan_rx) + .map_err(|x| x.into()) } } @@ -187,6 +186,7 @@ async fn tx_process( loop { if let Some(request) = out_rx.recv().await { let request_id = request.message.id(); + let is_dummy = request.message.is_dummy(); if let Some(error) = &connection_dead_error { send_error_to_request(request.return_chan, request.stream_id, destination, error); } else if let Err(error) = in_w.send(vec![request.message]).await { @@ -197,6 +197,7 @@ async fn tx_process( return_chan: request.return_chan, stream_id: request.stream_id, request_id, + is_dummy, }) { let error = rx_process_has_shutdown_rx .try_recv() @@ -235,7 +236,7 @@ async fn tx_process( fn send_error_to_request( return_chan: oneshot::Sender, - stream_id: i16, + stream_id: Option, destination: SocketAddr, error: &str, ) { @@ -273,10 +274,11 @@ async fn rx_process( // In order to handle that we have two seperate maps. // // We store the sender here if we receive from the tx_process task first - let mut from_tx_process: HashMap, MessageId)> = HashMap::new(); + let mut from_tx_process: HashMap, (oneshot::Sender, MessageId)> = + HashMap::new(); // We store the response message here if we receive from the server first. - let mut from_server: HashMap = HashMap::new(); + let mut from_server: HashMap, Message> = HashMap::new(); loop { tokio::select! { @@ -289,7 +291,8 @@ async fn rx_process( if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { pushed_messages_tx.send(vec![m]).ok(); } - } else if let Some(stream_id) = m.stream_id() { + } else { + let stream_id = m.stream_id(); match from_tx_process.remove(&stream_id) { None => { from_server.insert(stream_id, m); @@ -322,14 +325,21 @@ async fn rx_process( } }, original_request = return_rx.recv() => { - if let Some(ReturnChannel { return_chan, stream_id,request_id }) = original_request { - match from_server.remove(&stream_id) { - None => { - from_tx_process.insert(stream_id, (return_chan, request_id)); - } - Some(mut m) => { - m.set_request_id(request_id); - return_chan.send(Ok(m)).ok(); + if let Some(ReturnChannel { return_chan, stream_id, request_id, is_dummy }) = original_request { + if is_dummy { + // There will be no response from the DB for this message so we need to generate a dummy response instead. + let mut response = Message::from_frame(Frame::Dummy); + response.set_request_id(request_id); + return_chan.send(Ok(response)).ok(); + } else { + match from_server.remove(&stream_id) { + None => { + from_tx_process.insert(stream_id, (return_chan, request_id)); + } + Some(mut m) => { + m.set_request_id(request_id); + return_chan.send(Ok(m)).ok(); + } } } } else { @@ -346,7 +356,7 @@ async fn rx_process( async fn send_errors_and_shutdown( mut return_rx: mpsc::UnboundedReceiver, - mut waiting: HashMap, MessageId)>, + mut waiting: HashMap, (oneshot::Sender, MessageId)>, rx_process_has_shutdown_tx: oneshot::Sender, destination: SocketAddr, message: &str, diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index efeb183bc..47a2fef3c 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,4 +1,4 @@ -use crate::message::Messages; +use crate::message::{Message, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -24,7 +24,7 @@ impl TransformConfig for DebugReturnerConfig { #[serde(deny_unknown_fields)] pub enum Response { #[serde(skip)] - Message(Messages), + Message(Message), #[cfg(feature = "redis")] Redis(String), Fail, @@ -61,24 +61,28 @@ impl Transform for DebugReturner { NAME } - async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { - match &self.response { - Response::Message(message) => Ok(message.clone()), - #[cfg(feature = "redis")] - Response::Redis(string) => { - use crate::frame::{Frame, RedisFrame}; - use crate::message::Message; - Ok(requests_wrapper - .requests - .iter() - .map(|_| { - Message::from_frame(Frame::Redis(RedisFrame::BulkString( - string.to_string().into(), - ))) - }) - .collect()) - } - Response::Fail => Err(anyhow!("Intentional Fail")), - } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { + requests_wrapper + .requests + .iter_mut() + .map(|request| match &self.response { + Response::Message(message) => { + let mut message = message.clone(); + message.set_request_id(request.id()); + Ok(message) + } + #[cfg(feature = "redis")] + Response::Redis(string) => { + use crate::frame::{Frame, RedisFrame}; + use crate::message::Message; + let mut message = Message::from_frame(Frame::Redis(RedisFrame::BulkString( + string.to_string().into(), + ))); + message.set_request_id(request.id()); + Ok(message) + } + Response::Fail => Err(anyhow!("Intentional Fail")), + }) + .collect() } } diff --git a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs index 5899153bb..b5816fa34 100644 --- a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -317,9 +317,7 @@ mod scatter_transform_tests { #[tokio::test(flavor = "multi_thread")] async fn test_scatter_success() { - let response = vec![Message::from_frame(Frame::Redis(RedisFrame::BulkString( - "OK".into(), - )))]; + let response = Message::from_frame(Frame::Redis(RedisFrame::BulkString("OK".into()))); let wrapper = Wrapper::new_test(vec![Message::from_frame(Frame::Redis( RedisFrame::BulkString(Bytes::from_static(b"foo")), diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index 1d16d21f2..fc5ae28e5 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -1,11 +1,8 @@ -use crate::message::{Message, Messages, QueryType}; +use crate::message::{Message, MessageIdMap, Messages, QueryType}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use std::sync::atomic::{AtomicBool, Ordering}; - -static SHOWN_ERROR: AtomicBool = AtomicBool::new(false); #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields)] @@ -17,6 +14,7 @@ pub enum Filter { #[derive(Debug, Clone)] pub struct QueryTypeFilter { pub filter: Filter, + pub filtered_requests: MessageIdMap, } #[derive(Serialize, Deserialize, Debug)] @@ -33,6 +31,7 @@ impl TransformConfig for QueryTypeFilterConfig { async fn get_builder(&self, _chain_name: String) -> Result> { Ok(Box::new(QueryTypeFilter { filter: self.filter.clone(), + filtered_requests: MessageIdMap::default(), })) } } @@ -54,60 +53,33 @@ impl Transform for QueryTypeFilter { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - let removed_indexes: Result> = requests_wrapper - .requests - .iter_mut() - .enumerate() - .filter_map(|(i, m)| match self.filter { - Filter::AllowList(ref allow_list) => { - if allow_list.contains(&m.get_query_type()) { - None - } else { - Some((i, m)) - } - } - Filter::DenyList(ref deny_list) => { - if deny_list.contains(&m.get_query_type()) { - Some((i, m)) - } else { - None - } - } - }) - .map(|(i, m)| { - Ok(( - i, - m.to_error_response("Message was filtered out by shotover".to_owned()) - .map_err(|e| e.context("Failed to filter message {e:?}"))?, - )) - }) - .collect(); - - let removed_indexes = removed_indexes?; - - for (i, _) in removed_indexes.iter().rev() { - requests_wrapper.requests.remove(*i); + for request in requests_wrapper.requests.iter_mut() { + let filter_out = match &self.filter { + Filter::AllowList(allow_list) => !allow_list.contains(&request.get_query_type()), + Filter::DenyList(deny_list) => deny_list.contains(&request.get_query_type()), + }; + + if filter_out { + self.filtered_requests.insert( + request.id(), + request + .to_error_response("Message was filtered out by shotover".to_owned()) + .map_err(|e| e.context("Failed to filter message"))?, + ); + request.replace_with_dummy(); + } } - let mut shown_error = SHOWN_ERROR.load(Ordering::Relaxed); - - requests_wrapper - .call_next_transform() - .await - .map(|mut messages| { - - for (i, message) in removed_indexes.into_iter() { - if i <= messages.len() { - messages.insert(i, message); - } - else if !shown_error{ - tracing::error!("The current filter transform implementation does not obey the current transform invariants. see https://github.com/shotover/shotover-proxy/issues/499"); - shown_error = true; - SHOWN_ERROR.store(true , Ordering::Relaxed); - } + let mut responses = requests_wrapper.call_next_transform().await?; + for response in responses.iter_mut() { + if let Some(request_id) = response.request_id() { + if let Some(error_response) = self.filtered_requests.remove(&request_id) { + *response = error_response; } - messages - }) + } + } + + Ok(responses) } } @@ -116,6 +88,7 @@ mod test { use super::Filter; use crate::frame::Frame; use crate::frame::RedisFrame; + use crate::message::MessageIdMap; use crate::message::{Message, QueryType}; use crate::transforms::chain::TransformAndMetrics; use crate::transforms::filter::QueryTypeFilter; @@ -126,6 +99,7 @@ mod test { async fn test_filter_denylist() { let mut filter_transform = QueryTypeFilter { filter: Filter::DenyList(vec![QueryType::Read]), + filtered_requests: MessageIdMap::default(), }; let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; @@ -180,6 +154,7 @@ mod test { async fn test_filter_allowlist() { let mut filter_transform = QueryTypeFilter { filter: Filter::AllowList(vec![QueryType::Write]), + filtered_requests: MessageIdMap::default(), }; let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; diff --git a/shotover/src/transforms/loopback.rs b/shotover/src/transforms/loopback.rs index 0e39ba1ec..7345a2832 100644 --- a/shotover/src/transforms/loopback.rs +++ b/shotover/src/transforms/loopback.rs @@ -28,7 +28,12 @@ impl Transform for Loopback { NAME } - async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { + // This transform ultimately doesnt make a lot of sense semantically + // but make a vague attempt to follow transform invariants anyway. + for request in &mut requests_wrapper.requests { + request.set_request_id(request.id()); + } Ok(requests_wrapper.requests) } } diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 2bae4b1db..8dde15487 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, Messages}; +use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; @@ -29,6 +29,7 @@ impl TransformConfig for RequestThrottlingConfig { self.max_requests_per_second, ))), max_requests_per_second: self.max_requests_per_second, + throttled_requests: MessageIdMap::default(), })) } } @@ -37,6 +38,7 @@ impl TransformConfig for RequestThrottlingConfig { pub struct RequestThrottling { limiter: Arc>, max_requests_per_second: NonZeroU32, + throttled_requests: MessageIdMap, } impl TransformBuilder for RequestThrottling { @@ -67,43 +69,37 @@ impl Transform for RequestThrottling { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - // extract throttled messages from the requests_wrapper - let throttled_messages: Vec<(Message, usize)> = (0..requests_wrapper.requests.len()) - .rev() - .filter_map(|i| { - match self - .limiter - .check_n(requests_wrapper.requests[i].cell_count().ok()?) - { - // occurs if all cells can be accommodated and - Ok(Ok(())) => None, + for request in &mut requests_wrapper.requests { + if let Ok(cell_count) = request.cell_count() { + let throttle = match self.limiter.check_n(cell_count) { + // occurs if all cells can be accommodated + Ok(Ok(())) => false, // occurs if not all cells can be accommodated. - Ok(Err(_)) => { - let message = requests_wrapper.requests.remove(i); - Some((message, i)) - } + Ok(Err(_)) => true, // occurs when the batch can never go through, meaning the rate limiter's quota's burst size is too low for the given number of cells to be ever allowed through Err(_) => { tracing::warn!("A message was received that could never have been successfully delivered since it contains more sub messages than can ever be allowed through via the `RequestThrottling` transforms `max_requests_per_second` configuration."); - let message = requests_wrapper.requests.remove(i); - Some((message, i)) + true } + }; + if throttle { + self.throttled_requests + .insert(request.id(), request.to_backpressure()?); + request.replace_with_dummy(); } - }) - .collect(); + } + } - // if every message got backpressured we can skip this - let mut responses = if !requests_wrapper.requests.is_empty() { - // send allowed messages to Cassandra - requests_wrapper.call_next_transform().await? - } else { - vec![] - }; + // send allowed messages to Cassandra + let mut responses = requests_wrapper.call_next_transform().await?; - // reinsert backpressure error responses back into responses - for (mut message, i) in throttled_messages.into_iter().rev() { - message.set_backpressure()?; - responses.insert(i, message); + // replace dummy responses with throttle messages + for response in responses.iter_mut() { + if let Some(request_id) = response.request_id() { + if let Some(error_response) = self.throttled_requests.remove(&request_id) { + *response = error_response; + } + } } Ok(responses) @@ -124,6 +120,7 @@ mod test { Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(20u32)))), max_requests_per_second: nonzero!(20u32), + throttled_requests: MessageIdMap::default(), }), Box::::default(), ], @@ -146,6 +143,7 @@ mod test { Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(100u32)))), max_requests_per_second: nonzero!(100u32), + throttled_requests: MessageIdMap::default(), }), Box::::default(), ], diff --git a/shotover/src/transforms/util/cluster_connection_pool.rs b/shotover/src/transforms/util/cluster_connection_pool.rs index 949eb33ff..1aaafa369 100644 --- a/shotover/src/transforms/util/cluster_connection_pool.rs +++ b/shotover/src/transforms/util/cluster_connection_pool.rs @@ -1,5 +1,7 @@ use super::Response; use crate::codec::{CodecBuilder, CodecWriteError, DecoderHalf, EncoderHalf}; +use crate::frame::Frame; +use crate::message::{Message, MessageId}; use crate::tcp; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::util::{ConnectionError, Request}; @@ -193,6 +195,7 @@ pub fn spawn_read_write_tasks< stream_rx: R, stream_tx: W, ) -> Connection { + let (dummy_request_tx, dummy_request_rx) = tokio::sync::mpsc::unbounded_channel(); let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel(); let (return_tx, return_rx) = tokio::sync::mpsc::unbounded_channel(); let (closed_tx, closed_rx) = tokio::sync::oneshot::channel(); @@ -201,7 +204,7 @@ pub fn spawn_read_write_tasks< tokio::spawn(async move { tokio::select! { - result = tx_process(stream_tx, out_rx, return_tx, encoder) => if let Err(e) = result { + result = tx_process(dummy_request_tx, stream_tx, out_rx, return_tx, encoder) => if let Err(e) = result { trace!("connection write-closed with error: {:?}", e); } else { trace!("connection write-closed gracefully"); @@ -214,7 +217,7 @@ pub fn spawn_read_write_tasks< tokio::spawn( async move { - if let Err(e) = rx_process(stream_rx, return_rx, decoder).await { + if let Err(e) = rx_process(dummy_request_rx, stream_rx, return_rx, decoder).await { trace!("connection read-closed with error: {:?}", e); } else { trace!("connection read-closed gracefully"); @@ -230,6 +233,7 @@ pub fn spawn_read_write_tasks< } async fn tx_process( + dummy_request_tx: UnboundedSender, write: W, out_rx: UnboundedReceiver, return_tx: UnboundedSender, @@ -237,6 +241,9 @@ async fn tx_process( ) -> Result<(), CodecWriteError> { let writer = FramedWrite::new(write, codec); let rx_stream = UnboundedReceiverStream::new(out_rx).map(|x| { + if x.message.is_dummy() { + dummy_request_tx.send(x.message.id()).ok(); + } let ret = Ok(vec![x.message]); return_tx .send(x.return_chan) @@ -249,6 +256,7 @@ async fn tx_process( type ReturnChan = Option>; async fn rx_process( + mut dummy_request_rx: UnboundedReceiver, read: R, mut return_rx: UnboundedReceiver, codec: C, @@ -258,29 +266,40 @@ async fn rx_process( // TODO: This reader.next() may perform reads after tx_process has shutdown the write half. // This may result in unexpected ConnectionReset errors. // refer to the cassandra connection logic. - while let Some(responses) = reader.next().await { - match responses { - Ok(responses) => { - for response_message in responses { - loop { - if let Some(Some(ret)) = return_rx.recv().await { - // If the receiver hangs up, just silently ignore - let _ = ret.send(Response { - response: Ok(response_message), - }); - break; + loop { + tokio::select!( + responses = reader.next() => { + match responses { + Some(Ok(responses)) => { + for response_message in responses { + if let Some(Some(ret)) = return_rx.recv().await { + // If the receiver hangs up, just silently ignore + ret.send(Response { + response: Ok(response_message), + }).ok(); + } } } + Some(Err(e)) => return Err(anyhow!("Couldn't decode message from upstream host {e:?}")), + None => { + // connection closed + break; + } } } - Err(e) => { - debug!("Couldn't decode message from upstream host {:?}", e); - return Err(anyhow!( - "Couldn't decode message from upstream host {:?}", - e - )); + request_id = dummy_request_rx.recv() => { + match request_id { + Some(request_id) => if let Some(Some(ret)) = return_rx.recv().await { + let mut response= Message::from_frame(Frame::Dummy); + response.set_request_id(request_id); + ret.send(Response { response: Ok(response) }).ok(); + } + None => { + break; + } + } } - } + ) } Ok(())