diff --git a/shotover/src/codec/cassandra.rs b/shotover/src/codec/cassandra.rs index f14e5b004..bdfd32e7d 100644 --- a/shotover/src/codec/cassandra.rs +++ b/shotover/src/codec/cassandra.rs @@ -139,8 +139,8 @@ impl CodecBuilder for CassandraCodecBuilder { ) } - fn websocket_subprotocol(&self) -> &'static str { - "cql" + fn protocol(&self) -> MessageType { + MessageType::Cassandra } } diff --git a/shotover/src/codec/kafka.rs b/shotover/src/codec/kafka.rs index cc92681f9..85892a0ba 100644 --- a/shotover/src/codec/kafka.rs +++ b/shotover/src/codec/kafka.rs @@ -52,8 +52,8 @@ impl CodecBuilder for KafkaCodecBuilder { ) } - fn websocket_subprotocol(&self) -> &'static str { - "kafka" + fn protocol(&self) -> MessageType { + MessageType::Kafka } } diff --git a/shotover/src/codec/mod.rs b/shotover/src/codec/mod.rs index 8280c98a3..f750ab526 100644 --- a/shotover/src/codec/mod.rs +++ b/shotover/src/codec/mod.rs @@ -1,6 +1,6 @@ //! Codec types to use for connecting to a DB in a sink transform -use crate::message::Messages; +use crate::{frame::MessageType, message::Messages}; #[cfg(feature = "cassandra")] use cassandra_protocol::compression::Compression; use core::fmt; @@ -128,5 +128,5 @@ pub trait CodecBuilder: Clone + Send { fn new(direction: Direction, destination_name: String) -> Self; - fn websocket_subprotocol(&self) -> &'static str; + fn protocol(&self) -> MessageType; } diff --git a/shotover/src/codec/opensearch.rs b/shotover/src/codec/opensearch.rs index 1024cddaa..c83a6b6a5 100644 --- a/shotover/src/codec/opensearch.rs +++ b/shotover/src/codec/opensearch.rs @@ -56,8 +56,8 @@ impl CodecBuilder for OpenSearchCodecBuilder { ) } - fn websocket_subprotocol(&self) -> &'static str { - "opensearch" + fn protocol(&self) -> MessageType { + MessageType::OpenSearch } } diff --git a/shotover/src/codec/redis.rs b/shotover/src/codec/redis.rs index 3f8366dbf..13ce77835 100644 --- a/shotover/src/codec/redis.rs +++ b/shotover/src/codec/redis.rs @@ -44,8 +44,8 @@ impl CodecBuilder for RedisCodecBuilder { ) } - fn websocket_subprotocol(&self) -> &'static str { - "redis" + fn protocol(&self) -> MessageType { + MessageType::Redis } } diff --git a/shotover/src/frame/mod.rs b/shotover/src/frame/mod.rs index 6dc614697..465a46f1c 100644 --- a/shotover/src/frame/mod.rs +++ b/shotover/src/frame/mod.rs @@ -38,6 +38,36 @@ pub enum MessageType { OpenSearch, } +impl MessageType { + pub fn is_inorder(&self) -> bool { + match self { + #[cfg(feature = "cassandra")] + MessageType::Cassandra => false, + #[cfg(feature = "redis")] + MessageType::Redis => true, + #[cfg(feature = "kafka")] + MessageType::Kafka => true, + #[cfg(feature = "opensearch")] + MessageType::OpenSearch => true, + MessageType::Dummy => false, + } + } + + pub fn websocket_subprotocol(&self) -> &'static str { + match self { + #[cfg(feature = "cassandra")] + MessageType::Cassandra => "cql", + #[cfg(feature = "redis")] + MessageType::Redis => "redis", + #[cfg(feature = "kafka")] + MessageType::Kafka => "kafka", + #[cfg(feature = "opensearch")] + MessageType::OpenSearch => "opensearch", + MessageType::Dummy => "dummy", + } + } +} + impl From<&ProtocolType> for MessageType { fn from(value: &ProtocolType) -> Self { match value { diff --git a/shotover/src/server.rs b/shotover/src/server.rs index d943575fd..4da1c5fc8 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -594,7 +594,7 @@ impl Handler { match transport { Transport::WebSocket => { - let websocket_subprotocol = codec_builder.websocket_subprotocol(); + let websocket_subprotocol = codec_builder.protocol().websocket_subprotocol(); if let Some(tls) = &self.tls { let tls_stream = match tls.accept(stream).await {