Skip to content

Commit

Permalink
add enforce_target_node on batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed Jul 2, 2023
1 parent 20c32b8 commit 7e47f67
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 5 deletions.
79 changes: 78 additions & 1 deletion scylla/src/statement/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::sync::Arc;
use crate::history::HistoryListener;
use crate::retry_policy::RetryPolicy;
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
use crate::transport::execution_profile::ExecutionProfileHandle;
use crate::transport::{execution_profile::ExecutionProfileHandle, Node};
use crate::Session;

use super::StatementConfig;
pub use super::{Consistency, SerialConsistency};
Expand Down Expand Up @@ -144,6 +145,82 @@ impl Batch {
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
self.config.execution_profile_handle.as_ref()
}

/// Associates the batch with a new execution profile that will have a load balancing policy
/// that will enforce the use of the provided [`Node`] to the extent possible.
///
/// This should typically be used in conjunction with [`Session::shard_for_statement`], where
/// you would constitute a batch by assigning to the same batch all the statements that would be executed in
/// the same shard.
///
/// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement
/// to the same node, you should use this method to enforce the use of the original node that was envisioned by
/// `shard_for_statement` for the batch:
///
/// ```rust
/// # use scylla::Session;
/// # use std::error::Error;
/// # async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
/// use scylla::{
/// batch::Batch,
/// frame::value::{SerializedValues, ValueList},
/// };
///
/// let prepared_statement = session
/// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)")
/// .await?;
///
/// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned();
/// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?;
///
/// // Send that to a task that will handle statements targeted to the same shard
///
/// // On that task:
/// // Constitute a batch with all the statements that would be executed in the same shard
///
/// let mut batch: Batch = Default::default();
/// if let Some((node, _shard_idx)) = shard {
/// batch.enforce_target_node(&node, &session);
/// }
/// let mut batch_values = Vec::new();
///
/// // As the task handling statements targeted to this shard receives them,
/// // it appends them to the batch
/// batch.append_statement(prepared_statement);
/// batch_values.push(serialized_values);
///
/// // Run the batch
/// session.batch(&batch, batch_values).await?;
/// # Ok(())
/// # }
/// ```
///
///
/// If the target node is not available anymore at the time of executing the statement, it will fallback to the
/// original load balancing policy:
/// - Either that currently set on the [`Batch`], if any
/// - Or that of the [`Session`] if there isn't one on the `Batch`
pub fn enforce_target_node<'a>(
&mut self,
node: &Arc<Node>,
base_execution_profile_from_session: &Session,
) {
let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| {
base_execution_profile_from_session.get_default_execution_profile_handle()
});
self.set_execution_profile_handle(Some(
execution_profile_handle
.pointee_to_builder()
.load_balancing_policy(Arc::new(
crate::load_balancing::EnforceTargetNodePolicy::new(
node,
execution_profile_handle.load_balancing_policy(),
),
))
.build()
.into_handle(),
))
}
}

impl Default for Batch {
Expand Down
4 changes: 4 additions & 0 deletions scylla/src/transport/execution_profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,4 +485,8 @@ impl ExecutionProfileHandle {
pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) {
self.0 .0.store(profile.0)
}

pub fn load_balancing_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
self.0 .0.load().load_balancing_policy.clone()
}
}
42 changes: 42 additions & 0 deletions scylla/src/transport/load_balancing/enforce_node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
use crate::transport::{cluster::ClusterData, Node};
use std::sync::Arc;

#[derive(Debug)]
pub struct EnforceTargetNodePolicy {
target_node: uuid::Uuid,
fallback: Arc<dyn LoadBalancingPolicy>,
}

impl EnforceTargetNodePolicy {
pub fn new(target_node: &Arc<Node>, fallback: Arc<dyn LoadBalancingPolicy>) -> Self {
Self {
target_node: target_node.host_id,
fallback,
}
}
}
impl LoadBalancingPolicy for EnforceTargetNodePolicy {
fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option<NodeRef<'a>> {
cluster
.known_peers
.get(&self.target_node)
.or_else(|| self.fallback.pick(query, cluster))
}

fn fallback<'a>(
&'a self,
query: &'a RoutingInfo,
cluster: &'a ClusterData,
) -> FallbackPlan<'a> {
self.fallback.fallback(query, cluster)
}

fn name(&self) -> String {
format!(
"Enforce target node Load balancing policy - Node: {} - fallback: {}",
self.target_node,
self.fallback.name()
)
}
}
6 changes: 5 additions & 1 deletion scylla/src/transport/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ use scylla_cql::{errors::QueryError, frame::types};
use std::time::Duration;

mod default;
mod enforce_node;
mod plan;
pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder};
pub use plan::Plan;
pub use {
default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder},
enforce_node::EnforceTargetNodePolicy,
};

/// Represents info about statement that can be used by load balancing policies.
#[derive(Default, Clone, Debug)]
Expand Down
6 changes: 3 additions & 3 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1812,11 +1812,11 @@ impl Session {
.map(|partition_key| prepared.get_partitioner_name().hash(&partition_key)))
}

/// Get the first node/shard that the load balancer would target if running this query
/// Get a node/shard that the load balancer would potentially target if running this query
///
/// This may help constituting shard-aware batches
/// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`])
#[allow(clippy::type_complexity)]
pub fn first_shard_for_statement(
pub fn shard_for_statement(
&self,
prepared: &PreparedStatement,
serialized_values: &SerializedValues,
Expand Down

0 comments on commit 7e47f67

Please sign in to comment.