Skip to content

Commit

Permalink
KafkaSinkCluster - handle receive errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 20, 2024
1 parent 70cf0be commit 9ff6fc8
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 42 deletions.
215 changes: 195 additions & 20 deletions shotover/src/transforms/kafka/sink_cluster/connections.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use crate::connection::SinkConnection;
use anyhow::{Context, Result};
use crate::{
connection::{ConnectionError, SinkConnection},
message::Message,
};
use anyhow::{anyhow, Context, Result};
use fnv::FnvBuildHasher;
use kafka_protocol::{messages::BrokerId, protocol::StrBytes};
use metrics::Counter;
use rand::{rngs::SmallRng, seq::SliceRandom};
use std::collections::HashMap;
use std::{
collections::HashMap,
time::{Duration, Instant},
};

use super::{
node::{ConnectionFactory, KafkaAddress, KafkaNode},
Expand All @@ -27,8 +33,33 @@ pub enum Destination {
ControlConnection,
}

pub struct KafkaConnection {
pub connection: SinkConnection,
/// When a connection is recreated to avoid timeouts,
/// the old connection will be kept around until all responses have been received from it.
old_connection: Option<SinkConnection>,
created_at: Instant,
}

impl KafkaConnection {
pub fn try_recv_into(&mut self, responses: &mut Vec<Message>) -> Result<(), ConnectionError> {
// ensure old connection is completely drained before receiving from new connection
if let Some(old_connection) = &mut self.old_connection {
old_connection.try_recv_into(responses)?;
if old_connection.pending_requests_count() == 0 {
self.old_connection = None;
Ok(())
} else {
self.connection.try_recv_into(responses)
}
} else {
self.connection.try_recv_into(responses)
}
}
}

pub struct Connections {
pub connections: HashMap<Destination, SinkConnection, FnvBuildHasher>,
pub connections: HashMap<Destination, KafkaConnection, FnvBuildHasher>,
out_of_rack_requests: Counter,
}

Expand All @@ -50,8 +81,9 @@ impl Connections {
nodes: &[KafkaNode],
contact_points: &[KafkaAddress],
local_rack: &StrBytes,
recent_instant: Instant,
destination: Destination,
) -> Result<&mut SinkConnection> {
) -> Result<&mut KafkaConnection> {
let node = match destination {
Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()),
Destination::ControlConnection => None,
Expand All @@ -67,22 +99,165 @@ impl Connections {
}
}

// map entry API can not be used with async
#[allow(clippy::map_entry)]
if !self.connections.contains_key(&destination) {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

self.connections.insert(
destination,
connection_factory
match self.get_connection_state(authorize_scram_over_mtls, recent_instant, destination) {
ConnectionState::Open => Ok(self.connections.get_mut(&destination).unwrap()),
ConnectionState::Unopened => {
self.create_connection(
rng,
connection_factory,
authorize_scram_over_mtls,
sasl_mechanism,
node,
contact_points,
None,
destination,
)
.await
.context("Failed to create a new connection")?;

Ok(self.connections.get_mut(&destination).unwrap())
}
ConnectionState::AtRiskOfTimeout => {
let old_connection = self.connections.remove(&destination).unwrap();
if old_connection.old_connection.is_some() {
return Err(anyhow!("Old connection had an old connection"));
}
let old_connection = if old_connection.connection.pending_requests_count() == 0 {
None
} else {
Some(old_connection.connection)
};

self.create_connection(
rng,
connection_factory,
authorize_scram_over_mtls,
sasl_mechanism,
node,
contact_points,
old_connection,
destination,
)
.await
.context("Failed to create a new connection to replace an old connection")?;

tracing::info!("Recreated outgoing connection due to risk of timeout");
Ok(self.connections.get_mut(&destination).unwrap())
}
}
}

#[allow(clippy::too_many_arguments)]
async fn create_connection(
&mut self,
rng: &mut SmallRng,
connection_factory: &ConnectionFactory,
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
sasl_mechanism: &Option<String>,
node: Option<&KafkaNode>,
contact_points: &[KafkaAddress],
old_connection: Option<SinkConnection>,
destination: Destination,
) -> Result<()> {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

self.connections.insert(
destination,
KafkaConnection {
connection: connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
);
.await?,
old_connection,
created_at: Instant::now(),
},
);

Ok(())
}

fn get_connection_state(
&self,
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
recent_instant: Instant,
destination: Destination,
) -> ConnectionState {
let timeout = if let Some(scram_over_mtls) = authorize_scram_over_mtls {
// The delegation token is recreated after `0.5 * delegation_token_lifetime`
// Consider what happens when we match that timing for our connection timeout here:
//
// create token t1 create token t2
// |--------------------|--------------------|
// | ^ all connections created after this point use token t2 instead of token t1
// | |
// | token t1 lifetime |
// |-----------------------------------------|
// | ^
// | after this point, connections still alive that were authed with token t1 will be closed by the broker.
// | |
// | |
// | |
// | token t2 lifetime
// | |-----------------------------------------|
// | ^ all connections created after this point use token t2
// | |
// | |
// | |
// | connection lifetime using token t1 |
// | |--------------------| |
// This case is fine, the connection exists entirely within the lifetime of token t1.
// | |
// | |
// | |
// | connection lifetime using token t2
// | |--------------------|
// This case is fine, the connection exists entirely within the lifetime of token t2.
// | |
// | |
// | |
// | connection lifetime using token t?
// | |--------------------|
// This case is a race condition.
// We could start with either token t2 or t1.
// If we start with t1 we could go past the end of t1's lifetime.
// To avoid this issue we reduce the size of the connection lifetime by a further 25%
//
// At low values of delegation_token_lifetime all of this falls apart since something
// like a VM migration could delay shotover execution for many seconds.
// However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine.
scram_over_mtls
.delegation_token_lifetime
// match token recreation time
.mul_f32(0.5)
// further reduce connection timeout
.mul_f32(0.75)
} else {
// use 3/4 of the timeout to make sure we trigger this well before it actually times out
CONNECTIONS_MAX_IDLE_DEFAULT.mul_f32(0.75)
// TODO: relying on the default value to be unchanged is not ideal, so either:
// * query the broker for the actual value of connections.max.idle.ms
// * have the user configure it in shotover's topology.yaml
};
if let Some(connection) = self.connections.get(&destination) {
// Since we cant be 100% exact with time anyway, we use a recent instant that can be reused to reduce syscalls.
if recent_instant.duration_since(connection.created_at) > timeout {
ConnectionState::AtRiskOfTimeout
} else {
ConnectionState::Open
}
} else {
ConnectionState::Unopened
}
Ok(self.connections.get_mut(&destination).unwrap())
}
}

/// default value of kafka broker config connections.max.idle.ms (10 minutes)
const CONNECTIONS_MAX_IDLE_DEFAULT: Duration = Duration::from_secs(60 * 10);

enum ConnectionState {
Open,
Unopened,
AtRiskOfTimeout,
}
45 changes: 24 additions & 21 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use std::collections::{HashMap, VecDeque};
use std::hash::Hasher;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;

Expand Down Expand Up @@ -403,11 +403,12 @@ impl KafkaSinkCluster {
&self.nodes,
&self.first_contact_points,
&self.rack,
Instant::now(),
Destination::ControlConnection,
)
.await?;
connection.send(vec![requests])?;
Ok(connection.recv().await?.remove(0))
connection.connection.send(vec![requests])?;
Ok(connection.connection.recv().await?.remove(0))
}

fn store_topic_names(&self, topics: &mut Vec<TopicName>, topic: TopicName) {
Expand Down Expand Up @@ -1016,6 +1017,7 @@ routing message to a random node so that:
}
}

let recent_instant = Instant::now();
for (destination, requests) in broker_to_routed_requests {
self.connections
.get_or_open_connection(
Expand All @@ -1026,9 +1028,11 @@ routing message to a random node so that:
&self.nodes,
&self.first_contact_points,
&self.rack,
recent_instant,
destination,
)
.await?
.connection
.send(requests.requests)?;
}

Expand All @@ -1041,24 +1045,23 @@ routing message to a random node so that:
// Convert all received PendingRequestTy::Sent into PendingRequestTy::Received
for (connection_destination, connection) in &mut self.connections.connections {
self.temp_responses_buffer.clear();
if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) {
for response in self.temp_responses_buffer.drain(..) {
let mut response = Some(response);
for pending_request in &mut self.pending_requests {
if let PendingRequestTy::Sent { destination, index } =
&mut pending_request.ty
{
if destination == connection_destination {
// Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent
// All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent
// to be used next time, and the time after that, and ...
if *index == 0 {
pending_request.ty = PendingRequestTy::Received {
response: response.take().unwrap(),
};
} else {
*index -= 1;
}
connection
.try_recv_into(&mut self.temp_responses_buffer)
.with_context(|| format!("Failed to receive from {connection_destination:?}"))?;
for response in self.temp_responses_buffer.drain(..) {
let mut response = Some(response);
for pending_request in &mut self.pending_requests {
if let PendingRequestTy::Sent { destination, index } = &mut pending_request.ty {
if destination == connection_destination {
// Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent
// All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent
// to be used next time, and the time after that, and ...
if *index == 0 {
pending_request.ty = PendingRequestTy::Received {
response: response.take().unwrap(),
};
} else {
*index -= 1;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,21 @@ impl AuthorizeScramOverMtlsConfig {
.iter()
.map(|x| KafkaAddress::from_str(x))
.collect();
let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds);
Ok(AuthorizeScramOverMtlsBuilder {
token_task: TokenTask::new(
mtls_connection_factory,
contact_points?,
Duration::from_secs(self.delegation_token_lifetime_seconds),
delegation_token_lifetime,
),
delegation_token_lifetime,
})
}
}

pub struct AuthorizeScramOverMtlsBuilder {
pub token_task: TokenTask,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtlsBuilder {
Expand All @@ -245,6 +248,7 @@ impl AuthorizeScramOverMtlsBuilder {
original_scram_state: OriginalScramState::WaitingOnServerFirst,
token_task: self.token_task.clone(),
username: String::new(),
delegation_token_lifetime: self.delegation_token_lifetime,
}
}
}
Expand All @@ -256,6 +260,7 @@ pub struct AuthorizeScramOverMtls {
token_task: TokenTask,
/// The username used in the original scram auth to generate the delegation token
username: String,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtls {
Expand Down

0 comments on commit 9ff6fc8

Please sign in to comment.