Skip to content

Commit

Permalink
Merge pull request scylladb#969 from Lorak-mmk/fix_shard_load_balancing
Browse files Browse the repository at this point in the history
LBP: Return Option<Shard> instead of Shard
  • Loading branch information
wprzytula authored Mar 28, 2024
2 parents 8155fd4 + e90f102 commit 1fcadc9
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 263 deletions.
6 changes: 3 additions & 3 deletions examples/custom_load_balancing_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ struct CustomLoadBalancingPolicy {
fav_datacenter_name: String,
}

fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) {
fn with_random_shard(node: NodeRef) -> (NodeRef, Option<Shard>) {
let nr_shards = node
.sharder()
.map(|sharder| sharder.nr_shards.get())
.unwrap_or(1);
(node, thread_rng().gen_range(0..nr_shards) as Shard)
(node, Some(thread_rng().gen_range(0..nr_shards) as Shard))
}

impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
fn pick<'a>(
&'a self,
_info: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> Option<(NodeRef<'a>, Shard)> {
) -> Option<(NodeRef<'a>, Option<Shard>)> {
self.fallback(_info, cluster).next()
}

Expand Down
358 changes: 181 additions & 177 deletions scylla/src/transport/load_balancing/default.rs

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions scylla/src/transport/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> {
///
/// It is computed on-demand, only if querying the most preferred node fails
/// (or when speculative execution is triggered).
pub type FallbackPlan<'a> = Box<dyn Iterator<Item = (NodeRef<'a>, Shard)> + Send + Sync + 'a>;
pub type FallbackPlan<'a> =
Box<dyn Iterator<Item = (NodeRef<'a>, Option<Shard>)> + Send + Sync + 'a>;

/// Policy that decides which nodes and shards to contact for each query.
///
Expand Down Expand Up @@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug {
&'a self,
query: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> Option<(NodeRef<'a>, Shard)>;
) -> Option<(NodeRef<'a>, Option<Shard>)>;

/// Returns all contact-appropriate nodes for a given query.
fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData)
Expand Down
89 changes: 75 additions & 14 deletions scylla/src/transport/load_balancing/plan.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rand::{thread_rng, Rng};
use tracing::error;

use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
Expand All @@ -6,20 +7,65 @@ use crate::{routing::Shard, transport::ClusterData};
enum PlanState<'a> {
Created,
PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements.
Picked((NodeRef<'a>, Shard)),
Picked((NodeRef<'a>, Option<Shard>)),
Fallback {
iter: FallbackPlan<'a>,
node_to_filter_out: (NodeRef<'a>, Shard),
target_to_filter_out: (NodeRef<'a>, Option<Shard>),
},
}

/// The list of nodes constituting the query plan.
/// The list of targets constituting the query plan. Target here is a pair `(NodeRef<'a>, Shard)`.
///
/// The plan is partly lazily computed, with the first node computed
/// eagerly in the first place and the remaining nodes computed on-demand
/// The plan is partly lazily computed, with the first target computed
/// eagerly in the first place and the remaining targets computed on-demand
/// (all at once).
/// This significantly reduces the allocation overhead on "the happy path"
/// (when the first node successfully handles the request),
/// (when the first target successfully handles the request).
///
/// `Plan` implements `Iterator<Item=(NodeRef<'a>, Shard)>` but LoadBalancingPolicy
/// returns `Option<Shard>` instead of `Shard` both in `pick` and in `fallback`.
/// `Plan` handles the `None` case by using random shard for a given node.
/// There is currently no way to configure RNG used by `Plan`.
/// If you don't want `Plan` to do randomize shards or you want to control the RNG,
/// use custom LBP that will always return non-`None` shards.
/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers:
///
/// ```
/// # use std::sync::Arc;
/// # use scylla::load_balancing::LoadBalancingPolicy;
/// # use scylla::load_balancing::RoutingInfo;
/// # use scylla::transport::ClusterData;
/// # use scylla::transport::NodeRef;
/// # use scylla::routing::Shard;
/// # use scylla::load_balancing::FallbackPlan;
///
/// #[derive(Debug)]
/// struct NonRandomLBP {
/// inner: Arc<dyn LoadBalancingPolicy>,
/// }
/// impl LoadBalancingPolicy for NonRandomLBP {
/// fn pick<'a>(
/// &'a self,
/// info: &'a RoutingInfo,
/// cluster: &'a ClusterData,
/// ) -> Option<(NodeRef<'a>, Option<Shard>)> {
/// self.inner
/// .pick(info, cluster)
/// .map(|(node, shard)| (node, shard.or(Some(0))))
/// }
///
/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> {
/// Box::new(self.inner
/// .fallback(info, cluster)
/// .map(|(node, shard)| (node, shard.or(Some(0)))))
/// }
///
/// fn name(&self) -> String {
/// "NonRandomLBP".to_string()
/// }
/// }
/// ```
pub struct Plan<'a> {
policy: &'a dyn LoadBalancingPolicy,
routing_info: &'a RoutingInfo<'a>,
Expand All @@ -41,6 +87,21 @@ impl<'a> Plan<'a> {
state: PlanState::Created,
}
}

fn with_random_shard_if_unknown(
(node, shard): (NodeRef<'_>, Option<Shard>),
) -> (NodeRef<'_>, Shard) {
(
node,
shard.unwrap_or_else(|| {
let nr_shards = node
.sharder()
.map(|sharder| sharder.nr_shards.get())
.unwrap_or(1);
thread_rng().gen_range(0..nr_shards).into()
}),
)
}
}

impl<'a> Iterator for Plan<'a> {
Expand All @@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> {
let picked = self.policy.pick(self.routing_info, self.cluster);
if let Some(picked) = picked {
self.state = PlanState::Picked(picked);
Some(picked)
Some(Self::with_random_shard_if_unknown(picked))
} else {
// `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_.
// This, however, does not imply that fallback would return an empty plan, too.
Expand All @@ -64,9 +125,9 @@ impl<'a> Iterator for Plan<'a> {
if let Some(node) = first_fallback_node {
self.state = PlanState::Fallback {
iter,
node_to_filter_out: node,
target_to_filter_out: node,
};
Some(node)
Some(Self::with_random_shard_if_unknown(node))
} else {
error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info);
self.state = PlanState::PickedNone;
Expand All @@ -77,20 +138,20 @@ impl<'a> Iterator for Plan<'a> {
PlanState::Picked(node) => {
self.state = PlanState::Fallback {
iter: self.policy.fallback(self.routing_info, self.cluster),
node_to_filter_out: *node,
target_to_filter_out: *node,
};

self.next()
}
PlanState::Fallback {
iter,
node_to_filter_out,
target_to_filter_out: node_to_filter_out,
} => {
for node in iter {
if node == *node_to_filter_out {
continue;
} else {
return Some(node);
return Some(Self::with_random_shard_if_unknown(node));
}
}

Expand Down Expand Up @@ -135,7 +196,7 @@ mod tests {
&'a self,
_query: &'a RoutingInfo,
_cluster: &'a ClusterData,
) -> Option<(NodeRef<'a>, Shard)> {
) -> Option<(NodeRef<'a>, Option<Shard>)> {
None
}

Expand All @@ -147,7 +208,7 @@ mod tests {
Box::new(
self.expected_nodes
.iter()
.map(|(node_ref, shard)| (node_ref, *shard)),
.map(|(node_ref, shard)| (node_ref, Some(*shard))),
)
}

Expand Down
2 changes: 1 addition & 1 deletion scylla/tests/integration/consistency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper {
&'a self,
query: &'a RoutingInfo,
cluster: &'a scylla::transport::ClusterData,
) -> Option<(NodeRef<'a>, Shard)> {
) -> Option<(NodeRef<'a>, Option<Shard>)> {
self.routing_info_tx
.send(OwnedRoutingInfo::from(query.clone()))
.unwrap();
Expand Down
8 changes: 6 additions & 2 deletions scylla/tests/integration/execution_profiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,13 @@ impl<const NODE: u8> LoadBalancingPolicy for BoundToPredefinedNodePolicy<NODE> {
&'a self,
_info: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> Option<(NodeRef<'a>, Shard)> {
) -> Option<(NodeRef<'a>, Option<Shard>)> {
self.report_node(Report::LoadBalancing);
cluster.get_nodes_info().iter().next().map(|node| (node, 0))
cluster
.get_nodes_info()
.iter()
.next()
.map(|node| (node, None))
}

fn fallback<'a>(
Expand Down
64 changes: 0 additions & 64 deletions scylla/tests/integration/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use futures::Future;
use itertools::Itertools;
use scylla::load_balancing::LoadBalancingPolicy;
use scylla::routing::Shard;
use scylla::transport::NodeRef;
use std::collections::HashMap;
use std::env;
use std::net::SocketAddr;
Expand All @@ -19,66 +15,6 @@ pub(crate) fn setup_tracing() {
.try_init();
}

fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) {
let nr_shards = node
.sharder()
.map(|sharder| sharder.nr_shards.get())
.unwrap_or(1);
(node, ((nr_shards - 1) % 42) as Shard)
}

#[derive(Debug)]
pub(crate) struct FixedOrderLoadBalancer;
impl LoadBalancingPolicy for FixedOrderLoadBalancer {
fn pick<'a>(
&'a self,
_info: &'a scylla::load_balancing::RoutingInfo,
cluster: &'a scylla::transport::ClusterData,
) -> Option<(NodeRef<'a>, Shard)> {
cluster
.get_nodes_info()
.iter()
.sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address))
.next()
.map(with_pseudorandom_shard)
}

fn fallback<'a>(
&'a self,
_info: &'a scylla::load_balancing::RoutingInfo,
cluster: &'a scylla::transport::ClusterData,
) -> scylla::load_balancing::FallbackPlan<'a> {
Box::new(
cluster
.get_nodes_info()
.iter()
.sorted_by(|node1, node2| Ord::cmp(&node1.address, &node2.address))
.map(with_pseudorandom_shard),
)
}

fn on_query_success(
&self,
_: &scylla::load_balancing::RoutingInfo,
_: std::time::Duration,
_: NodeRef<'_>,
) {
}

fn on_query_failure(
&self,
_: &scylla::load_balancing::RoutingInfo,
_: std::time::Duration,
_: NodeRef<'_>,
_: &scylla_cql::errors::QueryError,
) {
}

fn name(&self) -> String {
"FixedOrderLoadBalancer".to_string()
}
}

pub(crate) async fn test_with_3_node_cluster<F, Fut>(
shard_awareness: ShardAwareness,
test: F,
Expand Down

0 comments on commit 1fcadc9

Please sign in to comment.