Skip to content

Commit

Permalink
Add dummy requests to better support certain transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 19, 2024
1 parent 8629f70 commit ba0fffb
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 131 deletions.
11 changes: 7 additions & 4 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use bytes::Bytes;
use cassandra_protocol::compression::Compression;
use cassandra_protocol::{consistency::Consistency, frame::Version, query::QueryParams};
Expand Down Expand Up @@ -70,6 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) {
vec![
Box::new(QueryTypeFilter {
filter: Filter::DenyList(vec![QueryType::Read]),
filtered_requests: HashMap::new(),
}),
Box::new(DebugReturner::new(Response::Redis("a".into()))),
],
Expand Down Expand Up @@ -106,12 +109,12 @@ fn criterion_benchmark(c: &mut Criterion) {
let chain = TransformChainBuilder::new(
vec![
Box::new(RedisTimestampTagger::new()),
Box::new(DebugReturner::new(Response::Message(vec![
Message::from_frame(Frame::Redis(RedisFrame::Array(vec![
Box::new(DebugReturner::new(Response::Message(Message::from_frame(
Frame::Redis(RedisFrame::Array(vec![
RedisFrame::BulkString(Bytes::from_static(b"1")), // real frame
RedisFrame::BulkString(Bytes::from_static(b"1")), // timestamp
]))),
]))),
])),
)))),
],
"bench",
);
Expand Down
18 changes: 18 additions & 0 deletions shotover/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,24 @@ impl Message {
},
}
}
/// Set this `Message` to a dummy frame so that the message will never reach the client or DB.
/// For requests, the dummy frame will be dropped when it reaches the Sink.
/// Additionally a corresponding dummy response will be generated with its request_id set to the requests id.
/// For responses, the dummy frame will be dropped when it reaches the Source.
pub fn replace_with_dummy(&mut self) {
self.inner = Some(MessageInner::Modified {
frame: Frame::Dummy,
});
}

pub fn is_dummy(&self) -> bool {
matches!(
self.inner,
Some(MessageInner::Modified {
frame: Frame::Dummy
})
)
}

/// Set this `Message` to a backpressure response
pub fn set_backpressure(&mut self) -> Result<()> {
Expand Down
46 changes: 25 additions & 21 deletions shotover/src/transforms/debug/returner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::message::Messages;
use crate::message::{Message, Messages};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand All @@ -24,7 +24,7 @@ impl TransformConfig for DebugReturnerConfig {
#[serde(deny_unknown_fields)]
pub enum Response {
#[serde(skip)]
Message(Messages),
Message(Message),
#[cfg(feature = "redis")]
Redis(String),
Fail,
Expand Down Expand Up @@ -61,24 +61,28 @@ impl Transform for DebugReturner {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
match &self.response {
Response::Message(message) => Ok(message.clone()),
#[cfg(feature = "redis")]
Response::Redis(string) => {
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
Ok(requests_wrapper
.requests
.iter()
.map(|_| {
Message::from_frame(Frame::Redis(RedisFrame::BulkString(
string.to_string().into(),
)))
})
.collect())
}
Response::Fail => Err(anyhow!("Intentional Fail")),
}
async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
requests_wrapper
.requests
.iter_mut()
.map(|request| match &self.response {
Response::Message(message) => {
let mut message = message.clone();
message.set_request_id(request.id());
Ok(message)
}
#[cfg(feature = "redis")]
Response::Redis(string) => {
use crate::frame::{Frame, RedisFrame};
use crate::message::Message;
let mut message = Message::from_frame(Frame::Redis(RedisFrame::BulkString(
string.to_string().into(),
)));
message.set_request_id(request.id());
Ok(message)
}
Response::Fail => Err(anyhow!("Intentional Fail")),
})
.collect()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,7 @@ mod scatter_transform_tests {

#[tokio::test(flavor = "multi_thread")]
async fn test_scatter_success() {
let response = vec![Message::from_frame(Frame::Redis(RedisFrame::BulkString(
"OK".into(),
)))];
let response = Message::from_frame(Frame::Redis(RedisFrame::BulkString("OK".into())));

let wrapper = Wrapper::new_test(vec![Message::from_frame(Frame::Redis(
RedisFrame::BulkString(Bytes::from_static(b"foo")),
Expand Down
86 changes: 31 additions & 55 deletions shotover/src/transforms/filter.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use crate::message::{Message, Messages, QueryType};
use crate::message::{Message, MessageId, Messages, QueryType};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};

static SHOWN_ERROR: AtomicBool = AtomicBool::new(false);
use std::collections::HashMap;

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -17,6 +15,7 @@ pub enum Filter {
#[derive(Debug, Clone)]
pub struct QueryTypeFilter {
pub filter: Filter,
pub filtered_requests: HashMap<MessageId, Message>,
}

#[derive(Serialize, Deserialize, Debug)]
Expand All @@ -33,6 +32,7 @@ impl TransformConfig for QueryTypeFilterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(QueryTypeFilter {
filter: self.filter.clone(),
filtered_requests: HashMap::new(),
}))
}
}
Expand All @@ -54,60 +54,33 @@ impl Transform for QueryTypeFilter {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
let removed_indexes: Result<Vec<(usize, Message)>> = requests_wrapper
.requests
.iter_mut()
.enumerate()
.filter_map(|(i, m)| match self.filter {
Filter::AllowList(ref allow_list) => {
if allow_list.contains(&m.get_query_type()) {
None
} else {
Some((i, m))
}
}
Filter::DenyList(ref deny_list) => {
if deny_list.contains(&m.get_query_type()) {
Some((i, m))
} else {
None
}
}
})
.map(|(i, m)| {
Ok((
i,
m.to_error_response("Message was filtered out by shotover".to_owned())
.map_err(|e| e.context("Failed to filter message {e:?}"))?,
))
})
.collect();

let removed_indexes = removed_indexes?;

for (i, _) in removed_indexes.iter().rev() {
requests_wrapper.requests.remove(*i);
for request in requests_wrapper.requests.iter_mut() {
let filter_out = match &self.filter {
Filter::AllowList(allow_list) => !allow_list.contains(&request.get_query_type()),
Filter::DenyList(deny_list) => deny_list.contains(&request.get_query_type()),
};

if filter_out {
self.filtered_requests.insert(
request.id(),
request
.to_error_response("Message was filtered out by shotover".to_owned())
.map_err(|e| e.context("Failed to filter message"))?,
);
request.replace_with_dummy();
}
}

let mut shown_error = SHOWN_ERROR.load(Ordering::Relaxed);

requests_wrapper
.call_next_transform()
.await
.map(|mut messages| {

for (i, message) in removed_indexes.into_iter() {
if i <= messages.len() {
messages.insert(i, message);
}
else if !shown_error{
tracing::error!("The current filter transform implementation does not obey the current transform invariants. see https://github.com/shotover/shotover-proxy/issues/499");
shown_error = true;
SHOWN_ERROR.store(true , Ordering::Relaxed);
}
let mut responses = requests_wrapper.call_next_transform().await?;
for response in responses.iter_mut() {
if let Some(request_id) = response.request_id() {
if let Some(error_response) = self.filtered_requests.remove(&request_id) {
*response = error_response;
}
messages
})
}
}

Ok(responses)
}
}

Expand All @@ -121,11 +94,13 @@ mod test {
use crate::transforms::filter::QueryTypeFilter;
use crate::transforms::loopback::Loopback;
use crate::transforms::{Transform, Wrapper};
use std::collections::HashMap;

#[tokio::test(flavor = "multi_thread")]
async fn test_filter_denylist() {
let mut filter_transform = QueryTypeFilter {
filter: Filter::DenyList(vec![QueryType::Read]),
filtered_requests: HashMap::new(),
};

let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))];
Expand Down Expand Up @@ -180,6 +155,7 @@ mod test {
async fn test_filter_allowlist() {
let mut filter_transform = QueryTypeFilter {
filter: Filter::AllowList(vec![QueryType::Write]),
filtered_requests: HashMap::new(),
};

let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))];
Expand Down
7 changes: 6 additions & 1 deletion shotover/src/transforms/loopback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ impl Transform for Loopback {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
// This transform ultimately doesnt make a lot of sense semantically
// but make a vague attempt to follow transform invariants anyway.
for request in &mut requests_wrapper.requests {
request.set_request_id(request.id());
}
Ok(requests_wrapper.requests)
}
}
54 changes: 27 additions & 27 deletions shotover/src/transforms/throttling.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::message::{Message, Messages};
use crate::message::{MessageId, Messages};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
Expand All @@ -10,6 +10,7 @@ use governor::{
};
use nonzero_ext::nonzero;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::num::NonZeroU32;
use std::sync::Arc;

Expand All @@ -29,6 +30,7 @@ impl TransformConfig for RequestThrottlingConfig {
self.max_requests_per_second,
))),
max_requests_per_second: self.max_requests_per_second,
throttled_requests: HashSet::new(),
}))
}
}
Expand All @@ -37,6 +39,7 @@ impl TransformConfig for RequestThrottlingConfig {
pub struct RequestThrottling {
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
max_requests_per_second: NonZeroU32,
throttled_requests: HashSet<MessageId>,
}

impl TransformBuilder for RequestThrottling {
Expand Down Expand Up @@ -67,43 +70,38 @@ impl Transform for RequestThrottling {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
// extract throttled messages from the requests_wrapper
let throttled_messages: Vec<(Message, usize)> = (0..requests_wrapper.requests.len())
.rev()
.filter_map(|i| {
match self
.limiter
.check_n(requests_wrapper.requests[i].cell_count().ok()?)
{
// occurs if all cells can be accommodated and
Ok(Ok(())) => None,
for request in &mut requests_wrapper.requests {
if let Ok(cell_count) = request.cell_count() {
match self.limiter.check_n(cell_count) {
// occurs if all cells can be accommodated
Ok(Ok(())) => {}
// occurs if not all cells can be accommodated.
Ok(Err(_)) => {
let message = requests_wrapper.requests.remove(i);
Some((message, i))
self.throttled_requests.insert(request.id());
request.replace_with_dummy();
}
// occurs when the batch can never go through, meaning the rate limiter's quota's burst size is too low for the given number of cells to be ever allowed through
Err(_) => {
tracing::warn!("A message was received that could never have been successfully delivered since it contains more sub messages than can ever be allowed through via the `RequestThrottling` transforms `max_requests_per_second` configuration.");
let message = requests_wrapper.requests.remove(i);
Some((message, i))
self.throttled_requests.insert(request.id());
request.replace_with_dummy();
}
}
})
.collect();
}
}

// if every message got backpressured we can skip this
let mut responses = if !requests_wrapper.requests.is_empty() {
// send allowed messages to Cassandra
requests_wrapper.call_next_transform().await?
} else {
vec![]
};
// send allowed messages to Cassandra
let mut responses = requests_wrapper.call_next_transform().await?;

// reinsert backpressure error responses back into responses
for (mut message, i) in throttled_messages.into_iter().rev() {
message.set_backpressure()?;
responses.insert(i, message);
for response in responses.iter_mut() {
if response
.request_id()
.map(|id| self.throttled_requests.remove(&id))
.unwrap_or(false)
{
response.set_backpressure()?;
}
}

Ok(responses)
Expand All @@ -124,6 +122,7 @@ mod test {
Box::new(RequestThrottling {
limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(20u32)))),
max_requests_per_second: nonzero!(20u32),
throttled_requests: HashSet::new(),
}),
Box::<NullSink>::default(),
],
Expand All @@ -146,6 +145,7 @@ mod test {
Box::new(RequestThrottling {
limiter: Arc::new(RateLimiter::direct(Quota::per_second(nonzero!(100u32)))),
max_requests_per_second: nonzero!(100u32),
throttled_requests: HashSet::new(),
}),
Box::<NullSink>::default(),
],
Expand Down
Loading

0 comments on commit ba0fffb

Please sign in to comment.