From d4a2d6718a756b8c05a6ab8b33f6d5ecc046fa3f Mon Sep 17 00:00:00 2001 From: Lucas Kent Date: Mon, 19 Aug 2024 12:56:02 +1000 Subject: [PATCH] server.rs - move response handling into .process() --- shotover/src/server.rs | 64 +++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/shotover/src/server.rs b/shotover/src/server.rs index c2f253182..c993cc7b9 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -6,7 +6,7 @@ use crate::sources::Transport; use crate::tls::{AcceptError, TlsAcceptor}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; use crate::transforms::{TransformContextBuilder, TransformContextConfig, Wrapper}; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; use bytes::BytesMut; use futures::future::join_all; use futures::{SinkExt, StreamExt}; @@ -631,7 +631,7 @@ impl Handler { }; let result = self - .process_messages(&client_details, local_addr, in_rx, out_tx, force_run_chain) + .run_loop(&client_details, local_addr, in_rx, out_tx, force_run_chain) .await; // Only flush messages if we are shutting down due to application shutdown @@ -670,7 +670,7 @@ impl Handler { } } - async fn process_messages( + async fn run_loop( &mut self, client_details: &str, local_addr: SocketAddr, @@ -683,7 +683,7 @@ impl Handler { while !self.shutdown.is_shutdown() { // While reading a request frame, also listen for the shutdown signal debug!("Waiting for message {client_details}"); - let responses = tokio::select! { + tokio::select! { biased; _ = self.shutdown.recv() => { // If a shutdown signal is received, return from `run`. @@ -696,7 +696,9 @@ impl Handler { requests.extend(x); } debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests); - self.process(local_addr, &out_tx, requests).await? + if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { + return Ok(()) + } }, requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => { match requests { @@ -705,56 +707,60 @@ impl Handler { requests.extend(x); } debug!("Received requests from client {:?}", requests); - self.process(local_addr, &out_tx, requests).await? - } - None => { - // Either we timed out the connection or the client disconnected, so terminate this connection - return Ok(()) + if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? { + return Ok(()) + } } + // Either we timed out the connection or the client disconnected, so terminate this connection + None => return Ok(()), } }, }; - - // send the result of the process up stream - if !responses.is_empty() { - debug!("sending response to client: {:?}", responses); - if out_tx.send(responses).is_err() { - // the client has disconnected so we should terminate this connection - return Ok(()); - } - } } Ok(()) } - async fn process( + async fn send_receive_chain( &mut self, local_addr: SocketAddr, out_tx: &mpsc::UnboundedSender, requests: Messages, - ) -> Result { + ) -> Result> { self.pending_requests.process_requests(&requests); let mut wrapper = Wrapper::new_with_addr(requests, local_addr); - match self.chain.process_request(&mut wrapper).await.context( - "Chain failed to send and/or receive messages, the connection will now be closed.", - ) { - Ok(x) => { - self.pending_requests.process_responses(&x); - Ok(x) - } + let responses = match self.chain.process_request(&mut wrapper).await { + Ok(x) => x, Err(err) => { + let err = err.context("Chain failed to send and/or receive messages, the connection will now be closed."); // The connection is going to be closed once we return Err. // So first make a best effort attempt of responding to any pending requests with an error response. out_tx.send(self.pending_requests.to_errors(&err))?; - Err(err) + return Err(err); + } + }; + self.pending_requests.process_responses(&responses); + + // send the result of the process up stream + if !responses.is_empty() { + debug!("sending response to client: {:?}", responses); + if out_tx.send(responses).is_err() { + // the client has disconnected so we should terminate this connection + return Ok(Some(CloseReason::Generic)); } } + + Ok(None) } } +/// Indicates that the connection to the client must be closed. +enum CloseReason { + Generic, +} + /// Listens for the server shutdown signal. /// /// Shutdown is signaled using a `broadcast::Receiver`. Only a single value is