Skip to content

Commit

Permalink
Introduce TransformContextConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 21, 2024
1 parent 10b04e6 commit c406e7d
Show file tree
Hide file tree
Showing 36 changed files with 319 additions and 96 deletions.
9 changes: 7 additions & 2 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
Expand All @@ -15,7 +17,10 @@ const NAME: &str = "RedisGetRewrite";
#[typetag::serde(name = "RedisGetRewrite")]
#[async_trait(?Send)]
impl TransformConfig for RedisGetRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(RedisGetRewriteBuilder {
result: self.result.clone(),
}))
Expand Down
12 changes: 9 additions & 3 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig};
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{TransformConfig, Wrapper};
use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper};

fn criterion_benchmark(c: &mut Criterion) {
crate::init();
Expand Down Expand Up @@ -192,7 +192,10 @@ fn criterion_benchmark(c: &mut Criterion) {
// an absurdly large value is given so that all messages will pass through
max_requests_per_second: std::num::NonZeroU32::new(100_000_000).unwrap(),
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: ProtocolType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down Expand Up @@ -301,7 +304,10 @@ fn criterion_benchmark(c: &mut Criterion) {
kek_id: "".to_string(),
},
}
.get_builder("".to_owned()),
.get_builder(TransformContextConfig {
chain_name: "".into(),
protocol: ProtocolType::Redis,
}),
)
.unwrap(),
Box::<NullSink>::default(),
Expand Down
8 changes: 5 additions & 3 deletions shotover/src/codec/cassandra.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{CodecBuilder, CodecReadError, CodecWriteError, Direction};
use crate::frame::cassandra::{CassandraMetadata, CassandraOperation, Tracing};
use crate::frame::{CassandraFrame, Frame, MessageType};
use crate::message::{Encodable, Message, Messages, Metadata};
use crate::message::{Encodable, Message, Messages, Metadata, ProtocolType};
use anyhow::{anyhow, Result};
use atomic_enum::atomic_enum;
use bytes::{Buf, BufMut, Bytes, BytesMut};
Expand Down Expand Up @@ -139,8 +139,10 @@ impl CodecBuilder for CassandraCodecBuilder {
)
}

fn websocket_subprotocol(&self) -> &'static str {
"cql"
fn protocol(&self) -> ProtocolType {
ProtocolType::Cassandra {
compression: Compression::None,
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions shotover/src/codec/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ impl CodecBuilder for KafkaCodecBuilder {
)
}

fn websocket_subprotocol(&self) -> &'static str {
"kafka"
fn protocol(&self) -> ProtocolType {
ProtocolType::Kafka {
request_header: None,
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions shotover/src/codec/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Codec types to use for connecting to a DB in a sink transform
use crate::message::Messages;
use crate::message::{Messages, ProtocolType};
#[cfg(feature = "cassandra")]
use cassandra_protocol::compression::Compression;
use core::fmt;
Expand Down Expand Up @@ -128,5 +128,5 @@ pub trait CodecBuilder: Clone + Send {

fn new(direction: Direction, destination_name: String) -> Self;

fn websocket_subprotocol(&self) -> &'static str;
fn protocol(&self) -> ProtocolType;
}
6 changes: 3 additions & 3 deletions shotover/src/codec/opensearch.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{CodecBuilder, CodecReadError, CodecWriteError, Direction};
use crate::message::{Encodable, Message, Messages};
use crate::message::{Encodable, Message, Messages, ProtocolType};
use crate::{
frame::{
opensearch::{HttpHead, RequestParts, ResponseParts},
Expand Down Expand Up @@ -56,8 +56,8 @@ impl CodecBuilder for OpenSearchCodecBuilder {
)
}

fn websocket_subprotocol(&self) -> &'static str {
"opensearch"
fn protocol(&self) -> ProtocolType {
ProtocolType::OpenSearch
}
}

Expand Down
6 changes: 3 additions & 3 deletions shotover/src/codec/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::time::Instant;
use super::{CodecWriteError, Direction};
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::{Frame, MessageType, RedisFrame};
use crate::message::{Encodable, Message, MessageId, Messages};
use crate::message::{Encodable, Message, MessageId, Messages, ProtocolType};
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use metrics::Histogram;
Expand Down Expand Up @@ -44,8 +44,8 @@ impl CodecBuilder for RedisCodecBuilder {
)
}

fn websocket_subprotocol(&self) -> &'static str {
"redis"
fn protocol(&self) -> ProtocolType {
ProtocolType::Redis
}
}

Expand Down
14 changes: 10 additions & 4 deletions shotover/src/config/chain.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::transforms::chain::TransformChainBuilder;
use crate::transforms::{TransformBuilder, TransformConfig};
use crate::transforms::{TransformBuilder, TransformConfig, TransformContextConfig};
use anyhow::Result;
use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::{Deserialize, Serialize};
Expand All @@ -14,12 +14,18 @@ pub struct TransformChainConfig(
);

impl TransformChainConfig {
pub async fn get_builder(&self, name: String) -> Result<TransformChainBuilder> {
pub async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<TransformChainBuilder> {
let mut transforms: Vec<Box<dyn TransformBuilder>> = Vec::new();
for tc in &self.0 {
transforms.push(tc.get_builder(name.clone()).await?)
transforms.push(tc.get_builder(transform_context.clone()).await?)
}
Ok(TransformChainBuilder::new(transforms, name.leak()))
Ok(TransformChainBuilder::new(
transforms,
transform_context.chain_name.leak(),
))
}
}

Expand Down
30 changes: 29 additions & 1 deletion shotover/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub enum Metadata {
OpenSearch,
}

#[derive(PartialEq)]
#[derive(Clone, PartialEq)]
pub enum ProtocolType {
#[cfg(feature = "cassandra")]
Cassandra { compression: Compression },
Expand All @@ -48,6 +48,34 @@ pub enum ProtocolType {
OpenSearch,
}

impl ProtocolType {
pub fn is_inorder(&self) -> bool {
match self {
#[cfg(feature = "cassandra")]
ProtocolType::Cassandra { .. } => false,
#[cfg(feature = "redis")]
ProtocolType::Redis => true,
#[cfg(feature = "kafka")]
ProtocolType::Kafka { .. } => true,
#[cfg(feature = "opensearch")]
ProtocolType::OpenSearch => true,
}
}

pub fn websocket_subprotocol(&self) -> &'static str {
match self {
#[cfg(feature = "cassandra")]
ProtocolType::Cassandra { .. } => "cql",
#[cfg(feature = "redis")]
ProtocolType::Redis => "redis",
#[cfg(feature = "kafka")]
ProtocolType::Kafka { .. } => "kafka",
#[cfg(feature = "opensearch")]
ProtocolType::OpenSearch => "opensearch",
}
}
}

impl From<&ProtocolType> for CodecState {
fn from(value: &ProtocolType) -> Self {
match value {
Expand Down
10 changes: 7 additions & 3 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::Wrapper;
use crate::transforms::{TransformContextConfig, Wrapper};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand Down Expand Up @@ -92,8 +92,12 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
gauge!("shotover_available_connections_count", "source" => source_name.clone());
available_connections_gauge.set(limit_connections.available_permits() as f64);

let chain_usage_config = TransformContextConfig {
chain_name: source_name.clone(),
protocol: codec.protocol(),
};
let chain_builder = chain_config
.get_builder(source_name.clone())
.get_builder(chain_usage_config)
.await
.map_err(|x| vec![format!("{x:?}")])?;

Expand Down Expand Up @@ -594,7 +598,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {

match transport {
Transport::WebSocket => {
let websocket_subprotocol = codec_builder.websocket_subprotocol();
let websocket_subprotocol = codec_builder.protocol().websocket_subprotocol();

if let Some(tls) = &self.tls {
let tls_stream = match tls.accept(stream).await {
Expand Down
16 changes: 11 additions & 5 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use crate::frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
};
use crate::message::{Message, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::{
frame::{
value::{GenericValue, IntSize},
CassandraOperation, CassandraResult, Frame,
},
transforms::TransformContextConfig,
};
use anyhow::Result;
use async_trait::async_trait;
use cassandra_protocol::frame::events::{ServerEvent, StatusChange};
Expand All @@ -23,7 +26,10 @@ const NAME: &str = "CassandraPeersRewrite";
#[typetag::serde(name = "CassandraPeersRewrite")]
#[async_trait(?Send)]
impl TransformConfig for CassandraPeersRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(CassandraPeersRewrite::new(self.port)))
}
}
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame};
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError};
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use cassandra_protocol::events::ServerEvent;
Expand Down Expand Up @@ -66,7 +68,10 @@ const NAME: &str = "CassandraSinkCluster";
#[typetag::serde(name = "CassandraSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
let mut shotover_nodes = self.shotover_nodes.clone();
let index = self
Expand All @@ -84,7 +89,7 @@ impl TransformConfig for CassandraSinkClusterConfig {
Ok(Box::new(CassandraSinkClusterBuilder::new(
self.first_contact_points.clone(),
shotover_nodes,
chain_name,
transform_context.chain_name,
local_node,
tls,
self.connect_timeout_ms,
Expand Down
11 changes: 8 additions & 3 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::frame::cassandra::CassandraMetadata;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::Response;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use cassandra_protocol::frame::Version;
Expand All @@ -29,11 +31,14 @@ const NAME: &str = "CassandraSinkSingle";
#[typetag::serde(name = "CassandraSinkSingle")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkSingleConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.clone().map(TlsConnector::new).transpose()?;
Ok(Box::new(CassandraSinkSingleBuilder::new(
self.address.clone(),
chain_name,
transform_context.chain_name,
tls,
self.connect_timeout_ms,
self.read_timeout,
Expand Down
6 changes: 5 additions & 1 deletion shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::TransformContextConfig;
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand All @@ -24,7 +25,10 @@ const NAME: &str = "Coalesce";
#[typetag::serde(name = "Coalesce")]
#[async_trait(?Send)]
impl TransformConfig for CoalesceConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(Coalesce {
buffer: Vec::with_capacity(self.flush_when_buffered_message_count.unwrap_or(0)),
flush_when_buffered_message_count: self.flush_when_buffered_message_count,
Expand Down
12 changes: 9 additions & 3 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use crate::message::Messages;
/// This transform will by default parse requests and responses that pass through it.
/// request and response parsing can be individually disabled if desired.
///
Expand All @@ -8,6 +7,7 @@ use crate::message::Messages;
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
use crate::{message::Messages, transforms::TransformContextConfig};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand All @@ -25,7 +25,10 @@ pub struct DebugForceParseConfig {
#[typetag::serde(name = "DebugForceParse")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceParseConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.parse_requests,
parse_responses: self.parse_responses,
Expand All @@ -49,7 +52,10 @@ const NAME: &str = "DebugForceEncode";
#[typetag::serde(name = "DebugForceEncode")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceEncodeConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(DebugForceParse {
parse_requests: self.encode_requests,
parse_responses: self.encode_responses,
Expand Down
Loading

0 comments on commit c406e7d

Please sign in to comment.