From 24fa40c7401be0c982d41ddb1ecd705b0349a9de Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Fri, 9 Feb 2024 10:21:58 +1100 Subject: [PATCH] codec send Message instead of Vec --- shotover/benches/benches/codec/kafka.rs | 10 +- shotover/src/codec/cassandra.rs | 44 +++--- shotover/src/codec/kafka.rs | 6 +- shotover/src/codec/mod.rs | 6 +- shotover/src/codec/opensearch.rs | 12 +- shotover/src/codec/redis.rs | 10 +- shotover/src/server.rs | 18 ++- .../src/transforms/cassandra/connection.rs | 26 ++-- .../transforms/redis/cluster_ports_rewrite.rs | 7 +- shotover/src/transforms/redis/sink_cluster.rs | 7 +- shotover/src/transforms/redis/sink_single.rs | 140 +++++++++--------- .../util/cluster_connection_pool.rs | 22 ++- 12 files changed, 145 insertions(+), 163 deletions(-) diff --git a/shotover/benches/benches/codec/kafka.rs b/shotover/benches/benches/codec/kafka.rs index 99810c13e..b72f90549 100644 --- a/shotover/benches/benches/codec/kafka.rs +++ b/shotover/benches/benches/codec/kafka.rs @@ -40,11 +40,10 @@ fn criterion_benchmark(c: &mut Criterion) { ) }, |((mut decoder, _encoder), mut input)| { - let mut result = decoder.decode(&mut input).unwrap().unwrap(); - for message in &mut result { + while let Some(mut message) = decoder.decode(&mut input).unwrap() { message.frame(); + black_box(message); } - black_box(result) }, BatchSize::SmallInput, ) @@ -98,11 +97,10 @@ fn criterion_benchmark(c: &mut Criterion) { ) }, |((mut decoder, _encoder), mut input)| { - let mut result = decoder.decode(&mut input).unwrap().unwrap(); - for message in &mut result { + while let Some(mut message) = decoder.decode(&mut input).unwrap() { message.frame(); + black_box(message); } - black_box(result) }, BatchSize::SmallInput, ) diff --git a/shotover/src/codec/cassandra.rs b/shotover/src/codec/cassandra.rs index 87d8be543..ab8306ecb 100644 --- a/shotover/src/codec/cassandra.rs +++ b/shotover/src/codec/cassandra.rs @@ -243,7 +243,7 @@ impl CassandraDecoder { compression: Compression, handshake_complete: bool, received_at: Instant, - ) -> Result> { + ) -> Result { match (version, handshake_complete) { (Version::V5, true) => match compression { Compression::None => { @@ -278,10 +278,10 @@ impl CassandraDecoder { frame_bytes.advance(UNCOMPRESSED_FRAME_HEADER_LENGTH); let payload = frame_bytes.split_to(payload_length).freeze(); - let envelopes = + let mut envelopes = self.extract_envelopes_from_payload(payload, self_contained, received_at)?; - Ok(envelopes) + Ok(envelopes.pop().unwrap()) //TODO } Compression::Lz4 => { let mut frame_bytes = src.split_to(frame_len); @@ -339,10 +339,10 @@ impl CassandraDecoder { .into() }; - let envelopes = + let mut envelopes = self.extract_envelopes_from_payload(payload, self_contained, received_at)?; - Ok(envelopes) + Ok(envelopes.pop().unwrap()) //TODO } _ => Err(anyhow!("Only Lz4 compression is supported for v5")), }, @@ -368,7 +368,7 @@ impl CassandraDecoder { Some(received_at), ); - Ok(vec![message]) + Ok(message) } } } @@ -564,7 +564,7 @@ fn set_startup_state( } impl Decoder for CassandraDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, CodecReadError> { @@ -575,7 +575,7 @@ impl Decoder for CassandraDecoder { match self.check_size(src, version, compression, handshake_complete) { Ok(frame_len) => { - let mut messages = self + let mut message = self .decode_frame( src, frame_len, @@ -586,22 +586,20 @@ impl Decoder for CassandraDecoder { ) .map_err(CodecReadError::Parser)?; - for message in messages.iter_mut() { - if let Ok(Metadata::Cassandra(CassandraMetadata { - opcode: Opcode::Query | Opcode::Batch, - .. - })) = message.metadata() - { - if let Some(keyspace) = get_use_keyspace(message) { - self.current_use_keyspace = Some(keyspace); - } + if let Ok(Metadata::Cassandra(CassandraMetadata { + opcode: Opcode::Query | Opcode::Batch, + .. + })) = message.metadata() + { + if let Some(keyspace) = get_use_keyspace(&mut message) { + self.current_use_keyspace = Some(keyspace); + } - if let Some(keyspace) = &self.current_use_keyspace { - set_default_keyspace(message, keyspace); - } + if let Some(keyspace) = &self.current_use_keyspace { + set_default_keyspace(&mut message, keyspace); } } - Ok(Some(messages)) + Ok(Some(message)) } Err(CheckFrameSizeError::NotEnoughBytes) => Ok(None), Err(CheckFrameSizeError::UnsupportedVersion(version)) => { @@ -1019,10 +1017,10 @@ mod cassandra_protocol_tests { ) { let (mut decoder, mut encoder) = codec.build(); // test decode - let decoded_messages = decoder + let decoded_messages = vec![decoder .decode(&mut BytesMut::from(raw_frame)) .unwrap() - .unwrap(); + .unwrap()]; // test messages parse correctly let mut parsed_messages = decoded_messages.clone(); diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index 567a02390..9a980b851 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -88,7 +88,7 @@ fn get_length_of_full_message(src: &BytesMut) -> Option { } impl Decoder for KafkaDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -108,11 +108,11 @@ impl Decoder for KafkaDecoder { } else { None }; - Ok(Some(vec![Message::from_bytes_at_instant( + Ok(Some(Message::from_bytes_at_instant( bytes.freeze(), ProtocolType::Kafka { request_header }, Some(received_at), - )])) + ))) } else { Ok(None) } diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index 76f2a89a7..a02de1db3 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -1,6 +1,6 @@ //! Codec types to use for connecting to a DB in a sink transform -use crate::message::Messages; +use crate::message::{Message, Messages}; #[cfg(feature = "cassandra")] use cassandra_protocol::compression::Compression; use core::fmt; @@ -114,8 +114,8 @@ impl From for CodecWriteError { } // TODO: Replace with trait_alias (rust-lang/rust#41517). -pub trait DecoderHalf: Decoder + Send {} -impl + Send> DecoderHalf for T {} +pub trait DecoderHalf: Decoder + Send {} +impl + Send> DecoderHalf for T {} // TODO: Replace with trait_alias (rust-lang/rust#41517). pub trait EncoderHalf: Encoder + Send {} diff --git a/shotover/src/codec/opensearch.rs b/shotover/src/codec/opensearch.rs index 287629349..7914511e3 100644 --- a/shotover/src/codec/opensearch.rs +++ b/shotover/src/codec/opensearch.rs @@ -186,7 +186,7 @@ enum State { } impl Decoder for OpenSearchDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, CodecReadError> { @@ -218,13 +218,13 @@ impl Decoder for OpenSearchDecoder { } State::ReadingBody(http_headers, content_length) => { if let Some(Method::HEAD) = *self.last_outgoing_method.lock().unwrap() { - return Ok(Some(vec![Message::from_frame_at_instant( + return Ok(Some(Message::from_frame_at_instant( Frame::OpenSearch(OpenSearchFrame::new( http_headers, bytes::Bytes::new(), )), Some(received_at), - )])); + ))); } if src.len() < content_length { @@ -233,10 +233,10 @@ impl Decoder for OpenSearchDecoder { } let body = src.split_to(content_length).freeze(); - return Ok(Some(vec![Message::from_frame_at_instant( + return Ok(Some(Message::from_frame_at_instant( Frame::OpenSearch(OpenSearchFrame::new(http_headers, body)), Some(received_at), - )])); + ))); } } } @@ -357,7 +357,7 @@ mod opensearch_tests { .unwrap(); let mut dest = BytesMut::new(); - encoder.encode(message, &mut dest).unwrap(); + encoder.encode(vec![message], &mut dest).unwrap(); assert_eq!(raw_frame, &dest); } diff --git a/shotover/src/codec/redis.rs b/shotover/src/codec/redis.rs index 0c6462600..1567542e3 100644 --- a/shotover/src/codec/redis.rs +++ b/shotover/src/codec/redis.rs @@ -57,7 +57,7 @@ impl RedisDecoder { } impl Decoder for RedisDecoder { - type Item = Messages; + type Item = Message; type Error = CodecReadError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { @@ -71,11 +71,11 @@ impl Decoder for RedisDecoder { self.direction, pretty_hex::pretty_hex(&bytes) ); - Ok(Some(vec![Message::from_bytes_and_frame_at_instant( + Ok(Some(Message::from_bytes_and_frame_at_instant( bytes, Frame::Redis(frame), Some(received_at), - )])) + ))) } None => Ok(None), } @@ -157,10 +157,10 @@ mod redis_tests { fn test_frame(raw_frame: &[u8]) { let (mut decoder, mut encoder) = RedisCodecBuilder::new(Direction::Sink, "redis".to_owned()).build(); - let message = decoder + let message = vec![decoder .decode(&mut BytesMut::from(raw_frame)) .unwrap() - .unwrap(); + .unwrap()]; let mut dest = BytesMut::new(); encoder.encode(message, &mut dest).unwrap(); diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 0c19f2630..c9ba1d599 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -290,7 +290,7 @@ async fn spawn_websocket_read_write_tasks< >( codec: C, stream: S, - in_tx: UnboundedSender, + in_tx: UnboundedSender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, websocket_subprotocol: &str, @@ -442,7 +442,7 @@ fn spawn_read_write_tasks< codec: C, rx: R, tx: W, - in_tx: UnboundedSender, + in_tx: UnboundedSender, mut out_rx: UnboundedReceiver, out_tx: UnboundedSender, ) { @@ -580,7 +580,7 @@ impl Handler { .unwrap_or_else(|_| "Unknown peer".to_string()); tracing::debug!("New connection from {}", client_details); - let (in_tx, in_rx) = mpsc::unbounded_channel::(); + let (in_tx, in_rx) = mpsc::unbounded_channel::(); let (out_tx, out_rx) = mpsc::unbounded_channel::(); let local_addr = stream.local_addr()?; @@ -669,9 +669,9 @@ impl Handler { async fn receive_with_timeout( timeout: Option, - in_rx: &mut UnboundedReceiver>, + in_rx: &mut UnboundedReceiver, client_details: &str, - ) -> Option> { + ) -> Option { if let Some(timeout) = timeout { match tokio::time::timeout(timeout, in_rx.recv()).await { Ok(messages) => messages, @@ -689,7 +689,7 @@ impl Handler { &mut self, client_details: &str, local_addr: SocketAddr, - mut in_rx: mpsc::UnboundedReceiver, + mut in_rx: mpsc::UnboundedReceiver, out_tx: mpsc::UnboundedSender, ) -> Result<()> { // As long as the shutdown signal has not been received, try to read a @@ -700,9 +700,11 @@ impl Handler { let responses = tokio::select! { requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { match requests { - Some(mut requests) => { + Some(request) => { + // TODO: use tokio method + let mut requests = vec!(request); while let Ok(x) = in_rx.try_recv() { - requests.extend(x); + requests.push(x); } debug!("Received requests from client {:?}", requests); self.process_forward(client_details, local_addr, &out_tx, requests).await? diff --git a/shotover/src/transforms/cassandra/connection.rs b/shotover/src/transforms/cassandra/connection.rs index 03a12406e..e709b1748 100644 --- a/shotover/src/transforms/cassandra/connection.rs +++ b/shotover/src/transforms/cassandra/connection.rs @@ -280,20 +280,18 @@ async fn rx_process( response = reader.next() => { match response { Some(Ok(response)) => { - for m in response { - let meta = m.metadata(); - if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta { - if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { - pushed_messages_tx.send(vec![m]).ok(); - } - } else if let Some(stream_id) = m.stream_id() { - match from_tx_process.remove(&stream_id) { - None => { - from_server.insert(stream_id, m); - }, - Some(return_tx) => { - return_tx.send(Ok(m)).ok(); - } + let meta = response.metadata(); + if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta { + if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() { + pushed_messages_tx.send(vec![response]).ok(); + } + } else if let Some(stream_id) = response.stream_id() { + match from_tx_process.remove(&stream_id) { + None => { + from_server.insert(stream_id, response); + }, + Some(return_tx) => { + return_tx.send(Ok(response)).ok(); } } } diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index b2904b03c..48467944b 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -288,12 +288,7 @@ mod test { fn test_rewrite_port_slots() { let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n"; let mut codec = RedisDecoder::new(Direction::Sink); - let mut message = codec - .decode(&mut slots_pcap.into()) - .unwrap() - .unwrap() - .pop() - .unwrap(); + let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap(); rewrite_port_slot(message.frame().unwrap(), 6380).unwrap(); diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index ae330d517..db6bcc9af 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -1147,12 +1147,7 @@ mod test { let mut codec = RedisDecoder::new(Direction::Sink); - let mut message = codec - .decode(&mut slots_pcap.into()) - .unwrap() - .unwrap() - .pop() - .unwrap(); + let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap(); let slots_frames = match message.frame().unwrap() { Frame::Redis(RedisFrame::Array(frames)) => frames, diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 54af80b13..c20cabd5b 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -248,96 +248,94 @@ async fn server_response_processing_task( /// returns true when the task should shutdown async fn process_server_response( - responses: Option>, + response: Option>, subscribe_tx: &Option>, response_messages_tx: &mpsc::UnboundedSender, is_subscribed: &mut bool, sent_message_type: &mut mpsc::UnboundedReceiver, ) -> bool { - match responses { - Some(Ok(messages)) => { - for mut message in messages { - // Notes on subscription responses - // - // There are 3 types of pubsub responses and the type is determined by the first value in the array: - // * `subscribe` - a response to a SUBSCRIBE, PSUBSCRIBE or SSUBSCRIBE request - // * `unsubscribe` - a response to an UNSUBSCRIBE, PUNSUBSCRIBE or SUNSUBSCRIBE request - // * `message` - a subscription message - // - // Additionally redis will: - // * accept a few regular commands while in pubsub mode: PING, RESET and QUIT - // * return an error response when a nonexistent or non pubsub compatible command is used - // - // Note: PING has a custom response when in pubsub mode. - // It returns an array ['pong', $pingMessage] instead of directly returning $pingMessage. - // But this doesnt cause any problems for us. + match response { + Some(Ok(mut message)) => { + // Notes on subscription responses + // + // There are 3 types of pubsub responses and the type is determined by the first value in the array: + // * `subscribe` - a response to a SUBSCRIBE, PSUBSCRIBE or SSUBSCRIBE request + // * `unsubscribe` - a response to an UNSUBSCRIBE, PUNSUBSCRIBE or SUNSUBSCRIBE request + // * `message` - a subscription message + // + // Additionally redis will: + // * accept a few regular commands while in pubsub mode: PING, RESET and QUIT + // * return an error response when a nonexistent or non pubsub compatible command is used + // + // Note: PING has a custom response when in pubsub mode. + // It returns an array ['pong', $pingMessage] instead of directly returning $pingMessage. + // But this doesnt cause any problems for us. - // Determine if message is a `message` subscription message - // - // Because PING, RESET, QUIT and error responses never return a RedisFrame::Array starting with `message`, - // they have no way to collide with the `message` value of a subscription message. - // So while we are in subscription mode we can use that to determine if an - // incoming message is a subscription message. - let is_subscription_message = if *is_subscribed { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let [RedisFrame::BulkString(ty), ..] = array.as_slice() { - ty.as_ref() == b"message" - } else { - false - } + // Determine if message is a `message` subscription message + // + // Because PING, RESET, QUIT and error responses never return a RedisFrame::Array starting with `message`, + // they have no way to collide with the `message` value of a subscription message. + // So while we are in subscription mode we can use that to determine if an + // incoming message is a subscription message. + let is_subscription_message = if *is_subscribed { + if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { + if let [RedisFrame::BulkString(ty), ..] = array.as_slice() { + ty.as_ref() == b"message" } else { false } } else { false - }; + } + } else { + false + }; - // Update is_subscribed state - // - // In order to make sense of a response we need the main task to - // send us the type of its corresponding request. - // - // In order to keep the incoming request MessageTypes in sync with their corresponding responses - // we must only process a MessageType when the message is not a subscription message. - // This is fine because subscription messages cannot affect the is_subscribed state. - if !is_subscription_message { - match sent_message_type.recv().await { - Some(MessageType::Subscribe) | Some(MessageType::Unsubscribe) => { - if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { - if let Some(RedisFrame::Integer(number_of_subscribed_channels)) = - array.get(2) - { - *is_subscribed = *number_of_subscribed_channels != 0; - } + // Update is_subscribed state + // + // In order to make sense of a response we need the main task to + // send us the type of its corresponding request. + // + // In order to keep the incoming request MessageTypes in sync with their corresponding responses + // we must only process a MessageType when the message is not a subscription message. + // This is fine because subscription messages cannot affect the is_subscribed state. + if !is_subscription_message { + match sent_message_type.recv().await { + Some(MessageType::Subscribe) | Some(MessageType::Unsubscribe) => { + if let Some(Frame::Redis(RedisFrame::Array(array))) = message.frame() { + if let Some(RedisFrame::Integer(number_of_subscribed_channels)) = + array.get(2) + { + *is_subscribed = *number_of_subscribed_channels != 0; } } - Some(MessageType::Other) => {} - Some(MessageType::Reset) => { - *is_subscribed = false; - } - None => { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; - } + } + Some(MessageType::Other) => {} + Some(MessageType::Reset) => { + *is_subscribed = false; + } + None => { + tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); + return true; } } + } - // Route the message down the correct path: - // * `message` subscription messages: - // needs to be routed down the pushed_messages chain - // * everything else: - // needs to be routed down the regular chain - if is_subscription_message { - // subscribe_tx may not exist if we are e.g. in an alternate chain of a tee transform - if let Some(subscribe_tx) = subscribe_tx { - if let Err(mpsc::error::SendError(_)) = subscribe_tx.send(vec![message]) { - tracing::debug!("shotover chain is terminated, will continue running until Transform is dropped"); - } + // Route the message down the correct path: + // * `message` subscription messages: + // needs to be routed down the pushed_messages chain + // * everything else: + // needs to be routed down the regular chain + if is_subscription_message { + // subscribe_tx may not exist if we are e.g. in an alternate chain of a tee transform + if let Some(subscribe_tx) = subscribe_tx { + if let Err(mpsc::error::SendError(_)) = subscribe_tx.send(vec![message]) { + tracing::debug!("shotover chain is terminated, will continue running until Transform is dropped"); } - } else if let Err(mpsc::error::SendError(_)) = response_messages_tx.send(message) { - tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); - return true; } + } else if let Err(mpsc::error::SendError(_)) = response_messages_tx.send(message) { + tracing::debug!("RedisSinkSingle dropped after a message was received from server, RedisSinkSingle request processor task shutting down"); + return true; } false } diff --git a/shotover/src/transforms/util/cluster_connection_pool.rs b/shotover/src/transforms/util/cluster_connection_pool.rs index 949eb33ff..fd051ba38 100644 --- a/shotover/src/transforms/util/cluster_connection_pool.rs +++ b/shotover/src/transforms/util/cluster_connection_pool.rs @@ -258,18 +258,16 @@ async fn rx_process( // TODO: This reader.next() may perform reads after tx_process has shutdown the write half. // This may result in unexpected ConnectionReset errors. // refer to the cassandra connection logic. - while let Some(responses) = reader.next().await { - match responses { - Ok(responses) => { - for response_message in responses { - loop { - if let Some(Some(ret)) = return_rx.recv().await { - // If the receiver hangs up, just silently ignore - let _ = ret.send(Response { - response: Ok(response_message), - }); - break; - } + while let Some(response) = reader.next().await { + match response { + Ok(response_message) => { + loop { + if let Some(Some(ret)) = return_rx.recv().await { + // If the receiver hangs up, just silently ignore + let _ = ret.send(Response { + response: Ok(response_message), + }); + break; } } }