diff --git a/Cargo.lock b/Cargo.lock index 5104da9e8..af81ffceb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4362,7 +4362,6 @@ dependencies = [ "serde_json", "serde_yaml", "string", - "strum_macros 0.26.1", "thiserror", "tokio", "tokio-rustls 0.25.0", diff --git a/changelog.md b/changelog.md index 772248e15..dc6b95a47 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,29 @@ This assists us in knowing when to make the next release a breaking release and ## 0.3.0 +## shotover rust API + +`TransformBuilder::build` now returns `Box` instead of `Transforms`. +This means that custom transforms should implement the builder as: + +```rust +impl TransformBuilder for CustomBuilder { + fn build(&self) -> Box { + Box::new(CustomTransform::new()) + } +} +``` + +Instead of: + +```rust +impl TransformBuilder for CustomBuilder { + fn build(&self) -> Transforms { + Transforms::Custom(CustomTransform::new()) + } +} +``` + ### metrics The prometheus metrics were renamed to better follow the official reccomended naming scheme: diff --git a/custom-transforms-example/src/redis_get_rewrite.rs b/custom-transforms-example/src/redis_get_rewrite.rs index bf9c8a1ea..3af52b232 100644 --- a/custom-transforms-example/src/redis_get_rewrite.rs +++ b/custom-transforms-example/src/redis_get_rewrite.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use shotover::frame::{Frame, RedisFrame}; use shotover::message::Messages; -use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(deny_unknown_fields)] @@ -11,6 +11,7 @@ pub struct RedisGetRewriteConfig { pub result: String, } +const NAME: &str = "RedisGetRewrite"; #[typetag::serde(name = "RedisGetRewrite")] #[async_trait(?Send)] impl TransformConfig for RedisGetRewriteConfig { @@ -27,14 +28,14 @@ pub struct RedisGetRewriteBuilder { } impl TransformBuilder for RedisGetRewriteBuilder { - fn build(&self) -> Transforms { - Transforms::Custom(Box::new(RedisGetRewrite { + fn build(&self) -> Box { + Box::new(RedisGetRewrite { result: self.result.clone(), - })) + }) } fn get_name(&self) -> &'static str { - "RedisGetRewrite" + NAME } } @@ -44,6 +45,10 @@ pub struct RedisGetRewrite { #[async_trait] impl Transform for RedisGetRewrite { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { let mut get_indices = vec![]; for (i, message) in requests_wrapper.requests.iter_mut().enumerate() { diff --git a/shotover/Cargo.toml b/shotover/Cargo.toml index 8b29aff86..dc9c1531e 100644 --- a/shotover/Cargo.toml +++ b/shotover/Cargo.toml @@ -76,7 +76,6 @@ ordered-float.workspace = true #Crypto aws-config = "1.0.0" aws-sdk-kms = "1.1.0" -strum_macros = "0.26" chacha20poly1305 = { version = "0.10.0", features = ["std"] } generic-array = { version = "0.14", features = ["serde"] } kafka-protocol = "0.8.0" diff --git a/shotover/src/lib.rs b/shotover/src/lib.rs index 1fa941900..709aace81 100644 --- a/shotover/src/lib.rs +++ b/shotover/src/lib.rs @@ -30,6 +30,7 @@ #![deny(clippy::print_stdout)] #![deny(clippy::print_stderr)] #![allow(clippy::needless_doctest_main)] +#![allow(clippy::box_default)] pub mod codec; pub mod config; diff --git a/shotover/src/runner.rs b/shotover/src/runner.rs index bedd247b0..e8018176b 100644 --- a/shotover/src/runner.rs +++ b/shotover/src/runner.rs @@ -2,8 +2,6 @@ use crate::config::topology::Topology; use crate::config::Config; use crate::observability::LogFilterHttpExporter; -use crate::transforms::Transforms; -use crate::transforms::Wrapper; use anyhow::Context; use anyhow::{anyhow, Result}; use clap::{crate_version, Parser}; @@ -13,7 +11,7 @@ use std::net::SocketAddr; use tokio::runtime::{self, Runtime}; use tokio::signal::unix::{signal, SignalKind}; use tokio::sync::watch; -use tracing::{debug, error, info}; +use tracing::{error, info}; use tracing_appender::non_blocking::{NonBlocking, WorkerGuard}; use tracing_subscriber::filter::Directive; use tracing_subscriber::fmt::format::DefaultFields; @@ -302,16 +300,6 @@ async fn run( info!(configuration = ?config); info!(topology = ?topology); - debug!( - "Transform overhead size on stack is {}", - std::mem::size_of::() - ); - - debug!( - "Wrapper overhead size on stack is {}", - std::mem::size_of::>() - ); - match topology.run_chains(trigger_shutdown_rx).await { Ok(sources) => { futures::future::join_all(sources.into_iter().map(|x| x.into_join_handle())).await; diff --git a/shotover/src/transforms/cassandra/peers_rewrite.rs b/shotover/src/transforms/cassandra/peers_rewrite.rs index 40cddf976..a87e9e391 100644 --- a/shotover/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover/src/transforms/cassandra/peers_rewrite.rs @@ -4,7 +4,7 @@ use crate::frame::{ }; use crate::message::{Message, Messages}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use cassandra_protocol::frame::events::{ServerEvent, StatusChange}; @@ -19,6 +19,7 @@ pub struct CassandraPeersRewriteConfig { pub port: u16, } +const NAME: &str = "CassandraPeersRewrite"; #[typetag::serde(name = "CassandraPeersRewrite")] #[async_trait(?Send)] impl TransformConfig for CassandraPeersRewriteConfig { @@ -43,17 +44,21 @@ impl CassandraPeersRewrite { } impl TransformBuilder for CassandraPeersRewrite { - fn build(&self) -> Transforms { - Transforms::CassandraPeersRewrite(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "CassandraPeersRewrite" + NAME } } #[async_trait] impl Transform for CassandraPeersRewrite { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // Find the indices of queries to system.peers & system.peers_v2 // we need to know which columns in which CQL queries in which messages have system peers diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 79cae182d..02462850d 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -5,7 +5,7 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame}; use crate::message::{Message, Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use cassandra_protocol::events::ServerEvent; @@ -62,6 +62,7 @@ pub struct CassandraSinkClusterConfig { pub read_timeout: Option, } +const NAME: &str = "CassandraSinkCluster"; #[typetag::serde(name = "CassandraSinkCluster")] #[async_trait(?Send)] impl TransformConfig for CassandraSinkClusterConfig { @@ -152,8 +153,8 @@ impl CassandraSinkClusterBuilder { } impl TransformBuilder for CassandraSinkClusterBuilder { - fn build(&self) -> crate::transforms::Transforms { - Transforms::CassandraSinkCluster(Box::new(CassandraSinkCluster { + fn build(&self) -> Box { + Box::new(CassandraSinkCluster { contact_points: self.contact_points.clone(), message_rewriter: self.message_rewriter.clone(), control_connection: None, @@ -170,11 +171,11 @@ impl TransformBuilder for CassandraSinkClusterBuilder { keyspaces_rx: self.keyspaces_rx.clone(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), task_handshake_tx: self.task_handshake_tx.clone(), - })) + }) } fn get_name(&self) -> &'static str { - "CassandraSinkCluster" + NAME } fn is_terminating(&self) -> bool { @@ -718,6 +719,10 @@ fn is_use_statement_successful(response: Option>) -> bool { #[async_trait] impl Transform for CassandraSinkCluster { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { self.send_message(requests_wrapper.requests).await } diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index 796c99a33..488595f13 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -4,7 +4,7 @@ use crate::frame::cassandra::CassandraMetadata; use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::cassandra::connection::Response; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use cassandra_protocol::frame::Version; @@ -25,6 +25,7 @@ pub struct CassandraSinkSingleConfig { pub read_timeout: Option, } +const NAME: &str = "CassandraSinkSingle"; #[typetag::serde(name = "CassandraSinkSingle")] #[async_trait(?Send)] impl TransformConfig for CassandraSinkSingleConfig { @@ -77,8 +78,8 @@ impl CassandraSinkSingleBuilder { } impl TransformBuilder for CassandraSinkSingleBuilder { - fn build(&self) -> Transforms { - Transforms::CassandraSinkSingle(CassandraSinkSingle { + fn build(&self) -> Box { + Box::new(CassandraSinkSingle { outbound: None, version: self.version, address: self.address.clone(), @@ -92,7 +93,7 @@ impl TransformBuilder for CassandraSinkSingleBuilder { } fn get_name(&self) -> &'static str { - "CassandraSinkSingle" + NAME } fn is_terminating(&self) -> bool { @@ -168,6 +169,10 @@ impl CassandraSinkSingle { #[async_trait] impl Transform for CassandraSinkSingle { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { self.send_message(requests_wrapper.requests).await } diff --git a/shotover/src/transforms/chain.rs b/shotover/src/transforms/chain.rs index a53571d9a..c61110e04 100644 --- a/shotover/src/transforms/chain.rs +++ b/shotover/src/transforms/chain.rs @@ -1,5 +1,5 @@ use crate::message::Messages; -use crate::transforms::{TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::{anyhow, Result}; use derivative::Derivative; use futures::TryFutureExt; @@ -51,18 +51,13 @@ impl BufferedChainMessages { /// Transform chains can be of arbitary complexity and a transform can even have its own set of child transform chains. /// Transform chains are defined by the user in Shotover's configuration file and are linked to sources. /// -/// The transform chain is a vector of mutable references to the enum [Transforms] (which is an enum dispatch wrapper around the various transform types). -#[derive(Derivative)] -#[derivative(Debug)] +/// The transform chain is a vector of mutable references to the enum [Transform] (which is an enum dispatch wrapper around the various transform types). pub struct TransformChain { pub name: &'static str, pub chain: InnerChain, - #[derivative(Debug = "ignore")] chain_total: Counter, - #[derivative(Debug = "ignore")] chain_failures: Counter, - #[derivative(Debug = "ignore")] chain_batch_size: Histogram, } @@ -194,27 +189,19 @@ impl TransformChain { } } -#[derive(Derivative)] -#[derivative(Debug)] pub struct TransformAndMetrics { - pub transform: Transforms, - #[derivative(Debug = "ignore")] + pub transform: Box, pub transform_total: Counter, - #[derivative(Debug = "ignore")] pub transform_failures: Counter, - #[derivative(Debug = "ignore")] pub transform_latency: Histogram, - #[derivative(Debug = "ignore")] pub transform_pushed_total: Counter, - #[derivative(Debug = "ignore")] pub transform_pushed_failures: Counter, - #[derivative(Debug = "ignore")] pub transform_pushed_latency: Histogram, } impl TransformAndMetrics { #[cfg(test)] - pub fn new(transform: Transforms) -> Self { + pub fn new(transform: Box) -> Self { TransformAndMetrics { transform, transform_total: Counter::noop(), diff --git a/shotover/src/transforms/coalesce.rs b/shotover/src/transforms/coalesce.rs index e2fc7a118..49e4106d6 100644 --- a/shotover/src/transforms/coalesce.rs +++ b/shotover/src/transforms/coalesce.rs @@ -1,5 +1,5 @@ use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -20,6 +20,7 @@ pub struct CoalesceConfig { pub flush_when_millis_since_last_flush: Option, } +const NAME: &str = "Coalesce"; #[typetag::serde(name = "Coalesce")] #[async_trait(?Send)] impl TransformConfig for CoalesceConfig { @@ -34,12 +35,12 @@ impl TransformConfig for CoalesceConfig { } impl TransformBuilder for Coalesce { - fn build(&self) -> Transforms { - Transforms::Coalesce(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "Coalesce" + NAME } fn validate(&self) -> Vec { @@ -64,6 +65,10 @@ impl TransformBuilder for Coalesce { #[async_trait] impl Transform for Coalesce { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { self.buffer.append(&mut requests_wrapper.requests); @@ -96,7 +101,7 @@ mod test { use crate::transforms::chain::TransformAndMetrics; use crate::transforms::coalesce::Coalesce; use crate::transforms::loopback::Loopback; - use crate::transforms::{Transform, Transforms, Wrapper}; + use crate::transforms::{Transform, Wrapper}; use std::time::{Duration, Instant}; #[tokio::test(flavor = "multi_thread")] @@ -108,9 +113,7 @@ mod test { last_write: Instant::now(), }; - let mut chain = vec![TransformAndMetrics::new(Transforms::Loopback( - Loopback::default(), - ))]; + let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; let messages: Vec<_> = (0..25) .map(|_| Message::from_frame(Frame::Redis(RedisFrame::Null))) @@ -149,9 +152,7 @@ mod test { last_write: Instant::now(), }; - let mut chain = vec![TransformAndMetrics::new(Transforms::Loopback( - Loopback::default(), - ))]; + let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; let messages: Vec<_> = (0..25) .map(|_| Message::from_frame(Frame::Redis(RedisFrame::Null))) @@ -190,9 +191,7 @@ mod test { last_write: Instant::now(), }; - let mut chain = vec![TransformAndMetrics::new(Transforms::Loopback( - Loopback::default(), - ))]; + let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; let messages: Vec<_> = (0..25) .map(|_| Message::from_frame(Frame::Redis(RedisFrame::Null))) diff --git a/shotover/src/transforms/debug/force_parse.rs b/shotover/src/transforms/debug/force_parse.rs index 26d57e009..fd739b397 100644 --- a/shotover/src/transforms/debug/force_parse.rs +++ b/shotover/src/transforms/debug/force_parse.rs @@ -7,7 +7,7 @@ use crate::message::Messages; /// It could also be used to ensure that messages round trip correctly when parsed. #[cfg(feature = "alpha-transforms")] use crate::transforms::TransformConfig; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -44,6 +44,7 @@ pub struct DebugForceEncodeConfig { pub encode_responses: bool, } +const NAME: &str = "DebugForceEncode"; #[cfg(feature = "alpha-transforms")] #[typetag::serde(name = "DebugForceEncode")] #[async_trait(?Send)] @@ -67,17 +68,21 @@ pub struct DebugForceParse { } impl TransformBuilder for DebugForceParse { - fn build(&self) -> Transforms { - Transforms::DebugForceParse(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "DebugForceParse" + NAME } } #[async_trait] impl Transform for DebugForceParse { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { for message in &mut requests_wrapper.requests { if self.parse_requests { diff --git a/shotover/src/transforms/debug/log_to_file.rs b/shotover/src/transforms/debug/log_to_file.rs index 75a0f0d85..98a0a4803 100644 --- a/shotover/src/transforms/debug/log_to_file.rs +++ b/shotover/src/transforms/debug/log_to_file.rs @@ -1,5 +1,5 @@ use crate::message::{Encodable, Message}; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::{Context, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -12,6 +12,7 @@ use tracing::{error, info}; #[serde(deny_unknown_fields)] pub struct DebugLogToFileConfig; +const NAME: &str = "DebugLogToFile"; #[cfg(feature = "alpha-transforms")] #[typetag::serde(name = "DebugLogToFile")] #[async_trait(?Send)] @@ -31,7 +32,7 @@ pub struct DebugLogToFileBuilder { } impl TransformBuilder for DebugLogToFileBuilder { - fn build(&self) -> Transforms { + fn build(&self) -> Box { self.connection_counter.fetch_add(1, Ordering::Relaxed); let connection_current = self.connection_counter.load(Ordering::Relaxed); @@ -50,7 +51,7 @@ impl TransformBuilder for DebugLogToFileBuilder { .context("failed to create directory for logging responses") .unwrap(); - Transforms::DebugLogToFile(DebugLogToFile { + Box::new(DebugLogToFile { request_counter: 0, response_counter: 0, requests, @@ -59,7 +60,7 @@ impl TransformBuilder for DebugLogToFileBuilder { } fn get_name(&self) -> &'static str { - "DebugLogToFile" + NAME } } @@ -72,6 +73,10 @@ pub struct DebugLogToFile { #[async_trait] impl Transform for DebugLogToFile { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result> { for message in &requests_wrapper.requests { self.request_counter += 1; diff --git a/shotover/src/transforms/debug/printer.rs b/shotover/src/transforms/debug/printer.rs index 50e58ae34..2d9c89bf4 100644 --- a/shotover/src/transforms/debug/printer.rs +++ b/shotover/src/transforms/debug/printer.rs @@ -1,5 +1,5 @@ use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -9,6 +9,7 @@ use tracing::info; #[serde(deny_unknown_fields)] pub struct DebugPrinterConfig; +const NAME: &str = "DebugPrinter"; #[typetag::serde(name = "DebugPrinter")] #[async_trait(?Send)] impl TransformConfig for DebugPrinterConfig { @@ -35,17 +36,21 @@ impl DebugPrinter { } impl TransformBuilder for DebugPrinter { - fn build(&self) -> Transforms { - Transforms::DebugPrinter(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "DebugPrinter" + NAME } } #[async_trait] impl Transform for DebugPrinter { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { for request in &mut requests_wrapper.requests { info!("Request: {}", request.to_high_level_string()); diff --git a/shotover/src/transforms/debug/random_delay.rs b/shotover/src/transforms/debug/random_delay.rs index 1594d0f19..d885ea82d 100644 --- a/shotover/src/transforms/debug/random_delay.rs +++ b/shotover/src/transforms/debug/random_delay.rs @@ -6,6 +6,8 @@ use rand_distr::Distribution; use rand_distr::Normal; use tokio::time::Duration; +const NAME: &str = "DebugRandomDelay"; + #[derive(Debug, Clone)] pub struct DebugRandomDelay { pub delay: u64, @@ -14,6 +16,10 @@ pub struct DebugRandomDelay { #[async_trait] impl Transform for DebugRandomDelay { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { let delay = if let Some(dist) = self.distribution { Duration::from_millis(dist.sample(&mut rand::thread_rng()) as u64 + self.delay) diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index 70f3f2351..8b6b3b5c2 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,6 +1,6 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -12,6 +12,7 @@ pub struct DebugReturnerConfig { response: Response, } +const NAME: &str = "DebugReturner"; #[typetag::serde(name = "DebugReturner")] #[async_trait(?Send)] impl TransformConfig for DebugReturnerConfig { @@ -41,12 +42,12 @@ impl DebugReturner { } impl TransformBuilder for DebugReturner { - fn build(&self) -> Transforms { - Transforms::DebugReturner(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "DebugReturner" + NAME } fn is_terminating(&self) -> bool { @@ -56,6 +57,10 @@ impl TransformBuilder for DebugReturner { #[async_trait] impl Transform for DebugReturner { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { match &self.response { Response::Message(message) => Ok(message.clone()), diff --git a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs index 72d59d7aa..66111d44a 100644 --- a/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs +++ b/shotover/src/transforms/distributed/tuneable_consistency_scatter.rs @@ -2,7 +2,7 @@ use crate::config::chain::TransformChainConfig; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use futures::stream::FuturesUnordered; @@ -19,6 +19,7 @@ pub struct TuneableConsistencyScatterConfig { pub read_consistency: i32, } +const NAME: &str = "TuneableConsistencyScatter"; #[typetag::serde(name = "TuneableConsistencyScatter")] #[async_trait(?Send)] impl TransformConfig for TuneableConsistencyScatterConfig { @@ -46,8 +47,8 @@ pub struct TuneableConsistencyScatterBuilder { } impl TransformBuilder for TuneableConsistencyScatterBuilder { - fn build(&self) -> Transforms { - Transforms::TuneableConsistencyScatter(TuneableConsistentencyScatter { + fn build(&self) -> Box { + Box::new(TuneableConsistentencyScatter { route_map: self .route_map .iter() @@ -59,7 +60,7 @@ impl TransformBuilder for TuneableConsistencyScatterBuilder { } fn get_name(&self) -> &'static str { - "TuneableConsistencyScatter" + NAME } fn is_terminating(&self) -> bool { @@ -182,6 +183,10 @@ enum Resolver { #[async_trait] impl Transform for TuneableConsistentencyScatter { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { let consistency: Vec<_> = requests_wrapper .requests @@ -287,7 +292,7 @@ mod scatter_transform_tests { TuneableConsistencyScatterBuilder, TuneableConsistentencyScatter, }; use crate::transforms::null::NullSink; - use crate::transforms::{TransformBuilder, Transforms, Wrapper}; + use crate::transforms::{Transform, TransformBuilder, Wrapper}; use bytes::Bytes; use std::collections::HashMap; @@ -337,12 +342,11 @@ mod scatter_transform_tests { TransformChainBuilder::new(vec![err_repeat.clone()], "three"), ); - let mut tuneable_success_consistency = - Transforms::TuneableConsistencyScatter(TuneableConsistentencyScatter { - route_map: build_chains(two_of_three).await, - write_consistency: 2, - read_consistency: 2, - }); + let mut tuneable_success_consistency = Box::new(TuneableConsistentencyScatter { + route_map: build_chains(two_of_three).await, + write_consistency: 2, + read_consistency: 2, + }); let test = tuneable_success_consistency .transform(wrapper.clone()) @@ -365,12 +369,11 @@ mod scatter_transform_tests { TransformChainBuilder::new(vec![err_repeat.clone()], "three"), ); - let mut tuneable_fail_consistency = - Transforms::TuneableConsistencyScatter(TuneableConsistentencyScatter { - route_map: build_chains(one_of_three).await, - write_consistency: 2, - read_consistency: 2, - }); + let mut tuneable_fail_consistency = Box::new(TuneableConsistentencyScatter { + route_map: build_chains(one_of_three).await, + write_consistency: 2, + read_consistency: 2, + }); let response_fail = tuneable_fail_consistency .transform(wrapper.clone()) diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index 770f280a8..5473f7ab6 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -5,8 +5,6 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use std::sync::atomic::{AtomicBool, Ordering}; -use super::Transforms; - static SHOWN_ERROR: AtomicBool = AtomicBool::new(false); #[derive(Serialize, Deserialize, Debug, Clone)] @@ -28,6 +26,7 @@ pub struct QueryTypeFilterConfig { pub filter: Filter, } +const NAME: &str = "QueryTypeFilter"; #[typetag::serde(name = "QueryTypeFilter")] #[async_trait(?Send)] impl TransformConfig for QueryTypeFilterConfig { @@ -39,17 +38,21 @@ impl TransformConfig for QueryTypeFilterConfig { } impl TransformBuilder for QueryTypeFilter { - fn build(&self) -> Transforms { - Transforms::QueryTypeFilter(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "QueryTypeFilter" + NAME } } #[async_trait] impl Transform for QueryTypeFilter { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { let removed_indexes: Result> = requests_wrapper .requests @@ -117,7 +120,7 @@ mod test { use crate::transforms::chain::TransformAndMetrics; use crate::transforms::filter::QueryTypeFilter; use crate::transforms::loopback::Loopback; - use crate::transforms::{Transform, Transforms, Wrapper}; + use crate::transforms::{Transform, Wrapper}; #[tokio::test(flavor = "multi_thread")] async fn test_filter_denylist() { @@ -125,9 +128,7 @@ mod test { filter: Filter::DenyList(vec![QueryType::Read]), }; - let mut chain = vec![TransformAndMetrics::new(Transforms::Loopback( - Loopback::default(), - ))]; + let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; let messages: Vec<_> = (0..26) .map(|i| { @@ -181,9 +182,7 @@ mod test { filter: Filter::AllowList(vec![QueryType::Write]), }; - let mut chain = vec![TransformAndMetrics::new(Transforms::Loopback( - Loopback::default(), - ))]; + let mut chain = vec![TransformAndMetrics::new(Box::new(Loopback::default()))]; let messages: Vec<_> = (0..26) .map(|i| { diff --git a/shotover/src/transforms/kafka/sink_cluster.rs b/shotover/src/transforms/kafka/sink_cluster.rs index dcca0e064..799adf069 100644 --- a/shotover/src/transforms/kafka/sink_cluster.rs +++ b/shotover/src/transforms/kafka/sink_cluster.rs @@ -6,7 +6,8 @@ use crate::message::{Message, Messages}; use crate::tcp; use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::TransformConfig; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use dashmap::DashMap; @@ -40,8 +41,7 @@ pub struct KafkaSinkClusterConfig { pub read_timeout: Option, } -use crate::transforms::TransformConfig; - +const NAME: &str = "KafkaSinkCluster"; #[typetag::serde(name = "KafkaSinkCluster")] #[async_trait(?Send)] impl TransformConfig for KafkaSinkClusterConfig { @@ -103,8 +103,8 @@ impl KafkaSinkClusterBuilder { } impl TransformBuilder for KafkaSinkClusterBuilder { - fn build(&self) -> Transforms { - Transforms::KafkaSinkCluster(KafkaSinkCluster { + fn build(&self) -> Box { + Box::new(KafkaSinkCluster { first_contact_points: self.first_contact_points.clone(), shotover_nodes: self.shotover_nodes.clone(), pushed_messages_tx: None, @@ -120,7 +120,7 @@ impl TransformBuilder for KafkaSinkClusterBuilder { } fn get_name(&self) -> &'static str { - "KafkaSinkCluster" + NAME } fn is_terminating(&self) -> bool { @@ -166,6 +166,10 @@ pub struct KafkaSinkCluster { #[async_trait] impl Transform for KafkaSinkCluster { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { if requests_wrapper.requests.is_empty() { return Ok(vec![]); diff --git a/shotover/src/transforms/kafka/sink_single.rs b/shotover/src/transforms/kafka/sink_single.rs index 7bf57dfae..6a8740cef 100644 --- a/shotover/src/transforms/kafka/sink_single.rs +++ b/shotover/src/transforms/kafka/sink_single.rs @@ -6,7 +6,7 @@ use crate::tcp; use crate::transforms::kafka::common::produce_channel; use crate::transforms::util::cluster_connection_pool::{spawn_read_write_tasks, Connection}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -26,6 +26,7 @@ pub struct KafkaSinkSingleConfig { use crate::transforms::TransformConfig; +const NAME: &str = "KafkaSinkSingle"; #[typetag::serde(name = "KafkaSinkSingle")] #[async_trait(?Send)] impl TransformConfig for KafkaSinkSingleConfig { @@ -64,8 +65,8 @@ impl KafkaSinkSingleBuilder { } impl TransformBuilder for KafkaSinkSingleBuilder { - fn build(&self) -> Transforms { - Transforms::KafkaSinkSingle(KafkaSinkSingle { + fn build(&self) -> Box { + Box::new(KafkaSinkSingle { outbound: None, address_port: self.address_port, pushed_messages_tx: None, @@ -75,7 +76,7 @@ impl TransformBuilder for KafkaSinkSingleBuilder { } fn get_name(&self) -> &'static str { - "KafkaSinkSingle" + NAME } fn is_terminating(&self) -> bool { @@ -93,6 +94,10 @@ pub struct KafkaSinkSingle { #[async_trait] impl Transform for KafkaSinkSingle { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { if self.outbound.is_none() { let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkSingle".to_owned()); diff --git a/shotover/src/transforms/load_balance.rs b/shotover/src/transforms/load_balance.rs index 6e6209694..03edc6d15 100644 --- a/shotover/src/transforms/load_balance.rs +++ b/shotover/src/transforms/load_balance.rs @@ -1,4 +1,3 @@ -use super::Transforms; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -17,6 +16,7 @@ pub struct ConnectionBalanceAndPoolConfig { pub chain: TransformChainConfig, } +const NAME: &str = "ConnectionBalanceAndPool"; #[typetag::serde(name = "ConnectionBalanceAndPool")] #[async_trait(?Send)] impl TransformConfig for ConnectionBalanceAndPoolConfig { @@ -39,8 +39,8 @@ pub struct ConnectionBalanceAndPoolBuilder { } impl TransformBuilder for ConnectionBalanceAndPoolBuilder { - fn build(&self) -> Transforms { - Transforms::PoolConnections(ConnectionBalanceAndPool { + fn build(&self) -> Box { + Box::new(ConnectionBalanceAndPool { active_connection: None, max_connections: self.max_connections, all_connections: self.all_connections.clone(), @@ -53,7 +53,7 @@ impl TransformBuilder for ConnectionBalanceAndPoolBuilder { } fn get_name(&self) -> &'static str { - "ConnectionBalanceAndPool" + NAME } } @@ -61,14 +61,18 @@ impl TransformBuilder for ConnectionBalanceAndPoolBuilder { /// Once this happens cloned instances will reuse connections from earlier clones. #[derive(Debug)] pub struct ConnectionBalanceAndPool { - pub active_connection: Option, - pub max_connections: usize, - pub all_connections: Arc>>, - pub chain_to_clone: Arc, + active_connection: Option, + max_connections: usize, + all_connections: Arc>>, + chain_to_clone: Arc, } #[async_trait] impl Transform for ConnectionBalanceAndPool { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { if self.active_connection.is_none() { let mut all_connections = self.all_connections.lock().await; @@ -91,48 +95,3 @@ impl Transform for ConnectionBalanceAndPool { .await } } - -#[cfg(test)] -mod test { - use crate::message::Messages; - use crate::transforms::chain::TransformChainBuilder; - use crate::transforms::debug::returner::{DebugReturner, Response}; - use crate::transforms::load_balance::ConnectionBalanceAndPoolBuilder; - use crate::transforms::{Transforms, Wrapper}; - use std::sync::Arc; - - #[tokio::test(flavor = "multi_thread")] - pub async fn test_balance() { - let transform = Box::new(ConnectionBalanceAndPoolBuilder { - max_connections: 3, - all_connections: Arc::new(Default::default()), - chain_to_clone: Arc::new(TransformChainBuilder::new( - vec![Box::new(DebugReturner::new(Response::Message( - Messages::new(), - )))], - "child_test", - )), - }); - - let chain = TransformChainBuilder::new(vec![transform], "test"); - - for _ in 0..90 { - chain - .build() - .process_request(Wrapper::new_test(Messages::new())) - .await - .unwrap(); - } - - match chain.chain[0].builder.build() { - Transforms::PoolConnections(p) => { - let all_connections = p.all_connections.lock().await; - assert_eq!(all_connections.len(), 3); - for bc in all_connections.iter() { - assert_eq!(bc.count.load(std::sync::atomic::Ordering::Relaxed), 30); - } - } - _ => panic!("whoops"), - } - } -} diff --git a/shotover/src/transforms/loopback.rs b/shotover/src/transforms/loopback.rs index a268d27f6..0e39ba1ec 100644 --- a/shotover/src/transforms/loopback.rs +++ b/shotover/src/transforms/loopback.rs @@ -1,18 +1,20 @@ use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; +const NAME: &str = "Loopback"; + #[derive(Debug, Clone, Default)] pub struct Loopback {} impl TransformBuilder for Loopback { - fn build(&self) -> Transforms { - Transforms::Loopback(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "Loopback" + NAME } fn is_terminating(&self) -> bool { @@ -22,6 +24,10 @@ impl TransformBuilder for Loopback { #[async_trait] impl Transform for Loopback { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { Ok(requests_wrapper.requests) } diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index 71dcfb918..25a3d4f07 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -1,34 +1,6 @@ //! Various types required for defining a transform use crate::message::Messages; -use crate::transforms::cassandra::peers_rewrite::CassandraPeersRewrite; -use crate::transforms::cassandra::sink_cluster::CassandraSinkCluster; -use crate::transforms::cassandra::sink_single::CassandraSinkSingle; -use crate::transforms::coalesce::Coalesce; -use crate::transforms::debug::force_parse::DebugForceParse; -use crate::transforms::debug::log_to_file::DebugLogToFile; -use crate::transforms::debug::printer::DebugPrinter; -use crate::transforms::debug::random_delay::DebugRandomDelay; -use crate::transforms::debug::returner::DebugReturner; -use crate::transforms::distributed::tuneable_consistency_scatter::TuneableConsistentencyScatter; -use crate::transforms::filter::QueryTypeFilter; -use crate::transforms::kafka::sink_cluster::KafkaSinkCluster; -use crate::transforms::kafka::sink_single::KafkaSinkSingle; -use crate::transforms::load_balance::ConnectionBalanceAndPool; -use crate::transforms::loopback::Loopback; -use crate::transforms::null::NullSink; -#[cfg(feature = "alpha-transforms")] -use crate::transforms::opensearch::OpenSearchSinkSingle; -use crate::transforms::parallel_map::ParallelMap; -use crate::transforms::protect::Protect; -use crate::transforms::query_counter::QueryCounter; -use crate::transforms::redis::cache::SimpleRedisCache; -use crate::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite; -use crate::transforms::redis::sink_cluster::RedisSinkCluster; -use crate::transforms::redis::sink_single::RedisSinkSingle; -use crate::transforms::redis::timestamp_tagging::RedisTimestampTagger; -use crate::transforms::tee::Tee; -use crate::transforms::throttling::RequestThrottling; use anyhow::{anyhow, Result}; use async_trait::async_trait; use core::fmt; @@ -38,7 +10,6 @@ use std::iter::Rev; use std::net::SocketAddr; use std::pin::Pin; use std::slice::IterMut; -use strum_macros::IntoStaticStr; use tokio::sync::mpsc; use tokio::time::Instant; @@ -67,7 +38,7 @@ pub mod throttling; pub mod util; pub trait TransformBuilder: Send + Sync { - fn build(&self) -> Transforms; + fn build(&self) -> Box; fn get_name(&self) -> &'static str; @@ -86,161 +57,6 @@ impl Debug for dyn TransformBuilder { } } -//TODO Generate the trait implementation for this passthrough enum via a macro -/// The [`crate::transforms::Transforms`] enum is responsible for [`crate::transforms::Transform`] registration and enum dispatch -/// in the transform chain. This is largely a performance optimisation by using enum dispatch rather -/// than using dynamic trait objects. -#[derive(IntoStaticStr)] -pub enum Transforms { - KafkaSinkSingle(KafkaSinkSingle), - KafkaSinkCluster(KafkaSinkCluster), - CassandraSinkSingle(CassandraSinkSingle), - CassandraSinkCluster(Box), - RedisSinkSingle(RedisSinkSingle), - CassandraPeersRewrite(CassandraPeersRewrite), - RedisCache(SimpleRedisCache), - Tee(Tee), - NullSink(NullSink), - Loopback(Loopback), - Protect(Box), - TuneableConsistencyScatter(TuneableConsistentencyScatter), - RedisTimestampTagger(RedisTimestampTagger), - RedisSinkCluster(RedisSinkCluster), - RedisClusterPortsRewrite(RedisClusterPortsRewrite), - DebugReturner(DebugReturner), - DebugRandomDelay(DebugRandomDelay), - DebugPrinter(DebugPrinter), - DebugLogToFile(DebugLogToFile), - DebugForceParse(DebugForceParse), - ParallelMap(ParallelMap), - PoolConnections(ConnectionBalanceAndPool), - Coalesce(Coalesce), - QueryTypeFilter(QueryTypeFilter), - QueryCounter(QueryCounter), - RequestThrottling(RequestThrottling), - Custom(Box), - #[cfg(feature = "alpha-transforms")] - OpenSearchSinkSingle(OpenSearchSinkSingle), -} - -impl Debug for Transforms { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "Transform: {}", self.get_name()) - } -} - -impl Transforms { - async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { - match self { - Transforms::KafkaSinkSingle(c) => c.transform(requests_wrapper).await, - Transforms::KafkaSinkCluster(c) => c.transform(requests_wrapper).await, - Transforms::CassandraSinkSingle(c) => c.transform(requests_wrapper).await, - Transforms::CassandraSinkCluster(c) => c.transform(requests_wrapper).await, - Transforms::CassandraPeersRewrite(c) => c.transform(requests_wrapper).await, - Transforms::RedisCache(r) => r.transform(requests_wrapper).await, - Transforms::Tee(m) => m.transform(requests_wrapper).await, - Transforms::DebugPrinter(p) => p.transform(requests_wrapper).await, - Transforms::DebugLogToFile(p) => p.transform(requests_wrapper).await, - Transforms::DebugForceParse(p) => p.transform(requests_wrapper).await, - Transforms::NullSink(n) => n.transform(requests_wrapper).await, - Transforms::Loopback(n) => n.transform(requests_wrapper).await, - Transforms::Protect(p) => p.transform(requests_wrapper).await, - Transforms::DebugReturner(p) => p.transform(requests_wrapper).await, - Transforms::DebugRandomDelay(p) => p.transform(requests_wrapper).await, - Transforms::TuneableConsistencyScatter(tc) => tc.transform(requests_wrapper).await, - Transforms::RedisSinkSingle(r) => r.transform(requests_wrapper).await, - Transforms::RedisTimestampTagger(r) => r.transform(requests_wrapper).await, - Transforms::RedisClusterPortsRewrite(r) => r.transform(requests_wrapper).await, - Transforms::RedisSinkCluster(r) => r.transform(requests_wrapper).await, - Transforms::ParallelMap(s) => s.transform(requests_wrapper).await, - Transforms::PoolConnections(s) => s.transform(requests_wrapper).await, - Transforms::Coalesce(s) => s.transform(requests_wrapper).await, - Transforms::QueryTypeFilter(s) => s.transform(requests_wrapper).await, - Transforms::QueryCounter(s) => s.transform(requests_wrapper).await, - Transforms::RequestThrottling(s) => s.transform(requests_wrapper).await, - Transforms::Custom(s) => s.transform(requests_wrapper).await, - #[cfg(feature = "alpha-transforms")] - Transforms::OpenSearchSinkSingle(s) => s.transform(requests_wrapper).await, - } - } - - async fn transform_pushed<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { - match self { - Transforms::KafkaSinkSingle(c) => c.transform_pushed(requests_wrapper).await, - Transforms::KafkaSinkCluster(c) => c.transform_pushed(requests_wrapper).await, - Transforms::CassandraSinkSingle(c) => c.transform_pushed(requests_wrapper).await, - Transforms::CassandraSinkCluster(c) => c.transform_pushed(requests_wrapper).await, - Transforms::CassandraPeersRewrite(c) => c.transform_pushed(requests_wrapper).await, - Transforms::RedisCache(r) => r.transform_pushed(requests_wrapper).await, - Transforms::Tee(m) => m.transform_pushed(requests_wrapper).await, - Transforms::DebugPrinter(p) => p.transform_pushed(requests_wrapper).await, - Transforms::DebugLogToFile(p) => p.transform_pushed(requests_wrapper).await, - Transforms::DebugForceParse(p) => p.transform_pushed(requests_wrapper).await, - Transforms::NullSink(n) => n.transform_pushed(requests_wrapper).await, - Transforms::Loopback(n) => n.transform_pushed(requests_wrapper).await, - Transforms::Protect(p) => p.transform_pushed(requests_wrapper).await, - Transforms::DebugReturner(p) => p.transform_pushed(requests_wrapper).await, - Transforms::DebugRandomDelay(p) => p.transform_pushed(requests_wrapper).await, - Transforms::TuneableConsistencyScatter(tc) => { - tc.transform_pushed(requests_wrapper).await - } - Transforms::RedisSinkSingle(r) => r.transform_pushed(requests_wrapper).await, - Transforms::RedisTimestampTagger(r) => r.transform_pushed(requests_wrapper).await, - Transforms::RedisClusterPortsRewrite(r) => r.transform_pushed(requests_wrapper).await, - Transforms::RedisSinkCluster(r) => r.transform_pushed(requests_wrapper).await, - Transforms::ParallelMap(s) => s.transform_pushed(requests_wrapper).await, - Transforms::PoolConnections(s) => s.transform_pushed(requests_wrapper).await, - Transforms::Coalesce(s) => s.transform_pushed(requests_wrapper).await, - Transforms::QueryTypeFilter(s) => s.transform_pushed(requests_wrapper).await, - Transforms::QueryCounter(s) => s.transform_pushed(requests_wrapper).await, - Transforms::RequestThrottling(s) => s.transform_pushed(requests_wrapper).await, - Transforms::Custom(s) => s.transform_pushed(requests_wrapper).await, - #[cfg(feature = "alpha-transforms")] - Transforms::OpenSearchSinkSingle(s) => s.transform_pushed(requests_wrapper).await, - } - } - - fn get_name(&self) -> &'static str { - self.into() - } - - fn set_pushed_messages_tx(&mut self, pushed_messages_tx: mpsc::UnboundedSender) { - match self { - Transforms::KafkaSinkSingle(c) => c.set_pushed_messages_tx(pushed_messages_tx), - Transforms::KafkaSinkCluster(c) => c.set_pushed_messages_tx(pushed_messages_tx), - Transforms::CassandraSinkSingle(c) => c.set_pushed_messages_tx(pushed_messages_tx), - Transforms::CassandraSinkCluster(c) => c.set_pushed_messages_tx(pushed_messages_tx), - Transforms::CassandraPeersRewrite(c) => c.set_pushed_messages_tx(pushed_messages_tx), - Transforms::RedisCache(r) => r.set_pushed_messages_tx(pushed_messages_tx), - Transforms::Tee(t) => t.set_pushed_messages_tx(pushed_messages_tx), - Transforms::RedisSinkSingle(r) => r.set_pushed_messages_tx(pushed_messages_tx), - Transforms::TuneableConsistencyScatter(c) => { - c.set_pushed_messages_tx(pushed_messages_tx) - } - Transforms::RedisTimestampTagger(r) => r.set_pushed_messages_tx(pushed_messages_tx), - Transforms::RedisClusterPortsRewrite(r) => r.set_pushed_messages_tx(pushed_messages_tx), - Transforms::DebugPrinter(p) => p.set_pushed_messages_tx(pushed_messages_tx), - Transforms::DebugLogToFile(p) => p.set_pushed_messages_tx(pushed_messages_tx), - Transforms::DebugForceParse(p) => p.set_pushed_messages_tx(pushed_messages_tx), - Transforms::NullSink(n) => n.set_pushed_messages_tx(pushed_messages_tx), - Transforms::RedisSinkCluster(r) => r.set_pushed_messages_tx(pushed_messages_tx), - Transforms::ParallelMap(s) => s.set_pushed_messages_tx(pushed_messages_tx), - Transforms::PoolConnections(s) => s.set_pushed_messages_tx(pushed_messages_tx), - Transforms::Coalesce(s) => s.set_pushed_messages_tx(pushed_messages_tx), - Transforms::QueryTypeFilter(s) => s.set_pushed_messages_tx(pushed_messages_tx), - Transforms::QueryCounter(s) => s.set_pushed_messages_tx(pushed_messages_tx), - Transforms::Loopback(l) => l.set_pushed_messages_tx(pushed_messages_tx), - Transforms::Protect(p) => p.set_pushed_messages_tx(pushed_messages_tx), - Transforms::DebugReturner(d) => d.set_pushed_messages_tx(pushed_messages_tx), - Transforms::DebugRandomDelay(d) => d.set_pushed_messages_tx(pushed_messages_tx), - Transforms::RequestThrottling(d) => d.set_pushed_messages_tx(pushed_messages_tx), - Transforms::Custom(d) => d.set_pushed_messages_tx(pushed_messages_tx), - #[cfg(feature = "alpha-transforms")] - Transforms::OpenSearchSinkSingle(s) => s.set_pushed_messages_tx(pushed_messages_tx), - } - } -} - #[typetag::serde] #[async_trait(?Send)] pub trait TransformConfig: Debug { @@ -250,7 +66,6 @@ pub trait TransformConfig: Debug { /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the /// remaining transforms that will process the messages attached to this [`Wrapper`]. /// Most [`Transform`] authors will only be interested in [`wrapper.requests`]. -#[derive(Debug)] pub struct Wrapper<'a> { pub requests: Messages, transforms: TransformIter<'a>, @@ -263,7 +78,6 @@ pub struct Wrapper<'a> { pub flush: bool, } -#[derive(Debug)] enum TransformIter<'a> { Forwards(IterMut<'a, TransformAndMetrics>), Backwards(Rev>), @@ -447,12 +261,8 @@ impl<'a> Wrapper<'a> { /// then wrapping a member in an [`Arc>`](std::sync::Mutex) will achieve that, /// but make sure to copy the value from the TransformBuilder to ensure all instances are referring to the same value. /// -/// Once you have created your [`Transform`], you will need to create -/// new enum variants in [Transforms]. -/// And implement the [TransformBuilder] and [TransformConfig] traits to make them configurable in Shotover. -/// Shotover uses a concept called enum dispatch to provide dynamic configuration of transform chains -/// with minimal impact on performance. -/// +/// Once you have created your [`Transform`], you will need to implement the [TransformBuilder] and [TransformConfig] traits +/// to make them configurable in Shotover. /// Implementing this trait is usually done using `#[async_trait]` macros. #[async_trait] pub trait Transform: Send { @@ -480,7 +290,7 @@ pub trait Transform: Send { /// [`crate::transforms::cassandra::sink_single::CassandraSinkSingle`]. This type of transform /// is called a Terminating transform (as no subsequent transforms in the chain will be called). /// * _Message count_ - requests_wrapper.requests will contain 0 or more messages. - /// Your transform should return the same number of responses as messages received in requests_wrapper.requests. Transforms that + /// Your transform should return the same number of responses as messages received in requests_wrapper.requests. Transform that /// don't do this explicitly for each call, should return the same number of responses as messages it receives over the lifetime /// of the transform chain. A good example of this is the [`crate::transforms::coalesce::Coalesce`] transform. The /// [`crate::transforms::sampler::Sampler`] transform is also another example of this, with a slightly different twist. @@ -488,10 +298,10 @@ pub trait Transform: Send { /// changing the behavior of the main chain. /// /// ## Naming - /// Transforms also have different naming conventions. - /// * Transforms that interact with an external system are called Sinks. - /// * Transforms that don't call subsequent chains via `requests_wrapper.call_next_transform()` are called terminating transforms. - /// * Transforms that do call subsquent chains via `requests_wrapper.call_next_transform()` are non-terminating transforms. + /// Transform also have different naming conventions. + /// * Transform that interact with an external system are called Sinks. + /// * Transform that don't call subsequent chains via `requests_wrapper.call_next_transform()` are called terminating transforms. + /// * Transform that do call subsquent chains via `requests_wrapper.call_next_transform()` are non-terminating transforms. /// /// You can have have a transforms that is both non-terminating and a sink. async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result; @@ -515,6 +325,8 @@ pub trait Transform: Send { Ok(response) } + fn get_name(&self) -> &'static str; + fn set_pushed_messages_tx(&mut self, _pushed_messages_tx: mpsc::UnboundedSender) {} } diff --git a/shotover/src/transforms/noop.rs b/shotover/src/transforms/noop.rs index a7e290b3f..9182091ed 100644 --- a/shotover/src/transforms/noop.rs +++ b/shotover/src/transforms/noop.rs @@ -3,6 +3,8 @@ use crate::transforms::{Transform, Wrapper}; use anyhow::Result; use async_trait::async_trait; +const NAME: &str = "NoOp"; + #[derive(Debug, Clone)] pub struct NoOp {} @@ -20,6 +22,10 @@ impl Default for NoOp { #[async_trait] impl Transform for NoOp { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { requests_wrapper.call_next_transform().await } diff --git a/shotover/src/transforms/null.rs b/shotover/src/transforms/null.rs index bab52e015..b04f5ff14 100644 --- a/shotover/src/transforms/null.rs +++ b/shotover/src/transforms/null.rs @@ -1,5 +1,5 @@ use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] pub struct NullSinkConfig; +const NAME: &str = "NullSink"; #[typetag::serde(name = "NullSink")] #[async_trait(?Send)] impl TransformConfig for NullSinkConfig { @@ -20,12 +21,12 @@ impl TransformConfig for NullSinkConfig { pub struct NullSink {} impl TransformBuilder for NullSink { - fn build(&self) -> super::Transforms { - Transforms::NullSink(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "NullSink" + NAME } fn is_terminating(&self) -> bool { @@ -35,6 +36,10 @@ impl TransformBuilder for NullSink { #[async_trait] impl Transform for NullSink { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { for message in &mut requests_wrapper.requests { *message = diff --git a/shotover/src/transforms/opensearch/mod.rs b/shotover/src/transforms/opensearch/mod.rs index 76c5647e1..43da41f47 100644 --- a/shotover/src/transforms/opensearch/mod.rs +++ b/shotover/src/transforms/opensearch/mod.rs @@ -1,7 +1,5 @@ use crate::tcp; -use crate::transforms::{ - Messages, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, -}; +use crate::transforms::{Messages, Transform, TransformBuilder, TransformConfig, Wrapper}; use crate::{ codec::{opensearch::OpenSearchCodecBuilder, CodecBuilder, Direction}, transforms::util::{ @@ -23,6 +21,7 @@ pub struct OpenSearchSinkSingleConfig { connect_timeout_ms: u64, } +const NAME: &str = "OpenSearchSinkSingle"; #[typetag::serde(name = "OpenSearchSinkSingle")] #[async_trait(?Send)] impl TransformConfig for OpenSearchSinkSingleConfig { @@ -53,8 +52,8 @@ impl OpenSearchSinkSingleBuilder { } impl TransformBuilder for OpenSearchSinkSingleBuilder { - fn build(&self) -> Transforms { - Transforms::OpenSearchSinkSingle(OpenSearchSinkSingle { + fn build(&self) -> Box { + Box::new(OpenSearchSinkSingle { address: self.address.clone(), connect_timeout: self.connect_timeout, codec_builder: OpenSearchCodecBuilder::new(Direction::Sink, self.get_name().to_owned()), @@ -63,7 +62,7 @@ impl TransformBuilder for OpenSearchSinkSingleBuilder { } fn get_name(&self) -> &'static str { - "OpenSearchSinkSingle" + NAME } fn is_terminating(&self) -> bool { @@ -80,6 +79,10 @@ pub struct OpenSearchSinkSingle { #[async_trait] impl Transform for OpenSearchSinkSingle { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { // Return immediately if we have no messages. // If we tried to send no messages we would block forever waiting for a reply that will never come. diff --git a/shotover/src/transforms/parallel_map.rs b/shotover/src/transforms/parallel_map.rs index 8e5f117c4..cec3fd3c2 100644 --- a/shotover/src/transforms/parallel_map.rs +++ b/shotover/src/transforms/parallel_map.rs @@ -1,7 +1,7 @@ use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::Result; use async_trait::async_trait; use futures::stream::{FuturesOrdered, FuturesUnordered}; @@ -18,7 +18,6 @@ pub struct ParallelMapBuilder { ordered: bool, } -#[derive(Debug)] pub struct ParallelMap { chains: Vec, ordered: bool, @@ -71,6 +70,7 @@ pub struct ParallelMapConfig { pub ordered_results: bool, } +const NAME: &str = "ParallelMap"; #[typetag::serde(name = "ParallelMap")] #[async_trait(?Send)] impl TransformConfig for ParallelMapConfig { @@ -89,6 +89,10 @@ impl TransformConfig for ParallelMapConfig { #[async_trait] impl Transform for ParallelMap { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { let mut results = Vec::with_capacity(requests_wrapper.requests.len()); let mut message_iter = requests_wrapper.requests.into_iter(); @@ -119,15 +123,15 @@ impl Transform for ParallelMap { } impl TransformBuilder for ParallelMapBuilder { - fn build(&self) -> Transforms { - Transforms::ParallelMap(ParallelMap { + fn build(&self) -> Box { + Box::new(ParallelMap { chains: self.chains.iter().map(|x| x.build()).collect(), ordered: self.ordered, }) } fn get_name(&self) -> &'static str { - "ParallelMap" + NAME } fn validate(&self) -> Vec { diff --git a/shotover/src/transforms/protect/mod.rs b/shotover/src/transforms/protect/mod.rs index 179b5a30f..4627aadbf 100644 --- a/shotover/src/transforms/protect/mod.rs +++ b/shotover/src/transforms/protect/mod.rs @@ -4,7 +4,7 @@ use crate::frame::{ use crate::message::Messages; use crate::transforms::protect::key_management::KeyManager; pub use crate::transforms::protect::key_management::KeyManagerConfig; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; use cql3_parser::cassandra_statement::CassandraStatement; @@ -30,6 +30,7 @@ pub struct ProtectConfig { #[cfg(feature = "alpha-transforms")] use crate::transforms::TransformConfig; +const NAME: &str = "Protect"; #[cfg(feature = "alpha-transforms")] #[typetag::serde(name = "Protect")] #[async_trait(?Send)] @@ -70,12 +71,12 @@ pub struct Protect { } impl TransformBuilder for Protect { - fn build(&self) -> Transforms { - Transforms::Protect(Box::new(self.clone())) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "Protect" + NAME } } @@ -167,6 +168,10 @@ impl Protect { #[async_trait] impl Transform for Protect { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // encrypt the values included in any INSERT or UPDATE queries for message in requests_wrapper.requests.iter_mut() { diff --git a/shotover/src/transforms/query_counter.rs b/shotover/src/transforms/query_counter.rs index 72e3c7833..2297de32f 100644 --- a/shotover/src/transforms/query_counter.rs +++ b/shotover/src/transforms/query_counter.rs @@ -1,7 +1,7 @@ use crate::frame::Frame; use crate::message::Messages; use crate::transforms::TransformConfig; -use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, Wrapper}; use anyhow::Result; use async_trait::async_trait; use metrics::{counter, register_counter}; @@ -28,17 +28,21 @@ impl QueryCounter { } impl TransformBuilder for QueryCounter { - fn build(&self) -> Transforms { - Transforms::QueryCounter(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "QueryCounter" + NAME } } #[async_trait] impl Transform for QueryCounter { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { for m in &mut requests_wrapper.requests { match m.frame() { @@ -73,6 +77,7 @@ impl Transform for QueryCounter { } } +const NAME: &str = "QueryCounter"; #[typetag::serde(name = "QueryCounter")] #[async_trait(?Send)] impl TransformConfig for QueryCounterConfig { diff --git a/shotover/src/transforms/redis/cache.rs b/shotover/src/transforms/redis/cache.rs index ba6a48c08..be48c7461 100644 --- a/shotover/src/transforms/redis/cache.rs +++ b/shotover/src/transforms/redis/cache.rs @@ -2,7 +2,7 @@ use crate::config::chain::TransformChainConfig; use crate::frame::{CassandraFrame, CassandraOperation, Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{bail, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -81,6 +81,7 @@ pub struct RedisConfig { pub chain: TransformChainConfig, } +const NAME: &str = "RedisCache"; #[typetag::serde(name = "RedisCache")] #[async_trait(?Send)] impl TransformConfig for RedisConfig { @@ -108,8 +109,8 @@ pub struct SimpleRedisCacheBuilder { } impl TransformBuilder for SimpleRedisCacheBuilder { - fn build(&self) -> Transforms { - Transforms::RedisCache(SimpleRedisCache { + fn build(&self) -> Box { + Box::new(SimpleRedisCache { cache_chain: self.cache_chain.build(), caching_schema: self.caching_schema.clone(), missed_requests: self.missed_requests.clone(), @@ -117,7 +118,7 @@ impl TransformBuilder for SimpleRedisCacheBuilder { } fn get_name(&self) -> &'static str { - "RedisCache" + NAME } fn validate(&self) -> Vec { @@ -568,6 +569,10 @@ fn build_redis_key_from_cql3( #[async_trait] impl Transform for SimpleRedisCache { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { let cache_responses = self .read_from_cache(&mut requests_wrapper.requests, requests_wrapper.local_addr) diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index 8c13cf6cc..b2904b03c 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -1,7 +1,7 @@ use crate::frame::Frame; use crate::frame::RedisFrame; use crate::message::Messages; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use bytes::{BufMut, Bytes, BytesMut}; @@ -14,6 +14,7 @@ pub struct RedisClusterPortsRewriteConfig { pub new_port: u16, } +const NAME: &str = "RedisClusterPortsRewrite"; #[typetag::serde(name = "RedisClusterPortsRewrite")] #[async_trait(?Send)] impl TransformConfig for RedisClusterPortsRewriteConfig { @@ -25,12 +26,12 @@ impl TransformConfig for RedisClusterPortsRewriteConfig { } impl TransformBuilder for RedisClusterPortsRewrite { - fn build(&self) -> Transforms { - Transforms::RedisClusterPortsRewrite(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "RedisClusterPortsRewrite" + NAME } } @@ -47,6 +48,10 @@ impl RedisClusterPortsRewrite { #[async_trait] impl Transform for RedisClusterPortsRewrite { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // Find the indices of cluster slot messages let mut cluster_slots_indices = vec![]; diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 71ff98856..ae330d517 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -7,9 +7,7 @@ use crate::transforms::redis::RedisError; use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; use crate::transforms::util::{Request, Response}; -use crate::transforms::{ - ResponseFuture, Transform, TransformBuilder, TransformConfig, Transforms, Wrapper, -}; +use crate::transforms::{ResponseFuture, Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, bail, ensure, Context, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -46,6 +44,7 @@ pub struct RedisSinkClusterConfig { pub connect_timeout_ms: u64, } +const NAME: &str = "RedisSinkCluster"; #[typetag::serde(name = "RedisSinkCluster")] #[async_trait(?Send)] impl TransformConfig for RedisSinkClusterConfig { @@ -77,8 +76,8 @@ pub struct RedisSinkClusterBuilder { } impl TransformBuilder for RedisSinkClusterBuilder { - fn build(&self) -> Transforms { - Transforms::RedisSinkCluster(RedisSinkCluster::new( + fn build(&self) -> Box { + Box::new(RedisSinkCluster::new( self.first_contact_points.clone(), self.direct_destination.clone(), self.connection_count, @@ -89,7 +88,7 @@ impl TransformBuilder for RedisSinkClusterBuilder { } fn get_name(&self) -> &'static str { - "RedisSinkCluster" + NAME } fn is_terminating(&self) -> bool { @@ -185,7 +184,7 @@ impl RedisSinkCluster { } fn get_name(&self) -> &'static str { - "RedisSinkCluster" + NAME } #[inline] @@ -985,6 +984,10 @@ fn short_circuit(frame: RedisFrame) -> Result { #[async_trait] impl Transform for RedisSinkCluster { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { if !self.has_run_init { self.topology = (*self.shared_topology.read().await).clone(); diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index f13808b65..54af80b13 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -6,7 +6,7 @@ use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages}; use crate::tcp; use crate::tls::{AsyncStream, TlsConnector, TlsConnectorConfig}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::{FutureExt, SinkExt, StreamExt}; @@ -29,6 +29,7 @@ pub struct RedisSinkSingleConfig { pub connect_timeout_ms: u64, } +const NAME: &str = "RedisSinkSingle"; #[typetag::serde(name = "RedisSinkSingle")] #[async_trait(?Send)] impl TransformConfig for RedisSinkSingleConfig { @@ -71,8 +72,8 @@ impl RedisSinkSingleBuilder { } impl TransformBuilder for RedisSinkSingleBuilder { - fn build(&self) -> Transforms { - Transforms::RedisSinkSingle(RedisSinkSingle { + fn build(&self) -> Box { + Box::new(RedisSinkSingle { address: self.address.clone(), tls: self.tls.clone(), connection: None, @@ -83,7 +84,7 @@ impl TransformBuilder for RedisSinkSingleBuilder { } fn get_name(&self) -> &'static str { - "RedisSinkSingle" + NAME } fn is_terminating(&self) -> bool { @@ -110,6 +111,10 @@ pub struct RedisSinkSingle { #[async_trait] impl Transform for RedisSinkSingle { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // Return immediately if we have no messages. // If we tried to send no messages we would block forever waiting for a reply that will never come. diff --git a/shotover/src/transforms/redis/timestamp_tagging.rs b/shotover/src/transforms/redis/timestamp_tagging.rs index 3da99349f..8fb10c93b 100644 --- a/shotover/src/transforms/redis/timestamp_tagging.rs +++ b/shotover/src/transforms/redis/timestamp_tagging.rs @@ -1,7 +1,7 @@ use crate::frame::redis::redis_query_type; use crate::frame::{Frame, RedisFrame}; use crate::message::{Message, Messages, QueryType}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Result}; use async_trait::async_trait; use bytes::Bytes; @@ -15,6 +15,7 @@ use tracing::{debug, trace}; #[serde(deny_unknown_fields)] pub struct RedisTimestampTaggerConfig; +const NAME: &str = "RedisTimestampTagger"; #[typetag::serde(name = "RedisTimestampTagger")] #[async_trait(?Send)] impl TransformConfig for RedisTimestampTaggerConfig { @@ -33,12 +34,12 @@ impl RedisTimestampTagger { } impl TransformBuilder for RedisTimestampTagger { - fn build(&self) -> Transforms { - Transforms::RedisTimestampTagger(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "RedisTimestampTagger" + NAME } } @@ -182,6 +183,10 @@ fn unwrap_response(message: &mut Message) -> Result<()> { #[async_trait] impl Transform for RedisTimestampTagger { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // TODO: This is wrong. We need to keep track of tagged_success per message let mut tagged_success = true; diff --git a/shotover/src/transforms/sampler.rs b/shotover/src/transforms/sampler.rs index 1c19fadcf..fe465dfe3 100644 --- a/shotover/src/transforms/sampler.rs +++ b/shotover/src/transforms/sampler.rs @@ -6,6 +6,8 @@ use async_trait::async_trait; use tokio::macros::support::thread_rng_n; use tracing::warn; +const NAME: &str = "Sampler"; + #[derive(Debug)] pub struct SamplerBuilder { pub numerator: u32, @@ -29,7 +31,6 @@ impl Default for SamplerBuilder { } } -#[derive(Debug)] pub struct Sampler { numerator: u32, denominator: u32, @@ -38,6 +39,10 @@ pub struct Sampler { #[async_trait] impl Transform for Sampler { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { let chance = thread_rng_n(self.denominator); if chance < self.numerator { diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index 84c9b74bd..38c7911e2 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,7 +1,7 @@ use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; -use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper}; +use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper}; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use atomic_enum::atomic_enum; @@ -64,8 +64,8 @@ impl TeeBuilder { } impl TransformBuilder for TeeBuilder { - fn build(&self) -> Transforms { - Transforms::Tee(Tee { + fn build(&self) -> Box { + Box::new(Tee { tx: self.tx.build_buffered(self.buffer_size), behavior: match &self.behavior { ConsistencyBehaviorBuilder::Ignore => ConsistencyBehavior::Ignore, @@ -85,7 +85,7 @@ impl TransformBuilder for TeeBuilder { } fn get_name(&self) -> &'static str { - "Tee" + NAME } fn validate(&self) -> Vec { @@ -163,6 +163,7 @@ pub enum ConsistencyBehaviorConfig { SubchainOnMismatch(TransformChainConfig), } +const NAME: &str = "Tee"; #[typetag::serde(name = "Tee")] #[async_trait(?Send)] impl TransformConfig for TeeConfig { @@ -199,6 +200,10 @@ impl TransformConfig for TeeConfig { #[async_trait] impl Transform for Tee { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { match &mut self.behavior { ConsistencyBehavior::Ignore => self.ignore_behaviour(requests_wrapper).await, diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 26343afbd..2bae4b1db 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -13,14 +13,13 @@ use serde::{Deserialize, Serialize}; use std::num::NonZeroU32; use std::sync::Arc; -use super::Transforms; - #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct RequestThrottlingConfig { pub max_requests_per_second: NonZeroU32, } +const NAME: &str = "RequestThrottling"; #[typetag::serde(name = "RequestThrottling")] #[async_trait(?Send)] impl TransformConfig for RequestThrottlingConfig { @@ -41,12 +40,12 @@ pub struct RequestThrottling { } impl TransformBuilder for RequestThrottling { - fn build(&self) -> Transforms { - Transforms::RequestThrottling(self.clone()) + fn build(&self) -> Box { + Box::new(self.clone()) } fn get_name(&self) -> &'static str { - "RequestThrottlingConfig" + NAME } fn validate(&self) -> Vec { @@ -63,6 +62,10 @@ impl TransformBuilder for RequestThrottling { #[async_trait] impl Transform for RequestThrottling { + fn get_name(&self) -> &'static str { + NAME + } + async fn transform<'a>(&'a mut self, mut requests_wrapper: Wrapper<'a>) -> Result { // extract throttled messages from the requests_wrapper let throttled_messages: Vec<(Message, usize)> = (0..requests_wrapper.requests.len())