Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass read_from_replica flag when finding routing. #18

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ where
route: &Route,
) -> RedisResult<(String, &'a mut C)> {
let slots = self.slots.borrow();
if let Some(addr) = slots.slot_addr_for_route(route, self.read_from_replicas) {
if let Some(addr) = slots.slot_addr_for_route(route) {
Ok((
addr.to_string(),
self.get_connection_by_addr(connections, addr)?,
Expand Down Expand Up @@ -362,12 +362,12 @@ where

let addr_for_slot = |route: Route| -> RedisResult<String> {
let slot_addr = slots
.slot_addr_for_route(&route, self.read_from_replicas)
.slot_addr_for_route(&route)
.ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?;
Ok(slot_addr.to_string())
};

match RoutingInfo::for_routable(cmd) {
match RoutingInfo::for_routable(cmd, self.read_from_replicas) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => {
let mut rng = thread_rng();
Ok(addr_for_slot(Route::new(
Expand Down Expand Up @@ -415,7 +415,7 @@ where
let mut results = HashMap::new();

// TODO: reconnect and shit
let addresses = slots.addresses_for_multi_routing(&routing, self.read_from_replicas);
let addresses = slots.addresses_for_multi_routing(&routing);
for addr in addresses {
let addr = addr.to_string();
if let Some(connection) = connections.get_mut(&addr) {
Expand All @@ -433,7 +433,7 @@ where
T: MergeResults + std::fmt::Debug,
F: FnMut(&mut C) -> RedisResult<T>,
{
let route = match RoutingInfo::for_routable(cmd) {
let route = match RoutingInfo::for_routable(cmd, self.read_from_replicas) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => {
Some(route)
Expand Down
46 changes: 28 additions & 18 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ use tokio::sync::{mpsc, oneshot, RwLock};
/// underlying connections maintained for each node in the cluster, as well
/// as common parameters for connecting to nodes and executing commands.
#[derive(Clone)]
pub struct ClusterConnection<C = MultiplexedConnection>(mpsc::Sender<Message<C>>);
pub struct ClusterConnection<C = MultiplexedConnection> {
sender: mpsc::Sender<Message<C>>,
read_from_replicas: bool,
}

impl<C> ClusterConnection<C>
where
Expand All @@ -88,6 +91,7 @@ where
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
) -> RedisResult<ClusterConnection<C>> {
let read_from_replicas = cluster_params.read_from_replicas;
ClusterConnInner::new(initial_nodes, cluster_params)
.await
.map(|inner| {
Expand All @@ -103,7 +107,10 @@ where
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
AsyncStd::spawn(stream);

ClusterConnection(tx)
ClusterConnection {
sender: tx,
read_from_replicas,
}
})
}

Expand All @@ -113,14 +120,15 @@ where
cmd: &Cmd,
routing: Option<RoutingInfo>,
) -> RedisResult<Value> {
let allow_replicas = self.read_from_replicas;
trace!("send_packed_command");
let (sender, receiver) = oneshot::channel();
self.0
self.sender
.send(Message {
cmd: CmdArg::Cmd {
cmd: Arc::new(cmd.clone()), // TODO Remove this clone?
routing: CommandRouting::Route(
routing.or_else(|| RoutingInfo::for_routable(cmd)),
routing.or_else(|| RoutingInfo::for_routable(cmd, allow_replicas)),
),
response_policy: RoutingInfo::response_policy(cmd),
},
Expand Down Expand Up @@ -155,14 +163,15 @@ where
count: usize,
route: Option<Route>,
) -> RedisResult<Vec<Value>> {
let allow_replicas = self.read_from_replicas;
let (sender, receiver) = oneshot::channel();
self.0
self.sender
.send(Message {
cmd: CmdArg::Pipeline {
pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone?
offset,
count,
route: route.or_else(|| route_pipeline(pipeline)),
route: route.or_else(|| route_pipeline(pipeline, allow_replicas)),
},
sender,
})
Expand Down Expand Up @@ -226,17 +235,17 @@ enum CmdArg<C> {
},
}

fn route_pipeline(pipeline: &crate::Pipeline) -> Option<Route> {
fn route_for_command(cmd: &Cmd) -> Option<Route> {
match RoutingInfo::for_routable(cmd) {
fn route_pipeline(pipeline: &crate::Pipeline, allow_replica: bool) -> Option<Route> {
let route_for_command = |cmd| -> Option<Route> {
match RoutingInfo::for_routable(cmd, allow_replica) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => {
Some(route)
}
Some(RoutingInfo::MultiNode(_)) => None,
None => None,
}
}
};

// Find first specific slot and send to it. There's no need to check If later commands
// should be routed to a different slot, since the server will return an error indicating this.
Expand Down Expand Up @@ -715,7 +724,7 @@ where
let read_guard = core.conn_lock.read().await;
let (receivers, requests): (Vec<_>, Vec<_>) = read_guard
.1
.addresses_for_multi_routing(routing, core.cluster_params.read_from_replicas)
.addresses_for_multi_routing(routing)
.into_iter()
.enumerate()
.filter_map(|(index, addr)| {
Expand Down Expand Up @@ -874,11 +883,7 @@ where
Some(Redirect::Ask(ask_addr)) => Some(ask_addr),
None => route
.as_ref()
.and_then(|route| {
read_guard
.1
.slot_addr_for_route(route, core.cluster_params.read_from_replicas)
})
.and_then(|route| read_guard.1.slot_addr_for_route(route))
.map(|addr| addr.to_string()),
}
.map(|addr| {
Expand Down Expand Up @@ -1319,9 +1324,14 @@ mod pipeline_routing_tests {
.add_command(cmd("EVAL")); // route randomly

assert_eq!(
route_pipeline(&pipeline),
route_pipeline(&pipeline, true),
Some(Route::new(12182, SlotAddr::Replica))
);

assert_eq!(
route_pipeline(&pipeline, false),
Some(Route::new(12182, SlotAddr::Master))
);
}

#[test]
Expand All @@ -1334,7 +1344,7 @@ mod pipeline_routing_tests {
.get("foo"); // route to slot 12182

assert_eq!(
route_pipeline(&pipeline),
route_pipeline(&pipeline, false),
Some(Route::new(4813, SlotAddr::Master))
);
}
Expand Down
Loading