Skip to content

Commit

Permalink
Fix client timeout (#1399)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Dec 21, 2023
1 parent 36e6cc3 commit a998cb7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 32 deletions.
57 changes: 29 additions & 28 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::codec::{CodecBuilder, CodecReadError, CodecWriteError};
use crate::config::chain::TransformChainConfig;
use crate::message::Messages;
use crate::message::{Message, Messages};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
Expand All @@ -19,7 +19,6 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, watch, OwnedSemaphorePermit, Semaphore};
use tokio::task::JoinHandle;
use tokio::time;
use tokio::time::timeout;
use tokio::time::Duration;
use tokio_tungstenite::tungstenite::{
handshake::server::{Request, Response},
Expand Down Expand Up @@ -67,8 +66,8 @@ pub struct TcpCodecListener<C: CodecBuilder> {

available_connections_gauge: Gauge,

/// Timeout in seconds after which to kill an idle connection. No timeout means connections will never be timed out.
timeout: Option<u64>,
/// Timeout after which to kill an idle connection. No timeout means connections will never be timed out.
timeout: Option<Duration>,

connection_handles: Vec<JoinHandle<()>>,

Expand All @@ -86,7 +85,7 @@ impl<C: CodecBuilder + 'static> TcpCodecListener<C> {
limit_connections: Arc<Semaphore>,
trigger_shutdown_rx: watch::Receiver<bool>,
tls: Option<TlsAcceptor>,
timeout: Option<u64>,
timeout: Option<Duration>,
transport: Transport,
) -> Result<Self, Vec<String>> {
let available_connections_gauge = register_gauge!("shotover_available_connections_count", "source" => source_name.clone());
Expand Down Expand Up @@ -283,7 +282,7 @@ pub struct Handler<C: CodecBuilder> {
/// which point the connection is terminated.
shutdown: Shutdown,
/// Timeout in seconds after which to kill an idle connection. No timeout means connections will never be timed out.
timeout: Option<u64>,
timeout: Option<Duration>,
pushed_messages_rx: UnboundedReceiver<Messages>,
_permit: OwnedSemaphorePermit,
}
Expand Down Expand Up @@ -678,6 +677,24 @@ impl<C: CodecBuilder + 'static> Handler<C> {
result
}

async fn receive_with_timeout(
timeout: Option<Duration>,
in_rx: &mut UnboundedReceiver<Vec<Message>>,
client_details: &str,
) -> Option<Vec<Message>> {
if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, in_rx.recv()).await {
Ok(messages) => messages,
Err(_) => {
debug!("Dropping connection to {client_details} due to being idle for more than {timeout:?}");
None
}
}
} else {
in_rx.recv().await
}
}

async fn process_messages(
&mut self,
client_details: &str,
Expand All @@ -687,34 +704,18 @@ impl<C: CodecBuilder + 'static> Handler<C> {
) -> Result<()> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
let mut idle_time_seconds: u64 = 1;

while !self.shutdown.is_shutdown() {
// While reading a request frame, also listen for the shutdown signal
debug!("Waiting for message {client_details}");
let mut reverse_chain = false;

let messages = tokio::select! {
res = timeout(Duration::from_secs(idle_time_seconds), in_rx.recv()) => {
match res {
Ok(maybe_message) => {
idle_time_seconds = 1;
match maybe_message {
Some(m) => m,
None => return Ok(())
}
},
Err(_) => {
if let Some(timeout) = self.timeout {
if idle_time_seconds < timeout {
debug!("Connection Idle for more than {} seconds {}", timeout, client_details);
} else {
debug!("Dropping. Connection Idle for more than {} seconds {}", timeout, client_details);
return Ok(());
}
}
idle_time_seconds *= 2;
continue
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
match requests {
Some(requests) => requests,
None => {
// Either we timed out the connection or the client disconnected, so terminate this connection
return Ok(())
}
}
},
Expand Down
3 changes: 2 additions & 1 deletion shotover/src/sources/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};
Expand Down Expand Up @@ -75,7 +76,7 @@ impl CassandraSource {
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
tls.map(TlsAcceptor::new).transpose()?,
timeout,
timeout.map(Duration::from_secs),
transport.unwrap_or(Transport::Tcp),
)
.await?;
Expand Down
3 changes: 2 additions & 1 deletion shotover/src/sources/kafka.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};
Expand Down Expand Up @@ -71,7 +72,7 @@ impl KafkaSource {
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
tls.map(TlsAcceptor::new).transpose()?,
timeout,
timeout.map(Duration::from_secs),
Transport::Tcp,
)
.await?;
Expand Down
3 changes: 2 additions & 1 deletion shotover/src/sources/opensearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::sources::{Source, Transport};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};
Expand Down Expand Up @@ -65,7 +66,7 @@ impl OpenSearchSource {
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
None,
timeout,
timeout.map(Duration::from_secs),
Transport::Tcp,
)
.await?;
Expand Down
3 changes: 2 additions & 1 deletion shotover/src/sources/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::tls::{TlsAcceptor, TlsAcceptorConfig};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{watch, Semaphore};
use tokio::task::JoinHandle;
use tracing::{error, info};
Expand Down Expand Up @@ -71,7 +72,7 @@ impl RedisSource {
Arc::new(Semaphore::new(connection_limit.unwrap_or(512))),
trigger_shutdown_rx.clone(),
tls.map(TlsAcceptor::new).transpose()?,
timeout,
timeout.map(Duration::from_secs),
Transport::Tcp,
)
.await?;
Expand Down

0 comments on commit a998cb7

Please sign in to comment.