Skip to content

Commit

Permalink
codec send Message instead of Vec<Message>
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 8, 2024
1 parent bd7fbe8 commit 24fa40c
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 163 deletions.
10 changes: 4 additions & 6 deletions shotover/benches/benches/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ fn criterion_benchmark(c: &mut Criterion) {
)
},
|((mut decoder, _encoder), mut input)| {
let mut result = decoder.decode(&mut input).unwrap().unwrap();
for message in &mut result {
while let Some(mut message) = decoder.decode(&mut input).unwrap() {
message.frame();
black_box(message);
}
black_box(result)
},
BatchSize::SmallInput,
)
Expand Down Expand Up @@ -98,11 +97,10 @@ fn criterion_benchmark(c: &mut Criterion) {
)
},
|((mut decoder, _encoder), mut input)| {
let mut result = decoder.decode(&mut input).unwrap().unwrap();
for message in &mut result {
while let Some(mut message) = decoder.decode(&mut input).unwrap() {
message.frame();
black_box(message);
}
black_box(result)
},
BatchSize::SmallInput,
)
Expand Down
44 changes: 21 additions & 23 deletions shotover/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ impl CassandraDecoder {
compression: Compression,
handshake_complete: bool,
received_at: Instant,
) -> Result<Vec<Message>> {
) -> Result<Message> {
match (version, handshake_complete) {
(Version::V5, true) => match compression {
Compression::None => {
Expand Down Expand Up @@ -278,10 +278,10 @@ impl CassandraDecoder {
frame_bytes.advance(UNCOMPRESSED_FRAME_HEADER_LENGTH);
let payload = frame_bytes.split_to(payload_length).freeze();

let envelopes =
let mut envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;

Ok(envelopes)
Ok(envelopes.pop().unwrap()) //TODO
}
Compression::Lz4 => {
let mut frame_bytes = src.split_to(frame_len);
Expand Down Expand Up @@ -339,10 +339,10 @@ impl CassandraDecoder {
.into()
};

let envelopes =
let mut envelopes =
self.extract_envelopes_from_payload(payload, self_contained, received_at)?;

Ok(envelopes)
Ok(envelopes.pop().unwrap()) //TODO
}
_ => Err(anyhow!("Only Lz4 compression is supported for v5")),
},
Expand All @@ -368,7 +368,7 @@ impl CassandraDecoder {
Some(received_at),
);

Ok(vec![message])
Ok(message)
}
}
}
Expand Down Expand Up @@ -564,7 +564,7 @@ fn set_startup_state(
}

impl Decoder for CassandraDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
Expand All @@ -575,7 +575,7 @@ impl Decoder for CassandraDecoder {

match self.check_size(src, version, compression, handshake_complete) {
Ok(frame_len) => {
let mut messages = self
let mut message = self
.decode_frame(
src,
frame_len,
Expand All @@ -586,22 +586,20 @@ impl Decoder for CassandraDecoder {
)
.map_err(CodecReadError::Parser)?;

for message in messages.iter_mut() {
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(message) {
self.current_use_keyspace = Some(keyspace);
}
if let Ok(Metadata::Cassandra(CassandraMetadata {
opcode: Opcode::Query | Opcode::Batch,
..
})) = message.metadata()
{
if let Some(keyspace) = get_use_keyspace(&mut message) {
self.current_use_keyspace = Some(keyspace);
}

if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(message, keyspace);
}
if let Some(keyspace) = &self.current_use_keyspace {
set_default_keyspace(&mut message, keyspace);
}
}
Ok(Some(messages))
Ok(Some(message))
}
Err(CheckFrameSizeError::NotEnoughBytes) => Ok(None),
Err(CheckFrameSizeError::UnsupportedVersion(version)) => {
Expand Down Expand Up @@ -1019,10 +1017,10 @@ mod cassandra_protocol_tests {
) {
let (mut decoder, mut encoder) = codec.build();
// test decode
let decoded_messages = decoder
let decoded_messages = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

// test messages parse correctly
let mut parsed_messages = decoded_messages.clone();
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ fn get_length_of_full_message(src: &BytesMut) -> Option<usize> {
}

impl Decoder for KafkaDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Expand All @@ -108,11 +108,11 @@ impl Decoder for KafkaDecoder {
} else {
None
};
Ok(Some(vec![Message::from_bytes_at_instant(
Ok(Some(Message::from_bytes_at_instant(
bytes.freeze(),
ProtocolType::Kafka { request_header },
Some(received_at),
)]))
)))
} else {
Ok(None)
}
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Codec types to use for connecting to a DB in a sink transform
use crate::message::Messages;
use crate::message::{Message, Messages};
#[cfg(feature = "cassandra")]
use cassandra_protocol::compression::Compression;
use core::fmt;
Expand Down Expand Up @@ -114,8 +114,8 @@ impl From<std::io::Error> for CodecWriteError {
}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait DecoderHalf: Decoder<Item = Messages, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Messages, Error = CodecReadError> + Send> DecoderHalf for T {}
pub trait DecoderHalf: Decoder<Item = Message, Error = CodecReadError> + Send {}
impl<T: Decoder<Item = Message, Error = CodecReadError> + Send> DecoderHalf for T {}

// TODO: Replace with trait_alias (rust-lang/rust#41517).
pub trait EncoderHalf: Encoder<Messages, Error = CodecWriteError> + Send {}
Expand Down
12 changes: 6 additions & 6 deletions shotover/src/codec/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ enum State {
}

impl Decoder for OpenSearchDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, CodecReadError> {
Expand Down Expand Up @@ -218,13 +218,13 @@ impl Decoder for OpenSearchDecoder {
}
State::ReadingBody(http_headers, content_length) => {
if let Some(Method::HEAD) = *self.last_outgoing_method.lock().unwrap() {
return Ok(Some(vec![Message::from_frame_at_instant(
return Ok(Some(Message::from_frame_at_instant(
Frame::OpenSearch(OpenSearchFrame::new(
http_headers,
bytes::Bytes::new(),
)),
Some(received_at),
)]));
)));
}

if src.len() < content_length {
Expand All @@ -233,10 +233,10 @@ impl Decoder for OpenSearchDecoder {
}

let body = src.split_to(content_length).freeze();
return Ok(Some(vec![Message::from_frame_at_instant(
return Ok(Some(Message::from_frame_at_instant(
Frame::OpenSearch(OpenSearchFrame::new(http_headers, body)),
Some(received_at),
)]));
)));
}
}
}
Expand Down Expand Up @@ -357,7 +357,7 @@ mod opensearch_tests {
.unwrap();

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
encoder.encode(vec![message], &mut dest).unwrap();
assert_eq!(raw_frame, &dest);
}

Expand Down
10 changes: 5 additions & 5 deletions shotover/src/codec/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl RedisDecoder {
}

impl Decoder for RedisDecoder {
type Item = Messages;
type Item = Message;
type Error = CodecReadError;

fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
Expand All @@ -71,11 +71,11 @@ impl Decoder for RedisDecoder {
self.direction,
pretty_hex::pretty_hex(&bytes)
);
Ok(Some(vec![Message::from_bytes_and_frame_at_instant(
Ok(Some(Message::from_bytes_and_frame_at_instant(
bytes,
Frame::Redis(frame),
Some(received_at),
)]))
)))
}
None => Ok(None),
}
Expand Down Expand Up @@ -157,10 +157,10 @@ mod redis_tests {
fn test_frame(raw_frame: &[u8]) {
let (mut decoder, mut encoder) =
RedisCodecBuilder::new(Direction::Sink, "redis".to_owned()).build();
let message = decoder
let message = vec![decoder
.decode(&mut BytesMut::from(raw_frame))
.unwrap()
.unwrap();
.unwrap()];

let mut dest = BytesMut::new();
encoder.encode(message, &mut dest).unwrap();
Expand Down
18 changes: 10 additions & 8 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ async fn spawn_websocket_read_write_tasks<
>(
codec: C,
stream: S,
in_tx: UnboundedSender<Messages>,
in_tx: UnboundedSender<Message>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
websocket_subprotocol: &str,
Expand Down Expand Up @@ -442,7 +442,7 @@ fn spawn_read_write_tasks<
codec: C,
rx: R,
tx: W,
in_tx: UnboundedSender<Messages>,
in_tx: UnboundedSender<Message>,
mut out_rx: UnboundedReceiver<Messages>,
out_tx: UnboundedSender<Messages>,
) {
Expand Down Expand Up @@ -580,7 +580,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
.unwrap_or_else(|_| "Unknown peer".to_string());
tracing::debug!("New connection from {}", client_details);

let (in_tx, in_rx) = mpsc::unbounded_channel::<Messages>();
let (in_tx, in_rx) = mpsc::unbounded_channel::<Message>();
let (out_tx, out_rx) = mpsc::unbounded_channel::<Messages>();

let local_addr = stream.local_addr()?;
Expand Down Expand Up @@ -669,9 +669,9 @@ impl<C: CodecBuilder + 'static> Handler<C> {

async fn receive_with_timeout(
timeout: Option<Duration>,
in_rx: &mut UnboundedReceiver<Vec<Message>>,
in_rx: &mut UnboundedReceiver<Message>,
client_details: &str,
) -> Option<Vec<Message>> {
) -> Option<Message> {
if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, in_rx.recv()).await {
Ok(messages) => messages,
Expand All @@ -689,7 +689,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
&mut self,
client_details: &str,
local_addr: SocketAddr,
mut in_rx: mpsc::UnboundedReceiver<Messages>,
mut in_rx: mpsc::UnboundedReceiver<Message>,
out_tx: mpsc::UnboundedSender<Messages>,
) -> Result<()> {
// As long as the shutdown signal has not been received, try to read a
Expand All @@ -700,9 +700,11 @@ impl<C: CodecBuilder + 'static> Handler<C> {
let responses = tokio::select! {
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
match requests {
Some(mut requests) => {
Some(request) => {
// TODO: use tokio method
let mut requests = vec!(request);
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
requests.push(x);
}
debug!("Received requests from client {:?}", requests);
self.process_forward(client_details, local_addr, &out_tx, requests).await?
Expand Down
26 changes: 12 additions & 14 deletions shotover/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,18 @@ async fn rx_process<T: AsyncRead>(
response = reader.next() => {
match response {
Some(Ok(response)) => {
for m in response {
let meta = m.metadata();
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![m]).ok();
}
} else if let Some(stream_id) = m.stream_id() {
match from_tx_process.remove(&stream_id) {
None => {
from_server.insert(stream_id, m);
},
Some(return_tx) => {
return_tx.send(Ok(m)).ok();
}
let meta = response.metadata();
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![response]).ok();
}
} else if let Some(stream_id) = response.stream_id() {
match from_tx_process.remove(&stream_id) {
None => {
from_server.insert(stream_id, response);
},
Some(return_tx) => {
return_tx.send(Ok(response)).ok();
}
}
}
Expand Down
7 changes: 1 addition & 6 deletions shotover/src/transforms/redis/cluster_ports_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,7 @@ mod test {
fn test_rewrite_port_slots() {
let slots_pcap: &[u8] = b"*3\r\n*4\r\n:10923\r\n:16383\r\n*3\r\n$12\r\n192.168.80.6\r\n:6379\r\n$40\r\n3a7c357ed75d2aa01fca1e14ef3735a2b2b8ffac\r\n*3\r\n$12\r\n192.168.80.3\r\n:6379\r\n$40\r\n77c01b0ddd8668fff05e3f6a8aaf5f3ccd454a79\r\n*4\r\n:5461\r\n:10922\r\n*3\r\n$12\r\n192.168.80.5\r\n:6379\r\n$40\r\n969c6215d064e68593d384541ceeb57e9520dbed\r\n*3\r\n$12\r\n192.168.80.2\r\n:6379\r\n$40\r\n3929f69990a75be7b2d49594c57fe620862e6fd6\r\n*4\r\n:0\r\n:5460\r\n*3\r\n$12\r\n192.168.80.7\r\n:6379\r\n$40\r\n15d52a65d1fc7a53e34bf9193415aa39136882b2\r\n*3\r\n$12\r\n192.168.80.4\r\n:6379\r\n$40\r\ncd023916a3528fae7e606a10d8289a665d6c47b0\r\n";
let mut codec = RedisDecoder::new(Direction::Sink);
let mut message = codec
.decode(&mut slots_pcap.into())
.unwrap()
.unwrap()
.pop()
.unwrap();
let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap();

rewrite_port_slot(message.frame().unwrap(), 6380).unwrap();

Expand Down
7 changes: 1 addition & 6 deletions shotover/src/transforms/redis/sink_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,12 +1147,7 @@ mod test {

let mut codec = RedisDecoder::new(Direction::Sink);

let mut message = codec
.decode(&mut slots_pcap.into())
.unwrap()
.unwrap()
.pop()
.unwrap();
let mut message = codec.decode(&mut slots_pcap.into()).unwrap().unwrap();

let slots_frames = match message.frame().unwrap() {
Frame::Redis(RedisFrame::Array(frames)) => frames,
Expand Down
Loading

0 comments on commit 24fa40c

Please sign in to comment.