Skip to content

Commit

Permalink
codec send Message instead of Vec<Message>
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 9, 2024
1 parent bd7fbe8 commit 598b2b6
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 203 deletions.
71 changes: 34 additions & 37 deletions shotover/benches/benches/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn criterion_benchmark(c: &mut Criterion) {
{
let mut input = BytesMut::new();
input.extend_from_slice(message);
group.bench_function(format!("decode_{file_name}"), |b| {
group.bench_function(format!("decode_{file_name}_create"), |b| {
b.iter_batched(
|| {
(
Expand All @@ -39,12 +39,33 @@ fn criterion_benchmark(c: &mut Criterion) {
input.clone(),
)
},
|((mut decoder, _encoder), mut input)| {
let mut result = decoder.decode(&mut input).unwrap().unwrap();
for message in &mut result {
message.frame();
}
black_box(result)
|((mut decoder, encoder), mut input)| {
let mut message = decoder.decode(&mut input).unwrap().unwrap();
message.frame();

// avoid measuring any drops
(decoder, encoder, input, message)
},
BatchSize::SmallInput,
)
});
group.bench_function(format!("decode_{file_name}_drop"), |b| {
b.iter_batched(
|| {
// Recreate everything from scratch to ensure that we dont have any `Bytes` references held onto preventing a full drop
let (mut decoder, encoder) =
KafkaCodecBuilder::new(Direction::Source, "kafka".to_owned()).build();
let mut input = input.clone();
let mut message = decoder.decode(&mut input).unwrap().unwrap();
message.frame();
assert!(decoder.decode(&mut input).unwrap().is_none());
(decoder, encoder, message)
},
|(decoder, encoder, message)| {
std::mem::drop(message);

// avoid measuring any drops other than the message
(decoder, encoder)
},
BatchSize::SmallInput,
)
Expand Down Expand Up @@ -72,43 +93,18 @@ fn criterion_benchmark(c: &mut Criterion) {
messages.clone(),
)
},
|((_decoder, mut encoder), messages)| {
|((decoder, mut encoder), messages)| {
let mut bytes = BytesMut::new();
encoder.encode(messages, &mut bytes).unwrap();
black_box(bytes)
std::mem::drop(black_box(bytes));
(encoder, decoder)
},
BatchSize::SmallInput,
)
});
}
}

{
let mut input = BytesMut::new();
for (message, _) in KAFKA_REQUESTS {
input.extend_from_slice(message);
}
group.bench_function("decode_all", |b| {
b.iter_batched(
|| {
(
// recreate codec since it is stateful
KafkaCodecBuilder::new(Direction::Source, "kafka".to_owned()).build(),
input.clone(),
)
},
|((mut decoder, _encoder), mut input)| {
let mut result = decoder.decode(&mut input).unwrap().unwrap();
for message in &mut result {
message.frame();
}
black_box(result)
},
BatchSize::SmallInput,
)
});
}

{
let mut messages = vec![];
for (message, _) in KAFKA_REQUESTS {
Expand All @@ -134,10 +130,11 @@ fn criterion_benchmark(c: &mut Criterion) {
messages.clone(),
)
},
|((_decoder, mut encoder), messages)| {
|((decoder, mut encoder), messages)| {
let mut bytes = BytesMut::new();
encoder.encode(messages, &mut bytes).unwrap();
black_box(bytes)
std::mem::drop(black_box(bytes));
(encoder, decoder)
},
BatchSize::SmallInput,
)
Expand Down
60 changes: 35 additions & 25 deletions shotover/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ use cql3_parser::common::Identifier;
use lz4_flex::{block::get_maximum_output_size, compress_into, decompress};
use metrics::{register_counter, Counter, Histogram};
use std::collections::HashMap;
use std::iter::Rev;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::vec::IntoIter;
use tokio_util::codec::{Decoder, Encoder};
use tracing::info;

Expand Down Expand Up @@ -179,6 +181,7 @@ pub struct CassandraDecoder {
version_counter: VersionCounter,
expected_payload_len: Option<usize>,
payload_buffer: BytesMut,
v5_envelopes: Rev<IntoIter<Message>>,
}

impl CassandraDecoder {
Expand All @@ -198,6 +201,7 @@ impl CassandraDecoder {
version_counter,
payload_buffer: BytesMut::new(),
expected_payload_len: None,
v5_envelopes: vec![].into_iter().rev(),
}
}
}
Expand Down Expand Up @@ -243,7 +247,7 @@ impl CassandraDecoder {
compression: Compression,
handshake_complete: bool,
received_at: Instant,
) -> Result<Vec<Message>> {
) -> Result<Message> {
match (version, handshake_complete) {
(Version::V5, true) => match compression {
Compression::None => {
Expand Down Expand Up @@ -278,10 +282,12 @@ impl CassandraDecoder {
frame_bytes.advance(UNCOMPRESSED_FRAME_HEADER_LENGTH);
let payload = frame_bytes.split_to(payload_length).freeze();

let envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;
self.v5_envelopes = self
.extract_envelopes_from_payload(payload, self_contained, received_at)?
.into_iter()
.rev();

Ok(envelopes)
Ok(self.v5_envelopes.next().unwrap())
}
Compression::Lz4 => {
let mut frame_bytes = src.split_to(frame_len);
Expand Down Expand Up @@ -339,10 +345,12 @@ impl CassandraDecoder {
.into()
};

let envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;
self.v5_envelopes = self
.extract_envelopes_from_payload(payload, self_contained, received_at)?
.into_iter()
.rev();

Ok(envelopes)
Ok(self.v5_envelopes.next().unwrap())
}
_ => Err(anyhow!("Only Lz4 compression is supported for v5")),
},
Expand All @@ -368,7 +376,7 @@ impl CassandraDecoder {
Some(received_at),
);

Ok(vec![message])
Ok(message)
}
}
}
Expand Down Expand Up @@ -564,18 +572,22 @@ fn set_startup_state(
}

impl Decoder for CassandraDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
if let Some(message) = self.v5_envelopes.next() {
return Ok(Some(message));
}

let version: Version = self.version.load(Ordering::Relaxed).into();
let compression: Compression = self.compression.load(Ordering::Relaxed).into();
let handshake_complete = self.handshake_complete.load(Ordering::Relaxed);
let received_at = Instant::now();

match self.check_size(src, version, compression, handshake_complete) {
Ok(frame_len) => {
let mut messages = self
let mut message = self
.decode_frame(
src,
frame_len,
Expand All @@ -586,22 +598,20 @@ impl Decoder for CassandraDecoder {
)
.map_err(CodecReadError::Parser)?;

for message in messages.iter_mut() {
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(message) {
self.current_use_keyspace = Some(keyspace);
}
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(&mut message) {
self.current_use_keyspace = Some(keyspace);
}

if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(message, keyspace);
}
if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(&mut message, keyspace);
}
}
Ok(Some(messages))
Ok(Some(message))
}
Err(CheckFrameSizeError::NotEnoughBytes) => Ok(None),
Err(CheckFrameSizeError::UnsupportedVersion(version)) => {
Expand Down Expand Up @@ -1019,10 +1029,10 @@ mod cassandra_protocol_tests {
) {
let (mut decoder, mut encoder) = codec.build();
// test decode
let decoded_messages = decoder
let decoded_messages = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

// test messages parse correctly
let mut parsed_messages = decoded_messages.clone();
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ fn get_length_of_full_message(src: &BytesMut) -> Option<usize> {
}

impl Decoder for KafkaDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Expand All @@ -108,11 +108,11 @@ impl Decoder for KafkaDecoder {
} else {
None
};
Ok(Some(vec![Message::from_bytes_at_instant(
Ok(Some(Message::from_bytes_at_instant(
bytes.freeze(),
ProtocolType::Kafka { request_header },
Some(received_at),
)]))
)))
} else {
Ok(None)
}
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Codec types to use for connecting to a DB in a sink transform
use crate::message::Messages;
use crate::message::{Message, Messages};
#[cfg(feature = "cassandra")]
use cassandra_protocol::compression::Compression;
use core::fmt;
Expand Down Expand Up @@ -114,8 +114,8 @@ impl From<std::io::Error> for CodecWriteError {
}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait DecoderHalf: Decoder<Item = Messages, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Messages, Error = CodecReadError> + Send> DecoderHalf for T {}
pub trait DecoderHalf: Decoder<Item = Message, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Message, Error = CodecReadError> + Send> DecoderHalf for T {}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait EncoderHalf: Encoder<Messages, Error = CodecWriteError> + Send {}
Expand Down
12 changes: 6 additions & 6 deletions shotover/src/codec/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ enum State {
}

impl Decoder for OpenSearchDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
Expand Down Expand Up @@ -218,13 +218,13 @@ impl Decoder for OpenSearchDecoder {
}
State::ReadingBody(http_headers, content_length) => {
if let Some(Method::HEAD) = *self.last_outgoing_method.lock().unwrap() {
return Ok(Some(vec![Message::from_frame_at_instant(
return Ok(Some(Message::from_frame_at_instant(
Frame::OpenSearch(OpenSearchFrame::new(
http_headers,
bytes::Bytes::new(),
)),
Some(received_at),
)]));
)));
}

if src.len() < content_length {
Expand All @@ -233,10 +233,10 @@ impl Decoder for OpenSearchDecoder {
}

let body = src.split_to(content_length).freeze();
return Ok(Some(vec![Message::from_frame_at_instant(
return Ok(Some(Message::from_frame_at_instant(
Frame::OpenSearch(OpenSearchFrame::new(http_headers, body)),
Some(received_at),
)]));
)));
}
}
}
Expand Down Expand Up @@ -357,7 +357,7 @@ mod opensearch_tests {
.unwrap();

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
encoder.encode(vec![message], &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

Expand Down
10 changes: 5 additions & 5 deletions shotover/src/codec/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl RedisDecoder {
}

impl Decoder for RedisDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Expand All @@ -71,11 +71,11 @@ impl Decoder for RedisDecoder {
self.direction,
pretty_hex::pretty_hex(&bytes)
);
Ok(Some(vec![Message::from_bytes_and_frame_at_instant(
Ok(Some(Message::from_bytes_and_frame_at_instant(
bytes,
Frame::Redis(frame),
Some(received_at),
)]))
)))
}
None => Ok(None),
}
Expand Down Expand Up @@ -157,10 +157,10 @@ mod redis_tests {
fn test_frame(raw_frame: &[u8]) {
let (mut decoder, mut encoder) =
RedisCodecBuilder::new(Direction::Sink, "redis".to_owned()).build();
let message = decoder
let message = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
Expand Down
Loading

0 comments on commit 598b2b6

Please sign in to comment.