Skip to content

Commit

Permalink
KafkaSinkCluster - handle receive errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 20, 2024
1 parent 17ccad5 commit e365edd
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 41 deletions.
184 changes: 164 additions & 20 deletions shotover/src/transforms/kafka/sink_cluster/connections.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::connection::SinkConnection;
use anyhow::{Context, Result};
use crate::{
connection::{ConnectionError, SinkConnection},
message::Message,
};
use anyhow::{anyhow, Context, Result};
use fnv::FnvBuildHasher;
use kafka_protocol::{messages::BrokerId, protocol::StrBytes};
use metrics::Counter;
use rand::{rngs::SmallRng, seq::SliceRandom};
use std::collections::HashMap;
use std::{
collections::HashMap,
time::{Duration, Instant},
};

use super::{
node::{ConnectionFactory, KafkaAddress, KafkaNode},
Expand All @@ -27,8 +33,31 @@ pub enum Destination {
ControlConnection,
}

pub struct KafkaConnection {
pub connection: SinkConnection,
old_connection: Option<SinkConnection>,
created_at: Instant,
}

impl KafkaConnection {
pub fn try_recv_into(&mut self, responses: &mut Vec<Message>) -> Result<(), ConnectionError> {
// ensure old connection is completely drained before receiving from new connection
if let Some(old_connection) = &mut self.old_connection {
old_connection.try_recv_into(responses)?;
if old_connection.pending_requests_count() == 0 {
self.old_connection = None;
Ok(())
} else {
self.connection.try_recv_into(responses)
}
} else {
self.connection.try_recv_into(responses)
}
}
}

pub struct Connections {
pub connections: HashMap<Destination, SinkConnection, FnvBuildHasher>,
pub connections: HashMap<Destination, KafkaConnection, FnvBuildHasher>,
out_of_rack_requests: Counter,
}

Expand All @@ -51,7 +80,7 @@ impl Connections {
contact_points: &[KafkaAddress],
local_rack: &StrBytes,
destination: Destination,
) -> Result<&mut SinkConnection> {
) -> Result<&mut KafkaConnection> {
let node = match destination {
Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()),
Destination::ControlConnection => None,
Expand All @@ -67,22 +96,137 @@ impl Connections {
}
}

// map entry API can not be used with async
#[allow(clippy::map_entry)]
if !self.connections.contains_key(&destination) {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};
match self.get_connection_state(authorize_scram_over_mtls, destination) {
ConnectionState::Open => Ok(self.connections.get_mut(&destination).unwrap()),
ConnectionState::Unopened => {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

self.connections.insert(
destination,
KafkaConnection {
connection: connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
old_connection: None,
created_at: Instant::now(),
},
);
Ok(self.connections.get_mut(&destination).unwrap())
}
ConnectionState::AtRiskOfTimeout => {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

let old_connection = self.connections.remove(&destination).unwrap();
if old_connection.old_connection.is_some() {
return Err(anyhow!("Old connection had an old connection"));
}
let old_connection = if old_connection.connection.pending_requests_count() == 0 {
None
} else {
Some(old_connection.connection)
};

self.connections.insert(
destination,
connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
);
self.connections.insert(
destination,
KafkaConnection {
connection: connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
old_connection,
created_at: Instant::now(),
},
);
tracing::info!("Recreated outgoing connection due to risk of timeout");
Ok(self.connections.get_mut(&destination).unwrap())
}
}
Ok(self.connections.get_mut(&destination).unwrap())
}

fn get_connection_state(
&self,
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
destination: Destination,
) -> ConnectionState {
let timeout = if let Some(scram_over_mtls) = authorize_scram_over_mtls {
// The delegation token is recreated after `0.5 * delegation_token_lifetime`
// Consider what happens when we match that timing for our connection timeout here:
//
// create token t1 create token t2
// |--------------------|--------------------|
// | ^ all connections created after this point use token t2 instead of token t1
// | |
// | token t1 lifetime |
// |-----------------------------------------|
// | ^
// | after this point, connections still alive that were authed with token t1 will be closed by the broker.
// | |
// | |
// | |
// | token t2 lifetime
// | |-----------------------------------------|
// | ^ all connections created after this point use token t2
// | |
// | |
// | |
// | connection lifetime using token t1 |
// | |--------------------| |
// This case is fine, the connection exists entirely within the lifetime of token t1.
// | |
// | |
// | |
// | connection lifetime using token t2
// | |--------------------|
// This case is fine, the connection exists entirely within the lifetime of token t2.
// | |
// | |
// | |
// | connection lifetime using token t2 (or is token t1?)
// | |--------------------|
// This case is potentially a race condition.
// We could start with either token t2 or t1.
// If we start with t1 we could go past the end of t1's lifetime.
// To avoid this issue we reduce the size of the connection lifetime by a further 0.75 seconds.
//
// At low values of delegation_token_lifetime all of this falls apart since something
// like a VM migration could delay shotover execution for many seconds.
// However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine.
scram_over_mtls
.delegation_token_lifetime
// match token recreation time
.mul_f32(0.5)
// further reduce connection timeout
.mul_f32(0.75)
} else {
// TODO: make this configurable
// for now, use the default value of connections.max.idle.ms (10 minutes)
let connections_max_idle = Duration::from_secs(60 * 10);
connections_max_idle.mul_f32(0.75)
};
if let Some(connection) = self.connections.get(&destination) {
// TODO: subtract a batch level Instant::now instead of using elapsed
// use 3/4 of the timeout to make sure we trigger this well before it actually times out
if connection.created_at.elapsed() > timeout {
ConnectionState::AtRiskOfTimeout
} else {
ConnectionState::Open
}
} else {
ConnectionState::Unopened
}
}
}

enum ConnectionState {
Open,
Unopened,
// TODO: maybe combine with Unopened since old_connection can just easily take the appropriate Option value
AtRiskOfTimeout,
}
40 changes: 20 additions & 20 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ impl KafkaSinkCluster {
Destination::ControlConnection,
)
.await?;
connection.send(vec![requests])?;
Ok(connection.recv().await?.remove(0))
connection.connection.send(vec![requests])?;
Ok(connection.connection.recv().await?.remove(0))
}

fn store_topic_names(&self, topics: &mut Vec<TopicName>, topic: TopicName) {
Expand Down Expand Up @@ -1026,6 +1026,7 @@ routing message to a random node so that:
destination,
)
.await?
.connection
.send(requests.requests)?;
}

Expand All @@ -1038,24 +1039,23 @@ routing message to a random node so that:
// Convert all received PendingRequestTy::Sent into PendingRequestTy::Received
for (connection_destination, connection) in &mut self.connections.connections {
self.temp_responses_buffer.clear();
if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) {
for response in self.temp_responses_buffer.drain(..) {
let mut response = Some(response);
for pending_request in &mut self.pending_requests {
if let PendingRequestTy::Sent { destination, index } =
&mut pending_request.ty
{
if destination == connection_destination {
// Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent
// All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent
// to be used next time, and the time after that, and ...
if *index == 0 {
pending_request.ty = PendingRequestTy::Received {
response: response.take().unwrap(),
};
} else {
*index -= 1;
}
connection
.try_recv_into(&mut self.temp_responses_buffer)
.with_context(|| format!("Failed to receive from {connection_destination:?}"))?;
for response in self.temp_responses_buffer.drain(..) {
let mut response = Some(response);
for pending_request in &mut self.pending_requests {
if let PendingRequestTy::Sent { destination, index } = &mut pending_request.ty {
if destination == connection_destination {
// Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent
// All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent
// to be used next time, and the time after that, and ...
if *index == 0 {
pending_request.ty = PendingRequestTy::Received {
response: response.take().unwrap(),
};
} else {
*index -= 1;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,21 @@ impl AuthorizeScramOverMtlsConfig {
.iter()
.map(|x| KafkaAddress::from_str(x))
.collect();
let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds);
Ok(AuthorizeScramOverMtlsBuilder {
token_task: TokenTask::new(
mtls_connection_factory,
contact_points?,
Duration::from_secs(self.delegation_token_lifetime_seconds),
delegation_token_lifetime,
),
delegation_token_lifetime,
})
}
}

pub struct AuthorizeScramOverMtlsBuilder {
pub token_task: TokenTask,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtlsBuilder {
Expand All @@ -245,6 +248,7 @@ impl AuthorizeScramOverMtlsBuilder {
original_scram_state: OriginalScramState::WaitingOnServerFirst,
token_task: self.token_task.clone(),
username: String::new(),
delegation_token_lifetime: self.delegation_token_lifetime,
}
}
}
Expand All @@ -256,6 +260,7 @@ pub struct AuthorizeScramOverMtls {
token_task: TokenTask,
/// The username used in the original scram auth to generate the delegation token
username: String,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtls {
Expand Down

0 comments on commit e365edd

Please sign in to comment.