Skip to content

Commit

Permalink
Add MessageId (progress towards cleaning up transform invariants)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Feb 14, 2024
1 parent 17f394d commit 26ad3d5
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 57 deletions.
29 changes: 29 additions & 0 deletions shotover/src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ impl From<&ProtocolType> for CodecState {

pub type Messages = Vec<Message>;

/// Unique identifier for the message assigned by shotover at creation time.
pub type MessageId = u128;

/// Message holds a single message/query/result going between the client and database.
/// It is designed to efficiently abstract over the message being in various states of processing.
///
Expand Down Expand Up @@ -98,6 +101,9 @@ pub struct Message {
pub(crate) received_from_source_or_sink_at: Option<Instant>,

pub(crate) codec_state: CodecState,

pub(crate) id: MessageId,
pub(crate) request_id: Option<MessageId>,
}

// `from_*` methods for `Message`
Expand All @@ -118,6 +124,8 @@ impl Message {
meta_timestamp: None,
codec_state: CodecState::from(&protocol_type),
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}

Expand All @@ -134,6 +142,8 @@ impl Message {
inner: Some(MessageInner::Parsed { bytes, frame }),
meta_timestamp: None,
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}

Expand All @@ -149,6 +159,8 @@ impl Message {
inner: Some(MessageInner::Modified { frame }),
meta_timestamp: None,
received_from_source_or_sink_at,
id: rand::random(),
request_id: None,
}
}

Expand Down Expand Up @@ -194,6 +206,23 @@ impl Message {
}
}

/// Return the shotover assigned MessageId
pub fn id(&self) -> MessageId {
self.id
}

/// Return the MessageId of the request that resulted in this message
/// Returns None when:
/// * The message is a request
/// * The message is a response but was not created in response to a request. e.g. Cassandra events and redis pubsub
pub fn request_id(&self) -> Option<MessageId> {
self.request_id
}

pub fn set_request_id(&mut self, request_id: MessageId) {
self.request_id = Some(request_id);
}

pub fn ensure_message_type(&self, expected_message_type: MessageType) -> Result<()> {
match self.inner.as_ref().unwrap() {
MessageInner::RawBytes { message_type, .. } => {
Expand Down
23 changes: 14 additions & 9 deletions shotover/src/transforms/cassandra/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::codec::cassandra::{CassandraCodecBuilder, CassandraDecoder, Cassandra
use crate::codec::{CodecBuilder, CodecReadError};
use crate::frame::cassandra::CassandraMetadata;
use crate::frame::{CassandraFrame, Frame};
use crate::message::{Message, Metadata};
use crate::message::{Message, MessageId, Metadata};
use crate::tcp;
use crate::tls::{TlsConnector, ToHostname};
use crate::transforms::Messages;
Expand Down Expand Up @@ -52,6 +52,7 @@ impl ResponseError {
#[derive(Debug)]
struct ReturnChannel {
return_chan: oneshot::Sender<Response>,
request_id: MessageId,
stream_id: i16,
}

Expand Down Expand Up @@ -185,6 +186,7 @@ async fn tx_process<T: AsyncWrite>(
let mut connection_dead_error: Option<String> = None;
loop {
if let Some(request) = out_rx.recv().await {
let request_id = request.message.id();
if let Some(error) = &connection_dead_error {
send_error_to_request(request.return_chan, request.stream_id, destination, error);
} else if let Err(error) = in_w.send(vec![request.message]).await {
Expand All @@ -194,6 +196,7 @@ async fn tx_process<T: AsyncWrite>(
} else if let Err(mpsc::error::SendError(return_chan)) = return_tx.send(ReturnChannel {
return_chan: request.return_chan,
stream_id: request.stream_id,
request_id,
}) {
let error = rx_process_has_shutdown_rx
.try_recv()
Expand Down Expand Up @@ -270,7 +273,7 @@ async fn rx_process<T: AsyncRead>(
// In order to handle that we have two seperate maps.
//
// We store the sender here if we receive from the tx_process task first
let mut from_tx_process: HashMap<i16, oneshot::Sender<Response>> = HashMap::new();
let mut from_tx_process: HashMap<i16, (oneshot::Sender<Response>, MessageId)> = HashMap::new();

// We store the response message here if we receive from the server first.
let mut from_server: HashMap<i16, Message> = HashMap::new();
Expand All @@ -280,7 +283,7 @@ async fn rx_process<T: AsyncRead>(
response = reader.next() => {
match response {
Some(Ok(response)) => {
for m in response {
for mut m in response {
let meta = m.metadata();
if let Ok(Metadata::Cassandra(CassandraMetadata { opcode: Opcode::Event, .. })) = meta {
if let Some(pushed_messages_tx) = pushed_messages_tx.as_ref() {
Expand All @@ -291,7 +294,8 @@ async fn rx_process<T: AsyncRead>(
None => {
from_server.insert(stream_id, m);
},
Some(return_tx) => {
Some((return_tx, request_id)) => {
m.set_request_id(request_id);
return_tx.send(Ok(m)).ok();
}
}
Expand All @@ -318,12 +322,13 @@ async fn rx_process<T: AsyncRead>(
}
},
original_request = return_rx.recv() => {
if let Some(ReturnChannel { return_chan, stream_id }) = original_request {
if let Some(ReturnChannel { return_chan, stream_id,request_id }) = original_request {
match from_server.remove(&stream_id) {
None => {
from_tx_process.insert(stream_id, return_chan);
from_tx_process.insert(stream_id, (return_chan, request_id));
}
Some(m) => {
Some(mut m) => {
m.set_request_id(request_id);
return_chan.send(Ok(m)).ok();
}
}
Expand All @@ -341,7 +346,7 @@ async fn rx_process<T: AsyncRead>(

async fn send_errors_and_shutdown(
mut return_rx: mpsc::UnboundedReceiver<ReturnChannel>,
mut waiting: HashMap<i16, oneshot::Sender<Response>>,
mut waiting: HashMap<i16, (oneshot::Sender<Response>, MessageId)>,
rx_process_has_shutdown_tx: oneshot::Sender<String>,
destination: SocketAddr,
message: &str,
Expand All @@ -355,7 +360,7 @@ async fn send_errors_and_shutdown(

return_rx.close();

for (stream_id, return_tx) in waiting.drain() {
for (stream_id, (return_tx, _)) in waiting.drain() {
return_tx
.send(Err(ResponseError {
cause: anyhow!(message.to_owned()),
Expand Down
71 changes: 41 additions & 30 deletions shotover/src/transforms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! Various types required for defining a transform
use crate::message::Messages;
use crate::message::{Message, MessageId, Messages};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use core::fmt;
use futures::Future;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::iter::Rev;
use std::net::SocketAddr;
Expand Down Expand Up @@ -187,6 +188,12 @@ impl<'a> Wrapper<'a> {
result
}

pub fn clone_requests_into_hashmap(&self, destination: &mut HashMap<MessageId, Message>) {
for request in &self.requests {
destination.insert(request.id(), request.clone());
}
}

#[cfg(test)]
pub fn new_test(requests: Messages) -> Self {
Wrapper {
Expand Down Expand Up @@ -271,46 +278,50 @@ impl<'a> Wrapper<'a> {
/// Implementing this trait is usually done using `#[async_trait]` macros.
#[async_trait]
pub trait Transform: Send {
/// This method should be implemented by your transform. The wrapper object contains the queries/
/// frames in a [`Vec<Message>`](crate::message::Message). Some protocols support multiple queries before a response is expected
/// for example pipelined Redis queries or batched Cassandra queries.
/// In order to implement your transform you can modify the messages:
/// * contained in requests_wrapper.requests
/// + these are the requests that will flow into the next transform in the chain.
/// * contained in the return value of `requests_wrapper.call_next_transform()`
/// + These are the responses that will flow back to the previous transform in the chain.
///
/// Shotover expects the same number of messages in [`wrapper.requests`](crate::transforms::Wrapper) to be returned as was passed
/// into the method via the parameter requests_wrapper. For in order protocols (such as Redis) you will
/// also need to ensure the order of responses matches the order of the queries.
/// But while doing so, also make sure to follow the below invariants when modifying the messages.
///
/// You can modify the messages in the wrapper struct to achieve your own designs. Your transform
/// can also modify the response from `requests_wrapper.call_next_transform()` if it needs
/// to. As long as you return the same number of messages as you received, you won't break behavior
/// from other transforms.
/// # Invariants
///
/// ## Invariants
/// Your transform method at a minimum needs to
/// * _Non-terminating_ - If your transform does not send the message to an external system or generate its own response to the query,
/// it will need to call and return the response from `requests_wrapper.call_next_transform()`. This ensures that your
/// transform will call any subsequent downstream transforms without needing to know about what they
/// do. This type of transform is called an non-terminating transform.
/// * _Terminating_ - Your transform can also choose not to call `requests_wrapper.call_next_transform()` if it sends the
/// * Non-terminating specific invariants
/// + If your transform does not send the message to an external system or generate its own response to the query,
/// it will need to call and return the response from `requests_wrapper.call_next_transform()`.
/// + This ensures that your transform will call any subsequent downstream transforms without needing to know about what they
/// do. This type of transform is called a non-terminating transform.
///
/// * Terminating specific invariants
/// + Your transform can also choose not to call `requests_wrapper.call_next_transform()` if it sends the
/// messages to an external system or generates its own response to the query e.g.
/// [`crate::transforms::cassandra::sink_single::CassandraSinkSingle`]. This type of transform
/// is called a Terminating transform (as no subsequent transforms in the chain will be called).
/// * _Message count_ - requests_wrapper.requests will contain 0 or more messages.
/// Your transform should return the same number of responses as messages received in requests_wrapper.requests. Transform that
/// don't do this explicitly for each call, should return the same number of responses as messages it receives over the lifetime
/// of the transform chain. A good example of this is the [`crate::transforms::coalesce::Coalesce`] transform. The
/// [`crate::transforms::sampler::Sampler`] transform is also another example of this, with a slightly different twist.
/// The number of responses will be the sames as the number of messages, as the sampled messages are sent to a subchain rather than
/// changing the behavior of the main chain.
/// [`crate::transforms::cassandra::sink_single::CassandraSinkSingle`].
/// + This type of transform is called a Terminating transform (as no subsequent transforms in the chain will be called).
///
/// * Request/Response invariants:
/// + Transforms must ensure that each request that passes through the transform has a corresponding response returned for it.
/// - A response/request pair can be identified by calling `request_id()` on a response and matching that to the id of a previous request.
/// - The response does not need to be returned within the same call to [`Transform::transform`] that the request was encountered.
/// But it must be returned eventually over the lifetime of the transform.
/// - If a transform deletes a request it must return a dummy frame message with its request_id set to the deleted request.
/// * Deprecated invariants:
/// + Many transforms rely on the number of responses equalling the number of requests and that requests will be in the same order as the responses.
/// Currently shotover maintains this gaurantee for backwards compatibility
/// but the gaurantee will be removed as soon as the transforms have been altered to no longer rely on it.
///
/// ## Naming
/// # Naming
/// Transform also have different naming conventions.
/// * Transform that interact with an external system are called Sinks.
/// * Transform that don't call subsequent chains via `requests_wrapper.call_next_transform()` are called terminating transforms.
/// * Transform that do call subsquent chains via `requests_wrapper.call_next_transform()` are non-terminating transforms.
///
/// You can have have a transforms that is both non-terminating and a sink.
/// You can have have a transform that is both non-terminating and a sink.
async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages>;

/// TODO: This method should be removed and integrated with `Transform::transform` once we properly support out of order protocols.
///
/// This method should be should be implemented by your transform if it is required to process pushed messages (typically events
/// or messages that your source is subscribed to. The wrapper object contains the queries/frames
/// in a [`Vec<Message`](crate::message::Message).
Expand All @@ -323,7 +334,7 @@ pub trait Transform: Send {
/// carries on through the chain, it will function correctly. You are able to add or remove messages as this method is not expecting
/// request/response pairs.
///
/// ## Invariants
/// # Invariants
/// * _Non-terminating_ - Your `transform_pushed` method should not be terminating as the messages should get passed back to the source, where they will terminate.
async fn transform_pushed<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result<Messages> {
let response = requests_wrapper.call_next_transform_pushed().await?;
Expand Down
42 changes: 24 additions & 18 deletions shotover/src/transforms/protect/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::frame::{
value::GenericValue, CassandraFrame, CassandraOperation, CassandraResult, Frame,
};
use crate::message::Messages;
use crate::message::{Message, MessageId, Messages};
use crate::transforms::protect::key_management::KeyManager;
pub use crate::transforms::protect::key_management::KeyManagerConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
Expand Down Expand Up @@ -56,6 +56,7 @@ impl TransformConfig for ProtectConfig {
.collect(),
key_source: self.key_manager.build().await?,
key_id: "XXXXXXX".to_string(),
requests: HashMap::new(),
}))
}
}
Expand All @@ -68,6 +69,7 @@ pub struct Protect {
// TODO this should be a function to create key_ids based on "something", e.g. primary key
// for the moment this is just a string
key_id: String,
requests: HashMap<MessageId, Message>,
}

impl TransformBuilder for Protect {
Expand Down Expand Up @@ -187,27 +189,31 @@ impl Transform for Protect {
}
}

let mut original_messages = requests_wrapper.requests.clone();
let mut result = requests_wrapper.call_next_transform().await?;

for (response, request) in result.iter_mut().zip(original_messages.iter_mut()) {
let mut invalidate_cache = false;
if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = request.frame() {
if let Some(Frame::Cassandra(CassandraFrame {
operation: CassandraOperation::Result(CassandraResult::Rows { rows, .. }),
..
})) = response.frame()
{
for statement in operation.queries() {
invalidate_cache |= self.decrypt_results(statement, rows).await?
requests_wrapper.clone_requests_into_hashmap(&mut self.requests);
let mut responses = requests_wrapper.call_next_transform().await?;

for response in &mut responses {
if let Some(request_id) = response.request_id() {
let mut request = self.requests.remove(&request_id).unwrap();

let mut invalidate_cache = false;
if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = request.frame() {
if let Some(Frame::Cassandra(CassandraFrame {
operation: CassandraOperation::Result(CassandraResult::Rows { rows, .. }),
..
})) = response.frame()
{
for statement in operation.queries() {
invalidate_cache |= self.decrypt_results(statement, rows).await?
}
}
}
}
if invalidate_cache {
response.invalidate_cache();
if invalidate_cache {
response.invalidate_cache();
}
}
}

Ok(result)
Ok(responses)
}
}

0 comments on commit 26ad3d5

Please sign in to comment.