Skip to content

Commit

Permalink
Pass in Wrapper as &mut
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 5, 2024
1 parent 8438fa3 commit 9cb4503
Show file tree
Hide file tree
Showing 29 changed files with 161 additions and 57 deletions.
5 changes: 4 additions & 1 deletion custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ impl Transform for RedisGetRewrite {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
for message in requests_wrapper.requests.iter_mut() {
if let Some(frame) = message.frame() {
if is_get(frame) {
Expand Down
5 changes: 3 additions & 2 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ impl<'a> BenchInput<'a> {
let mut chain = chain.build(TransformContextBuilder::new_test());

// Run the chain once so we are measuring the chain once each transform has been fully initialized
futures::executor::block_on(chain.process_request(wrapper.clone())).unwrap();
futures::executor::block_on(chain.process_request(&mut wrapper.clone())).unwrap();

BenchInput {
chain,
Expand All @@ -372,8 +372,9 @@ impl<'a> BenchInput<'a> {

async fn bench(mut self) -> (Vec<Message>, TransformChain) {
// Return both the chain itself and the response to avoid measuring the time to drop the values in the benchmark
let mut wrapper = self.wrapper;
(
self.chain.process_request(self.wrapper).await.unwrap(),
self.chain.process_request(&mut wrapper).await.unwrap(),
self.chain,
)
}
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
// Only flush messages if we are shutting down due to application shutdown
// If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing.
if result.is_ok() {
match self.chain.process_request(Wrapper::flush()).await {
match self.chain.process_request(&mut Wrapper::flush()).await {
Ok(_) => {}
Err(e) => error!(
"{:?}",
Expand Down Expand Up @@ -736,9 +736,9 @@ impl<C: CodecBuilder + 'static> Handler<C> {
) -> Result<Messages> {
self.pending_requests.process_requests(&requests);

let wrapper = Wrapper::new_with_addr(requests, local_addr);
let mut wrapper = Wrapper::new_with_addr(requests, local_addr);

match self.chain.process_request(wrapper).await.context(
match self.chain.process_request(&mut wrapper).await.context(
"Chain failed to send and/or receive messages, the connection will now be closed.",
) {
Ok(x) => {
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ impl Transform for CassandraPeersRewrite {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
// Find the indices of queries to system.peers & system.peers_v2
// we need to know which columns in which CQL queries in which messages have system peers
for request in &mut requests_wrapper.requests {
Expand Down
8 changes: 6 additions & 2 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,11 @@ impl Transform for CassandraSinkCluster {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
self.send_message(requests_wrapper.requests).await
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut requests_wrapper.requests))
.await
}
}
8 changes: 6 additions & 2 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ impl Transform for CassandraSinkSingle {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
self.send_message(requests_wrapper.requests).await
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut requests_wrapper.requests))
.await
}
}
9 changes: 6 additions & 3 deletions shotover/src/transforms/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ impl BufferedChain {
}

impl TransformChain {
pub async fn process_request(&mut self, mut wrapper: Wrapper<'_>) -> Result<Messages> {
pub async fn process_request<'a>(
&'a mut self,
wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
let start = Instant::now();
wrapper.reset(&mut self.chain);

Expand Down Expand Up @@ -319,7 +322,7 @@ impl TransformChainBuilder {

let mut wrapper = Wrapper::new_with_addr(messages, local_addr);
wrapper.flush = flush;
let chain_response = chain.process_request(wrapper).await;
let chain_response = chain.process_request(&mut wrapper).await;

if let Err(e) = &chain_response {
error!("Internal error in buffered chain: {e:?}");
Expand All @@ -338,7 +341,7 @@ impl TransformChainBuilder {
debug!("buffered chain processing thread exiting, stopping chain loop and dropping");

match chain
.process_request(Wrapper::flush())
.process_request(&mut Wrapper::flush())
.await
{
Ok(_) => info!("Buffered chain {} was shutdown", chain.name),
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ impl Transform for Coalesce {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
self.buffer.append(&mut requests_wrapper.requests);

let flush_buffer = requests_wrapper.flush
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ impl Transform for DebugForceParse {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
for message in &mut requests_wrapper.requests {
if self.parse_requests {
message.frame();
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ impl Transform for DebugLogToFile {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Vec<Message>> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Vec<Message>> {
for message in &requests_wrapper.requests {
self.request_counter += 1;
let path = self
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ impl Transform for DebugPrinter {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
for request in &mut requests_wrapper.requests {
info!("Request: {}", request.to_high_level_string());
}
Expand Down
5 changes: 4 additions & 1 deletion shotover/src/transforms/debug/returner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ impl Transform for DebugReturner {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
requests_wrapper
.requests
.iter_mut()
Expand Down
15 changes: 12 additions & 3 deletions shotover/src/transforms/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ impl Transform for QueryTypeFilter {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
for request in requests_wrapper.requests.iter_mut() {
let filter_out = match &self.filter {
Filter::AllowList(allow_list) => !allow_list.contains(&request.get_query_type()),
Expand Down Expand Up @@ -138,7 +141,10 @@ mod test {

let mut requests_wrapper = Wrapper::new_test(messages);
requests_wrapper.reset(&mut chain);
let result = filter_transform.transform(requests_wrapper).await.unwrap();
let result = filter_transform
.transform(&mut requests_wrapper)
.await
.unwrap();

assert_eq!(result.len(), 26);

Expand Down Expand Up @@ -193,7 +199,10 @@ mod test {

let mut requests_wrapper = Wrapper::new_test(messages);
requests_wrapper.reset(&mut chain);
let result = filter_transform.transform(requests_wrapper).await.unwrap();
let result = filter_transform
.transform(&mut requests_wrapper)
.await
.unwrap();

assert_eq!(result.len(), 26);

Expand Down
7 changes: 5 additions & 2 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,10 @@ impl Transform for KafkaSinkCluster {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
let mut responses = if requests_wrapper.requests.is_empty() {
// there are no requests, so no point sending any, but we should check for any responses without awaiting
self.recv_responses()
Expand All @@ -364,7 +367,7 @@ impl Transform for KafkaSinkCluster {
}
}

self.route_requests(requests_wrapper.requests)
self.route_requests(std::mem::take(&mut requests_wrapper.requests))
.await
.context("Failed to route requests")?;
self.send_requests().await?;
Expand Down
7 changes: 5 additions & 2 deletions shotover/src/transforms/kafka/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ impl Transform for KafkaSinkSingle {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
if self.connection.is_none() {
let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkSingle".to_owned());
let address = (requests_wrapper.local_addr.ip(), self.address_port);
Expand Down Expand Up @@ -161,7 +164,7 @@ impl Transform for KafkaSinkSingle {
// send
let connection = self.connection.as_mut().unwrap();
let requests_count = requests_wrapper.requests.len();
connection.send(requests_wrapper.requests)?;
connection.send(std::mem::take(&mut requests_wrapper.requests))?;

// receive
while responses.len() < requests_count {
Expand Down
7 changes: 5 additions & 2 deletions shotover/src/transforms/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ impl Transform for ConnectionBalanceAndPool {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
if self.active_connection.is_none() {
let mut all_connections = self.all_connections.lock().await;
if all_connections.len() < self.max_connections {
Expand All @@ -105,7 +108,7 @@ impl Transform for ConnectionBalanceAndPool {
self.active_connection
.as_mut()
.unwrap()
.process_request(requests_wrapper, None)
.process_request(requests_wrapper.take(), None)
.await
}
}
7 changes: 5 additions & 2 deletions shotover/src/transforms/loopback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ impl Transform for Loopback {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
// This transform ultimately doesnt make a lot of sense semantically
// but make a vague attempt to follow transform invariants anyway.
for request in &mut requests_wrapper.requests {
request.set_request_id(request.id());
}
Ok(requests_wrapper.requests)
Ok(std::mem::take(&mut requests_wrapper.requests))
}
}
14 changes: 12 additions & 2 deletions shotover/src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ impl<'a> Clone for Wrapper<'a> {
}

impl<'a> Wrapper<'a> {
fn take(&mut self) -> Self {
Wrapper {
requests: std::mem::take(&mut self.requests),
transforms: std::mem::take(&mut self.transforms),
local_addr: self.local_addr,
flush: self.flush,
}
}

/// This function will take a mutable reference to the next transform out of the [`Wrapper`] structs
/// vector of transform references. It then sets up the chain name and transform name in the local
/// thread scope for structured logging.
Expand All @@ -183,7 +192,7 @@ impl<'a> Wrapper<'a> {
/// the execution time of the [Transform::transform] function as a metrics latency histogram.
///
/// The result of calling the next transform is then provided as a response.
pub async fn call_next_transform(mut self) -> Result<Messages> {
pub async fn call_next_transform(&'a mut self) -> Result<Messages> {
let TransformAndMetrics {
transform,
transform_total,
Expand Down Expand Up @@ -327,7 +336,8 @@ pub trait Transform: Send {
/// * Transform that do call subsquent chains via `requests_wrapper.call_next_transform()` are non-terminating transforms.
///
/// You can have have a transform that is both non-terminating and a sink.
async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages>;
async fn transform<'a>(&'a mut self, requests_wrapper: &'a mut Wrapper<'a>)
-> Result<Messages>;

/// Name of the transform used in logs and displayed to the user
fn get_name(&self) -> &'static str;
Expand Down
7 changes: 5 additions & 2 deletions shotover/src/transforms/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ impl Transform for NullSink {
NAME
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
for request in &mut requests_wrapper.requests {
// reuse the requests to hold the responses to avoid an allocation
*request = request
.from_request_to_error_response("Handled by shotover null transform".to_string())?;
}
Ok(requests_wrapper.requests)
Ok(std::mem::take(&mut requests_wrapper.requests))
}
}
9 changes: 6 additions & 3 deletions shotover/src/transforms/opensearch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,14 @@ impl Transform for OpenSearchSinkSingle {
NAME
}

async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
async fn transform<'a>(
&'a mut self,
requests_wrapper: &'a mut Wrapper<'a>,
) -> Result<Messages> {
// Return immediately if we have no messages.
// If we tried to send no messages we would block forever waiting for a reply that will never come.
if requests_wrapper.requests.is_empty() {
return Ok(requests_wrapper.requests);
return Ok(vec![]);
}

if self.connection.is_none() {
Expand All @@ -115,7 +118,7 @@ impl Transform for OpenSearchSinkSingle {
let messages_len = requests_wrapper.requests.len();

let mut result = Vec::with_capacity(messages_len);
for message in requests_wrapper.requests {
for message in requests_wrapper.requests.drain(..) {
let (tx, rx) = oneshot::channel();

connection
Expand Down
Loading

0 comments on commit 9cb4503

Please sign in to comment.