From 82b0897f286e2ece02771de11945ed954ac04b08 Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Mon, 5 Aug 2024 15:55:28 +1000 Subject: [PATCH] Add Wrapper::close_client_connection --- shotover/src/server.rs | 38 +++++++++++-------- .../src/transforms/kafka/sink_cluster/mod.rs | 21 ++++++++-- shotover/src/transforms/mod.rs | 8 ++++ 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/shotover/src/server.rs b/shotover/src/server.rs index c993cc7b9..0ea464aa2 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -634,9 +634,9 @@ impl Handler { .run_loop(&client_details, local_addr, in_rx, out_tx, force_run_chain) .await; - // Only flush messages if we are shutting down due to application shutdown + // Only flush messages if we are shutting down due to shotover shutdown or client disconnect // If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing. - if result.is_ok() { + if let Ok(CloseReason::ShotoverShutdown | CloseReason::ClientClosed) = result { match self.chain.process_request(&mut Wrapper::flush()).await { Ok(_) => {} Err(e) => error!( @@ -649,7 +649,7 @@ impl Handler { } } - result + result.map(|_| ()) } async fn receive_with_timeout( @@ -677,7 +677,7 @@ impl Handler { mut in_rx: mpsc::Receiver, out_tx: mpsc::UnboundedSender, force_run_chain: Arc, - ) -> Result<()> { + ) -> Result { // As long as the shutdown signal has not been received, try to read a // new request frame. while !self.shutdown.is_shutdown() { @@ -688,7 +688,7 @@ impl Handler { _ = self.shutdown.recv() => { // If a shutdown signal is received, return from `run`. // This will result in the task terminating. - return Ok(()); + return Ok(CloseReason::ShotoverShutdown); } () = force_run_chain.notified() => { let mut requests = vec!(); @@ -696,8 +696,8 @@ impl Handler { requests.extend(x); } debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests); - if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { - return Ok(()) + if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? { + return Ok(close_reason) } }, requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { @@ -707,30 +707,29 @@ impl Handler { requests.extend(x); } debug!("Received requests from client {:?}", requests); - if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { - return Ok(()) + if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? { + return Ok(close_reason) } } // Either we timed out the connection or the client disconnected, so terminate this connection - None => return Ok(()), + None => return Ok(CloseReason::ClientClosed), } }, }; } - Ok(()) + Ok(CloseReason::ShotoverShutdown) } - async fn send_receive_chain( + async fn process_requests( &mut self, local_addr: SocketAddr, out_tx: &mpsc::UnboundedSender, requests: Messages, ) -> Result> { - self.pending_requests.process_requests(&requests); - let mut wrapper = Wrapper::new_with_addr(requests, local_addr); + self.pending_requests.process_requests(&wrapper.requests); let responses = match self.chain.process_request(&mut wrapper).await { Ok(x) => x, Err(err) => { @@ -748,17 +747,24 @@ impl Handler { debug!("sending response to client: {:?}", responses); if out_tx.send(responses).is_err() { // the client has disconnected so we should terminate this connection - return Ok(Some(CloseReason::Generic)); + return Ok(Some(CloseReason::ClientClosed)); } } + // if requested by a transform, close connection AFTER sending any responses back to the client + if wrapper.close_client_connection { + return Ok(Some(CloseReason::TransformRequested)); + } + Ok(None) } } /// Indicates that the connection to the client must be closed. enum CloseReason { - Generic, + TransformRequested, + ClientClosed, + ShotoverShutdown, } /// Listens for the server shutdown signal. diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index a5f2ee714..71daafa2e 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -375,9 +375,12 @@ impl Transform for KafkaSinkCluster { .context("Failed to receive responses")? }; - self.process_responses(&mut responses) - .await - .context("Failed to process responses")?; + self.process_responses( + &mut responses, + &mut requests_wrapper.close_client_connection, + ) + .await + .context("Failed to process responses")?; Ok(responses) } } @@ -1163,7 +1166,11 @@ routing message to a random node so that: Ok(base) } - async fn process_responses(&mut self, responses: &mut [Message]) -> Result<()> { + async fn process_responses( + &mut self, + responses: &mut [Message], + close_client_connection: &mut bool, + ) -> Result<()> { for response in responses.iter_mut() { let request_id = response.request_id().unwrap(); match response.frame() { @@ -1209,6 +1216,12 @@ routing message to a random node so that: body: ResponseBody::SaslAuthenticate(authenticate), .. })) => { + // The broker always closes the connection after an auth failure response, + // so we should do the same. + if authenticate.error_code != 0 { + tracing::debug!("Closing connection to client after responses are sent"); + *close_client_connection = true; + } self.process_sasl_authenticate(authenticate)?; } Some(Frame::Kafka(KafkaFrame::Response { diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index bf0ae819b..f5ef1c4a0 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -158,6 +158,9 @@ pub struct Wrapper<'a> { /// This can occur at any time but will always occur before the transform is destroyed due to either /// shotover or the transform's chain shutting down. pub flush: bool, + /// Set to false by default. + /// When set to true the connection to the client is closed after the stack of `Transform::transform` calls returns. + pub close_client_connection: bool, } /// [`Wrapper`] will not (cannot) bring the current list of transforms that it needs to traverse with it @@ -170,6 +173,7 @@ impl<'a> Clone for Wrapper<'a> { transforms: [].iter_mut(), local_addr: self.local_addr, flush: self.flush, + close_client_connection: self.close_client_connection, } } } @@ -181,6 +185,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: std::mem::take(&mut self.transforms), local_addr: self.local_addr, flush: self.flush, + close_client_connection: self.close_client_connection, } } @@ -232,6 +237,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: [].iter_mut(), local_addr: "127.0.0.1:8000".parse().unwrap(), flush: false, + close_client_connection: false, } } @@ -241,6 +247,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: [].iter_mut(), local_addr, flush: false, + close_client_connection: false, } } @@ -251,6 +258,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { // The connection is closed so we need to just fake an address here local_addr: "127.0.0.1:10000".parse().unwrap(), flush: true, + close_client_connection: false, } }