diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index ab120aea62..20b4bbe90e 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use crate::history::HistoryListener; use crate::retry_policy::RetryPolicy; use crate::statement::{prepared_statement::PreparedStatement, query::Query}; -use crate::transport::execution_profile::ExecutionProfileHandle; +use crate::transport::{execution_profile::ExecutionProfileHandle, Node}; +use crate::Session; use super::StatementConfig; pub use super::{Consistency, SerialConsistency}; @@ -144,6 +145,82 @@ impl Batch { pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { self.config.execution_profile_handle.as_ref() } + + /// Associates the batch with a new execution profile that will have a load balancing policy + /// that will enforce the use of the provided [`Node`] to the extent possible. + /// + /// This should typically be used in conjunction with [`Session::shard_for_statement`], where + /// you would constitute a batch by assigning to the same batch all the statements that would be executed in + /// the same shard. + /// + /// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement + /// to the same node, you should use this method to enforce the use of the original node that was envisioned by + /// `shard_for_statement` for the batch: + /// + /// ```rust + /// # use scylla::Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { + /// use scylla::{ + /// batch::Batch, + /// frame::value::{SerializedValues, ValueList}, + /// }; + /// + /// let prepared_statement = session + /// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)") + /// .await?; + /// + /// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned(); + /// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?; + /// + /// // Send that to a task that will handle statements targeted to the same shard + /// + /// // On that task: + /// // Constitute a batch with all the statements that would be executed in the same shard + /// + /// let mut batch: Batch = Default::default(); + /// if let Some((node, _shard_idx)) = shard { + /// batch.enforce_target_node(&node, &session); + /// } + /// let mut batch_values = Vec::new(); + /// + /// // As the task handling statements targeted to this shard receives them, + /// // it appends them to the batch + /// batch.append_statement(prepared_statement); + /// batch_values.push(serialized_values); + /// + /// // Run the batch + /// session.batch(&batch, batch_values).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + /// If the target node is not available anymore at the time of executing the statement, it will fallback to the + /// original load balancing policy: + /// - Either that currently set on the [`Batch`], if any + /// - Or that of the [`Session`] if there isn't one on the `Batch` + pub fn enforce_target_node<'a>( + &mut self, + node: &Arc, + base_execution_profile_from_session: &Session, + ) { + let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| { + base_execution_profile_from_session.get_default_execution_profile_handle() + }); + self.set_execution_profile_handle(Some( + execution_profile_handle + .pointee_to_builder() + .load_balancing_policy(Arc::new( + crate::load_balancing::EnforceTargetNodePolicy::new( + node, + execution_profile_handle.load_balancing_policy(), + ), + )) + .build() + .into_handle(), + )) + } } impl Default for Batch { diff --git a/scylla/src/transport/execution_profile.rs b/scylla/src/transport/execution_profile.rs index 245beffab9..7f92fb3b18 100644 --- a/scylla/src/transport/execution_profile.rs +++ b/scylla/src/transport/execution_profile.rs @@ -485,4 +485,8 @@ impl ExecutionProfileHandle { pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) { self.0 .0.store(profile.0) } + + pub fn load_balancing_policy(&self) -> Arc { + self.0 .0.load().load_balancing_policy.clone() + } } diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs new file mode 100644 index 0000000000..12c17bbcbf --- /dev/null +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -0,0 +1,42 @@ +use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; +use crate::transport::{cluster::ClusterData, Node}; +use std::sync::Arc; + +#[derive(Debug)] +pub struct EnforceTargetNodePolicy { + target_node: uuid::Uuid, + fallback: Arc, +} + +impl EnforceTargetNodePolicy { + pub fn new(target_node: &Arc, fallback: Arc) -> Self { + Self { + target_node: target_node.host_id, + fallback, + } + } +} +impl LoadBalancingPolicy for EnforceTargetNodePolicy { + fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + cluster + .known_peers + .get(&self.target_node) + .or_else(|| self.fallback.pick(query, cluster)) + } + + fn fallback<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> FallbackPlan<'a> { + self.fallback.fallback(query, cluster) + } + + fn name(&self) -> String { + format!( + "Enforce target node Load balancing policy - Node: {} - fallback: {}", + self.target_node, + self.fallback.name() + ) + } +} diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index d4095743c3..64760685e7 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -9,9 +9,13 @@ use scylla_cql::{errors::QueryError, frame::types}; use std::time::Duration; mod default; +mod enforce_node; mod plan; -pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}; pub use plan::Plan; +pub use { + default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}, + enforce_node::EnforceTargetNodePolicy, +}; /// Represents info about statement that can be used by load balancing policies. #[derive(Default, Clone, Debug)] diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 2efd26f4aa..53ae000ed3 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -1812,11 +1812,11 @@ impl Session { .map(|partition_key| prepared.get_partitioner_name().hash(&partition_key))) } - /// Get the first node/shard that the load balancer would target if running this query + /// Get a node/shard that the load balancer would potentially target if running this query /// - /// This may help constituting shard-aware batches + /// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`]) #[allow(clippy::type_complexity)] - pub fn first_shard_for_statement( + pub fn shard_for_statement( &self, prepared: &PreparedStatement, serialized_values: &SerializedValues,