diff --git a/scylla-cql/benches/benchmark.rs b/scylla-cql/benches/benchmark.rs index 0aa6c89102..ea334d41b5 100644 --- a/scylla-cql/benches/benchmark.rs +++ b/scylla-cql/benches/benchmark.rs @@ -3,11 +3,11 @@ use std::borrow::Cow; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use scylla_cql::frame::request::SerializableRequest; -use scylla_cql::frame::value::SerializedValues; use scylla_cql::frame::value::ValueList; use scylla_cql::frame::{request::query, Compression, SerializedRequest}; +use scylla_cql::types::serialize::row::NewSerializedValues; -fn make_query<'a>(contents: &'a str, values: &'a SerializedValues) -> query::Query<'a> { +fn make_query<'a>(contents: &'a str, values: &'a NewSerializedValues) -> query::Query<'a> { query::Query { contents: Cow::Borrowed(contents), parameters: query::QueryParameters { @@ -31,7 +31,7 @@ fn serialized_request_make_bench(c: &mut Criterion) { &(1234, "a value", "i am storing a string", "dc0c8cd7-d954-47c1-8722-a857941c43fb").serialized().unwrap() ), ]; - let queries = query_args.map(|(q, v)| make_query(q, v)); + let queries = query_args.map(|(q, v)| make_query(q, todo!())); for query in queries { let query_size = query.to_bytes().unwrap().len(); diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index 9e80247e20..e884e37ad5 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -3,6 +3,7 @@ use crate::frame::frame_errors::{FrameError, ParseError}; use crate::frame::protocol_features::ProtocolFeatures; use crate::frame::value::SerializeValuesError; +use crate::types::serialize::SerializationError; use crate::Consistency; use bytes::Bytes; use std::io::ErrorKind; @@ -340,6 +341,9 @@ pub enum BadQuery { #[error("Serializing values failed: {0} ")] SerializeValuesError(#[from] SerializeValuesError), + #[error("Serializing values failed: {0} ")] + SerializationError(#[from] SerializationError), + /// Serialized values are too long to compute partition key #[error("Serialized values are too long to compute partition key! Length: {0}, Max allowed length: {1}")] ValuesTooLongForKey(usize, usize), @@ -443,6 +447,12 @@ impl From for QueryError { } } +impl From for QueryError { + fn from(serialized_err: SerializationError) -> QueryError { + QueryError::BadQuery(BadQuery::SerializationError(serialized_err)) + } +} + impl From for QueryError { fn from(parse_error: ParseError) -> QueryError { QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error)) diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 3da4e26d01..9a3b228505 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -1,6 +1,7 @@ use super::response; use crate::cql_to_rust::CqlTypeError; use crate::frame::value::SerializeValuesError; +use crate::types::serialize::SerializationError; use thiserror::Error; #[derive(Error, Debug)] @@ -44,5 +45,7 @@ pub enum ParseError { #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), #[error(transparent)] + SerializationError(#[from] SerializationError), + #[error(transparent)] CqlTypeError(#[from] CqlTypeError), } diff --git a/scylla-cql/src/frame/request/mod.rs b/scylla-cql/src/frame/request/mod.rs index cd41d6bce1..160008cab0 100644 --- a/scylla-cql/src/frame/request/mod.rs +++ b/scylla-cql/src/frame/request/mod.rs @@ -112,9 +112,10 @@ mod tests { query::{Query, QueryParameters}, DeserializableRequest, SerializableRequest, }, + response::result::ColumnType, types::{self, SerialConsistency}, - value::SerializedValues, }, + types::serialize::row::NewSerializedValues, Consistency, }; @@ -129,8 +130,8 @@ mod tests { page_size: Some(323), paging_state: Some(vec![2, 1, 3, 7].into()), values: { - let mut vals = SerializedValues::new(); - vals.add_value(&2137).unwrap(); + let mut vals = NewSerializedValues::new(); + vals.add_value(&2137, &ColumnType::Int).unwrap(); Cow::Owned(vals) }, }; @@ -156,9 +157,9 @@ mod tests { page_size: None, paging_state: None, values: { - let mut vals = SerializedValues::new(); - vals.add_named_value("the_answer", &42).unwrap(); - vals.add_named_value("really?", &2137).unwrap(); + let mut vals = NewSerializedValues::new(); + vals.add_value(&42, &ColumnType::Int).unwrap(); + vals.add_value(&2137, &ColumnType::Int).unwrap(); Cow::Owned(vals) }, }; @@ -189,8 +190,18 @@ mod tests { // Not execute's values, because named values are not supported in batches. values: vec![ - query.parameters.values.deref().clone(), - query.parameters.values.deref().clone(), + query + .parameters + .values + .deref() + .clone() + .into_old_serialized_values(), + query + .parameters + .values + .deref() + .clone() + .into_old_serialized_values(), ], }; { @@ -212,7 +223,7 @@ mod tests { timestamp: None, page_size: None, paging_state: None, - values: Cow::Owned(SerializedValues::new()), + values: Cow::Owned(NewSerializedValues::new()), }; let query = Query { contents: contents.clone(), @@ -261,7 +272,12 @@ mod tests { serial_consistency: None, timestamp: None, - values: vec![query.parameters.values.deref().clone()], + values: vec![query + .parameters + .values + .deref() + .clone() + .into_old_serialized_values()], }; { let mut buf = Vec::new(); diff --git a/scylla-cql/src/frame/request/query.rs b/scylla-cql/src/frame/request/query.rs index ff0b0cc867..ed372c85cc 100644 --- a/scylla-cql/src/frame/request/query.rs +++ b/scylla-cql/src/frame/request/query.rs @@ -1,12 +1,14 @@ use std::borrow::Cow; -use crate::frame::{frame_errors::ParseError, types::SerialConsistency}; +use crate::{ + frame::{frame_errors::ParseError, types::SerialConsistency}, + types::serialize::row::NewSerializedValues, +}; use bytes::{Buf, BufMut, Bytes}; use crate::{ frame::request::{RequestOpcode, SerializableRequest}, frame::types, - frame::value::SerializedValues, }; use super::DeserializableRequest; @@ -61,7 +63,7 @@ pub struct QueryParameters<'a> { pub timestamp: Option, pub page_size: Option, pub paging_state: Option, - pub values: Cow<'a, SerializedValues>, + pub values: Cow<'a, NewSerializedValues>, } impl Default for QueryParameters<'_> { @@ -72,7 +74,7 @@ impl Default for QueryParameters<'_> { timestamp: None, page_size: None, paging_state: None, - values: Cow::Borrowed(SerializedValues::EMPTY), + values: Cow::Owned(NewSerializedValues::new()), } } } @@ -102,10 +104,6 @@ impl QueryParameters<'_> { flags |= FLAG_WITH_DEFAULT_TIMESTAMP; } - if self.values.has_names() { - flags |= FLAG_WITH_NAMES_FOR_VALUES; - } - buf.put_u8(flags); if !self.values.is_empty() { @@ -151,10 +149,16 @@ impl<'q> QueryParameters<'q> { let default_timestamp_flag = (flags & FLAG_WITH_DEFAULT_TIMESTAMP) != 0; let values_have_names_flag = (flags & FLAG_WITH_NAMES_FOR_VALUES) != 0; + if values_have_names_flag { + return Err(ParseError::BadIncomingData( + "Named values in frame are currently unsupported".to_string(), + )); + } + let values = Cow::Owned(if values_flag { - SerializedValues::new_from_frame(buf, values_have_names_flag)? + NewSerializedValues::new_from_frame(buf)? } else { - SerializedValues::new() + NewSerializedValues::new() }); let page_size = page_size_flag.then(|| types::read_int(buf)).transpose()?; diff --git a/scylla-cql/src/types/serialize/mod.rs b/scylla-cql/src/types/serialize/mod.rs index 8aa52be9ee..a920c6bac3 100644 --- a/scylla-cql/src/types/serialize/mod.rs +++ b/scylla-cql/src/types/serialize/mod.rs @@ -1,9 +1,9 @@ -use std::{any::Any, sync::Arc}; +use std::{error::Error, sync::Arc}; pub mod row; pub mod value; -type SerializationError = Arc; +pub type SerializationError = Arc; /// An interface that facilitates writing values for a CQL query. pub trait RowWriter { diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index e852343b40..06c3d9c4c6 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,11 +1,18 @@ use std::{collections::HashMap, sync::Arc}; +use bytes::BufMut; use thiserror::Error; -use crate::frame::value::ValueList; +use crate::_macro_internal::SerializedValues; +use crate::frame::frame_errors::ParseError; +use crate::frame::response::result::{ColumnType, PreparedMetadata}; +use crate::frame::types; +use crate::frame::value::{SerializeValuesError, ValueList}; use crate::frame::{response::result::ColumnSpec, types::RawValue}; +use crate::types::serialize::BufBackedRowWriter; -use super::{CellWriter, RowWriter, SerializationError}; +use super::value::SerializeCql; +use super::{BufBackedCellWriter, CellWriter, RowWriter, SerializationError}; /// Contains information needed to serialize a row. pub struct RowSerializationContext<'a> { @@ -13,6 +20,12 @@ pub struct RowSerializationContext<'a> { } impl<'a> RowSerializationContext<'a> { + pub fn from_prepared(prepared: &'a PreparedMetadata) -> Self { + Self { + columns: prepared.col_specs.as_slice(), + } + } + /// Returns column/bind marker specifications for given query. #[inline] pub fn columns(&self) -> &'a [ColumnSpec] { @@ -37,7 +50,9 @@ pub trait SerializeRow { /// Sometimes, a row cannot be fully type checked right away without knowing /// the exact values of the columns (e.g. when deserializing to `CqlValue`), /// but it's fine to do full type checking later in `serialize`. - fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>; + fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError> + where + Self: Sized; /// Serializes the row according to the information in the given context. /// @@ -48,6 +63,8 @@ pub trait SerializeRow { ctx: &RowSerializationContext<'_>, writer: &mut W, ) -> Result<(), SerializationError>; + + fn is_empty(&self) -> bool; } impl SerializeRow for T { @@ -64,6 +81,10 @@ impl SerializeRow for T { ) -> Result<(), SerializationError> { serialize_legacy_row(self, ctx, writer) } + + fn is_empty(&self) -> bool { + self.serialized().map(|s| s.len()).unwrap_or(0) == 0 + } } pub fn serialize_legacy_row( @@ -187,3 +208,122 @@ mod tests { assert_eq!(sorted_row_data, unsorted_row_data); } } + +/// Keeps a buffer with serialized Values +/// Allows adding new Values and iterating over serialized ones +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct NewSerializedValues { + serialized_values: Vec, +} + +impl NewSerializedValues { + pub fn new() -> Self { + let mut buf = Vec::new(); + buf.extend_from_slice(&0u16.to_be_bytes()); + NewSerializedValues { + serialized_values: buf, + } + } + pub fn from_serializable( + ctx: &RowSerializationContext, + row: &T, + ) -> Result { + let mut data = Vec::new(); + let mut writer = BufBackedRowWriter::new(&mut data); + T::preliminary_type_check(ctx)?; + row.serialize(ctx, &mut writer)?; + drop(writer); + Ok(NewSerializedValues { + serialized_values: data, + }) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn iter(&self) -> impl Iterator { + // TODO: skip first bytes because they contain value count + NewSerializedValuesIterator { + serialized_values: &self.serialized_values, + } + } + + pub fn len(&self) -> u16 { + // We initialize first two bytes in new() and BufBackedRowWriter does too, + // so this unwrap is safe + u16::from_be_bytes(self.serialized_values[0..2].try_into().unwrap()) + } + + pub fn size(&self) -> usize { + self.serialized_values.len() + } + + pub fn write_to_request(&self, buf: &mut impl BufMut) { + buf.put(self.serialized_values.as_slice()) + } + + /// Serializes value and appends it to the list + pub fn add_value( + &mut self, + val: &T, + typ: &ColumnType, + ) -> Result<(), SerializationError> { + if self.len() == u16::MAX { + return Err(Arc::new(SerializeValuesError::TooManyValues) as SerializationError); + } + + T::preliminary_type_check(typ)?; + + let len_before_serialize: usize = self.serialized_values.len(); + + let writer = BufBackedCellWriter::new(&mut self.serialized_values); + if let Err(e) = val.serialize(typ, writer) { + self.serialized_values.resize(len_before_serialize, 0); + Err(e) + } else { + let new_len: u16 = self.len() + 1; + self.serialized_values[0..2].copy_from_slice(&new_len.to_be_bytes()); + Ok(()) + } + } + + /// Creates value list from the request frame + pub fn new_from_frame(buf: &mut &[u8]) -> Result { + let full_buf = *buf; + let values_num = types::read_short(buf)?; + let values_beg = *buf; + for _ in 0..values_num { + let _serialized = types::read_value(buf)?; + } + + let values_len_in_buf = values_beg.len() - buf.len(); + let values_in_frame = &full_buf[0..values_len_in_buf + 2]; + Ok(NewSerializedValues { + serialized_values: values_in_frame.to_vec(), + }) + } + + // Temporary function, to be removed when we implement new batching API (right now it is needed in frame::request::mod.rs tests) + #[allow(dead_code)] + pub(crate) fn into_old_serialized_values(self) -> SerializedValues { + SerializedValues::new_from_frame(&mut self.serialized_values.as_slice(), false).unwrap() + } +} + +#[derive(Clone, Copy)] +pub struct NewSerializedValuesIterator<'a> { + serialized_values: &'a [u8], +} + +impl<'a> Iterator for NewSerializedValuesIterator<'a> { + type Item = RawValue<'a>; + + fn next(&mut self) -> Option { + if self.serialized_values.is_empty() { + return None; + } + + Some(types::read_value(&mut self.serialized_values).expect("badly encoded value")) + } +} diff --git a/scylla/benches/benchmark.rs b/scylla/benches/benchmark.rs index b33b08a21b..d305cbfe08 100644 --- a/scylla/benches/benchmark.rs +++ b/scylla/benches/benchmark.rs @@ -3,9 +3,9 @@ use criterion::{criterion_group, criterion_main, Criterion}; use bytes::BytesMut; use scylla::{ frame::types, - frame::value::ValueList, transport::partitioner::{calculate_token_for_partition_key, Murmur3Partitioner}, }; +use scylla_cql::{frame::response::result::ColumnType, types::serialize::row::NewSerializedValues}; fn types_benchmark(c: &mut Criterion) { let mut buf = BytesMut::with_capacity(64); @@ -40,23 +40,49 @@ fn types_benchmark(c: &mut Criterion) { } fn calculate_token_bench(c: &mut Criterion) { - let simple_pk = ("I'm prepared!!!",); - let serialized_simple_pk = simple_pk.serialized().unwrap().into_owned(); - let simple_pk_long_column = ( - 17_i32, - 16_i32, - String::from_iter(std::iter::repeat('.').take(2000)), - ); - let serialized_simple_pk_long_column = simple_pk_long_column.serialized().unwrap().into_owned(); + let mut serialized_simple_pk = NewSerializedValues::new(); + serialized_simple_pk + .add_value(&"I'm prepared!!!", &ColumnType::Text) + .unwrap(); - let complex_pk = (17_i32, 16_i32, "I'm prepared!!!"); - let serialized_complex_pk = complex_pk.serialized().unwrap().into_owned(); - let complex_pk_long_column = ( - 17_i32, - 16_i32, - String::from_iter(std::iter::repeat('.').take(2000)), - ); - let serialized_values_long_column = complex_pk_long_column.serialized().unwrap().into_owned(); + let mut serialized_simple_pk_long_column = NewSerializedValues::new(); + serialized_simple_pk_long_column + .add_value(&17_i32, &ColumnType::Int) + .unwrap(); + serialized_simple_pk_long_column + .add_value(&16_i32, &ColumnType::Int) + .unwrap(); + serialized_simple_pk_long_column + .add_value( + &String::from_iter(std::iter::repeat('.').take(2000)), + &ColumnType::Text, + ) + .unwrap(); + + let mut serialized_complex_pk = NewSerializedValues::new(); + serialized_complex_pk + .add_value(&17_i32, &ColumnType::Int) + .unwrap(); + serialized_complex_pk + .add_value(&16_i32, &ColumnType::Int) + .unwrap(); + serialized_complex_pk + .add_value(&"I'm prepared!!!", &ColumnType::Text) + .unwrap(); + + let mut serialized_values_long_column = NewSerializedValues::new(); + serialized_values_long_column + .add_value(&17_i32, &ColumnType::Int) + .unwrap(); + serialized_values_long_column + .add_value(&16_i32, &ColumnType::Int) + .unwrap(); + serialized_values_long_column + .add_value( + &String::from_iter(std::iter::repeat('.').take(2000)), + &ColumnType::Text, + ) + .unwrap(); c.bench_function("calculate_token_from_partition_key simple pk", |b| { b.iter(|| calculate_token_for_partition_key(&serialized_simple_pk, &Murmur3Partitioner)) diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 58d8b9ea3d..cb90ecec8d 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -1,6 +1,10 @@ use bytes::{Bytes, BytesMut}; use scylla_cql::errors::{BadQuery, QueryError}; use scylla_cql::frame::types::RawValue; +use scylla_cql::types::serialize::row::{ + NewSerializedValues, RowSerializationContext, SerializeRow, +}; +use scylla_cql::types::serialize::SerializationError; use smallvec::{smallvec, SmallVec}; use std::convert::TryInto; use std::sync::Arc; @@ -13,7 +17,6 @@ use scylla_cql::frame::response::result::ColumnSpec; use super::StatementConfig; use crate::frame::response::result::PreparedMetadata; use crate::frame::types::{Consistency, SerialConsistency}; -use crate::frame::value::SerializedValues; use crate::history::HistoryListener; use crate::retry_policy::RetryPolicy; use crate::routing::Token; @@ -134,9 +137,10 @@ impl PreparedStatement { /// [Self::calculate_token()]. pub fn compute_partition_key( &self, - bound_values: &SerializedValues, + bound_values: &impl SerializeRow, ) -> Result { - let partition_key = self.extract_partition_key(bound_values)?; + let serialized = self.serialize_values(bound_values)?; + let partition_key = self.extract_partition_key(&serialized)?; let mut buf = BytesMut::new(); let mut writer = |chunk: &[u8]| buf.extend_from_slice(chunk); @@ -150,7 +154,7 @@ impl PreparedStatement { /// This is a preparation step necessary for calculating token based on a prepared statement. pub(crate) fn extract_partition_key<'ps>( &'ps self, - bound_values: &'ps SerializedValues, + bound_values: &'ps NewSerializedValues, ) -> Result { PartitionKey::new(self.get_prepared_metadata(), bound_values) } @@ -158,7 +162,7 @@ impl PreparedStatement { pub(crate) fn extract_partition_key_and_calculate_token<'ps>( &'ps self, partitioner_name: &'ps PartitionerName, - serialized_values: &'ps SerializedValues, + serialized_values: &'ps NewSerializedValues, ) -> Result, Token)>, QueryError> { if !self.is_token_aware() { return Ok(None); @@ -189,12 +193,12 @@ impl PreparedStatement { // As this function creates a `PartitionKey`, it is intended rather for external usage (by users). // For internal purposes, `PartitionKey::calculate_token()` is preferred, as `PartitionKey` // is either way used internally, among others for display in traces. - pub fn calculate_token( - &self, - serialized_values: &SerializedValues, - ) -> Result, QueryError> { - self.extract_partition_key_and_calculate_token(&self.partitioner_name, serialized_values) - .map(|opt| opt.map(|(_pk, token)| token)) + pub fn calculate_token(&self, values: &impl SerializeRow) -> Result, QueryError> { + self.extract_partition_key_and_calculate_token( + &self.partitioner_name, + &self.serialize_values(values)?, + ) + .map(|opt| opt.map(|(_pk, token)| token)) } /// Returns the name of the keyspace this statement is operating on. @@ -335,6 +339,14 @@ impl PreparedStatement { pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { self.config.execution_profile_handle.as_ref() } + + pub(crate) fn serialize_values( + &self, + values: &impl SerializeRow, + ) -> Result { + let ctx = RowSerializationContext::from_prepared(self.get_prepared_metadata()); + NewSerializedValues::from_serializable(&ctx, values) + } } #[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] @@ -349,12 +361,14 @@ pub enum TokenCalculationError { ValueTooLong(usize), } -#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Debug, Error)] pub enum PartitionKeyError { #[error(transparent)] PartitionKeyExtraction(PartitionKeyExtractionError), #[error(transparent)] TokenCalculation(TokenCalculationError), + #[error(transparent)] + Serialization(SerializationError), } impl From for PartitionKeyError { @@ -369,6 +383,12 @@ impl From for PartitionKeyError { } } +impl From for PartitionKeyError { + fn from(err: SerializationError) -> Self { + Self::Serialization(err) + } +} + pub(crate) type PartitionKeyValue<'ps> = (&'ps [u8], &'ps ColumnSpec); pub(crate) struct PartitionKey<'ps> { @@ -380,7 +400,7 @@ impl<'ps> PartitionKey<'ps> { fn new( prepared_metadata: &'ps PreparedMetadata, - bound_values: &'ps SerializedValues, + bound_values: &'ps NewSerializedValues, ) -> Result { // Iterate on values using sorted pk_indexes (see deser_prepared_metadata), // and use PartitionKeyIndex.sequence to insert the value in pk_values with the correct order. @@ -456,11 +476,11 @@ impl<'ps> PartitionKey<'ps> { #[cfg(test)] mod tests { - use scylla_cql::frame::{ - response::result::{ + use scylla_cql::{ + frame::response::result::{ ColumnSpec, ColumnType, PartitionKeyIndex, PreparedMetadata, TableSpec, }, - value::SerializedValues, + types::serialize::row::NewSerializedValues, }; use crate::prepared_statement::PartitionKey; @@ -511,12 +531,14 @@ mod tests { ], [4, 0, 3], ); - let mut values = SerializedValues::new(); - values.add_value(&67i8).unwrap(); - values.add_value(&42i16).unwrap(); - values.add_value(&23i32).unwrap(); - values.add_value(&89i64).unwrap(); - values.add_value(&[1u8, 2, 3, 4, 5]).unwrap(); + let mut values = NewSerializedValues::new(); + values.add_value(&67i8, &ColumnType::TinyInt).unwrap(); + values.add_value(&42i16, &ColumnType::SmallInt).unwrap(); + values.add_value(&23i32, &ColumnType::Int).unwrap(); + values.add_value(&89i64, &ColumnType::BigInt).unwrap(); + values + .add_value(&[1u8, 2, 3, 4, 5], &ColumnType::Blob) + .unwrap(); let pk = PartitionKey::new(&meta, &values).unwrap(); let pk_cols = Vec::from_iter(pk.iter()); diff --git a/scylla/src/transport/caching_session.rs b/scylla/src/transport/caching_session.rs index 82e12b1ab2..034ec8793a 100644 --- a/scylla/src/transport/caching_session.rs +++ b/scylla/src/transport/caching_session.rs @@ -1,5 +1,5 @@ use crate::batch::{Batch, BatchStatement}; -use crate::frame::value::{BatchValues, ValueList}; +use crate::frame::value::BatchValues; use crate::prepared_statement::PreparedStatement; use crate::query::Query; use crate::transport::errors::QueryError; @@ -10,6 +10,7 @@ use bytes::Bytes; use dashmap::DashMap; use futures::future::try_join_all; use scylla_cql::frame::response::result::PreparedMetadata; +use scylla_cql::types::serialize::row::SerializeRow; use std::collections::hash_map::RandomState; use std::hash::BuildHasher; @@ -70,38 +71,35 @@ where pub async fn execute( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { let query = query.into(); let prepared = self.add_prepared_statement_owned(query).await?; - let values = values.serialized()?; - self.session.execute(&prepared, values.clone()).await + self.session.execute(&prepared, values).await } /// Does the same thing as [`Session::execute_iter`] but uses the prepared statement cache pub async fn execute_iter( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { let query = query.into(); let prepared = self.add_prepared_statement_owned(query).await?; - let values = values.serialized()?; - self.session.execute_iter(prepared, values.clone()).await + self.session.execute_iter(prepared, values).await } /// Does the same thing as [`Session::execute_paged`] but uses the prepared statement cache pub async fn execute_paged( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, paging_state: Option, ) -> Result { let query = query.into(); let prepared = self.add_prepared_statement_owned(query).await?; - let values = values.serialized()?; self.session - .execute_paged(&prepared, values.clone(), paging_state.clone()) + .execute_paged(&prepared, values, paging_state.clone()) .await } diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 503d14519d..63ba5840b5 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -1,6 +1,5 @@ /// Cluster manages up to date information and connections to database nodes use crate::frame::response::event::{Event, StatusChangeEvent}; -use crate::frame::value::ValueList; use crate::prepared_statement::TokenCalculationError; use crate::routing::Token; use crate::transport::host_filter::HostFilter; @@ -18,6 +17,7 @@ use futures::future::join_all; use futures::{future::RemoteHandle, FutureExt}; use itertools::Itertools; use scylla_cql::errors::{BadQuery, NewSessionError}; +use scylla_cql::types::serialize::row::NewSerializedValues; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; @@ -390,7 +390,7 @@ impl ClusterData { &self, keyspace: &str, table: &str, - partition_key: impl ValueList, + partition_key: &NewSerializedValues, ) -> Result { let partitioner = self .keyspaces @@ -400,12 +400,11 @@ impl ClusterData { .and_then(PartitionerName::from_str) .unwrap_or_default(); - calculate_token_for_partition_key(&partition_key.serialized().unwrap(), &partitioner) - .map_err(|err| match err { - TokenCalculationError::ValueTooLong(values_len) => { - BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into()) - } - }) + calculate_token_for_partition_key(partition_key, &partitioner).map_err(|err| match err { + TokenCalculationError::ValueTooLong(values_len) => { + BadQuery::ValuesTooLongForKey(values_len, u16::MAX.into()) + } + }) } /// Access to replicas owning a given token @@ -431,12 +430,13 @@ impl ClusterData { replica_set.into_iter() } + // TODO: NewSerializedValues instead of SerializeRow ?? /// Access to replicas owning a given partition key (similar to `nodetool getendpoints`) pub fn get_endpoints( &self, keyspace: &str, table: &str, - partition_key: impl ValueList, + partition_key: &NewSerializedValues, ) -> Result>, BadQuery> { Ok(self.get_token_endpoints( keyspace, diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index de79c5d130..21f28644e8 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -4,6 +4,7 @@ use scylla_cql::errors::TranslationError; use scylla_cql::frame::request::options::Options; use scylla_cql::frame::response::Error; use scylla_cql::frame::types::SerialConsistency; +use scylla_cql::types::serialize::row::{NewSerializedValues, SerializeRow}; use socket2::{SockRef, TcpKeepalive}; use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::{TcpSocket, TcpStream}; @@ -52,7 +53,7 @@ use crate::frame::{ request::{self, batch, execute, query, register, SerializableRequest}, response::{event::Event, result, NonErrorResponse, Response, ResponseOpcode}, server_event_type::EventType, - value::{BatchValues, BatchValuesIterator, ValueList}, + value::{BatchValues, BatchValuesIterator}, FrameParams, SerializedRequest, }; use crate::query::Query; @@ -596,7 +597,6 @@ impl Connection { pub(crate) async fn query_single_page( &self, query: impl Into, - values: impl ValueList, ) -> Result { let query: Query = query.into(); @@ -606,24 +606,18 @@ impl Connection { .determine_consistency(self.config.default_consistency); let serial_consistency = query.config.serial_consistency; - self.query_single_page_with_consistency( - query, - &values, - consistency, - serial_consistency.flatten(), - ) - .await + self.query_single_page_with_consistency(query, consistency, serial_consistency.flatten()) + .await } pub(crate) async fn query_single_page_with_consistency( &self, query: impl Into, - values: impl ValueList, consistency: Consistency, serial_consistency: Option, ) -> Result { let query: Query = query.into(); - self.query_with_consistency(&query, &values, consistency, serial_consistency, None) + self.query_with_consistency(&query, consistency, serial_consistency, None) .await? .into_query_result() } @@ -631,13 +625,11 @@ impl Connection { pub(crate) async fn query( &self, query: &Query, - values: impl ValueList, paging_state: Option, ) -> Result { // This method is used only for driver internal queries, so no need to consult execution profile here. self.query_with_consistency( query, - values, query .config .determine_consistency(self.config.default_consistency), @@ -650,33 +642,16 @@ impl Connection { pub(crate) async fn query_with_consistency( &self, query: &Query, - values: impl ValueList, consistency: Consistency, serial_consistency: Option, paging_state: Option, ) -> Result { - let serialized_values = values.serialized()?; - - let values_size = serialized_values.size(); - if values_size != 0 { - let prepared = self.prepare(query).await?; - return self - .execute_with_consistency( - &prepared, - values, - consistency, - serial_consistency, - paging_state, - ) - .await; - } - let query_frame = query::Query { contents: Cow::Borrowed(&query.contents), parameters: query::QueryParameters { consistency, serial_consistency, - values: serialized_values, + values: Cow::Owned(NewSerializedValues::new()), page_size: query.get_page_size(), paging_state, timestamp: query.get_timestamp(), @@ -687,22 +662,41 @@ impl Connection { .await } + #[allow(dead_code)] + pub(crate) async fn execute( + &self, + prepared: PreparedStatement, + values: impl SerializeRow, + paging_state: Option, + ) -> Result { + let serialized_values = prepared.serialize_values(&values)?; + // This method is used only for driver internal queries, so no need to consult execution profile here. + self.execute_with_consistency( + &prepared, + &serialized_values, + prepared + .config + .determine_consistency(self.config.default_consistency), + prepared.config.serial_consistency.flatten(), + paging_state, + ) + .await + } + pub(crate) async fn execute_with_consistency( &self, prepared_statement: &PreparedStatement, - values: impl ValueList, + values: &NewSerializedValues, consistency: Consistency, serial_consistency: Option, paging_state: Option, ) -> Result { - let serialized_values = values.serialized()?; - let execute_frame = execute::Execute { id: prepared_statement.get_id().to_owned(), parameters: query::QueryParameters { consistency, serial_consistency, - values: serialized_values, + values: Cow::Borrowed(values), page_size: prepared_statement.get_page_size(), timestamp: prepared_statement.get_timestamp(), paging_state, @@ -734,19 +728,32 @@ impl Connection { pub(crate) async fn query_iter( self: Arc, query: Query, - values: impl ValueList, ) -> Result { - let serialized_values = values.serialized()?.into_owned(); - let consistency = query .config .determine_consistency(self.config.default_consistency); let serial_consistency = query.config.serial_consistency.flatten(); - RowIterator::new_for_connection_query_iter( - query, + RowIterator::new_for_connection_query_iter(query, self, consistency, serial_consistency) + .await + } + + /// Executes a prepared statements and fetches its results over multiple pages, using + /// the asynchronous iterator interface. + pub(crate) async fn execute_iter( + self: Arc, + prepared_statement: PreparedStatement, + values: NewSerializedValues, + ) -> Result { + let consistency = prepared_statement + .config + .determine_consistency(self.config.default_consistency); + let serial_consistency = prepared_statement.config.serial_consistency.flatten(); + + RowIterator::new_for_connection_execute_iter( + prepared_statement, + values, self, - serialized_values, consistency, serial_consistency, ) @@ -885,7 +892,7 @@ impl Connection { false => format!("USE {}", keyspace_name.as_str()).into(), }; - let query_response = self.query(&query, (), None).await?; + let query_response = self.query(&query, None).await?; match query_response.response { Response::Result(result::Result::SetKeyspace(set_keyspace)) => { @@ -929,7 +936,7 @@ impl Connection { pub(crate) async fn fetch_schema_version(&self) -> Result { let (version_id,): (Uuid,) = self - .query_single_page(LOCAL_VERSION, &[]) + .query_single_page(LOCAL_VERSION) .await? .rows .ok_or(QueryError::ProtocolError("Version query returned not rows"))? @@ -1833,7 +1840,6 @@ mod tests { use super::ConnectionConfig; use crate::query::Query; use crate::transport::connection::open_connection; - use crate::transport::connection::QueryResponse; use crate::transport::node::ResolvedContactPoint; use crate::transport::topology::UntranslatedEndpoint; use crate::utils::test_utils::unique_keyspace_name; @@ -1914,7 +1920,7 @@ mod tests { let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7); let empty_res = connection .clone() - .query_iter(select_query.clone(), &[]) + .query_iter(select_query.clone()) .await .unwrap() .try_collect::>() @@ -1927,15 +1933,18 @@ mod tests { let mut insert_futures = Vec::new(); let insert_query = Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7); + let prepared = connection.prepare(&insert_query).await.unwrap(); for v in &values { - insert_futures.push(connection.query_single_page(insert_query.clone(), (v,))); + let prepared_clone = prepared.clone(); + let fut = async { connection.execute(prepared_clone, (*v,), None).await }; + insert_futures.push(fut); } futures::future::try_join_all(insert_futures).await.unwrap(); let mut results: Vec = connection .clone() - .query_iter(select_query.clone(), &[]) + .query_iter(select_query.clone()) .await .unwrap() .into_typed::<(i32,)>() @@ -1947,7 +1956,9 @@ mod tests { // 3. INSERT query_iter should work and not return any rows. let insert_res1 = connection - .query_iter(insert_query, (0,)) + .query_iter(Query::new( + "INSERT INTO connection_query_iter_tab (p) VALUES (0)", + )) .await .unwrap() .try_collect::>() @@ -2007,10 +2018,7 @@ mod tests { .await .unwrap(); - connection - .query(&"TRUNCATE t".into(), (), None) - .await - .unwrap(); + connection.query(&"TRUNCATE t".into(), None).await.unwrap(); let mut futs = Vec::new(); @@ -2025,8 +2033,9 @@ mod tests { let q = Query::new("INSERT INTO t (p, v) VALUES (?, ?)"); let conn = conn.clone(); async move { - let response: QueryResponse = conn - .query(&q, (j, vec![j as u8; j as usize]), None) + let prepared = conn.prepare(&q).await.unwrap(); + let response = conn + .execute(prepared.clone(), (j, vec![j as u8; j as usize]), None) .await .unwrap(); // QueryResponse might contain an error - make sure that there were no errors @@ -2045,7 +2054,7 @@ mod tests { // Check that everything was written properly let range_end = arithmetic_sequence_sum(NUM_BATCHES); let mut results = connection - .query(&"SELECT p, v FROM t".into(), (), None) + .query(&"SELECT p, v FROM t".into(), None) .await .unwrap() .into_query_result() @@ -2198,7 +2207,7 @@ mod tests { // As everything is normal, these queries should succeed. for _ in 0..3 { tokio::time::sleep(Duration::from_millis(500)).await; - conn.query_single_page("SELECT host_id FROM system.local", ()) + conn.query_single_page("SELECT host_id FROM system.local") .await .unwrap(); } @@ -2218,7 +2227,7 @@ mod tests { // As the router is invalidated, all further queries should immediately // return error. - conn.query_single_page("SELECT host_id FROM system.local", ()) + conn.query_single_page("SELECT host_id FROM system.local") .await .unwrap_err(); diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index e9389992ed..c794d52103 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -12,6 +12,7 @@ use bytes::Bytes; use futures::Stream; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::frame::types::SerialConsistency; +use scylla_cql::types::serialize::row::NewSerializedValues; use std::result::Result; use thiserror::Error; use tokio::sync::mpsc; @@ -22,12 +23,9 @@ use super::execution_profile::ExecutionProfileInner; use super::session::RequestSpan; use crate::cql_to_rust::{FromRow, FromRowError}; -use crate::frame::{ - response::{ - result, - result::{ColumnSpec, Row, Rows}, - }, - value::SerializedValues, +use crate::frame::response::{ + result, + result::{ColumnSpec, Row, Rows}, }; use crate::history::{self, HistoryListener}; use crate::statement::Consistency; @@ -73,7 +71,7 @@ struct ReceivedPage { pub(crate) struct PreparedIteratorConfig { pub(crate) prepared: PreparedStatement, - pub(crate) values: SerializedValues, + pub(crate) values: NewSerializedValues, pub(crate) execution_profile: Arc, pub(crate) cluster_data: Arc, pub(crate) metrics: Arc, @@ -128,7 +126,6 @@ impl RowIterator { pub(crate) async fn new_for_query( mut query: Query, - values: SerializedValues, execution_profile: Arc, cluster_data: Arc, metrics: Arc, @@ -162,29 +159,28 @@ impl RowIterator { let parent_span = tracing::Span::current(); let worker_task = async move { let query_ref = &query; - let values_ref = &values; let choose_connection = |node: Arc| async move { node.random_connection().await }; let page_query = |connection: Arc, consistency: Consistency, - paging_state: Option| async move { - connection - .query_with_consistency( - query_ref, - values_ref, - consistency, - serial_consistency, - paging_state, - ) - .await + paging_state: Option| { + async move { + connection + .query_with_consistency( + query_ref, + consistency, + serial_consistency, + paging_state, + ) + .await + } }; let query_ref = &query; - let serialized_values_size = values.size(); + // let serialized_values_size = values.size(); - let span_creator = - move || RequestSpan::new_query(&query_ref.contents, serialized_values_size); + let span_creator = move || RequestSpan::new_query(&query_ref.contents); let worker = RowIteratorWorker { sender: sender.into(), @@ -337,7 +333,6 @@ impl RowIterator { pub(crate) async fn new_for_connection_query_iter( mut query: Query, connection: Arc, - values: SerializedValues, consistency: Consistency, serial_consistency: Option, ) -> Result { @@ -352,6 +347,36 @@ impl RowIterator { fetcher: |paging_state| { connection.query_with_consistency( &query, + consistency, + serial_consistency, + paging_state, + ) + }, + }; + worker.work().await + }; + + Self::new_from_worker_future(worker_task, receiver).await + } + + pub(crate) async fn new_for_connection_execute_iter( + mut prepared: PreparedStatement, + values: NewSerializedValues, + connection: Arc, + consistency: Consistency, + serial_consistency: Option, + ) -> Result { + if prepared.get_page_size().is_none() { + prepared.set_page_size(DEFAULT_ITER_PAGE_SIZE); + } + let (sender, receiver) = mpsc::channel::>(1); + + let worker_task = async move { + let worker = SingleConnectionRowIteratorWorker { + sender: sender.into(), + fetcher: |paging_state| { + connection.execute_with_consistency( + &prepared, &values, consistency, serial_consistency, diff --git a/scylla/src/transport/partitioner.rs b/scylla/src/transport/partitioner.rs index 4526715ab2..28c4fc7f18 100644 --- a/scylla/src/transport/partitioner.rs +++ b/scylla/src/transport/partitioner.rs @@ -1,10 +1,8 @@ use bytes::Buf; -use scylla_cql::frame::types::RawValue; +use scylla_cql::{frame::types::RawValue, types::serialize::row::NewSerializedValues}; use std::num::Wrapping; -use crate::{ - frame::value::SerializedValues, prepared_statement::TokenCalculationError, routing::Token, -}; +use crate::{prepared_statement::TokenCalculationError, routing::Token}; #[allow(clippy::upper_case_acronyms)] #[derive(Clone, PartialEq, Debug, Default)] @@ -337,7 +335,7 @@ impl PartitionerHasher for CDCPartitionerHasher { /// NOTE: the provided values must completely constitute partition key /// and be in the order defined in CREATE TABLE statement. pub fn calculate_token_for_partition_key( - serialized_partition_key_values: &SerializedValues, + serialized_partition_key_values: &NewSerializedValues, partitioner: &P, ) -> Result { let mut partitioner_hasher = partitioner.build_hasher(); diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 24cc481c93..fbecaa5890 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -16,6 +16,7 @@ use itertools::{Either, Itertools}; pub use scylla_cql::errors::TranslationError; use scylla_cql::frame::response::result::{deser_cql_value, ColumnSpec, Rows}; use scylla_cql::frame::response::NonErrorResponse; +use scylla_cql::types::serialize::row::SerializeRow; use std::borrow::Borrow; use std::collections::HashMap; use std::fmt::Display; @@ -46,9 +47,7 @@ use super::NodeRef; use crate::cql_to_rust::FromRow; use crate::frame::response::cql_to_rust::FromRowError; use crate::frame::response::result; -use crate::frame::value::{ - BatchValues, BatchValuesFirstSerialized, BatchValuesIterator, ValueList, -}; +use crate::frame::value::{BatchValues, BatchValuesFirstSerialized, BatchValuesIterator}; use crate::prepared_statement::PreparedStatement; use crate::query::Query; use crate::routing::Token; @@ -603,7 +602,7 @@ impl Session { pub async fn query( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { self.query_paged(query, values, None).await } @@ -617,11 +616,10 @@ impl Session { pub async fn query_paged( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, paging_state: Option, ) -> Result { let query: Query = query.into(); - let serialized_values = values.serialized()?; let execution_profile = query .get_execution_profile_handle() @@ -640,7 +638,8 @@ impl Session { ..Default::default() }; - let span = RequestSpan::new_query(&query.contents, serialized_values.size()); + let span = RequestSpan::new_query(&query.contents); + let span_ref = &span; let run_query_result = self .run_query( statement_info, @@ -656,19 +655,35 @@ impl Session { .unwrap_or(execution_profile.serial_consistency); // Needed to avoid moving query and values into async move block let query_ref = &query; - let values_ref = &serialized_values; + let values_ref = &values; let paging_state_ref = &paging_state; async move { - connection - .query_with_consistency( - query_ref, - values_ref, - consistency, - serial_consistency, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) + if values_ref.is_empty() { + span_ref.record_request_size(0); + connection + .query_with_consistency( + query_ref, + consistency, + serial_consistency, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } else { + let prepared = connection.prepare(query_ref).await?; + let serialized = prepared.serialize_values(values_ref)?; + span_ref.record_request_size(serialized.size()); + connection + .execute_with_consistency( + &prepared, + &serialized, + consistency, + serial_consistency, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } } }, &span, @@ -764,24 +779,38 @@ impl Session { pub async fn query_iter( &self, query: impl Into, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { let query: Query = query.into(); - let serialized_values = values.serialized()?; let execution_profile = query .get_execution_profile_handle() .unwrap_or_else(|| self.get_default_execution_profile_handle()) .access(); - RowIterator::new_for_query( - query, - serialized_values.into_owned(), - execution_profile, - self.cluster.get_data(), - self.metrics.clone(), - ) - .await + if values.is_empty() { + RowIterator::new_for_query( + query, + execution_profile, + self.cluster.get_data(), + self.metrics.clone(), + ) + .await + } else { + // Making RowIterator::new_for_query work with values is too hard (if even possible) + // so instead of sending one prepare to a specific connection on each iterator query, + // we fully prepare a statement beforehand. + let prepared = self.prepare(query).await?; + let values = prepared.serialize_values(&values)?; + RowIterator::new_for_prepared_statement(PreparedIteratorConfig { + prepared: prepared, + values: values, + execution_profile: execution_profile, + cluster_data: self.cluster.get_data(), + metrics: self.metrics.clone(), + }) + .await + } } /// Prepares a statement on the server side and returns a prepared statement, @@ -916,7 +945,7 @@ impl Session { pub async fn execute( &self, prepared: &PreparedStatement, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { self.execute_paged(prepared, values, None).await } @@ -930,10 +959,10 @@ impl Session { pub async fn execute_paged( &self, prepared: &PreparedStatement, - values: impl ValueList, + values: impl SerializeRow, paging_state: Option, ) -> Result { - let serialized_values = values.serialized()?; + let serialized_values = prepared.serialize_values(&values)?; let values_ref = &serialized_values; let paging_state_ref = &paging_state; @@ -1076,10 +1105,10 @@ impl Session { pub async fn execute_iter( &self, prepared: impl Into, - values: impl ValueList, + values: impl SerializeRow, ) -> Result { let prepared = prepared.into(); - let serialized_values = values.serialized()?; + let serialized_values = prepared.serialize_values(&values)?; let execution_profile = prepared .get_execution_profile_handle() @@ -1088,7 +1117,7 @@ impl Session { RowIterator::new_for_prepared_statement(PreparedIteratorConfig { prepared, - values: serialized_values.into_owned(), + values: serialized_values, execution_profile, cluster_data: self.cluster.get_data(), metrics: self.metrics.clone(), @@ -1891,7 +1920,7 @@ pub(crate) struct RequestSpan { } impl RequestSpan { - pub(crate) fn new_query(contents: &str, request_size: usize) -> Self { + pub(crate) fn new_query(contents: &str) -> Self { use tracing::field::Empty; let span = trace_span!( @@ -1899,7 +1928,7 @@ impl RequestSpan { kind = "unprepared", contents = contents, // - request_size = request_size, + request_size = Empty, result_size = Empty, result_rows = Empty, replicas = Empty, @@ -2013,6 +2042,10 @@ impl RequestSpan { .record("replicas", tracing::field::display(&ReplicaIps(replicas))); } + pub(crate) fn record_request_size(&self, size: usize) { + self.span.record("request_size", size); + } + pub(crate) fn inc_speculative_executions(&self) { self.speculative_executions.fetch_add(1, Ordering::Relaxed); } diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 805217053d..6168329821 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -1,7 +1,6 @@ use crate as scylla; use crate::batch::{Batch, BatchStatement}; use crate::frame::response::result::Row; -use crate::frame::value::ValueList; use crate::prepared_statement::PreparedStatement; use crate::query::Query; use crate::retry_policy::{QueryInfo, RetryDecision, RetryPolicy, RetrySession}; @@ -28,7 +27,9 @@ use assert_matches::assert_matches; use bytes::Bytes; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::Itertools; +use scylla_cql::frame::response::result::ColumnType; use scylla_cql::frame::value::Value; +use scylla_cql::types::serialize::row::{NewSerializedValues, SerializeRow}; use std::collections::BTreeSet; use std::collections::{BTreeMap, HashMap}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -208,7 +209,9 @@ async fn test_prepared_statement() { .unwrap(); let values = (17_i32, 16_i32, "I'm prepared!!!"); - let serialized_values = values.serialized().unwrap().into_owned(); + let serialized_values_complex_pk = prepared_complex_pk_statement + .serialize_values(&values) + .unwrap(); session.execute(&prepared_statement, &values).await.unwrap(); session @@ -231,15 +234,14 @@ async fn test_prepared_statement() { .as_bigint() .unwrap(), }; - let prepared_token = Murmur3Partitioner.hash_one( - &prepared_statement - .compute_partition_key(&serialized_values) - .unwrap(), - ); + let prepared_token = Murmur3Partitioner + .hash_one(&prepared_statement.compute_partition_key(&values).unwrap()); assert_eq!(token, prepared_token); + let mut pk = NewSerializedValues::new(); + pk.add_value(&17_i32, &ColumnType::Int).unwrap(); let cluster_data_token = session .get_cluster_data() - .compute_token(&ks, "t2", (17_i32,)) + .compute_token(&ks, "t2", &pk) .unwrap(); assert_eq!(token, cluster_data_token); } @@ -259,13 +261,13 @@ async fn test_prepared_statement() { }; let prepared_token = Murmur3Partitioner.hash_one( &prepared_complex_pk_statement - .compute_partition_key(&serialized_values) + .compute_partition_key(&values) .unwrap(), ); assert_eq!(token, prepared_token); let cluster_data_token = session .get_cluster_data() - .compute_token(&ks, "complex_pk", &serialized_values) + .compute_token(&ks, "complex_pk", &serialized_values_complex_pk) .unwrap(); assert_eq!(token, cluster_data_token); } @@ -510,7 +512,7 @@ async fn test_token_calculation() { s.push('a'); } let values = (&s,); - let serialized_values = values.serialized().unwrap().into_owned(); + let serialized_values = prepared_statement.serialize_values(&values).unwrap(); session.execute(&prepared_statement, &values).await.unwrap(); let rs = session @@ -529,11 +531,8 @@ async fn test_token_calculation() { .as_bigint() .unwrap(), }; - let prepared_token = Murmur3Partitioner.hash_one( - &prepared_statement - .compute_partition_key(&serialized_values) - .unwrap(), - ); + let prepared_token = Murmur3Partitioner + .hash_one(&prepared_statement.compute_partition_key(&values).unwrap()); assert_eq!(token, prepared_token); let cluster_data_token = session .get_cluster_data() @@ -2776,23 +2775,22 @@ async fn test_manual_primary_key_computation() { async fn assert_tokens_equal( session: &Session, prepared: &PreparedStatement, - pk_values_in_pk_order: impl ValueList, - all_values_in_query_order: impl ValueList, + pk_values_in_pk_order: impl SerializeRow, + all_values_in_query_order: impl SerializeRow, ) { let serialized_values_in_pk_order = - pk_values_in_pk_order.serialized().unwrap().into_owned(); - let serialized_values_in_query_order = - all_values_in_query_order.serialized().unwrap().into_owned(); + prepared.serialize_values(&pk_values_in_pk_order).unwrap(); + + let token_by_prepared = prepared + .calculate_token(&all_values_in_query_order) + .unwrap() + .unwrap(); session - .execute(prepared, &serialized_values_in_query_order) + .execute(prepared, all_values_in_query_order) .await .unwrap(); - let token_by_prepared = prepared - .calculate_token(&serialized_values_in_query_order) - .unwrap() - .unwrap(); let token_by_hand = calculate_token_for_partition_key(&serialized_values_in_pk_order, &Murmur3Partitioner) .unwrap(); diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index 63ee14f5b2..42101fec6b 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -15,7 +15,6 @@ use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use scylla_cql::errors::NewSessionError; use scylla_cql::frame::response::result::Row; -use scylla_cql::frame::value::ValueList; use scylla_macros::FromRow; use std::borrow::BorrowMut; use std::cell::Cell; @@ -751,7 +750,7 @@ async fn query_peers(conn: &Arc, connect_port: u16) -> Result, connect_port: u16) -> Result( conn: &Arc, query_str: &str, - keyspaces_to_fetch: &[String], -) -> impl Stream> { - let keyspaces = &[keyspaces_to_fetch] as &[&[String]]; + keyspaces_to_fetch: &'a [String], +) -> impl Stream> + 'a { let (query_str, query_values) = if !keyspaces_to_fetch.is_empty() { - (format!("{query_str} where keyspace_name in ?"), keyspaces) + ( + format!("{query_str} where keyspace_name in ?"), + keyspaces_to_fetch, + ) } else { - (query_str.into(), &[] as &[&[String]]) + (query_str.into(), &[] as &[String]) }; - let query_values = query_values.serialized().map(|sv| sv.into_owned()); + //let query_values = query_values.serialized().map(|sv| sv.into_owned()); let mut query = Query::new(query_str); let conn = conn.clone(); query.set_page_size(1024); let fut = async move { - let query_values = query_values?; - conn.query_iter(query, query_values).await + let vals = &[query_values] as &[&[String]]; + let prepared = conn.prepare(&query).await?; + let serialized_values = prepared.serialize_values(&vals)?; + conn.execute_iter(prepared, serialized_values).await }; fut.into_stream().try_flatten() } @@ -1601,7 +1604,7 @@ async fn query_table_partitioners( let rows = conn .clone() - .query_iter(partitioner_query, &[]) + .query_iter(partitioner_query) .into_stream() .try_flatten();