Skip to content

Commit

Permalink
Merge pull request #18 from nihohit/read-replica-fix
Browse files Browse the repository at this point in the history
Pass read_from_replica flag when finding routing.
  • Loading branch information
nihohit authored Aug 7, 2023
2 parents 63adbe0 + f364939 commit 1a80090
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 72 deletions.
10 changes: 5 additions & 5 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ where
route: &Route,
) -> RedisResult<(String, &'a mut C)> {
let slots = self.slots.borrow();
if let Some(addr) = slots.slot_addr_for_route(route, self.read_from_replicas) {
if let Some(addr) = slots.slot_addr_for_route(route) {
Ok((
addr.to_string(),
self.get_connection_by_addr(connections, addr)?,
Expand Down Expand Up @@ -362,12 +362,12 @@ where

let addr_for_slot = |route: Route| -> RedisResult<String> {
let slot_addr = slots
.slot_addr_for_route(&route, self.read_from_replicas)
.slot_addr_for_route(&route)
.ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?;
Ok(slot_addr.to_string())
};

match RoutingInfo::for_routable(cmd) {
match RoutingInfo::for_routable(cmd, self.read_from_replicas) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => {
let mut rng = thread_rng();
Ok(addr_for_slot(Route::new(
Expand Down Expand Up @@ -415,7 +415,7 @@ where
let mut results = HashMap::new();

// TODO: reconnect and shit
let addresses = slots.addresses_for_multi_routing(&routing, self.read_from_replicas);
let addresses = slots.addresses_for_multi_routing(&routing);
for addr in addresses {
let addr = addr.to_string();
if let Some(connection) = connections.get_mut(&addr) {
Expand All @@ -433,7 +433,7 @@ where
T: MergeResults + std::fmt::Debug,
F: FnMut(&mut C) -> RedisResult<T>,
{
let route = match RoutingInfo::for_routable(cmd) {
let route = match RoutingInfo::for_routable(cmd, self.read_from_replicas) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => {
Some(route)
Expand Down
46 changes: 28 additions & 18 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ use tokio::sync::{mpsc, oneshot, RwLock};
/// underlying connections maintained for each node in the cluster, as well
/// as common parameters for connecting to nodes and executing commands.
#[derive(Clone)]
pub struct ClusterConnection<C = MultiplexedConnection>(mpsc::Sender<Message<C>>);
pub struct ClusterConnection<C = MultiplexedConnection> {
sender: mpsc::Sender<Message<C>>,
read_from_replicas: bool,
}

impl<C> ClusterConnection<C>
where
Expand All @@ -88,6 +91,7 @@ where
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
) -> RedisResult<ClusterConnection<C>> {
let read_from_replicas = cluster_params.read_from_replicas;
ClusterConnInner::new(initial_nodes, cluster_params)
.await
.map(|inner| {
Expand All @@ -103,7 +107,10 @@ where
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
AsyncStd::spawn(stream);

ClusterConnection(tx)
ClusterConnection {
sender: tx,
read_from_replicas,
}
})
}

Expand All @@ -113,14 +120,15 @@ where
cmd: &Cmd,
routing: Option<RoutingInfo>,
) -> RedisResult<Value> {
let allow_replicas = self.read_from_replicas;
trace!("send_packed_command");
let (sender, receiver) = oneshot::channel();
self.0
self.sender
.send(Message {
cmd: CmdArg::Cmd {
cmd: Arc::new(cmd.clone()), // TODO Remove this clone?
routing: CommandRouting::Route(
routing.or_else(|| RoutingInfo::for_routable(cmd)),
routing.or_else(|| RoutingInfo::for_routable(cmd, allow_replicas)),
),
response_policy: RoutingInfo::response_policy(cmd),
},
Expand Down Expand Up @@ -155,14 +163,15 @@ where
count: usize,
route: Option<Route>,
) -> RedisResult<Vec<Value>> {
let allow_replicas = self.read_from_replicas;
let (sender, receiver) = oneshot::channel();
self.0
self.sender
.send(Message {
cmd: CmdArg::Pipeline {
pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone?
offset,
count,
route: route.or_else(|| route_pipeline(pipeline)),
route: route.or_else(|| route_pipeline(pipeline, allow_replicas)),
},
sender,
})
Expand Down Expand Up @@ -226,17 +235,17 @@ enum CmdArg<C> {
},
}

fn route_pipeline(pipeline: &crate::Pipeline) -> Option<Route> {
fn route_for_command(cmd: &Cmd) -> Option<Route> {
match RoutingInfo::for_routable(cmd) {
fn route_pipeline(pipeline: &crate::Pipeline, allow_replica: bool) -> Option<Route> {
let route_for_command = |cmd| -> Option<Route> {
match RoutingInfo::for_routable(cmd, allow_replica) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => {
Some(route)
}
Some(RoutingInfo::MultiNode(_)) => None,
None => None,
}
}
};

// Find first specific slot and send to it. There's no need to check If later commands
// should be routed to a different slot, since the server will return an error indicating this.
Expand Down Expand Up @@ -715,7 +724,7 @@ where
let read_guard = core.conn_lock.read().await;
let (receivers, requests): (Vec<_>, Vec<_>) = read_guard
.1
.addresses_for_multi_routing(routing, core.cluster_params.read_from_replicas)
.addresses_for_multi_routing(routing)
.into_iter()
.enumerate()
.filter_map(|(index, addr)| {
Expand Down Expand Up @@ -874,11 +883,7 @@ where
Some(Redirect::Ask(ask_addr)) => Some(ask_addr),
None => route
.as_ref()
.and_then(|route| {
read_guard
.1
.slot_addr_for_route(route, core.cluster_params.read_from_replicas)
})
.and_then(|route| read_guard.1.slot_addr_for_route(route))
.map(|addr| addr.to_string()),
}
.map(|addr| {
Expand Down Expand Up @@ -1319,9 +1324,14 @@ mod pipeline_routing_tests {
.add_command(cmd("EVAL")); // route randomly

assert_eq!(
route_pipeline(&pipeline),
route_pipeline(&pipeline, true),
Some(Route::new(12182, SlotAddr::Replica))
);

assert_eq!(
route_pipeline(&pipeline, false),
Some(Route::new(12182, SlotAddr::Master))
);
}

#[test]
Expand All @@ -1334,7 +1344,7 @@ mod pipeline_routing_tests {
.get("foo"); // route to slot 12182

assert_eq!(
route_pipeline(&pipeline),
route_pipeline(&pipeline, false),
Some(Route::new(4813, SlotAddr::Master))
);
}
Expand Down
Loading

0 comments on commit 1a80090

Please sign in to comment.