Skip to content

Commit

Permalink
Merge branch 'main' into cluster_ports_rewrite_message_id
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 19, 2024
2 parents 7c68e72 + 203691a commit 3d6bdd0
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::Messages;
use shotover::message::{MessageId, Messages};
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use std::collections::HashSet;

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -30,6 +31,7 @@ pub struct RedisGetRewriteBuilder {
impl TransformBuilder for RedisGetRewriteBuilder {
fn build(&self) -> Box<dyn Transform> {
Box::new(RedisGetRewrite {
get_requests: HashSet::new(),
result: self.result.clone(),
})
}
Expand All @@ -40,6 +42,7 @@ impl TransformBuilder for RedisGetRewriteBuilder {
}

pub struct RedisGetRewrite {
get_requests: HashSet<MessageId>,
result: String,
}

Expand All @@ -50,20 +53,25 @@ impl Transform for RedisGetRewrite {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
let mut get_indices = vec![];
for (i, message) in requests_wrapper.requests.iter_mut().enumerate() {
for message in requests_wrapper.requests.iter_mut() {
if let Some(frame) = message.frame() {
if is_get(frame) {
get_indices.push(i);
self.get_requests.insert(message.id());
}
}
}
let mut responses = requests_wrapper.call_next_transform().await?;

for i in get_indices {
if let Some(frame) = responses[i].frame() {
rewrite_get(frame, &self.result);
responses[i].invalidate_cache();
for response in responses.iter_mut() {
if response
.request_id()
.map(|id| self.get_requests.remove(&id))
.unwrap_or(false)
{
if let Some(frame) = response.frame() {
rewrite_get(frame, &self.result);
response.invalidate_cache();
}
}
}

Expand Down

0 comments on commit 3d6bdd0

Please sign in to comment.