Skip to content

Commit

Permalink
Add dummy requests to better support certain transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 19, 2024
1 parent fea7468 commit feece72
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 183 deletions.
4 changes: 2 additions & 2 deletions shotover-proxy/tests/redis_int_tests/basic_driver_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ pub async fn test_cluster_replication(
replication_connection: &mut ClusterConnection,
) {
// According to the coalesce config the writes are only flushed to the replication cluster after 2000 total writes pass through shotover
for i in 0..1000 {
for i in 0..500 {
// 2000 writes havent occured yet so this must be true
assert!(
replication_connection.get::<&str, i32>("foo").is_err(),
Expand Down Expand Up @@ -1189,7 +1189,7 @@ pub async fn test_cluster_replication(
// although we do need to account for the race condition of shotover returning a response before flushing to the replication cluster
let mut value1 = Ok(1); // These dummy values are fine because they get overwritten on the first loop
let mut value2 = Ok(b"".to_vec());
for _ in 0..100 {
for _ in 0..200 {
sleep(Duration::from_millis(100));
value1 = replication_connection.get("foo");
value2 = replication_connection.get("bar");
Expand Down
8 changes: 1 addition & 7 deletions shotover-proxy/tests/redis_int_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,5 @@ async fn cluster_dr() {
test_dr_auth().await;
run_all_cluster_hiding(&mut connection, &mut flusher).await;

shotover
.shutdown_and_then_consume_events(&[EventMatcher::new()
.with_level(Level::Error)
.with_target("shotover::transforms::filter")
.with_message("The current filter transform implementation does not obey the current transform invariants. see https://github.com/shotover/shotover-proxy/issues/499")
])
.await;
shotover.shutdown_and_then_consume_events(&[]).await;
}
11 changes: 6 additions & 5 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use hex_literal::hex;
use shotover::frame::cassandra::{parse_statement_single, Tracing};
use shotover::frame::RedisFrame;
use shotover::frame::{CassandraFrame, CassandraOperation, Frame};
use shotover::message::{Message, ProtocolType, QueryType};
use shotover::message::{Message, MessageIdMap, ProtocolType, QueryType};
use shotover::transforms::cassandra::peers_rewrite::CassandraPeersRewrite;
use shotover::transforms::chain::{TransformChain, TransformChainBuilder};
use shotover::transforms::debug::returner::{DebugReturner, Response};
Expand Down Expand Up @@ -70,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) {
vec![
Box::new(QueryTypeFilter {
filter: Filter::DenyList(vec![QueryType::Read]),
filtered_requests: MessageIdMap::default(),
}),
Box::new(DebugReturner::new(Response::Redis("a".into()))),
],
Expand Down Expand Up @@ -106,12 +107,12 @@ fn criterion_benchmark(c: &mut Criterion) {
let chain = TransformChainBuilder::new(
vec![
Box::new(RedisTimestampTagger::new()),
Box::new(DebugReturner::new(Response::Message(vec![
Message::from_frame(Frame::Redis(RedisFrame::Array(vec![
Box::new(DebugReturner::new(Response::Message(Message::from_frame(
Frame::Redis(RedisFrame::Array(vec![
RedisFrame::BulkString(Bytes::from_static(b"1")), // real frame
RedisFrame::BulkString(Bytes::from_static(b"1")), // timestamp
]))),
]))),
])),
)))),
],
"bench",
);
Expand Down
5 changes: 5 additions & 0 deletions shotover/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@ impl CassandraEncoder {
compression: Compression,
handshake_complete: bool,
) -> Result<()> {
if m.is_dummy() {
// skip dummy messages
return Ok(());
}

match (version, handshake_complete) {
(Version::V5, true) => {
match compression {
Expand Down
26 changes: 21 additions & 5 deletions shotover/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,30 @@ impl Message {
},
}
}
/// Set this `Message` to a dummy frame so that the message will never reach the client or DB.
/// For requests, the dummy frame will be dropped when it reaches the Sink.
/// Additionally a corresponding dummy response will be generated with its request_id set to the requests id.
/// For responses, the dummy frame will be dropped when it reaches the Source.
pub fn replace_with_dummy(&mut self) {
self.inner = Some(MessageInner::Modified {
frame: Frame::Dummy,
});
}

pub fn is_dummy(&self) -> bool {
matches!(
self.inner,
Some(MessageInner::Modified {
frame: Frame::Dummy
})
)
}

/// Set this `Message` to a backpressure response
pub fn set_backpressure(&mut self) -> Result<()> {
pub fn to_backpressure(&mut self) -> Result<Message> {
let metadata = self.metadata()?;

*self = Message::from_frame_at_instant(
Ok(Message::from_frame_at_instant(
match metadata {
#[cfg(feature = "cassandra")]
Metadata::Cassandra(metadata) => Frame::Cassandra(metadata.backpressure_response()),
Expand All @@ -440,9 +458,7 @@ impl Message {
// reachable with feature = cassandra
#[allow(unreachable_code)]
self.received_from_source_or_sink_at,
);

Ok(())
))
}

// Retrieves the stream_id without parsing the rest of the frame.
Expand Down
78 changes: 44 additions & 34 deletions shotover/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tracing::Instrument;
struct Request {
message: Message,
return_chan: oneshot::Sender<Response>,
stream_id: i16,
stream_id: Option<i16>,
}

pub type Response = Result<Message, ResponseError>;
Expand All @@ -36,24 +36,26 @@ pub struct ResponseError {
#[source]
pub cause: anyhow::Error,
pub destination: SocketAddr,
pub stream_id: i16,
pub stream_id: Option<i16>,
}

impl ResponseError {
pub fn to_response(&self, version: Version) -> Message {
Message::from_frame(Frame::Cassandra(CassandraFrame::shotover_error(
self.stream_id,
version,
&format!("{}", self),
)))
match self.stream_id {
Some(stream_id) => Message::from_frame(Frame::Cassandra(
CassandraFrame::shotover_error(stream_id, version, &format!("{}", self)),
)),
None => Message::from_frame(Frame::Dummy),
}
}
}

#[derive(Debug)]
struct ReturnChannel {
return_chan: oneshot::Sender<Response>,
request_id: MessageId,
stream_id: i16,
stream_id: Option<i16>,
is_dummy: bool,
}

#[derive(Clone, Derivative)]
Expand Down Expand Up @@ -148,19 +150,16 @@ impl CassandraConnection {
/// But this indicates a bug within CassandraConnection and should be fixed here.
pub fn send(&self, message: Message) -> Result<oneshot::Receiver<Response>> {
let (return_chan_tx, return_chan_rx) = oneshot::channel();
// Convert the message to `Request` and send upstream
if let Some(stream_id) = message.stream_id() {
self.connection
.send(Request {
message,
return_chan: return_chan_tx,
stream_id,
})
.map(|_| return_chan_rx)
.map_err(|x| x.into())
} else {
Err(anyhow!("no cassandra frame found"))
}
let stream_id = message.stream_id();
self.connection
.send(Request {
message,
return_chan: return_chan_tx,
// TODO: delete the stream_id field, we wont need it when we are handling cassandra out of order
stream_id,
})
.map(|_| return_chan_rx)
.map_err(|x| x.into())
}
}

Expand All @@ -187,6 +186,7 @@ async fn tx_process<T: AsyncWrite>(
loop {
if let Some(request) = out_rx.recv().await {
let request_id = request.message.id();
let is_dummy = request.message.is_dummy();
if let Some(error) = &connection_dead_error {
send_error_to_request(request.return_chan, request.stream_id, destination, error);
} else if let Err(error) = in_w.send(vec![request.message]).await {
Expand All @@ -197,6 +197,7 @@ async fn tx_process<T: AsyncWrite>(
return_chan: request.return_chan,
stream_id: request.stream_id,
request_id,
is_dummy,
}) {
let error = rx_process_has_shutdown_rx
.try_recv()
Expand Down Expand Up @@ -235,7 +236,7 @@ async fn tx_process<T: AsyncWrite>(

fn send_error_to_request(
return_chan: oneshot::Sender<Response>,
stream_id: i16,
stream_id: Option<i16>,
destination: SocketAddr,
error: &str,
) {
Expand Down Expand Up @@ -273,10 +274,11 @@ async fn rx_process<T: AsyncRead>(
// In order to handle that we have two seperate maps.
//
// We store the sender here if we receive from the tx_process task first
let mut from_tx_process: HashMap<i16, (oneshot::Sender<Response>, MessageId)> = HashMap::new();
let mut from_tx_process: HashMap<Option<i16>, (oneshot::Sender<Response>, MessageId)> =
HashMap::new();

// We store the response message here if we receive from the server first.
let mut from_server: HashMap<i16, Message> = HashMap::new();
let mut from_server: HashMap<Option<i16>, Message> = HashMap::new();

loop {
tokio::select! {
Expand All @@ -289,7 +291,8 @@ async fn rx_process<T: AsyncRead>(
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
pushed_messages_tx.send(vec![m]).ok();
}
} else if let Some(stream_id) = m.stream_id() {
} else {
let stream_id = m.stream_id();
match from_tx_process.remove(&stream_id) {
None => {
from_server.insert(stream_id, m);
Expand Down Expand Up @@ -322,14 +325,21 @@ async fn rx_process<T: AsyncRead>(
}
},
original_request = return_rx.recv() => {
if let Some(ReturnChannel { return_chan, stream_id,request_id }) = original_request {
match from_server.remove(&stream_id) {
None => {
from_tx_process.insert(stream_id, (return_chan, request_id));
}
Some(mut m) => {
m.set_request_id(request_id);
return_chan.send(Ok(m)).ok();
if let Some(ReturnChannel { return_chan, stream_id, request_id, is_dummy }) = original_request {
if is_dummy {
// There will be no response from the DB for this message so we need to generate a dummy response instead.
let mut response = Message::from_frame(Frame::Dummy);
response.set_request_id(request_id);
return_chan.send(Ok(response)).ok();
} else {
match from_server.remove(&stream_id) {
None => {
from_tx_process.insert(stream_id, (return_chan, request_id));
}
Some(mut m) => {
m.set_request_id(request_id);
return_chan.send(Ok(m)).ok();
}
}
}
} else {
Expand All @@ -346,7 +356,7 @@ async fn rx_process<T: AsyncRead>(

async fn send_errors_and_shutdown(
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
mut waiting: HashMap<i16, (oneshot::Sender<Response>, MessageId)>,
mut waiting: HashMap<Option<i16>, (oneshot::Sender<Response>, MessageId)>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
destination: SocketAddr,
message: &str,
Expand Down
46 changes: 25 additions & 21 deletions shotover/src/transforms/debug/returner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::message::Messages;
use crate::message::{Message, Messages};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand All @@ -24,7 +24,7 @@ impl TransformConfig for DebugReturnerConfig {
#[serde(deny_unknown_fields)]
pub enum Response {
#[serde(skip)]
Message(Messages),
Message(Message),
#[cfg(feature = "redis")]
Redis(String),
Fail,
Expand Down Expand Up @@ -61,24 +61,28 @@ impl Transform for DebugReturner {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
match &self.response {
Response::Message(message) => Ok(message.clone()),
#[cfg(feature = "redis")]
Response::Redis(string) => {
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
Ok(requests_wrapper
.requests
.iter()
.map(|_| {
Message::from_frame(Frame::Redis(RedisFrame::BulkString(
string.to_string().into(),
)))
})
.collect())
}
Response::Fail => Err(anyhow!("Intentional Fail")),
}
async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
requests_wrapper
.requests
.iter_mut()
.map(|request| match &self.response {
Response::Message(message) => {
let mut message = message.clone();
message.set_request_id(request.id());
Ok(message)
}
#[cfg(feature = "redis")]
Response::Redis(string) => {
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
let mut message = Message::from_frame(Frame::Redis(RedisFrame::BulkString(
string.to_string().into(),
)));
message.set_request_id(request.id());
Ok(message)
}
Response::Fail => Err(anyhow!("Intentional Fail")),
})
.collect()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ mod scatter_transform_tests {

#[tokio::test(flavor = "multi_thread")]
async fn test_scatter_success() {
let response = vec![Message::from_frame(Frame::Redis(RedisFrame::BulkString(
"OK".into(),
)))];
let response = Message::from_frame(Frame::Redis(RedisFrame::BulkString("OK".into())));

let wrapper = Wrapper::new_test(vec![Message::from_frame(Frame::Redis(
RedisFrame::BulkString(Bytes::from_static(b"foo")),
Expand Down
Loading

0 comments on commit feece72

Please sign in to comment.