diff --git a/redis/Cargo.toml b/redis/Cargo.toml index 2de7272cc6..95c7091bcb 100644 --- a/redis/Cargo.toml +++ b/redis/Cargo.toml @@ -43,6 +43,8 @@ pin-project-lite = { version = "0.2", optional = true } tokio-util = { version = "0.7", optional = true } tokio = { version = "1", features = ["rt", "net", "time"], optional = true } socket2 = { version = "0.4", default-features = false, optional = true } +fast-math = { version = "0.1.1", optional = true } +dispose = { version = "0.5.0", optional = true } # Only needed for the connection manager arc-swap = { version = "1.1.0", optional = true } @@ -92,7 +94,7 @@ arcstr = "1.1.5" [features] default = ["acl", "streams", "geospatial", "script", "keep-alive"] acl = [] -aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "tokio/sync", "combine/tokio", "async-trait", "futures-time"] +aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "tokio/sync", "combine/tokio", "async-trait", "futures-time", "fast-math", "dispose"] geospatial = [] json = ["serde", "serde/derive", "serde_json"] cluster = ["crc16", "rand", "derivative"] diff --git a/redis/src/cluster_async/connections_container.rs b/redis/src/cluster_async/connections_container.rs index c9cf0c4ff2..3e6a10dff8 100644 --- a/redis/src/cluster_async/connections_container.rs +++ b/redis/src/cluster_async/connections_container.rs @@ -5,7 +5,7 @@ use arcstr::ArcStr; use rand::seq::IteratorRandom; use crate::cluster_routing::{MultipleNodeRoutingInfo, Route, SlotAddr}; -use crate::cluster_topology::{ReadFromReplicaStrategy, SlotMap, SlotMapValue}; +use crate::cluster_topology::{ReadFromReplicaStrategy, SlotMap, SlotMapValue, TopologyHash}; type IdentifierType = ArcStr; @@ -32,6 +32,7 @@ pub(crate) struct ConnectionsContainer { connection_map: HashMap>>, slot_map: SlotMap, read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, } impl Default for ConnectionsContainer { @@ -40,6 +41,7 @@ impl Default for ConnectionsContainer { connection_map: Default::default(), slot_map: Default::default(), read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, + topology_hash: 0, } } } @@ -54,6 +56,7 @@ where slot_map: SlotMap, connection_map: HashMap>, read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, ) -> Self { Self { connection_map: connection_map @@ -62,6 +65,7 @@ where .collect(), slot_map, read_from_replica_strategy, + topology_hash, } } @@ -221,6 +225,10 @@ where .filter(|(_, conn_option)| conn_option.is_some()) .count() } + + pub(crate) fn get_current_topology_hash(&self) -> TopologyHash { + self.topology_hash + } } #[cfg(test)] @@ -310,6 +318,7 @@ mod tests { slot_map, connection_map, read_from_replica_strategy: stragey, + topology_hash: 0, } } diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index aa6585ab25..14049f7182 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -32,7 +32,7 @@ use std::{ net::{IpAddr, SocketAddr}, pin::Pin, sync::{ - atomic::{self, AtomicUsize}, + atomic::{self, AtomicUsize, Ordering}, Arc, Mutex, }, task::{self, Poll}, @@ -48,12 +48,13 @@ use crate::{ SingleNodeRoutingInfo, }, cluster_topology::{ - calculate_topology, DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, - DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT, + calculate_topology, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT, }, Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, Value, }; +use std::time::Duration; #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] use crate::aio::{async_std::AsyncStd, RedisRuntime}; @@ -68,6 +69,7 @@ use backoff_tokio::future::retry; #[cfg(feature = "tokio-comp")] use backoff_tokio::{Error, ExponentialBackoff}; +use dispose::{Disposable, Dispose}; use futures::{ future::{self, BoxFuture}, prelude::*, @@ -75,6 +77,7 @@ use futures::{ }; use futures_time::future::FutureExt; use pin_project_lite::pin_project; +use std::sync::atomic::AtomicBool; use tokio::sync::{ mpsc, oneshot::{self, Receiver}, @@ -198,6 +201,7 @@ struct InnerCore { conn_lock: RwLock>, cluster_params: ClusterParams, pending_requests: Mutex>>, + slot_refresh_in_progress: AtomicBool, } type Core = Arc>; @@ -212,6 +216,14 @@ struct ClusterConnInner { >, >, refresh_error: Option, + // A flag indicating the connection's closure and the requirement to shut down all related tasks. + shutdown_flag: Arc, +} + +impl Dispose for ClusterConnInner { + fn dispose(self) { + self.shutdown_flag.store(true, Ordering::Relaxed); + } } #[derive(Clone)] @@ -478,25 +490,42 @@ where async fn new( initial_nodes: &[ConnectionInfo], cluster_params: ClusterParams, - ) -> RedisResult { + ) -> RedisResult> { let connections = Self::create_initial_connections(initial_nodes, &cluster_params).await?; + let topology_checks_interval = cluster_params.topology_checks_interval; let inner = Arc::new(InnerCore { conn_lock: RwLock::new(ConnectionsContainer::new( Default::default(), connections, cluster_params.read_from_replicas, + 0, )), cluster_params, pending_requests: Mutex::new(Vec::new()), + slot_refresh_in_progress: AtomicBool::new(false), }); - let mut connection = ClusterConnInner { + let shutdown_flag = Arc::new(AtomicBool::new(false)); + let connection = ClusterConnInner { inner, in_flight_requests: Default::default(), refresh_error: None, state: ConnectionState::PollComplete, + shutdown_flag: shutdown_flag.clone(), }; - connection.refresh_slots_with_retries().await?; - Ok(connection) + Self::refresh_slots_with_retries(connection.inner.clone()).await?; + if let Some(duration) = topology_checks_interval { + let periodic_task = ClusterConnInner::periodic_topology_check( + connection.inner.clone(), + duration, + shutdown_flag, + ); + #[cfg(feature = "tokio-comp")] + tokio::spawn(periodic_task); + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + AsyncStd::spawn(periodic_task); + } + + Ok(Disposable::new(connection)) } /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. @@ -700,28 +729,91 @@ where } // Query a node to discover slot-> master mappings with retries - fn refresh_slots_with_retries(&mut self) -> impl Future> { - let inner = self.inner.clone(); - async move { + async fn refresh_slots_with_retries(inner: Arc>) -> RedisResult<()> { + if inner + .slot_refresh_in_progress + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return Ok(()); + } + let retry_strategy = ExponentialBackoff { + initial_interval: DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, + max_interval: DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT, + ..Default::default() + }; + let retries_counter = AtomicUsize::new(0); + let res = retry(retry_strategy, || { + let curr_retry = retries_counter.fetch_add(1, atomic::Ordering::Relaxed); + Self::refresh_slots(inner.clone(), curr_retry).map_err(Error::from) + }) + .await; + inner + .slot_refresh_in_progress + .store(false, Ordering::Relaxed); + res + } + + async fn periodic_topology_check( + inner: Arc>, + interval_duration: Duration, + shutdown_flag: Arc, + ) { + loop { + if shutdown_flag.load(Ordering::Relaxed) { + return; + } + #[cfg(feature = "tokio-comp")] + tokio::time::sleep(interval_duration).await; + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + async_std::task::sleep(interval_duration).await; let retry_strategy = ExponentialBackoff { initial_interval: DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, max_interval: DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT, ..Default::default() }; - let retries_counter = AtomicUsize::new(0); - retry(retry_strategy, || { - retries_counter.fetch_add(1, atomic::Ordering::Relaxed); - Self::refresh_slots( - inner.clone(), - retries_counter.load(atomic::Ordering::Relaxed), - ) - .map_err(Error::from) + let topology_check_res = retry(retry_strategy, || { + Self::check_for_topology_diff(inner.clone()).map_err(Error::from) }) - .await?; - Ok(()) + .await; + if topology_check_res.is_ok() && topology_check_res.unwrap() { + let _ = Self::refresh_slots_with_retries(inner.clone()).await; + }; } } + /// Queries log2n nodes (where n represents the number of cluster nodes) to determine whether their + /// topology view differs from the one currently stored in the connection manager. + /// Returns true if change was detected, otherwise false. + async fn check_for_topology_diff(inner: Arc>) -> RedisResult { + let read_guard = inner.conn_lock.read().await; + let num_of_nodes: usize = read_guard.len(); + // TODO: Starting from Rust V1.67, integers has logarithms support. + // When we no longer need to support Rust versions < 1.67, remove fast_math and transition to the ilog2 function. + let num_of_nodes_to_query = + std::cmp::max(fast_math::log2_raw(num_of_nodes as f32) as usize, 1); + let requested_nodes = read_guard.random_connections(num_of_nodes_to_query); + let topology_join_results = + futures::future::join_all(requested_nodes.map(|conn| async move { + let mut conn: C = conn.1.await; + conn.req_packed_command(&slot_cmd()).await + })) + .await; + let topology_values: Vec<_> = topology_join_results + .into_iter() + .filter_map(|r| r.ok()) + .collect(); + let (_, found_topology_hash) = calculate_topology( + topology_values, + DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + inner.cluster_params.tls, + num_of_nodes_to_query, + inner.cluster_params.read_from_replicas, + )?; + let change_found = read_guard.get_current_topology_hash() != found_topology_hash; + Ok(change_found) + } + // Query a node to discover slot-> master mappings async fn refresh_slots(inner: Arc>, curr_retry: usize) -> RedisResult<()> { let read_guard = inner.conn_lock.read().await; @@ -739,7 +831,7 @@ where .into_iter() .filter_map(|r| r.ok()) .collect(); - let new_slots = calculate_topology( + let (new_slots, topology_hash) = calculate_topology( topology_values, curr_retry, inner.cluster_params.tls, @@ -799,6 +891,7 @@ where new_slots, new_connections, inner.cluster_params.read_from_replicas, + topology_hash, ); Ok(()) } @@ -1075,7 +1168,7 @@ where } Poll::Ready(Err(err)) => { self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( - self.refresh_slots_with_retries(), + Self::refresh_slots_with_retries(self.inner.clone()), ))); Poll::Ready(Err(err)) } @@ -1250,7 +1343,7 @@ impl PollFlushAction { } } -impl Sink> for ClusterConnInner +impl Sink> for Disposable> where C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, { @@ -1332,9 +1425,10 @@ where ConnectionState::PollComplete => match ready!(self.poll_complete(cx)) { PollFlushAction::None => return Poll::Ready(Ok(())), PollFlushAction::RebuildSlots => { - self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots( - Box::pin(self.refresh_slots_with_retries()), - )); + self.state = + ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + ClusterConnInner::refresh_slots_with_retries(self.inner.clone()), + ))); } PollFlushAction::Reconnect(identifiers) => { self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 78685a70a7..8a42573073 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -21,6 +21,7 @@ struct BuilderParams { tls: Option, retries_configuration: RetryParams, connection_timeout: Option, + topology_checks_interval: Option, } #[derive(Clone)] @@ -72,6 +73,7 @@ pub(crate) struct ClusterParams { pub(crate) tls: Option, pub(crate) retry_params: RetryParams, pub(crate) connection_timeout: Duration, + pub(crate) topology_checks_interval: Option, } impl From for ClusterParams { @@ -83,6 +85,7 @@ impl From for ClusterParams { tls: value.tls, retry_params: value.retries_configuration, connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), + topology_checks_interval: value.topology_checks_interval, } } } @@ -250,6 +253,17 @@ impl ClusterClientBuilder { self } + /// Enables periodic topology checks for this client. + /// + /// If enabled, periodic topology checks will be executed at the configured intervals to examine whether there + /// have been any changes in the cluster's topology. If a change is detected, it will trigger a slot refresh. + /// Unlike slot refreshments, the periodic topology checks only examine a limited number of nodes to query their + /// topology, ensuring that the check remains quick and efficient. + pub fn periodic_topology_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.topology_checks_interval = Some(interval); + self + } + /// Use `build()`. #[deprecated(since = "0.22.0", note = "Use build()")] pub fn open(self) -> RedisResult { diff --git a/redis/src/cluster_topology.rs b/redis/src/cluster_topology.rs index aa138e5b26..bd0850cea6 100644 --- a/redis/src/cluster_topology.rs +++ b/redis/src/cluster_topology.rs @@ -18,12 +18,13 @@ pub const DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT: Duration = Duration::from_secs(1) pub const DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL: Duration = Duration::from_millis(100); pub(crate) const SLOT_SIZE: u16 = 16384; +pub(crate) type TopologyHash = u64; #[derive(Derivative)] #[derivative(PartialEq, Eq)] #[derive(Debug)] pub(crate) struct TopologyView { - pub(crate) hash_value: u64, + pub(crate) hash_value: TopologyHash, #[derivative(PartialEq = "ignore")] pub(crate) topology_value: Value, #[derivative(PartialEq = "ignore")] @@ -282,7 +283,7 @@ pub(crate) fn calculate_topology( tls_mode: Option, num_of_queried_nodes: usize, read_from_replica: ReadFromReplicaStrategy, -) -> Result { +) -> Result<(SlotMap, TopologyHash), RedisError> { if topology_views.is_empty() { return Err(RedisError::from(( ErrorKind::ResponseError, @@ -355,7 +356,10 @@ pub(crate) fn calculate_topology( "Failed to parse the slots on the majority view", )))?; - Ok(SlotMap::new(slots_data, read_from_replica)) + Ok(( + SlotMap::new(slots_data, read_from_replica), + most_frequent_topology.hash_value, + )) }; if non_unique_max_node_count { @@ -471,7 +475,7 @@ mod tests { get_view(&ViewType::SingleNodeViewFullCoverage), get_view(&ViewType::TwoNodesViewFullCoverage), ]; - let topology_view = calculate_topology( + let (topology_view, _) = calculate_topology( topology_results, 1, None, @@ -513,7 +517,7 @@ mod tests { get_view(&ViewType::TwoNodesViewFullCoverage), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology( + let (topology_view, _) = calculate_topology( topology_results, 3, None, @@ -536,7 +540,7 @@ mod tests { get_view(&ViewType::TwoNodesViewFullCoverage), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology( + let (topology_view, _) = calculate_topology( topology_results, 1, None, @@ -560,7 +564,7 @@ mod tests { get_view(&ViewType::SingleNodeViewMissingSlots), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology( + let (topology_view, _) = calculate_topology( topology_results, 1, None, @@ -584,7 +588,7 @@ mod tests { get_view(&ViewType::TwoNodesViewMissingSlots), get_view(&ViewType::SingleNodeViewMissingSlots), ]; - let topology_view = calculate_topology( + let (topology_view, _) = calculate_topology( topology_results, 1, None, diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 6581737cd8..b14555fdc1 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -1,5 +1,6 @@ #![cfg(feature = "cluster-async")] mod support; +use std::assert; use std::net::{IpAddr, SocketAddr}; use std::sync::{ atomic::{self, AtomicI32, AtomicU16}, @@ -7,9 +8,15 @@ use std::sync::{ Arc, }; +use crate::support::*; use futures::prelude::*; use futures::stream; +use futures_time::task::sleep; use once_cell::sync::Lazy; +use redis::cluster_routing::Route; +use redis::cluster_routing::SingleNodeRoutingInfo; +use redis::cluster_routing::SlotAddr; + use redis::{ aio::{ConnectionLike, MultiplexedConnection}, cluster::ClusterClient, @@ -19,9 +26,8 @@ use redis::{ cmd, parse_redis_value, AsyncCommands, Cmd, ErrorKind, InfoDict, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, Script, Value, }; - -use crate::support::*; - +use std::str::from_utf8; +use std::time::Duration; #[test] fn test_async_cluster_basic_cmd() { let cluster = TestClusterContext::new(3, 0); @@ -1709,3 +1715,89 @@ fn test_async_cluster_round_robin_read_from_replica() { found_ports.lock().unwrap().sort(); assert_eq!(*found_ports.lock().unwrap(), vec![6380, 6381, 6383, 6384]); } + +fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option { + // Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master. + // Otherwise, returns None. + match cluster_nodes_output { + Value::Data(val) => match from_utf8(&val) { + Ok(str_res) => { + let parts: Vec<&str> = str_res.split('\n').collect(); + for node_entry in parts { + if node_entry.contains("myself") && node_entry.contains("master") { + let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); + let node_id = node_entry_parts[0]; + return Some(node_id.to_string()); + } + } + None + } + Err(e) => panic!("failed to decode INFO response: {:?}", e), + }, + _ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output), + } +} +#[test] +fn test_async_cluster_periodic_checks_update_topology_after_failover() { + // This test aims to validate the functionality of periodic topology checks by detecting and updating topology changes. + // We will repeatedly execute CLUSTER NODES commands against the primary node responsible for slot 0, recording its node ID. + // Once we've successfully completed commands with the current primary, we will initiate a failover within the same shard. + // Since we are not executing key-based commands, we won't encounter MOVED errors that trigger a slot refresh. + // Consequently, we anticipate that only the periodic topology check will detect this change and trigger topology refresh. + // If successful, the node to which we route the CLUSTER NODES command should be the newly promoted node with a different node ID. + let cluster = TestClusterContext::new_with_cluster_client_builder(6, 1, |builder| { + builder.periodic_topology_checks(Duration::from_millis(100)) + }); + + block_on_all(async move { + let mut connection = cluster.async_connection().await; + let mut prev_master_id = "".to_string(); + let max_requests = 10000; + let mut i = 0; + loop { + if i == 10 { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("FAILOVER"); + cmd.arg("TAKEOVER"); + let res = connection + .send_packed_command( + &cmd, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::ReplicaRequired, + )), + )), + ) + .await; + assert!(res.is_ok()); + } else if i == max_requests { + break; + } else { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res = connection + .send_packed_command( + &cmd, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master)), + )), + ) + .await + .expect("Failed executing CLUSTER NODES"); + let node_id = get_queried_node_id_if_master(res); + if let Some(current_master_id) = node_id { + if prev_master_id.is_empty() { + prev_master_id = current_master_id; + } else if prev_master_id != current_master_id { + return Ok::<_, RedisError>(()); + } + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Topology change wasn't found!"); + }) + .unwrap(); +}