Skip to content

Commit

Permalink
transform configs: require Serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 14, 2023
1 parent 6021ef5 commit 261793b
Show file tree
Hide file tree
Showing 38 changed files with 121 additions and 107 deletions.
6 changes: 3 additions & 3 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::Messages;
use shotover::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper};

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct RedisGetRewriteConfig {
pub result: String,
}

#[typetag::deserialize(name = "RedisGetRewrite")]
#[typetag::serde(name = "RedisGetRewrite")]
#[async_trait(?Send)]
impl TransformConfig for RedisGetRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/config/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::transforms::chain::TransformChainBuilder;
use crate::transforms::{TransformBuilder, TransformConfig};
use anyhow::Result;
use serde::de::{DeserializeSeed, Deserializer, MapAccess, SeqAccess, Visitor};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::fmt::{self, Debug};
use std::iter;

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct TransformChainConfig(
#[serde(rename = "TransformChain", deserialize_with = "vec_transform_config")]
Expand Down
2 changes: 2 additions & 0 deletions shotover/src/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Config types, used for serializing/deserializing shotover configuration files
use anyhow::{Context, Result};
use serde::Deserialize;

Expand Down
13 changes: 11 additions & 2 deletions shotover/src/config/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ use crate::sources::{Source, SourceConfig};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::{anyhow, Context, Result};
use itertools::Itertools;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::watch;
use tracing::info;

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct Topology {
pub sources: HashMap<String, SourceConfig>,
Expand All @@ -17,6 +17,7 @@ pub struct Topology {
}

impl Topology {
/// Load the topology.yaml from the provided path into a Topology instance
pub fn from_file(filepath: &str) -> Result<Topology> {
let file = std::fs::File::open(filepath)
.with_context(|| format!("Couldn't open the topology file {}", filepath))?;
Expand All @@ -26,6 +27,14 @@ impl Topology {
.with_context(|| format!("Failed to parse topology file {}", filepath))
}

/// Generate the yaml representation of this instance
pub fn serialize(&self) -> Result<String> {
let mut output = vec![];
let mut serializer = serde_yaml::Serializer::new(&mut output);
serde_yaml::with::singleton_map_recursive::serialize(self, &mut serializer)?;
Ok(String::from_utf8(output).unwrap())
}

async fn build_chains(&self) -> Result<HashMap<String, Option<TransformChainBuilder>>> {
let mut result = HashMap::new();
for (source_name, chain_name) in &self.source_to_chain_mapping {
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
#![allow(clippy::needless_doctest_main)]

pub mod codec;
mod config;
pub mod config;
pub mod frame;
pub mod message;
mod observability;
pub mod runner;
mod server;
mod sources;
pub mod sources;
pub mod tcp;
pub mod tls;
mod tracing_panic_handler;
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use bytes::{Buf, Bytes};
use cassandra_protocol::compression::Compression;
use cassandra_protocol::frame::message_error::{ErrorBody, ErrorType};
use nonzero_ext::nonzero;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::num::NonZeroU32;

pub enum Metadata {
Expand Down Expand Up @@ -436,7 +436,7 @@ pub enum Encodable {
Frame(Frame),
}

#[derive(PartialEq, Debug, Clone, Deserialize)]
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub enum QueryType {
Read,
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/sources/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use crate::sources::{Source, Transport};
use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::Result;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct CassandraConfig {
pub listen_addr: String,
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/sources/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::sources::{Source, Transport};
use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::Result;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct KafkaConfig {
pub listen_addr: String,
Expand Down
8 changes: 5 additions & 3 deletions shotover/src/sources/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! Sources used to listen for connections and send/recieve with the client.
use crate::sources::cassandra::{CassandraConfig, CassandraSource};
use crate::sources::kafka::{KafkaConfig, KafkaSource};
use crate::sources::opensearch::{OpenSearchConfig, OpenSearchSource};
use crate::sources::redis::{RedisConfig, RedisSource};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::Result;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tokio::sync::watch;
use tokio::task::JoinHandle;

Expand All @@ -13,7 +15,7 @@ pub mod kafka;
pub mod opensearch;
pub mod redis;

#[derive(Deserialize, Debug, Clone, Copy)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
#[serde(deny_unknown_fields)]
pub enum Transport {
Tcp,
Expand All @@ -39,7 +41,7 @@ impl Source {
}
}

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub enum SourceConfig {
Cassandra(CassandraConfig),
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/sources/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use crate::server::TcpCodecListener;
use crate::sources::{Source, Transport};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::Result;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct OpenSearchConfig {
pub listen_addr: String,
pub connection_limit: Option<usize>,
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/sources/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::sources::{Source, Transport};
use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use crate::transforms::chain::TransformChainBuilder;
use anyhow::Result;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct RedisConfig {
pub listen_addr: String,
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ use cassandra_protocol::frame::events::{ServerEvent, StatusChange};
use cql3_parser::cassandra_statement::CassandraStatement;
use cql3_parser::common::{FQName, Identifier};
use cql3_parser::select::SelectElement;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct CassandraPeersRewriteConfig {
pub port: u16,
}

#[typetag::deserialize(name = "CassandraPeersRewrite")]
#[typetag::serde(name = "CassandraPeersRewrite")]
#[async_trait(?Send)]
impl TransformConfig for CassandraPeersRewriteConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
8 changes: 4 additions & 4 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use metrics::{register_counter, Counter};
use node::{CassandraNode, ConnectionFactory};
use node_pool::{GetReplicaErr, KeyspaceMetadata, NodePool};
use rand::prelude::*;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
Expand All @@ -47,7 +47,7 @@ const SYSTEM_KEYSPACES: [IdentifierRef<'static>; 3] = [
IdentifierRef::Quoted("system_distributed"),
];

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct CassandraSinkClusterConfig {
/// contact points must be within the specified data_center and rack.
Expand All @@ -61,7 +61,7 @@ pub struct CassandraSinkClusterConfig {
pub read_timeout: Option<u64>,
}

#[typetag::deserialize(name = "CassandraSinkCluster")]
#[typetag::serde(name = "CassandraSinkCluster")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkClusterConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down Expand Up @@ -181,7 +181,7 @@ impl TransformBuilder for CassandraSinkClusterBuilder {
}
}

#[derive(Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct ShotoverNode {
pub address: SocketAddr,
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ use async_trait::async_trait;
use cassandra_protocol::frame::Version;
use futures::stream::FuturesOrdered;
use metrics::{register_counter, Counter};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tracing::trace;

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct CassandraSinkSingleConfig {
#[serde(rename = "remote_address")]
Expand All @@ -25,7 +25,7 @@ pub struct CassandraSinkSingleConfig {
pub read_timeout: Option<u64>,
}

#[typetag::deserialize(name = "CassandraSinkSingle")]
#[typetag::serde(name = "CassandraSinkSingle")]
#[async_trait(?Send)]
impl TransformConfig for CassandraSinkSingleConfig {
async fn get_builder(&self, chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::time::Instant;

#[derive(Debug, Clone)]
Expand All @@ -13,14 +13,14 @@ pub struct Coalesce {
last_write: Instant,
}

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct CoalesceConfig {
pub flush_when_buffered_message_count: Option<usize>,
pub flush_when_millis_since_last_flush: Option<u128>,
}

#[typetag::deserialize(name = "Coalesce")]
#[typetag::serde(name = "Coalesce")]
#[async_trait(?Send)]
impl TransformConfig for CoalesceConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
10 changes: 5 additions & 5 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ use crate::transforms::TransformConfig;
use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

/// Messages that pass through this transform will be parsed.
/// Must be individually enabled at the request or response level.
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct DebugForceParseConfig {
pub parse_requests: bool,
pub parse_responses: bool,
}

#[cfg(feature = "alpha-transforms")]
#[typetag::deserialize(name = "DebugForceParse")]
#[typetag::serde(name = "DebugForceParse")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceParseConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand All @@ -37,15 +37,15 @@ impl TransformConfig for DebugForceParseConfig {

/// Messages that pass through this transform will be parsed and then reencoded.
/// Must be individually enabled at the request or response level.
#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct DebugForceEncodeConfig {
pub encode_requests: bool,
pub encode_responses: bool,
}

#[cfg(feature = "alpha-transforms")]
#[typetag::deserialize(name = "DebugForceEncode")]
#[typetag::serde(name = "DebugForceEncode")]
#[async_trait(?Send)]
impl TransformConfig for DebugForceEncodeConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ use crate::message::{Encodable, Message};
use crate::transforms::{Transform, TransformBuilder, Transforms, Wrapper};
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tracing::{error, info};

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct DebugLogToFileConfig;

#[cfg(feature = "alpha-transforms")]
#[typetag::deserialize(name = "DebugLogToFile")]
#[typetag::serde(name = "DebugLogToFile")]
#[async_trait(?Send)]
impl crate::transforms::TransformConfig for DebugLogToFileConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
6 changes: 3 additions & 3 deletions shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Transforms, Wrapper};
use anyhow::Result;
use async_trait::async_trait;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use tracing::info;

#[derive(Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct DebugPrinterConfig;

#[typetag::deserialize(name = "DebugPrinter")]
#[typetag::serde(name = "DebugPrinter")]
#[async_trait(?Send)]
impl TransformConfig for DebugPrinterConfig {
async fn get_builder(&self, _chain_name: String) -> Result<Box<dyn TransformBuilder>> {
Expand Down
Loading

0 comments on commit 261793b

Please sign in to comment.