From 8c9bde85776140a0c17a63cd1eac94e7fb9fcab2 Mon Sep 17 00:00:00 2001 From: Shachar Langbeheim Date: Mon, 14 Aug 2023 13:48:41 +0000 Subject: [PATCH] Use round-robin read from replica in clusters. --- redis/src/cluster.rs | 12 +- redis/src/cluster_async/mod.rs | 9 +- redis/src/cluster_client.rs | 13 +- redis/src/cluster_routing.rs | 12 +- redis/src/cluster_topology.rs | 259 +++++++++++++++++++++--------- redis/tests/test_cluster_async.rs | 126 +++++++++++++++ 6 files changed, 335 insertions(+), 96 deletions(-) diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index f9b368ed52..b686b46223 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -127,7 +127,6 @@ pub struct ClusterConnection { connections: RefCell>, slots: RefCell, auto_reconnect: RefCell, - read_from_replicas: bool, read_timeout: RefCell>, write_timeout: RefCell>, cluster_params: ClusterParams, @@ -143,9 +142,8 @@ where ) -> RedisResult { let connection = Self { connections: RefCell::new(HashMap::new()), - slots: RefCell::new(SlotMap::new(vec![])), + slots: RefCell::new(SlotMap::new(vec![], cluster_params.read_from_replicas)), auto_reconnect: RefCell::new(true), - read_from_replicas: cluster_params.read_from_replicas, cluster_params, read_timeout: RefCell::new(None), write_timeout: RefCell::new(None), @@ -297,7 +295,9 @@ where ))); for conn in samples.iter_mut() { let value = conn.req_command(&slot_cmd())?; - match parse_slots(&value, self.cluster_params.tls).map(build_slot_map) { + match parse_slots(&value, self.cluster_params.tls).map(|slots_data| { + build_slot_map(slots_data, self.cluster_params.read_from_replicas) + }) { Ok(new_slots) => { result = Ok(new_slots); break; @@ -312,7 +312,9 @@ where let info = get_connection_info(node, self.cluster_params.clone())?; let mut conn = C::connect(info, Some(self.cluster_params.connection_timeout))?; - if self.read_from_replicas { + if self.cluster_params.read_from_replicas + != crate::cluster_topology::ReadFromReplicaStrategy::AlwaysFromPrimary + { // If READONLY is sent to primary nodes, it will have no effect cmd("READONLY").query(&mut conn)?; } diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 1863502607..089f8f1056 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -464,7 +464,10 @@ where ) -> RedisResult { let connections = Self::create_initial_connections(initial_nodes, &cluster_params).await?; let inner = Arc::new(InnerCore { - conn_lock: RwLock::new((connections, Default::default())), + conn_lock: RwLock::new(( + connections, + SlotMap::new(vec![], cluster_params.read_from_replicas), + )), cluster_params, pending_requests: Mutex::new(Vec::new()), }); @@ -713,6 +716,7 @@ where curr_retry, inner.cluster_params.tls, num_of_nodes_to_query, + inner.cluster_params.read_from_replicas, )?; // Create a new connection vector of the found nodes let connections: &ConnectionMap = &read_guard.0; @@ -1313,7 +1317,8 @@ async fn connect_and_check(node: &str, params: ClusterParams) -> RedisResult< where C: ConnectionLike + Connect + Send + 'static, { - let read_from_replicas = params.read_from_replicas; + let read_from_replicas = params.read_from_replicas + != crate::cluster_topology::ReadFromReplicaStrategy::AlwaysFromPrimary; let connection_timeout = params.connection_timeout.into(); let info = get_connection_info(node, params)?; let mut conn: C = C::connect(info).timeout(connection_timeout).await??; diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index caf4e656c5..b0c94096d0 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -2,6 +2,7 @@ use std::time::Duration; use rand::Rng; +use crate::cluster_topology::ReadFromReplicaStrategy; use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; use crate::types::{ErrorKind, RedisError, RedisResult}; use crate::{cluster, cluster::TlsMode}; @@ -16,7 +17,7 @@ use crate::cluster_async; struct BuilderParams { password: Option, username: Option, - read_from_replicas: bool, + read_from_replicas: ReadFromReplicaStrategy, tls: Option, retries_configuration: RetryParams, connection_timeout: Option, @@ -64,7 +65,7 @@ impl RetryParams { pub(crate) struct ClusterParams { pub(crate) password: Option, pub(crate) username: Option, - pub(crate) read_from_replicas: bool, + pub(crate) read_from_replicas: ReadFromReplicaStrategy, /// tls indicates tls behavior of connections. /// When Some(TlsMode), connections use tls and verify certification depends on TlsMode. /// When None, connections do not use tls. @@ -237,7 +238,7 @@ impl ClusterClientBuilder { /// If enabled, then read queries will go to the replica nodes & write queries will go to the /// primary nodes. If there are no replica nodes, then all queries will go to the primary nodes. pub fn read_from_replicas(mut self) -> ClusterClientBuilder { - self.builder_params.read_from_replicas = true; + self.builder_params.read_from_replicas = ReadFromReplicaStrategy::RoundRobin; self } @@ -258,7 +259,11 @@ impl ClusterClientBuilder { /// Use `read_from_replicas()`. #[deprecated(since = "0.22.0", note = "Use read_from_replicas()")] pub fn readonly(mut self, read_from_replicas: bool) -> ClusterClientBuilder { - self.builder_params.read_from_replicas = read_from_replicas; + self.builder_params.read_from_replicas = if read_from_replicas { + ReadFromReplicaStrategy::RoundRobin + } else { + ReadFromReplicaStrategy::AlwaysFromPrimary + }; self } } diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index 564e81f93c..21e99ab15e 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -456,8 +456,8 @@ pub enum SlotAddr { /// a command is executed #[derive(Debug, Eq, PartialEq)] pub(crate) struct SlotAddrs { - primary: String, - replicas: Vec, + pub(crate) primary: String, + pub(crate) replicas: Vec, } impl SlotAddrs { @@ -465,14 +465,6 @@ impl SlotAddrs { Self { primary, replicas } } - pub(crate) fn slot_addr(&self, slot_addr: SlotAddr) -> &str { - if slot_addr == SlotAddr::Master || self.replicas.is_empty() { - self.primary.as_str() - } else { - self.replicas[0].as_str() - } - } - pub(crate) fn from_slot(slot: Slot) -> Self { SlotAddrs::new(slot.master, slot.replicas) } diff --git a/redis/src/cluster_topology.rs b/redis/src/cluster_topology.rs index 0775dd7704..27344f1166 100644 --- a/redis/src/cluster_topology.rs +++ b/redis/src/cluster_topology.rs @@ -13,6 +13,7 @@ use std::collections::BTreeMap; use std::collections::HashMap; use std::collections::HashSet; use std::hash::{Hash, Hasher}; +use std::sync::atomic::AtomicUsize; use std::time::Duration; /// The default number of refersh topology retries @@ -61,6 +62,7 @@ impl TopologyView { struct SlotMapValue { start: u16, addrs: SlotAddrs, + latest_used_replica: AtomicUsize, } impl SlotMapValue { @@ -68,23 +70,57 @@ impl SlotMapValue { Self { start: slot.start(), addrs: SlotAddrs::from_slot(slot), + latest_used_replica: AtomicUsize::new(0), } } } -#[derive(Debug, Default)] -pub(crate) struct SlotMap(BTreeMap); +#[derive(Debug, Default, Clone, PartialEq, Copy)] +pub(crate) enum ReadFromReplicaStrategy { + #[default] + AlwaysFromPrimary, + RoundRobin, +} + +#[derive(Debug)] +pub(crate) struct SlotMap { + slots: BTreeMap, + read_from_replica: ReadFromReplicaStrategy, +} + +fn get_address_from_slot( + slot: &SlotMapValue, + read_from_replica: ReadFromReplicaStrategy, + slot_addr: SlotAddr, +) -> &str { + if slot_addr == SlotAddr::Master || slot.addrs.replicas.is_empty() { + return slot.addrs.primary.as_str(); + } + match read_from_replica { + ReadFromReplicaStrategy::AlwaysFromPrimary => slot.addrs.primary.as_str(), + ReadFromReplicaStrategy::RoundRobin => { + let index = slot + .latest_used_replica + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % slot.addrs.replicas.len(); + slot.addrs.replicas[index].as_str() + } + } +} impl SlotMap { - pub fn new(slots: Vec) -> Self { - let mut this = Self(BTreeMap::new()); + pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut this = Self { + slots: BTreeMap::new(), + read_from_replica, + }; this.replace_slots(slots); this } fn replace_slots(&mut self, slots: Vec) { - self.0.clear(); - self.0.extend( + self.slots.clear(); + self.slots.extend( slots .into_iter() .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), @@ -93,26 +129,32 @@ impl SlotMap { pub fn slot_addr_for_route(&self, route: &Route) -> Option<&str> { let slot = route.slot(); - self.0.range(slot..).next().and_then(|(end, slot_value)| { - if slot <= *end && slot_value.start <= slot { - Some(slot_value.addrs.slot_addr(route.slot_addr())) - } else { - None - } - }) + self.slots + .range(slot..) + .next() + .and_then(|(end, slot_value)| { + if slot <= *end && slot_value.start <= slot { + Some(get_address_from_slot( + slot_value, + self.read_from_replica, + route.slot_addr(), + )) + } else { + None + } + }) } pub fn values(&self) -> impl Iterator { - self.0.values().map(|slot_value| &slot_value.addrs) + self.slots.values().map(|slot_value| &slot_value.addrs) } fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { let mut addresses = HashSet::new(); for slot in self.values() { - if only_primaries { - addresses.insert(slot.slot_addr(SlotAddr::Master)); - } else { - addresses.extend(slot.into_iter().map(|str| str.as_str())); + addresses.insert(slot.primary.as_str()); + if !only_primaries { + addresses.extend(slot.replicas.iter().map(|str| str.as_str())); } } @@ -236,8 +278,11 @@ pub(crate) fn parse_slots(raw_slot_resp: &Value, tls: Option) -> RedisR Ok(result) } -pub(crate) fn build_slot_map(slots_data: Vec) -> SlotMap { - let slot_map = SlotMap::new(slots_data); +pub(crate) fn build_slot_map( + slots_data: Vec, + read_from_replica: ReadFromReplicaStrategy, +) -> SlotMap { + let slot_map = SlotMap::new(slots_data, read_from_replica); trace!("{:?}", slot_map); slot_map } @@ -253,6 +298,7 @@ pub(crate) fn calculate_topology( curr_retry: usize, tls_mode: Option, num_of_queried_nodes: usize, + read_from_replica: ReadFromReplicaStrategy, ) -> Result { if topology_views.is_empty() { return Err(RedisError::from(( @@ -326,7 +372,7 @@ pub(crate) fn calculate_topology( "Failed to parse the slots on the majority view", )))?; - Ok(build_slot_map(slots_data)) + Ok(build_slot_map(slots_data, read_from_replica)) }; if non_unique_max_node_count { @@ -442,7 +488,14 @@ mod tests { get_view(&ViewType::SingleNodeViewFullCoverage), get_view(&ViewType::TwoNodesViewFullCoverage), ]; - let topology_view = calculate_topology(topology_results, 1, None, queried_nodes).unwrap(); + let topology_view = calculate_topology( + topology_results, + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let expected: Vec<&SlotAddrs> = vec![&node_1]; @@ -458,7 +511,13 @@ mod tests { get_view(&ViewType::TwoNodesViewFullCoverage), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology(topology_results, 1, None, queried_nodes); + let topology_view = calculate_topology( + topology_results, + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); assert!(topology_view.is_err()); } @@ -471,7 +530,14 @@ mod tests { get_view(&ViewType::TwoNodesViewFullCoverage), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology(topology_results, 3, None, queried_nodes).unwrap(); + let topology_view = calculate_topology( + topology_results, + 3, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); @@ -487,7 +553,14 @@ mod tests { get_view(&ViewType::TwoNodesViewFullCoverage), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology(topology_results, 1, None, queried_nodes).unwrap(); + let topology_view = calculate_topology( + topology_results, + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); @@ -504,7 +577,14 @@ mod tests { get_view(&ViewType::SingleNodeViewMissingSlots), get_view(&ViewType::TwoNodesViewMissingSlots), ]; - let topology_view = calculate_topology(topology_results, 1, None, queried_nodes).unwrap(); + let topology_view = calculate_topology( + topology_results, + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node3", 6381); let node_2 = get_node_addr("node4", 6382); @@ -521,7 +601,14 @@ mod tests { get_view(&ViewType::TwoNodesViewMissingSlots), get_view(&ViewType::SingleNodeViewMissingSlots), ]; - let topology_view = calculate_topology(topology_results, 1, None, queried_nodes).unwrap(); + let topology_view = calculate_topology( + topology_results, + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let expected: Vec<&SlotAddrs> = vec![&node_1]; @@ -530,20 +617,23 @@ mod tests { #[test] fn test_slot_map_retrieve_routes() { - let slot_map = SlotMap::new(vec![ - Slot::new( - 1, - 1000, - "node1:6379".to_owned(), - vec!["replica1:6379".to_owned()], - ), - Slot::new( - 1002, - 2000, - "node2:6379".to_owned(), - vec!["replica2:6379".to_owned()], - ), - ]); + let slot_map = SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); assert!(slot_map .slot_addr_for_route(&Route::new(0, SlotAddr::Master)) @@ -593,42 +683,45 @@ mod tests { .is_none()); } - fn get_slot_map() -> SlotMap { - SlotMap::new(vec![ - Slot::new( - 1, - 1000, - "node1:6379".to_owned(), - vec!["replica1:6379".to_owned()], - ), - Slot::new( - 1002, - 2000, - "node2:6379".to_owned(), - vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], - ), - Slot::new( - 2001, - 3000, - "node3:6379".to_owned(), - vec![ - "replica4:6379".to_owned(), - "replica5:6379".to_owned(), - "replica6:6379".to_owned(), - ], - ), - Slot::new( - 3001, - 4000, - "node2:6379".to_owned(), - vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], - ), - ]) + fn get_slot_map(read_from_replica: ReadFromReplicaStrategy) -> SlotMap { + SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + Slot::new( + 2001, + 3000, + "node3:6379".to_owned(), + vec![ + "replica4:6379".to_owned(), + "replica5:6379".to_owned(), + "replica6:6379".to_owned(), + ], + ), + Slot::new( + 3001, + 4000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + ], + read_from_replica, + ) } #[test] fn test_slot_map_get_all_primaries() { - let slot_map = get_slot_map(); + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); let mut addresses = slot_map.addresses_for_multi_routing(&MultipleNodeRoutingInfo::AllMasters); addresses.sort(); @@ -637,7 +730,7 @@ mod tests { #[test] fn test_slot_map_get_all_nodes() { - let slot_map = get_slot_map(); + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); let mut addresses = slot_map.addresses_for_multi_routing(&MultipleNodeRoutingInfo::AllNodes); addresses.sort(); @@ -659,7 +752,7 @@ mod tests { #[test] fn test_slot_map_get_multi_node() { - let slot_map = get_slot_map(); + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); let mut addresses = slot_map.addresses_for_multi_routing(&MultipleNodeRoutingInfo::MultiSlot(vec![ (Route::new(1, SlotAddr::Master), vec![]), @@ -668,4 +761,20 @@ mod tests { addresses.sort(); assert_eq!(addresses, vec!["node1:6379", "replica4:6379"]); } + + #[test] + fn test_slot_map_rotate_read_replicas() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let route = Route::new(2001, SlotAddr::Replica); + let mut addresses = vec![ + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + ]; + addresses.sort(); + assert_eq!( + addresses, + vec!["replica4:6379", "replica5:6379", "replica6:6379"] + ); + } } diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 13e6264379..e2c1dfa377 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -1574,3 +1574,129 @@ fn test_async_cluster_non_retryable_error_should_not_retry() { } assert_eq!(completed.load(Ordering::SeqCst), 1); } + +#[test] +fn test_async_cluster_read_from_primary() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6379, 6379, 6382, 6382]); +} + +#[test] +fn test_async_cluster_round_robin_read_from_replica() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6380, 6381, 6383, 6384]); +}