Skip to content

Commit

Permalink
KafkaSinkCluster - handle receive errors (#1728)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Aug 22, 2024
1 parent 09ba622 commit 307bb10
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 24 deletions.
201 changes: 181 additions & 20 deletions shotover/src/transforms/kafka/sink_cluster/connections.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use crate::connection::SinkConnection;
use crate::{
connection::{ConnectionError, SinkConnection},
message::Message,
};
use anyhow::{Context, Result};
use fnv::FnvBuildHasher;
use kafka_protocol::{messages::BrokerId, protocol::StrBytes};
use metrics::Counter;
use rand::{rngs::SmallRng, seq::SliceRandom};
use std::collections::HashMap;
use std::{collections::HashMap, time::Instant};

use super::{
node::{ConnectionFactory, KafkaAddress, KafkaNode},
scram_over_mtls::AuthorizeScramOverMtls,
scram_over_mtls::{connection::ScramOverMtlsConnection, AuthorizeScramOverMtls},
SASL_SCRAM_MECHANISMS,
};

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
Expand All @@ -28,7 +32,7 @@ pub enum Destination {
}

pub struct Connections {
pub connections: HashMap<Destination, SinkConnection, FnvBuildHasher>,
pub connections: HashMap<Destination, KafkaConnection, FnvBuildHasher>,
out_of_rack_requests: Counter,
}

Expand All @@ -50,8 +54,9 @@ impl Connections {
nodes: &[KafkaNode],
contact_points: &[KafkaAddress],
local_rack: &StrBytes,
recent_instant: Instant,
destination: Destination,
) -> Result<&mut SinkConnection> {
) -> Result<&mut KafkaConnection> {
let node = match destination {
Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()),
Destination::ControlConnection => None,
Expand All @@ -67,22 +72,178 @@ impl Connections {
}
}

// map entry API can not be used with async
#[allow(clippy::map_entry)]
if !self.connections.contains_key(&destination) {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};

self.connections.insert(
destination,
connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await
.context("Failed to create a new connection")?,
);
match self.get_connection_state(recent_instant, destination) {
ConnectionState::Open => {
// connection already open
}
ConnectionState::Unopened => {
self.create_and_insert_connection(
rng,
connection_factory,
authorize_scram_over_mtls,
sasl_mechanism,
node,
contact_points,
None,
destination,
)
.await
.context("Failed to create a new connection")?;
}
// This variant is only returned when scram_over_mtls is in use
ConnectionState::AtRiskOfAuthTokenExpiry => {
let old_connection = self.connections.remove(&destination);

self.create_and_insert_connection(
rng,
connection_factory,
authorize_scram_over_mtls,
sasl_mechanism,
node,
contact_points,
old_connection,
destination,
)
.await
.context("Failed to create a new connection to replace a connection that is at risk of having its delegation token expire")?;

tracing::info!(
"Recreated outgoing connection due to risk of delegation token expiring"
);
}
}
Ok(self.connections.get_mut(&destination).unwrap())
}

#[allow(clippy::too_many_arguments)]
async fn create_and_insert_connection(
&mut self,
rng: &mut SmallRng,
connection_factory: &ConnectionFactory,
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
sasl_mechanism: &Option<String>,
node: Option<&KafkaNode>,
contact_points: &[KafkaAddress],
old_connection: Option<KafkaConnection>,
destination: Destination,
) -> Result<()> {
let address = match &node {
Some(node) => &node.kafka_address,
None => contact_points.choose(rng).unwrap(),
};
let connection = connection_factory
.create_connection(address, authorize_scram_over_mtls, sasl_mechanism)
.await?;

self.connections.insert(
destination,
KafkaConnection::new(
authorize_scram_over_mtls,
sasl_mechanism,
connection,
old_connection,
)?,
);

Ok(())
}

fn get_connection_state(
&self,
recent_instant: Instant,
destination: Destination,
) -> ConnectionState {
if let Some(connection) = self.connections.get(&destination) {
connection.state(recent_instant)
} else {
ConnectionState::Unopened
}
}
}

pub enum KafkaConnection {
Regular(SinkConnection),
ScramOverMtls(ScramOverMtlsConnection),
}

impl KafkaConnection {
pub fn new(
authorize_scram_over_mtls: &Option<AuthorizeScramOverMtls>,
sasl_mechanism: &Option<String>,
connection: SinkConnection,
old_connection: Option<KafkaConnection>,
) -> Result<Self> {
let using_scram_over_mtls = authorize_scram_over_mtls.is_some()
&& sasl_mechanism
.as_ref()
.map(|x| SASL_SCRAM_MECHANISMS.contains(&x.as_str()))
.unwrap_or(false);
if using_scram_over_mtls {
let old_connection = old_connection.map(|x| match x {
KafkaConnection::Regular(_) => {
panic!("Cannot replace a Regular connection with ScramOverMtlsConnection")
}
KafkaConnection::ScramOverMtls(old_connection) => old_connection,
});
Ok(KafkaConnection::ScramOverMtls(
ScramOverMtlsConnection::new(
connection,
old_connection,
authorize_scram_over_mtls,
)?,
))
} else {
Ok(KafkaConnection::Regular(connection))
}
}

/// Attempts to receive messages, if there are no messages available it immediately returns an empty vec.
/// If there is a problem with the connection an error is returned.
pub fn try_recv_into(&mut self, responses: &mut Vec<Message>) -> Result<(), ConnectionError> {
match self {
KafkaConnection::Regular(c) => c.try_recv_into(responses),
KafkaConnection::ScramOverMtls(c) => c.try_recv_into(responses),
}
}

/// Send messages.
/// If there is a problem with the connection an error is returned.
pub fn send(&mut self, messages: Vec<Message>) -> Result<(), ConnectionError> {
match self {
KafkaConnection::Regular(c) => c.send(messages),
KafkaConnection::ScramOverMtls(c) => c.send(messages),
}
}

/// Receives messages, if there are no messages available it awaits until there are messages.
/// If there is a problem with the connection an error is returned.
pub async fn recv(&mut self) -> Result<Vec<Message>, ConnectionError> {
match self {
KafkaConnection::Regular(c) => c.recv().await,
KafkaConnection::ScramOverMtls(c) => c.recv().await,
}
}

/// Number of requests waiting on a response.
/// The count includes requests that will have a dummy response generated by shotover.
pub fn pending_requests_count(&self) -> usize {
match self {
KafkaConnection::Regular(c) => c.pending_requests_count(),
KafkaConnection::ScramOverMtls(c) => c.pending_requests_count(),
}
}

/// Returns either ConnectionState::Open or ConnectionState::AtRiskOfAuthTokenExpiry
pub fn state(&self, recent_instant: Instant) -> ConnectionState {
match self {
KafkaConnection::Regular(_) => ConnectionState::Open,
KafkaConnection::ScramOverMtls(c) => c.state(recent_instant),
}
}
}

pub enum ConnectionState {
Open,
Unopened,
AtRiskOfAuthTokenExpiry,
}
15 changes: 12 additions & 3 deletions shotover/src/transforms/kafka/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use std::collections::{HashMap, VecDeque};
use std::hash::Hasher;
use std::sync::atomic::AtomicI64;
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;

Expand Down Expand Up @@ -400,6 +400,7 @@ impl KafkaSinkCluster {
&self.nodes,
&self.first_contact_points,
&self.rack,
Instant::now(),
Destination::ControlConnection,
)
.await?;
Expand Down Expand Up @@ -1013,6 +1014,7 @@ routing message to a random node so that:
}
}

let recent_instant = Instant::now();
for (destination, requests) in broker_to_routed_requests {
self.connections
.get_or_open_connection(
Expand All @@ -1023,6 +1025,7 @@ routing message to a random node so that:
&self.nodes,
&self.first_contact_points,
&self.rack,
recent_instant,
destination,
)
.await?
Expand All @@ -1037,8 +1040,14 @@ routing message to a random node so that:
fn recv_responses(&mut self) -> Result<Vec<Message>> {
// Convert all received PendingRequestTy::Sent into PendingRequestTy::Received
for (connection_destination, connection) in &mut self.connections.connections {
self.temp_responses_buffer.clear();
if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) {
// skip recv when no pending requests to avoid timeouts on old connections
if connection.pending_requests_count() != 0 {
self.temp_responses_buffer.clear();
connection
.try_recv_into(&mut self.temp_responses_buffer)
.with_context(|| {
format!("Failed to receive from {connection_destination:?}")
})?;
for response in self.temp_responses_buffer.drain(..) {
let mut response = Some(response);
for pending_request in &mut self.pending_requests {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use tokio::sync::Notify;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::StreamExt;

pub(crate) mod connection;
mod create_token;
mod recreate_token_queue;

Expand Down Expand Up @@ -225,18 +226,21 @@ impl AuthorizeScramOverMtlsConfig {
.iter()
.map(|x| KafkaAddress::from_str(x))
.collect();
let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds);
Ok(AuthorizeScramOverMtlsBuilder {
token_task: TokenTask::new(
mtls_connection_factory,
contact_points?,
Duration::from_secs(self.delegation_token_lifetime_seconds),
delegation_token_lifetime,
),
delegation_token_lifetime,
})
}
}

pub struct AuthorizeScramOverMtlsBuilder {
pub token_task: TokenTask,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtlsBuilder {
Expand All @@ -245,6 +249,7 @@ impl AuthorizeScramOverMtlsBuilder {
original_scram_state: OriginalScramState::WaitingOnServerFirst,
token_task: self.token_task.clone(),
username: String::new(),
delegation_token_lifetime: self.delegation_token_lifetime,
}
}
}
Expand All @@ -256,6 +261,7 @@ pub struct AuthorizeScramOverMtls {
token_task: TokenTask,
/// The username used in the original scram auth to generate the delegation token
username: String,
pub delegation_token_lifetime: Duration,
}

impl AuthorizeScramOverMtls {
Expand Down
Loading

0 comments on commit 307bb10

Please sign in to comment.