Skip to content

Commit

Permalink
Merge branch 'main' into dummy_requests
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Feb 19, 2024
2 parents 738f92a + 022529e commit b5f37e9
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 37 deletions.
24 changes: 16 additions & 8 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::Messages;
use shotover::message::{MessageId, Messages};
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use std::collections::HashSet;

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -30,6 +31,7 @@ pub struct RedisGetRewriteBuilder {
impl TransformBuilder for RedisGetRewriteBuilder {
fn build(&self) -> Box<dyn Transform> {
Box::new(RedisGetRewrite {
get_requests: HashSet::new(),
result: self.result.clone(),
})
}
Expand All @@ -40,6 +42,7 @@ impl TransformBuilder for RedisGetRewriteBuilder {
}

pub struct RedisGetRewrite {
get_requests: HashSet<MessageId>,
result: String,
}

Expand All @@ -50,20 +53,25 @@ impl Transform for RedisGetRewrite {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
let mut get_indices = vec![];
for (i, message) in requests_wrapper.requests.iter_mut().enumerate() {
for message in requests_wrapper.requests.iter_mut() {
if let Some(frame) = message.frame() {
if is_get(frame) {
get_indices.push(i);
self.get_requests.insert(message.id());
}
}
}
let mut responses = requests_wrapper.call_next_transform().await?;

for i in get_indices {
if let Some(frame) = responses[i].frame() {
rewrite_get(frame, &self.result);
responses[i].invalidate_cache();
for response in responses.iter_mut() {
if response
.request_id()
.map(|id| self.get_requests.remove(&id))
.unwrap_or(false)
{
if let Some(frame) = response.frame() {
rewrite_get(frame, &self.result);
response.invalidate_cache();
}
}
}

Expand Down
74 changes: 45 additions & 29 deletions shotover/src/transforms/redis/cluster_ports_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::collections::HashMap;

use crate::frame::Frame;
use crate::frame::RedisFrame;
use crate::message::MessageId;
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::{anyhow, bail, Context, Result};
Expand All @@ -19,9 +22,7 @@ const NAME: &str = "RedisClusterPortsRewrite";
#[async_trait(?Send)]
impl TransformConfig for RedisClusterPortsRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(RedisClusterPortsRewrite {
new_port: self.new_port,
}))
Ok(Box::new(RedisClusterPortsRewrite::new(self.new_port)))
}
}

Expand All @@ -38,11 +39,21 @@ impl TransformBuilder for RedisClusterPortsRewrite {
#[derive(Clone)]
pub struct RedisClusterPortsRewrite {
new_port: u16,
request_type: HashMap<MessageId, RequestType>,
}

#[derive(Clone)]
enum RequestType {
ClusterSlot,
ClusterNodes,
}

impl RedisClusterPortsRewrite {
pub fn new(new_port: u16) -> Self {
RedisClusterPortsRewrite { new_port }
RedisClusterPortsRewrite {
new_port,
request_type: HashMap::new(),
}
}
}

Expand All @@ -53,43 +64,48 @@ impl Transform for RedisClusterPortsRewrite {
}

async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result<Messages> {
// Find the indices of cluster slot messages
let mut cluster_slots_indices = vec![];
let mut cluster_nodes_indices = vec![];

for (i, message) in requests_wrapper.requests.iter_mut().enumerate() {
for message in requests_wrapper.requests.iter_mut() {
let message_id = message.id();
if let Some(frame) = message.frame() {
if is_cluster_slots(frame) {
cluster_slots_indices.push(i);
self.request_type
.insert(message_id, RequestType::ClusterSlot);
}

if is_cluster_nodes(frame) {
cluster_nodes_indices.push(i);
self.request_type
.insert(message_id, RequestType::ClusterNodes);
}
}
}

let mut response = requests_wrapper.call_next_transform().await?;

// Rewrite the ports in the cluster slots responses
for i in cluster_slots_indices {
if let Some(frame) = response[i].frame() {
rewrite_port_slot(frame, self.new_port)
.context("failed to rewrite CLUSTER SLOTS port")?;
}
response[i].invalidate_cache();
}

// Rewrite the ports in the cluster nodes responses
for i in cluster_nodes_indices {
if let Some(frame) = response[i].frame() {
rewrite_port_node(frame, self.new_port)
.context("failed to rewrite CLUSTER NODES port")?;
let mut responses = requests_wrapper.call_next_transform().await?;

for response in &mut responses {
if let Some(request_id) = response.request_id() {
match self.request_type.remove(&request_id) {
// Rewrite the ports in the cluster slots responses
Some(RequestType::ClusterSlot) => {
if let Some(frame) = response.frame() {
rewrite_port_slot(frame, self.new_port)
.context("failed to rewrite CLUSTER SLOTS port")?;
}
response.invalidate_cache();
}
// Rewrite the ports in the cluster nodes responses
Some(RequestType::ClusterNodes) => {
if let Some(frame) = response.frame() {
rewrite_port_node(frame, self.new_port)
.context("failed to rewrite CLUSTER NODES port")?;
}
response.invalidate_cache();
}
None => {}
}
}
response[i].invalidate_cache();
}

Ok(response)
Ok(responses)
}
}

Expand Down

0 comments on commit b5f37e9

Please sign in to comment.