diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index eafc5cc234..c328de739e 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -479,6 +479,7 @@ dependencies = [ "clap", "futures", "openssl", + "rand", "rustyline", "rustyline-derive", "scylla", diff --git a/docs/source/load-balancing/load-balancing.md b/docs/source/load-balancing/load-balancing.md index 5fc8a069c1..3ec27dd7e1 100644 --- a/docs/source/load-balancing/load-balancing.md +++ b/docs/source/load-balancing/load-balancing.md @@ -2,8 +2,8 @@ ## Introduction -The driver uses a load balancing policy to determine which node(s) to contact -when executing a query. Load balancing policies implement the +The driver uses a load balancing policy to determine which node(s) and shard(s) +to contact when executing a query. Load balancing policies implement the `LoadBalancingPolicy` trait, which contains methods to generate a load balancing plan based on the query information and the state of the cluster. @@ -12,12 +12,14 @@ being opened. For a node connection blacklist configuration refer to `scylla::transport::host_filter::HostFilter`, which can be set session-wide using `SessionBuilder::host_filter` method. +In this chapter, "target" will refer to a pair ``. + ## Plan When a query is prepared to be sent to the database, the load balancing policy -constructs a load balancing plan. This plan is essentially a list of nodes to +constructs a load balancing plan. This plan is essentially a list of targets to which the driver will try to send the query. The first elements of the plan are -the nodes which are the best to contact (e.g. they might be replicas for the +the targets which are the best to contact (e.g. they might be replicas for the requested data or have the best latency). ## Policy @@ -84,17 +86,16 @@ first element of the load balancing plan is needed, so it's usually unnecessary to compute entire load balancing plan. To optimize this common case, the `LoadBalancingPolicy` trait provides two methods: `pick` and `fallback`. -`pick` returns the first node to contact for a given query, which is usually -the best based on a particular load balancing policy. If `pick` returns `None`, -then `fallback` will not be called. +`pick` returns the first target to contact for a given query, which is usually +the best based on a particular load balancing policy. -`fallback`, returns an iterator that provides the rest of the nodes in the load -balancing plan. `fallback` is called only when using the initial picked node -fails (or when executing speculatively). +`fallback`, returns an iterator that provides the rest of the targets in the +load balancing plan. `fallback` is called when using the initial picked +target fails (or when executing speculatively) or when `pick` returned `None`. -It's possible for the `fallback` method to include the same node that was +It's possible for the `fallback` method to include the same target that was returned by the `pick` method. In such cases, the query execution layer filters -out the picked node from the iterator returned by `fallback`. +out the picked target from the iterator returned by `fallback`. ### `on_query_success` and `on_query_failure`: diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 467963f93f..b068ee9e3c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -20,6 +20,7 @@ uuid = "1.0" tower = "0.4" stats_alloc = "0.1" clap = { version = "3.2.4", features = ["derive"] } +rand = "0.8.5" [[example]] name = "auth" diff --git a/examples/compare-tokens.rs b/examples/compare-tokens.rs index 294dc7842b..47bad6418a 100644 --- a/examples/compare-tokens.rs +++ b/examples/compare-tokens.rs @@ -41,7 +41,7 @@ async fn main() -> Result<()> { .get_cluster_data() .get_token_endpoints("examples_ks", Token { value: t }) .iter() - .map(|n| n.address) + .map(|(node, _shard)| node.address) .collect::>() ); diff --git a/examples/custom_load_balancing_policy.rs b/examples/custom_load_balancing_policy.rs index c4b2d32002..fb1ae0cb7c 100644 --- a/examples/custom_load_balancing_policy.rs +++ b/examples/custom_load_balancing_policy.rs @@ -1,23 +1,37 @@ use anyhow::Result; +use rand::thread_rng; +use rand::Rng; +use scylla::transport::NodeRef; use scylla::{ load_balancing::{LoadBalancingPolicy, RoutingInfo}, + routing::Shard, transport::{ClusterData, ExecutionProfile}, Session, SessionBuilder, }; use std::{env, sync::Arc}; /// Example load balancing policy that prefers nodes from favorite datacenter +/// This is, of course, very naive, as it is completely non token-aware. +/// For more realistic implementation, see [`DefaultPolicy`](scylla::load_balancing::DefaultPolicy). #[derive(Debug)] struct CustomLoadBalancingPolicy { fav_datacenter_name: String, } +fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + (node, thread_rng().gen_range(0..nr_shards) as Shard) +} + impl LoadBalancingPolicy for CustomLoadBalancingPolicy { fn pick<'a>( &'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { self.fallback(_info, cluster).next() } @@ -31,9 +45,9 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy { .unique_nodes_in_datacenter_ring(&self.fav_datacenter_name); match fav_dc_nodes { - Some(nodes) => Box::new(nodes.iter()), + Some(nodes) => Box::new(nodes.iter().map(with_random_shard)), // If there is no dc with provided name, fallback to other datacenters - None => Box::new(cluster.get_nodes_info().iter()), + None => Box::new(cluster.get_nodes_info().iter().map(with_random_shard)), } } diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 0098391854..6985498885 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -1,7 +1,7 @@ /// Cluster manages up to date information and connections to database nodes use crate::frame::response::event::{Event, StatusChangeEvent}; use crate::prepared_statement::TokenCalculationError; -use crate::routing::Token; +use crate::routing::{Shard, Token}; use crate::transport::host_filter::HostFilter; use crate::transport::{ connection::{Connection, VerifiedKeyspaceName}, @@ -27,6 +27,7 @@ use tracing::{debug, warn}; use uuid::Uuid; use super::node::{KnownNode, NodeAddr}; +use super::NodeRef; use super::locator::ReplicaLocator; use super::partitioner::calculate_token_for_partition_key; @@ -408,9 +409,9 @@ impl ClusterData { } /// Access to replicas owning a given token - pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec> { + pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<(Arc, Shard)> { self.get_token_endpoints_iter(keyspace, token) - .cloned() + .map(|(node, shard)| (node.clone(), shard)) .collect() } @@ -418,7 +419,7 @@ impl ClusterData { &self, keyspace: &str, token: Token, - ) -> impl Iterator> { + ) -> impl Iterator, Shard)> { let keyspace = self.keyspaces.get(keyspace); let strategy = keyspace .map(|k| &k.strategy) @@ -436,7 +437,7 @@ impl ClusterData { keyspace: &str, table: &str, partition_key: &SerializedValues, - ) -> Result>, BadQuery> { + ) -> Result, Shard)>, BadQuery> { Ok(self.get_token_endpoints( keyspace, self.compute_token(keyspace, table, partition_key)?, diff --git a/scylla/src/transport/connection_pool.rs b/scylla/src/transport/connection_pool.rs index f26ea36ac2..11c2d671ec 100644 --- a/scylla/src/transport/connection_pool.rs +++ b/scylla/src/transport/connection_pool.rs @@ -1,7 +1,7 @@ #[cfg(feature = "cloud")] use crate::cloud::set_ssl_config_for_scylla_cloud_host; -use crate::routing::{Shard, ShardCount, Sharder, Token}; +use crate::routing::{Shard, ShardCount, Sharder}; use crate::transport::errors::QueryError; use crate::transport::{ connection, @@ -28,7 +28,7 @@ use std::time::Duration; use tokio::sync::{broadcast, mpsc, Notify}; use tracing::instrument::WithSubscriber; -use tracing::{debug, trace, warn}; +use tracing::{debug, error, trace, warn}; /// The target size of a per-node connection pool. #[derive(Debug, Clone, Copy)] @@ -235,22 +235,25 @@ impl NodeConnectionPool { .unwrap_or(None) } - pub(crate) fn connection_for_token(&self, token: Token) -> Result, QueryError> { - trace!(token = token.value, "Selecting connection for token"); + pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result, QueryError> { + trace!(shard = shard, "Selecting connection for shard"); self.with_connections(|pool_conns| match pool_conns { PoolConnections::NotSharded(conns) => { Self::choose_random_connection_from_slice(conns).unwrap() } PoolConnections::Sharded { - sharder, connections, + sharder } => { - let shard: u16 = sharder - .shard_of(token) + let shard = shard .try_into() - .expect("Shard number doesn't fit in u16"); - trace!(shard = shard, "Selecting connection for token"); - Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + // It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy` + // now, which can be implemented by a user in an arbitrary way. + .unwrap_or_else(|_| { + error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard); + 0 + }); + Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice()) } }) } @@ -266,13 +269,13 @@ impl NodeConnectionPool { connections, } => { let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get()); - Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice()) + Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice()) } }) } // Tries to get a connection to given shard, if it's broken returns any working connection - fn connection_for_shard( + fn connection_for_shard_helper( shard: u16, nr_shards: ShardCount, shard_conns: &[Vec>], diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index 366a7ccb4a..f277f5b646 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo use crate::transport::load_balancing::{self, RoutingInfo}; use crate::transport::metrics::Metrics; use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession}; -use crate::transport::{Node, NodeRef}; +use crate::transport::NodeRef; use tracing::{trace, trace_span, warn, Instrument}; use uuid::Uuid; @@ -160,8 +160,6 @@ impl RowIterator { let worker_task = async move { let query_ref = &query; - let choose_connection = |node: Arc| async move { node.random_connection().await }; - let page_query = |connection: Arc, consistency: Consistency, paging_state: Option| { @@ -187,7 +185,6 @@ impl RowIterator { let worker = RowIteratorWorker { sender: sender.into(), - choose_connection, page_query, statement_info: routing_info, query_is_idempotent: query.config.is_idempotent, @@ -259,13 +256,6 @@ impl RowIterator { is_confirmed_lwt: config.prepared.is_confirmed_lwt(), }; - let choose_connection = |node: Arc| async move { - match token { - Some(token) => node.connection_for_token(token).await, - None => node.random_connection().await, - } - }; - let page_query = |connection: Arc, consistency: Consistency, paging_state: Option| async move { @@ -290,7 +280,7 @@ impl RowIterator { config .cluster_data .get_token_endpoints_iter(keyspace, token) - .cloned() + .map(|(node, shard)| (node.clone(), shard)) .collect(), ) } else { @@ -311,7 +301,6 @@ impl RowIterator { let worker = RowIteratorWorker { sender: sender.into(), - choose_connection, page_query, statement_info, query_is_idempotent: config.prepared.config.is_idempotent, @@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof { +struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> { sender: ProvingSender>, - // Closure used to choose a connection from a node - // AsyncFn(Arc) -> Result, QueryError> - choose_connection: ConnFunc, - // Closure used to perform a single page query // AsyncFn(Arc, Option) -> Result page_query: QueryFunc, @@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> { span_creator: SpanCreatorFunc, } -impl - RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator> +impl RowIteratorWorker<'_, QueryFunc, SpanCreator> where - ConnFunc: Fn(Arc) -> ConnFut, - ConnFut: Future, QueryError>>, QueryFunc: Fn(Arc, Consistency, Option) -> QueryFut, QueryFut: Future>, SpanCreator: Fn() -> RequestSpan, @@ -546,12 +528,13 @@ where self.log_query_start(); - 'nodes_in_plan: for node in query_plan { + 'nodes_in_plan: for (node, shard) in query_plan { let span = trace_span!(parent: &self.parent_span, "Executing query", node = %node.address); // For each node in the plan choose a connection to use // This connection will be reused for same node retries to preserve paging cache on the shard - let connection: Arc = match (self.choose_connection)(node.clone()) + let connection: Arc = match node + .connection_for_shard(shard) .instrument(span.clone()) .await { diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 3fdeef18ef..df5c5c5520 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -3,7 +3,7 @@ pub use self::latency_awareness::LatencyAwarenessBuilder; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; use crate::{ - routing::Token, + routing::{Shard, Token}, transport::{cluster::ClusterData, locator::ReplicaSet, node::Node, topology::Strategy}, }; use itertools::{Either, Itertools}; @@ -70,13 +70,14 @@ enum StatementType { /// It can be configured to be datacenter-aware and token-aware. /// Datacenter failover for queries with non local consistency mode is also supported. /// Latency awareness is available, althrough not recommended. +#[allow(clippy::type_complexity)] pub struct DefaultPolicy { preferences: NodeLocationPreference, is_token_aware: bool, permit_dc_failover: bool, - pick_predicate: Box bool + Send + Sync>, + pick_predicate: Box, Shard)) -> bool + Send + Sync>, latency_awareness: Option, - fixed_shuffle_seed: Option, + fixed_seed: Option, } impl fmt::Debug for DefaultPolicy { @@ -86,13 +87,17 @@ impl fmt::Debug for DefaultPolicy { .field("is_token_aware", &self.is_token_aware) .field("permit_dc_failover", &self.permit_dc_failover) .field("latency_awareness", &self.latency_awareness) - .field("fixed_shuffle_seed", &self.fixed_shuffle_seed) + .field("fixed_shuffle_seed", &self.fixed_seed) .finish_non_exhaustive() } } impl LoadBalancingPolicy for DefaultPolicy { - fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)> { let routing_info = self.routing_info(query, cluster); if let Some(ref token_with_strategy) = routing_info.token_with_strategy { if self.preferences.datacenter().is_some() @@ -177,7 +182,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if &self.pick_predicate, NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - let local_rack_picked = Self::pick_node(nodes, rack_predicate); + let local_rack_picked = self.pick_node(nodes, rack_predicate); if let Some(alive_local_rack) = local_rack_picked { return Some(alive_local_rack); @@ -185,14 +190,14 @@ or refrain from preferring datacenters (which may ban all other datacenters, if } // Try to pick some alive local random node. - if let Some(alive_local) = Self::pick_node(nodes, &self.pick_predicate) { + if let Some(alive_local) = self.pick_node(nodes, &self.pick_predicate) { return Some(alive_local); } let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = Self::pick_node(all_nodes, &self.pick_predicate); + let picked = self.pick_node(all_nodes, &self.pick_predicate); if let Some(alive_maybe_remote) = picked { return Some(alive_maybe_remote); } @@ -200,21 +205,21 @@ or refrain from preferring datacenters (which may ban all other datacenters, if // Previous checks imply that every node we could have selected is down. // Let's try to return a down node that wasn't disabled. - let picked = Self::pick_node(nodes, |node| node.is_enabled()); + let picked = self.pick_node(nodes, |(node, _shard)| node.is_enabled()); if let Some(down_but_enabled_local_node) = picked { return Some(down_but_enabled_local_node); } // If a datacenter failover is possible, loosen restriction about locality. if self.is_datacenter_failover_possible(&routing_info) { - let picked = Self::pick_node(all_nodes, |node| node.is_enabled()); + let picked = self.pick_node(all_nodes, |(node, _shard)| node.is_enabled()); if let Some(down_but_enabled_maybe_remote_node) = picked { return Some(down_but_enabled_maybe_remote_node); } } // Every node is disabled. This could be due to a bad host filter - configuration error. - nodes.first() + nodes.first().map(|node| self.with_random_shard(node)) } fn fallback<'a>( @@ -285,7 +290,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if .chain(maybe_remote_replicas), ) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; // Get a list of all local alive nodes, and apply a round robin to it @@ -297,29 +302,37 @@ or refrain from preferring datacenters (which may ban all other datacenters, if &self.pick_predicate, NodeLocationCriteria::DatacenterAndRack(dc, rack), ); - Either::Left(Self::round_robin_nodes(local_nodes, rack_predicate)) + Either::Left(self.round_robin_nodes_with_shards(local_nodes, rack_predicate)) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; - let robined_local_nodes = Self::round_robin_nodes(local_nodes, Self::is_alive); + let robinned_local_nodes = self.round_robin_nodes_with_shards(local_nodes, Self::is_alive); let all_nodes = cluster.replica_locator().unique_nodes_in_global_ring(); // If a datacenter failover is possible, loosen restriction about locality. let maybe_remote_nodes = if self.is_datacenter_failover_possible(&routing_info) { - let robined_all_nodes = Self::round_robin_nodes(all_nodes, Self::is_alive); + let robinned_all_nodes = self.round_robin_nodes_with_shards(all_nodes, Self::is_alive); - Either::Left(robined_all_nodes) + Either::Left(robinned_all_nodes) } else { - Either::Right(std::iter::empty::>()) + Either::Right(std::iter::empty::<(NodeRef<'a>, Shard)>()) }; // Even if we consider some enabled nodes to be down, we should try contacting them in the last resort. - let maybe_down_local_nodes = local_nodes.iter().filter(|node| node.is_enabled()); + let maybe_down_local_nodes = local_nodes + .iter() + .filter(|node| node.is_enabled()) + .map(|node| self.with_random_shard(node)); // If a datacenter failover is possible, loosen restriction about locality. let maybe_down_nodes = if self.is_datacenter_failover_possible(&routing_info) { - Either::Left(all_nodes.iter().filter(|node| node.is_enabled())) + Either::Left( + all_nodes + .iter() + .filter(|node| node.is_enabled()) + .map(|node| self.with_random_shard(node)), + ) } else { Either::Right(std::iter::empty()) }; @@ -327,7 +340,7 @@ or refrain from preferring datacenters (which may ban all other datacenters, if // Construct a fallback plan as a composition of replicas, local nodes and remote nodes. let plan = maybe_replicas .chain(maybe_local_rack_nodes) - .chain(robined_local_nodes) + .chain(robinned_local_nodes) .chain(maybe_remote_nodes) .chain(maybe_down_local_nodes) .chain(maybe_down_nodes) @@ -420,13 +433,15 @@ impl DefaultPolicy { /// Wraps the provided predicate, adding the requirement for rack to match. fn make_rack_predicate<'a>( - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, replica_location: NodeLocationCriteria<'a>, - ) -> impl Fn(&NodeRef<'a>) -> bool { - move |node| match replica_location { - NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => predicate(node), + ) -> impl Fn(&(NodeRef<'a>, Shard)) -> bool { + move |node_and_shard @ (node, _shard)| match replica_location { + NodeLocationCriteria::Any | NodeLocationCriteria::Datacenter(_) => { + predicate(node_and_shard) + } NodeLocationCriteria::DatacenterAndRack(_, rack) => { - predicate(node) && node.rack.as_deref() == Some(rack) + predicate(node_and_shard) && node.rack.as_deref() == Some(rack) } } } @@ -435,10 +450,10 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool + 'a, cluster: &'a ClusterData, order: ReplicaOrder, - ) -> impl Iterator> { + ) -> impl Iterator, Shard)> { let predicate = Self::make_rack_predicate(predicate, replica_location); let replica_iter = match order { @@ -452,17 +467,17 @@ impl DefaultPolicy { .into_iter(), ), }; - replica_iter.filter(move |node: &NodeRef<'a>| predicate(node)) + replica_iter.filter(move |node_and_shard: &(NodeRef<'a>, Shard)| predicate(node_and_shard)) } fn pick_replica<'a>( &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, statement_type: StatementType, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { match statement_type { StatementType::Lwt => self.pick_first_replica(ts, replica_location, predicate, cluster), StatementType::NonLwt => { @@ -487,9 +502,9 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { match replica_location { NodeLocationCriteria::Any => { // ReplicaSet returned by ReplicaLocator for this case: @@ -519,7 +534,7 @@ impl DefaultPolicy { self.replicas( ts, replica_location, - move |node| predicate(node), + move |node_and_shard| predicate(node_and_shard), cluster, ReplicaOrder::RingOrder, ) @@ -532,14 +547,14 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: &'a impl Fn(&NodeRef<'a>) -> bool, + predicate: &'a impl Fn(&(NodeRef<'a>, Shard)) -> bool, cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { let predicate = Self::make_rack_predicate(predicate, replica_location); let replica_set = self.nonfiltered_replica_set(ts, replica_location, cluster); - if let Some(fixed) = self.fixed_shuffle_seed { + if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); replica_set.choose_filtered(&mut gen, predicate) } else { @@ -551,10 +566,10 @@ impl DefaultPolicy { &'a self, ts: &TokenWithStrategy<'a>, replica_location: NodeLocationCriteria<'a>, - predicate: impl Fn(&NodeRef<'a>) -> bool + 'a, + predicate: impl Fn(&(NodeRef<'_>, Shard)) -> bool + 'a, cluster: &'a ClusterData, statement_type: StatementType, - ) -> impl Iterator> { + ) -> impl Iterator, Shard)> { let order = match statement_type { StatementType::Lwt => ReplicaOrder::RingOrder, StatementType::NonLwt => ReplicaOrder::Arbitrary, @@ -587,27 +602,33 @@ impl DefaultPolicy { } fn pick_node<'a>( + &'a self, nodes: &'a [Arc], - predicate: impl Fn(&NodeRef<'a>) -> bool, - ) -> Option> { + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, + ) -> Option<(NodeRef<'_>, Shard)> { // Select the first node that matches the predicate - Self::randomly_rotated_nodes(nodes).find(predicate) + Self::randomly_rotated_nodes(nodes) + .map(|node| self.with_random_shard(node)) + .find(predicate) } - fn round_robin_nodes<'a>( + fn round_robin_nodes_with_shards<'a>( + &'a self, nodes: &'a [Arc], - predicate: impl Fn(&NodeRef<'a>) -> bool, - ) -> impl Iterator> { - Self::randomly_rotated_nodes(nodes).filter(predicate) + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, + ) -> impl Iterator, Shard)> { + Self::randomly_rotated_nodes(nodes) + .map(|node| self.with_random_shard(node)) + .filter(predicate) } fn shuffle<'a>( &self, - iter: impl Iterator>, - ) -> impl Iterator> { - let mut vec: Vec> = iter.collect(); + iter: impl Iterator, Shard)>, + ) -> impl Iterator, Shard)> { + let mut vec: Vec<(NodeRef<'_>, Shard)> = iter.collect(); - if let Some(fixed) = self.fixed_shuffle_seed { + if let Some(fixed) = self.fixed_seed { let mut gen = Pcg32::new(fixed, 0); vec.shuffle(&mut gen); } else { @@ -617,7 +638,23 @@ impl DefaultPolicy { vec.into_iter() } - fn is_alive(node: &NodeRef<'_>) -> bool { + fn with_random_shard<'a>(&self, node: NodeRef<'a>) -> (NodeRef<'a>, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + ( + node, + (if let Some(fixed) = self.fixed_seed { + let mut gen = Pcg32::new(fixed, 0); + gen.gen_range(0..nr_shards) + } else { + thread_rng().gen_range(0..nr_shards) + }) as Shard, + ) + } + + fn is_alive(&(node, _shard): &(NodeRef<'_>, Shard)) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() @@ -638,7 +675,7 @@ impl Default for DefaultPolicy { permit_dc_failover: false, pick_predicate: Box::new(Self::is_alive), latency_awareness: None, - fixed_shuffle_seed: None, + fixed_seed: None, } } } @@ -683,8 +720,11 @@ impl DefaultPolicyBuilder { let latency_awareness = self.latency_awareness.map(|builder| builder.build()); let pick_predicate = if let Some(ref latency_awareness) = latency_awareness { let latency_predicate = latency_awareness.generate_predicate(); - Box::new(move |node: &NodeRef| DefaultPolicy::is_alive(node) && latency_predicate(node)) - as Box bool + Send + Sync + 'static> + Box::new( + move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { + DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) + }, + ) as Box, Shard)) -> bool + Send + Sync + 'static> } else { Box::new(DefaultPolicy::is_alive) }; @@ -695,7 +735,7 @@ impl DefaultPolicyBuilder { permit_dc_failover: self.permit_dc_failover, pick_predicate, latency_awareness, - fixed_shuffle_seed: (!self.enable_replica_shuffle).then(rand::random), + fixed_seed: (!self.enable_replica_shuffle).then(rand::random), }) } @@ -1123,7 +1163,8 @@ mod tests { cluster: &ClusterData, ) -> Vec { let plan = Plan::new(policy, query_info, cluster); - plan.map(|node| node.address.port()).collect::>() + plan.map(|(node, _shard)| node.address.port()) + .collect::>() } } @@ -1239,7 +1280,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1332,7 +1373,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1578,7 +1619,7 @@ mod tests { ), is_token_aware: true, permit_dc_failover: false, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1710,7 +1751,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -1807,7 +1848,7 @@ mod tests { preferences: NodeLocationPreference::Datacenter("eu".to_owned()), is_token_aware: true, permit_dc_failover: true, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -2064,7 +2105,7 @@ mod tests { ), is_token_aware: true, permit_dc_failover: false, - fixed_shuffle_seed: Some(123), + fixed_seed: Some(123), ..Default::default() }, routing_info: RoutingInfo { @@ -2136,7 +2177,7 @@ mod latency_awareness { use tracing::{instrument::WithSubscriber, trace, warn}; use uuid::Uuid; - use crate::{load_balancing::NodeRef, transport::node::Node}; + use crate::{load_balancing::NodeRef, routing::Shard, transport::node::Node}; use std::{ collections::HashMap, ops::Deref, @@ -2347,8 +2388,8 @@ mod latency_awareness { pub(super) fn wrap<'a>( &self, - fallback: impl Iterator>, - ) -> impl Iterator> { + fallback: impl Iterator, Shard)>, + ) -> impl Iterator, Shard)> { let min_avg_latency = match self.last_min_latency.load() { Some(min_avg) => min_avg, None => return Either::Left(fallback), // noop, as no latency data has been collected yet @@ -2669,8 +2710,8 @@ mod latency_awareness { struct IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator>, - Penalised: Iterator>, + Fast: Iterator, Shard)>, + Penalised: Iterator, Shard)>, { fast_nodes: Fast, penalised_nodes: Penalised, @@ -2679,13 +2720,13 @@ mod latency_awareness { impl<'a> IteratorWithSkippedNodes< 'a, - std::vec::IntoIter>, - std::vec::IntoIter>, + std::vec::IntoIter<(NodeRef<'a>, Shard)>, + std::vec::IntoIter<(NodeRef<'a>, Shard)>, > { fn new( average_latencies: &HashMap>>, - nodes: impl Iterator>, + nodes: impl Iterator, Shard)>, exclusion_threshold: f64, retry_period: Duration, minimum_measurements: usize, @@ -2694,7 +2735,7 @@ mod latency_awareness { let mut fast_nodes = vec![]; let mut penalised_nodes = vec![]; - for node in nodes { + for node_and_shard @ (node, _shard) in nodes { match fast_enough( average_latencies, node.host_id, @@ -2703,11 +2744,11 @@ mod latency_awareness { minimum_measurements, min_avg, ) { - FastEnough::Yes => fast_nodes.push(node), + FastEnough::Yes => fast_nodes.push(node_and_shard), FastEnough::No { average } => { trace!("Latency awareness: Penalising node {{address={}, datacenter={:?}, rack={:?}}} for being on average at least {} times slower (latency: {}ms) than the fastest ({}ms).", node.address, node.datacenter, node.rack, exclusion_threshold, average.as_millis(), min_avg.as_millis()); - penalised_nodes.push(node); + penalised_nodes.push(node_and_shard); } } } @@ -2721,10 +2762,10 @@ mod latency_awareness { impl<'a, Fast, Penalised> Iterator for IteratorWithSkippedNodes<'a, Fast, Penalised> where - Fast: Iterator>, - Penalised: Iterator>, + Fast: Iterator, Shard)>, + Penalised: Iterator, Shard)>, { - type Item = &'a Arc; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { self.fast_nodes @@ -2743,7 +2784,8 @@ mod latency_awareness { }; use crate::{ - load_balancing::default::NodeLocationPreference, test_utils::create_new_session_builder, + load_balancing::default::NodeLocationPreference, routing::Shard, + test_utils::create_new_session_builder, }; use crate::{ load_balancing::{ @@ -2804,9 +2846,12 @@ mod latency_awareness { ) -> DefaultPolicy { let pick_predicate = { let latency_predicate = latency_awareness.generate_predicate(); - Box::new(move |node: &NodeRef| { - DefaultPolicy::is_alive(node) && latency_predicate(node) - }) as Box bool + Send + Sync + 'static> + Box::new( + move |node_and_shard @ (node, _shard): &(NodeRef<'_>, Shard)| { + DefaultPolicy::is_alive(node_and_shard) && latency_predicate(node) + }, + ) + as Box, Shard)) -> bool + Send + Sync + 'static> }; DefaultPolicy { @@ -2815,7 +2860,7 @@ mod latency_awareness { is_token_aware: true, pick_predicate, latency_awareness: Some(latency_awareness), - fixed_shuffle_seed: None, + fixed_seed: None, } } diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index d4095743c3..977e3d508f 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -3,7 +3,7 @@ //! See [the book](https://rust-driver.docs.scylladb.com/stable/load-balancing/load-balancing.html) for more information use super::{cluster::ClusterData, NodeRef}; -use crate::routing::Token; +use crate::routing::{Shard, Token}; use scylla_cql::{errors::QueryError, frame::types}; use std::time::Duration; @@ -39,18 +39,19 @@ pub struct RoutingInfo<'a> { /// /// It is computed on-demand, only if querying the most preferred node fails /// (or when speculative execution is triggered). -pub type FallbackPlan<'a> = Box> + Send + Sync + 'a>; +pub type FallbackPlan<'a> = Box, Shard)> + Send + Sync + 'a>; -/// Policy that decides which nodes to contact for each query. +/// Policy that decides which nodes and shards to contact for each query. /// /// When a query is prepared to be sent to ScyllaDB/Cassandra, a `LoadBalancingPolicy` -/// implementation constructs a load balancing plan. That plan is a list of nodes to which -/// the driver will try to send the query. The first elements of the plan are the nodes which are +/// implementation constructs a load balancing plan. That plan is a list of +/// targets (target is a node + an optional shard) to which +/// the driver will try to send the query. The first elements of the plan are the targets which are /// the best to contact (e.g. they might have the lowest latency). /// -/// Most queries are send on the first try, so the query execution layer rarely needs to know more -/// than one node from plan. To better optimize that case, `LoadBalancingPolicy` has two methods: -/// `pick` and `fallback`. `pick` returns a first node to contact for a given query, `fallback` +/// Most queries are sent on the first try, so the query execution layer rarely needs to know more +/// than one target from plan. To better optimize that case, `LoadBalancingPolicy` has two methods: +/// `pick` and `fallback`. `pick` returns the first target to contact for a given query, `fallback` /// returns the rest of the load balancing plan. /// /// `fallback` is called not only if a send to `pick`ed node failed (or when executing @@ -62,7 +63,11 @@ pub type FallbackPlan<'a> = Box> + Send + Sync + /// This trait is used to produce an iterator of nodes to contact for a given query. pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug { /// Returns the first node to contact for a given query. - fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option>; + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)>; /// Returns all contact-appropriate nodes for a given query. fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) diff --git a/scylla/src/transport/load_balancing/plan.rs b/scylla/src/transport/load_balancing/plan.rs index e49d4cb012..55413cada0 100644 --- a/scylla/src/transport/load_balancing/plan.rs +++ b/scylla/src/transport/load_balancing/plan.rs @@ -1,15 +1,15 @@ use tracing::error; use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; -use crate::transport::ClusterData; +use crate::{routing::Shard, transport::ClusterData}; enum PlanState<'a> { Created, PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements. - Picked(NodeRef<'a>), + Picked((NodeRef<'a>, Shard)), Fallback { iter: FallbackPlan<'a>, - node_to_filter_out: NodeRef<'a>, + node_to_filter_out: (NodeRef<'a>, Shard), }, } @@ -44,7 +44,7 @@ impl<'a> Plan<'a> { } impl<'a> Iterator for Plan<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { match &mut self.state { @@ -77,7 +77,7 @@ impl<'a> Iterator for Plan<'a> { PlanState::Picked(node) => { self.state = PlanState::Fallback { iter: self.policy.fallback(self.routing_info, self.cluster), - node_to_filter_out: node, + node_to_filter_out: *node, }; self.next() @@ -112,24 +112,27 @@ mod tests { use super::*; - fn expected_nodes() -> Vec> { - vec![Arc::new(Node::new_for_test( - NodeAddr::Translatable(SocketAddr::from_str("127.0.0.1:9042").unwrap()), - None, - None, - ))] + fn expected_nodes() -> Vec<(Arc, Shard)> { + vec![( + Arc::new(Node::new_for_test( + NodeAddr::Translatable(SocketAddr::from_str("127.0.0.1:9042").unwrap()), + None, + None, + )), + 42, + )] } #[derive(Debug)] struct PickingNonePolicy { - expected_nodes: Vec>, + expected_nodes: Vec<(Arc, Shard)>, } impl LoadBalancingPolicy for PickingNonePolicy { fn pick<'a>( &'a self, _query: &'a RoutingInfo, _cluster: &'a ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { None } @@ -138,7 +141,11 @@ mod tests { _query: &'a RoutingInfo, _cluster: &'a ClusterData, ) -> FallbackPlan<'a> { - Box::new(self.expected_nodes.iter()) + Box::new( + self.expected_nodes + .iter() + .map(|(node_ref, shard)| (node_ref, *shard)), + ) } fn name(&self) -> String { @@ -159,6 +166,9 @@ mod tests { }; let routing_info = RoutingInfo::default(); let plan = Plan::new(&policy, &routing_info, &cluster_data); - assert_eq!(Vec::from_iter(plan.cloned()), policy.expected_nodes); + assert_eq!( + Vec::from_iter(plan.map(|(node, shard)| (node.clone(), shard))), + policy.expected_nodes + ); } } diff --git a/scylla/src/transport/locator/mod.rs b/scylla/src/transport/locator/mod.rs index 4ff44891d1..31a9a0cb9e 100644 --- a/scylla/src/transport/locator/mod.rs +++ b/scylla/src/transport/locator/mod.rs @@ -9,7 +9,7 @@ use rand::{seq::IteratorRandom, Rng}; pub use token_ring::TokenRing; use super::{topology::Strategy, Node, NodeRef}; -use crate::routing::Token; +use crate::routing::{Shard, Token}; use itertools::Itertools; use precomputed_replicas::PrecomputedReplicas; use replicas::{ReplicasArray, EMPTY_REPLICAS}; @@ -49,8 +49,9 @@ impl ReplicaLocator { let datacenters = replication_data .get_global_ring() .iter() - .filter_map(|(_, node)| node.datacenter.clone()) + .filter_map(|(_, node)| node.datacenter.as_deref()) .unique() + .map(ToOwned::to_owned) .collect(); Self { @@ -84,16 +85,20 @@ impl ReplicaLocator { if let Some(datacenter) = datacenter { let replicas = self.get_simple_strategy_replicas(token, *replication_factor); - return ReplicaSetInner::FilteredSimple { - replicas, - datacenter, - } - .into(); + return ReplicaSet { + inner: ReplicaSetInner::FilteredSimple { + replicas, + datacenter, + }, + token, + }; } else { - return ReplicaSetInner::Plain( - self.get_simple_strategy_replicas(token, *replication_factor), - ) - .into(); + return ReplicaSet { + inner: ReplicaSetInner::Plain( + self.get_simple_strategy_replicas(token, *replication_factor), + ), + token, + }; } } Strategy::NetworkTopologyStrategy { @@ -101,21 +106,28 @@ impl ReplicaLocator { } => { if let Some(dc) = datacenter { if let Some(repfactor) = datacenter_repfactors.get(dc) { - return ReplicaSetInner::Plain( - self.get_network_strategy_replicas(token, dc, *repfactor), - ) - .into(); + return ReplicaSet { + inner: ReplicaSetInner::Plain( + self.get_network_strategy_replicas(token, dc, *repfactor), + ), + token, + }; } else { debug!("Datacenter ({}) does not exist!", dc); - return EMPTY_REPLICAS.into(); + return ReplicaSet { + inner: ReplicaSetInner::Plain(EMPTY_REPLICAS), + token, + }; } } else { - return ReplicaSetInner::ChainedNTS { - datacenter_repfactors, - locator: self, + return ReplicaSet { + inner: ReplicaSetInner::ChainedNTS { + datacenter_repfactors, + locator: self, + token, + }, token, - } - .into(); + }; } } Strategy::Other { name, .. } => { @@ -210,6 +222,14 @@ impl ReplicaLocator { } } +fn with_computed_shard(node: NodeRef, token: Token) -> (NodeRef, Shard) { + let shard = node + .sharder() + .map(|sharder| sharder.shard_of(token)) + .unwrap_or(0); + (node, shard) +} + #[derive(Debug)] enum ReplicaSetInner<'a> { Plain(ReplicasArray<'a>), @@ -237,6 +257,7 @@ enum ReplicaSetInner<'a> { #[derive(Debug)] pub struct ReplicaSet<'a> { inner: ReplicaSetInner<'a>, + token: Token, } impl<'a> ReplicaSet<'a> { @@ -244,8 +265,8 @@ impl<'a> ReplicaSet<'a> { pub fn choose_filtered( self, rng: &mut R, - predicate: impl Fn(&NodeRef<'a>) -> bool, - ) -> Option> + predicate: impl Fn(&(NodeRef<'a>, Shard)) -> bool, + ) -> Option<(NodeRef<'a>, Shard)> where R: Rng + ?Sized, { @@ -302,7 +323,7 @@ impl<'a> ReplicaSet<'a> { self.len() == 0 } - fn choose(&self, rng: &mut R) -> Option> + fn choose(&self, rng: &mut R) -> Option<(NodeRef<'a>, Shard)> where R: Rng + ?Sized, { @@ -311,14 +332,17 @@ impl<'a> ReplicaSet<'a> { let index = rng.gen_range(0..len); match &self.inner { - ReplicaSetInner::Plain(replicas) => replicas.get(index), + ReplicaSetInner::Plain(replicas) => replicas + .get(index) + .map(|node| with_computed_shard(node, self.token)), ReplicaSetInner::FilteredSimple { replicas, datacenter, } => replicas .iter() .filter(|node| node.datacenter.as_deref() == Some(*datacenter)) - .nth(index), + .nth(index) + .map(|node| with_computed_shard(node, self.token)), ReplicaSetInner::ChainedNTS { datacenter_repfactors, locator, @@ -338,7 +362,8 @@ impl<'a> ReplicaSet<'a> { if nodes_to_skip < repfactor { return locator .get_network_strategy_replicas(*token, datacenter, repfactor) - .get(nodes_to_skip); + .get(nodes_to_skip) + .map(|node| with_computed_shard(node, self.token)); } nodes_to_skip -= repfactor; @@ -354,7 +379,7 @@ impl<'a> ReplicaSet<'a> { } impl<'a> IntoIterator for ReplicaSet<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); type IntoIter = ReplicaSetIterator<'a>; /// Converts the replica set into iterator. Order defined by that iterator does not have to @@ -399,23 +424,9 @@ impl<'a> IntoIterator for ReplicaSet<'a> { } }; - ReplicaSetIterator { inner } - } -} - -impl<'a> From> for ReplicaSet<'a> { - fn from(item: ReplicaSetInner<'a>) -> Self { - Self { inner: item } - } -} - -impl<'a, T> From for ReplicaSet<'a> -where - T: Into>, -{ - fn from(item: T) -> Self { - Self { - inner: ReplicaSetInner::Plain(item.into()), + ReplicaSetIterator { + inner, + token: self.token, } } } @@ -444,17 +455,18 @@ enum ReplicaSetIteratorInner<'a> { /// Iterator that returns replicas from some replica set. pub struct ReplicaSetIterator<'a> { inner: ReplicaSetIteratorInner<'a>, + token: Token, } impl<'a> Iterator for ReplicaSetIterator<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { match &mut self.inner { ReplicaSetIteratorInner::Plain { replicas, idx } => { if let Some(replica) = replicas.get(*idx) { *idx += 1; - return Some(replica); + return Some(with_computed_shard(replica, self.token)); } None @@ -467,7 +479,7 @@ impl<'a> Iterator for ReplicaSetIterator<'a> { while let Some(replica) = replicas.get(*idx) { *idx += 1; if replica.datacenter.as_deref() == Some(*datacenter) { - return Some(replica); + return Some(with_computed_shard(replica, self.token)); } } @@ -483,7 +495,7 @@ impl<'a> Iterator for ReplicaSetIterator<'a> { } => { if let Some(replica) = replicas.get(*replicas_idx) { *replicas_idx += 1; - Some(replica) + Some(with_computed_shard(replica, self.token)) } else if *datacenter_idx + 1 < locator.datacenters.len() { *datacenter_idx += 1; *replicas_idx = 0; @@ -589,7 +601,12 @@ enum ReplicasOrderedIteratorInner<'a> { }, } -enum ReplicasOrderedNTSIterator<'a> { +struct ReplicasOrderedNTSIterator<'a> { + token: Token, + inner: ReplicasOrderedNTSIteratorInner<'a>, +} + +enum ReplicasOrderedNTSIteratorInner<'a> { FreshForPick { datacenter_repfactors: &'a HashMap, locator: &'a ReplicaLocator, @@ -608,11 +625,11 @@ enum ReplicasOrderedNTSIterator<'a> { } impl<'a> Iterator for ReplicasOrderedNTSIterator<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { - match *self { - Self::FreshForPick { + match self.inner { + ReplicasOrderedNTSIteratorInner::FreshForPick { datacenter_repfactors, locator, token, @@ -624,19 +641,19 @@ impl<'a> Iterator for ReplicasOrderedNTSIterator<'a> { if let Some(dc) = &node.datacenter { if datacenter_repfactors.get(dc).is_some() { // ...then this node must be the primary replica. - *self = Self::Picked { + self.inner = ReplicasOrderedNTSIteratorInner::Picked { datacenter_repfactors, locator, token, picked: node, }; - return Some(node); + return Some(with_computed_shard(node, self.token)); } } } None } - Self::Picked { + ReplicasOrderedNTSIteratorInner::Picked { datacenter_repfactors, locator, token, @@ -673,19 +690,19 @@ impl<'a> Iterator for ReplicasOrderedNTSIterator<'a> { "all_replicas somehow contained a node that wasn't present in the global ring!" ); - *self = Self::ComputedFallback { + self.inner = ReplicasOrderedNTSIteratorInner::ComputedFallback { replicas: ReplicasArray::Owned(replicas_ordered), idx: 0, }; self.next() } - Self::ComputedFallback { + ReplicasOrderedNTSIteratorInner::ComputedFallback { ref replicas, ref mut idx, } => { if let Some(replica) = replicas.get(*idx) { *idx += 1; - Some(replica) + Some(with_computed_shard(replica, self.token)) } else { None } @@ -695,7 +712,7 @@ impl<'a> Iterator for ReplicasOrderedNTSIterator<'a> { } impl<'a> Iterator for ReplicasOrderedIterator<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { match &mut self.inner { @@ -710,7 +727,7 @@ impl<'a> Iterator for ReplicasOrderedIterator<'a> { } impl<'a> IntoIterator for ReplicasOrdered<'a> { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); type IntoIter = ReplicasOrderedIterator<'a>; fn into_iter(self) -> Self::IntoIter { @@ -727,10 +744,13 @@ impl<'a> IntoIterator for ReplicasOrdered<'a> { locator, token, } => ReplicasOrderedIteratorInner::PolyDatacenterNTS { - replicas_ordered_iter: ReplicasOrderedNTSIterator::FreshForPick { - datacenter_repfactors, - locator, - token, + replicas_ordered_iter: ReplicasOrderedNTSIterator { + token: replica_set.token, + inner: ReplicasOrderedNTSIteratorInner::FreshForPick { + datacenter_repfactors, + locator, + token, + }, }, }, }, @@ -755,7 +775,7 @@ mod tests { let replicas_ordered = replica_set.into_replicas_ordered(); let ids: Vec<_> = replicas_ordered .into_iter() - .map(|node| node.address.port()) + .map(|(node, _shard)| node.address.port()) .collect(); assert_eq!(expected, ids); }; diff --git a/scylla/src/transport/locator/test.rs b/scylla/src/transport/locator/test.rs index bb74ee0469..004d7809d2 100644 --- a/scylla/src/transport/locator/test.rs +++ b/scylla/src/transport/locator/test.rs @@ -175,7 +175,7 @@ fn assert_same_node_ids<'a>(left: impl Iterator>, ids: &[u16] } fn assert_replica_set_equal_to(nodes: ReplicaSet<'_>, ids: &[u16]) { - assert_same_node_ids(nodes.into_iter(), ids) + assert_same_node_ids(nodes.into_iter().map(|(node, _shard)| node), ids) } pub(crate) fn create_ring(metadata: &Metadata) -> impl Iterator)> { @@ -501,7 +501,7 @@ fn test_replica_set_choose(locator: &ReplicaLocator) { let mut chosen_replicas = HashSet::new(); for _ in 0..32 { let set = replica_set_generator(); - let node = set + let (node, _shard) = set .choose(&mut rng) .expect("choose from non-empty set must return some node"); chosen_replicas.insert(node.host_id); @@ -541,8 +541,10 @@ fn test_replica_set_choose_filtered(locator: &ReplicaLocator) { let mut chosen_replicas = HashSet::new(); for _ in 0..32 { let set = replica_set_generator(); - let node = set - .choose_filtered(&mut rng, |node| node.datacenter == Some("eu".into())) + let (node, _shard) = set + .choose_filtered(&mut rng, |(node, _shard)| { + node.datacenter == Some("eu".into()) + }) .expect("choose from non-empty set must return some node"); chosen_replicas.insert(node.host_id); } diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index 97b2679461..07c34e1302 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -3,7 +3,7 @@ use tracing::warn; use uuid::Uuid; /// Node represents a cluster node along with it's data and connections -use crate::routing::{Sharder, Token}; +use crate::routing::{Shard, Sharder}; use crate::transport::connection::Connection; use crate::transport::connection::VerifiedKeyspaceName; use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig}; @@ -152,18 +152,13 @@ impl Node { self.pool.as_ref()?.sharder() } - /// Get connection which should be used to connect using given token - /// If this connection is broken get any random connection to this Node - pub(crate) async fn connection_for_token( + /// Get a connection targetting the given shard + /// If such connection is broken, get any random connection to this `Node` + pub(crate) async fn connection_for_shard( &self, - token: Token, + shard: Shard, ) -> Result, QueryError> { - self.get_pool()?.connection_for_token(token) - } - - /// Get random connection - pub(crate) async fn random_connection(&self) -> Result, QueryError> { - self.get_pool()?.random_connection() + self.get_pool()?.connection_for_shard(shard) } pub fn is_down(&self) -> bool { diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index f4f5ab2365..721b2af1d5 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -51,7 +51,7 @@ use crate::frame::response::cql_to_rust::FromRowError; use crate::frame::response::result; use crate::prepared_statement::PreparedStatement; use crate::query::Query; -use crate::routing::Token; +use crate::routing::{Shard, Token}; use crate::statement::Consistency; use crate::tracing::{TracingEvent, TracingInfo}; use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug}; @@ -655,7 +655,6 @@ impl Session { statement_info, &query.config, execution_profile, - |node: Arc| async move { node.random_connection().await }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1024,12 +1023,6 @@ impl Session { statement_info, &prepared.config, execution_profile, - |node: Arc| async move { - match token { - Some(token) => node.connection_for_token(token).await, - None => node.random_connection().await, - } - }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1236,14 +1229,6 @@ impl Session { statement_info, &batch.config, execution_profile, - |node: Arc| async move { - match first_value_token { - Some(first_value_token) => { - node.connection_for_token(first_value_token).await - } - None => node.random_connection().await, - } - }, |connection: Arc, consistency: Consistency, execution_profile: &ExecutionProfileInner| { @@ -1507,28 +1492,23 @@ impl Session { } // This method allows to easily run a query using load balancing, retry policy etc. - // Requires some information about the query and two closures - // First closure is used to choose a connection - // - query will use node.random_connection() - // - execute will use node.connection_for_token() - // The second closure is used to do the query itself on a connection + // Requires some information about the query and a closure. + // The closure is used to do the query itself on a connection. // - query will use connection.query() // - execute will use connection.execute() // If this query closure fails with some errors retry policy is used to perform retries // On success this query's result is returned // I tried to make this closures take a reference instead of an Arc but failed // maybe once async closures get stabilized this can be fixed - async fn run_query<'a, ConnFut, QueryFut, ResT>( + async fn run_query<'a, QueryFut, ResT>( &'a self, statement_info: RoutingInfo<'a>, statement_config: &'a StatementConfig, execution_profile: Arc, - choose_connection: impl Fn(Arc) -> ConnFut, do_query: impl Fn(Arc, Consistency, &ExecutionProfileInner) -> QueryFut, request_span: &'a RequestSpan, ) -> Result, QueryError> where - ConnFut: Future, QueryError>>, QueryFut: Future>, ResT: AllowedRunQueryResTType, { @@ -1550,16 +1530,16 @@ impl Session { // can be shared safely. struct SharedPlan<'a, I> where - I: Iterator>, + I: Iterator, Shard)>, { iter: std::sync::Mutex, } impl<'a, I> Iterator for &SharedPlan<'a, I> where - I: Iterator>, + I: Iterator, Shard)>, { - type Item = NodeRef<'a>; + type Item = (NodeRef<'a>, Shard); fn next(&mut self) -> Option { self.iter.lock().unwrap().next() @@ -1602,7 +1582,6 @@ impl Session { self.execute_query( &shared_query_plan, - &choose_connection, &do_query, &execution_profile, ExecuteQueryContext { @@ -1638,7 +1617,6 @@ impl Session { }); self.execute_query( query_plan, - &choose_connection, &do_query, &execution_profile, ExecuteQueryContext { @@ -1684,16 +1662,14 @@ impl Session { result } - async fn execute_query<'a, ConnFut, QueryFut, ResT>( + async fn execute_query<'a, QueryFut, ResT>( &'a self, - query_plan: impl Iterator>, - choose_connection: impl Fn(Arc) -> ConnFut, + query_plan: impl Iterator, Shard)>, do_query: impl Fn(Arc, Consistency, &ExecutionProfileInner) -> QueryFut, execution_profile: &ExecutionProfileInner, mut context: ExecuteQueryContext<'a>, ) -> Option, QueryError>> where - ConnFut: Future, QueryError>>, QueryFut: Future>, ResT: AllowedRunQueryResTType, { @@ -1702,14 +1678,11 @@ impl Session { .consistency_set_on_statement .unwrap_or(execution_profile.consistency); - 'nodes_in_plan: for node in query_plan { + 'nodes_in_plan: for (node, shard) in query_plan { let span = trace_span!("Executing query", node = %node.address); 'same_node_retries: loop { trace!(parent: &span, "Execution started"); - let connection: Arc = match choose_connection(node.clone()) - .instrument(span.clone()) - .await - { + let connection = match node.connection_for_shard(shard).await { Ok(connection) => connection, Err(e) => { trace!( @@ -2027,19 +2000,19 @@ impl RequestSpan { self.span.record("result_rows", rows.rows.len()); } - pub(crate) fn record_replicas<'a>(&'a self, replicas: &'a [impl Borrow>]) { - struct ReplicaIps<'a, N>(&'a [N]); + pub(crate) fn record_replicas<'a>(&'a self, replicas: &'a [(impl Borrow>, Shard)]) { + struct ReplicaIps<'a, N>(&'a [(N, Shard)]); impl<'a, N> Display for ReplicaIps<'a, N> where N: Borrow>, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut nodes = self.0.iter(); - if let Some(node) = nodes.next() { - write!(f, "{}", node.borrow().address.ip())?; + let mut nodes_with_shards = self.0.iter(); + if let Some((node, shard)) = nodes_with_shards.next() { + write!(f, "{}-shard{}", node.borrow().address.ip(), shard)?; - for node in nodes { - write!(f, ",{}", node.borrow().address.ip())?; + for (node, shard) in nodes_with_shards { + write!(f, ",{}-shard{}", node.borrow().address.ip(), shard)?; } } Ok(()) diff --git a/scylla/tests/integration/consistency.rs b/scylla/tests/integration/consistency.rs index e265265335..a85a4bb60d 100644 --- a/scylla/tests/integration/consistency.rs +++ b/scylla/tests/integration/consistency.rs @@ -4,9 +4,10 @@ use scylla::execution_profile::{ExecutionProfileBuilder, ExecutionProfileHandle} use scylla::load_balancing::{DefaultPolicy, LoadBalancingPolicy, RoutingInfo}; use scylla::prepared_statement::PreparedStatement; use scylla::retry_policy::FallthroughRetryPolicy; -use scylla::routing::Token; +use scylla::routing::{Shard, Token}; use scylla::test_utils::unique_keyspace_name; use scylla::transport::session::Session; +use scylla::transport::NodeRef; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use scylla::statement::batch::BatchStatement; @@ -377,7 +378,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper { &'a self, query: &'a RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { self.routing_info_tx .send(OwnedRoutingInfo::from(query.clone())) .unwrap(); diff --git a/scylla/tests/integration/execution_profiles.rs b/scylla/tests/integration/execution_profiles.rs index 119487a609..38cdc85bb6 100644 --- a/scylla/tests/integration/execution_profiles.rs +++ b/scylla/tests/integration/execution_profiles.rs @@ -6,6 +6,7 @@ use assert_matches::assert_matches; use scylla::batch::BatchStatement; use scylla::batch::{Batch, BatchType}; use scylla::query::Query; +use scylla::routing::Shard; use scylla::statement::SerialConsistency; use scylla::transport::NodeRef; use scylla::{ @@ -46,9 +47,13 @@ impl BoundToPredefinedNodePolicy { } impl LoadBalancingPolicy for BoundToPredefinedNodePolicy { - fn pick<'a>(&'a self, _info: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + fn pick<'a>( + &'a self, + _info: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Shard)> { self.report_node(Report::LoadBalancing); - cluster.get_nodes_info().iter().next() + cluster.get_nodes_info().iter().next().map(|node| (node, 0)) } fn fallback<'a>( diff --git a/scylla/tests/integration/utils.rs b/scylla/tests/integration/utils.rs index f32ffa2764..52270c8942 100644 --- a/scylla/tests/integration/utils.rs +++ b/scylla/tests/integration/utils.rs @@ -1,6 +1,8 @@ use futures::Future; use itertools::Itertools; use scylla::load_balancing::LoadBalancingPolicy; +use scylla::routing::Shard; +use scylla::transport::NodeRef; use std::collections::HashMap; use std::env; use std::net::SocketAddr; @@ -16,6 +18,14 @@ pub fn init_logger() { .try_init(); } +fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) { + let nr_shards = node + .sharder() + .map(|sharder| sharder.nr_shards.get()) + .unwrap_or(1); + (node, ((nr_shards - 1) % 42) as Shard) +} + #[derive(Debug)] pub struct FixedOrderLoadBalancer; impl LoadBalancingPolicy for FixedOrderLoadBalancer { @@ -23,12 +33,13 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &'a self, _info: &'a scylla::load_balancing::RoutingInfo, cluster: &'a scylla::transport::ClusterData, - ) -> Option> { + ) -> Option<(NodeRef<'a>, Shard)> { cluster .get_nodes_info() .iter() .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) .next() + .map(with_pseudorandom_shard) } fn fallback<'a>( @@ -40,7 +51,8 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { cluster .get_nodes_info() .iter() - .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)), + .sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address)) + .map(with_pseudorandom_shard), ) } @@ -48,7 +60,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &self, _: &scylla::load_balancing::RoutingInfo, _: std::time::Duration, - _: scylla::transport::NodeRef<'_>, + _: NodeRef<'_>, ) { } @@ -56,7 +68,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer { &self, _: &scylla::load_balancing::RoutingInfo, _: std::time::Duration, - _: scylla::transport::NodeRef<'_>, + _: NodeRef<'_>, _: &scylla_cql::errors::QueryError, ) { }