diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index 1ebf9cfede..fa96b2880c 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -3,6 +3,7 @@ use crate::frame::frame_errors::{CqlResponseParseError, FrameError, ParseError}; use crate::frame::protocol_features::ProtocolFeatures; use crate::frame::value::SerializeValuesError; +use crate::types::deserialize::{DeserializationError, TypeCheckError}; use crate::types::serialize::SerializationError; use crate::Consistency; use bytes::Bytes; @@ -461,6 +462,18 @@ impl From for QueryError { } } +impl From for QueryError { + fn from(value: DeserializationError) -> Self { + Self::InvalidMessage(value.to_string()) + } +} + +impl From for QueryError { + fn from(value: TypeCheckError) -> Self { + Self::InvalidMessage(value.to_string()) + } +} + impl From for QueryError { fn from(parse_error: ParseError) -> QueryError { QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error)) diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index b90eef3a79..155491ef91 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::TryFromPrimitiveError; use crate::cql_to_rust::CqlTypeError; use crate::frame::value::SerializeValuesError; -use crate::types::deserialize::DeserializationError; +use crate::types::deserialize::{DeserializationError, TypeCheckError}; use crate::types::serialize::SerializationError; use thiserror::Error; @@ -46,6 +46,8 @@ pub enum ParseError { #[error(transparent)] DeserializationError(#[from] DeserializationError), #[error(transparent)] + DeserializationTypeCheckError(#[from] TypeCheckError), + #[error(transparent)] IoError(#[from] std::io::Error), #[error(transparent)] SerializeValuesError(#[from] SerializeValuesError), @@ -216,6 +218,8 @@ pub enum PreparedParseError { ResultMetadataParseError(ResultMetadataParseError), #[error("Invalid prepared metadata: {0}")] PreparedMetadataParseError(ResultMetadataParseError), + #[error("Non-zero paging state in result metadata: {0:?}")] + NonZeroPagingState(Arc<[u8]>), } /// An error type returned when deserialization diff --git a/scylla-cql/src/frame/request/query.rs b/scylla-cql/src/frame/request/query.rs index 9c755c3db9..2794a2b5d9 100644 --- a/scylla-cql/src/frame/request/query.rs +++ b/scylla-cql/src/frame/request/query.rs @@ -242,6 +242,14 @@ impl PagingStateResponse { Self::NoMorePages => ControlFlow::Break(()), } } + + /// Swaps the paging state response with PagingStateResponse::NoMorePages. + /// + /// Only for use in driver's inner code, as an optimisation. + #[doc(hidden)] + pub fn take(&mut self) -> Self { + std::mem::replace(self, Self::NoMorePages) + } } /// The state of a paged query, i.e. where to resume fetching result rows diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index 8e6e7ff335..d084eb71c9 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -5,6 +5,8 @@ pub mod event; pub mod result; pub mod supported; +use std::sync::Arc; + pub use error::Error; pub use supported::Supported; @@ -65,9 +67,10 @@ impl Response { pub fn deserialize( features: &ProtocolFeatures, opcode: ResponseOpcode, - buf: &mut &[u8], - cached_metadata: Option<&ResultMetadata>, + buf_bytes: bytes::Bytes, + cached_metadata: Option<&Arc>, ) -> Result { + let buf = &mut &*buf_bytes; let response = match opcode { ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?), ResponseOpcode::Ready => Response::Ready, @@ -75,7 +78,9 @@ impl Response { Response::Authenticate(authenticate::Authenticate::deserialize(buf)?) } ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?), - ResponseOpcode::Result => Response::Result(result::deserialize(buf, cached_metadata)?), + ResponseOpcode::Result => { + Response::Result(result::deserialize(buf_bytes, cached_metadata)?) + } ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?), ResponseOpcode::AuthChallenge => { Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index b6b51dc7df..4506d6a680 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -17,6 +17,7 @@ use crate::types::deserialize::value::{ use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; use std::borrow::Cow; +use std::sync::Arc; use std::{net::IpAddr, result::Result as StdResult, str}; use uuid::Uuid; @@ -431,7 +432,6 @@ pub struct ColumnSpec { #[derive(Debug, Clone)] pub struct ResultMetadata { col_count: usize, - pub paging_state: PagingStateResponse, pub col_specs: Vec, } @@ -440,10 +440,18 @@ impl ResultMetadata { pub fn mock_empty() -> Self { Self { col_count: 0, - paging_state: PagingStateResponse::NoMorePages, col_specs: Vec::new(), } } + + #[inline] + #[doc(hidden)] + pub fn new_for_test(col_count: usize, col_specs: Vec) -> Self { + Self { + col_count, + col_specs, + } + } } #[derive(Debug, Copy, Clone)] @@ -478,7 +486,8 @@ impl Row { #[derive(Debug)] pub struct Rows { - pub metadata: ResultMetadata, + pub metadata: Arc, + pub paging_state_response: PagingStateResponse, pub rows_count: usize, pub rows: Vec, /// Original size of the serialized rows. @@ -620,7 +629,9 @@ fn deser_col_specs( Ok(col_specs) } -fn deser_result_metadata(buf: &mut &[u8]) -> StdResult { +fn deser_result_metadata( + buf: &mut &[u8], +) -> StdResult<(ResultMetadata, PagingStateResponse), ResultMetadataParseError> { let flags = types::read_int(buf) .map_err(|err| ResultMetadataParseError::FlagsParseError(err.into()))?; let global_tables_spec = flags & 0x0001 != 0; @@ -635,27 +646,23 @@ fn deser_result_metadata(buf: &mut &[u8]) -> StdResult, + buf_bytes: Bytes, + cached_metadata: Option<&Arc>, ) -> StdResult { - let server_metadata = deser_result_metadata(buf)?; + let buf = &mut &*buf_bytes; + let (server_metadata, paging_state_response) = deser_result_metadata(buf)?; let metadata = match cached_metadata { - Some(cached) => ResultMetadata { - col_count: cached.col_count, - paging_state: server_metadata.paging_state, - col_specs: cached.col_specs.clone(), - }, + Some(cached) => Arc::clone(cached), None => { // No cached_metadata provided. Server is supposed to provide the result metadata. if server_metadata.col_count != server_metadata.col_specs.len() { @@ -878,7 +882,7 @@ fn deser_rows( col_specs_count: server_metadata.col_specs.len(), }); } - server_metadata + Arc::new(server_metadata) } }; @@ -899,6 +903,7 @@ fn deser_rows( Ok(Rows { metadata, + paging_state_response, rows_count, rows, serialized_size: original_size - buf.len(), @@ -919,8 +924,17 @@ fn deser_prepared(buf: &mut &[u8]) -> StdResult { buf.advance(id_len); let prepared_metadata = deser_prepared_metadata(buf).map_err(PreparedParseError::PreparedMetadataParseError)?; - let result_metadata = + let (result_metadata, paging_state_response) = deser_result_metadata(buf).map_err(PreparedParseError::ResultMetadataParseError)?; + if let PagingStateResponse::HasMorePages { state } = paging_state_response { + return Err(PreparedParseError::NonZeroPagingState( + state + .as_bytes_slice() + .cloned() + .unwrap_or_else(|| Arc::from([])), + )); + } + Ok(Prepared { id, prepared_metadata, @@ -935,16 +949,17 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult, + buf_bytes: Bytes, + cached_metadata: Option<&Arc>, ) -> StdResult { + let buf = &mut &*buf_bytes; use self::Result::*; Ok( match types::read_int(buf) .map_err(|err| CqlResultParseError::ResultIdParseError(err.into()))? { 0x0001 => Void, - 0x0002 => Rows(deser_rows(buf, cached_metadata)?), + 0x0002 => Rows(deser_rows(buf_bytes.slice_ref(buf), cached_metadata)?), 0x0003 => SetKeyspace(deser_set_keyspace(buf)?), 0x0004 => Prepared(deser_prepared(buf)?), 0x0005 => SchemaChange(deser_schema_change(buf)?), diff --git a/scylla-cql/src/types/deserialize/result.rs b/scylla-cql/src/types/deserialize/result.rs index 036b909afb..cb18d14727 100644 --- a/scylla-cql/src/types/deserialize/result.rs +++ b/scylla-cql/src/types/deserialize/result.rs @@ -5,6 +5,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use std::marker::PhantomData; /// Iterates over the whole result, returning rows. +#[derive(Debug)] pub struct RowIterator<'frame> { specs: &'frame [ColumnSpec], remaining: usize, @@ -76,6 +77,7 @@ impl<'frame> Iterator for RowIterator<'frame> { /// A typed version of [RowIterator] which deserializes the rows before /// returning them. +#[derive(Debug)] pub struct TypedRowIterator<'frame, R> { inner: RowIterator<'frame>, _phantom: PhantomData, diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 074a7c298a..ca6e22b651 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -670,6 +670,15 @@ impl<'frame, T> ListlikeIterator<'frame, T> { phantom_data: std::marker::PhantomData, } } + + fn empty(coll_typ: &'frame ColumnType, elem_typ: &'frame ColumnType) -> Self { + Self { + coll_typ, + elem_typ, + raw_iter: FixedLengthBytesSequenceIterator::empty(), + phantom_data: std::marker::PhantomData, + } + } } impl<'frame, T> DeserializeValue<'frame> for ListlikeIterator<'frame, T> @@ -699,7 +708,19 @@ where typ: &'frame ColumnType, v: Option>, ) -> Result { - let mut v = ensure_not_null_frame_slice::(typ, v)?; + let elem_typ = match typ { + ColumnType::List(elem_typ) | ColumnType::Set(elem_typ) => elem_typ, + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + + let mut v = if let Some(v) = v { + v + } else { + return Ok(Self::empty(typ, elem_typ)); + }; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { mk_deser_err::( typ, @@ -708,12 +729,7 @@ where ), ) })?; - let elem_typ = match typ { - ColumnType::List(elem_typ) | ColumnType::Set(elem_typ) => elem_typ, - _ => { - unreachable!("Typecheck should have prevented this scenario!") - } - }; + Ok(Self::new(typ, elem_typ, count, v)) } } @@ -849,6 +865,21 @@ impl<'frame, K, V> MapIterator<'frame, K, V> { phantom_data_v: std::marker::PhantomData, } } + + fn empty( + coll_typ: &'frame ColumnType, + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + ) -> Self { + Self { + coll_typ, + k_typ, + v_typ, + raw_iter: FixedLengthBytesSequenceIterator::empty(), + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, + } + } } impl<'frame, K, V> DeserializeValue<'frame> for MapIterator<'frame, K, V> @@ -875,7 +906,19 @@ where typ: &'frame ColumnType, v: Option>, ) -> Result { - let mut v = ensure_not_null_frame_slice::(typ, v)?; + let (k_typ, v_typ) = match typ { + ColumnType::Map(k_t, v_t) => (k_t, v_t), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + + let mut v = if let Some(v) = v { + v + } else { + return Ok(Self::empty(typ, k_typ, v_typ)); + }; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { mk_deser_err::( typ, @@ -884,12 +927,7 @@ where ), ) })?; - let (k_typ, v_typ) = match typ { - ColumnType::Map(k_t, v_t) => (k_t, v_t), - _ => { - unreachable!("Typecheck should have prevented this scenario!") - } - }; + Ok(Self::new(typ, k_typ, v_typ, 2 * count, v)) } } @@ -1275,6 +1313,13 @@ impl<'frame> FixedLengthBytesSequenceIterator<'frame> { remaining: count, } } + + fn empty() -> Self { + Self { + slice: FrameSlice::new_empty(), + remaining: 0, + } + } } impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { diff --git a/scylla-cql/src/types/deserialize/value_tests.rs b/scylla-cql/src/types/deserialize/value_tests.rs index 9375ce47f6..fd14c5e730 100644 --- a/scylla-cql/src/types/deserialize/value_tests.rs +++ b/scylla-cql/src/types/deserialize/value_tests.rs @@ -424,6 +424,24 @@ fn test_list_and_set() { expected_vec_string.into_iter().collect(), ); + // Null collections are interpreted as empty collections, to retain convenience: + // when an empty collection is sent to the DB, the DB nullifies the column instead. + { + let list_typ = ColumnType::List(Box::new(ColumnType::BigInt)); + let set_typ = ColumnType::Set(Box::new(ColumnType::BigInt)); + type CollTyp = i64; + + fn check<'frame, Collection: DeserializeValue<'frame>>(typ: &'frame ColumnType) { + >::type_check(typ).unwrap(); + >::deserialize(typ, None).unwrap(); + } + + check::>(&list_typ); + check::>(&set_typ); + check::>(&set_typ); + check::>(&set_typ); + } + // ser/de identity assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); @@ -486,6 +504,22 @@ fn test_map() { ); assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); + // Null collections are interpreted as empty collections, to retain convenience: + // when an empty collection is sent to the DB, the DB nullifies the column instead. + { + let map_typ = ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Ascii)); + type KeyTyp = i64; + type ValueTyp<'s> = &'s str; + + fn check<'frame, Collection: DeserializeValue<'frame>>(typ: &'frame ColumnType) { + >::type_check(typ).unwrap(); + >::deserialize(typ, None).unwrap(); + } + + check::>(&map_typ); + check::>(&map_typ); + } + // ser/de identity assert_ser_de_identity( &typ, @@ -1218,18 +1252,6 @@ fn test_set_or_list_errors() { ); } - // Got null - { - type RustTyp = Vec; - let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - // Bad element type { assert_type_check_error!( @@ -1316,18 +1338,6 @@ fn test_map_errors() { ); } - // Got null - { - type RustTyp = HashMap; - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - // Key type mismatch { let err = deserialize::>( diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 07ba8aaef5..3adce0fc04 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -63,7 +63,8 @@ //! # Ok(()) //! # } //! ``` -//! But the driver will accept anything implementing the trait [ValueList](crate::frame::value::ValueList) +//! But the driver will accept anything implementing the trait [SerializeRow] +//! (crate::serialize::row::SerializeRow) //! //! ### Receiving results //! The easiest way to read rows returned by a query is to cast each row to a tuple of values: diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 2899c40dd4..809a8fe30c 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -102,7 +102,7 @@ pub struct PreparedStatement { #[derive(Debug)] struct PreparedStatementSharedData { metadata: PreparedMetadata, - result_metadata: ResultMetadata, + result_metadata: Arc, statement: String, } @@ -125,7 +125,7 @@ impl PreparedStatement { id: Bytes, is_lwt: bool, metadata: PreparedMetadata, - result_metadata: ResultMetadata, + result_metadata: Arc, statement: String, page_size: PageSize, config: StatementConfig, @@ -417,7 +417,7 @@ impl PreparedStatement { } /// Access metadata about the result of prepared statement returned by the database - pub(crate) fn get_result_metadata(&self) -> &ResultMetadata { + pub(crate) fn get_result_metadata(&self) -> &Arc { &self.shared.result_metadata } diff --git a/scylla/src/transport/caching_session.rs b/scylla/src/transport/caching_session.rs index 0449937956..3fa352cd7d 100644 --- a/scylla/src/transport/caching_session.rs +++ b/scylla/src/transport/caching_session.rs @@ -14,6 +14,7 @@ use scylla_cql::types::serialize::batch::BatchValues; use scylla_cql::types::serialize::row::SerializeRow; use std::collections::hash_map::RandomState; use std::hash::BuildHasher; +use std::sync::Arc; /// Contains just the parts of a prepared statement that were returned /// from the database. All remaining parts (query string, page size, @@ -24,7 +25,7 @@ struct RawPreparedStatementData { id: Bytes, is_confirmed_lwt: bool, metadata: PreparedMetadata, - result_metadata: ResultMetadata, + result_metadata: Arc, partitioner_name: PartitionerName, } diff --git a/scylla/src/transport/connection.rs b/scylla/src/transport/connection.rs index 76a226ce54..456f08d386 100644 --- a/scylla/src/transport/connection.rs +++ b/scylla/src/transport/connection.rs @@ -267,14 +267,14 @@ impl NonErrorQueryResponse { pub(crate) fn into_query_result_and_paging_state( self, ) -> Result<(QueryResult, PagingStateResponse), QueryError> { - let (rows, paging_state, col_specs, serialized_size) = match self.response { + let (rows, paging_state, metadata, serialized_size) = match self.response { NonErrorResponse::Result(result::Result::Rows(rs)) => ( Some(rs.rows), - rs.metadata.paging_state, - rs.metadata.col_specs, + rs.paging_state_response, + Some(rs.metadata), rs.serialized_size, ), - NonErrorResponse::Result(_) => (None, PagingStateResponse::NoMorePages, vec![], 0), + NonErrorResponse::Result(_) => (None, PagingStateResponse::NoMorePages, None, 0), _ => { return Err(QueryError::ProtocolError( "Unexpected server response, expected Result or Error", @@ -287,7 +287,7 @@ impl NonErrorQueryResponse { rows, warnings: self.warnings, tracing_id: self.tracing_id, - col_specs, + metadata, serialized_size, }, paging_state, @@ -778,7 +778,7 @@ impl Connection { .protocol_features .prepared_flags_contain_lwt_mark(p.prepared_metadata.flags as u32), p.prepared_metadata, - p.result_metadata, + Arc::new(p.result_metadata), query.contents.clone(), query.get_validated_page_size(), query.config.clone(), @@ -1273,7 +1273,7 @@ impl Connection { request: &impl SerializableRequest, compress: bool, tracing: bool, - cached_metadata: Option<&ResultMetadata>, + cached_metadata: Option<&Arc>, ) -> Result { let compression = if compress { self.config.compression @@ -1298,7 +1298,7 @@ impl Connection { task_response: TaskResponse, compression: Option, features: &ProtocolFeatures, - cached_metadata: Option<&ResultMetadata>, + cached_metadata: Option<&Arc>, ) -> Result { let body_with_ext = frame::parse_response_body_extensions( task_response.params.flags, @@ -1316,7 +1316,7 @@ impl Connection { let response = Response::deserialize( features, task_response.opcode, - &mut &*body_with_ext.body, + body_with_ext.body, cached_metadata, )?; diff --git a/scylla/src/transport/iterator.rs b/scylla/src/transport/iterator.rs index cb5c26ca87..cb5a8141c8 100644 --- a/scylla/src/transport/iterator.rs +++ b/scylla/src/transport/iterator.rs @@ -410,9 +410,12 @@ impl RowIterator { mod checked_channel_sender { use scylla_cql::{ errors::QueryError, - frame::response::result::{ResultMetadata, Rows}, + frame::{ + request::query::PagingStateResponse, + response::result::{ResultMetadata, Rows}, + }, }; - use std::marker::PhantomData; + use std::{marker::PhantomData, sync::Arc}; use tokio::sync::mpsc; use uuid::Uuid; @@ -453,7 +456,8 @@ mod checked_channel_sender { ) { let empty_page = ReceivedPage { rows: Rows { - metadata: ResultMetadata::mock_empty(), + metadata: Arc::new(ResultMetadata::mock_empty()), + paging_state_response: PagingStateResponse::NoMorePages, rows_count: 0, rows: Vec::new(), serialized_size: 0, @@ -660,7 +664,7 @@ where match query_response { Ok(NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Rows(rows)), + response: NonErrorResponse::Result(result::Result::Rows(mut rows)), tracing_id, .. }) => { @@ -671,9 +675,9 @@ where .load_balancing_policy .on_query_success(&self.statement_info, elapsed, node); - request_span.record_rows_fields(&rows); + let paging_state_response = rows.paging_state_response.take(); - let paging_state_response = rows.metadata.paging_state.clone(); + request_span.record_rows_fields(&rows); let received_page = ReceivedPage { rows, tracing_id }; @@ -840,8 +844,8 @@ where let result = (self.fetcher)(paging_state).await?; let response = result.into_non_error_query_response()?; match response.response { - NonErrorResponse::Result(result::Result::Rows(rows)) => { - let paging_state_response = rows.metadata.paging_state.clone(); + NonErrorResponse::Result(result::Result::Rows(mut rows)) => { + let paging_state_response = rows.paging_state_response.take(); let (proof, send_result) = self .sender diff --git a/scylla/src/transport/query_result.rs b/scylla/src/transport/query_result.rs index 49a68f5c73..db446209f7 100644 --- a/scylla/src/transport/query_result.rs +++ b/scylla/src/transport/query_result.rs @@ -1,13 +1,15 @@ +use std::sync::Arc; + use crate::frame::response::cql_to_rust::{FromRow, FromRowError}; use crate::frame::response::result::ColumnSpec; use crate::frame::response::result::Row; use crate::transport::session::{IntoTypedRows, TypedRowIter}; +use scylla_cql::frame::response::result::ResultMetadata; use thiserror::Error; use uuid::Uuid; /// Result of a single query\ /// Contains all rows returned by the database and some more information -#[non_exhaustive] #[derive(Debug)] pub struct QueryResult { /// Rows returned by the database.\ @@ -18,8 +20,8 @@ pub struct QueryResult { pub warnings: Vec, /// CQL Tracing uuid - can only be Some if tracing is enabled for this query pub tracing_id: Option, - /// Column specification returned from the server - pub col_specs: Vec, + /// Metadata returned along with this response. + pub(crate) metadata: Option>, /// The original size of the serialized rows in request pub serialized_size: usize, } @@ -30,7 +32,7 @@ impl QueryResult { rows: None, warnings: Vec::new(), tracing_id: None, - col_specs: Vec::new(), + metadata: None, serialized_size: 0, } } @@ -134,9 +136,19 @@ impl QueryResult { Ok(self.single_row()?.into_typed::()?) } + /// Returns column specifications. + #[inline] + pub fn col_specs(&self) -> &[ColumnSpec] { + self.metadata + .as_ref() + .map(|metadata| metadata.col_specs.as_slice()) + .unwrap_or_default() + } + /// Returns a column specification for a column with given name, or None if not found + #[inline] pub fn get_column_spec<'a>(&'a self, name: &str) -> Option<(usize, &'a ColumnSpec)> { - self.col_specs + self.col_specs() .iter() .enumerate() .find(|(_id, spec)| spec.name == name) @@ -269,12 +281,13 @@ impl From for SingleRowTypedError { mod tests { use super::*; use crate::{ - frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row, TableSpec}, + frame::response::result::{CqlValue, Row}, test_utils::setup_tracing, }; use std::convert::TryInto; use assert_matches::assert_matches; + use scylla_cql::frame::response::result::{ColumnType, TableSpec}; // Returns specified number of rows, each one containing one int32 value. // Values are 0, 1, 2, 3, 4, ... @@ -301,8 +314,8 @@ mod tests { rows } - fn make_not_rows_query_result() -> QueryResult { - let table_spec = TableSpec::owned("some_keyspace".to_string(), "some_table".to_string()); + fn make_test_metadata() -> ResultMetadata { + let table_spec = TableSpec::borrowed("some_keyspace", "some_table"); let column_spec = ColumnSpec { table_spec, @@ -310,11 +323,15 @@ mod tests { typ: ColumnType::Int, }; + ResultMetadata::new_for_test(1, vec![column_spec]) + } + + fn make_not_rows_query_result() -> QueryResult { QueryResult { rows: None, warnings: vec![], tracing_id: None, - col_specs: vec![column_spec], + metadata: None, serialized_size: 0, } } @@ -322,12 +339,14 @@ mod tests { fn make_rows_query_result(rows_num: usize) -> QueryResult { let mut res = make_not_rows_query_result(); res.rows = Some(make_rows(rows_num)); + res.metadata = Some(Arc::new(make_test_metadata())); res } fn make_string_rows_query_result(rows_num: usize) -> QueryResult { let mut res = make_not_rows_query_result(); res.rows = Some(make_string_rows(rows_num)); + res.metadata = Some(Arc::new(make_test_metadata())); res } diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index a88fa73018..412b76a00a 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -1056,7 +1056,7 @@ async fn test_tracing_query_iter(session: &Session, ks: String) { assert!(untraced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let untraced_typed_row_iter = untraced_row_iter.into_typed::<(i32,)>(); + let untraced_typed_row_iter = untraced_row_iter.into_typed::<(String,)>(); assert!(untraced_typed_row_iter.get_tracing_ids().is_empty()); // A query with tracing enabled has a tracing ids in result @@ -1071,7 +1071,7 @@ async fn test_tracing_query_iter(session: &Session, ks: String) { assert!(!traced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let traced_typed_row_iter = traced_row_iter.into_typed::<(i32,)>(); + let traced_typed_row_iter = traced_row_iter.into_typed::<(String,)>(); assert!(!traced_typed_row_iter.get_tracing_ids().is_empty()); for tracing_id in traced_typed_row_iter.get_tracing_ids() { @@ -1094,7 +1094,7 @@ async fn test_tracing_execute_iter(session: &Session, ks: String) { assert!(untraced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let untraced_typed_row_iter = untraced_row_iter.into_typed::<(i32,)>(); + let untraced_typed_row_iter = untraced_row_iter.into_typed::<(String,)>(); assert!(untraced_typed_row_iter.get_tracing_ids().is_empty()); // A prepared statement with tracing enabled has a tracing ids in result @@ -1112,7 +1112,7 @@ async fn test_tracing_execute_iter(session: &Session, ks: String) { assert!(!traced_row_iter.get_tracing_ids().is_empty()); // The same is true for TypedRowIter - let traced_typed_row_iter = traced_row_iter.into_typed::<(i32,)>(); + let traced_typed_row_iter = traced_row_iter.into_typed::<(String,)>(); assert!(!traced_typed_row_iter.get_tracing_ids().is_empty()); for tracing_id in traced_typed_row_iter.get_tracing_ids() { @@ -2570,7 +2570,7 @@ async fn test_batch_lwts() { let batch_res: QueryResult = session.batch(&batch, ((), (), ())).await.unwrap(); // Scylla returns 5 columns, but Cassandra returns only 1 - let is_scylla: bool = batch_res.col_specs.len() == 5; + let is_scylla: bool = batch_res.col_specs().len() == 5; if is_scylla { test_batch_lwts_for_scylla(&session, &batch, batch_res).await;