Skip to content

Commit

Permalink
Add Wrapper::close_client_connection
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 20, 2024
1 parent 17ccad5 commit 82b0897
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 20 deletions.
38 changes: 22 additions & 16 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,9 @@ impl<C: CodecBuilder + 'static> Handler<C> {
.run_loop(&client_details, local_addr, in_rx, out_tx, force_run_chain)
.await;

// Only flush messages if we are shutting down due to application shutdown
// Only flush messages if we are shutting down due to shotover shutdown or client disconnect
// If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing.
if result.is_ok() {
if let Ok(CloseReason::ShotoverShutdown | CloseReason::ClientClosed) = result {
match self.chain.process_request(&mut Wrapper::flush()).await {
Ok(_) => {}
Err(e) => error!(
Expand All @@ -649,7 +649,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}
}

result
result.map(|_| ())
}

async fn receive_with_timeout(
Expand Down Expand Up @@ -677,7 +677,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
mut in_rx: mpsc::Receiver<Messages>,
out_tx: mpsc::UnboundedSender<Messages>,
force_run_chain: Arc<Notify>,
) -> Result<()> {
) -> Result<CloseReason> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
while !self.shutdown.is_shutdown() {
Expand All @@ -688,16 +688,16 @@ impl<C: CodecBuilder + 'static> Handler<C> {
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
return Ok(CloseReason::ShotoverShutdown);
}
() = force_run_chain.notified() => {
let mut requests = vec!();
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
}
debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests);
if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? {
return Ok(())
if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? {
return Ok(close_reason)
}
},
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
Expand All @@ -707,30 +707,29 @@ impl<C: CodecBuilder + 'static> Handler<C> {
requests.extend(x);
}
debug!("Received requests from client {:?}", requests);
if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? {
return Ok(())
if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? {
return Ok(close_reason)
}
}
// Either we timed out the connection or the client disconnected, so terminate this connection
None => return Ok(()),
None => return Ok(CloseReason::ClientClosed),
}
},
};
}

Ok(())
Ok(CloseReason::ShotoverShutdown)
}

async fn send_receive_chain(
async fn process_requests(
&mut self,
local_addr: SocketAddr,
out_tx: &mpsc::UnboundedSender<Messages>,
requests: Messages,
) -> Result<Option<CloseReason>> {
self.pending_requests.process_requests(&requests);

let mut wrapper = Wrapper::new_with_addr(requests, local_addr);

self.pending_requests.process_requests(&wrapper.requests);
let responses = match self.chain.process_request(&mut wrapper).await {
Ok(x) => x,
Err(err) => {
Expand All @@ -748,17 +747,24 @@ impl<C: CodecBuilder + 'static> Handler<C> {
debug!("sending response to client: {:?}", responses);
if out_tx.send(responses).is_err() {
// the client has disconnected so we should terminate this connection
return Ok(Some(CloseReason::Generic));
return Ok(Some(CloseReason::ClientClosed));
}
}

// if requested by a transform, close connection AFTER sending any responses back to the client
if wrapper.close_client_connection {
return Ok(Some(CloseReason::TransformRequested));
}

Ok(None)
}
}

/// Indicates that the connection to the client must be closed.
enum CloseReason {
Generic,
TransformRequested,
ClientClosed,
ShotoverShutdown,
}

/// Listens for the server shutdown signal.
Expand Down
21 changes: 17 additions & 4 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,12 @@ impl Transform for KafkaSinkCluster {
.context("Failed to receive responses")?
};

self.process_responses(&mut responses)
.await
.context("Failed to process responses")?;
self.process_responses(
&mut responses,
&mut requests_wrapper.close_client_connection,
)
.await
.context("Failed to process responses")?;
Ok(responses)
}
}
Expand Down Expand Up @@ -1163,7 +1166,11 @@ routing message to a random node so that:
Ok(base)
}

async fn process_responses(&mut self, responses: &mut [Message]) -> Result<()> {
async fn process_responses(
&mut self,
responses: &mut [Message],
close_client_connection: &mut bool,
) -> Result<()> {
for response in responses.iter_mut() {
let request_id = response.request_id().unwrap();
match response.frame() {
Expand Down Expand Up @@ -1209,6 +1216,12 @@ routing message to a random node so that:
body: ResponseBody::SaslAuthenticate(authenticate),
..
})) => {
// The broker always closes the connection after an auth failure response,
// so we should do the same.
if authenticate.error_code != 0 {
tracing::debug!("Closing connection to client after responses are sent");
*close_client_connection = true;
}
self.process_sasl_authenticate(authenticate)?;
}
Some(Frame::Kafka(KafkaFrame::Response {
Expand Down
8 changes: 8 additions & 0 deletions shotover/src/transforms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ pub struct Wrapper<'a> {
/// This can occur at any time but will always occur before the transform is destroyed due to either
/// shotover or the transform's chain shutting down.
pub flush: bool,
/// Set to false by default.
/// When set to true the connection to the client is closed after the stack of `Transform::transform` calls returns.
pub close_client_connection: bool,
}

/// [`Wrapper`] will not (cannot) bring the current list of transforms that it needs to traverse with it
Expand All @@ -170,6 +173,7 @@ impl<'a> Clone for Wrapper<'a> {
transforms: [].iter_mut(),
local_addr: self.local_addr,
flush: self.flush,
close_client_connection: self.close_client_connection,
}
}
}
Expand All @@ -181,6 +185,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> {
transforms: std::mem::take(&mut self.transforms),
local_addr: self.local_addr,
flush: self.flush,
close_client_connection: self.close_client_connection,
}
}

Expand Down Expand Up @@ -232,6 +237,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> {
transforms: [].iter_mut(),
local_addr: "127.0.0.1:8000".parse().unwrap(),
flush: false,
close_client_connection: false,
}
}

Expand All @@ -241,6 +247,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> {
transforms: [].iter_mut(),
local_addr,
flush: false,
close_client_connection: false,
}
}

Expand All @@ -251,6 +258,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> {
// The connection is closed so we need to just fake an address here
local_addr: "127.0.0.1:10000".parse().unwrap(),
flush: true,
close_client_connection: false,
}
}

Expand Down

0 comments on commit 82b0897

Please sign in to comment.