diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index b9fbe1789..cb1499dc5 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -1,10 +1,16 @@ -use crate::connection::SinkConnection; -use anyhow::{Context, Result}; +use crate::{ + connection::{ConnectionError, SinkConnection}, + message::Message, +}; +use anyhow::{anyhow, Context, Result}; use fnv::FnvBuildHasher; use kafka_protocol::{messages::BrokerId, protocol::StrBytes}; use metrics::Counter; use rand::{rngs::SmallRng, seq::SliceRandom}; -use std::collections::HashMap; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; use super::{ node::{ConnectionFactory, KafkaAddress, KafkaNode}, @@ -27,8 +33,33 @@ pub enum Destination { ControlConnection, } +pub struct KafkaConnection { + pub connection: SinkConnection, + /// When a connection is recreated to avoid timeouts, + /// the old connection will be kept around until all responses have been received from it. + old_connection: Option, + created_at: Instant, +} + +impl KafkaConnection { + pub fn try_recv_into(&mut self, responses: &mut Vec) -> Result<(), ConnectionError> { + // ensure old connection is completely drained before receiving from new connection + if let Some(old_connection) = &mut self.old_connection { + old_connection.try_recv_into(responses)?; + if old_connection.pending_requests_count() == 0 { + self.old_connection = None; + Ok(()) + } else { + self.connection.try_recv_into(responses) + } + } else { + self.connection.try_recv_into(responses) + } + } +} + pub struct Connections { - pub connections: HashMap, + pub connections: HashMap, out_of_rack_requests: Counter, } @@ -50,8 +81,9 @@ impl Connections { nodes: &[KafkaNode], contact_points: &[KafkaAddress], local_rack: &StrBytes, + recent_instant: Instant, destination: Destination, - ) -> Result<&mut SinkConnection> { + ) -> Result<&mut KafkaConnection> { let node = match destination { Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()), Destination::ControlConnection => None, @@ -67,22 +99,166 @@ impl Connections { } } - // map entry API can not be used with async - #[allow(clippy::map_entry)] - if !self.connections.contains_key(&destination) { - let address = match &node { - Some(node) => &node.kafka_address, - None => contact_points.choose(rng).unwrap(), - }; - - self.connections.insert( - destination, - connection_factory + match self.get_connection_state(authorize_scram_over_mtls, recent_instant, destination) { + ConnectionState::Open => Ok(self.connections.get_mut(&destination).unwrap()), + ConnectionState::Unopened => { + self.create_connection( + rng, + connection_factory, + authorize_scram_over_mtls, + sasl_mechanism, + node, + contact_points, + None, + destination, + ) + .await + .context("Failed to create a new connection")?; + + Ok(self.connections.get_mut(&destination).unwrap()) + } + ConnectionState::AtRiskOfTimeout => { + let old_connection = self.connections.remove(&destination).unwrap(); + if old_connection.old_connection.is_some() { + return Err(anyhow!("Old connection had an old connection")); + } + let old_connection = if old_connection.connection.pending_requests_count() == 0 { + None + } else { + Some(old_connection.connection) + }; + + self.create_connection( + rng, + connection_factory, + authorize_scram_over_mtls, + sasl_mechanism, + node, + contact_points, + old_connection, + destination, + ) + .await + .context("Failed to create a new connection to replace an old connection")?; + + tracing::info!("Recreated outgoing connection due to risk of timeout"); + Ok(self.connections.get_mut(&destination).unwrap()) + } + } + } + + #[allow(clippy::too_many_arguments)] + async fn create_connection( + &mut self, + rng: &mut SmallRng, + connection_factory: &ConnectionFactory, + authorize_scram_over_mtls: &Option, + sasl_mechanism: &Option, + node: Option<&KafkaNode>, + contact_points: &[KafkaAddress], + old_connection: Option, + destination: Destination, + ) -> Result<()> { + let address = match &node { + Some(node) => &node.kafka_address, + None => contact_points.choose(rng).unwrap(), + }; + + self.connections.insert( + destination, + KafkaConnection { + connection: connection_factory .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) - .await - .context("Failed to create a new connection")?, - ); + .await?, + old_connection, + created_at: Instant::now(), + }, + ); + + Ok(()) + } + + fn get_connection_state( + &self, + authorize_scram_over_mtls: &Option, + recent_instant: Instant, + destination: Destination, + ) -> ConnectionState { + let timeout = if let Some(scram_over_mtls) = authorize_scram_over_mtls { + // The delegation token is recreated after `0.5 * delegation_token_lifetime` + // Consider what happens when we match that timing for our connection timeout here: + // + // create token t1 create token t2 + // |--------------------|--------------------| + // | ^ all connections created after this point use token t2 instead of token t1 + // | | + // | token t1 lifetime | + // |-----------------------------------------| + // | ^ + // | after this point, connections still alive that were authed with token t1 will be closed by the broker. + // | | + // | | + // | | + // | token t2 lifetime + // | |-----------------------------------------| + // | ^ all connections created after this point use token t2 + // | | + // | | + // | | + // | connection lifetime using token t1 | + // | |--------------------| | + // This case is fine, the connection exists entirely within the lifetime of token t1. + // | | + // | | + // | | + // | connection lifetime using token t2 + // | |--------------------| + // This case is fine, the connection exists entirely within the lifetime of token t2. + // | | + // | | + // | | + // | connection lifetime using token t? + // | |--------------------| + // This case is a race condition. + // We could start with either token t2 or t1. + // If we start with t1 we could go past the end of t1's lifetime. + // To avoid this issue we reduce the size of the connection lifetime by a further 25% + // + // At low values of delegation_token_lifetime all of this falls apart since something + // like a VM migration could delay shotover execution for many seconds. + // However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine. + scram_over_mtls + .delegation_token_lifetime + // match token recreation time + .mul_f32(0.5) + // further reduce connection timeout + .mul_f32(0.75) + } else { + // use 3/4 of the timeout to make sure we trigger this well before it actually times out + CONNECTIONS_MAX_IDLE_DEFAULT.mul_f32(0.75) + // TODO: relying on the default value to be unchanged is not ideal, so either: + // * query the broker for the actual value of connections.max.idle.ms + // * have the user configure it in shotover's topology.yaml + }; + if let Some(connection) = self.connections.get(&destination) { + // Since we cant be 100% exact with time anyway, we use a recent instant that can be reused to reduce syscalls. + if recent_instant.duration_since(connection.created_at) > timeout { + ConnectionState::AtRiskOfTimeout + } else { + ConnectionState::Open + } + } else { + ConnectionState::Unopened } - Ok(self.connections.get_mut(&destination).unwrap()) } } + +/// default value of kafka broker config connections.max.idle.ms (10 minutes) +const CONNECTIONS_MAX_IDLE_DEFAULT: Duration = Duration::from_secs(60 * 10); + +enum ConnectionState { + Open, + Unopened, + // TODO: maybe combine with Unopened since old_connection can just easily take the appropriate Option value + AtRiskOfTimeout, +} diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index a72844de2..c44ea4005 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -38,7 +38,7 @@ use std::collections::{HashMap, VecDeque}; use std::hash::Hasher; use std::sync::atomic::AtomicI64; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; use tokio::sync::RwLock; use uuid::Uuid; @@ -403,11 +403,12 @@ impl KafkaSinkCluster { &self.nodes, &self.first_contact_points, &self.rack, + Instant::now(), Destination::ControlConnection, ) .await?; - connection.send(vec![requests])?; - Ok(connection.recv().await?.remove(0)) + connection.connection.send(vec![requests])?; + Ok(connection.connection.recv().await?.remove(0)) } fn store_topic_names(&self, topics: &mut Vec, topic: TopicName) { @@ -1016,6 +1017,7 @@ routing message to a random node so that: } } + let recent_instant = Instant::now(); for (destination, requests) in broker_to_routed_requests { self.connections .get_or_open_connection( @@ -1026,9 +1028,11 @@ routing message to a random node so that: &self.nodes, &self.first_contact_points, &self.rack, + recent_instant, destination, ) .await? + .connection .send(requests.requests)?; } @@ -1041,24 +1045,23 @@ routing message to a random node so that: // Convert all received PendingRequestTy::Sent into PendingRequestTy::Received for (connection_destination, connection) in &mut self.connections.connections { self.temp_responses_buffer.clear(); - if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) { - for response in self.temp_responses_buffer.drain(..) { - let mut response = Some(response); - for pending_request in &mut self.pending_requests { - if let PendingRequestTy::Sent { destination, index } = - &mut pending_request.ty - { - if destination == connection_destination { - // Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent - // All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent - // to be used next time, and the time after that, and ... - if *index == 0 { - pending_request.ty = PendingRequestTy::Received { - response: response.take().unwrap(), - }; - } else { - *index -= 1; - } + connection + .try_recv_into(&mut self.temp_responses_buffer) + .with_context(|| format!("Failed to receive from {connection_destination:?}"))?; + for response in self.temp_responses_buffer.drain(..) { + let mut response = Some(response); + for pending_request in &mut self.pending_requests { + if let PendingRequestTy::Sent { destination, index } = &mut pending_request.ty { + if destination == connection_destination { + // Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent + // All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent + // to be used next time, and the time after that, and ... + if *index == 0 { + pending_request.ty = PendingRequestTy::Received { + response: response.take().unwrap(), + }; + } else { + *index -= 1; } } } diff --git a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs index 485dcb80c..70461a8e6 100644 --- a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs +++ b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs @@ -225,18 +225,21 @@ impl AuthorizeScramOverMtlsConfig { .iter() .map(|x| KafkaAddress::from_str(x)) .collect(); + let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds); Ok(AuthorizeScramOverMtlsBuilder { token_task: TokenTask::new( mtls_connection_factory, contact_points?, - Duration::from_secs(self.delegation_token_lifetime_seconds), + delegation_token_lifetime, ), + delegation_token_lifetime, }) } } pub struct AuthorizeScramOverMtlsBuilder { pub token_task: TokenTask, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtlsBuilder { @@ -245,6 +248,7 @@ impl AuthorizeScramOverMtlsBuilder { original_scram_state: OriginalScramState::WaitingOnServerFirst, token_task: self.token_task.clone(), username: String::new(), + delegation_token_lifetime: self.delegation_token_lifetime, } } } @@ -256,6 +260,7 @@ pub struct AuthorizeScramOverMtls { token_task: TokenTask, /// The username used in the original scram auth to generate the delegation token username: String, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtls {