Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force_run_chain Notify #1525

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::TransformContextBuilder;
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
};
Expand Down Expand Up @@ -33,7 +34,7 @@ pub struct RedisGetRewriteBuilder {
}

impl TransformBuilder for RedisGetRewriteBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(RedisGetRewrite {
get_requests: MessageIdSet::default(),
result: self.result.clone(),
Expand Down
24 changes: 13 additions & 11 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use shotover::transforms::protect::{KeyManagerConfig, ProtectConfig};
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::redis::timestamp_tagging::RedisTimestampTagger;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{TransformConfig, TransformContextConfig, Wrapper};
use shotover::transforms::{
TransformConfig, TransformContextBuilder, TransformContextConfig, Wrapper,
};

fn criterion_benchmark(c: &mut Criterion) {
crate::init();
Expand All @@ -38,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("loopback", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand All @@ -57,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("nullsink", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -95,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_filter", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -129,7 +131,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_timestamp_tagger_untagged", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper_set.clone(),
},
BenchInput::bench,
Expand All @@ -148,7 +150,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_timestamp_tagger_tagged", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper_get.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -177,7 +179,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("redis_cluster_ports_rewrite", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -224,7 +226,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_request_throttling_unparsed", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -278,7 +280,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_rewrite_peers_passthrough", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down Expand Up @@ -324,7 +326,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_protect_unprotected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand All @@ -339,7 +341,7 @@ fn criterion_benchmark(c: &mut Criterion) {
group.bench_function("cassandra_protect_protected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput {
chain: chain.build(),
chain: chain.build(TransformContextBuilder::new()),
wrapper: wrapper.clone(),
},
BenchInput::bench,
Expand Down
50 changes: 35 additions & 15 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::{TransformContextConfig, Wrapper};
use crate::transforms::{TransformContextBuilder, TransformContextConfig, Wrapper};
use anyhow::{anyhow, Context, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand All @@ -16,7 +16,7 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, watch, OwnedSemaphorePermit, Semaphore};
use tokio::sync::{mpsc, watch, Notify, OwnedSemaphorePermit, Semaphore};
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::Duration;
Expand Down Expand Up @@ -190,10 +190,15 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
let (pushed_messages_tx, pushed_messages_rx) =
tokio::sync::mpsc::unbounded_channel::<Messages>();

let force_run_chain = Arc::new(Notify::new());
let context = TransformContextBuilder {
force_run_chain: force_run_chain.clone(),
};

let handler = Handler {
chain: self
.chain_builder
.build_with_pushed_messages(pushed_messages_tx),
.build_with_pushed_messages(pushed_messages_tx, context),
codec: self.codec.clone(),
shutdown: Shutdown::new(self.trigger_shutdown_rx.clone()),
tls: self.tls.clone(),
Expand All @@ -206,7 +211,7 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
self.connection_handles.push(tokio::spawn(
async move {
// Process the connection. If an error is encountered, log it.
if let Err(err) = handler.run(stream, transport).await {
if let Err(err) = handler.run(stream, transport, force_run_chain).await {
error!(
"{:?}",
err.context("connection was unexpectedly terminated")
Expand Down Expand Up @@ -576,7 +581,12 @@ impl<C: CodecBuilder + 'static> Handler<C> {
///
/// When the shutdown signal is received, the connection is processed until
/// it reaches a safe state, at which point it is terminated.
pub async fn run(mut self, stream: TcpStream, transport: Transport) -> Result<()> {
pub async fn run(
mut self,
stream: TcpStream,
transport: Transport,
force_run_chain: Arc<Notify>,
) -> Result<()> {
stream.set_nodelay(true)?;

let client_details = stream
Expand Down Expand Up @@ -658,7 +668,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
};

let result = self
.process_messages(&client_details, local_addr, in_rx, out_tx)
.process_messages(&client_details, local_addr, in_rx, out_tx, force_run_chain)
.await;

// Flush messages regardless of if we are shutting down due to a failure or due to application shutdown
Expand Down Expand Up @@ -700,13 +710,32 @@ impl<C: CodecBuilder + 'static> Handler<C> {
local_addr: SocketAddr,
mut in_rx: mpsc::Receiver<Messages>,
out_tx: mpsc::UnboundedSender<Messages>,
force_run_chain: Arc<Notify>,
) -> Result<()> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
while !self.shutdown.is_shutdown() {
// While reading a request frame, also listen for the shutdown signal
debug!("Waiting for message {client_details}");
let responses = tokio::select! {
biased;
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
}
Some(responses) = self.pushed_messages_rx.recv() => {
debug!("Received unrequested responses from destination {:?}", responses);
self.process_backward(client_details, local_addr, responses).await?
}
() = force_run_chain.notified() => {
let mut requests = vec!();
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
}
debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests);
self.process_forward(client_details, local_addr, &out_tx, requests).await?
},
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
match requests {
Some(mut requests) => {
Expand All @@ -722,15 +751,6 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}
}
},
Some(responses) = self.pushed_messages_rx.recv() => {
debug!("Received unrequested responses from destination {:?}", responses);
self.process_backward(client_details, local_addr, responses).await?
}
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
}
};

debug!("sending response to client: {:?}", responses);
Expand Down
6 changes: 4 additions & 2 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, Wrapper,
};
use crate::{
frame::{
value::{GenericValue, IntSize},
Expand Down Expand Up @@ -52,7 +54,7 @@ impl CassandraPeersRewrite {
}

impl TransformBuilder for CassandraPeersRewrite {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
5 changes: 3 additions & 2 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::{CassandraConnection, Response, ResponseError};
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig,
Wrapper,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -160,7 +161,7 @@ impl CassandraSinkClusterBuilder {
}

impl TransformBuilder for CassandraSinkClusterBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(CassandraSinkCluster {
contact_points: self.contact_points.clone(),
message_rewriter: self.message_rewriter.clone(),
Expand Down
5 changes: 3 additions & 2 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::cassandra::connection::Response;
use crate::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
Transform, TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig,
Wrapper,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -83,7 +84,7 @@ impl CassandraSinkSingleBuilder {
}

impl TransformBuilder for CassandraSinkSingleBuilder {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(CassandraSinkSingle {
outbound: None,
version: self.version,
Expand Down
25 changes: 18 additions & 7 deletions shotover/src/transforms/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use tokio::sync::{mpsc, oneshot};
use tokio::time::{Duration, Instant};
use tracing::{debug, error, info, trace, Instrument};

use super::TransformContextBuilder;

type InnerChain = Vec<TransformAndMetrics>;

#[derive(Debug)]
Expand Down Expand Up @@ -233,9 +235,9 @@ pub struct TransformBuilderAndMetrics {
}

impl TransformBuilderAndMetrics {
fn build(&self) -> TransformAndMetrics {
fn build(&self, context: TransformContextBuilder) -> TransformAndMetrics {
TransformAndMetrics {
transform: self.builder.build(),
transform: self.builder.build(context),
transform_total: self.transform_total.clone(),
transform_failures: self.transform_failures.clone(),
transform_latency: self.transform_latency.clone(),
Expand Down Expand Up @@ -331,7 +333,11 @@ impl TransformChainBuilder {
errors
}

pub fn build_buffered(&self, buffer_size: usize) -> BufferedChain {
pub fn build_buffered(
&self,
buffer_size: usize,
context: TransformContextBuilder,
) -> BufferedChain {
let (tx, mut rx) = mpsc::channel::<BufferedChainMessages>(buffer_size);

#[cfg(test)]
Expand All @@ -341,7 +347,7 @@ impl TransformChainBuilder {

// Even though we don't keep the join handle, this thread will wrap up once all corresponding senders have been dropped.

let mut chain = self.build();
let mut chain = self.build(context);
let _jh = tokio::spawn(
async move {
while let Some(BufferedChainMessages {
Expand Down Expand Up @@ -398,8 +404,12 @@ impl TransformChainBuilder {
}

/// Clone the chain while adding a producer for the pushed messages channel
pub fn build(&self) -> TransformChain {
let chain = self.chain.iter().map(|x| x.build()).collect();
pub fn build(&self, context: TransformContextBuilder) -> TransformChain {
let chain = self
.chain
.iter()
.map(|x| x.build(context.clone()))
.collect();

TransformChain {
name: self.name,
Expand All @@ -414,12 +424,13 @@ impl TransformChainBuilder {
pub fn build_with_pushed_messages(
&self,
pushed_messages_tx: mpsc::UnboundedSender<Messages>,
context: TransformContextBuilder,
) -> TransformChain {
let chain = self
.chain
.iter()
.map(|x| {
let mut transform = x.build();
let mut transform = x.build(context.clone());
transform
.transform
.set_pushed_messages_tx(pushed_messages_tx.clone());
Expand Down
4 changes: 2 additions & 2 deletions shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::TransformContextConfig;
use super::{TransformContextBuilder, TransformContextConfig};
use crate::message::Messages;
use crate::transforms::{Transform, TransformBuilder, TransformConfig, Wrapper};
use anyhow::Result;
Expand Down Expand Up @@ -39,7 +39,7 @@ impl TransformConfig for CoalesceConfig {
}

impl TransformBuilder for Coalesce {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
3 changes: 2 additions & 1 deletion shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::message::Messages;
/// It could also be used to ensure that messages round trip correctly when parsed.
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformConfig;
use crate::transforms::TransformContextBuilder;
#[cfg(feature = "alpha-transforms")]
use crate::transforms::TransformContextConfig;
use crate::transforms::{Transform, TransformBuilder, Wrapper};
Expand Down Expand Up @@ -76,7 +77,7 @@ pub struct DebugForceParse {
}

impl TransformBuilder for DebugForceParse {
fn build(&self) -> Box<dyn Transform> {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}

Expand Down
Loading