diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index b9fbe1789..342c04788 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -1,14 +1,18 @@ -use crate::connection::SinkConnection; +use crate::{ + connection::{ConnectionError, SinkConnection}, + message::Message, +}; use anyhow::{Context, Result}; use fnv::FnvBuildHasher; use kafka_protocol::{messages::BrokerId, protocol::StrBytes}; use metrics::Counter; use rand::{rngs::SmallRng, seq::SliceRandom}; -use std::collections::HashMap; +use std::{collections::HashMap, time::Instant}; use super::{ node::{ConnectionFactory, KafkaAddress, KafkaNode}, - scram_over_mtls::AuthorizeScramOverMtls, + scram_over_mtls::{connection::ScramOverMtlsConnection, AuthorizeScramOverMtls}, + SASL_SCRAM_MECHANISMS, }; #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -28,7 +32,7 @@ pub enum Destination { } pub struct Connections { - pub connections: HashMap, + pub connections: HashMap, out_of_rack_requests: Counter, } @@ -50,8 +54,9 @@ impl Connections { nodes: &[KafkaNode], contact_points: &[KafkaAddress], local_rack: &StrBytes, + recent_instant: Instant, destination: Destination, - ) -> Result<&mut SinkConnection> { + ) -> Result<&mut KafkaConnection> { let node = match destination { Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()), Destination::ControlConnection => None, @@ -67,22 +72,178 @@ impl Connections { } } - // map entry API can not be used with async - #[allow(clippy::map_entry)] - if !self.connections.contains_key(&destination) { - let address = match &node { - Some(node) => &node.kafka_address, - None => contact_points.choose(rng).unwrap(), - }; - - self.connections.insert( - destination, - connection_factory - .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) - .await - .context("Failed to create a new connection")?, - ); + match self.get_connection_state(recent_instant, destination) { + ConnectionState::Open => { + // connection already open + } + ConnectionState::Unopened => { + self.create_and_insert_connection( + rng, + connection_factory, + authorize_scram_over_mtls, + sasl_mechanism, + node, + contact_points, + None, + destination, + ) + .await + .context("Failed to create a new connection")?; + } + // This variant is only returned when scram_over_mtls is in use + ConnectionState::AtRiskOfAuthTokenExpiry => { + let old_connection = self.connections.remove(&destination); + + self.create_and_insert_connection( + rng, + connection_factory, + authorize_scram_over_mtls, + sasl_mechanism, + node, + contact_points, + old_connection, + destination, + ) + .await + .context("Failed to create a new connection to replace a connection that is at risk of having its delegation token expire")?; + + tracing::info!( + "Recreated outgoing connection due to risk of delegation token expiring" + ); + } } Ok(self.connections.get_mut(&destination).unwrap()) } + + #[allow(clippy::too_many_arguments)] + async fn create_and_insert_connection( + &mut self, + rng: &mut SmallRng, + connection_factory: &ConnectionFactory, + authorize_scram_over_mtls: &Option, + sasl_mechanism: &Option, + node: Option<&KafkaNode>, + contact_points: &[KafkaAddress], + old_connection: Option, + destination: Destination, + ) -> Result<()> { + let address = match &node { + Some(node) => &node.kafka_address, + None => contact_points.choose(rng).unwrap(), + }; + let connection = connection_factory + .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) + .await?; + + self.connections.insert( + destination, + KafkaConnection::new( + authorize_scram_over_mtls, + sasl_mechanism, + connection, + old_connection, + )?, + ); + + Ok(()) + } + + fn get_connection_state( + &self, + recent_instant: Instant, + destination: Destination, + ) -> ConnectionState { + if let Some(connection) = self.connections.get(&destination) { + connection.state(recent_instant) + } else { + ConnectionState::Unopened + } + } +} + +pub enum KafkaConnection { + Regular(SinkConnection), + ScramOverMtls(ScramOverMtlsConnection), +} + +impl KafkaConnection { + pub fn new( + authorize_scram_over_mtls: &Option, + sasl_mechanism: &Option, + connection: SinkConnection, + old_connection: Option, + ) -> Result { + let using_scram_over_mtls = authorize_scram_over_mtls.is_some() + && sasl_mechanism + .as_ref() + .map(|x| SASL_SCRAM_MECHANISMS.contains(&x.as_str())) + .unwrap_or(false); + if using_scram_over_mtls { + let old_connection = old_connection.map(|x| match x { + KafkaConnection::Regular(_) => { + panic!("Cannot replace a Regular connection with ScramOverMtlsConnection") + } + KafkaConnection::ScramOverMtls(old_connection) => old_connection, + }); + Ok(KafkaConnection::ScramOverMtls( + ScramOverMtlsConnection::new( + connection, + old_connection, + authorize_scram_over_mtls, + )?, + )) + } else { + Ok(KafkaConnection::Regular(connection)) + } + } + + /// Attempts to receive messages, if there are no messages available it immediately returns an empty vec. + /// If there is a problem with the connection an error is returned. + pub fn try_recv_into(&mut self, responses: &mut Vec) -> Result<(), ConnectionError> { + match self { + KafkaConnection::Regular(c) => c.try_recv_into(responses), + KafkaConnection::ScramOverMtls(c) => c.try_recv_into(responses), + } + } + + /// Send messages. + /// If there is a problem with the connection an error is returned. + pub fn send(&mut self, messages: Vec) -> Result<(), ConnectionError> { + match self { + KafkaConnection::Regular(c) => c.send(messages), + KafkaConnection::ScramOverMtls(c) => c.send(messages), + } + } + + /// Receives messages, if there are no messages available it awaits until there are messages. + /// If there is a problem with the connection an error is returned. + pub async fn recv(&mut self) -> Result, ConnectionError> { + match self { + KafkaConnection::Regular(c) => c.recv().await, + KafkaConnection::ScramOverMtls(c) => c.recv().await, + } + } + + /// Number of requests waiting on a response. + /// The count includes requests that will have a dummy response generated by shotover. + pub fn pending_requests_count(&self) -> usize { + match self { + KafkaConnection::Regular(c) => c.pending_requests_count(), + KafkaConnection::ScramOverMtls(c) => c.pending_requests_count(), + } + } + + /// Returns either ConnectionState::Open or ConnectionState::AtRiskOfAuthTokenExpiry + pub fn state(&self, recent_instant: Instant) -> ConnectionState { + match self { + KafkaConnection::Regular(_) => ConnectionState::Open, + KafkaConnection::ScramOverMtls(c) => c.state(recent_instant), + } + } +} + +pub enum ConnectionState { + Open, + Unopened, + AtRiskOfAuthTokenExpiry, } diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index b591b4b62..61a640900 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -38,7 +38,7 @@ use std::collections::{HashMap, VecDeque}; use std::hash::Hasher; use std::sync::atomic::AtomicI64; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::sync::RwLock; use uuid::Uuid; @@ -400,6 +400,7 @@ impl KafkaSinkCluster { &self.nodes, &self.first_contact_points, &self.rack, + Instant::now(), Destination::ControlConnection, ) .await?; @@ -1013,6 +1014,7 @@ routing message to a random node so that: } } + let recent_instant = Instant::now(); for (destination, requests) in broker_to_routed_requests { self.connections .get_or_open_connection( @@ -1023,6 +1025,7 @@ routing message to a random node so that: &self.nodes, &self.first_contact_points, &self.rack, + recent_instant, destination, ) .await? @@ -1037,8 +1040,14 @@ routing message to a random node so that: fn recv_responses(&mut self) -> Result> { // Convert all received PendingRequestTy::Sent into PendingRequestTy::Received for (connection_destination, connection) in &mut self.connections.connections { - self.temp_responses_buffer.clear(); - if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) { + // skip recv when no pending requests to avoid timeouts on old connections + if connection.pending_requests_count() != 0 { + self.temp_responses_buffer.clear(); + connection + .try_recv_into(&mut self.temp_responses_buffer) + .with_context(|| { + format!("Failed to receive from {connection_destination:?}") + })?; for response in self.temp_responses_buffer.drain(..) { let mut response = Some(response); for pending_request in &mut self.pending_requests { diff --git a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs index 485dcb80c..4c10d0642 100644 --- a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs +++ b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs @@ -19,6 +19,7 @@ use tokio::sync::Notify; use tokio::sync::{mpsc, oneshot}; use tokio_stream::StreamExt; +pub(crate) mod connection; mod create_token; mod recreate_token_queue; @@ -225,18 +226,21 @@ impl AuthorizeScramOverMtlsConfig { .iter() .map(|x| KafkaAddress::from_str(x)) .collect(); + let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds); Ok(AuthorizeScramOverMtlsBuilder { token_task: TokenTask::new( mtls_connection_factory, contact_points?, - Duration::from_secs(self.delegation_token_lifetime_seconds), + delegation_token_lifetime, ), + delegation_token_lifetime, }) } } pub struct AuthorizeScramOverMtlsBuilder { pub token_task: TokenTask, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtlsBuilder { @@ -245,6 +249,7 @@ impl AuthorizeScramOverMtlsBuilder { original_scram_state: OriginalScramState::WaitingOnServerFirst, token_task: self.token_task.clone(), username: String::new(), + delegation_token_lifetime: self.delegation_token_lifetime, } } } @@ -256,6 +261,7 @@ pub struct AuthorizeScramOverMtls { token_task: TokenTask, /// The username used in the original scram auth to generate the delegation token username: String, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtls { diff --git a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls/connection.rs b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls/connection.rs new file mode 100644 index 000000000..cf0fdd11d --- /dev/null +++ b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls/connection.rs @@ -0,0 +1,162 @@ +use crate::{ + connection::{ConnectionError, SinkConnection}, + message::Message, + transforms::kafka::sink_cluster::connections::ConnectionState, +}; +use anyhow::{anyhow, Result}; +use std::time::{Duration, Instant}; + +use super::AuthorizeScramOverMtls; + +pub struct ScramOverMtlsConnection { + connection: SinkConnection, + /// When a connection is recreated to avoid timeouts, + /// the old connection will be kept around until all responses have been received from it. + old_connection: Option, + created_at: Instant, + timeout: Duration, +} + +impl ScramOverMtlsConnection { + pub fn new( + connection: SinkConnection, + old_connection: Option, + authorize_scram_over_mtls: &Option, + ) -> Result { + let old_connection = old_connection + .map(|x| x.into_old_connection()) + .transpose()? + .flatten(); + Ok(ScramOverMtlsConnection { + connection, + old_connection, + created_at: Instant::now(), + timeout: Self::calculate_timeout(authorize_scram_over_mtls), + }) + } + + fn into_old_connection(self) -> Result> { + if self.old_connection.is_some() { + return Err(anyhow!("The connection to be replaced had an old_connection. For this to occur a response needs to have been pending for longer than the timeout period which indicates other problems.")); + } + if self.connection.pending_requests_count() == 0 { + Ok(None) + } else { + Ok(Some(self.connection)) + } + } + + /// Attempts to receive messages, if there are no messages available it immediately returns an empty vec. + /// If there is a problem with the connection an error is returned. + pub fn try_recv_into(&mut self, responses: &mut Vec) -> Result<(), ConnectionError> { + // ensure old connection is completely drained before receiving from new connection + if let Some(old_connection) = &mut self.old_connection { + old_connection.try_recv_into(responses)?; + if old_connection.pending_requests_count() == 0 { + self.old_connection = None; + self.connection.try_recv_into(responses)?; + } + Ok(()) + } else { + self.connection.try_recv_into(responses) + } + } + + /// Send messages. + /// If there is a problem with the connection an error is returned. + pub fn send(&mut self, messages: Vec) -> Result<(), ConnectionError> { + self.connection.send(messages) + } + + /// Receives messages, if there are no messages available it awaits until there are messages. + /// If there is a problem with the connection an error is returned. + pub async fn recv(&mut self) -> Result, ConnectionError> { + // ensure old connection is completely drained before receiving from new connection + if let Some(old_connection) = &mut self.old_connection { + let mut received = old_connection.recv().await?; + if old_connection.pending_requests_count() == 0 { + self.old_connection = None; + // Do not use `recv` method here since we already have at least one message due to previous `recv`, + // so we avoid blocking by calling `try_recv_into` instead. + self.connection.try_recv_into(&mut received)?; + } + Ok(received) + } else { + self.connection.recv().await + } + } + + pub fn pending_requests_count(&self) -> usize { + self.connection.pending_requests_count() + + self + .old_connection + .as_ref() + .map(|x| x.pending_requests_count()) + .unwrap_or_default() + } + + fn calculate_timeout(authorize_scram_over_mtls: &Option) -> Duration { + // The delegation token is recreated after `0.5 * delegation_token_lifetime` + // Consider what happens when we match that timing for our connection timeout, + // in this timeline going from left to right: + // + // create token t1 create token t2 + // |--------------------|--------------------| + // | ^ all connections created after this point use token t2 instead of token t1 + // | | + // | token t1 lifetime | + // |-----------------------------------------| + // | ^ + // | after this point, connections still alive that were authed with token t1 will be closed by the broker. + // | | + // | | + // | | + // | token t2 lifetime + // | |-----------------------------------------| + // | ^ all connections created after this point use token t2 + // | | + // | | + // | | + // | connection lifetime using token t1 | + // | |--------------------| | + // This case is fine, the connection exists entirely within the lifetime of token t1. + // | | + // | | + // | | + // | connection lifetime using token t2 + // | |--------------------| + // This case is fine, the connection exists entirely within the lifetime of token t2. + // | | + // | | + // | | + // | connection lifetime using token t? + // | |--------------------| + // This case is a race condition. + // We could start with either token t2 or t1. + // If we start with t1 we could go past the end of t1's lifetime. + // To avoid this issue we reduce the size of the connection lifetime by a further 25% + // + // At low values of delegation_token_lifetime all of this falls apart since something + // like a VM migration could delay shotover execution for many seconds. + // However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine. + authorize_scram_over_mtls + .as_ref() + .unwrap() + .delegation_token_lifetime + .mul_f32( + // match token recreation time + 0.5 * + // further reduce connection timeout + 0.75, + ) + } + + pub fn state(&self, recent_instant: Instant) -> ConnectionState { + // Since we cant be 100% exact with time anyway, we use a recent instant that can be reused to reduce syscalls. + if recent_instant.duration_since(self.created_at) > self.timeout { + ConnectionState::AtRiskOfAuthTokenExpiry + } else { + ConnectionState::Open + } + } +}