diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 0934710fd..80d8005f0 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -27,10 +27,11 @@ use rand::SeedableRng; use redis_protocol::types::Redirection; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::Arc; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, RwLock}; use tokio::time::{timeout, Duration}; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, error, trace, warn}; const SLOT_SIZE: usize = 16384; @@ -50,35 +51,73 @@ pub struct RedisSinkClusterConfig { #[async_trait(?Send)] impl TransformConfig for RedisSinkClusterConfig { async fn get_builder(&self, chain_name: String) -> Result> { - let mut cluster = RedisSinkCluster::new( - self.first_contact_points.clone(), - self.direct_destination.clone(), - self.connection_count.unwrap_or(1), + let connection_pool = ConnectionPool::new_with_auth( + Duration::from_millis(self.connect_timeout_ms), + RedisCodecBuilder::new(Direction::Sink), + RedisAuthenticator {}, self.tls.clone(), - self.connect_timeout_ms, - chain_name, )?; + Ok(Box::new(RedisSinkClusterBuilder { + first_contact_points: self.first_contact_points.clone(), + direct_destination: self.direct_destination.clone(), + connection_count: self.connection_count.unwrap_or(1), + connection_pool, + chain_name, + shared_topology: Arc::new(RwLock::new(Topology::new())), + })) + } +} - match cluster.build_connections(None).await { - Ok(()) => { - info!("connected to upstream"); - } - Err(TransformError::Upstream(RedisError::NotAuthenticated)) => { - info!("upstream requires auth"); - } - Err(e) => { - return Err(anyhow!(e).context("failed to connect to upstream")); - } - } +pub struct RedisSinkClusterBuilder { + first_contact_points: Vec, + direct_destination: Option, + connection_count: usize, + connection_pool: ConnectionPool, + chain_name: String, + shared_topology: Arc>, +} + +impl TransformBuilder for RedisSinkClusterBuilder { + fn build(&self) -> Transforms { + Transforms::RedisSinkCluster(RedisSinkCluster::new( + self.first_contact_points.clone(), + self.direct_destination.clone(), + self.connection_count, + self.chain_name.clone(), + self.shared_topology.clone(), + self.connection_pool.clone(), + )) + } + + fn get_name(&self) -> &'static str { + "RedisSinkCluster" + } + + fn is_terminating(&self) -> bool { + true + } +} + +#[derive(Debug, Clone)] +struct Topology { + slots: SlotMap, + channels: ChannelMap, +} - Ok(Box::new(cluster)) +impl Topology { + fn new() -> Self { + Topology { + slots: SlotMap::new(), + channels: ChannelMap::new(), + } } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct RedisSinkCluster { - pub slots: SlotMap, - pub channels: ChannelMap, + has_run_init: bool, + topology: Topology, + shared_topology: Arc>, direct_connection: Option>, load_scores: HashMap<(String, usize), usize>, rng: SmallRng, @@ -92,42 +131,37 @@ pub struct RedisSinkCluster { } impl RedisSinkCluster { - pub fn new( + fn new( first_contact_points: Vec, direct_destination: Option, connection_count: usize, - tls: Option, - connect_timeout_ms: u64, chain_name: String, - ) -> Result { - let authenticator = RedisAuthenticator {}; - - let connect_timeout = Duration::from_millis(connect_timeout_ms); - let connection_pool = ConnectionPool::new_with_auth( - connect_timeout, - RedisCodecBuilder::new(Direction::Sink), - authenticator, - tls, - )?; - + shared_topology: Arc>, + connection_pool: ConnectionPool< + RedisCodecBuilder, + RedisAuthenticator, + UsernamePasswordToken, + >, + ) -> Self { let sink_cluster = RedisSinkCluster { + has_run_init: false, first_contact_points, direct_destination, - slots: SlotMap::new(), - channels: ChannelMap::new(), + topology: Topology::new(), + shared_topology, direct_connection: None, load_scores: HashMap::new(), rng: SmallRng::from_rng(rand::thread_rng()).unwrap(), connection_count, connection_pool, reason_for_no_nodes: None, - rebuild_connections: false, + rebuild_connections: true, token: None, }; register_counter!("failed_requests", "chain" => chain_name, "transform" => sink_cluster.get_name()); - Ok(sink_cluster) + sink_cluster } async fn direct_connection(&mut self) -> Result<&UnboundedSender> { @@ -173,7 +207,7 @@ impl RedisSinkCluster { slot: u16, message: Message, ) -> Result { - if let Some((_, lookup)) = self.slots.masters.range(&slot..).next() { + if let Some((_, lookup)) = self.topology.slots.masters.range(&slot..).next() { let lookup = lookup.to_string(); let one_rx = self.choose_and_send(&lookup, message).await?; Ok(Box::pin( @@ -254,9 +288,14 @@ impl RedisSinkCluster { } fn latest_contact_points(&self) -> Vec<&str> { - if !self.slots.nodes.is_empty() { + if !self.topology.slots.nodes.is_empty() { // Use latest node addresses as contact points. - self.slots.nodes.iter().map(|x| x.as_str()).collect() + self.topology + .slots + .nodes + .iter() + .map(|x| x.as_str()) + .collect() } else { // Fallback to initial contact points. self.first_contact_points @@ -312,10 +351,12 @@ impl RedisSinkCluster { match self.build_connections_inner(&token).await { Ok((slots, channels)) => { debug!("connected to cluster: {:?}", channels.keys()); + self.topology = Topology { slots, channels }; + if token.is_none() { + // when authentication isnt used we can share topology between connections + *self.shared_topology.write().await = self.topology.clone(); + } self.token = token; - self.slots = slots; - self.channels = channels; - self.reason_for_no_nodes = None; self.rebuild_connections = false; Ok(()) @@ -365,7 +406,7 @@ impl RedisSinkCluster { #[inline] async fn choose_and_send(&mut self, host: &str, message: Message) -> Result { - let channel = match self.channels.get_mut(host) { + let channel = match self.topology.channels.get_mut(host) { Some(channels) if channels.len() == 1 => channels.get_mut(0), Some(channels) if channels.len() > 1 => { let candidates = rand::seq::index::sample(&mut self.rng, channels.len(), 2); @@ -409,8 +450,13 @@ impl RedisSinkCluster { { Ok(Ok(connections)) => { debug!("Found {} live connections for {}", connections.len(), host); - self.channels.insert(host.to_string(), connections); - self.channels.get_mut(host).unwrap().get_mut(0).unwrap() + self.topology.channels.insert(host.to_string(), connections); + self.topology + .channels + .get_mut(host) + .unwrap() + .get_mut(0) + .unwrap() } Ok(Err(e)) => { debug!("failed to connect to {}: {}", host, e); @@ -434,7 +480,7 @@ impl RedisSinkCluster { .is_err() { self.rebuild_connections = true; - self.channels.remove(host); + self.topology.channels.remove(host); return self.short_circuit_with_error(); } @@ -450,7 +496,7 @@ impl RedisSinkCluster { RoutingInfo::Slot(slot) => self.send_message_to_slot(slot, message).await, RoutingInfo::AllNodes(_) => { self.send_message_to_channels( - &self.slots.nodes.iter().cloned().collect_vec(), + &self.topology.slots.nodes.iter().cloned().collect_vec(), message, routing_info, ) @@ -458,7 +504,7 @@ impl RedisSinkCluster { } RoutingInfo::AllMasters(_) => { self.send_message_to_channels( - &self.slots.masters.values().cloned().collect_vec(), + &self.topology.slots.masters.values().cloned().collect_vec(), message, routing_info, ) @@ -466,6 +512,7 @@ impl RedisSinkCluster { } RoutingInfo::Random => { let lookup = self + .topology .slots .masters .values() @@ -932,26 +979,36 @@ fn short_circuit(frame: RedisFrame) -> Result { })) } -impl TransformBuilder for RedisSinkCluster { - fn build(&self) -> Transforms { - Transforms::RedisSinkCluster(self.clone()) - } - - fn get_name(&self) -> &'static str { - "RedisSinkCluster" - } - - fn is_terminating(&self) -> bool { - true - } -} - #[async_trait] impl Transform for RedisSinkCluster { async fn transform<'a>(&'a mut self, requests_wrapper: Wrapper<'a>) -> Result { + if !self.has_run_init { + self.topology = (*self.shared_topology.read().await).clone(); + if self.topology.channels.is_empty() { + // The code paths for authenticated and unauthenticated redis are quite different. + // * For unauthenticated redis this initial build_connections should succeed. + // + This is required to process the messages we are about to receive. + // + We also share the results to skip having to run build_connections again for new connection + // * For authenticated redis this initial build_connections always fails + // + The first message to come through should be an AUTH command which will give us the credentials required for us to run build_connections. + // As soon as we receive it we will rerun build_connections so we can process other message types afterwards. + // + It is important we do not share the results of the successful build_connections as that would leak authenticated shotover<->redis connections to other client<->shotover connections. + if let Err(err) = self.build_connections(self.token.clone()).await { + match err { + TransformError::Upstream(RedisError::NotAuthenticated) => { + // Build_connections sent an internal `CLUSTER SLOTS` command to redis and redis refused to respond because it is enforcing authentication. + // When the client sends an AUTH message we will rerun build_connections. + } + _ => tracing::warn!("Error when building connections: {err:?}"), + } + } + } + self.has_run_init = true; + } + if self.rebuild_connections { if let Err(err) = self.build_connections(self.token.clone()).await { - tracing::warn!("Error when rebuilding connections {err:?}"); + tracing::warn!("Error when rebuilding connections: {err:?}"); } } @@ -989,7 +1046,7 @@ impl Transform for RedisSinkCluster { debug!("Got MOVE {} {}", slot, server); // The destination of a MOVE should always be a master. - self.slots.masters.insert(slot, server.clone()); + self.topology.slots.masters.insert(slot, server.clone()); self.rebuild_connections = true;