Skip to content

Commit

Permalink
codec send Message instead of Vec<Message>
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 18, 2024
1 parent 8629f70 commit 7cf784b
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 172 deletions.
6 changes: 2 additions & 4 deletions shotover/benches/benches/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
)
},
|((mut decoder, encoder), mut input)| {
let mut message =
decoder.decode(&mut input).unwrap().unwrap().pop().unwrap();
let mut message = decoder.decode(&mut input).unwrap().unwrap();
message.frame();

// avoid measuring any drops
Expand All @@ -57,8 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let (mut decoder, encoder) =
KafkaCodecBuilder::new(Direction::Source, "kafka".to_owned()).build();
let mut input = input.clone();
let mut message =
decoder.decode(&mut input).unwrap().unwrap().pop().unwrap();
let mut message = decoder.decode(&mut input).unwrap().unwrap();
message.frame();
assert!(decoder.decode(&mut input).unwrap().is_none());
(decoder, encoder, message)
Expand Down
59 changes: 32 additions & 27 deletions shotover/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::vec::IntoIter;
use tokio_util::codec::{Decoder, Encoder};
use tracing::info;

Expand Down Expand Up @@ -179,6 +180,7 @@ pub struct CassandraDecoder {
version_counter: VersionCounter,
expected_payload_len: Option<usize>,
payload_buffer: BytesMut,
v5_envelopes: IntoIter<Message>,
}

impl CassandraDecoder {
Expand All @@ -198,6 +200,7 @@ impl CassandraDecoder {
version_counter,
payload_buffer: BytesMut::new(),
expected_payload_len: None,
v5_envelopes: vec![].into_iter(),
}
}
}
Expand Down Expand Up @@ -243,7 +246,7 @@ impl CassandraDecoder {
compression: Compression,
handshake_complete: bool,
received_at: Instant,
) -> Result<Vec<Message>> {
) -> Result<Message> {
match (version, handshake_complete) {
(Version::V5, true) => match compression {
Compression::None => {
Expand Down Expand Up @@ -278,10 +281,10 @@ impl CassandraDecoder {
frame_bytes.advance(UNCOMPRESSED_FRAME_HEADER_LENGTH);
let payload = frame_bytes.split_to(payload_length).freeze();

let envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;

Ok(envelopes)
self.v5_envelopes = self
.extract_envelopes_from_payload(payload, self_contained, received_at)?
.into_iter();
Ok(self.v5_envelopes.next().unwrap())
}
Compression::Lz4 => {
let mut frame_bytes = src.split_to(frame_len);
Expand Down Expand Up @@ -339,10 +342,10 @@ impl CassandraDecoder {
.into()
};

let envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;

Ok(envelopes)
self.v5_envelopes = self
.extract_envelopes_from_payload(payload, self_contained, received_at)?
.into_iter();
Ok(self.v5_envelopes.next().unwrap())
}
_ => Err(anyhow!("Only Lz4 compression is supported for v5")),
},
Expand All @@ -368,7 +371,7 @@ impl CassandraDecoder {
Some(received_at),
);

Ok(vec![message])
Ok(message)
}
}
}
Expand Down Expand Up @@ -564,18 +567,22 @@ fn set_startup_state(
}

impl Decoder for CassandraDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
if let Some(message) = self.v5_envelopes.next() {
return Ok(Some(message));
}

let version: Version = self.version.load(Ordering::Relaxed).into();
let compression: Compression = self.compression.load(Ordering::Relaxed).into();
let handshake_complete = self.handshake_complete.load(Ordering::Relaxed);
let received_at = Instant::now();

match self.check_size(src, version, compression, handshake_complete) {
Ok(frame_len) => {
let mut messages = self
let mut message = self
.decode_frame(
src,
frame_len,
Expand All @@ -586,22 +593,20 @@ impl Decoder for CassandraDecoder {
)
.map_err(CodecReadError::Parser)?;

for message in messages.iter_mut() {
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(message) {
self.current_use_keyspace = Some(keyspace);
}
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(&mut message) {
self.current_use_keyspace = Some(keyspace);
}

if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(message, keyspace);
}
if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(&mut message, keyspace);
}
}
Ok(Some(messages))
Ok(Some(message))
}
Err(CheckFrameSizeError::NotEnoughBytes) => Ok(None),
Err(CheckFrameSizeError::UnsupportedVersion(version)) => {
Expand Down Expand Up @@ -1019,10 +1024,10 @@ mod cassandra_protocol_tests {
) {
let (mut decoder, mut encoder) = codec.build();
// test decode
let decoded_messages = decoder
let decoded_messages = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

// test messages parse correctly
let mut parsed_messages = decoded_messages.clone();
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn get_length_of_full_message(src: &BytesMut) -> Option<usize> {
}

impl Decoder for KafkaDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Expand Down Expand Up @@ -128,7 +128,7 @@ impl Decoder for KafkaDecoder {
Some(received_at),
)
};
Ok(Some(vec![message]))
Ok(Some(message))
} else {
Ok(None)
}
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Codec types to use for connecting to a DB in a sink transform
use crate::message::Messages;
use crate::message::{Message, Messages};
#[cfg(feature = "cassandra")]
use cassandra_protocol::compression::Compression;
use core::fmt;
Expand Down Expand Up @@ -114,8 +114,8 @@ impl From<std::io::Error> for CodecWriteError {
}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait DecoderHalf: Decoder<Item = Messages, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Messages, Error = CodecReadError> + Send> DecoderHalf for T {}
pub trait DecoderHalf: Decoder<Item = Message, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Message, Error = CodecReadError> + Send> DecoderHalf for T {}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait EncoderHalf: Encoder<Messages, Error = CodecWriteError> + Send {}
Expand Down
12 changes: 6 additions & 6 deletions shotover/src/codec/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ enum State {
}

impl Decoder for OpenSearchDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
Expand Down Expand Up @@ -236,13 +236,13 @@ impl Decoder for OpenSearchDecoder {
}
State::ReadingBody(http_headers, content_length) => {
if let Some(Method::HEAD) = *self.last_outgoing_method.lock().unwrap() {
return Ok(Some(vec![Message::from_frame_at_instant(
return Ok(Some(Message::from_frame_at_instant(
Frame::OpenSearch(OpenSearchFrame::new(
http_headers,
bytes::Bytes::new(),
)),
Some(received_at),
)]));
)));
}

if src.len() < content_length {
Expand All @@ -261,7 +261,7 @@ impl Decoder for OpenSearchDecoder {
})?;
message.set_request_id(id);
}
return Ok(Some(vec![message]));
return Ok(Some(message));
}
}
}
Expand Down Expand Up @@ -399,7 +399,7 @@ mod opensearch_tests {
.unwrap();

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
encoder.encode(vec![message], &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

Expand Down Expand Up @@ -429,7 +429,7 @@ mod opensearch_tests {
.unwrap();

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
encoder.encode(vec![message], &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

Expand Down
8 changes: 4 additions & 4 deletions shotover/src/codec/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl RedisDecoder {
}

impl Decoder for RedisDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

// TODO: this duplicates a bunch of logic from sink_single.rs
Expand Down Expand Up @@ -184,7 +184,7 @@ impl Decoder for RedisDecoder {
}
}
}
Ok(Some(vec![message]))
Ok(Some(message))
}
None => Ok(None),
}
Expand Down Expand Up @@ -291,10 +291,10 @@ mod redis_tests {
fn test_frame(raw_frame: &[u8]) {
let (mut decoder, mut encoder) =
RedisCodecBuilder::new(Direction::Source, "redis".to_owned()).build();
let message = decoder
let message = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
Expand Down
43 changes: 28 additions & 15 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite};
use tracing::Instrument;
use tracing::{debug, error, warn};

// TODO: move into https://github.com/shotover/shotover-proxy/pull/1465
const DECODE_BUFFER_LEN: usize = 10_000;

pub struct TcpCodecListener<C: CodecBuilder> {
chain_builder: TransformChainBuilder,
source_name: String,
Expand Down Expand Up @@ -291,7 +294,7 @@ async fn spawn_websocket_read_write_tasks<
>(
codec: C,
stream: S,
in_tx: mpsc::Sender<Messages>,
in_tx: mpsc::Sender<Message>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
websocket_subprotocol: &str,
Expand Down Expand Up @@ -443,7 +446,7 @@ fn spawn_read_write_tasks<
codec: C,
rx: R,
tx: W,
in_tx: mpsc::Sender<Messages>,
in_tx: mpsc::Sender<Message>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
) {
Expand Down Expand Up @@ -585,7 +588,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
// A particular scenario we are concerned about is if it takes longer to send to the server
// than for the client to send to us, the buffer will grow indefinitely, increasing latency until the buffer triggers an OoM.
// To avoid that we have currently hardcoded a limit of 10,000 but if we start hitting that in production we should make this user configurable.
let (in_tx, in_rx) = mpsc::channel::<Messages>(10_000);
let (in_tx, in_rx) = mpsc::channel::<Message>(10_000);
let (out_tx, out_rx) = mpsc::unbounded_channel::<Messages>();

let local_addr = stream.local_addr()?;
Expand Down Expand Up @@ -674,41 +677,51 @@ impl<C: CodecBuilder + 'static> Handler<C> {

async fn receive_with_timeout(
timeout: Option<Duration>,
in_rx: &mut mpsc::Receiver<Vec<Message>>,
in_rx: &mut mpsc::Receiver<Message>,
client_details: &str,
) -> Option<Vec<Message>> {
last_received: usize,
) -> Option<Messages> {
let mut messages = Vec::with_capacity((last_received * 2).min(DECODE_BUFFER_LEN).max(16));
if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, in_rx.recv()).await {
Ok(messages) => messages,
match tokio::time::timeout(timeout, in_rx.recv_many(&mut messages, DECODE_BUFFER_LEN))
.await
{
Ok(_) => {}
Err(_) => {
debug!("Dropping connection to {client_details} due to being idle for more than {timeout:?}");
None
return None;
}
}
} else {
in_rx.recv().await
in_rx.recv_many(&mut messages, DECODE_BUFFER_LEN).await;
}

if messages.is_empty() {
// No messages indicates that the channel has been closed
None
} else {
Some(messages)
}
}

async fn process_messages(
&mut self,
client_details: &str,
local_addr: SocketAddr,
mut in_rx: mpsc::Receiver<Messages>,
mut in_rx: mpsc::Receiver<Message>,
out_tx: mpsc::UnboundedSender<Messages>,
) -> Result<()> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
let mut last_received = 0;
while !self.shutdown.is_shutdown() {
// While reading a request frame, also listen for the shutdown signal
debug!("Waiting for message {client_details}");
let responses = tokio::select! {
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details,last_received) => {
match requests {
Some(mut requests) => {
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
}
Some(requests) => {
last_received = requests.len();
debug!("Received requests from client {:?}", requests);
self.process_forward(client_details, local_addr, &out_tx, requests).await?
}
Expand Down
Loading

0 comments on commit 7cf784b

Please sign in to comment.