Skip to content

Commit

Permalink
Add Wrapper::close_client_connection
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 20, 2024
1 parent 17ccad5 commit 6fd0705
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 61 deletions.
38 changes: 22 additions & 16 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,9 @@ impl<C: CodecBuilder + 'static> Handler<C> {
.run_loop(&client_details, local_addr, in_rx, out_tx, force_run_chain)
.await;

// Only flush messages if we are shutting down due to application shutdown
// Only flush messages if we are shutting down due to shotover shutdown or client disconnect
// If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing.
if result.is_ok() {
if let Ok(CloseReason::ShotoverShutdown | CloseReason::ClientClosed) = result {
match self.chain.process_request(&mut Wrapper::flush()).await {
Ok(_) => {}
Err(e) => error!(
Expand All @@ -649,7 +649,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}
}

result
result.map(|_| ())
}

async fn receive_with_timeout(
Expand Down Expand Up @@ -677,7 +677,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
mut in_rx: mpsc::Receiver<Messages>,
out_tx: mpsc::UnboundedSender<Messages>,
force_run_chain: Arc<Notify>,
) -> Result<()> {
) -> Result<CloseReason> {
// As long as the shutdown signal has not been received, try to read a
// new request frame.
while !self.shutdown.is_shutdown() {
Expand All @@ -688,16 +688,16 @@ impl<C: CodecBuilder + 'static> Handler<C> {
_ = self.shutdown.recv() => {
// If a shutdown signal is received, return from `run`.
// This will result in the task terminating.
return Ok(());
return Ok(CloseReason::ShotoverShutdown);
}
() = force_run_chain.notified() => {
let mut requests = vec!();
while let Ok(x) = in_rx.try_recv() {
requests.extend(x);
}
debug!("A transform in the chain requested that a chain run occur, requests {:?}", requests);
if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? {
return Ok(())
if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? {
return Ok(close_reason)
}
},
requests = Self::receive_with_timeout(self.timeout, &mut in_rx, client_details) => {
Expand All @@ -707,30 +707,29 @@ impl<C: CodecBuilder + 'static> Handler<C> {
requests.extend(x);
}
debug!("Received requests from client {:?}", requests);
if let Some(_close_reason) = self.send_receive_chain(local_addr, &out_tx, requests).await? {
return Ok(())
if let Some(close_reason) = self.process_requests(local_addr, &out_tx, requests).await? {
return Ok(close_reason)
}
}
// Either we timed out the connection or the client disconnected, so terminate this connection
None => return Ok(()),
None => return Ok(CloseReason::ClientClosed),
}
},
};
}

Ok(())
Ok(CloseReason::ShotoverShutdown)
}

async fn send_receive_chain(
async fn process_requests(
&mut self,
local_addr: SocketAddr,
out_tx: &mpsc::UnboundedSender<Messages>,
requests: Messages,
) -> Result<Option<CloseReason>> {
self.pending_requests.process_requests(&requests);

let mut wrapper = Wrapper::new_with_addr(requests, local_addr);

self.pending_requests.process_requests(&wrapper.requests);
let responses = match self.chain.process_request(&mut wrapper).await {
Ok(x) => x,
Err(err) => {
Expand All @@ -748,17 +747,24 @@ impl<C: CodecBuilder + 'static> Handler<C> {
debug!("sending response to client: {:?}", responses);
if out_tx.send(responses).is_err() {
// the client has disconnected so we should terminate this connection
return Ok(Some(CloseReason::Generic));
return Ok(Some(CloseReason::ClientClosed));
}
}

// if requested by a transform, close connection AFTER sending any responses back to the client
if wrapper.close_client_connection {
return Ok(Some(CloseReason::TransformRequested));
}

Ok(None)
}
}

/// Indicates that the connection to the client must be closed.
enum CloseReason {
Generic,
TransformRequested,
ClientClosed,
ShotoverShutdown,
}

/// Listens for the server shutdown signal.
Expand Down
177 changes: 157 additions & 20 deletions shotover/src/transforms/kafka/sink_cluster/connections.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::connection::SinkConnection;
use anyhow::{Context, Result};
use crate::{
connection::{ConnectionError, SinkConnection},
message::Message,
};
use anyhow::{anyhow, Context, Result};
use fnv::FnvBuildHasher;
use kafka_protocol::{messages::BrokerId, protocol::StrBytes};
use metrics::Counter;
use rand::{rngs::SmallRng, seq::SliceRandom};
use std::collections::HashMap;
use std::{
collections::HashMap,
time::{Duration, Instant},
};

use super::{
node::{ConnectionFactory, KafkaAddress, KafkaNode},
Expand All @@ -27,8 +33,31 @@ pub enum Destination {
ControlConnection,
}

pub struct KafkaConnection {
pub connection: SinkConnection,
old_connection: Option<SinkConnection>,
created_at: Instant,
}

impl KafkaConnection {
pub fn try_recv_into(&mut self, responses: &mut Vec<Message>) -> Result<(), ConnectionError> {
// ensure old connection is completely drained before receiving from new connection
if let Some(old_connection) = &mut self.old_connection {
old_connection.try_recv_into(responses)?;
if old_connection.pending_requests_count() == 0 {
self.old_connection = None;
Ok(())
} else {
self.connection.try_recv_into(responses)
}
} else {
self.connection.try_recv_into(responses)
}
}
}

pub struct Connections {
pub connections: HashMap<Destination, SinkConnection, FnvBuildHasher>,
pub connections: HashMap<Destination, KafkaConnection, FnvBuildHasher>,
out_of_rack_requests: Counter,
}

Expand All @@ -51,7 +80,7 @@ impl Connections {
contact_points: &[KafkaAddress],
local_rack: &StrBytes,
destination: Destination,
) -> Result<&mut SinkConnection> {
) -> Result<&mut KafkaConnection> {
let node = match destination {
Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()),
Destination::ControlConnection => None,
Expand All @@ -67,22 +96,130 @@ impl Connections {
}
}

// map entry API can not be used with async
#[allow(clippy::map_entry)]
if !self.connections.contains_key(&destination) {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};
match self.get_connection_state(authorize_scram_over_mtls, destination) {
ConnectionState::Open => Ok(self.connections.get_mut(&destination).unwrap()),
ConnectionState::Unopened => {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

self.connections.insert(
destination,
KafkaConnection {
connection: connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
old_connection: None,
created_at: Instant::now(),
},
);
Ok(self.connections.get_mut(&destination).unwrap())
}
ConnectionState::AtRiskOfTimeout => {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

let old_connection = self.connections.remove(&destination).unwrap();
if old_connection.old_connection.is_some() {
return Err(anyhow!("Old connection had an old connection"));
}
let old_connection = if old_connection.connection.pending_requests_count() == 0 {
None
} else {
Some(old_connection.connection)
};

self.connections.insert(
destination,
connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
);
self.connections.insert(
destination,
KafkaConnection {
connection: connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
old_connection,
created_at: Instant::now(),
},
);
tracing::info!("Recreated outgoing connection due to risk of timeout");
Ok(self.connections.get_mut(&destination).unwrap())
}
}
Ok(self.connections.get_mut(&destination).unwrap())
}

fn get_connection_state(
&self,
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
destination: Destination,
) -> ConnectionState {
let timeout = if let Some(scram_over_mtls) = authorize_scram_over_mtls {
// The delegation token is recreated after `0.5 * token_lifetime`
// Consider what happens when we match that timing for our connection timeout here:
//
// create token t1 create token t2
// |--------------------|--------------------|
// ^ using a token after this point means we are working with a different token
//
// token t1 lifetime
// |-----------------------------------------|
// ^ we cannot use token t1 past this point
//
// token t2 lifetime
// |-----------------------------------------|
// ^ all connections created after this point use token t2
//
// connection lifetime using token t1
// |--------------------|
// This case is fine, it exists entirely within a lifetime of a connection.
//
//
// connection lifetime using token t2
// |--------------------|
// This case is fine, it exists entirely within a lifetime of a connection.
//
//
// connection lifetime using token t2 (or is token t1?)
// |--------------------|
// This case is potentially a race condition.
// We could start with either token t2 or t1.
// If we start with t1 we could go past the end of t1's lifetime due to race conditions.
// To avoid this issue we reduce the size of the connection lifetime by a further 0.75 seconds.
//
// At low values of delegation_token_lifetime all of this falls apart since something
// like a VM migration could delay execution for many seconds.
// However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine.
scram_over_mtls
.delegation_token_lifetime
// match token recreation time
.mul_f32(0.5)
// further restrict connection timeout
.mul_f32(0.75)
} else {
// TODO: make this configurable
// for now, use the default value of connections.max.idle.ms (10 minutes)
let connections_max_idle = Duration::from_secs(60 * 10);
connections_max_idle.mul_f32(0.75)
};
if let Some(connection) = self.connections.get(&destination) {
// TODO: subtract a batch level Instant::now instead of using elapsed
// use 3/4 of the timeout to make sure we trigger this well before it actually times out
if connection.created_at.elapsed() > timeout {
ConnectionState::AtRiskOfTimeout
} else {
ConnectionState::Open
}
} else {
ConnectionState::Unopened
}
}
}

enum ConnectionState {
Open,
Unopened,
// TODO: maybe combine with Unopened since old_connection can just easily take the appropriate Option value
AtRiskOfTimeout,
}
Loading

0 comments on commit 6fd0705

Please sign in to comment.