diff --git a/shotover/src/transforms/kafka/sink_cluster/connections.rs b/shotover/src/transforms/kafka/sink_cluster/connections.rs index b9fbe1789..d01eb2a0e 100644 --- a/shotover/src/transforms/kafka/sink_cluster/connections.rs +++ b/shotover/src/transforms/kafka/sink_cluster/connections.rs @@ -1,10 +1,16 @@ -use crate::connection::SinkConnection; -use anyhow::{Context, Result}; +use crate::{ + connection::{ConnectionError, SinkConnection}, + message::Message, +}; +use anyhow::{anyhow, Context, Result}; use fnv::FnvBuildHasher; use kafka_protocol::{messages::BrokerId, protocol::StrBytes}; use metrics::Counter; use rand::{rngs::SmallRng, seq::SliceRandom}; -use std::collections::HashMap; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; use super::{ node::{ConnectionFactory, KafkaAddress, KafkaNode}, @@ -27,8 +33,31 @@ pub enum Destination { ControlConnection, } +pub struct KafkaConnection { + pub connection: SinkConnection, + old_connection: Option, + created_at: Instant, +} + +impl KafkaConnection { + pub fn try_recv_into(&mut self, responses: &mut Vec) -> Result<(), ConnectionError> { + // ensure old connection is completely drained before receiving from new connection + if let Some(old_connection) = &mut self.old_connection { + old_connection.try_recv_into(responses)?; + if old_connection.pending_requests_count() == 0 { + self.old_connection = None; + Ok(()) + } else { + self.connection.try_recv_into(responses) + } + } else { + self.connection.try_recv_into(responses) + } + } +} + pub struct Connections { - pub connections: HashMap, + pub connections: HashMap, out_of_rack_requests: Counter, } @@ -51,7 +80,7 @@ impl Connections { contact_points: &[KafkaAddress], local_rack: &StrBytes, destination: Destination, - ) -> Result<&mut SinkConnection> { + ) -> Result<&mut KafkaConnection> { let node = match destination { Destination::Id(id) => Some(nodes.iter().find(|x| x.broker_id == id).unwrap()), Destination::ControlConnection => None, @@ -67,22 +96,137 @@ impl Connections { } } - // map entry API can not be used with async - #[allow(clippy::map_entry)] - if !self.connections.contains_key(&destination) { - let address = match &node { - Some(node) => &node.kafka_address, - None => contact_points.choose(rng).unwrap(), - }; + match self.get_connection_state(authorize_scram_over_mtls, destination) { + ConnectionState::Open => Ok(self.connections.get_mut(&destination).unwrap()), + ConnectionState::Unopened => { + let address = match &node { + Some(node) => &node.kafka_address, + None => contact_points.choose(rng).unwrap(), + }; + + self.connections.insert( + destination, + KafkaConnection { + connection: connection_factory + .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) + .await + .context("Failed to create a new connection")?, + old_connection: None, + created_at: Instant::now(), + }, + ); + Ok(self.connections.get_mut(&destination).unwrap()) + } + ConnectionState::AtRiskOfTimeout => { + let address = match &node { + Some(node) => &node.kafka_address, + None => contact_points.choose(rng).unwrap(), + }; + + let old_connection = self.connections.remove(&destination).unwrap(); + if old_connection.old_connection.is_some() { + return Err(anyhow!("Old connection had an old connection")); + } + let old_connection = if old_connection.connection.pending_requests_count() == 0 { + None + } else { + Some(old_connection.connection) + }; - self.connections.insert( - destination, - connection_factory - .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) - .await - .context("Failed to create a new connection")?, - ); + self.connections.insert( + destination, + KafkaConnection { + connection: connection_factory + .create_connection(address, authorize_scram_over_mtls, sasl_mechanism) + .await + .context("Failed to create a new connection")?, + old_connection, + created_at: Instant::now(), + }, + ); + tracing::info!("Recreated outgoing connection due to risk of timeout"); + Ok(self.connections.get_mut(&destination).unwrap()) + } } - Ok(self.connections.get_mut(&destination).unwrap()) } + + fn get_connection_state( + &self, + authorize_scram_over_mtls: &Option, + destination: Destination, + ) -> ConnectionState { + let timeout = if let Some(scram_over_mtls) = authorize_scram_over_mtls { + // The delegation token is recreated after `0.5 * delegation_token_lifetime` + // Consider what happens when we match that timing for our connection timeout here: + // + // create token t1 create token t2 + // |--------------------|--------------------| + // | ^ all connections created after this point use token t2 instead of token t1 + // | | + // | token t1 lifetime | + // |-----------------------------------------| + // | ^ + // | after this point, connections still alive that were authed with token t1 will be closed by the broker. + // | | + // | | + // | | + // | token t2 lifetime + // | |-----------------------------------------| + // | ^ all connections created after this point use token t2 + // | | + // | | + // | | + // | connection lifetime using token t1 | + // | |--------------------| | + // This case is fine, the connection exists entirely within the lifetime of token t1. + // | | + // | | + // | | + // | connection lifetime using token t2 + // | |--------------------| + // This case is fine, the connection exists entirely within the lifetime of token t2. + // | | + // | | + // | | + // | connection lifetime using token t2 (or is token t1?) + // | |--------------------| + // This case is potentially a race condition. + // We could start with either token t2 or t1. + // If we start with t1 we could go past the end of t1's lifetime. + // To avoid this issue we reduce the size of the connection lifetime by a further 0.75 seconds. + // + // At low values of delegation_token_lifetime all of this falls apart since something + // like a VM migration could delay shotover execution for many seconds. + // However for sufficently large delegation_token_lifetime values (> 1 hour) this should be fine. + scram_over_mtls + .delegation_token_lifetime + // match token recreation time + .mul_f32(0.5) + // further reduce connection timeout + .mul_f32(0.75) + } else { + // TODO: make this configurable + // for now, use the default value of connections.max.idle.ms (10 minutes) + let connections_max_idle = Duration::from_secs(60 * 10); + connections_max_idle.mul_f32(0.75) + }; + if let Some(connection) = self.connections.get(&destination) { + // TODO: subtract a batch level Instant::now instead of using elapsed + // use 3/4 of the timeout to make sure we trigger this well before it actually times out + if connection.created_at.elapsed() > timeout { + ConnectionState::AtRiskOfTimeout + } else { + ConnectionState::Open + } + } else { + ConnectionState::Unopened + } + } +} + +enum ConnectionState { + Open, + Unopened, + // TODO: maybe combine with Unopened since old_connection can just easily take the appropriate Option value + AtRiskOfTimeout, } diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index a5f2ee714..4c1c38cf2 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -403,8 +403,8 @@ impl KafkaSinkCluster { Destination::ControlConnection, ) .await?; - connection.send(vec![requests])?; - Ok(connection.recv().await?.remove(0)) + connection.connection.send(vec![requests])?; + Ok(connection.connection.recv().await?.remove(0)) } fn store_topic_names(&self, topics: &mut Vec, topic: TopicName) { @@ -1026,6 +1026,7 @@ routing message to a random node so that: destination, ) .await? + .connection .send(requests.requests)?; } @@ -1038,24 +1039,23 @@ routing message to a random node so that: // Convert all received PendingRequestTy::Sent into PendingRequestTy::Received for (connection_destination, connection) in &mut self.connections.connections { self.temp_responses_buffer.clear(); - if let Ok(()) = connection.try_recv_into(&mut self.temp_responses_buffer) { - for response in self.temp_responses_buffer.drain(..) { - let mut response = Some(response); - for pending_request in &mut self.pending_requests { - if let PendingRequestTy::Sent { destination, index } = - &mut pending_request.ty - { - if destination == connection_destination { - // Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent - // All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent - // to be used next time, and the time after that, and ... - if *index == 0 { - pending_request.ty = PendingRequestTy::Received { - response: response.take().unwrap(), - }; - } else { - *index -= 1; - } + connection + .try_recv_into(&mut self.temp_responses_buffer) + .with_context(|| format!("Failed to receive from {connection_destination:?}"))?; + for response in self.temp_responses_buffer.drain(..) { + let mut response = Some(response); + for pending_request in &mut self.pending_requests { + if let PendingRequestTy::Sent { destination, index } = &mut pending_request.ty { + if destination == connection_destination { + // Store the PendingRequestTy::Received at the location of the next PendingRequestTy::Sent + // All other PendingRequestTy::Sent need to be decremented, in order to determine the PendingRequestTy::Sent + // to be used next time, and the time after that, and ... + if *index == 0 { + pending_request.ty = PendingRequestTy::Received { + response: response.take().unwrap(), + }; + } else { + *index -= 1; } } } diff --git a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs index 485dcb80c..70461a8e6 100644 --- a/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs +++ b/shotover/src/transforms/kafka/sink_cluster/scram_over_mtls.rs @@ -225,18 +225,21 @@ impl AuthorizeScramOverMtlsConfig { .iter() .map(|x| KafkaAddress::from_str(x)) .collect(); + let delegation_token_lifetime = Duration::from_secs(self.delegation_token_lifetime_seconds); Ok(AuthorizeScramOverMtlsBuilder { token_task: TokenTask::new( mtls_connection_factory, contact_points?, - Duration::from_secs(self.delegation_token_lifetime_seconds), + delegation_token_lifetime, ), + delegation_token_lifetime, }) } } pub struct AuthorizeScramOverMtlsBuilder { pub token_task: TokenTask, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtlsBuilder { @@ -245,6 +248,7 @@ impl AuthorizeScramOverMtlsBuilder { original_scram_state: OriginalScramState::WaitingOnServerFirst, token_task: self.token_task.clone(), username: String::new(), + delegation_token_lifetime: self.delegation_token_lifetime, } } } @@ -256,6 +260,7 @@ pub struct AuthorizeScramOverMtls { token_task: TokenTask, /// The username used in the original scram auth to generate the delegation token username: String, + pub delegation_token_lifetime: Duration, } impl AuthorizeScramOverMtls {