diff --git a/shotover/src/transforms/cassandra/peers_rewrite.rs b/shotover/src/transforms/cassandra/peers_rewrite.rs index c1708073c..c887a0d16 100644 --- a/shotover/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover/src/transforms/cassandra/peers_rewrite.rs @@ -1,4 +1,4 @@ -use crate::message::{Message, Messages}; +use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use crate::{ @@ -38,6 +38,7 @@ impl TransformConfig for CassandraPeersRewriteConfig { pub struct CassandraPeersRewrite { port: u16, peer_table: FQName, + column_names_to_rewrite: MessageIdMap>, } impl CassandraPeersRewrite { @@ -45,6 +46,7 @@ impl CassandraPeersRewrite { CassandraPeersRewrite { port, peer_table: FQName::new("system", "peers_v2"), + column_names_to_rewrite: Default::default(), } } } @@ -68,27 +70,21 @@ impl Transform for CassandraPeersRewrite { async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // Find the indices of queries to system.peers & system.peers_v2 // we need to know which columns in which CQL queries in which messages have system peers - let column_names: Vec<(usize, Vec)> = requests_wrapper - .requests - .iter_mut() - .enumerate() - .filter_map(|(i, m)| { - let sys_peers = extract_native_port_column(&self.peer_table, m); - if sys_peers.is_empty() { - None - } else { - Some((i, sys_peers)) - } - }) - .collect(); + for request in &mut requests_wrapper.requests { + let sys_peers = extract_native_port_column(&self.peer_table, request); + self.column_names_to_rewrite.insert(request.id(), sys_peers); + } - let mut response = requests_wrapper.call_next_transform().await?; + let mut responses = requests_wrapper.call_next_transform().await?; - for (i, name_list) in column_names { - rewrite_port(&mut response[i], &name_list, self.port); + for response in &mut responses { + if let Some(id) = response.request_id() { + let name_list = self.column_names_to_rewrite.remove(&id).unwrap(); + rewrite_port(response, &name_list, self.port); + } } - Ok(response) + Ok(responses) } async fn transform_pushed<'a>(