Skip to content

Commit

Permalink
Merge pull request scylladb#944 from Lorak-mmk/shard-selecting-lb-v2
Browse files Browse the repository at this point in the history
Shard selecting load balancing
  • Loading branch information
wprzytula authored Mar 14, 2024
2 parents 07df198 + 28ae015 commit 8e845e7
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 296 deletions.
1 change: 1 addition & 0 deletions Cargo.lock.msrv

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 13 additions & 12 deletions docs/source/load-balancing/load-balancing.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

## Introduction

The driver uses a load balancing policy to determine which node(s) to contact
when executing a query. Load balancing policies implement the
The driver uses a load balancing policy to determine which node(s) and shard(s)
to contact when executing a query. Load balancing policies implement the
`LoadBalancingPolicy` trait, which contains methods to generate a load
balancing plan based on the query information and the state of the cluster.

Expand All @@ -12,12 +12,14 @@ being opened. For a node connection blacklist configuration refer to
`scylla::transport::host_filter::HostFilter`, which can be set session-wide
using `SessionBuilder::host_filter` method.

In this chapter, "target" will refer to a pair `<node, optional shard>`.

## Plan

When a query is prepared to be sent to the database, the load balancing policy
constructs a load balancing plan. This plan is essentially a list of nodes to
constructs a load balancing plan. This plan is essentially a list of targets to
which the driver will try to send the query. The first elements of the plan are
the nodes which are the best to contact (e.g. they might be replicas for the
the targets which are the best to contact (e.g. they might be replicas for the
requested data or have the best latency).

## Policy
Expand Down Expand Up @@ -84,17 +86,16 @@ first element of the load balancing plan is needed, so it's usually unnecessary
to compute entire load balancing plan. To optimize this common case, the
`LoadBalancingPolicy` trait provides two methods: `pick` and `fallback`.

`pick` returns the first node to contact for a given query, which is usually
the best based on a particular load balancing policy. If `pick` returns `None`,
then `fallback` will not be called.
`pick` returns the first target to contact for a given query, which is usually
the best based on a particular load balancing policy.

`fallback`, returns an iterator that provides the rest of the nodes in the load
balancing plan. `fallback` is called only when using the initial picked node
fails (or when executing speculatively).
`fallback`, returns an iterator that provides the rest of the targets in the
load balancing plan. `fallback` is called when using the initial picked
target fails (or when executing speculatively) or when `pick` returned `None`.

It's possible for the `fallback` method to include the same node that was
It's possible for the `fallback` method to include the same target that was
returned by the `pick` method. In such cases, the query execution layer filters
out the picked node from the iterator returned by `fallback`.
out the picked target from the iterator returned by `fallback`.

### `on_query_success` and `on_query_failure`:

Expand Down
1 change: 1 addition & 0 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ uuid = "1.0"
tower = "0.4"
stats_alloc = "0.1"
clap = { version = "3.2.4", features = ["derive"] }
rand = "0.8.5"

[[example]]
name = "auth"
Expand Down
2 changes: 1 addition & 1 deletion examples/compare-tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async fn main() -> Result<()> {
.get_cluster_data()
.get_token_endpoints("examples_ks", Token { value: t })
.iter()
.map(|n| n.address)
.map(|(node, _shard)| node.address)
.collect::<Vec<NodeAddr>>()
);

Expand Down
20 changes: 17 additions & 3 deletions examples/custom_load_balancing_policy.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
use anyhow::Result;
use rand::thread_rng;
use rand::Rng;
use scylla::transport::NodeRef;
use scylla::{
load_balancing::{LoadBalancingPolicy, RoutingInfo},
routing::Shard,
transport::{ClusterData, ExecutionProfile},
Session, SessionBuilder,
};
use std::{env, sync::Arc};

/// Example load balancing policy that prefers nodes from favorite datacenter
/// This is, of course, very naive, as it is completely non token-aware.
/// For more realistic implementation, see [`DefaultPolicy`](scylla::load_balancing::DefaultPolicy).
#[derive(Debug)]
struct CustomLoadBalancingPolicy {
fav_datacenter_name: String,
}

fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) {
let nr_shards = node
.sharder()
.map(|sharder| sharder.nr_shards.get())
.unwrap_or(1);
(node, thread_rng().gen_range(0..nr_shards) as Shard)
}

impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
fn pick<'a>(
&'a self,
_info: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> Option<scylla::transport::NodeRef<'a>> {
) -> Option<(NodeRef<'a>, Shard)> {
self.fallback(_info, cluster).next()
}

Expand All @@ -31,9 +45,9 @@ impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
.unique_nodes_in_datacenter_ring(&self.fav_datacenter_name);

match fav_dc_nodes {
Some(nodes) => Box::new(nodes.iter()),
Some(nodes) => Box::new(nodes.iter().map(with_random_shard)),
// If there is no dc with provided name, fallback to other datacenters
None => Box::new(cluster.get_nodes_info().iter()),
None => Box::new(cluster.get_nodes_info().iter().map(with_random_shard)),
}
}

Expand Down
11 changes: 6 additions & 5 deletions scylla/src/transport/cluster.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/// Cluster manages up to date information and connections to database nodes
use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::prepared_statement::TokenCalculationError;
use crate::routing::Token;
use crate::routing::{Shard, Token};
use crate::transport::host_filter::HostFilter;
use crate::transport::{
connection::{Connection, VerifiedKeyspaceName},
Expand All @@ -27,6 +27,7 @@ use tracing::{debug, warn};
use uuid::Uuid;

use super::node::{KnownNode, NodeAddr};
use super::NodeRef;

use super::locator::ReplicaLocator;
use super::partitioner::calculate_token_for_partition_key;
Expand Down Expand Up @@ -408,17 +409,17 @@ impl ClusterData {
}

/// Access to replicas owning a given token
pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<Arc<Node>> {
pub fn get_token_endpoints(&self, keyspace: &str, token: Token) -> Vec<(Arc<Node>, Shard)> {
self.get_token_endpoints_iter(keyspace, token)
.cloned()
.map(|(node, shard)| (node.clone(), shard))
.collect()
}

pub(crate) fn get_token_endpoints_iter(
&self,
keyspace: &str,
token: Token,
) -> impl Iterator<Item = &Arc<Node>> {
) -> impl Iterator<Item = (NodeRef<'_>, Shard)> {
let keyspace = self.keyspaces.get(keyspace);
let strategy = keyspace
.map(|k| &k.strategy)
Expand All @@ -436,7 +437,7 @@ impl ClusterData {
keyspace: &str,
table: &str,
partition_key: &SerializedValues,
) -> Result<Vec<Arc<Node>>, BadQuery> {
) -> Result<Vec<(Arc<Node>, Shard)>, BadQuery> {
Ok(self.get_token_endpoints(
keyspace,
self.compute_token(keyspace, table, partition_key)?,
Expand Down
27 changes: 15 additions & 12 deletions scylla/src/transport/connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(feature = "cloud")]
use crate::cloud::set_ssl_config_for_scylla_cloud_host;

use crate::routing::{Shard, ShardCount, Sharder, Token};
use crate::routing::{Shard, ShardCount, Sharder};
use crate::transport::errors::QueryError;
use crate::transport::{
connection,
Expand All @@ -28,7 +28,7 @@ use std::time::Duration;

use tokio::sync::{broadcast, mpsc, Notify};
use tracing::instrument::WithSubscriber;
use tracing::{debug, trace, warn};
use tracing::{debug, error, trace, warn};

/// The target size of a per-node connection pool.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -235,22 +235,25 @@ impl NodeConnectionPool {
.unwrap_or(None)
}

pub(crate) fn connection_for_token(&self, token: Token) -> Result<Arc<Connection>, QueryError> {
trace!(token = token.value, "Selecting connection for token");
pub(crate) fn connection_for_shard(&self, shard: Shard) -> Result<Arc<Connection>, QueryError> {
trace!(shard = shard, "Selecting connection for shard");
self.with_connections(|pool_conns| match pool_conns {
PoolConnections::NotSharded(conns) => {
Self::choose_random_connection_from_slice(conns).unwrap()
}
PoolConnections::Sharded {
sharder,
connections,
sharder
} => {
let shard: u16 = sharder
.shard_of(token)
let shard = shard
.try_into()
.expect("Shard number doesn't fit in u16");
trace!(shard = shard, "Selecting connection for token");
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
// It's safer to use 0 rather that panic here, as shards are returned by `LoadBalancingPolicy`
// now, which can be implemented by a user in an arbitrary way.
.unwrap_or_else(|_| {
error!("The provided shard number: {} does not fit u16! Using 0 as the shard number. Check your LoadBalancingPolicy implementation.", shard);
0
});
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
}
})
}
Expand All @@ -266,13 +269,13 @@ impl NodeConnectionPool {
connections,
} => {
let shard: u16 = rand::thread_rng().gen_range(0..sharder.nr_shards.get());
Self::connection_for_shard(shard, sharder.nr_shards, connections.as_slice())
Self::connection_for_shard_helper(shard, sharder.nr_shards, connections.as_slice())
}
})
}

// Tries to get a connection to given shard, if it's broken returns any working connection
fn connection_for_shard(
fn connection_for_shard_helper(
shard: u16,
nr_shards: ShardCount,
shard_conns: &[Vec<Arc<Connection>>],
Expand Down
31 changes: 7 additions & 24 deletions scylla/src/transport/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::transport::connection::{Connection, NonErrorQueryResponse, QueryRespo
use crate::transport::load_balancing::{self, RoutingInfo};
use crate::transport::metrics::Metrics;
use crate::transport::retry_policy::{QueryInfo, RetryDecision, RetrySession};
use crate::transport::{Node, NodeRef};
use crate::transport::NodeRef;
use tracing::{trace, trace_span, warn, Instrument};
use uuid::Uuid;

Expand Down Expand Up @@ -160,8 +160,6 @@ impl RowIterator {
let worker_task = async move {
let query_ref = &query;

let choose_connection = |node: Arc<Node>| async move { node.random_connection().await };

let page_query = |connection: Arc<Connection>,
consistency: Consistency,
paging_state: Option<Bytes>| {
Expand All @@ -187,7 +185,6 @@ impl RowIterator {

let worker = RowIteratorWorker {
sender: sender.into(),
choose_connection,
page_query,
statement_info: routing_info,
query_is_idempotent: query.config.is_idempotent,
Expand Down Expand Up @@ -259,13 +256,6 @@ impl RowIterator {
is_confirmed_lwt: config.prepared.is_confirmed_lwt(),
};

let choose_connection = |node: Arc<Node>| async move {
match token {
Some(token) => node.connection_for_token(token).await,
None => node.random_connection().await,
}
};

let page_query = |connection: Arc<Connection>,
consistency: Consistency,
paging_state: Option<Bytes>| async move {
Expand All @@ -290,7 +280,7 @@ impl RowIterator {
config
.cluster_data
.get_token_endpoints_iter(keyspace, token)
.cloned()
.map(|(node, shard)| (node.clone(), shard))
.collect(),
)
} else {
Expand All @@ -311,7 +301,6 @@ impl RowIterator {

let worker = RowIteratorWorker {
sender: sender.into(),
choose_connection,
page_query,
statement_info,
query_is_idempotent: config.prepared.config.is_idempotent,
Expand Down Expand Up @@ -496,13 +485,9 @@ type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError

// RowIteratorWorker works in the background to fetch pages
// RowIterator receives them through a channel
struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
struct RowIteratorWorker<'a, QueryFunc, SpanCreatorFunc> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,

// Closure used to choose a connection from a node
// AsyncFn(Arc<Node>) -> Result<Arc<Connection>, QueryError>
choose_connection: ConnFunc,

// Closure used to perform a single page query
// AsyncFn(Arc<Connection>, Option<Bytes>) -> Result<QueryResponse, QueryError>
page_query: QueryFunc,
Expand All @@ -524,11 +509,8 @@ struct RowIteratorWorker<'a, ConnFunc, QueryFunc, SpanCreatorFunc> {
span_creator: SpanCreatorFunc,
}

impl<ConnFunc, ConnFut, QueryFunc, QueryFut, SpanCreator>
RowIteratorWorker<'_, ConnFunc, QueryFunc, SpanCreator>
impl<QueryFunc, QueryFut, SpanCreator> RowIteratorWorker<'_, QueryFunc, SpanCreator>
where
ConnFunc: Fn(Arc<Node>) -> ConnFut,
ConnFut: Future<Output = Result<Arc<Connection>, QueryError>>,
QueryFunc: Fn(Arc<Connection>, Consistency, Option<Bytes>) -> QueryFut,
QueryFut: Future<Output = Result<QueryResponse, QueryError>>,
SpanCreator: Fn() -> RequestSpan,
Expand All @@ -546,12 +528,13 @@ where

self.log_query_start();

'nodes_in_plan: for node in query_plan {
'nodes_in_plan: for (node, shard) in query_plan {
let span =
trace_span!(parent: &self.parent_span, "Executing query", node = %node.address);
// For each node in the plan choose a connection to use
// This connection will be reused for same node retries to preserve paging cache on the shard
let connection: Arc<Connection> = match (self.choose_connection)(node.clone())
let connection: Arc<Connection> = match node
.connection_for_shard(shard)
.instrument(span.clone())
.await
{
Expand Down
Loading

0 comments on commit 8e845e7

Please sign in to comment.