Skip to content

Commit

Permalink
Add force_run_chain Notify (#1525)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Mar 14, 2024
1 parent 5523020 commit 97f027a
Show file tree
Hide file tree
Showing 31 changed files with 180 additions and 91 deletions.
3 changes: 2 additions & 1 deletion custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::TransformContextBuilder;
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
Expand Down Expand Up @@ -33,7 +34,7 @@ pub struct RedisGetRewriteBuilder {
}

impl TransformBuilder for RedisGetRewriteBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(RedisGetRewrite {
get_requests: MessageIdSet::default(),
result: self.result.clone(),
Expand Down
24 changes: 13 additions & 11 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig};
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper};
use shotover::transforms::{
TransformConfig, TransformContextBuilder, TransformContextConfig, Wrapper,
};

fn criterion_benchmark(c: &mut Criterion) {
crate::init();
Expand All @@ -38,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("loopback", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand All @@ -57,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("nullsink", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -95,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_filter", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -129,7 +131,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_timestamp_tagger_untagged", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper_set.clone(),
},
BenchInput::bench,
Expand All @@ -148,7 +150,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_timestamp_tagger_tagged", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper_get.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -177,7 +179,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_cluster_ports_rewrite", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -224,7 +226,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_request_throttling_unparsed", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -278,7 +280,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_rewrite_peers_passthrough", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -324,7 +326,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_protect_unprotected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand All @@ -339,7 +341,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_protect_protected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down
50 changes: 35 additions & 15 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::{TransformContextConfig, Wrapper};
use crate::transforms::{TransformContextBuilder, TransformContextConfig, Wrapper};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand All @@ -16,7 +16,7 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, watch, OwnedSemaphorePermit, Semaphore};
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore};
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Duration;
Expand Down Expand Up @@ -190,10 +190,15 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
let (pushed_messages_tx, pushed_messages_rx) =
tokio::sync::mpsc::unbounded_channel::<Messages>();

let force_run_chain = Arc::new(Notify::new());
let context = TransformContextBuilder {
force_run_chain: force_run_chain.clone(),
};

let handler = Handler {
chain: self
.chain_builder
.build_with_pushed_messages(pushed_messages_tx),
.build_with_pushed_messages(pushed_messages_tx, context),
codec: self.codec.clone(),
shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()),
tls: self.tls.clone(),
Expand All @@ -206,7 +211,7 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
self.connection_handles.push(tokio::spawn(
async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run(stream, transport).await {
if let Err(err) = handler.run(stream, transport, force_run_chain).await {
error!(
"{:?}",
err.context("connection was unexpectedly terminated")
Expand Down Expand Up @@ -576,7 +581,12 @@ impl<C: CodecBuilder + 'static> Handler<C> {
///
/// When the shutdown signal is received, the connection is processed until
/// it reaches a safe state, at which point it is terminated.
pub async fn run(mut self, stream: TcpStream, transport: Transport) -> Result<()> {
pub async fn run(
mut self,
stream: TcpStream,
transport: Transport,
force_run_chain: Arc<Notify>,
) -> Result<()> {
stream.set_nodelay(true)?;

let client_details = stream
Expand Down Expand Up @@ -658,7 +668,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
};

let result = self
.process_messages(&client_details, local_addr, in_rx, out_tx)
.process_messages(&client_details, local_addr, in_rx, out_tx, force_run_chain)
.await;

// Flush messages regardless of if we are shutting down due to a failure or due to application shutdown
Expand Down Expand Up @@ -700,13 +710,32 @@ impl<C: CodecBuilder + 'static> Handler<C> {
local_addr: SocketAddr,
mut in_rx: mpsc::Receiver<Messages>,
out_tx: mpsc::UnboundedSender<Messages>,
force_run_chain: Arc<Notify>,
) -> Result<()> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
while !self.shutdown.is_shutdown() {
// While reading a request frame, also listen for the shutdown signal
debug!("Waiting for message {client_details}");
let responses = tokio::select! {
biased;
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
}
Some(responses) = self.pushed_messages_rx.recv() => {
debug!("Received unrequested responses from destination {:?}", responses);
self.process_backward(client_details, local_addr, responses).await?
}
() = force_run_chain.notified() => {
let mut requests = vec!();
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
}
debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests);
self.process_forward(client_details, local_addr, &out_tx, requests).await?
},
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
match requests {
Some(mut requests) => {
Expand All @@ -722,15 +751,6 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}
}
},
Some(responses) = self.pushed_messages_rx.recv() => {
debug!("Received unrequested responses from destination {:?}", responses);
self.process_backward(client_details, local_addr, responses).await?
}
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
}
};

debug!("sending response to client: {:?}", responses);
Expand Down
6 changes: 4 additions & 2 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, Wrapper,
};
use crate::{
frame::{
value::{GenericValue, IntSize},
Expand Down Expand Up @@ -52,7 +54,7 @@ impl CassandraPeersRewrite {
}

impl TransformBuilder for CassandraPeersRewrite {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
5 changes: 3 additions & 2 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig,
Wrapper,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -160,7 +161,7 @@ impl CassandraSinkClusterBuilder {
}

impl TransformBuilder for CassandraSinkClusterBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(CassandraSinkCluster {
contact_points: self.contact_points.clone(),
message_rewriter: self.message_rewriter.clone(),
Expand Down
5 changes: 3 additions & 2 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::Response;
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig,
Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -83,7 +84,7 @@ impl CassandraSinkSingleBuilder {
}

impl TransformBuilder for CassandraSinkSingleBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(CassandraSinkSingle {
outbound: None,
version: self.version,
Expand Down
25 changes: 18 additions & 7 deletions shotover/src/transforms/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tracing::{debug, error, info, trace, Instrument};

use super::TransformContextBuilder;

type InnerChain = Vec<TransformAndMetrics>;

#[derive(Debug)]
Expand Down Expand Up @@ -233,9 +235,9 @@ pub struct TransformBuilderAndMetrics {
}

impl TransformBuilderAndMetrics {
fn build(&self) -> TransformAndMetrics {
fn build(&self, context: TransformContextBuilder) -> TransformAndMetrics {
TransformAndMetrics {
transform: self.builder.build(),
transform: self.builder.build(context),
transform_total: self.transform_total.clone(),
transform_failures: self.transform_failures.clone(),
transform_latency: self.transform_latency.clone(),
Expand Down Expand Up @@ -331,7 +333,11 @@ impl TransformChainBuilder {
errors
}

pub fn build_buffered(&self, buffer_size: usize) -> BufferedChain {
pub fn build_buffered(
&self,
buffer_size: usize,
context: TransformContextBuilder,
) -> BufferedChain {
let (tx, mut rx) = mpsc::channel::<BufferedChainMessages>(buffer_size);

#[cfg(test)]
Expand All @@ -341,7 +347,7 @@ impl TransformChainBuilder {

// Even though we don't keep the join handle, this thread will wrap up once all corresponding senders have been dropped.

let mut chain = self.build();
let mut chain = self.build(context);
let _jh = tokio::spawn(
async move {
while let Some(BufferedChainMessages {
Expand Down Expand Up @@ -398,8 +404,12 @@ impl TransformChainBuilder {
}

/// Clone the chain while adding a producer for the pushed messages channel
pub fn build(&self) -> TransformChain {
let chain = self.chain.iter().map(|x| x.build()).collect();
pub fn build(&self, context: TransformContextBuilder) -> TransformChain {
let chain = self
.chain
.iter()
.map(|x| x.build(context.clone()))
.collect();

TransformChain {
name: self.name,
Expand All @@ -414,12 +424,13 @@ impl TransformChainBuilder {
pub fn build_with_pushed_messages(
&self,
pushed_messages_tx: mpsc::UnboundedSender<Messages>,
context: TransformContextBuilder,
) -> TransformChain {
let chain = self
.chain
.iter()
.map(|x| {
let mut transform = x.build();
let mut transform = x.build(context.clone());
transform
.transform
.set_pushed_messages_tx(pushed_messages_tx.clone());
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::TransformContextConfig;
use super::{TransformContextBuilder, TransformContextConfig};
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand Down Expand Up @@ -39,7 +39,7 @@ impl TransformConfig for CoalesceConfig {
}

impl TransformBuilder for Coalesce {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
3 changes: 2 additions & 1 deletion shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::message::Messages;
/// It could also be used to ensure that messages round trip correctly when parsed.
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformConfig;
use crate::transforms::TransformContextBuilder;
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformContextConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
Expand Down Expand Up @@ -76,7 +77,7 @@ pub struct DebugForceParse {
}

impl TransformBuilder for DebugForceParse {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
Loading

0 comments on commit 97f027a

Please sign in to comment.