diff --git a/custom-transforms-example/src/redis_get_rewrite.rs b/custom-transforms-example/src/redis_get_rewrite.rs index 7d31f4260..55d757790 100644 --- a/custom-transforms-example/src/redis_get_rewrite.rs +++ b/custom-transforms-example/src/redis_get_rewrite.rs @@ -3,7 +3,9 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use shotover::frame::{Frame, RedisFrame}; use shotover::message::{MessageIdSet, Messages}; -use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use shotover::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields)] @@ -15,7 +17,10 @@ const NAME: &str = "RedisGetRewrite"; #[typetag::serde(name = "RedisGetRewrite")] #[async_trait(?Send)] impl TransformConfig for RedisGetRewriteConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(RedisGetRewriteBuilder { result: self.result.clone(), })) diff --git a/shotover/benches/benches/chain.rs b/shotover/benches/benches/chain.rs index 4c0f4b59c..140dc476d 100644 --- a/shotover/benches/benches/chain.rs +++ b/shotover/benches/benches/chain.rs @@ -5,8 +5,8 @@ use criterion::{criterion_group, BatchSize, Criterion}; use hex_literal::hex; use shotover::codec::CodecState; use shotover::frame::cassandra::{parse_statement_single, Tracing}; -use shotover::frame::RedisFrame; use shotover::frame::{CassandraFrame, CassandraOperation, Frame}; +use shotover::frame::{MessageType, RedisFrame}; use shotover::message::{Message, MessageIdMap, QueryType}; use shotover::transforms::cassandra::peers_rewrite::CassandraPeersRewrite; use shotover::transforms::chain::{TransformChain, TransformChainBuilder}; @@ -19,7 +19,7 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig}; use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite; use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger; use shotover::transforms::throttling::RequestThrottlingConfig; -use shotover::transforms::{TransformConfig, Wrapper}; +use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper}; fn criterion_benchmark(c: &mut Criterion) { crate::init(); @@ -194,7 +194,10 @@ fn criterion_benchmark(c: &mut Criterion) { // an absurdly large value is given so that all messages will pass through max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(), } - .get_builder("".to_owned()), + .get_builder(TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }), ) .unwrap(), Box::::default(), @@ -303,7 +306,10 @@ fn criterion_benchmark(c: &mut Criterion) { kek_id: "".to_string(), }, } - .get_builder("".to_owned()), + .get_builder(TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }), ) .unwrap(), Box::::default(), diff --git a/shotover/src/config/chain.rs b/shotover/src/config/chain.rs index 53695c959..0974f87f3 100644 --- a/shotover/src/config/chain.rs +++ b/shotover/src/config/chain.rs @@ -1,5 +1,5 @@ use crate::transforms::chain::TransformChainBuilder; -use crate::transforms::{TransformBuilder, TransformConfig}; +use crate::transforms::{TransformBuilder, TransformConfig, TransformContextConfig}; use anyhow::Result; use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor}; use serde::{Deserialize, Serialize}; @@ -14,12 +14,18 @@ pub struct TransformChainConfig( ); impl TransformChainConfig { - pub async fn get_builder(&self, name: String) -> Result { + pub async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result { let mut transforms: Vec> = Vec::new(); for tc in &self.0 { - transforms.push(tc.get_builder(name.clone()).await?) + transforms.push(tc.get_builder(transform_context.clone()).await?) } - Ok(TransformChainBuilder::new(transforms, name.leak())) + Ok(TransformChainBuilder::new( + transforms, + transform_context.chain_name.leak(), + )) } } diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 4da1c5fc8..2eef16403 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -4,7 +4,7 @@ use crate::message::{Message, Messages}; use crate::sources::Transport; use crate::tls::{AcceptError, TlsAcceptor}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::Wrapper; +use crate::transforms::{TransformContextConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; use bytes::BytesMut; use futures::future::join_all; @@ -92,8 +92,12 @@ impl TcpCodecListener { gauge!("shotover_available_connections_count", "source" => source_name.clone()); available_connections_gauge.set(limit_connections.available_permits() as f64); + let chain_usage_config = TransformContextConfig { + chain_name: source_name.clone(), + protocol: codec.protocol(), + }; let chain_builder = chain_config - .get_builder(source_name.clone()) + .get_builder(chain_usage_config) .await .map_err(|x| vec![format!("{x:?}")])?; diff --git a/shotover/src/transforms/cassandra/peers_rewrite.rs b/shotover/src/transforms/cassandra/peers_rewrite.rs index a87e9e391..c1708073c 100644 --- a/shotover/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover/src/transforms/cassandra/peers_rewrite.rs @@ -1,10 +1,13 @@ -use crate::frame::{ - value::{GenericValue, IntSize}, - CassandraOperation, CassandraResult, Frame, -}; use crate::message::{Message, Messages}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::{ + frame::{ + value::{GenericValue, IntSize}, + CassandraOperation, CassandraResult, Frame, + }, + transforms::TransformContextConfig, +}; use anyhow::Result; use async_trait::async_trait; use cassandra_protocol::frame::events::{ServerEvent, StatusChange}; @@ -23,7 +26,10 @@ const NAME: &str = "CassandraPeersRewrite"; #[typetag::serde(name = "CassandraPeersRewrite")] #[async_trait(?Send)] impl TransformConfig for CassandraPeersRewriteConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(CassandraPeersRewrite::new(self.port))) } } diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 9de065400..5fb4fd222 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -5,7 +5,9 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, MessageIdMap, Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use cassandra_protocol::events::ServerEvent; @@ -66,7 +68,10 @@ const NAME: &str = "CassandraSinkCluster"; #[typetag::serde(name = "CassandraSinkCluster")] #[async_trait(?Send)] impl TransformConfig for CassandraSinkClusterConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; let mut shotover_nodes = self.shotover_nodes.clone(); let index = self @@ -84,7 +89,7 @@ impl TransformConfig for CassandraSinkClusterConfig { Ok(Box::new(CassandraSinkClusterBuilder::new( self.first_contact_points.clone(), shotover_nodes, - chain_name, + transform_context.chain_name, local_node, tls, self.connect_timeout_ms, diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index eba0f1fc9..527ca5db9 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -4,7 +4,9 @@ use crate::frame::cassandra::CassandraMetadata; use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::Response; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use cassandra_protocol::frame::Version; @@ -29,11 +31,14 @@ const NAME: &str = "CassandraSinkSingle"; #[typetag::serde(name = "CassandraSinkSingle")] #[async_trait(?Send)] impl TransformConfig for CassandraSinkSingleConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; Ok(Box::new(CassandraSinkSingleBuilder::new( self.address.clone(), - chain_name, + transform_context.chain_name, tls, self.connect_timeout_ms, self.read_timeout, diff --git a/shotover/src/transforms/coalesce.rs b/shotover/src/transforms/coalesce.rs index 3b2939c57..b2a0add1f 100644 --- a/shotover/src/transforms/coalesce.rs +++ b/shotover/src/transforms/coalesce.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::message::Messages; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -24,7 +25,10 @@ const NAME: &str = "Coalesce"; #[typetag::serde(name = "Coalesce")] #[async_trait(?Send)] impl TransformConfig for CoalesceConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(Coalesce { buffer: Vec::with_capacity(self.flush_when_buffered_message_count.unwrap_or(0)), flush_when_buffered_message_count: self.flush_when_buffered_message_count, diff --git a/shotover/src/transforms/debug/force_parse.rs b/shotover/src/transforms/debug/force_parse.rs index fd739b397..d71ff8e3f 100644 --- a/shotover/src/transforms/debug/force_parse.rs +++ b/shotover/src/transforms/debug/force_parse.rs @@ -7,6 +7,8 @@ use crate::message::Messages; /// It could also be used to ensure that messages round trip correctly when parsed. #[cfg(feature = "alpha-transforms")] use crate::transforms::TransformConfig; +#[cfg(feature = "alpha-transforms")] +use crate::transforms::TransformContextConfig; use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; @@ -25,7 +27,10 @@ pub struct DebugForceParseConfig { #[typetag::serde(name = "DebugForceParse")] #[async_trait(?Send)] impl TransformConfig for DebugForceParseConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(DebugForceParse { parse_requests: self.parse_requests, parse_responses: self.parse_responses, @@ -49,7 +54,10 @@ const NAME: &str = "DebugForceEncode"; #[typetag::serde(name = "DebugForceEncode")] #[async_trait(?Send)] impl TransformConfig for DebugForceEncodeConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(DebugForceParse { parse_requests: self.encode_requests, parse_responses: self.encode_responses, diff --git a/shotover/src/transforms/debug/log_to_file.rs b/shotover/src/transforms/debug/log_to_file.rs index 98a0a4803..9a4b707ff 100644 --- a/shotover/src/transforms/debug/log_to_file.rs +++ b/shotover/src/transforms/debug/log_to_file.rs @@ -17,7 +17,10 @@ const NAME: &str = "DebugLogToFile"; #[typetag::serde(name = "DebugLogToFile")] #[async_trait(?Send)] impl crate::transforms::TransformConfig for DebugLogToFileConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: crate::transforms::TransformContextConfig, + ) -> Result> { // This transform is used for debugging a specific run, so we clean out any logs left over from the previous run std::fs::remove_dir_all("message-log").ok(); diff --git a/shotover/src/transforms/debug/printer.rs b/shotover/src/transforms/debug/printer.rs index 2d9c89bf4..2ba8e8f67 100644 --- a/shotover/src/transforms/debug/printer.rs +++ b/shotover/src/transforms/debug/printer.rs @@ -1,5 +1,7 @@ use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -13,7 +15,10 @@ const NAME: &str = "DebugPrinter"; #[typetag::serde(name = "DebugPrinter")] #[async_trait(?Send)] impl TransformConfig for DebugPrinterConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(DebugPrinter::new())) } } diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index 47a2fef3c..5f62acc68 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,5 +1,7 @@ use crate::message::{Message, Messages}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -15,7 +17,10 @@ const NAME: &str = "DebugReturner"; #[typetag::serde(name = "DebugReturner")] #[async_trait(?Send)] impl TransformConfig for DebugReturnerConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(DebugReturner::new(self.response.clone()))) } } diff --git a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs index b5816fa34..efffb0025 100644 --- a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -2,7 +2,9 @@ use crate::config::chain::TransformChainConfig; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::Result; use async_trait::async_trait; use futures::stream::FuturesUnordered; @@ -23,12 +25,19 @@ const NAME: &str = "TuneableConsistencyScatter"; #[typetag::serde(name = "TuneableConsistencyScatter")] #[async_trait(?Send)] impl TransformConfig for TuneableConsistencyScatterConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let mut route_map = Vec::with_capacity(self.route_map.len()); warn!("Using this transform is considered unstable - Does not work with REDIS pipelines"); for (key, value) in &self.route_map { - route_map.push(value.get_builder(key.clone()).await?); + let chain_config = TransformContextConfig { + chain_name: key.clone(), + protocol: transform_context.protocol, + }; + route_map.push(value.get_builder(chain_config).await?); } route_map.sort_by_key(|x| x.name); diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index fc5ae28e5..4c9e0b51f 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::message::{Message, MessageIdMap, Messages, QueryType}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -28,7 +29,10 @@ const NAME: &str = "QueryTypeFilter"; #[typetag::serde(name = "QueryTypeFilter")] #[async_trait(?Send)] impl TransformConfig for QueryTypeFilterConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(QueryTypeFilter { filter: self.filter.clone(), filtered_requests: MessageIdMap::default(), diff --git a/shotover/src/transforms/kafka/sink_cluster.rs b/shotover/src/transforms/kafka/sink_cluster.rs index 71f1fc8df..276eeb330 100644 --- a/shotover/src/transforms/kafka/sink_cluster.rs +++ b/shotover/src/transforms/kafka/sink_cluster.rs @@ -6,8 +6,8 @@ use crate::message::{Message, Messages}; use crate::tcp; use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; -use crate::transforms::TransformConfig; use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{TransformConfig, TransformContextConfig}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use dashmap::DashMap; @@ -45,11 +45,14 @@ const NAME: &str = "KafkaSinkCluster"; #[typetag::serde(name = "KafkaSinkCluster")] #[async_trait(?Send)] impl TransformConfig for KafkaSinkClusterConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(KafkaSinkClusterBuilder::new( self.first_contact_points.clone(), self.shotover_nodes.clone(), - chain_name, + transform_context.chain_name, self.connect_timeout_ms, self.read_timeout, ))) diff --git a/shotover/src/transforms/kafka/sink_single.rs b/shotover/src/transforms/kafka/sink_single.rs index 3d985f303..b78e873bc 100644 --- a/shotover/src/transforms/kafka/sink_single.rs +++ b/shotover/src/transforms/kafka/sink_single.rs @@ -7,7 +7,7 @@ use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::kafka::common::produce_channel; use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{Transform, TransformBuilder, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformContextConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -33,11 +33,14 @@ const NAME: &str = "KafkaSinkSingle"; #[typetag::serde(name = "KafkaSinkSingle")] #[async_trait(?Send)] impl TransformConfig for KafkaSinkSingleConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; Ok(Box::new(KafkaSinkSingleBuilder::new( self.destination_port, - chain_name, + transform_context.chain_name, self.connect_timeout_ms, self.read_timeout, tls, diff --git a/shotover/src/transforms/load_balance.rs b/shotover/src/transforms/load_balance.rs index 03edc6d15..9d82929f9 100644 --- a/shotover/src/transforms/load_balance.rs +++ b/shotover/src/transforms/load_balance.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -20,8 +21,11 @@ const NAME: &str = "ConnectionBalanceAndPool"; #[typetag::serde(name = "ConnectionBalanceAndPool")] #[async_trait(?Send)] impl TransformConfig for ConnectionBalanceAndPoolConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { - let chain = Arc::new(self.chain.get_builder(self.name.clone()).await?); + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { + let chain = Arc::new(self.chain.get_builder(transform_context).await?); Ok(Box::new(ConnectionBalanceAndPoolBuilder { max_connections: self.max_connections, diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index 51614cb79..1c1dd51d8 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -1,6 +1,7 @@ //! Various types required for defining a transform use self::chain::TransformAndMetrics; +use crate::frame::MessageType; use crate::message::{Message, MessageIdMap, Messages}; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -64,7 +65,17 @@ impl Debug for dyn TransformBuilder { #[typetag::serde] #[async_trait(?Send)] pub trait TransformConfig: Debug { - async fn get_builder(&self, chain_name: String) -> Result>; + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result>; +} + +/// Provides extra context that may be needed when creating a TransformBuilder +#[derive(Clone)] +pub struct TransformContextConfig { + pub chain_name: String, + pub protocol: MessageType, } /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the diff --git a/shotover/src/transforms/null.rs b/shotover/src/transforms/null.rs index 7cdd24e0c..0c46e6a09 100644 --- a/shotover/src/transforms/null.rs +++ b/shotover/src/transforms/null.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::message::Messages; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -12,7 +13,10 @@ const NAME: &str = "NullSink"; #[typetag::serde(name = "NullSink")] #[async_trait(?Send)] impl TransformConfig for NullSinkConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(NullSink {})) } } diff --git a/shotover/src/transforms/opensearch/mod.rs b/shotover/src/transforms/opensearch/mod.rs index 43da41f47..6464d29c5 100644 --- a/shotover/src/transforms/opensearch/mod.rs +++ b/shotover/src/transforms/opensearch/mod.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::tcp; use crate::transforms::{Messages, Transform, TransformBuilder, TransformConfig, Wrapper}; use crate::{ @@ -25,10 +26,13 @@ const NAME: &str = "OpenSearchSinkSingle"; #[typetag::serde(name = "OpenSearchSinkSingle")] #[async_trait(?Send)] impl TransformConfig for OpenSearchSinkSingleConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(OpenSearchSinkSingleBuilder::new( self.address.clone(), - chain_name, + transform_context.chain_name, self.connect_timeout_ms, ))) } diff --git a/shotover/src/transforms/parallel_map.rs b/shotover/src/transforms/parallel_map.rs index cec3fd3c2..a716fd7e6 100644 --- a/shotover/src/transforms/parallel_map.rs +++ b/shotover/src/transforms/parallel_map.rs @@ -12,6 +12,8 @@ use serde::{Deserialize, Serialize}; use std::future::Future; use std::pin::Pin; +use super::TransformContextConfig; + #[derive(Debug)] pub struct ParallelMapBuilder { chains: Vec, @@ -74,10 +76,17 @@ const NAME: &str = "ParallelMap"; #[typetag::serde(name = "ParallelMap")] #[async_trait(?Send)] impl TransformConfig for ParallelMapConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let mut chains = vec![]; for _ in 0..self.parallelism { - chains.push(self.chain.get_builder("parallel_map_chain".into()).await?); + let transform_context_config = TransformContextConfig { + chain_name: "parallel_map_chain".into(), + protocol: transform_context.protocol, + }; + chains.push(self.chain.get_builder(transform_context_config).await?); } Ok(Box::new(ParallelMapBuilder { diff --git a/shotover/src/transforms/protect/mod.rs b/shotover/src/transforms/protect/mod.rs index 2f81ade14..f2041daac 100644 --- a/shotover/src/transforms/protect/mod.rs +++ b/shotover/src/transforms/protect/mod.rs @@ -27,15 +27,15 @@ pub struct ProtectConfig { pub key_manager: KeyManagerConfig, } -#[cfg(feature = "alpha-transforms")] -use crate::transforms::TransformConfig; - const NAME: &str = "Protect"; #[cfg(feature = "alpha-transforms")] #[typetag::serde(name = "Protect")] #[async_trait(?Send)] -impl TransformConfig for ProtectConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { +impl crate::transforms::TransformConfig for ProtectConfig { + async fn get_builder( + &self, + _transform_context: crate::transforms::TransformContextConfig, + ) -> Result> { Ok(Box::new(Protect { keyspace_table_columns: self .keyspace_table_columns diff --git a/shotover/src/transforms/query_counter.rs b/shotover/src/transforms/query_counter.rs index f71012fce..575ee9d72 100644 --- a/shotover/src/transforms/query_counter.rs +++ b/shotover/src/transforms/query_counter.rs @@ -8,6 +8,8 @@ use metrics::counter; use serde::Deserialize; use serde::Serialize; +use super::TransformContextConfig; + #[derive(Debug, Clone)] pub struct QueryCounter { counter_name: String, @@ -85,7 +87,10 @@ const NAME: &str = "QueryCounter"; #[typetag::serde(name = "QueryCounter")] #[async_trait(?Send)] impl TransformConfig for QueryCounterConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(QueryCounter::new(self.name.clone()))) } } diff --git a/shotover/src/transforms/redis/cache.rs b/shotover/src/transforms/redis/cache.rs index ab629874e..a5948fb67 100644 --- a/shotover/src/transforms/redis/cache.rs +++ b/shotover/src/transforms/redis/cache.rs @@ -1,8 +1,10 @@ use crate::config::chain::TransformChainConfig; -use crate::frame::{CassandraFrame, CassandraOperation, Frame, RedisFrame}; +use crate::frame::{CassandraFrame, CassandraOperation, Frame, MessageType, RedisFrame}; use crate::message::{Message, Messages}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{bail, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -85,7 +87,10 @@ const NAME: &str = "RedisCache"; #[typetag::serde(name = "RedisCache")] #[async_trait(?Send)] impl TransformConfig for RedisConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { let missed_requests = counter!("shotover_cache_miss_count"); let caching_schema: HashMap = self @@ -94,8 +99,13 @@ impl TransformConfig for RedisConfig { .map(|(k, v)| (FQName::parse(k), v.into())) .collect(); + let transform_context_config = TransformContextConfig { + chain_name: "cache_chain".into(), + protocol: MessageType::Redis, + }; + Ok(Box::new(SimpleRedisCacheBuilder { - cache_chain: self.chain.get_builder("cache_chain".to_string()).await?, + cache_chain: self.chain.get_builder(transform_context_config).await?, caching_schema, missed_requests, })) diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index f740ee96f..0e704cfcf 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -1,6 +1,7 @@ use crate::frame::Frame; use crate::frame::RedisFrame; use crate::message::{MessageIdMap, Messages}; +use crate::transforms::TransformContextConfig; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; @@ -18,7 +19,10 @@ const NAME: &str = "RedisClusterPortsRewrite"; #[typetag::serde(name = "RedisClusterPortsRewrite")] #[async_trait(?Send)] impl TransformConfig for RedisClusterPortsRewriteConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(RedisClusterPortsRewrite::new(self.new_port))) } } diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 952267c27..5e423dcf8 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -7,7 +7,9 @@ use crate::transforms::redis::RedisError; use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{ResponseFuture, Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + ResponseFuture, Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, bail, ensure, Context, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -48,7 +50,10 @@ const NAME: &str = "RedisSinkCluster"; #[typetag::serde(name = "RedisSinkCluster")] #[async_trait(?Send)] impl TransformConfig for RedisSinkClusterConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let connection_pool = ConnectionPool::new_with_auth( Duration::from_millis(self.connect_timeout_ms), RedisCodecBuilder::new(Direction::Sink, "RedisSinkCluster".to_owned()), @@ -60,7 +65,7 @@ impl TransformConfig for RedisSinkClusterConfig { direct_destination: self.direct_destination.clone(), connection_count: self.connection_count.unwrap_or(1), connection_pool, - chain_name, + chain_name: transform_context.chain_name, shared_topology: Arc::new(RwLock::new(Topology::new())), })) } diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index 9521de68b..5329127bf 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -1,12 +1,15 @@ -use crate::codec::{ - redis::{RedisCodecBuilder, RedisDecoder, RedisEncoder}, - CodecBuilder, CodecReadError, Direction, -}; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::tcp; use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::{ + codec::{ + redis::{RedisCodecBuilder, RedisDecoder, RedisEncoder}, + CodecBuilder, CodecReadError, Direction, + }, + transforms::TransformContextConfig, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::{FutureExt, SinkExt, StreamExt}; @@ -33,12 +36,15 @@ const NAME: &str = "RedisSinkSingle"; #[typetag::serde(name = "RedisSinkSingle")] #[async_trait(?Send)] impl TransformConfig for RedisSinkSingleConfig { - async fn get_builder(&self, chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let tls = self.tls.clone().map(TlsConnector::new).transpose()?; Ok(Box::new(RedisSinkSingleBuilder::new( self.address.clone(), tls, - chain_name, + transform_context.chain_name, self.connect_timeout_ms, ))) } diff --git a/shotover/src/transforms/redis/timestamp_tagging.rs b/shotover/src/transforms/redis/timestamp_tagging.rs index 8fb10c93b..2716c5591 100644 --- a/shotover/src/transforms/redis/timestamp_tagging.rs +++ b/shotover/src/transforms/redis/timestamp_tagging.rs @@ -1,7 +1,9 @@ use crate::frame::redis::redis_query_type; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; +use crate::transforms::{ + Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper, +}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -19,7 +21,10 @@ const NAME: &str = "RedisTimestampTagger"; #[typetag::serde(name = "RedisTimestampTagger")] #[async_trait(?Send)] impl TransformConfig for RedisTimestampTaggerConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(RedisTimestampTagger {})) } } diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index aaca1f5cd..061c2de88 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -166,7 +167,10 @@ const NAME: &str = "Tee"; #[typetag::serde(name = "Tee")] #[async_trait(?Send)] impl TransformConfig for TeeConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + transform_context: TransformContextConfig, + ) -> Result> { let buffer_size = self.buffer_size.unwrap_or(5); let behavior = match &self.behavior { Some(ConsistencyBehaviorConfig::Ignore) => ConsistencyBehaviorBuilder::Ignore, @@ -179,13 +183,22 @@ impl TransformConfig for TeeConfig { Some(ConsistencyBehaviorConfig::SubchainOnMismatch(mismatch_chain)) => { ConsistencyBehaviorBuilder::SubchainOnMismatch( mismatch_chain - .get_builder("mismatch_chain".to_string()) + .get_builder(TransformContextConfig { + chain_name: "mismatch_chain".to_string(), + protocol: transform_context.protocol, + }) .await?, ) } None => ConsistencyBehaviorBuilder::Ignore, }; - let tee_chain = self.chain.get_builder("tee_chain".to_string()).await?; + let tee_chain = self + .chain + .get_builder(TransformContextConfig { + chain_name: "tee_chain".to_string(), + protocol: transform_context.protocol, + }) + .await?; Ok(Box::new(TeeBuilder::new( tee_chain, @@ -429,10 +442,10 @@ impl ChainSwitchListener { } } -#[cfg(test)] +#[cfg(all(test, feature = "redis"))] mod tests { use super::*; - use crate::transforms::null::NullSinkConfig; + use crate::{frame::MessageType, transforms::null::NullSinkConfig}; #[tokio::test] async fn test_validate_subchain_valid() { @@ -444,7 +457,11 @@ mod tests { switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } @@ -459,7 +476,11 @@ mod tests { switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate().join("\n"); let expected = r#"Tee: tee_chain chain: @@ -476,7 +497,11 @@ mod tests { buffer_size: None, switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } @@ -490,7 +515,11 @@ mod tests { buffer_size: None, switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } @@ -507,7 +536,11 @@ mod tests { switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate().join("\n"); let expected = r#"Tee: mismatch_chain chain: @@ -527,7 +560,11 @@ mod tests { switch_port: None, }; - let transform = config.get_builder("".to_owned()).await.unwrap(); + let transform_context_config = TransformContextConfig { + chain_name: "".into(), + protocol: MessageType::Redis, + }; + let transform = config.get_builder(transform_context_config).await.unwrap(); let result = transform.validate(); assert_eq!(result, Vec::::new()); } diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 8dde15487..1d0f4845a 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -1,3 +1,4 @@ +use super::TransformContextConfig; use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; @@ -23,7 +24,10 @@ const NAME: &str = "RequestThrottling"; #[typetag::serde(name = "RequestThrottling")] #[async_trait(?Send)] impl TransformConfig for RequestThrottlingConfig { - async fn get_builder(&self, _chain_name: String) -> Result> { + async fn get_builder( + &self, + _transform_context: TransformContextConfig, + ) -> Result> { Ok(Box::new(RequestThrottling { limiter: Arc::new(RateLimiter::direct(Quota::per_second( self.max_requests_per_second,