diff --git a/custom-transforms-example/src/redis_get_rewrite.rs b/custom-transforms-example/src/redis_get_rewrite.rs index 3af52b232..0a0650eb6 100644 --- a/custom-transforms-example/src/redis_get_rewrite.rs +++ b/custom-transforms-example/src/redis_get_rewrite.rs @@ -2,8 +2,9 @@ use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; use shotover::frame::{Frame, RedisFrame}; -use shotover::message::Messages; +use shotover::message::{MessageId, Messages}; use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use std::collections::HashSet; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields)] @@ -30,6 +31,7 @@ pub struct RedisGetRewriteBuilder { impl TransformBuilder for RedisGetRewriteBuilder { fn build(&self) -> Box { Box::new(RedisGetRewrite { + get_requests: HashSet::new(), result: self.result.clone(), }) } @@ -40,6 +42,7 @@ impl TransformBuilder for RedisGetRewriteBuilder { } pub struct RedisGetRewrite { + get_requests: HashSet, result: String, } @@ -50,20 +53,25 @@ impl Transform for RedisGetRewrite { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - let mut get_indices = vec![]; - for (i, message) in requests_wrapper.requests.iter_mut().enumerate() { + for message in requests_wrapper.requests.iter_mut() { if let Some(frame) = message.frame() { if is_get(frame) { - get_indices.push(i); + self.get_requests.insert(message.id()); } } } let mut responses = requests_wrapper.call_next_transform().await?; - for i in get_indices { - if let Some(frame) = responses[i].frame() { - rewrite_get(frame, &self.result); - responses[i].invalidate_cache(); + for response in responses.iter_mut() { + if response + .request_id() + .map(|id| self.get_requests.remove(&id)) + .unwrap_or(false) + { + if let Some(frame) = response.frame() { + rewrite_get(frame, &self.result); + response.invalidate_cache(); + } } } diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index 6154c1c44..0396d6a35 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -1,5 +1,8 @@ +use std::collections::HashMap; + use crate::frame::Frame; use crate::frame::RedisFrame; +use crate::message::MessageId; use crate::message::Messages; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, bail, Context, Result}; @@ -19,9 +22,7 @@ const NAME: &str = "RedisClusterPortsRewrite"; #[async_trait(?Send)] impl TransformConfig for RedisClusterPortsRewriteConfig { async fn get_builder(&self, _chain_name: String) -> Result> { - Ok(Box::new(RedisClusterPortsRewrite { - new_port: self.new_port, - })) + Ok(Box::new(RedisClusterPortsRewrite::new(self.new_port))) } } @@ -38,11 +39,21 @@ impl TransformBuilder for RedisClusterPortsRewrite { #[derive(Clone)] pub struct RedisClusterPortsRewrite { new_port: u16, + request_type: HashMap, +} + +#[derive(Clone)] +enum RequestType { + ClusterSlot, + ClusterNodes, } impl RedisClusterPortsRewrite { pub fn new(new_port: u16) -> Self { - RedisClusterPortsRewrite { new_port } + RedisClusterPortsRewrite { + new_port, + request_type: HashMap::new(), + } } } @@ -53,43 +64,48 @@ impl Transform for RedisClusterPortsRewrite { } async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { - // Find the indices of cluster slot messages - let mut cluster_slots_indices = vec![]; - let mut cluster_nodes_indices = vec![]; - - for (i, message) in requests_wrapper.requests.iter_mut().enumerate() { + for message in requests_wrapper.requests.iter_mut() { + let message_id = message.id(); if let Some(frame) = message.frame() { if is_cluster_slots(frame) { - cluster_slots_indices.push(i); + self.request_type + .insert(message_id, RequestType::ClusterSlot); } if is_cluster_nodes(frame) { - cluster_nodes_indices.push(i); + self.request_type + .insert(message_id, RequestType::ClusterNodes); } } } - let mut response = requests_wrapper.call_next_transform().await?; - - // Rewrite the ports in the cluster slots responses - for i in cluster_slots_indices { - if let Some(frame) = response[i].frame() { - rewrite_port_slot(frame, self.new_port) - .context("failed to rewrite CLUSTER SLOTS port")?; - } - response[i].invalidate_cache(); - } - - // Rewrite the ports in the cluster nodes responses - for i in cluster_nodes_indices { - if let Some(frame) = response[i].frame() { - rewrite_port_node(frame, self.new_port) - .context("failed to rewrite CLUSTER NODES port")?; + let mut responses = requests_wrapper.call_next_transform().await?; + + for response in &mut responses { + if let Some(request_id) = response.request_id() { + match self.request_type.remove(&request_id) { + // Rewrite the ports in the cluster slots responses + Some(RequestType::ClusterSlot) => { + if let Some(frame) = response.frame() { + rewrite_port_slot(frame, self.new_port) + .context("failed to rewrite CLUSTER SLOTS port")?; + } + response.invalidate_cache(); + } + // Rewrite the ports in the cluster nodes responses + Some(RequestType::ClusterNodes) => { + if let Some(frame) = response.frame() { + rewrite_port_node(frame, self.new_port) + .context("failed to rewrite CLUSTER NODES port")?; + } + response.invalidate_cache(); + } + None => {} + } } - response[i].invalidate_cache(); } - Ok(response) + Ok(responses) } }