Skip to content

Commit

Permalink
Changed the cluster initialization to retrieve all IP entries from th…
Browse files Browse the repository at this point in the history
…e initial nodes and use all resolved IPs.
  • Loading branch information
barshaul committed Aug 8, 2023
1 parent 63adbe0 commit 23528ea
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 8 deletions.
3 changes: 2 additions & 1 deletion redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ where
}
}

async fn get_socket_addrs(
pub(crate) async fn get_socket_addrs(
host: &str,
port: u16,
) -> RedisResult<impl Iterator<Item = SocketAddr> + Send + '_> {
Expand All @@ -383,6 +383,7 @@ async fn get_socket_addrs(
pub(crate) async fn connect_simple<T: RedisRuntime>(
connection_info: &ConnectionInfo,
) -> RedisResult<T> {
println!("connect simple was called with {:?}", connection_info);
Ok(match connection_info.addr {
ConnectionAddr::Tcp(ref host, port) => {
let socket_addrs = get_socket_addrs(host, port).await?;
Expand Down
87 changes: 80 additions & 7 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::{
};

use crate::{
aio::{ConnectionLike, MultiplexedConnection},
aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection},
cluster::{get_connection_info, slot_cmd},
cluster_client::{ClusterParams, RetryParams},
cluster_routing::{
Expand Down Expand Up @@ -482,11 +482,44 @@ where
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
// Go through each of the initial nodes and attempt to retrieve all IP entries from them.
// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list.
let initial_nodes: Vec<String> = stream::iter(initial_nodes)
.fold(
Vec::with_capacity(initial_nodes.len()),
|mut acc, info| async {
let (host, port) = match &info.addr {
crate::ConnectionAddr::Tcp(host, port) => (host, port),
crate::ConnectionAddr::TcpTls {
host,
port,
insecure: _,
} => (host, port),
crate::ConnectionAddr::Unix(_) => {
// We don't support multiple addresses for a Unix address. Store the initial node address and continue
acc.push(info.addr.to_string());
return acc;
}
};
let socket_addrs = match get_socket_addrs(host, *port).await {
Ok(socket_addr) => socket_addr,
Err(_) => {
// Couldn't find socket addresses, store the initial node address and continue
acc.push(info.addr.to_string());
return acc;
}
};
for addr in socket_addrs {
acc.push(addr.to_string());
}
acc
},
)
.await;
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|info| {
.map(|addr| {
let params = params.clone();
async move {
let addr = info.addr.to_string();
let result = connect_and_check(&addr, params).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Expand Down Expand Up @@ -675,15 +708,37 @@ where
inner.cluster_params.tls,
num_of_nodes_to_query,
)?;

// Create a new connection vector of the found nodes
let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.get(addr).cloned()));
let addresses_and_connections_iter = nodes.into_iter().map(|addr| async move {
let conn = match connections.get(addr).cloned() {
Some(conn) => Some(conn),
None => {
// If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name.
// We shall check if a connection is already exists under the resolved IP name.
let (host, port) = match get_host_and_port_from_addr(addr) {
Some((host, port)) => (host, port),
None => return (addr, None),
};
match get_socket_addrs(host, port).await {
Ok(socket_addresses) => {
socket_addresses.fold(None, |acc, addr| match acc {
ok @ Some(_) => ok,
None => connections.get(&addr.to_string()).cloned(),
})
}
Err(_) => None,
}
}
};
(addr, conn)
});
let addresses_and_connections_iter =
futures::future::join_all(addresses_and_connections_iter).await;
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
Expand All @@ -700,6 +755,7 @@ where
.await;

drop(read_guard);
// Replace the current slot map and connection vector with the new ones
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Expand Down Expand Up @@ -1059,6 +1115,7 @@ where
Err(_) => connect_and_check(addr, params.clone()).await,
}
} else {
println!("connection is None");
connect_and_check(addr, params.clone()).await
}
}
Expand Down Expand Up @@ -1301,6 +1358,22 @@ where
(addr, conn)
}

/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned.
/// [addr] should be in the following format: "<host>:<port>".
fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> {
let parts: Vec<&str> = addr.split(':').collect();
if parts.len() != 2 {
return None;
}
let host = parts.first().unwrap();
let port = parts.get(1).unwrap();
let port = match port.parse::<u16>() {
Ok(port) => port,
Err(_) => return None,
};
Some((host, port))
}

#[cfg(test)]
mod pipeline_routing_tests {
use super::route_pipeline;
Expand Down

0 comments on commit 23528ea

Please sign in to comment.