diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 9fa84a4f99..7d3f9dbc69 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -330,7 +330,7 @@ where route: &Route, ) -> RedisResult<(String, &'a mut C)> { let slots = self.slots.borrow(); - if let Some(addr) = slots.slot_addr_for_route(route, self.read_from_replicas) { + if let Some(addr) = slots.slot_addr_for_route(route) { Ok(( addr.to_string(), self.get_connection_by_addr(connections, addr)?, @@ -362,12 +362,12 @@ where let addr_for_slot = |route: Route| -> RedisResult { let slot_addr = slots - .slot_addr_for_route(&route, self.read_from_replicas) + .slot_addr_for_route(&route) .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; Ok(slot_addr.to_string()) }; - match RoutingInfo::for_routable(cmd) { + match RoutingInfo::for_routable(cmd, self.read_from_replicas) { Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => { let mut rng = thread_rng(); Ok(addr_for_slot(Route::new( @@ -415,7 +415,7 @@ where let mut results = HashMap::new(); // TODO: reconnect and shit - let addresses = slots.addresses_for_multi_routing(&routing, self.read_from_replicas); + let addresses = slots.addresses_for_multi_routing(&routing); for addr in addresses { let addr = addr.to_string(); if let Some(connection) = connections.get_mut(&addr) { @@ -433,7 +433,7 @@ where T: MergeResults + std::fmt::Debug, F: FnMut(&mut C) -> RedisResult, { - let route = match RoutingInfo::for_routable(cmd) { + let route = match RoutingInfo::for_routable(cmd, self.read_from_replicas) { Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None, Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => { Some(route) diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 5ad0768bcd..e3c6fc7791 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -78,7 +78,10 @@ use tokio::sync::{mpsc, oneshot, RwLock}; /// underlying connections maintained for each node in the cluster, as well /// as common parameters for connecting to nodes and executing commands. #[derive(Clone)] -pub struct ClusterConnection(mpsc::Sender>); +pub struct ClusterConnection { + sender: mpsc::Sender>, + params: ClusterParams, +} impl ClusterConnection where @@ -88,6 +91,7 @@ where initial_nodes: &[ConnectionInfo], cluster_params: ClusterParams, ) -> RedisResult> { + let cloned_params = cluster_params.clone(); ClusterConnInner::new(initial_nodes, cluster_params) .await .map(|inner| { @@ -103,7 +107,10 @@ where #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] AsyncStd::spawn(stream); - ClusterConnection(tx) + ClusterConnection { + sender: tx, + params: cloned_params, + } }) } @@ -113,14 +120,15 @@ where cmd: &Cmd, routing: Option, ) -> RedisResult { + let allow_replicas = self.params.read_from_replicas; trace!("send_packed_command"); let (sender, receiver) = oneshot::channel(); - self.0 + self.sender .send(Message { cmd: CmdArg::Cmd { cmd: Arc::new(cmd.clone()), // TODO Remove this clone? routing: CommandRouting::Route( - routing.or_else(|| RoutingInfo::for_routable(cmd)), + routing.or_else(|| RoutingInfo::for_routable(cmd, allow_replicas)), ), response_policy: RoutingInfo::response_policy(cmd), }, @@ -155,14 +163,15 @@ where count: usize, route: Option, ) -> RedisResult> { + let allow_replicas = self.params.read_from_replicas; let (sender, receiver) = oneshot::channel(); - self.0 + self.sender .send(Message { cmd: CmdArg::Pipeline { pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone? offset, count, - route: route.or_else(|| route_pipeline(pipeline)), + route: route.or_else(|| route_pipeline(pipeline, allow_replicas)), }, sender, }) @@ -226,9 +235,9 @@ enum CmdArg { }, } -fn route_pipeline(pipeline: &crate::Pipeline) -> Option { - fn route_for_command(cmd: &Cmd) -> Option { - match RoutingInfo::for_routable(cmd) { +fn route_pipeline(pipeline: &crate::Pipeline, allow_replica: bool) -> Option { + let route_for_command = |cmd| -> Option { + match RoutingInfo::for_routable(cmd, allow_replica) { Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None, Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => { Some(route) @@ -236,7 +245,7 @@ fn route_pipeline(pipeline: &crate::Pipeline) -> Option { Some(RoutingInfo::MultiNode(_)) => None, None => None, } - } + }; // Find first specific slot and send to it. There's no need to check If later commands // should be routed to a different slot, since the server will return an error indicating this. @@ -715,7 +724,7 @@ where let read_guard = core.conn_lock.read().await; let (receivers, requests): (Vec<_>, Vec<_>) = read_guard .1 - .addresses_for_multi_routing(routing, core.cluster_params.read_from_replicas) + .addresses_for_multi_routing(routing) .into_iter() .enumerate() .filter_map(|(index, addr)| { @@ -874,11 +883,7 @@ where Some(Redirect::Ask(ask_addr)) => Some(ask_addr), None => route .as_ref() - .and_then(|route| { - read_guard - .1 - .slot_addr_for_route(route, core.cluster_params.read_from_replicas) - }) + .and_then(|route| read_guard.1.slot_addr_for_route(route)) .map(|addr| addr.to_string()), } .map(|addr| { @@ -1319,9 +1324,14 @@ mod pipeline_routing_tests { .add_command(cmd("EVAL")); // route randomly assert_eq!( - route_pipeline(&pipeline), + route_pipeline(&pipeline, true), Some(Route::new(12182, SlotAddr::Replica)) ); + + assert_eq!( + route_pipeline(&pipeline, false), + Some(Route::new(12182, SlotAddr::Master)) + ); } #[test] @@ -1334,7 +1344,7 @@ mod pipeline_routing_tests { .get("foo"); // route to slot 12182 assert_eq!( - route_pipeline(&pipeline), + route_pipeline(&pipeline, false), Some(Route::new(4813, SlotAddr::Master)) ); } diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index ac31c2c987..f18c8e248b 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -272,7 +272,7 @@ impl RoutingInfo { } /// Returns the routing info for `r`. - pub fn for_routable(r: &R) -> Option + pub fn for_routable(r: &R, allow_replica: bool) -> Option where R: Routable + ?Sized, { @@ -318,7 +318,8 @@ impl RoutingInfo { if key_count == 0 { Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) } else { - r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) + r.arg_idx(3) + .map(|key| RoutingInfo::for_key(cmd, key, allow_replica)) } } b"XGROUP CREATE" @@ -328,22 +329,24 @@ impl RoutingInfo { | b"XGROUP SETID" | b"XINFO CONSUMERS" | b"XINFO GROUPS" - | b"XINFO STREAM" => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)), + | b"XINFO STREAM" => r + .arg_idx(2) + .map(|key| RoutingInfo::for_key(cmd, key, allow_replica)), b"XREAD" | b"XREADGROUP" => { let streams_position = r.position(b"STREAMS")?; r.arg_idx(streams_position + 1) - .map(|key| RoutingInfo::for_key(cmd, key)) + .map(|key| RoutingInfo::for_key(cmd, key, allow_replica)) } _ => match r.arg_idx(1) { - Some(key) => Some(RoutingInfo::for_key(cmd, key)), + Some(key) => Some(RoutingInfo::for_key(cmd, key, allow_replica)), None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), }, } } - fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo { + fn for_key(cmd: &[u8], key: &[u8], allow_replica: bool) -> RoutingInfo { RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route( - is_readonly_cmd(cmd), + allow_replica && is_readonly_cmd(cmd), key, ))) } @@ -467,8 +470,8 @@ impl SlotAddrs { Self([master_node, replica]) } - pub(crate) fn slot_addr(&self, slot_addr: SlotAddr, allow_replica: bool) -> &str { - if allow_replica && slot_addr == SlotAddr::Replica { + pub(crate) fn slot_addr(&self, slot_addr: SlotAddr) -> &str { + if slot_addr == SlotAddr::Replica { self.0[1].as_str() } else { self.0[0].as_str() @@ -545,16 +548,16 @@ mod tests { lower.arg("streams").arg("foo").arg(0); assert_eq!( - RoutingInfo::for_routable(&upper).unwrap(), - RoutingInfo::for_routable(&lower).unwrap() + RoutingInfo::for_routable(&upper, false).unwrap(), + RoutingInfo::for_routable(&lower, false).unwrap() ); let mut mixed = cmd("xReAd"); mixed.arg("StReAmS").arg("foo").arg(0); assert_eq!( - RoutingInfo::for_routable(&lower).unwrap(), - RoutingInfo::for_routable(&mixed).unwrap() + RoutingInfo::for_routable(&lower, false).unwrap(), + RoutingInfo::for_routable(&mixed, false).unwrap() ); } @@ -605,8 +608,8 @@ mod tests { for cmd in test_cmds { let value = parse_redis_value(&cmd.get_packed_command()).unwrap(); assert_eq!( - RoutingInfo::for_routable(&value).unwrap(), - RoutingInfo::for_routable(&cmd).unwrap(), + RoutingInfo::for_routable(&value, false).unwrap(), + RoutingInfo::for_routable(&cmd, false).unwrap(), ); } @@ -622,7 +625,7 @@ mod tests { cmd("SCRIPT KILL"), ] { assert_eq!( - RoutingInfo::for_routable(&cmd), + RoutingInfo::for_routable(&cmd, false), Some(RoutingInfo::MultiNode(MultipleNodeRoutingInfo::AllMasters)) ); } @@ -637,7 +640,7 @@ mod tests { cmd("BITOP"), ] { assert_eq!( - RoutingInfo::for_routable(&cmd), + RoutingInfo::for_routable(&cmd, false), None, "{}", std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() @@ -649,7 +652,7 @@ mod tests { cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), ] { assert_eq!( - RoutingInfo::for_routable(cmd), + RoutingInfo::for_routable(cmd, false), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) ); } @@ -719,7 +722,42 @@ mod tests { ), ] { assert_eq!( - RoutingInfo::for_routable(cmd), + RoutingInfo::for_routable(cmd, true), + expected, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + } + + #[test] + fn test_routing_info_without_allowing_replicas() { + for (cmd, expected) in vec![ + ( + cmd("XINFO").arg("GROUPS").arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)), + )), + ), + ( + cmd("XREAD") + .arg("COUNT") + .arg("2") + .arg("STREAMS") + .arg("mystream") + .arg("writers") + .arg("0-0") + .arg("0-0"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd, false), expected, "{}", std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() @@ -732,28 +770,28 @@ mod tests { assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10, 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10 - ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Replica)))) if slot == 964)); + ]).unwrap(), true), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Replica)))) if slot == 964)); assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241, 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52, 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 - ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352)); + ]).unwrap(), true), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352)); assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233, 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52, 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 - ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210)); + ]).unwrap(), true), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210)); } #[test] fn test_multi_shard() { let mut cmd = cmd("DEL"); cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); - let routing = RoutingInfo::for_routable(&cmd); + let routing = RoutingInfo::for_routable(&cmd, true); let mut expected = std::collections::HashMap::new(); expected.insert(Route(4813, SlotAddr::Master), vec![3]); expected.insert(Route(5061, SlotAddr::Master), vec![2, 4]); @@ -769,7 +807,7 @@ mod tests { let mut cmd = crate::cmd("MGET"); cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); - let routing = RoutingInfo::for_routable(&cmd); + let routing = RoutingInfo::for_routable(&cmd, true); let mut expected = std::collections::HashMap::new(); expected.insert(Route(4813, SlotAddr::Replica), vec![3]); expected.insert(Route(5061, SlotAddr::Replica), vec![2, 4]); @@ -788,7 +826,7 @@ mod tests { fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() { let mut cmd = cmd("DEL"); cmd.arg("foo").arg("{foo}bar").arg("{foo}baz"); - let routing = RoutingInfo::for_routable(&cmd); + let routing = RoutingInfo::for_routable(&cmd, true); assert!( matches!( @@ -821,53 +859,47 @@ mod tests { assert_eq!( "node1:6379", slot_map - .slot_addr_for_route(&Route::new(1, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(1, SlotAddr::Master)) .unwrap() ); assert_eq!( "node1:6379", slot_map - .slot_addr_for_route(&Route::new(500, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(500, SlotAddr::Master)) .unwrap() ); assert_eq!( "node1:6379", slot_map - .slot_addr_for_route(&Route::new(1000, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(1000, SlotAddr::Master)) .unwrap() ); assert_eq!( "replica1:6379", slot_map - .slot_addr_for_route(&Route::new(1000, SlotAddr::Replica), true) - .unwrap() - ); - assert_eq!( - "node1:6379", - slot_map - .slot_addr_for_route(&Route::new(1000, SlotAddr::Replica), false) + .slot_addr_for_route(&Route::new(1000, SlotAddr::Replica)) .unwrap() ); assert_eq!( "node2:6379", slot_map - .slot_addr_for_route(&Route::new(1001, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(1001, SlotAddr::Master)) .unwrap() ); assert_eq!( "node2:6379", slot_map - .slot_addr_for_route(&Route::new(1500, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(1500, SlotAddr::Master)) .unwrap() ); assert_eq!( "node2:6379", slot_map - .slot_addr_for_route(&Route::new(2000, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(2000, SlotAddr::Master)) .unwrap() ); assert!(slot_map - .slot_addr_for_route(&Route::new(2001, SlotAddr::Master), false) + .slot_addr_for_route(&Route::new(2001, SlotAddr::Master)) .is_none()); } } diff --git a/redis/src/cluster_topology.rs b/redis/src/cluster_topology.rs index fa4947717e..15b6e2c6a1 100644 --- a/redis/src/cluster_topology.rs +++ b/redis/src/cluster_topology.rs @@ -50,11 +50,11 @@ impl SlotMap { } } - pub fn slot_addr_for_route(&self, route: &Route, allow_replica: bool) -> Option<&str> { + pub fn slot_addr_for_route(&self, route: &Route) -> Option<&str> { self.0 .range(route.slot()..) .next() - .map(|(_, slot_addrs)| slot_addrs.slot_addr(route.slot_addr(), allow_replica)) + .map(|(_, slot_addrs)| slot_addrs.slot_addr(route.slot_addr())) } pub fn clear(&mut self) { @@ -68,20 +68,16 @@ impl SlotMap { fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { let mut addresses = HashSet::new(); for slot in self.values() { - addresses.insert(slot.slot_addr(SlotAddr::Master, false)); + addresses.insert(slot.slot_addr(SlotAddr::Master)); if !only_primaries { - addresses.insert(slot.slot_addr(SlotAddr::Replica, true)); + addresses.insert(slot.slot_addr(SlotAddr::Replica)); } } addresses } - pub fn addresses_for_multi_routing( - &self, - routing: &MultipleNodeRoutingInfo, - allow_replica: bool, - ) -> Vec<&str> { + pub fn addresses_for_multi_routing(&self, routing: &MultipleNodeRoutingInfo) -> Vec<&str> { match routing { MultipleNodeRoutingInfo::AllNodes => { self.all_unique_addresses(false).into_iter().collect() @@ -91,7 +87,7 @@ impl SlotMap { } MultipleNodeRoutingInfo::MultiSlot(routes) => routes .iter() - .flat_map(|(route, _)| self.slot_addr_for_route(route, allow_replica)) + .flat_map(|(route, _)| self.slot_addr_for_route(route)) .collect(), } }