From 7cf784bf35a12ba9880579e586f63d869f541a7e Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Fri, 9 Feb 2024 10:21:58 +1100 Subject: [PATCH] codec send Message instead of Vec --- shotover/benches/benches/codec/kafka.rs | 6 +- shotover/src/codec/cassandra.rs | 59 ++++---- shotover/src/codec/kafka.rs | 4 +- shotover/src/codec/mod.rs | 6 +- shotover/src/codec/opensearch.rs | 12 +- shotover/src/codec/redis.rs | 8 +- shotover/src/server.rs | 43 ++++-- .../src/transforms/cassandra/connection.rs | 30 ++-- .../transforms/redis/cluster_ports_rewrite.rs | 7 +- shotover/src/transforms/redis/sink_cluster.rs | 7 +- shotover/src/transforms/redis/sink_single.rs | 140 +++++++++--------- .../util/cluster_connection_pool.rs | 22 ++- 12 files changed, 172 insertions(+), 172 deletions(-) diff --git a/shotover/benches/benches/codec/kafka.rs b/shotover/benches/benches/codec/kafka.rs index 94f92d33d..f741fa9ed 100644 --- a/shotover/benches/benches/codec/kafka.rs +++ b/shotover/benches/benches/codec/kafka.rs @@ -40,8 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { ) }, |((mut decoder, encoder), mut input)| { - let mut message = - decoder.decode(&mut input).unwrap().unwrap().pop().unwrap(); + let mut message = decoder.decode(&mut input).unwrap().unwrap(); message.frame(); // avoid measuring any drops @@ -57,8 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { let (mut decoder, encoder) = KafkaCodecBuilder::new(Direction::Source, "kafka".to_owned()).build(); let mut input = input.clone(); - let mut message = - decoder.decode(&mut input).unwrap().unwrap().pop().unwrap(); + let mut message = decoder.decode(&mut input).unwrap().unwrap(); message.frame(); assert!(decoder.decode(&mut input).unwrap().is_none()); (decoder, encoder, message) diff --git a/shotover/src/codec/cassandra.rs b/shotover/src/codec/cassandra.rs index ce4d839dc..2655d2473 100644 --- a/shotover/src/codec/cassandra.rs +++ b/shotover/src/codec/cassandra.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Instant; +use std::vec::IntoIter; use tokio_util::codec::{Decoder, Encoder}; use tracing::info; @@ -179,6 +180,7 @@ pub struct CassandraDecoder { version_counter: VersionCounter, expected_payload_len: Option, payload_buffer: BytesMut, + v5_envelopes: IntoIter, } impl CassandraDecoder { @@ -198,6 +200,7 @@ impl CassandraDecoder { version_counter, payload_buffer: BytesMut::new(), expected_payload_len: None, + v5_envelopes: vec![].into_iter(), } } } @@ -243,7 +246,7 @@ impl CassandraDecoder { compression: Compression, handshake_complete: bool, received_at: Instant, - ) -> Result> { + ) -> Result { match (version, handshake_complete) { (Version::V5, true) => match compression { Compression::None => { @@ -278,10 +281,10 @@ impl CassandraDecoder { frame_bytes.advance(UNCOMPRESSED_FRAME_HEADER_LENGTH); let payload = frame_bytes.split_to(payload_length).freeze(); - let envelopes = - self.extract_envelopes_from_payload(payload, self_contained, received_at)?; - - Ok(envelopes) + self.v5_envelopes = self + .extract_envelopes_from_payload(payload, self_contained, received_at)? + .into_iter(); + Ok(self.v5_envelopes.next().unwrap()) } Compression::Lz4 => { let mut frame_bytes = src.split_to(frame_len); @@ -339,10 +342,10 @@ impl CassandraDecoder { .into() }; - let envelopes = - self.extract_envelopes_from_payload(payload, self_contained, received_at)?; - - Ok(envelopes) + self.v5_envelopes = self + .extract_envelopes_from_payload(payload, self_contained, received_at)? + .into_iter(); + Ok(self.v5_envelopes.next().unwrap()) } _ => Err(anyhow!("Only Lz4 compression is supported for v5")), }, @@ -368,7 +371,7 @@ impl CassandraDecoder { Some(received_at), ); - Ok(vec![message]) + Ok(message) } } } @@ -564,10 +567,14 @@ fn set_startup_state( } impl Decoder for CassandraDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, CodecReadError> { + if let Some(message) = self.v5_envelopes.next() { + return Ok(Some(message)); + } + let version: Version = self.version.load(Ordering::Relaxed).into(); let compression: Compression = self.compression.load(Ordering::Relaxed).into(); let handshake_complete = self.handshake_complete.load(Ordering::Relaxed); @@ -575,7 +582,7 @@ impl Decoder for CassandraDecoder { match self.check_size(src, version, compression, handshake_complete) { Ok(frame_len) => { - let mut messages = self + let mut message = self .decode_frame( src, frame_len, @@ -586,22 +593,20 @@ impl Decoder for CassandraDecoder { ) .map_err(CodecReadError::Parser)?; - for message in messages.iter_mut() { - if let Ok(Metadata::Cassandra(CassandraMetadata { - opcode: Opcode::Query | Opcode::Batch, - .. - })) = message.metadata() - { - if let Some(keyspace) = get_use_keyspace(message) { - self.current_use_keyspace = Some(keyspace); - } + if let Ok(Metadata::Cassandra(CassandraMetadata { + opcode: Opcode::Query | Opcode::Batch, + .. + })) = message.metadata() + { + if let Some(keyspace) = get_use_keyspace(&mut message) { + self.current_use_keyspace = Some(keyspace); + } - if let Some(keyspace) = &self.current_use_keyspace { - set_default_keyspace(message, keyspace); - } + if let Some(keyspace) = &self.current_use_keyspace { + set_default_keyspace(&mut message, keyspace); } } - Ok(Some(messages)) + Ok(Some(message)) } Err(CheckFrameSizeError::NotEnoughBytes) => Ok(None), Err(CheckFrameSizeError::UnsupportedVersion(version)) => { @@ -1019,10 +1024,10 @@ mod cassandra_protocol_tests { ) { let (mut decoder, mut encoder) = codec.build(); // test decode - let decoded_messages = decoder + let decoded_messages = vec![decoder .decode(&mut BytesMut::from(raw_frame)) .unwrap() - .unwrap(); + .unwrap()]; // test messages parse correctly let mut parsed_messages = decoded_messages.clone(); diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index cc92681f9..096ad452f 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -94,7 +94,7 @@ fn get_length_of_full_message(src: &BytesMut) -> Option { } impl Decoder for KafkaDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -128,7 +128,7 @@ impl Decoder for KafkaDecoder { Some(received_at), ) }; - Ok(Some(vec![message])) + Ok(Some(message)) } else { Ok(None) } diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index 8280c98a3..4552eb2b0 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -1,6 +1,6 @@ //! Codec types to use for connecting to a DB in a sink transform -use crate::message::Messages; +use crate::message::{Message, Messages}; #[cfg(feature = "cassandra")] use cassandra_protocol::compression::Compression; use core::fmt; @@ -114,8 +114,8 @@ impl From for CodecWriteError { } // TODO: Replace with trait_alias (rust-lang/rust#41517). -pub trait DecoderHalf: Decoder + Send {} -impl + Send> DecoderHalf for T {} +pub trait DecoderHalf: Decoder + Send {} +impl + Send> DecoderHalf for T {} // TODO: Replace with trait_alias (rust-lang/rust#41517). pub trait EncoderHalf: Encoder + Send {} diff --git a/shotover/src/codec/opensearch.rs b/shotover/src/codec/opensearch.rs index 1024cddaa..cf93d6dd9 100644 --- a/shotover/src/codec/opensearch.rs +++ b/shotover/src/codec/opensearch.rs @@ -204,7 +204,7 @@ enum State { } impl Decoder for OpenSearchDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, CodecReadError> { @@ -236,13 +236,13 @@ impl Decoder for OpenSearchDecoder { } State::ReadingBody(http_headers, content_length) => { if let Some(Method::HEAD) = *self.last_outgoing_method.lock().unwrap() { - return Ok(Some(vec![Message::from_frame_at_instant( + return Ok(Some(Message::from_frame_at_instant( Frame::OpenSearch(OpenSearchFrame::new( http_headers, bytes::Bytes::new(), )), Some(received_at), - )])); + ))); } if src.len() < content_length { @@ -261,7 +261,7 @@ impl Decoder for OpenSearchDecoder { })?; message.set_request_id(id); } - return Ok(Some(vec![message])); + return Ok(Some(message)); } } } @@ -399,7 +399,7 @@ mod opensearch_tests { .unwrap(); let mut dest = BytesMut::new(); - encoder.encode(message, &mut dest).unwrap(); + encoder.encode(vec![message], &mut dest).unwrap(); assert_eq!(raw_frame, &dest); } @@ -429,7 +429,7 @@ mod opensearch_tests { .unwrap(); let mut dest = BytesMut::new(); - encoder.encode(message, &mut dest).unwrap(); + encoder.encode(vec![message], &mut dest).unwrap(); assert_eq!(raw_frame, &dest); } diff --git a/shotover/src/codec/redis.rs b/shotover/src/codec/redis.rs index 3f8366dbf..fa99d117a 100644 --- a/shotover/src/codec/redis.rs +++ b/shotover/src/codec/redis.rs @@ -92,7 +92,7 @@ impl RedisDecoder { } impl Decoder for RedisDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; // TODO: this duplicates a bunch of logic from sink_single.rs @@ -184,7 +184,7 @@ impl Decoder for RedisDecoder { } } } - Ok(Some(vec![message])) + Ok(Some(message)) } None => Ok(None), } @@ -291,10 +291,10 @@ mod redis_tests { fn test_frame(raw_frame: &[u8]) { let (mut decoder, mut encoder) = RedisCodecBuilder::new(Direction::Source, "redis".to_owned()).build(); - let message = decoder + let message = vec![decoder .decode(&mut BytesMut::from(raw_frame)) .unwrap() - .unwrap(); + .unwrap()]; let mut dest = BytesMut::new(); encoder.encode(message, &mut dest).unwrap(); diff --git a/shotover/src/server.rs b/shotover/src/server.rs index d943575fd..7f2112d0b 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -28,6 +28,9 @@ use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite}; use tracing::Instrument; use tracing::{debug, error, warn}; +// TODO: move into https://github.com/shotover/shotover-proxy/pull/1465 +const DECODE_BUFFER_LEN: usize = 10_000; + pub struct TcpCodecListener { chain_builder: TransformChainBuilder, source_name: String, @@ -291,7 +294,7 @@ async fn spawn_websocket_read_write_tasks< >( codec: C, stream: S, - in_tx: mpsc::Sender, + in_tx: mpsc::Sender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, websocket_subprotocol: &str, @@ -443,7 +446,7 @@ fn spawn_read_write_tasks< codec: C, rx: R, tx: W, - in_tx: mpsc::Sender, + in_tx: mpsc::Sender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, ) { @@ -585,7 +588,7 @@ impl Handler { // A particular scenario we are concerned about is if it takes longer to send to the server // than for the client to send to us, the buffer will grow indefinitely, increasing latency until the buffer triggers an OoM. // To avoid that we have currently hardcoded a limit of 10,000 but if we start hitting that in production we should make this user configurable. - let (in_tx, in_rx) = mpsc::channel::(10_000); + let (in_tx, in_rx) = mpsc::channel::(10_000); let (out_tx, out_rx) = mpsc::unbounded_channel::(); let local_addr = stream.local_addr()?; @@ -674,19 +677,30 @@ impl Handler { async fn receive_with_timeout( timeout: Option, - in_rx: &mut mpsc::Receiver>, + in_rx: &mut mpsc::Receiver, client_details: &str, - ) -> Option> { + last_received: usize, + ) -> Option { + let mut messages = Vec::with_capacity((last_received * 2).min(DECODE_BUFFER_LEN).max(16)); if let Some(timeout) = timeout { - match tokio::time::timeout(timeout, in_rx.recv()).await { - Ok(messages) => messages, + match tokio::time::timeout(timeout, in_rx.recv_many(&mut messages, DECODE_BUFFER_LEN)) + .await + { + Ok(_) => {} Err(_) => { debug!("Dropping connection to {client_details} due to being idle for more than {timeout:?}"); - None + return None; } } } else { - in_rx.recv().await + in_rx.recv_many(&mut messages, DECODE_BUFFER_LEN).await; + } + + if messages.is_empty() { + // No messages indicates that the channel has been closed + None + } else { + Some(messages) } } @@ -694,21 +708,20 @@ impl Handler { &mut self, client_details: &str, local_addr: SocketAddr, - mut in_rx: mpsc::Receiver, + mut in_rx: mpsc::Receiver, out_tx: mpsc::UnboundedSender, ) -> Result<()> { // As long as the shutdown signal has not been received, try to read a // new request frame. + let mut last_received = 0; while !self.shutdown.is_shutdown() { // While reading a request frame, also listen for the shutdown signal debug!("Waiting for message {client_details}"); let responses = tokio::select! { - requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { + requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details,last_received) => { match requests { - Some(mut requests) => { - while let Ok(x) = in_rx.try_recv() { - requests.extend(x); - } + Some(requests) => { + last_received = requests.len(); debug!("Received requests from client {:?}", requests); self.process_forward(client_details, local_addr, &out_tx, requests).await? } diff --git a/shotover/src/transforms/cassandra/connection.rs b/shotover/src/transforms/cassandra/connection.rs index dca8316b5..553a7c2c3 100644 --- a/shotover/src/transforms/cassandra/connection.rs +++ b/shotover/src/transforms/cassandra/connection.rs @@ -282,22 +282,20 @@ async fn rx_process( tokio::select! { response = reader.next() => { match response { - Some(Ok(response)) => { - for mut m in response { - let meta = m.metadata(); - if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta { - if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { - pushed_messages_tx.send(vec![m]).ok(); - } - } else if let Some(stream_id) = m.stream_id() { - match from_tx_process.remove(&stream_id) { - None => { - from_server.insert(stream_id, m); - }, - Some((return_tx, request_id)) => { - m.set_request_id(request_id); - return_tx.send(Ok(m)).ok(); - } + Some(Ok(mut m)) => { + let meta = m.metadata(); + if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta { + if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { + pushed_messages_tx.send(vec![m]).ok(); + } + } else if let Some(stream_id) = m.stream_id() { + match from_tx_process.remove(&stream_id) { + None => { + from_server.insert(stream_id, m); + }, + Some((return_tx, request_id)) => { + m.set_request_id(request_id); + return_tx.send(Ok(m)).ok(); } } } diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index 6154c1c44..ba027eeb6 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -288,12 +288,7 @@ mod test { fn test_rewrite_port_slots() { let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n"; let mut codec = RedisDecoder::new(None, Direction::Sink); - let mut message = codec - .decode(&mut slots_pcap.into()) - .unwrap() - .unwrap() - .pop() - .unwrap(); + let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap(); rewrite_port_slot(message.frame().unwrap(), 6380).unwrap(); diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 952267c27..8b2b12250 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -1147,12 +1147,7 @@ mod test { let mut codec = RedisDecoder::new(None, Direction::Sink); - let mut message = codec - .decode(&mut slots_pcap.into()) - .unwrap() - .unwrap() - .pop() - .unwrap(); + let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap(); let slots_frames = match message.frame().unwrap() { Frame::Redis(RedisFrame::Array(frames)) => frames, diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 9521de68b..c4a24305c 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -248,96 +248,94 @@ async fn server_response_processing_task( /// returns true when the task should shutdown async fn process_server_response( - responses: Option>, + response: Option>, subscribe_tx: &Option>, response_messages_tx: &mpsc::UnboundedSender, is_subscribed: &mut bool, sent_message_type: &mut mpsc::UnboundedReceiver, ) -> bool { - match responses { - Some(Ok(messages)) => { - for mut message in messages { - // Notes on subscription responses - // - // There are 3 types of pubsub responses and the type is determined by the first value in the array: - // * `subscribe` - a response to a SUBSCRIBE, PSUBSCRIBE or SSUBSCRIBE request - // * `unsubscribe` - a response to an UNSUBSCRIBE, PUNSUBSCRIBE or SUNSUBSCRIBE request - // * `message` - a subscription message - // - // Additionally redis will: - // * accept a few regular commands while in pubsub mode: PING, RESET and QUIT - // * return an error response when a nonexistent or non pubsub compatible command is used - // - // Note: PING has a custom response when in pubsub mode. - // It returns an array ['pong', $pingMessage] instead of directly returning $pingMessage. - // But this doesnt cause any problems for us. + match response { + Some(Ok(mut message)) => { + // Notes on subscription responses + // + // There are 3 types of pubsub responses and the type is determined by the first value in the array: + // * `subscribe` - a response to a SUBSCRIBE, PSUBSCRIBE or SSUBSCRIBE request + // * `unsubscribe` - a response to an UNSUBSCRIBE, PUNSUBSCRIBE or SUNSUBSCRIBE request + // * `message` - a subscription message + // + // Additionally redis will: + // * accept a few regular commands while in pubsub mode: PING, RESET and QUIT + // * return an error response when a nonexistent or non pubsub compatible command is used + // + // Note: PING has a custom response when in pubsub mode. + // It returns an array ['pong', $pingMessage] instead of directly returning $pingMessage. + // But this doesnt cause any problems for us. - // Determine if message is a `message` subscription message - // - // Because PING, RESET, QUIT and error responses never return a RedisFrame::Array starting with `message`, - // they have no way to collide with the `message` value of a subscription message. - // So while we are in subscription mode we can use that to determine if an - // incoming message is a subscription message. - let is_subscription_message = if *is_subscribed { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let [RedisFrame::BulkString(ty), ..] = array.as_slice() { - ty.as_ref() == b"message" - } else { - false - } + // Determine if message is a `message` subscription message + // + // Because PING, RESET, QUIT and error responses never return a RedisFrame::Array starting with `message`, + // they have no way to collide with the `message` value of a subscription message. + // So while we are in subscription mode we can use that to determine if an + // incoming message is a subscription message. + let is_subscription_message = if *is_subscribed { + if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { + if let [RedisFrame::BulkString(ty), ..] = array.as_slice() { + ty.as_ref() == b"message" } else { false } } else { false - }; + } + } else { + false + }; - // Update is_subscribed state - // - // In order to make sense of a response we need the main task to - // send us the type of its corresponding request. - // - // In order to keep the incoming request MessageTypes in sync with their corresponding responses - // we must only process a MessageType when the message is not a subscription message. - // This is fine because subscription messages cannot affect the is_subscribed state. - if !is_subscription_message { - match sent_message_type.recv().await { - Some(MessageType::Subscribe) | Some(MessageType::Unsubscribe) => { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let Some(RedisFrame::Integer(number_of_subscribed_channels)) = - array.get(2) - { - *is_subscribed = *number_of_subscribed_channels != 0; - } + // Update is_subscribed state + // + // In order to make sense of a response we need the main task to + // send us the type of its corresponding request. + // + // In order to keep the incoming request MessageTypes in sync with their corresponding responses + // we must only process a MessageType when the message is not a subscription message. + // This is fine because subscription messages cannot affect the is_subscribed state. + if !is_subscription_message { + match sent_message_type.recv().await { + Some(MessageType::Subscribe) | Some(MessageType::Unsubscribe) => { + if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { + if let Some(RedisFrame::Integer(number_of_subscribed_channels)) = + array.get(2) + { + *is_subscribed = *number_of_subscribed_channels != 0; } } - Some(MessageType::Other) => {} - Some(MessageType::Reset) => { - *is_subscribed = false; - } - None => { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; - } + } + Some(MessageType::Other) => {} + Some(MessageType::Reset) => { + *is_subscribed = false; + } + None => { + tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); + return true; } } + } - // Route the message down the correct path: - // * `message` subscription messages: - // needs to be routed down the pushed_messages chain - // * everything else: - // needs to be routed down the regular chain - if is_subscription_message { - // subscribe_tx may not exist if we are e.g. in an alternate chain of a tee transform - if let Some(subscribe_tx) = subscribe_tx { - if let Err(mpsc::error::SendError(_)) = subscribe_tx.send(vec![message]) { - tracing::debug!("shotover chain is terminated, will continue running until Transform is dropped"); - } + // Route the message down the correct path: + // * `message` subscription messages: + // needs to be routed down the pushed_messages chain + // * everything else: + // needs to be routed down the regular chain + if is_subscription_message { + // subscribe_tx may not exist if we are e.g. in an alternate chain of a tee transform + if let Some(subscribe_tx) = subscribe_tx { + if let Err(mpsc::error::SendError(_)) = subscribe_tx.send(vec![message]) { + tracing::debug!("shotover chain is terminated, will continue running until Transform is dropped"); } - } else if let Err(mpsc::error::SendError(_)) = response_messages_tx.send(message) { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; } + } else if let Err(mpsc::error::SendError(_)) = response_messages_tx.send(message) { + tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); + return true; } false } diff --git a/shotover/src/transforms/util/cluster_connection_pool.rs b/shotover/src/transforms/util/cluster_connection_pool.rs index 949eb33ff..fd051ba38 100644 --- a/shotover/src/transforms/util/cluster_connection_pool.rs +++ b/shotover/src/transforms/util/cluster_connection_pool.rs @@ -258,18 +258,16 @@ async fn rx_process( // TODO: This reader.next() may perform reads after tx_process has shutdown the write half. // This may result in unexpected ConnectionReset errors. // refer to the cassandra connection logic. - while let Some(responses) = reader.next().await { - match responses { - Ok(responses) => { - for response_message in responses { - loop { - if let Some(Some(ret)) = return_rx.recv().await { - // If the receiver hangs up, just silently ignore - let _ = ret.send(Response { - response: Ok(response_message), - }); - break; - } + while let Some(response) = reader.next().await { + match response { + Ok(response_message) => { + loop { + if let Some(Some(ret)) = return_rx.recv().await { + // If the receiver hangs up, just silently ignore + let _ = ret.send(Response { + response: Ok(response_message), + }); + break; } } }