From 70cf0beeff553b98a5ac9fafb8574da6998ac133 Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Tue, 20 Aug 2024 16:26:47 +1000 Subject: [PATCH] Add Wrapper::close_client_connection (#1722) --- shotover/src/server.rs | 36 +++++++++++-------- .../src/transforms/kafka/sink_cluster/mod.rs | 25 ++++++++++--- shotover/src/transforms/mod.rs | 10 ++++++ 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/shotover/src/server.rs b/shotover/src/server.rs index c993cc7b9..f85973a51 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -634,9 +634,9 @@ impl Handler { .run_loop(&client_details, local_addr, in_rx, out_tx, force_run_chain) .await; - // Only flush messages if we are shutting down due to application shutdown + // Only flush messages if we are shutting down due to shotover shutdown or client disconnect // If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing. - if result.is_ok() { + if let Ok(CloseReason::ShotoverShutdown | CloseReason::ClientClosed) = result { match self.chain.process_request(&mut Wrapper::flush()).await { Ok(_) => {} Err(e) => error!( @@ -649,7 +649,7 @@ impl Handler { } } - result + result.map(|_| ()) } async fn receive_with_timeout( @@ -677,7 +677,7 @@ impl Handler { mut in_rx: mpsc::Receiver, out_tx: mpsc::UnboundedSender, force_run_chain: Arc, - ) -> Result<()> { + ) -> Result { // As long as the shutdown signal has not been received, try to read a // new request frame. while !self.shutdown.is_shutdown() { @@ -688,7 +688,7 @@ impl Handler { _ = self.shutdown.recv() => { // If a shutdown signal is received, return from `run`. // This will result in the task terminating. - return Ok(()); + return Ok(CloseReason::ShotoverShutdown); } () = force_run_chain.notified() => { let mut requests = vec!(); @@ -696,8 +696,8 @@ impl Handler { requests.extend(x); } debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests); - if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { - return Ok(()) + if let Some(close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { + return Ok(close_reason) } }, requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { @@ -707,18 +707,18 @@ impl Handler { requests.extend(x); } debug!("Received requests from client {:?}", requests); - if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { - return Ok(()) + if let Some(close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { + return Ok(close_reason) } } // Either we timed out the connection or the client disconnected, so terminate this connection - None => return Ok(()), + None => return Ok(CloseReason::ClientClosed), } }, }; } - Ok(()) + Ok(CloseReason::ShotoverShutdown) } async fn send_receive_chain( @@ -727,10 +727,9 @@ impl Handler { out_tx: &mpsc::UnboundedSender, requests: Messages, ) -> Result> { - self.pending_requests.process_requests(&requests); - let mut wrapper = Wrapper::new_with_addr(requests, local_addr); + self.pending_requests.process_requests(&wrapper.requests); let responses = match self.chain.process_request(&mut wrapper).await { Ok(x) => x, Err(err) => { @@ -748,17 +747,24 @@ impl Handler { debug!("sending response to client: {:?}", responses); if out_tx.send(responses).is_err() { // the client has disconnected so we should terminate this connection - return Ok(Some(CloseReason::Generic)); + return Ok(Some(CloseReason::ClientClosed)); } } + // if requested by a transform, close connection AFTER sending any responses back to the client + if wrapper.close_client_connection { + return Ok(Some(CloseReason::TransformRequested)); + } + Ok(None) } } /// Indicates that the connection to the client must be closed. enum CloseReason { - Generic, + TransformRequested, + ClientClosed, + ShotoverShutdown, } /// Listens for the server shutdown signal. diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index a5f2ee714..a72844de2 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -375,9 +375,12 @@ impl Transform for KafkaSinkCluster { .context("Failed to receive responses")? }; - self.process_responses(&mut responses) - .await - .context("Failed to process responses")?; + self.process_responses( + &mut responses, + &mut requests_wrapper.close_client_connection, + ) + .await + .context("Failed to process responses")?; Ok(responses) } } @@ -1163,7 +1166,11 @@ routing message to a random node so that: Ok(base) } - async fn process_responses(&mut self, responses: &mut [Message]) -> Result<()> { + async fn process_responses( + &mut self, + responses: &mut [Message], + close_client_connection: &mut bool, + ) -> Result<()> { for response in responses.iter_mut() { let request_id = response.request_id().unwrap(); match response.frame() { @@ -1209,7 +1216,7 @@ routing message to a random node so that: body: ResponseBody::SaslAuthenticate(authenticate), .. })) => { - self.process_sasl_authenticate(authenticate)?; + self.process_sasl_authenticate(authenticate, close_client_connection)?; } Some(Frame::Kafka(KafkaFrame::Response { body: ResponseBody::Produce(produce), @@ -1389,7 +1396,15 @@ routing message to a random node so that: fn process_sasl_authenticate( &mut self, authenticate: &mut SaslAuthenticateResponse, + close_client_connection: &mut bool, ) -> Result<()> { + // The broker always closes the connection after an auth failure response, + // so we should do the same. + if authenticate.error_code != 0 { + tracing::debug!("Closing connection to client due to auth failure"); + *close_client_connection = true; + } + if let Some(sasl_mechanism) = &self.sasl_mechanism { if SASL_SCRAM_MECHANISMS.contains(&sasl_mechanism.as_str()) { if let Some(scram_over_mtls) = &mut self.authorize_scram_over_mtls { diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index bf0ae819b..f430eaa72 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -157,7 +157,12 @@ pub struct Wrapper<'a> { /// When true transforms must flush any buffered messages into the messages field. /// This can occur at any time but will always occur before the transform is destroyed due to either /// shotover or the transform's chain shutting down. + /// The one exception is if [`Wrapper::close_client_connection`] was set to true, in which case no flush occurs. pub flush: bool, + /// Set to false by default. + /// Transforms can set this to true to force the connection to the client to be closed after the stack of `Transform::transform` calls returns. + /// When closed in this way, the chain will not be flushed and no further calls to the chain will be made before it is dropped. + pub close_client_connection: bool, } /// [`Wrapper`] will not (cannot) bring the current list of transforms that it needs to traverse with it @@ -170,6 +175,7 @@ impl<'a> Clone for Wrapper<'a> { transforms: [].iter_mut(), local_addr: self.local_addr, flush: self.flush, + close_client_connection: self.close_client_connection, } } } @@ -181,6 +187,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: std::mem::take(&mut self.transforms), local_addr: self.local_addr, flush: self.flush, + close_client_connection: self.close_client_connection, } } @@ -232,6 +239,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: [].iter_mut(), local_addr: "127.0.0.1:8000".parse().unwrap(), flush: false, + close_client_connection: false, } } @@ -241,6 +249,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { transforms: [].iter_mut(), local_addr, flush: false, + close_client_connection: false, } } @@ -251,6 +260,7 @@ impl<'shorter, 'longer: 'shorter> Wrapper<'longer> { // The connection is closed so we need to just fake an address here local_addr: "127.0.0.1:10000".parse().unwrap(), flush: true, + close_client_connection: false, } }