diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index 2bcd7cb9ce..3415117982 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -361,7 +361,7 @@ where } } -async fn get_socket_addrs( +pub(crate) async fn get_socket_addrs( host: &str, port: u16, ) -> RedisResult + Send + '_> { @@ -383,6 +383,7 @@ async fn get_socket_addrs( pub(crate) async fn connect_simple( connection_info: &ConnectionInfo, ) -> RedisResult { + println!("connect simple was called with {:?}", connection_info); Ok(match connection_info.addr { ConnectionAddr::Tcp(ref host, port) => { let socket_addrs = get_socket_addrs(host, port).await?; diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 5ad0768bcd..87f878315b 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -36,7 +36,7 @@ use std::{ }; use crate::{ - aio::{ConnectionLike, MultiplexedConnection}, + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection}, cluster::{get_connection_info, slot_cmd}, cluster_client::{ClusterParams, RetryParams}, cluster_routing::{ @@ -482,11 +482,44 @@ where initial_nodes: &[ConnectionInfo], params: &ClusterParams, ) -> RedisResult> { + // Go through each of the initial nodes and attempt to retrieve all IP entries from them. + // If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + let initial_nodes: Vec = stream::iter(initial_nodes) + .fold( + Vec::with_capacity(initial_nodes.len()), + |mut acc, info| async { + let (host, port) = match &info.addr { + crate::ConnectionAddr::Tcp(host, port) => (host, port), + crate::ConnectionAddr::TcpTls { + host, + port, + insecure: _, + } => (host, port), + crate::ConnectionAddr::Unix(_) => { + // We don't support multiple addresses for a Unix address. Store the initial node address and continue + acc.push(info.addr.to_string()); + return acc; + } + }; + let socket_addrs = match get_socket_addrs(host, *port).await { + Ok(socket_addr) => socket_addr, + Err(_) => { + // Couldn't find socket addresses, store the initial node address and continue + acc.push(info.addr.to_string()); + return acc; + } + }; + for addr in socket_addrs { + acc.push(addr.to_string()); + } + acc + }, + ) + .await; let connections = stream::iter(initial_nodes.iter().cloned()) - .map(|info| { + .map(|addr| { let params = params.clone(); async move { - let addr = info.addr.to_string(); let result = connect_and_check(&addr, params).await; match result { Ok(conn) => Some((addr, async { conn }.boxed().shared())), @@ -675,15 +708,37 @@ where inner.cluster_params.tls, num_of_nodes_to_query, )?; - + // Create a new connection vector of the found nodes let connections: &ConnectionMap = &read_guard.0; let mut nodes = new_slots.values().flatten().collect::>(); nodes.sort_unstable(); nodes.dedup(); let nodes_len = nodes.len(); - let addresses_and_connections_iter = nodes - .into_iter() - .map(|addr| (addr, connections.get(addr).cloned())); + let addresses_and_connections_iter = nodes.into_iter().map(|addr| async move { + let conn = match connections.get(addr).cloned() { + Some(conn) => Some(conn), + None => { + // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. + // We shall check if a connection is already exists under the resolved IP name. + let (host, port) = match get_host_and_port_from_addr(addr) { + Some((host, port)) => (host, port), + None => return (addr, None), + }; + match get_socket_addrs(host, port).await { + Ok(socket_addresses) => { + socket_addresses.fold(None, |acc, addr| match acc { + ok @ Some(_) => ok, + None => connections.get(&addr.to_string()).cloned(), + }) + } + Err(_) => None, + } + } + }; + (addr, conn) + }); + let addresses_and_connections_iter = + futures::future::join_all(addresses_and_connections_iter).await; let new_connections: HashMap> = stream::iter(addresses_and_connections_iter) .fold( @@ -700,6 +755,7 @@ where .await; drop(read_guard); + // Replace the current slot map and connection vector with the new ones let mut write_guard = inner.conn_lock.write().await; write_guard.1 = new_slots; write_guard.0 = new_connections; @@ -1059,6 +1115,7 @@ where Err(_) => connect_and_check(addr, params.clone()).await, } } else { + println!("connection is None"); connect_and_check(addr, params.clone()).await } } @@ -1301,6 +1358,22 @@ where (addr, conn) } +/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned. +/// [addr] should be in the following format: ":". +fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> { + let parts: Vec<&str> = addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + let host = parts.first().unwrap(); + let port = parts.get(1).unwrap(); + let port = match port.parse::() { + Ok(port) => port, + Err(_) => return None, + }; + Some((host, port)) +} + #[cfg(test)] mod pipeline_routing_tests { use super::route_pipeline;