From 73d3d6ef990de47cf4484085d95cc673e2613df5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 07:42:40 +0200 Subject: [PATCH 01/41] types/deser: introduce TypeCheckError In deserialization, contrary to serialization, we have two distinct functions: type_check() and deserialize(). As their task is much different, their returned errors should have distinct types, too. --- scylla-cql/src/types/deserialize/mod.rs | 30 +++++++++++++++++++++++++ scylla-cql/src/types/mod.rs | 1 + 2 files changed, 31 insertions(+) create mode 100644 scylla-cql/src/types/deserialize/mod.rs diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs new file mode 100644 index 0000000000..9a26b38ea8 --- /dev/null +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use thiserror::Error; + +/// An error indicating that a failure happened during type check. +/// +/// The error is type-erased so that the crate users can define their own +/// type check impls and their errors. +/// As for the impls defined or generated +/// by the driver itself, the following errors can be returned: +/// +/// - [`row::BuiltinTypeCheckError`] is returned when type check of +/// one of types with an impl built into the driver fails. It is also returned +/// from impls generated by the `DeserializeRow` macro. +/// - [`value::BuiltinTypeCheckError`] is analogous to the above but is +/// returned from [`DeserializeValue::type_check`] instead both in the case of +/// builtin impls and impls generated by the `DeserializeValue` macro. +/// It won't be returned by the `Session` directly, but it might be nested +/// in the [`row::BuiltinTypeCheckError`]. +#[derive(Debug, Clone, Error)] +#[error(transparent)] +pub struct TypeCheckError(pub(crate) Arc); + +impl TypeCheckError { + /// Constructs a new `TypeCheckError`. + #[inline] + pub fn new(err: impl std::error::Error + Send + Sync + 'static) -> Self { + Self(Arc::new(err)) + } +} diff --git a/scylla-cql/src/types/mod.rs b/scylla-cql/src/types/mod.rs index ec9942a885..59bce9522d 100644 --- a/scylla-cql/src/types/mod.rs +++ b/scylla-cql/src/types/mod.rs @@ -1 +1,2 @@ +pub mod deserialize; pub mod serialize; From 67dbe83d8b950eadd452a33f6f195b0205ff7ec2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 22 Mar 2024 20:48:46 +0100 Subject: [PATCH 02/41] types/deser: introduce FrameSlice FrameSlice is going to be the main hero of the new lazy deserialization framework. FrameSlice, unsurprisingly, represents a slice of a result frame. What is interesting though is that apart from a regular &[u8] slice, it holds a Bytes object that keeps the whole frame alive. This allows deserializing blobs as owned ('static!) objects, by subslicing the original Bytes object. This commit contains a suite of tests for FrameSlice. Co-authored-by: Piotr Dulikowski --- .../src/types/deserialize/frame_slice.rs | 187 ++++++++++++++++++ scylla-cql/src/types/deserialize/mod.rs | 33 ++++ 2 files changed, 220 insertions(+) create mode 100644 scylla-cql/src/types/deserialize/frame_slice.rs diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs new file mode 100644 index 0000000000..02713bbe7c --- /dev/null +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -0,0 +1,187 @@ +use bytes::Bytes; + +use crate::frame::frame_errors::ParseError; +use crate::frame::types; + +/// A reference to a part of the frame. +// +// # Design justification +// +// ## Why we need a borrowed type +// +// The reason why we need to store a borrowed slice is that we want a lifetime that is longer than one obtained +// when coercing Bytes to a slice in the body of a function. That is, we want to allow deserializing types +// that borrow from the frame, which resides in QueryResult. +// Consider a function with the signature: +// +// fn fun(b: Bytes) { ... } +// +// This function cannot return a type that borrows from the frame, because any slice created from `b` +// inside `fun` cannot escape `fun`. +// Conversely, if a function has signature: +// +// fn fun(s: &'frame [u8]) { ... } +// +// then it can happily return types with lifetime 'frame. +// +// ## Why we need the full frame +// +// We don't. We only need to be able to return Bytes encompassing our subslice. However, the design choice +// was made to only store a reference to the original Bytes object residing in QueryResult, so that we avoid +// cloning Bytes when performing subslicing on FrameSlice. We delay the Bytes cloning, normally a moderately +// expensive operation involving cloning an Arc, up until it is really needed. +// +// ## Why not different design +// +// - why not a &'frame [u8] only? Because we want to enable deserializing types containing owned Bytes, too. +// - why not a Bytes only? Because we need to propagate the 'frame lifetime. +// - why not a &'frame Bytes only? Because we want to somehow represent subslices, and subslicing +// &'frame Bytes return Bytes, not &'frame Bytes. +#[derive(Clone, Copy, Debug)] +pub struct FrameSlice<'frame> { + // The actual subslice represented by this FrameSlice. + frame_subslice: &'frame [u8], + + // Reference to the original Bytes object that this FrameSlice is derived + // from. It is used to convert the `mem` slice into a fully blown Bytes + // object via Bytes::slice_ref method. + original_frame: &'frame Bytes, +} + +static EMPTY_BYTES: Bytes = Bytes::new(); + +impl<'frame> FrameSlice<'frame> { + /// Creates a new FrameSlice from a reference of a Bytes object. + /// + /// This method is exposed to allow writing deserialization tests + /// for custom types. + #[inline] + pub fn new(frame: &'frame Bytes) -> Self { + Self { + frame_subslice: frame, + original_frame: frame, + } + } + + /// Creates an empty FrameSlice. + #[inline] + pub fn new_empty() -> Self { + Self { + frame_subslice: &EMPTY_BYTES, + original_frame: &EMPTY_BYTES, + } + } + + /// Returns `true` if the slice has length of 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.frame_subslice.is_empty() + } + + /// Returns the subslice. + #[inline] + pub fn as_slice(&self) -> &'frame [u8] { + self.frame_subslice + } + + /// Returns a mutable reference to the subslice. + #[inline] + pub fn as_slice_mut(&mut self) -> &mut &'frame [u8] { + &mut self.frame_subslice + } + + /// Returns a reference to the Bytes object which encompasses the whole frame slice. + /// + /// The Bytes object will usually be larger than the slice returned by + /// [FrameSlice::as_slice]. If you wish to obtain a new Bytes object that + /// points only to the subslice represented by the FrameSlice object, + /// see [FrameSlice::to_bytes]. + #[inline] + pub fn as_original_frame_bytes(&self) -> &'frame Bytes { + self.original_frame + } + + /// Returns a new Bytes object which is a subslice of the original Bytes + /// frame slice object. + #[inline] + pub fn to_bytes(&self) -> Bytes { + self.original_frame.slice_ref(self.frame_subslice) + } + + /// Reads and consumes a `[bytes]` item from the beginning of the frame, + /// returning a subslice that encompasses that item. + /// + /// If the operation fails then the slice remains unchanged. + #[inline] + pub(super) fn read_cql_bytes(&mut self) -> Result>, ParseError> { + // We copy the slice reference, not to mutate the FrameSlice in case of an error. + let mut slice = self.frame_subslice; + + let cql_bytes = types::read_bytes_opt(&mut slice)?; + + // `read_bytes_opt` hasn't failed, so now we must update the FrameSlice. + self.frame_subslice = slice; + + Ok(cql_bytes.map(|slice| Self { + frame_subslice: slice, + original_frame: self.original_frame, + })) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::super::tests::{serialize_cells, CELL1, CELL2}; + use super::FrameSlice; + + #[test] + fn test_cql_bytes_consumption() { + let frame = serialize_cells([Some(CELL1), None, Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + assert!(!slice.is_empty()); + + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL1) + ); + assert!(!slice.is_empty()); + assert!(slice.read_cql_bytes().unwrap().is_none()); + assert!(!slice.is_empty()); + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL2) + ); + assert!(slice.is_empty()); + slice.read_cql_bytes().unwrap_err(); + assert!(slice.is_empty()); + } + + #[test] + fn test_cql_bytes_owned() { + let frame = serialize_cells([Some(CELL1), Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + + let subslice1 = slice.read_cql_bytes().unwrap().unwrap(); + let subslice2 = slice.read_cql_bytes().unwrap().unwrap(); + + assert_eq!(subslice1.as_slice(), CELL1); + assert_eq!(subslice2.as_slice(), CELL2); + + assert_eq!( + subslice1.as_original_frame_bytes() as *const Bytes, + &frame as *const Bytes + ); + assert_eq!( + subslice2.as_original_frame_bytes() as *const Bytes, + &frame as *const Bytes + ); + + let subslice1_bytes = subslice1.to_bytes(); + let subslice2_bytes = subslice2.to_bytes(); + + assert_eq!(subslice1.as_slice(), subslice1_bytes.as_ref()); + assert_eq!(subslice2.as_slice(), subslice2_bytes.as_ref()); + } +} diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 9a26b38ea8..7b214352b6 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -1,3 +1,7 @@ +pub mod frame_slice; + +pub use frame_slice::FrameSlice; + use std::sync::Arc; use thiserror::Error; @@ -28,3 +32,32 @@ impl TypeCheckError { Self(Arc::new(err)) } } + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::frame::types; + + pub(super) static CELL1: &[u8] = &[1, 2, 3]; + pub(super) static CELL2: &[u8] = &[4, 5, 6, 7]; + + pub(super) fn serialize_cells( + cells: impl IntoIterator>>, + ) -> Bytes { + let mut bytes = BytesMut::new(); + for cell in cells { + types::write_bytes_opt(cell, &mut bytes).unwrap(); + } + bytes.freeze() + } + + pub(super) fn spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + name: name.to_owned(), + typ, + table_spec: TableSpec::borrowed("ks", "tbl"), + } + } +} From 88e9e06d1339c38d1f6951083241a938fd06b0d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 08:17:49 +0200 Subject: [PATCH 03/41] types/deser: introduce DeserializationError It is fully analogous to SerializationError. It features dynamic dispatch of Error trait, which enables runtime downcasts in tests, greatly improving testability while retaining flexibility of errors. --- scylla-cql/src/frame/frame_errors.rs | 3 +++ scylla-cql/src/types/deserialize/mod.rs | 35 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index a203762c53..68757331fc 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -1,6 +1,7 @@ use super::TryFromPrimitiveError; use crate::cql_to_rust::CqlTypeError; use crate::frame::value::SerializeValuesError; +use crate::types::deserialize::DeserializationError; use crate::types::serialize::SerializationError; use thiserror::Error; @@ -39,6 +40,8 @@ pub enum ParseError { #[error("Could not deserialize frame: {0}")] BadIncomingData(String), #[error(transparent)] + DeserializationError(#[from] DeserializationError), + #[error(transparent)] IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] TypeNotImplemented(u16), diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 7b214352b6..1d442a0bb1 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -2,10 +2,14 @@ pub mod frame_slice; pub use frame_slice::FrameSlice; +use std::error::Error; +use std::fmt::Display; use std::sync::Arc; use thiserror::Error; +// Errors + /// An error indicating that a failure happened during type check. /// /// The error is type-erased so that the crate users can define their own @@ -33,6 +37,37 @@ impl TypeCheckError { } } +/// An error indicating that a failure happened during deserialization. +/// +/// The error is type-erased so that the crate users can define their own +/// deserialization impls and their errors. As for the impls defined or generated +/// by the driver itself, the following errors can be returned: +/// +/// - [`row::BuiltinDeserializationError`] is returned when deserialization of +/// one of types with an impl built into the driver fails. It is also returned +/// from impls generated by the `DeserializeRow` macro. +/// - [`value::BuiltinDeserializationError`] is analogous to the above but is +/// returned from [`DeserializeValue::deserialize`] instead both in the case of +/// builtin impls and impls generated by the `DeserializeValue` macro. +/// It won't be returned by the `Session` directly, but it might be nested +/// in the [`row::BuiltinDeserializationError`]. +#[derive(Debug, Clone, Error)] +pub struct DeserializationError(Arc); + +impl DeserializationError { + /// Constructs a new `DeserializationError`. + #[inline] + pub fn new(err: impl Error + Send + Sync + 'static) -> Self { + Self(Arc::new(err)) + } +} + +impl Display for DeserializationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DeserializationError: {}", self.0) + } +} + #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; From 074d53103ea3f0b79cf9d458e7180bd6c3e14606 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 10:23:33 +0200 Subject: [PATCH 04/41] deser/value: error facilities --- scylla-cql/src/types/deserialize/mod.rs | 1 + scylla-cql/src/types/deserialize/value.rs | 113 ++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 scylla-cql/src/types/deserialize/value.rs diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 1d442a0bb1..e48fa5a57e 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -1,4 +1,5 @@ pub mod frame_slice; +pub mod value; pub use frame_slice::FrameSlice; diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs new file mode 100644 index 0000000000..f3b196a979 --- /dev/null +++ b/scylla-cql/src/types/deserialize/value.rs @@ -0,0 +1,113 @@ +//! Provides types for dealing with CQL value deserialization. + +use std::fmt::Display; + +use thiserror::Error; + +use super::{DeserializationError, TypeCheckError}; +use crate::frame::response::result::ColumnType; + +// Error facilities + +/// Type checking of one of the built-in types failed. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check Rust type {rust_name} against CQL type {cql_type:?}: {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type being deserialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being deserialized from. + pub cql_type: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err( + cql_type: &ColumnType, + kind: impl Into, +) -> TypeCheckError { + mk_typck_err_named(std::any::type_name::(), cql_type, kind) +} + +fn mk_typck_err_named( + name: &'static str, + cql_type: &ColumnType, + kind: impl Into, +) -> TypeCheckError { + TypeCheckError::new(BuiltinTypeCheckError { + rust_name: name, + cql_type: cql_type.clone(), + kind: kind.into(), + }) +} + +macro_rules! exact_type_check { + ($typ:ident, $($cql:tt),*) => { + match $typ { + $(ColumnType::$cql)|* => {}, + _ => return Err(mk_typck_err::( + $typ, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[$(ColumnType::$cql),*], + } + )) + } + }; +} +use exact_type_check; + +/// Describes why type checking some of the built-in types failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind {} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + +/// Deserialization of one of the built-in types failed. +#[derive(Debug, Error)] +#[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")] +pub struct BuiltinDeserializationError { + /// Name of the Rust type being deserialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being deserialized from. + pub cql_type: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinDeserializationErrorKind, +} + +fn mk_deser_err( + cql_type: &ColumnType, + kind: impl Into, +) -> DeserializationError { + mk_deser_err_named(std::any::type_name::(), cql_type, kind) +} + +fn mk_deser_err_named( + name: &'static str, + cql_type: &ColumnType, + kind: impl Into, +) -> DeserializationError { + DeserializationError::new(BuiltinDeserializationError { + rust_name: name, + cql_type: cql_type.clone(), + kind: kind.into(), + }) +} + +/// Describes why deserialization of some of the built-in types failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum BuiltinDeserializationErrorKind {} + +impl Display for BuiltinDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} From bea0f8e63a7370bc948af18661d944899d4f9324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 10:51:28 +0200 Subject: [PATCH 05/41] deser/row: error facilities --- scylla-cql/src/types/deserialize/mod.rs | 1 + scylla-cql/src/types/deserialize/row.rs | 95 +++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 scylla-cql/src/types/deserialize/row.rs diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index e48fa5a57e..9d5e8220a0 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -1,4 +1,5 @@ pub mod frame_slice; +pub mod row; pub mod value; pub use frame_slice::FrameSlice; diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs new file mode 100644 index 0000000000..272dca73dd --- /dev/null +++ b/scylla-cql/src/types/deserialize/row.rs @@ -0,0 +1,95 @@ +//! Provides types for dealing with row deserialization. + +use std::fmt::Display; + +use thiserror::Error; + +use super::{DeserializationError, TypeCheckError}; + +use crate::frame::response::result::ColumnType; + +// Error facilities + +/// Failed to type check incoming result column types again given Rust type, +/// one of the types having support built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check the Rust type {rust_name} against CQL column types {cql_types:?} : {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type used to represent the values. + pub rust_name: &'static str, + + /// The CQL types of the values that the Rust type was being deserialized from. + pub cql_types: Vec, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err( + cql_types: impl IntoIterator, + kind: impl Into, +) -> TypeCheckError { + mk_typck_err_named(std::any::type_name::(), cql_types, kind) +} + +fn mk_typck_err_named( + name: &'static str, + cql_types: impl IntoIterator, + kind: impl Into, +) -> TypeCheckError { + TypeCheckError::new(BuiltinTypeCheckError { + rust_name: name, + cql_types: Vec::from_iter(cql_types), + kind: kind.into(), + }) +} + +/// Describes why type checking incoming result column types again given Rust type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind {} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + +/// Failed to deserialize a row from the DB response, represented by one of the types +/// built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to deserialize query result row {rust_name}: {kind}")] +pub struct BuiltinDeserializationError { + /// Name of the Rust type used to represent the row. + pub rust_name: &'static str, + + /// Detailed information about the failure. + pub kind: BuiltinDeserializationErrorKind, +} + +pub(super) fn mk_deser_err( + kind: impl Into, +) -> DeserializationError { + mk_deser_err_named(std::any::type_name::(), kind) +} + +pub(super) fn mk_deser_err_named( + name: &'static str, + kind: impl Into, +) -> DeserializationError { + DeserializationError::new(BuiltinDeserializationError { + rust_name: name, + kind: kind.into(), + }) +} + +/// Describes why deserializing a result row failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinDeserializationErrorKind {} + +impl Display for BuiltinDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} From 704158812b429661fa34215ddd3e2277605ea463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 09:22:24 +0200 Subject: [PATCH 06/41] deser/row: introduce ColumnIterator ColumnIterator, well, iterates over raw columns of a single row returned from the DB. Among others, it will serve as a basis of the new DeserializeRow trait. Implementors of that trait will be exhausting ColumnIterator and deserializing consecutive columns on the fly. --- scylla-cql/src/types/deserialize/row.rs | 99 ++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 4 deletions(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 272dca73dd..56bf97e49a 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -4,9 +4,76 @@ use std::fmt::Display; use thiserror::Error; -use super::{DeserializationError, TypeCheckError}; +use super::{DeserializationError, FrameSlice, TypeCheckError}; +use crate::frame::response::result::{ColumnSpec, ColumnType}; -use crate::frame::response::result::ColumnType; +/// Represents a raw, unparsed column value. +#[non_exhaustive] +pub struct RawColumn<'frame> { + pub index: usize, + pub spec: &'frame ColumnSpec, + pub slice: Option>, +} + +/// Iterates over columns of a single row. +#[derive(Clone, Debug)] +pub struct ColumnIterator<'frame> { + specs: std::iter::Enumerate>, + slice: FrameSlice<'frame>, +} + +impl<'frame> ColumnIterator<'frame> { + /// Creates a new iterator over a single row. + /// + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a [FrameSlice] which points to the serialized row. + #[inline] + pub(crate) fn new(specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs: specs.iter().enumerate(), + slice, + } + } + + /// Returns the remaining number of columns that this iterator is expected + /// to return. + #[inline] + pub fn columns_remaining(&self) -> usize { + self.specs.len() + } +} + +impl<'frame> Iterator for ColumnIterator<'frame> { + type Item = Result, DeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + let (column_index, spec) = self.specs.next()?; + Some( + self.slice + .read_cql_bytes() + .map(|slice| RawColumn { + index: column_index, + spec, + slice, + }) + .map_err(|err| { + mk_deser_err::( + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name: spec.name.clone(), + err: DeserializationError::new(err), + }, + ) + }), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.specs.size_hint() + } +} // Error facilities @@ -86,10 +153,34 @@ pub(super) fn mk_deser_err_named( /// Describes why deserializing a result row failed. #[derive(Debug, Clone)] #[non_exhaustive] -pub enum BuiltinDeserializationErrorKind {} +pub enum BuiltinDeserializationErrorKind { + /// One of the raw columns failed to deserialize, most probably + /// due to the invalid column structure inside a row in the frame. + RawColumnDeserializationFailed { + /// Index of the raw column that failed to deserialize. + column_index: usize, + + /// Name of the raw column that failed to deserialize. + column_name: String, + + /// The error that caused the raw column deserialization to fail. + err: DeserializationError, + }, +} impl Display for BuiltinDeserializationErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Ok(()) + match self { + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name, + err, + } => { + write!( + f, + "failed to deserialize raw column {column_name} at index {column_index} (most probably due to invalid column structure inside a row): {err}" + ) + } + } } } From fe704858deadf21d52fb5ed75cd19379f4001f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 09:28:16 +0200 Subject: [PATCH 07/41] scylla-cql/types: introduce new deserialization traits This commit introduces traits that will serve as the base of the new deserialization framework. It is meant to replace the old way of deserialization based on CqlValue, FromCqlVal and FromRow. The new traits are `DeserializeValue<'f>` and `DeserializeRow<'f>`. If a type implements `DeserializeValue<'f>`, this means that it can take a buffer of lifetime `'f` and deserialize itself, treating the contents of the buffer as a CQL value. Analogously, `DeserializeRow<'f>` allows for deserialization of types that are supposed to represent a whole row returned in the results. Deserialization is now split into two phases: type checking and actual deserialization. Both traits have two methods: `type_check` and `deserialize`. The idea behind this is to validate the column types of the response only once, after the response is received. The `deserialize` method is then called for each row/value to perform the actual parsing: that method is allowed to assume that `type_check` was called and may skip some type checking for performance (although it cannot use that assumption to perform unsafe operations, it is not an unsafe method). The previous framework only supported deserialization of types that exclusively own their data, for example String or Vec. However, the new framework allows for types that borrow or co-own the frame buffer: - Types that borrow the buffer are &'f str and &'f [u8]. They just point to the raw data from the serialized response. Rust's lifetime system makes sure that the the user doesn't deallocate the serialized response until using the deserialized types that point to the buffer are dropped. Apart from the aforementioned types, a bunch of iterator types are introduced that allow consuming collections without allocations. - Serialized frame is represented with bytes::Bytes. It is possible to create a subslice (also of type Bytes) that keep the original Bytes alive. The type being deserialized can obtain access to such a subslice. Like in the case of e.g. &str vs. String, keeping Bytes also allows to avoid an allocation, but it is also easier to handle as the deserialized type doesn't have to be bound by any lifetime. It is important to be careful when handling such types as they keep the whole serialized frame alive, even if they only point to a small subslice. This can lead to a space leak and more memory being used than necessary. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/mod.rs | 164 ++++++++++++++++++++++ scylla-cql/src/types/deserialize/row.rs | 47 +++++++ scylla-cql/src/types/deserialize/value.rs | 30 +++- scylla/src/lib.rs | 5 +- 4 files changed, 244 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 9d5e8220a0..70b9f7f7c4 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -1,3 +1,167 @@ +//! Framework for deserialization of data returned by database queries. +//! +//! Deserialization is based on two traits: +//! +//! - A type that implements `DeserializeValue<'frame>` can be deserialized +//! from a single _CQL value_ - i.e. an element of a row in the query result, +//! - A type that implements `DeserializeRow<'frame>` can be deserialized +//! from a single _row_ of a query result. +//! +//! Those traits are quite similar to each other, both in the idea behind them +//! and the interface that they expose. +//! +//! # `type_check` and `deserialize` +//! +//! The deserialization process is divided into two parts: type checking and +//! actual deserialization, represented by `DeserializeValue`/`DeserializeRow`'s +//! methods called `type_check` and `deserialize`. +//! +//! The `deserialize` method can assume that `type_check` was called before, so +//! it doesn't have to verify the type again. This can be a performance gain +//! when deserializing query results with multiple rows: as each row in a result +//! has the same type, it is only necessary to call `type_check` once for the +//! whole result and then `deserialize` for each row. +//! +//! Note that `deserialize` is not an `unsafe` method - although you can be +//! sure that the driver will call `type_check` before `deserialize`, you +//! shouldn't do unsafe things based on this assumption. +//! +//! # Data ownership +//! +//! Some CQL types can be easily consumed while still partially serialized. +//! For example, types like `blob` or `text` can be just represented with +//! `&[u8]` and `&str` that just point to a part of the serialized response. +//! This is more efficient than using `Vec` or `String` because it avoids +//! an allocation and a copy, however it is less convenient because those types +//! are bound with a lifetime. +//! +//! The framework supports types that refer to the serialized response's memory +//! in three different ways: +//! +//! ## Owned types +//! +//! Some types don't borrow anything and fully own their data, e.g. `i32` or +//! `String`. They aren't constrained by any lifetime and should implement +//! the respective trait for _all_ lifetimes, i.e.: +//! +//! ```rust +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! struct MyVec(Vec); +//! impl<'frame> DeserializeValue<'frame> for MyVec { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new( +//! ParseError::BadIncomingData("Expected bytes".to_owned()) +//! )) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new( +//! ParseError::BadIncomingData("Expected non-null value".to_owned()) +//! ) +//! }) +//! .map(|v| Self(v.as_slice().to_vec())) +//! } +//! } +//! ``` +//! +//! ## Borrowing types +//! +//! Some types do not fully contain their data but rather will point to some +//! bytes in the serialized response, e.g. `&str` or `&[u8]`. Those types will +//! usually contain a lifetime in their definition. In order to properly +//! implement `DeserializeValue` or `DeserializeRow` for such a type, the `impl` +//! should still have a generic lifetime parameter, but the lifetimes from the +//! type definition should be constrained with the generic lifetime parameter. +//! For example: +//! +//! ```rust +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! struct MySlice<'a>(&'a [u8]); +//! impl<'a, 'frame> DeserializeValue<'frame> for MySlice<'a> +//! where +//! 'frame: 'a, +//! { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new( +//! ParseError::BadIncomingData("Expected bytes".to_owned()) +//! )) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new( +//! ParseError::BadIncomingData("Expected non-null value".to_owned()) +//! ) +//! }) +//! .map(|v| Self(v.as_slice())) +//! } +//! } +//! ``` +//! +//! ## Reference-counted types +//! +//! Internally, the driver uses the `bytes::Bytes` type to keep the contents +//! of the serialized response. It supports creating derived `Bytes` objects +//! which point to a subslice but keep the whole, original `Bytes` object alive. +//! +//! During deserialization, a type can obtain a `Bytes` subslice that points +//! to the serialized value. This approach combines advantages of the previous +//! two approaches - creating a derived `Bytes` object can be cheaper than +//! allocation and a copy (it supports `Arc`-like semantics) and the `Bytes` +//! type is not constrained by a lifetime. However, you should be aware that +//! the subslice will keep the whole `Bytes` object that holds the frame alive. +//! It is not recommended to use this approach for long-living objects because +//! it can introduce space leaks. +//! +//! Example: +//! +//! ```rust +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! # use bytes::Bytes; +//! struct MyBytes(Bytes); +//! impl<'frame> DeserializeValue<'frame> for MyBytes { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new(ParseError::BadIncomingData("Expected bytes".to_owned()))) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new(ParseError::BadIncomingData("Expected non-null value".to_owned())) +//! }) +//! .map(|v| Self(v.to_bytes())) +//! } +//! } +//! ``` +// TODO: in the above module docstring, stop abusing ParseError once errors are refactored. + pub mod frame_slice; pub mod row; pub mod value; diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 56bf97e49a..5eaa46b91c 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -75,6 +75,53 @@ impl<'frame> Iterator for ColumnIterator<'frame> { } } +/// A type that can be deserialized from a row that was returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeRow<'frame> +where + Self: Sized, +{ + /// Checks that the schema of the result matches what this type expects. + /// + /// This function can check whether column types and names match the + /// expectations. + fn type_check(specs: &[ColumnSpec]) -> Result<(), TypeCheckError>; + + /// Deserializes a row from given column iterator. + /// + /// This function can assume that the driver called `type_check` to verify + /// the row's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize(row: ColumnIterator<'frame>) -> Result; +} + +// raw deserialization as ColumnIterator + +// What is the purpose of implementing DeserializeRow for ColumnIterator? +// +// Sometimes users might be interested in operating on ColumnIterator directly. +// Implementing DeserializeRow for it allows us to simplify our interface. For example, +// we have `QueryResult::rows()` - you can put T = ColumnIterator +// instead of having a separate rows_raw function or something like this. +impl<'frame> DeserializeRow<'frame> for ColumnIterator<'frame> { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + Ok(()) + } + + #[inline] + fn deserialize(row: ColumnIterator<'frame>) -> Result { + Ok(row) + } +} + // Error facilities /// Failed to type check incoming result column types again given Rust type, diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index f3b196a979..3d3a3e0c27 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -4,9 +4,37 @@ use std::fmt::Display; use thiserror::Error; -use super::{DeserializationError, TypeCheckError}; +use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::response::result::ColumnType; +/// A type that can be deserialized from a column value inside a row that was +/// returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeValue<'frame> +where + Self: Sized, +{ + /// Checks that the column type matches what this type expects. + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError>; + + /// Deserialize a column value from given serialized representation. + /// + /// This function can assume that the driver called `type_check` to verify + /// the column's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result; +} + // Error facilities /// Type checking of one of the built-in types failed. diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 1b46559698..818c5ebbd0 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -126,7 +126,10 @@ pub mod frame { } } -pub use scylla_cql::types::serialize; +// FIXME: finer-grained control over exports +// Some types are `pub` in scylla-cql just for scylla crate, +// and those shouldn't be exposed for users. +pub use scylla_cql::types::{deserialize, serialize}; pub mod authentication; #[cfg(feature = "cloud")] From a2a83db7d0efd408e0ad76b7ffed75a3d0b9827a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 17:25:21 +0100 Subject: [PATCH 08/41] deser/value: set up test machinery The machinery is going to be used for testing proper deserialization of all supported types. The tests will involve compatibility checks with the previous framework. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 97 ++++++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 3d3a3e0c27..9cfdcc1b13 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -5,6 +5,7 @@ use std::fmt::Display; use thiserror::Error; use super::{DeserializationError, FrameSlice, TypeCheckError}; +use crate::frame::frame_errors::ParseError; use crate::frame::response::result::ColumnType; /// A type that can be deserialized from a column value inside a row that was @@ -132,10 +133,102 @@ fn mk_deser_err_named( /// Describes why deserialization of some of the built-in types failed. #[derive(Debug)] #[non_exhaustive] -pub enum BuiltinDeserializationErrorKind {} +pub enum BuiltinDeserializationErrorKind { + /// A generic deserialization failure - legacy error type. + GenericParseError(ParseError), +} impl Display for BuiltinDeserializationErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Ok(()) + match self { + BuiltinDeserializationErrorKind::GenericParseError(err) => err.fmt(f), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{BufMut, Bytes, BytesMut}; + + use std::fmt::Debug; + + use crate::frame::response::cql_to_rust::FromCqlVal; + use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; + use crate::frame::types; + use crate::types::deserialize::{DeserializationError, FrameSlice}; + use crate::types::serialize::value::SerializeValue; + use crate::types::serialize::CellWriter; + + use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue}; + + // Checks that both new and old serialization framework + // produces the same results in this case + fn compat_check(typ: &ColumnType, raw: Bytes) + where + T: for<'f> DeserializeValue<'f>, + T: FromCqlVal>, + T: Debug + PartialEq, + { + let mut slice = raw.as_ref(); + let mut cell = types::read_bytes_opt(&mut slice).unwrap(); + let old = T::from_cql( + cell.as_mut() + .map(|c| deser_cql_value(typ, c)) + .transpose() + .unwrap(), + ) + .unwrap(); + let new = deserialize::(typ, &raw).unwrap(); + assert_eq!(old, new); + } + + fn compat_check_serialized(typ: &ColumnType, val: &dyn SerializeValue) + where + T: for<'f> DeserializeValue<'f>, + T: FromCqlVal>, + T: Debug + PartialEq, + { + let raw = serialize(typ, val); + compat_check::(typ, raw); + } + + fn deserialize<'frame, T>( + typ: &'frame ColumnType, + bytes: &'frame Bytes, + ) -> Result + where + T: DeserializeValue<'frame>, + { + >::type_check(typ) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let mut frame_slice = FrameSlice::new(bytes); + let value = frame_slice.read_cql_bytes().map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) + })?; + >::deserialize(typ, value) + } + + fn make_bytes(cell: &[u8]) -> Bytes { + let mut b = BytesMut::new(); + append_bytes(&mut b, cell); + b.freeze() + } + + fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { + let mut bytes = Bytes::new(); + serialize_to_buf(typ, value, &mut bytes); + bytes + } + + fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { + let mut v = Vec::new(); + let writer = CellWriter::new(&mut v); + value.serialize(typ, writer).unwrap(); + *buf = v.into(); + } + + fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { + b.put_i32(cell.len() as i32); + b.put_slice(cell); } } From ec8378baa8e0fe884f2bbf8f6e4e4ea39e1be4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 08:46:02 +0200 Subject: [PATCH 09/41] value: impl DeserializeValue for CqlValue CqlValue's role is reduced to specialised use cases, where ability to hold any CQL value in one type is required. Otherwise, it is encouraged to deserialize data straight to end Rust types, not using CqlValue at all. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 42 ++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 9cfdcc1b13..351edd26fb 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -6,7 +6,7 @@ use thiserror::Error; use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; -use crate::frame::response::result::ColumnType; +use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -36,6 +36,40 @@ where ) -> Result; } +impl<'frame> DeserializeValue<'frame> for CqlValue { + fn type_check(_typ: &ColumnType) -> Result<(), TypeCheckError> { + // CqlValue accepts all possible CQL types + Ok(()) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut val = ensure_not_null_slice::(typ, v)?; + let cql = deser_cql_value(typ, &mut val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) + })?; + Ok(cql) + } +} + +// Utilities + +fn ensure_not_null_frame_slice<'frame, T>( + typ: &ColumnType, + v: Option>, +) -> Result, DeserializationError> { + v.ok_or_else(|| mk_deser_err::(typ, BuiltinDeserializationErrorKind::ExpectedNonNull)) +} + +fn ensure_not_null_slice<'frame, T>( + typ: &ColumnType, + v: Option>, +) -> Result<&'frame [u8], DeserializationError> { + ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.as_slice()) +} + // Error facilities /// Type checking of one of the built-in types failed. @@ -136,12 +170,18 @@ fn mk_deser_err_named( pub enum BuiltinDeserializationErrorKind { /// A generic deserialization failure - legacy error type. GenericParseError(ParseError), + + /// Expected non-null value, got null. + ExpectedNonNull, } impl Display for BuiltinDeserializationErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { BuiltinDeserializationErrorKind::GenericParseError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::ExpectedNonNull => { + f.write_str("expected a non-null value, got null") + } } } } From 1d005f19aba80aeae86e8d045383a812b6693951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 08:48:04 +0200 Subject: [PATCH 10/41] value: impl DeserializeValue for Option (null values) Option, as before, is going to represent null values. In the next commit, MaybeEmpty is introduced, which is a distinct type used to represent empty values. Although this is a quirky discrepancy in the CQL protocol, we want to represent it strictly using Rust type system. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 351edd26fb..c60ab7269f 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -54,6 +54,25 @@ impl<'frame> DeserializeValue<'frame> for CqlValue { } } +// Option represents nullability of CQL values: +// None corresponds to null, +// Some(val) to non-null values. +impl<'frame, T> DeserializeValue<'frame> for Option +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + T::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + v.map(|_| T::deserialize(typ, v)).transpose() + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( From 0e1f91678808ba0fc776382bfa78df8703f33329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:35:50 +0100 Subject: [PATCH 11/41] value: impl DeserializeValue for MaybeEmpty Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index c60ab7269f..70edb3f061 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -73,6 +73,50 @@ where } } +/// Values that may be empty or not. +/// +/// In CQL, some types can have a special value of "empty", represented as +/// a serialized value of length 0. An example of this are integral types: +/// the "int" type can actually hold 2^32 + 1 possible values because of this +/// quirk. Note that this is distinct from being NULL. +/// +/// Rust types that cannot represent an empty value (e.g. i32) should implement +/// this trait in order to be deserialized as [MaybeEmpty]. +pub trait Emptiable {} + +/// A value that may be empty or not. +/// +/// `MaybeEmpty` was introduced to help support the quirk described in [Emptiable] +/// for Rust types which can't represent the empty, additional value. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub enum MaybeEmpty { + Empty, + Value(T), +} + +impl<'frame, T> DeserializeValue<'frame> for MaybeEmpty +where + T: DeserializeValue<'frame> + Emptiable, +{ + #[inline] + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + if val.is_empty() { + Ok(MaybeEmpty::Empty) + } else { + let v = >::deserialize(typ, v)?; + Ok(MaybeEmpty::Value(v)) + } + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( From b9465a5d453ae3e5a099c834fad949d84d294d24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 14:46:29 +0100 Subject: [PATCH 12/41] value: impl DeserializeValue for fixed numeric types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 198 +++++++++++++++++++++- 1 file changed, 195 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 70edb3f061..83b170fa22 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -117,6 +117,88 @@ where } } +macro_rules! impl_strict_type { + ($t:ty, [$($cql:ident)|+], $conv:expr $(, $l:lifetime)?) => { + impl<$($l,)? 'frame> DeserializeValue<'frame> for $t + where + $('frame: $l)? + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // TODO: Format the CQL type names in the same notation + // that ScyllaDB/Cassandra uses internally and include them + // in such form in the error message + exact_type_check!(typ, $($cql),*); + Ok(()) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + $conv(typ, v) + } + } + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident, $conv:expr $(, $l:lifetime)?) => { + impl_strict_type!($t, [$cql], $conv $(, $l)*); + }; +} + +macro_rules! impl_emptiable_strict_type { + ($t:ty, [$($cql:ident)|+], $conv:expr $(, $l:lifetime)?) => { + impl<$($l,)?> Emptiable for $t {} + + impl_strict_type!($t, [$($cql)|*], $conv $(, $l)*); + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident, $conv:expr $(, $l:lifetime)?) => { + impl_emptiable_strict_type!($t, [$cql], $conv $(, $l)*); + }; + +} + +// fixed numeric types + +macro_rules! impl_fixed_numeric_type { + ($t:ty, [$($cql:ident)|+]) => { + impl_emptiable_strict_type!( + $t, + [$($cql)|*], + |typ: &'frame ColumnType, v: Option>| { + const SIZE: usize = std::mem::size_of::<$t>(); + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + Ok(<$t>::from_be_bytes(*arr)) + } + ); + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident) => { + impl_fixed_numeric_type!($t, [$cql]); + }; +} + +impl_emptiable_strict_type!( + bool, + Boolean, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + Ok(arr[0] != 0x00) + } +); + +impl_fixed_numeric_type!(i8, TinyInt); +impl_fixed_numeric_type!(i16, SmallInt); +impl_fixed_numeric_type!(i32, Int); +impl_fixed_numeric_type!(i64, [BigInt | Counter]); +impl_fixed_numeric_type!(f32, Float); +impl_fixed_numeric_type!(f64, Double); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -133,6 +215,21 @@ fn ensure_not_null_slice<'frame, T>( ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.as_slice()) } +fn ensure_exact_length<'frame, T, const SIZE: usize>( + typ: &ColumnType, + v: &'frame [u8], +) -> Result<&'frame [u8; SIZE], DeserializationError> { + v.try_into().map_err(|_| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: SIZE, + got: v.len(), + }, + ) + }) +} + // Error facilities /// Type checking of one of the built-in types failed. @@ -186,11 +283,21 @@ use exact_type_check; /// Describes why type checking some of the built-in types failed. #[derive(Debug, Clone)] #[non_exhaustive] -pub enum BuiltinTypeCheckErrorKind {} +pub enum BuiltinTypeCheckErrorKind { + /// Expected one from a list of particular types. + MismatchedType { + /// The list of types that the Rust type can deserialize from. + expected: &'static [ColumnType], + }, +} impl Display for BuiltinTypeCheckErrorKind { - fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Ok(()) + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinTypeCheckErrorKind::MismatchedType { expected } => { + write!(f, "expected one of the CQL types: {expected:?}") + } + } } } @@ -236,6 +343,9 @@ pub enum BuiltinDeserializationErrorKind { /// Expected non-null value, got null. ExpectedNonNull, + + /// The length of read value in bytes is different than expected for the Rust type. + ByteLengthMismatch { expected: usize, got: usize }, } impl Display for BuiltinDeserializationErrorKind { @@ -245,6 +355,11 @@ impl Display for BuiltinDeserializationErrorKind { BuiltinDeserializationErrorKind::ExpectedNonNull => { f.write_str("expected a non-null value, got null") } + BuiltinDeserializationErrorKind::ByteLengthMismatch { expected, got } => write!( + f, + "the CQL type requires {} bytes, but got {}", + expected, got, + ), } } } @@ -264,6 +379,83 @@ mod tests { use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue}; + #[test] + fn test_integral() { + let tinyint = make_bytes(&[0x01]); + let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); + assert_eq!(decoded_tinyint, 0x01); + + let smallint = make_bytes(&[0x01, 0x02]); + let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); + assert_eq!(decoded_smallint, 0x0102); + + let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); + let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, 0x01020304); + + let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); + assert_eq!(decoded_bigint, 0x0102030405060708); + } + + #[test] + fn test_bool() { + for boolean in [true, false] { + let boolean_bytes = make_bytes(&[boolean as u8]); + let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); + assert_eq!(decoded_bool, boolean); + } + } + + #[test] + fn test_floating_point() { + let float = make_bytes(&[63, 0, 0, 0]); + let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); + assert_eq!(decoded_float, 0.5); + + let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); + let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); + assert_eq!(decoded_double, 2.0); + } + + #[test] + fn test_from_cql_value_compatibility() { + // This test should have a sub-case for each type + // that implements FromCqlValue + + // fixed size integers + for i in 0..7 { + let v: i8 = 1 << i; + compat_check::(&ColumnType::TinyInt, make_bytes(&v.to_be_bytes())); + compat_check::(&ColumnType::TinyInt, make_bytes(&(-v).to_be_bytes())); + } + for i in 0..15 { + let v: i16 = 1 << i; + compat_check::(&ColumnType::SmallInt, make_bytes(&v.to_be_bytes())); + compat_check::(&ColumnType::SmallInt, make_bytes(&(-v).to_be_bytes())); + } + for i in 0..31 { + let v: i32 = 1 << i; + compat_check::(&ColumnType::Int, make_bytes(&v.to_be_bytes())); + compat_check::(&ColumnType::Int, make_bytes(&(-v).to_be_bytes())); + } + for i in 0..63 { + let v: i64 = 1 << i; + compat_check::(&ColumnType::BigInt, make_bytes(&v.to_be_bytes())); + compat_check::(&ColumnType::BigInt, make_bytes(&(-v).to_be_bytes())); + } + + // bool + compat_check::(&ColumnType::Boolean, make_bytes(&[0])); + compat_check::(&ColumnType::Boolean, make_bytes(&[1])); + + // fixed size floating point types + compat_check::(&ColumnType::Float, make_bytes(&123f32.to_be_bytes())); + compat_check::(&ColumnType::Float, make_bytes(&(-123f32).to_be_bytes())); + compat_check::(&ColumnType::Double, make_bytes(&123f64.to_be_bytes())); + compat_check::(&ColumnType::Double, make_bytes(&(-123f64).to_be_bytes())); + } + // Checks that both new and old serialization framework // produces the same results in this case fn compat_check(typ: &ColumnType, raw: Bytes) From 3cd5b698272a048cf9d39d0fa60fb79b0432e20e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 14:47:18 +0100 Subject: [PATCH 13/41] value: impl DeserializeValue for variable length numeric types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 143 ++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 83b170fa22..7ee2134431 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -7,6 +7,8 @@ use thiserror::Error; use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; +use crate::frame::types; +use crate::frame::value::{CqlDecimal, CqlVarint}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -199,6 +201,69 @@ impl_fixed_numeric_type!(i64, [BigInt | Counter]); impl_fixed_numeric_type!(f32, Float); impl_fixed_numeric_type!(f64, Double); +// other numeric types + +impl_emptiable_strict_type!( + CqlVarint, + Varint, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(CqlVarint::from_signed_bytes_be_slice(val)) + } +); + +#[cfg(feature = "num-bigint-03")] +impl_emptiable_strict_type!(num_bigint_03::BigInt, Varint, |typ: &'frame ColumnType, + v: Option< + FrameSlice<'frame>, +>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(num_bigint_03::BigInt::from_signed_bytes_be(val)) +}); + +#[cfg(feature = "num-bigint-04")] +impl_emptiable_strict_type!(num_bigint_04::BigInt, Varint, |typ: &'frame ColumnType, + v: Option< + FrameSlice<'frame>, +>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(num_bigint_04::BigInt::from_signed_bytes_be(val)) +}); + +impl_emptiable_strict_type!( + CqlDecimal, + Decimal, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + let scale = types::read_int(&mut val).map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::GenericParseError(err.into()), + ) + })?; + Ok(CqlDecimal::from_signed_be_bytes_slice_and_exponent( + val, scale, + )) + } +); + +#[cfg(feature = "bigdecimal-04")] +impl_emptiable_strict_type!( + bigdecimal_04::BigDecimal, + Decimal, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + let scale = types::read_int(&mut val).map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::GenericParseError(err.into()), + ) + })? as i64; + let int_value = bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(val); + Ok(bigdecimal_04::BigDecimal::from((int_value, scale))) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -373,6 +438,7 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; + use crate::frame::value::{CqlDecimal, CqlVarint}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -454,6 +520,83 @@ mod tests { compat_check::(&ColumnType::Float, make_bytes(&(-123f32).to_be_bytes())); compat_check::(&ColumnType::Double, make_bytes(&123f64.to_be_bytes())); compat_check::(&ColumnType::Double, make_bytes(&(-123f64).to_be_bytes())); + + // big integers + const PI_STR: &[u8] = b"3.1415926535897932384626433832795028841971693993751058209749445923"; + let num1 = &PI_STR[2..]; + let num2 = [b'-'] + .into_iter() + .chain(PI_STR[2..].iter().copied()) + .collect::>(); + let num3 = &b"0"[..]; + + // native - CqlVarint + { + let num1 = CqlVarint::from_signed_bytes_be_slice(num1); + let num2 = CqlVarint::from_signed_bytes_be_slice(&num2); + let num3 = CqlVarint::from_signed_bytes_be_slice(num3); + compat_check_serialized::(&ColumnType::Varint, &num1); + compat_check_serialized::(&ColumnType::Varint, &num2); + compat_check_serialized::(&ColumnType::Varint, &num3); + } + + #[cfg(feature = "num-bigint-03")] + { + use num_bigint_03::BigInt; + + let num1 = BigInt::parse_bytes(num1, 10).unwrap(); + let num2 = BigInt::parse_bytes(&num2, 10).unwrap(); + let num3 = BigInt::parse_bytes(num3, 10).unwrap(); + compat_check_serialized::(&ColumnType::Varint, &num1); + compat_check_serialized::(&ColumnType::Varint, &num2); + compat_check_serialized::(&ColumnType::Varint, &num3); + } + + #[cfg(feature = "num-bigint-04")] + { + use num_bigint_04::BigInt; + + let num1 = BigInt::parse_bytes(num1, 10).unwrap(); + let num2 = BigInt::parse_bytes(&num2, 10).unwrap(); + let num3 = BigInt::parse_bytes(num3, 10).unwrap(); + compat_check_serialized::(&ColumnType::Varint, &num1); + compat_check_serialized::(&ColumnType::Varint, &num2); + compat_check_serialized::(&ColumnType::Varint, &num3); + } + + // big decimals + { + let scale1 = 0; + let scale2 = -42; + let scale3 = 2137; + let num1 = CqlDecimal::from_signed_be_bytes_slice_and_exponent(num1, scale1); + let num2 = CqlDecimal::from_signed_be_bytes_and_exponent(num2, scale2); + let num3 = CqlDecimal::from_signed_be_bytes_slice_and_exponent(num3, scale3); + compat_check_serialized::(&ColumnType::Decimal, &num1); + compat_check_serialized::(&ColumnType::Decimal, &num2); + compat_check_serialized::(&ColumnType::Decimal, &num3); + } + + // native - CqlDecimal + + #[cfg(feature = "bigdecimal-04")] + { + use bigdecimal_04::BigDecimal; + + let num1 = PI_STR.to_vec(); + let num2 = vec![b'-'] + .into_iter() + .chain(PI_STR.iter().copied()) + .collect::>(); + let num3 = b"0.0".to_vec(); + + let num1 = BigDecimal::parse_bytes(&num1, 10).unwrap(); + let num2 = BigDecimal::parse_bytes(&num2, 10).unwrap(); + let num3 = BigDecimal::parse_bytes(&num3, 10).unwrap(); + compat_check_serialized::(&ColumnType::Decimal, &num1); + compat_check_serialized::(&ColumnType::Decimal, &num2); + compat_check_serialized::(&ColumnType::Decimal, &num3); + } } // Checks that both new and old serialization framework From 0abe185841ac17b04d8eb99d9727e0503f468a92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:31:24 +0100 Subject: [PATCH 14/41] value: impl DeserializeValue for Blob Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 56 +++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 7ee2134431..257dbf1ab4 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,5 +1,7 @@ //! Provides types for dealing with CQL value deserialization. +use bytes::Bytes; + use std::fmt::Display; use thiserror::Error; @@ -264,6 +266,34 @@ impl_emptiable_strict_type!( } ); +// blob + +impl_strict_type!( + &'a [u8], + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(val) + }, + 'a +); +impl_strict_type!( + Vec, + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(val.to_vec()) + } +); +impl_strict_type!( + Bytes, + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_owned::(typ, v)?; + Ok(val) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -280,6 +310,13 @@ fn ensure_not_null_slice<'frame, T>( ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.as_slice()) } +fn ensure_not_null_owned( + typ: &ColumnType, + v: Option, +) -> Result { + ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.to_bytes()) +} + fn ensure_exact_length<'frame, T, const SIZE: usize>( typ: &ColumnType, v: &'frame [u8], @@ -445,6 +482,21 @@ mod tests { use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue}; + #[test] + fn test_deserialize_bytes() { + const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; + + let bytes = make_bytes(ORIGINAL_BYTES); + + let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); + + assert_eq!(decoded_slice, ORIGINAL_BYTES); + assert_eq!(decoded_vec, ORIGINAL_BYTES); + assert_eq!(decoded_bytes, ORIGINAL_BYTES); + } + #[test] fn test_integral() { let tinyint = make_bytes(&[0x01]); @@ -597,6 +649,10 @@ mod tests { compat_check_serialized::(&ColumnType::Decimal, &num2); compat_check_serialized::(&ColumnType::Decimal, &num3); } + + // blob + compat_check::>(&ColumnType::Blob, make_bytes(&[])); + compat_check::>(&ColumnType::Blob, make_bytes(&[1, 9, 2, 8, 3, 7, 4, 6, 5])); } // Checks that both new and old serialization framework From b759ad92c92cd7851346a86d819d6031bda1fc3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:32:05 +0100 Subject: [PATCH 15/41] value: impl DeserializeValue for string Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 99 +++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 257dbf1ab4..ca51fdc11a 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -294,6 +294,55 @@ impl_strict_type!( } ); +// string + +macro_rules! impl_string_type { + ($t:ty, $conv:expr $(, $l:lifetime)?) => { + impl_strict_type!( + $t, + [Ascii | Text], + $conv + $(, $l)? + ); + } +} + +fn check_ascii(typ: &ColumnType, s: &[u8]) -> Result<(), DeserializationError> { + if matches!(typ, ColumnType::Ascii) && !s.is_ascii() { + return Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ExpectedAscii, + )); + } + Ok(()) +} + +impl_string_type!( + &'a str, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + check_ascii::<&str>(typ, val)?; + let s = std::str::from_utf8(val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::InvalidUtf8(err)) + })?; + Ok(s) + }, + 'a +); +impl_string_type!( + String, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + check_ascii::(typ, val)?; + let s = std::str::from_utf8(val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::InvalidUtf8(err)) + })?; + Ok(s.to_string()) + } +); + +// TODO: Consider support for deserialization of string::String + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -448,6 +497,12 @@ pub enum BuiltinDeserializationErrorKind { /// The length of read value in bytes is different than expected for the Rust type. ByteLengthMismatch { expected: usize, got: usize }, + + /// Expected valid ASCII string. + ExpectedAscii, + + /// Invalid UTF-8 string. + InvalidUtf8(std::str::Utf8Error), } impl Display for BuiltinDeserializationErrorKind { @@ -462,6 +517,10 @@ impl Display for BuiltinDeserializationErrorKind { "the CQL type requires {} bytes, but got {}", expected, got, ), + BuiltinDeserializationErrorKind::ExpectedAscii => { + f.write_str("expected a valid ASCII string") + } + BuiltinDeserializationErrorKind::InvalidUtf8(err) => err.fmt(f), } } } @@ -497,6 +556,39 @@ mod tests { assert_eq!(decoded_bytes, ORIGINAL_BYTES); } + #[test] + fn test_deserialize_ascii() { + const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; + + let ascii = make_bytes(ASCII_TEXT.as_bytes()); + + let decoded_ascii_str = deserialize::<&str>(&ColumnType::Ascii, &ascii).unwrap(); + let decoded_ascii_string = deserialize::(&ColumnType::Ascii, &ascii).unwrap(); + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &ascii).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &ascii).unwrap(); + + assert_eq!(decoded_ascii_str, ASCII_TEXT); + assert_eq!(decoded_ascii_string, ASCII_TEXT); + assert_eq!(decoded_text_str, ASCII_TEXT); + assert_eq!(decoded_text_string, ASCII_TEXT); + } + + #[test] + fn test_deserialize_text() { + const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; + + let unicode = make_bytes(UNICODE_TEXT.as_bytes()); + + // Should fail because it's not an ASCII string + deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); + deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); + + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); + assert_eq!(decoded_text_str, UNICODE_TEXT); + assert_eq!(decoded_text_string, UNICODE_TEXT); + } + #[test] fn test_integral() { let tinyint = make_bytes(&[0x01]); @@ -653,6 +745,13 @@ mod tests { // blob compat_check::>(&ColumnType::Blob, make_bytes(&[])); compat_check::>(&ColumnType::Blob, make_bytes(&[1, 9, 2, 8, 3, 7, 4, 6, 5])); + + // text types + for typ in &[ColumnType::Ascii, ColumnType::Text] { + compat_check::(typ, make_bytes("".as_bytes())); + compat_check::(typ, make_bytes("foo".as_bytes())); + compat_check::(typ, make_bytes("superfragilisticexpialidocious".as_bytes())); + } } // Checks that both new and old serialization framework From f1d7097fe850c2f0437d92d3ac3f289f16988c08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:32:31 +0100 Subject: [PATCH 16/41] value: impl DeserializeValue for Counter Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index ca51fdc11a..581699c2fd 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -10,7 +10,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; -use crate::frame::value::{CqlDecimal, CqlVarint}; +use crate::frame::value::{Counter, CqlDecimal, CqlVarint}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -343,6 +343,19 @@ impl_string_type!( // TODO: Consider support for deserialization of string::String +// counter + +impl_strict_type!( + Counter, + Counter, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let counter = i64::from_be_bytes(*arr); + Ok(Counter(counter)) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -534,7 +547,7 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; - use crate::frame::value::{CqlDecimal, CqlVarint}; + use crate::frame::value::{Counter, CqlDecimal, CqlVarint}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -752,6 +765,12 @@ mod tests { compat_check::(typ, make_bytes("foo".as_bytes())); compat_check::(typ, make_bytes("superfragilisticexpialidocious".as_bytes())); } + + // counters + for i in 0..63 { + let v: i64 = 1 << i; + compat_check::(&ColumnType::Counter, make_bytes(&v.to_be_bytes())); + } } // Checks that both new and old serialization framework From d76cd29f12ff51cde05acce43007a1fee33547c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sat, 23 Mar 2024 15:03:10 +0100 Subject: [PATCH 17/41] value: impl DeserializeValue for duration types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 72 ++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 581699c2fd..e22cad4023 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -10,7 +10,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; -use crate::frame::value::{Counter, CqlDecimal, CqlVarint}; +use crate::frame::value::{Counter, CqlDecimal, CqlDuration, CqlVarint}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -356,6 +356,51 @@ impl_strict_type!( } ); +// date and time types + +// duration +impl_strict_type!( + CqlDuration, + Duration, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + + macro_rules! mk_err { + ($err: expr) => { + mk_deser_err::(typ, $err) + }; + } + + let months_i64 = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + let months = i32::try_from(months_i64) + .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; + + let days_i64 = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + let days = i32::try_from(days_i64) + .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; + + let nanoseconds = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + + Ok(CqlDuration { + months, + days, + nanoseconds, + }) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -516,6 +561,10 @@ pub enum BuiltinDeserializationErrorKind { /// Invalid UTF-8 string. InvalidUtf8(std::str::Utf8Error), + + /// The read value is out of range supported by the Rust type. + // TODO: consider storing additional info here (what exactly did not fit and why) + ValueOverflow, } impl Display for BuiltinDeserializationErrorKind { @@ -534,6 +583,11 @@ impl Display for BuiltinDeserializationErrorKind { f.write_str("expected a valid ASCII string") } BuiltinDeserializationErrorKind::InvalidUtf8(err) => err.fmt(f), + BuiltinDeserializationErrorKind::ValueOverflow => { + // TODO: consider storing Arc of the offending value + // inside this variant for debug purposes. + f.write_str("read value is out of representable range") + } } } } @@ -547,7 +601,7 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; - use crate::frame::value::{Counter, CqlDecimal, CqlVarint}; + use crate::frame::value::{Counter, CqlDecimal, CqlDuration, CqlVarint}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -771,6 +825,20 @@ mod tests { let v: i64 = 1 << i; compat_check::(&ColumnType::Counter, make_bytes(&v.to_be_bytes())); } + + // duration + let duration1 = CqlDuration { + days: 123, + months: 456, + nanoseconds: 789, + }; + let duration2 = CqlDuration { + days: 987, + months: 654, + nanoseconds: 321, + }; + compat_check_serialized::(&ColumnType::Duration, &duration1); + compat_check_serialized::(&ColumnType::Duration, &duration2); } // Checks that both new and old serialization framework From f1ff504228bf6e459f3d047a09ef25654386712a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:33:31 +0100 Subject: [PATCH 18/41] value: impl DeserializeValue for date types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 82 ++++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index e22cad4023..28b815eb8f 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -10,7 +10,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; -use crate::frame::value::{Counter, CqlDecimal, CqlDuration, CqlVarint}; +use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlVarint}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -401,6 +401,61 @@ impl_strict_type!( } ); +impl_emptiable_strict_type!( + CqlDate, + Date, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let days = u32::from_be_bytes(*arr); + Ok(CqlDate(days)) + } +); + +#[cfg(any(feature = "chrono", feature = "time"))] +fn get_days_since_epoch_from_date_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let days = u32::from_be_bytes(*arr); + let days_since_epoch = days as i64 - (1i64 << 31); + Ok(days_since_epoch) +} + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::NaiveDate, + Date, + |typ: &'frame ColumnType, v: Option>| { + let fail = || mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow); + let days_since_epoch = + chrono::Duration::try_days(get_days_since_epoch_from_date_column::(typ, v)?) + .ok_or_else(fail)?; + chrono::NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .checked_add_signed(days_since_epoch) + .ok_or_else(fail) + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::Date, + Date, + |typ: &'frame ColumnType, v: Option>| { + let days_since_epoch = + time::Duration::days(get_days_since_epoch_from_date_column::(typ, v)?); + time::Date::from_calendar_date(1970, time::Month::January, 1) + .unwrap() + .checked_add(days_since_epoch) + .ok_or_else(|| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + }) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -601,7 +656,7 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; - use crate::frame::value::{Counter, CqlDecimal, CqlDuration, CqlVarint}; + use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlVarint}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -839,6 +894,29 @@ mod tests { }; compat_check_serialized::(&ColumnType::Duration, &duration1); compat_check_serialized::(&ColumnType::Duration, &duration2); + + // date + let date1 = (2u32.pow(31)).to_be_bytes(); + let date2 = (2u32.pow(31) - 30).to_be_bytes(); + let date3 = (2u32.pow(31) + 30).to_be_bytes(); + + compat_check::(&ColumnType::Date, make_bytes(&date1)); + compat_check::(&ColumnType::Date, make_bytes(&date2)); + compat_check::(&ColumnType::Date, make_bytes(&date3)); + + #[cfg(feature = "chrono")] + { + compat_check::(&ColumnType::Date, make_bytes(&date1)); + compat_check::(&ColumnType::Date, make_bytes(&date2)); + compat_check::(&ColumnType::Date, make_bytes(&date3)); + } + + #[cfg(feature = "time")] + { + compat_check::(&ColumnType::Date, make_bytes(&date1)); + compat_check::(&ColumnType::Date, make_bytes(&date2)); + compat_check::(&ColumnType::Date, make_bytes(&date3)); + } } // Checks that both new and old serialization framework From f49adc32f865c5514bc5f4f98a2ce377baa2ae9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sat, 23 Mar 2024 15:00:39 +0100 Subject: [PATCH 19/41] value: impl DeserializeValue for time types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 84 ++++++++++++++++++++++- 1 file changed, 82 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 28b815eb8f..d9375bafb6 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -10,7 +10,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; -use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlVarint}; +use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlVarint}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -456,6 +456,63 @@ impl_emptiable_strict_type!( } ); +fn get_nanos_from_time_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let nanoseconds = i64::from_be_bytes(*arr); + + // Valid values are in the range 0 to 86399999999999 + if !(0..=86399999999999).contains(&nanoseconds) { + return Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ValueOverflow, + )); + } + + Ok(nanoseconds) +} + +impl_emptiable_strict_type!( + CqlTime, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + Ok(CqlTime(nanoseconds)) + } +); + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::NaiveTime, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + let naive_time: chrono::NaiveTime = CqlTime(nanoseconds).try_into().map_err(|_| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + })?; + Ok(naive_time) + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::Time, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + let time: time::Time = CqlTime(nanoseconds).try_into().map_err(|_| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + })?; + Ok(time) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -656,7 +713,7 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; - use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlVarint}; + use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlVarint}; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -917,6 +974,29 @@ mod tests { compat_check::(&ColumnType::Date, make_bytes(&date2)); compat_check::(&ColumnType::Date, make_bytes(&date3)); } + + // time + let time1 = CqlTime(0); + let time2 = CqlTime(123456789); + let time3 = CqlTime(86399999999999); // maximum allowed + + compat_check_serialized::(&ColumnType::Time, &time1); + compat_check_serialized::(&ColumnType::Time, &time2); + compat_check_serialized::(&ColumnType::Time, &time3); + + #[cfg(feature = "chrono")] + { + compat_check_serialized::(&ColumnType::Time, &time1); + compat_check_serialized::(&ColumnType::Time, &time2); + compat_check_serialized::(&ColumnType::Time, &time3); + } + + #[cfg(feature = "time")] + { + compat_check_serialized::(&ColumnType::Time, &time1); + compat_check_serialized::(&ColumnType::Time, &time2); + compat_check_serialized::(&ColumnType::Time, &time3); + } } // Checks that both new and old serialization framework From cf96543ba721e679d9825d3db516f9981498668a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sat, 23 Mar 2024 14:59:01 +0100 Subject: [PATCH 20/41] value: impl DeserializeValue for timestamp types Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 89 ++++++++++++++++++++++- 1 file changed, 87 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index d9375bafb6..6754d7511c 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -10,7 +10,9 @@ use super::{DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; -use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlVarint}; +use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlVarint, +}; /// A type that can be deserialized from a column value inside a row that was /// returned from a query. @@ -513,6 +515,55 @@ impl_emptiable_strict_type!( } ); +fn get_millis_from_timestamp_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let millis = i64::from_be_bytes(*arr); + + Ok(millis) +} + +impl_emptiable_strict_type!( + CqlTimestamp, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + let millis = get_millis_from_timestamp_column::(typ, v)?; + Ok(CqlTimestamp(millis)) + } +); + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::DateTime, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + use chrono::TimeZone as _; + + let millis = get_millis_from_timestamp_column::(typ, v)?; + match chrono::Utc.timestamp_millis_opt(millis) { + chrono::LocalResult::Single(datetime) => Ok(datetime), + _ => Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ValueOverflow, + )), + } + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::OffsetDateTime, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + let millis = get_millis_from_timestamp_column::(typ, v)?; + time::OffsetDateTime::from_unix_timestamp_nanos(millis as i128 * 1_000_000) + .map_err(|_| mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow)) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -713,7 +764,9 @@ mod tests { use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; - use crate::frame::value::{Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlVarint}; + use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlVarint, + }; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -997,6 +1050,38 @@ mod tests { compat_check_serialized::(&ColumnType::Time, &time2); compat_check_serialized::(&ColumnType::Time, &time3); } + + // timestamp + let timestamp1 = CqlTimestamp(0); + let timestamp2 = CqlTimestamp(123456789); + let timestamp3 = CqlTimestamp(98765432123456); + + compat_check_serialized::(&ColumnType::Timestamp, ×tamp1); + compat_check_serialized::(&ColumnType::Timestamp, ×tamp2); + compat_check_serialized::(&ColumnType::Timestamp, ×tamp3); + + #[cfg(feature = "chrono")] + { + compat_check_serialized::>( + &ColumnType::Timestamp, + ×tamp1, + ); + compat_check_serialized::>( + &ColumnType::Timestamp, + ×tamp2, + ); + compat_check_serialized::>( + &ColumnType::Timestamp, + ×tamp3, + ); + } + + #[cfg(feature = "time")] + { + compat_check_serialized::(&ColumnType::Timestamp, ×tamp1); + compat_check_serialized::(&ColumnType::Timestamp, ×tamp2); + compat_check_serialized::(&ColumnType::Timestamp, ×tamp3); + } } // Checks that both new and old serialization framework From 9b15e238299ba45c50552f929505543b32574024 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:34:06 +0100 Subject: [PATCH 21/41] value: impl DeserializeValue for inet Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 43 +++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 6754d7511c..ae1a2ad864 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,5 +1,7 @@ //! Provides types for dealing with CQL value deserialization. +use std::net::IpAddr; + use bytes::Bytes; use std::fmt::Display; @@ -564,6 +566,26 @@ impl_emptiable_strict_type!( } ); +// inet + +impl_emptiable_strict_type!( + IpAddr, + Inet, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + if let Ok(ipv4) = <[u8; 4]>::try_from(val) { + Ok(IpAddr::from(ipv4)) + } else if let Ok(ipv6) = <[u8; 16]>::try_from(val) { + Ok(IpAddr::from(ipv6)) + } else { + Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::BadInetLength { got: val.len() }, + )) + } + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -728,6 +750,9 @@ pub enum BuiltinDeserializationErrorKind { /// The read value is out of range supported by the Rust type. // TODO: consider storing additional info here (what exactly did not fit and why) ValueOverflow, + + /// The length of read value in bytes is not suitable for IP address. + BadInetLength { got: usize }, } impl Display for BuiltinDeserializationErrorKind { @@ -751,6 +776,10 @@ impl Display for BuiltinDeserializationErrorKind { // inside this variant for debug purposes. f.write_str("read value is out of representable range") } + BuiltinDeserializationErrorKind::BadInetLength { got } => write!( + f, + "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" + ), } } } @@ -760,6 +789,7 @@ mod tests { use bytes::{BufMut, Bytes, BytesMut}; use std::fmt::Debug; + use std::net::{IpAddr, Ipv6Addr}; use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; @@ -1082,6 +1112,12 @@ mod tests { compat_check_serialized::(&ColumnType::Timestamp, ×tamp2); compat_check_serialized::(&ColumnType::Timestamp, ×tamp3); } + + // inet + let ipv4 = IpAddr::from([127u8, 0, 0, 1]); + let ipv6: IpAddr = Ipv6Addr::LOCALHOST.into(); + compat_check::(&ColumnType::Inet, make_ip_address(ipv4)); + compat_check::(&ColumnType::Inet, make_ip_address(ipv6)); } // Checks that both new and old serialization framework @@ -1150,6 +1186,13 @@ mod tests { *buf = v.into(); } + fn make_ip_address(ip: IpAddr) -> Bytes { + match ip { + IpAddr::V4(v4) => make_bytes(&v4.octets()), + IpAddr::V6(v6) => make_bytes(&v6.octets()), + } + } + fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { b.put_i32(cell.len() as i32); b.put_slice(cell); From 6738ac9a1530615b203d195c1d5e120301a07972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 11:37:12 +0200 Subject: [PATCH 22/41] value: impl DeserializeValue for Uuid and Timeuuid Co-authored-by: Piotr Dulikowski --- scylla-cql/Cargo.toml | 1 + scylla-cql/src/types/deserialize/value.rs | 37 +++++++++++++++++++++-- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/scylla-cql/Cargo.toml b/scylla-cql/Cargo.toml index 042564fd25..547404cae7 100644 --- a/scylla-cql/Cargo.toml +++ b/scylla-cql/Cargo.toml @@ -33,6 +33,7 @@ assert_matches = "1.5.0" criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 # Use large-dates feature to test potential edge cases time = { version = "0.3.21", features = ["large-dates"] } +uuid = { version = "1.0", features = ["v4"] } [[bench]] name = "benchmark" diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index ae1a2ad864..8971ebbe79 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -3,6 +3,7 @@ use std::net::IpAddr; use bytes::Bytes; +use uuid::Uuid; use std::fmt::Display; @@ -13,7 +14,7 @@ use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; use crate::frame::value::{ - Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlVarint, + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; /// A type that can be deserialized from a column value inside a row that was @@ -586,6 +587,30 @@ impl_emptiable_strict_type!( } ); +// uuid + +impl_emptiable_strict_type!( + Uuid, + Uuid, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let i = u128::from_be_bytes(*arr); + Ok(uuid::Uuid::from_u128(i)) + } +); + +impl_emptiable_strict_type!( + CqlTimeuuid, + Timeuuid, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let i = u128::from_be_bytes(*arr); + Ok(CqlTimeuuid::from(uuid::Uuid::from_u128(i))) + } +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -787,6 +812,7 @@ impl Display for BuiltinDeserializationErrorKind { #[cfg(test)] mod tests { use bytes::{BufMut, Bytes, BytesMut}; + use uuid::Uuid; use std::fmt::Debug; use std::net::{IpAddr, Ipv6Addr}; @@ -795,7 +821,7 @@ mod tests { use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; use crate::frame::value::{ - Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlVarint, + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; use crate::types::deserialize::{DeserializationError, FrameSlice}; use crate::types::serialize::value::SerializeValue; @@ -1118,6 +1144,13 @@ mod tests { let ipv6: IpAddr = Ipv6Addr::LOCALHOST.into(); compat_check::(&ColumnType::Inet, make_ip_address(ipv4)); compat_check::(&ColumnType::Inet, make_ip_address(ipv6)); + + // uuid and timeuuid + // new_v4 generates random UUIDs, so these are different cases + for uuid in std::iter::repeat_with(Uuid::new_v4).take(3) { + compat_check_serialized::(&ColumnType::Uuid, &uuid); + compat_check_serialized::(&ColumnType::Timeuuid, &CqlTimeuuid::from(uuid)); + } } // Checks that both new and old serialization framework From 4f8364327447821256d5cb5b3b28e2ff1ae5322a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 15:36:10 +0100 Subject: [PATCH 23/41] value: impl DeserializeValue for Secrecy Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 8971ebbe79..796b2fec8c 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -611,6 +611,24 @@ impl_emptiable_strict_type!( } ); +// secrecy +#[cfg(feature = "secret")] +impl<'frame, T> DeserializeValue<'frame> for secrecy::Secret +where + T: DeserializeValue<'frame> + secrecy::Zeroize, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + >::deserialize(typ, v).map(secrecy::Secret::new) + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( From bd6196ccec792a7d2e3c549b45a96f310144de8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 07:19:26 +0200 Subject: [PATCH 24/41] value: Option/null tests Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 796b2fec8c..84712fe946 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1169,6 +1169,13 @@ mod tests { compat_check_serialized::(&ColumnType::Uuid, &uuid); compat_check_serialized::(&ColumnType::Timeuuid, &CqlTimeuuid::from(uuid)); } + + // empty values + // ...are implemented via MaybeEmpty and are handled in other tests + + // nulls, represented via Option + compat_check_serialized::>(&ColumnType::Int, &123i32); + compat_check::>(&ColumnType::Int, make_null()); } // Checks that both new and old serialization framework @@ -1248,4 +1255,14 @@ mod tests { b.put_i32(cell.len() as i32); b.put_slice(cell); } + + fn make_null() -> Bytes { + let mut b = BytesMut::new(); + append_null(&mut b); + b.freeze() + } + + fn append_null(b: &mut impl BufMut) { + b.put_i32(-1); + } } From ebbac62ac01c0a2ace37ae8aebce106c530e582c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 09:28:25 +0200 Subject: [PATCH 25/41] value: test null and empty --- scylla-cql/src/types/deserialize/value.rs | 30 ++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 84712fe946..18930e10c7 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -845,7 +845,7 @@ mod tests { use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; - use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue}; + use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MaybeEmpty}; #[test] fn test_deserialize_bytes() { @@ -934,6 +934,34 @@ mod tests { assert_eq!(decoded_double, 2.0); } + #[test] + fn test_null_and_empty() { + // non-nullable emptiable deserialization, non-empty value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); + + // non-nullable emptiable deserialization, empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Empty); + + // nullable non-emptiable deserialization, non-null value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); + + // nullable non-emptiable deserialization, null value + let int = make_null(); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, None); + + // nullable emptiable deserialization, non-null non-empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); + } + #[test] fn test_from_cql_value_compatibility() { // This test should have a sub-case for each type From e13bfa941d3418505d02721922044ff349fe1c89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 16:00:16 +0100 Subject: [PATCH 26/41] value: impl DeserializeValue for List and Set There is a purposeful change from the previous framework: CQL List can no longer be deserialized straight to a HashSet or a BTreeSet. Such deserialization would be lossy, which is a property we don't want in our framework. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/mod.rs | 30 ++ scylla-cql/src/types/deserialize/value.rs | 398 +++++++++++++++++++++- 2 files changed, 425 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 70b9f7f7c4..f148f23484 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -234,6 +234,36 @@ impl Display for DeserializationError { } } +// This is a hack to enable setting the proper Rust type name in error messages, +// even though the error originates from some helper type used underneath. +// ASSUMPTION: This should be used: +// - ONLY in proper type_check()/deserialize() implementation, +// - BEFORE an error is cloned (because otherwise the Arc::get_mut fails). +macro_rules! make_error_replace_rust_name { + ($fn_name: ident, $outer_err: ty, $inner_err: ty) => { + fn $fn_name(mut err: $outer_err) -> $outer_err { + // Safety: the assumed usage of this function guarantees that the Arc has not yet been cloned. + let arc_mut = std::sync::Arc::get_mut(&mut err.0).unwrap(); + + let rust_name: &mut &str = { + if let Some(err) = arc_mut.downcast_mut::<$inner_err>() { + &mut err.rust_name + } else { + unreachable!(concat!( + "This function is assumed to be called only on built-in ", + stringify!($inner_err), + " kinds." + )) + } + }; + + *rust_name = std::any::type_name::(); + err + } + }; +} +use make_error_replace_rust_name; + #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 18930e10c7..1808843fe8 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,6 +1,10 @@ //! Provides types for dealing with CQL value deserialization. -use std::net::IpAddr; +use std::{ + collections::{BTreeSet, HashSet}, + hash::{BuildHasher, Hash}, + net::IpAddr, +}; use bytes::Bytes; use uuid::Uuid; @@ -9,7 +13,7 @@ use std::fmt::Display; use thiserror::Error; -use super::{DeserializationError, FrameSlice, TypeCheckError}; +use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; use crate::frame::types; @@ -629,6 +633,196 @@ where } } +// collections + +make_error_replace_rust_name!( + typck_error_replace_rust_name, + TypeCheckError, + BuiltinTypeCheckError +); + +make_error_replace_rust_name!( + deser_error_replace_rust_name, + DeserializationError, + BuiltinDeserializationError +); + +// lists and sets + +/// An iterator over either a CQL set or list. +pub struct ListlikeIterator<'frame, T> { + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} + +impl<'frame, T> ListlikeIterator<'frame, T> { + fn new( + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + count: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + elem_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data: std::marker::PhantomData, + } + } +} + +impl<'frame, T> DeserializeValue<'frame> for ListlikeIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::List(el_t) | ColumnType::Set(el_t) => { + >::type_check(el_t).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }) + } + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList, + ), + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { + mk_deser_err::( + typ, + SetOrListDeserializationErrorKind::LengthDeserializationFailed( + DeserializationError::new(err), + ), + ) + })?; + let elem_typ = match typ { + ColumnType::List(elem_typ) | ColumnType::Set(elem_typ) => elem_typ, + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(typ, elem_typ, count, v)) + } +} + +impl<'frame, T> Iterator for ListlikeIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?.map_err(|err| { + mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ) + }); + Some(raw.and_then(|raw| { + T::deserialize(self.elem_typ, raw).map_err(|err| { + mk_deser_err::( + self.coll_typ, + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) + }) + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, T> DeserializeValue<'frame> for Vec +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It makes sense for both Set and List to deserialize to Vec. + ListlikeIterator::<'frame, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, T> DeserializeValue<'frame> for BTreeSet +where + T: DeserializeValue<'frame> + Ord, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It only makes sense for Set to deserialize to BTreeSet. + // Deserializing List straight to BTreeSet would be lossy. + match typ { + ColumnType::Set(el_t) => >::type_check(el_t) + .map_err(typck_error_replace_rust_name::), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSet, + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, T, S> DeserializeValue<'frame> for HashSet +where + T: DeserializeValue<'frame> + Eq + Hash, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It only makes sense for Set to deserialize to HashSet. + // Deserializing List straight to HashSet would be lossy. + match typ { + ColumnType::Set(el_t) => >::type_check(el_t) + .map_err(typck_error_replace_rust_name::), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSet, + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -667,6 +861,37 @@ fn ensure_exact_length<'frame, T, const SIZE: usize>( }) } +// Helper iterators + +/// Iterates over a sequence of `[bytes]` items from a frame subslice, expecting +/// a particular number of items. +/// +/// The iterator does not consider it to be an error if there are some bytes +/// remaining in the slice after parsing requested amount of items. +#[derive(Clone, Copy, Debug)] +pub struct FixedLengthBytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, + remaining: usize, +} + +impl<'frame> FixedLengthBytesSequenceIterator<'frame> { + fn new(count: usize, slice: FrameSlice<'frame>) -> Self { + Self { + slice, + remaining: count, + } + } +} + +impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + Some(self.slice.read_cql_bytes()) + } +} + // Error facilities /// Type checking of one of the built-in types failed. @@ -726,6 +951,16 @@ pub enum BuiltinTypeCheckErrorKind { /// The list of types that the Rust type can deserialize from. expected: &'static [ColumnType], }, + + /// A type check failure specific to a CQL set or list. + SetOrListError(SetOrListTypeCheckErrorKind), +} + +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: SetOrListTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::SetOrListError(value) + } } impl Display for BuiltinTypeCheckErrorKind { @@ -734,6 +969,35 @@ impl Display for BuiltinTypeCheckErrorKind { BuiltinTypeCheckErrorKind::MismatchedType { expected } => { write!(f, "expected one of the CQL types: {expected:?}") } + BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + } + } +} + +/// Describes why type checking of a set or list type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum SetOrListTypeCheckErrorKind { + /// The CQL type is neither a set not a list. + NotSetOrList, + /// The CQL type is not a set. + NotSet, + /// Incompatible element types. + ElementTypeCheckFailed(TypeCheckError), +} + +impl Display for SetOrListTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListTypeCheckErrorKind::NotSetOrList => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set nor a list") + } + SetOrListTypeCheckErrorKind::NotSet => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was not a set") + } + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err) => { + write!(f, "the set or list element types between the CQL type and the Rust type failed to type check against each other: {}", err) + } } } } @@ -796,6 +1060,9 @@ pub enum BuiltinDeserializationErrorKind { /// The length of read value in bytes is not suitable for IP address. BadInetLength { got: usize }, + + /// A deserialization failure specific to a CQL set or list. + SetOrListError(SetOrListDeserializationErrorKind), } impl Display for BuiltinDeserializationErrorKind { @@ -823,15 +1090,48 @@ impl Display for BuiltinDeserializationErrorKind { f, "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" ), + BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), + } + } +} + +/// Describes why deserialization of a set or list type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum SetOrListDeserializationErrorKind { + /// Failed to deserialize set or list's length. + LengthDeserializationFailed(DeserializationError), + + /// One of the elements of the set/list failed to deserialize. + ElementDeserializationFailed(DeserializationError), +} + +impl Display for SetOrListDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListDeserializationErrorKind::LengthDeserializationFailed(err) => { + write!(f, "failed to deserialize set or list's length: {}", err) + } + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the elements: {}", err) + } } } } +impl From for BuiltinDeserializationErrorKind { + #[inline] + fn from(err: SetOrListDeserializationErrorKind) -> Self { + Self::SetOrListError(err) + } +} + #[cfg(test)] mod tests { use bytes::{BufMut, Bytes, BytesMut}; use uuid::Uuid; + use std::collections::{BTreeSet, HashSet}; use std::fmt::Debug; use std::net::{IpAddr, Ipv6Addr}; @@ -845,7 +1145,10 @@ mod tests { use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; - use super::{mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MaybeEmpty}; + use super::{ + mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, ListlikeIterator, + MaybeEmpty, + }; #[test] fn test_deserialize_bytes() { @@ -1204,6 +1507,95 @@ mod tests { // nulls, represented via Option compat_check_serialized::>(&ColumnType::Int, &123i32); compat_check::>(&ColumnType::Int, make_null()); + + // collections + let mut list = BytesMut::new(); + list.put_i32(3); + append_bytes(&mut list, &123i32.to_be_bytes()); + append_bytes(&mut list, &456i32.to_be_bytes()); + append_bytes(&mut list, &789i32.to_be_bytes()); + let list = make_bytes(&list); + let list_type = ColumnType::List(Box::new(ColumnType::Int)); + compat_check::>(&list_type, list.clone()); + // Support for deserialization List -> {Hash,BTree}Set was removed not to cause confusion. + // Such deserialization would be lossy, which is unwanted. + + let mut set = BytesMut::new(); + set.put_i32(3); + append_bytes(&mut set, &123i32.to_be_bytes()); + append_bytes(&mut set, &456i32.to_be_bytes()); + append_bytes(&mut set, &789i32.to_be_bytes()); + let set = make_bytes(&set); + let set_type = ColumnType::Set(Box::new(ColumnType::Int)); + compat_check::>(&set_type, set.clone()); + compat_check::>(&set_type, set.clone()); + compat_check::>(&set_type, set); + } + + #[test] + fn test_maybe_empty() { + let empty = make_bytes(&[]); + let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); + assert_eq!(decoded_empty, MaybeEmpty::Empty); + + let non_empty = make_bytes(&[0x01]); + let decoded_non_empty = + deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); + assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); + } + + #[test] + fn test_list_and_set() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some("quick")); + assert_eq!(iter.next().transpose().unwrap(), Some("brown")); + assert_eq!(iter.next().transpose().unwrap(), Some("fox")); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_vec_str = vec!["quick", "brown", "fox"]; + let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; + + // list + let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); + let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(decoded_vec_str, expected_vec_str); + assert_eq!(decoded_vec_string, expected_vec_string); + + // hash set + let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_hash_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_hash_string, + expected_vec_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_btree_string, + expected_vec_string.into_iter().collect(), + ); } // Checks that both new and old serialization framework From 92a432493f06db6d602bdb419e0c4cc32db1fee8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 16:01:47 +0100 Subject: [PATCH 27/41] value: impl DeserializeValue for Map Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 309 +++++++++++++++++++++- 1 file changed, 306 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 1808843fe8..65fab12d7f 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1,7 +1,7 @@ //! Provides types for dealing with CQL value deserialization. use std::{ - collections::{BTreeSet, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, hash::{BuildHasher, Hash}, net::IpAddr, }; @@ -823,6 +823,169 @@ where } } +/// An iterator over a CQL map. +pub struct MapIterator<'frame, K, V> { + coll_typ: &'frame ColumnType, + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, +} + +impl<'frame, K, V> MapIterator<'frame, K, V> { + fn new( + coll_typ: &'frame ColumnType, + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + count: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + k_typ, + v_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, + } + } +} + +impl<'frame, K, V> DeserializeValue<'frame> for MapIterator<'frame, K, V> +where + K: DeserializeValue<'frame>, + V: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::Map(k_t, v_t) => { + >::type_check(k_t).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::KeyTypeCheckFailed(err)) + })?; + >::type_check(v_t).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::ValueTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::(typ, MapTypeCheckErrorKind::NotMap)), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { + mk_deser_err::( + typ, + MapDeserializationErrorKind::LengthDeserializationFailed( + DeserializationError::new(err), + ), + ) + })?; + let (k_typ, v_typ) = match typ { + ColumnType::Map(k_t, v_t) => (k_t, v_t), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(typ, k_typ, v_typ, 2 * count, v)) + } +} + +impl<'frame, K, V> Iterator for MapIterator<'frame, K, V> +where + K: DeserializeValue<'frame>, + V: DeserializeValue<'frame>, +{ + type Item = Result<(K, V), DeserializationError>; + + fn next(&mut self) -> Option { + let raw_k = match self.raw_iter.next() { + Some(Ok(raw_k)) => raw_k, + Some(Err(err)) => { + return Some(Err(mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ))); + } + None => return None, + }; + let raw_v = match self.raw_iter.next() { + Some(Ok(raw_v)) => raw_v, + Some(Err(err)) => { + return Some(Err(mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ))); + } + None => return None, + }; + + let do_next = || -> Result<(K, V), DeserializationError> { + let k = K::deserialize(self.k_typ, raw_k).map_err(|err| { + mk_deser_err::( + self.coll_typ, + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) + })?; + let v = V::deserialize(self.v_typ, raw_v).map_err(|err| { + mk_deser_err::( + self.coll_typ, + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) + })?; + Ok((k, v)) + }; + Some(do_next()) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, K, V> DeserializeValue<'frame> for BTreeMap +where + K: DeserializeValue<'frame> + Ord, + V: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, K, V, S> DeserializeValue<'frame> for HashMap +where + K: DeserializeValue<'frame> + Eq + Hash, + V: DeserializeValue<'frame>, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -954,6 +1117,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL set or list. SetOrListError(SetOrListTypeCheckErrorKind), + + /// A type check failure specific to a CQL map. + MapError(MapTypeCheckErrorKind), } impl From for BuiltinTypeCheckErrorKind { @@ -963,6 +1129,13 @@ impl From for BuiltinTypeCheckErrorKind { } } +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: MapTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::MapError(value) + } +} + impl Display for BuiltinTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -970,6 +1143,7 @@ impl Display for BuiltinTypeCheckErrorKind { write!(f, "expected one of the CQL types: {expected:?}") } BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), } } } @@ -1002,6 +1176,34 @@ impl Display for SetOrListTypeCheckErrorKind { } } +/// Describes why type checking of a map type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum MapTypeCheckErrorKind { + /// The CQL type is not a map. + NotMap, + /// Incompatible key types. + KeyTypeCheckFailed(TypeCheckError), + /// Incompatible value types. + ValueTypeCheckFailed(TypeCheckError), +} + +impl Display for MapTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapTypeCheckErrorKind::NotMap => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a map") + } + MapTypeCheckErrorKind::KeyTypeCheckFailed(err) => { + write!(f, "the map key types between the CQL type and the Rust type failed to type check against each other: {}", err) + }, + MapTypeCheckErrorKind::ValueTypeCheckFailed(err) => { + write!(f, "the map value types between the CQL type and the Rust type failed to type check against each other: {}", err) + }, + } + } +} + /// Deserialization of one of the built-in types failed. #[derive(Debug, Error)] #[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")] @@ -1063,6 +1265,9 @@ pub enum BuiltinDeserializationErrorKind { /// A deserialization failure specific to a CQL set or list. SetOrListError(SetOrListDeserializationErrorKind), + + /// A deserialization failure specific to a CQL map. + MapError(MapDeserializationErrorKind), } impl Display for BuiltinDeserializationErrorKind { @@ -1091,6 +1296,7 @@ impl Display for BuiltinDeserializationErrorKind { "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" ), BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), } } } @@ -1126,12 +1332,48 @@ impl From for BuiltinDeserializationErrorKind } } +/// Describes why deserialization of a map type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum MapDeserializationErrorKind { + /// Failed to deserialize map's length. + LengthDeserializationFailed(DeserializationError), + + /// One of the keys in the map failed to deserialize. + KeyDeserializationFailed(DeserializationError), + + /// One of the values in the map failed to deserialize. + ValueDeserializationFailed(DeserializationError), +} + +impl Display for MapDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapDeserializationErrorKind::LengthDeserializationFailed(err) => { + write!(f, "failed to deserialize map's length: {}", err) + } + MapDeserializationErrorKind::KeyDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the keys: {}", err) + } + MapDeserializationErrorKind::ValueDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the values: {}", err) + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + fn from(err: MapDeserializationErrorKind) -> Self { + Self::MapError(err) + } +} + #[cfg(test)] mod tests { use bytes::{BufMut, Bytes, BytesMut}; use uuid::Uuid; - use std::collections::{BTreeSet, HashSet}; + use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::net::{IpAddr, Ipv6Addr}; @@ -1147,7 +1389,7 @@ mod tests { use super::{ mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, ListlikeIterator, - MaybeEmpty, + MapIterator, MaybeEmpty, }; #[test] @@ -1530,6 +1772,19 @@ mod tests { compat_check::>(&set_type, set.clone()); compat_check::>(&set_type, set.clone()); compat_check::>(&set_type, set); + + let mut map = BytesMut::new(); + map.put_i32(3); + append_bytes(&mut map, &123i32.to_be_bytes()); + append_bytes(&mut map, "quick".as_bytes()); + append_bytes(&mut map, &456i32.to_be_bytes()); + append_bytes(&mut map, "brown".as_bytes()); + append_bytes(&mut map, &789i32.to_be_bytes()); + append_bytes(&mut map, "fox".as_bytes()); + let map = make_bytes(&map); + let map_type = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Text)); + compat_check::>(&map_type, map.clone()); + compat_check::>(&map_type, map); } #[test] @@ -1598,6 +1853,54 @@ mod tests { ); } + #[test] + fn test_map() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, &1i32.to_be_bytes()); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, &2i32.to_be_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, &3i32.to_be_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); + assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); + assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; + let expected_string = vec![ + (1, "quick".to_string()), + (2, "brown".to_string()), + (3, "fox".to_string()), + ]; + + // hash set + let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); + assert_eq!( + decoded_hash_string, + expected_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_str.clone().into_iter().collect(), + ); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect(),); + } + // Checks that both new and old serialization framework // produces the same results in this case fn compat_check(typ: &ColumnType, raw: Bytes) From 2bf4ca05501b1344e88471bb0f478f0c83360ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 11:05:06 +0200 Subject: [PATCH 28/41] value: impl DeserializeValue for tuples Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 241 ++++++++++++++++++++++ 1 file changed, 241 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 65fab12d7f..6996dc9e74 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -986,6 +986,115 @@ where } } +// tuples + +// Implements tuple deserialization. +// The generated impl expects that the serialized data contains exactly the given amount of values. +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeValue<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeValue<'frame>),* + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + let [$($idf),*] = ensure_tuple_type::<($($Ti,)*), TUPLE_LEN>(typ)?; + $( + <$Ti>::type_check($idf).map_err(|err| mk_typck_err::( + typ, + TupleTypeCheckErrorKind::FieldTypeCheckFailed { + position: $idx, + err, + } + ))?; + )* + Ok(()) + } + + fn deserialize(typ: &'frame ColumnType, v: Option>) -> Result { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + // Safety: we are allowed to assume that type_check() was already called. + let [$($idf),*] = ensure_tuple_type::<($($Ti,)*), TUPLE_LEN>(typ) + .expect("Type check should have prevented this!"); + + // Ignore the warning for the zero-sized tuple + #[allow(unused)] + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let ret = ( + $( + v.read_cql_bytes() + .map_err(|err| DeserializationError::new(err)) + .and_then(|cql_bytes| <$Ti>::deserialize($idf, cql_bytes)) + .map_err(|err| mk_deser_err::( + typ, + TupleDeserializationErrorKind::FieldDeserializationFailed { + position: $idx, + err, + } + ) + )?, + )* + ); + Ok(ret) + } + } + } +} + +// Implements tuple deserialization for all tuple sizes up to predefined size. +// Accepts 3 lists, (see usage below the definition): +// - type parameters for the consecutive fields, +// - indices of the consecutive fields, +// - consecutive names for variables corresponding to each field. +// +// The idea is to recursively build prefixes of those lists (starting with an empty prefix) +// and for each prefix, implement deserialization for generic tuple represented by it. +// The < > brackets aid syntactically to separate the prefixes (positioned inside them) +// from the remaining suffixes (positioned beyond them). +macro_rules! impl_tuple_multiple { + // The entry point to the macro. + // Begins with implementing deserialization for (), then proceeds to the main recursive call. + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl_tuple!(;;); + impl_tuple_multiple!( + $($Ti),* ; < > ; + $($idx),*; < > ; + $($idf),*; < > + ); + }; + + // The termination condition. No more fields given to extend the tuple with. + (;< $($Ti:ident,)* >;;< $($idx:literal,)* >;;< $($idf:ident,)* >) => {}; + + // The recursion. Upon each call, a new field is appended to the tuple + // and deserialization is implemented for it. + ( + $T_head:ident $(,$T_suffix:ident)*; < $($T_prefix:ident,)* > ; + $idx_head:literal $(,$idx_suffix:literal)*; < $($idx_prefix:literal,)* >; + $idf_head:ident $(,$idf_suffix:ident)* ; <$($idf_prefix:ident,)*> + ) => { + impl_tuple!( + $($T_prefix,)* $T_head; + $($idx_prefix, )* $idx_head; + $($idf_prefix, )* $idf_head + ); + impl_tuple_multiple!( + $($T_suffix),* ; < $($T_prefix,)* $T_head, > ; + $($idx_suffix),*; < $($idx_prefix, )* $idx_head, > ; + $($idf_suffix),*; < $($idf_prefix, )* $idf_head, > + ); + } +} + +pub(super) use impl_tuple_multiple; + +// Implements tuple deserialization for all tuple sizes up to 16. +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -1024,6 +1133,24 @@ fn ensure_exact_length<'frame, T, const SIZE: usize>( }) } +fn ensure_tuple_type( + typ: &ColumnType, +) -> Result<&[ColumnType; SIZE], TypeCheckError> { + if let ColumnType::Tuple(typs_v) = typ { + typs_v.as_slice().try_into().map_err(|_| { + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: SIZE, + cql_type_el_count: typs_v.len(), + }) + }) + } else { + Err(BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::NotTuple, + )) + } + .map_err(|kind| mk_typck_err::(typ, kind)) +} + // Helper iterators /// Iterates over a sequence of `[bytes]` items from a frame subslice, expecting @@ -1120,6 +1247,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL map. MapError(MapTypeCheckErrorKind), + + /// A type check failure specific to a CQL tuple. + TupleError(TupleTypeCheckErrorKind), } impl From for BuiltinTypeCheckErrorKind { @@ -1136,6 +1266,13 @@ impl From for BuiltinTypeCheckErrorKind { } } +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: TupleTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::TupleError(value) + } +} + impl Display for BuiltinTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -1144,6 +1281,7 @@ impl Display for BuiltinTypeCheckErrorKind { } BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), } } } @@ -1204,6 +1342,57 @@ impl Display for MapTypeCheckErrorKind { } } +/// Describes why type checking of a tuple failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleTypeCheckErrorKind { + /// The CQL type is not a tuple. + NotTuple, + + /// The tuple has the wrong element count. + WrongElementCount { + /// The number of elements that the Rust tuple has. + rust_type_el_count: usize, + + /// The number of elements that the CQL tuple type has. + cql_type_el_count: usize, + }, + + /// The CQL type and the Rust type of a tuple field failed to type check against each other. + FieldTypeCheckFailed { + /// The index of the field whose type check failed. + position: usize, + + /// The type check error that occured. + err: TypeCheckError, + }, +} + +impl Display for TupleTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleTypeCheckErrorKind::NotTuple => write!( + f, + "the CQL type the tuple was attempted to be serialized to is not a tuple" + ), + TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count, + cql_type_el_count, + } => write!( + f, + "wrong tuple element count: CQL type has {cql_type_el_count}, the Rust tuple has {rust_type_el_count}" + ), + + TupleTypeCheckErrorKind::FieldTypeCheckFailed { position, err } => write!( + f, + "the CQL type and the Rust type of the tuple field {} failed to type check against each other: {}", + position, + err + ) + } + } +} + /// Deserialization of one of the built-in types failed. #[derive(Debug, Error)] #[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")] @@ -1268,6 +1457,9 @@ pub enum BuiltinDeserializationErrorKind { /// A deserialization failure specific to a CQL map. MapError(MapDeserializationErrorKind), + + /// A deserialization failure specific to a CQL tuple. + TupleError(TupleDeserializationErrorKind), } impl Display for BuiltinDeserializationErrorKind { @@ -1297,6 +1489,7 @@ impl Display for BuiltinDeserializationErrorKind { ), BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), } } } @@ -1368,6 +1561,39 @@ impl From for BuiltinDeserializationErrorKind { } } +/// Describes why deserialization of a tuple failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleDeserializationErrorKind { + /// One of the tuple fields failed to deserialize. + FieldDeserializationFailed { + /// Index of the tuple field that failed to deserialize. + position: usize, + + /// The error that caused the tuple field deserialization to fail. + err: DeserializationError, + }, +} + +impl Display for TupleDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleDeserializationErrorKind::FieldDeserializationFailed { + position: index, + err, + } => { + write!(f, "field no. {index} failed to deserialize: {err}") + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + fn from(err: TupleDeserializationErrorKind) -> Self { + Self::TupleError(err) + } +} + #[cfg(test)] mod tests { use bytes::{BufMut, Bytes, BytesMut}; @@ -1901,6 +2127,21 @@ mod tests { assert_eq!(decoded_btree_string, expected_string.into_iter().collect(),); } + #[test] + fn test_tuples() { + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_null(&mut tuple_contents); + + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); + + let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); + assert_eq!(tup, (42, "foo", None)); + } + // Checks that both new and old serialization framework // produces the same results in this case fn compat_check(typ: &ColumnType, raw: Bytes) From b02f021f1d952e29faf753d5203089f35e29032d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Thu, 21 Mar 2024 16:03:02 +0100 Subject: [PATCH 29/41] value: impl DeserializeValue for UDTs Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 202 +++++++++++++++++++++- 1 file changed, 201 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 6996dc9e74..6fe399d08b 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1095,6 +1095,111 @@ impl_tuple_multiple!( t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 ); +// udts + +/// An iterator over fields of a User Defined Type. +/// +/// # Note +/// +/// A serialized UDT will generally have one value for each field, but it is +/// allowed to have fewer. This iterator differentiates null values +/// from non-existent values in the following way: +/// +/// - `None` - missing from the serialized form +/// - `Some(None)` - present, but null +/// - `Some(Some(...))` - non-null, present value +pub struct UdtIterator<'frame> { + all_fields: &'frame [(String, ColumnType)], + type_name: &'frame str, + keyspace: &'frame str, + remaining_fields: &'frame [(String, ColumnType)], + raw_iter: BytesSequenceIterator<'frame>, +} + +impl<'frame> UdtIterator<'frame> { + fn new( + fields: &'frame [(String, ColumnType)], + type_name: &'frame str, + keyspace: &'frame str, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + all_fields: fields, + remaining_fields: fields, + type_name, + keyspace, + raw_iter: BytesSequenceIterator::new(slice), + } + } + + #[inline] + pub fn fields(&self) -> &'frame [(String, ColumnType)] { + self.remaining_fields + } +} + +impl<'frame> DeserializeValue<'frame> for UdtIterator<'frame> { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::UserDefinedType { .. } => Ok(()), + _ => Err(mk_typck_err::(typ, UdtTypeCheckErrorKind::NotUdt)), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_frame_slice::(typ, v)?; + let (fields, type_name, keyspace) = match typ { + ColumnType::UserDefinedType { + field_types, + type_name, + keyspace, + } => (field_types.as_ref(), type_name.as_ref(), keyspace.as_ref()), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(fields, type_name, keyspace, v)) + } +} + +impl<'frame> Iterator for UdtIterator<'frame> { + type Item = ( + &'frame (String, ColumnType), + Result>>, DeserializationError>, + ); + + fn next(&mut self) -> Option { + // TODO: Should we fail when there are too many fields? + let (head, fields) = self.remaining_fields.split_first()?; + self.remaining_fields = fields; + let raw_res = match self.raw_iter.next() { + // The field is there and it was parsed correctly + Some(Ok(raw)) => Ok(Some(raw)), + + // There were some bytes but they didn't parse as correct field value + Some(Err(err)) => Err(mk_deser_err::( + &ColumnType::UserDefinedType { + type_name: self.type_name.to_owned(), + keyspace: self.keyspace.to_owned(), + field_types: self.all_fields.to_owned(), + }, + BuiltinDeserializationErrorKind::GenericParseError(err), + )), + + // The field is just missing from the serialized form + None => Ok(None), + }; + Some((head, raw_res)) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + // Utilities fn ensure_not_null_frame_slice<'frame, T>( @@ -1182,6 +1287,39 @@ impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { } } +/// Iterates over a sequence of `[bytes]` items from a frame subslice. +/// +/// The `[bytes]` items are parsed until the end of subslice is reached. +#[derive(Clone, Copy, Debug)] +pub struct BytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, +} + +impl<'frame> BytesSequenceIterator<'frame> { + fn new(slice: FrameSlice<'frame>) -> Self { + Self { slice } + } +} + +impl<'frame> From> for BytesSequenceIterator<'frame> { + #[inline] + fn from(slice: FrameSlice<'frame>) -> Self { + Self::new(slice) + } +} + +impl<'frame> Iterator for BytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + if self.slice.as_slice().is_empty() { + None + } else { + Some(self.slice.read_cql_bytes()) + } + } +} + // Error facilities /// Type checking of one of the built-in types failed. @@ -1250,6 +1388,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL tuple. TupleError(TupleTypeCheckErrorKind), + + /// A type check failure specific to a CQL UDT. + UdtError(UdtTypeCheckErrorKind), } impl From for BuiltinTypeCheckErrorKind { @@ -1273,6 +1414,13 @@ impl From for BuiltinTypeCheckErrorKind { } } +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: UdtTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::UdtError(value) + } +} + impl Display for BuiltinTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -1282,6 +1430,7 @@ impl Display for BuiltinTypeCheckErrorKind { BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), } } } @@ -1393,6 +1542,25 @@ impl Display for TupleTypeCheckErrorKind { } } +/// Describes why type checking of a user defined type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtTypeCheckErrorKind { + /// The CQL type is not a user defined type. + NotUdt, +} + +impl Display for UdtTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtTypeCheckErrorKind::NotUdt => write!( + f, + "the CQL type the Rust type was attempted to be type checked against is not a UDT" + ), + } + } +} + /// Deserialization of one of the built-in types failed. #[derive(Debug, Error)] #[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")] @@ -1609,7 +1777,7 @@ mod tests { use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; - use crate::types::deserialize::{DeserializationError, FrameSlice}; + use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; @@ -2142,6 +2310,38 @@ mod tests { assert_eq!(tup, (42, "foo", None)); } + #[test] + fn test_custom_type_parser() { + #[derive(Default, Debug, PartialEq, Eq)] + struct SwappedPair(B, A); + impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair + where + A: DeserializeValue<'frame>, + B: DeserializeValue<'frame>, + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + <(B, A) as DeserializeValue<'frame>>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) + } + } + + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); + + let tup = deserialize::>(&typ, &tuple).unwrap(); + assert_eq!(tup, SwappedPair("foo", 42)); + } + // Checks that both new and old serialization framework // produces the same results in this case fn compat_check(typ: &ColumnType, raw: Bytes) From 82b5d9687723651a16f9de23d6329e52969ef77a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sun, 24 Mar 2024 13:49:36 +0100 Subject: [PATCH 30/41] deser/row: test deser as ColumnIterator Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 61 +++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 5eaa46b91c..99518ff6ca 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -231,3 +231,64 @@ impl Display for BuiltinDeserializationErrorKind { } } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use crate::frame::response::result::{ColumnSpec, ColumnType}; + use crate::types::deserialize::{DeserializationError, FrameSlice}; + + use super::super::tests::{serialize_cells, spec}; + use super::{ColumnIterator, DeserializeRow}; + + #[test] + fn test_deserialization_as_column_iterator() { + let col_specs = [ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Text), + spec("i3", ColumnType::Counter), + ]; + let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); + let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); + + let col1 = iter.next().unwrap().unwrap(); + assert_eq!(col1.spec.name, "i1"); + assert_eq!(col1.spec.typ, ColumnType::Int); + assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); + + let col2 = iter.next().unwrap().unwrap(); + assert_eq!(col2.spec.name, "i2"); + assert_eq!(col2.spec.typ, ColumnType::Text); + assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); + + let col3 = iter.next().unwrap().unwrap(); + assert_eq!(col3.spec.name, "i3"); + assert_eq!(col3.spec.typ, ColumnType::Counter); + assert!(col3.slice.is_none()); + + assert!(iter.next().is_none()); + } + + fn val_int(i: i32) -> Option> { + Some(i.to_be_bytes().to_vec()) + } + + fn val_str(s: &str) -> Option> { + Some(s.as_bytes().to_vec()) + } + + fn deserialize<'frame, R>( + specs: &'frame [ColumnSpec], + byts: &'frame Bytes, + ) -> Result + where + R: DeserializeRow<'frame>, + { + >::type_check(specs) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let slice = FrameSlice::new(byts); + let iter = ColumnIterator::new(specs, slice); + >::deserialize(iter) + } +} From 8e0f29a815dcc1ea906c88c7fb008773afc01521 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 22 Mar 2024 20:16:21 +0100 Subject: [PATCH 31/41] deser/row: impl DeserializeRow for tuples Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 192 +++++++++++++++++++++++- 1 file changed, 189 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 99518ff6ca..007e8041fa 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -4,7 +4,8 @@ use std::fmt::Display; use thiserror::Error; -use super::{DeserializationError, FrameSlice, TypeCheckError}; +use super::value::DeserializeValue; +use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; use crate::frame::response::result::{ColumnSpec, ColumnType}; /// Represents a raw, unparsed column value. @@ -122,6 +123,90 @@ impl<'frame> DeserializeRow<'frame> for ColumnIterator<'frame> { } } +make_error_replace_rust_name!( + _typck_error_replace_rust_name, + TypeCheckError, + BuiltinTypeCheckError +); + +make_error_replace_rust_name!( + deser_error_replace_rust_name, + DeserializationError, + BuiltinDeserializationError +); + +// tuples +// +/// This is the new encouraged way for deserializing a row. +/// If only you know the exact column types in advance, you had better deserialize the row +/// to a tuple. The new deserialization framework will take care of all type checking +/// and needed conversions, issuing meaningful errors in case something goes wrong. +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeRow<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeValue<'frame>),* + { + fn type_check(specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + + let column_types_iter = || specs.iter().map(|spec| spec.typ.clone()); + if let [$($idf),*] = &specs { + $( + <$Ti as DeserializeValue<'frame>>::type_check(&$idf.typ) + .map_err(|err| mk_typck_err::(column_types_iter(), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index: $idx, + column_name: specs[$idx].name.clone(), + err + }))?; + )* + Ok(()) + } else { + Err(mk_typck_err::(column_types_iter(), BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols: TUPLE_LEN, cql_cols: specs.len() + })) + } + } + + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + + let ret = ( + $({ + let column = row.next().unwrap_or_else(|| unreachable!( + "Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row {}", + TUPLE_LEN, + $idx + )).map_err(deser_error_replace_rust_name::)?; + + <$Ti as DeserializeValue<'frame>>::deserialize(&column.spec.typ, column.slice) + .map_err(|err| mk_deser_err::(BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: column.index, + column_name: column.spec.name.clone(), + err, + }))? + },)* + ); + assert!( + row.next().is_none(), + "Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row is bigger", + TUPLE_LEN, + ); + Ok(ret) + } + } + } +} + +use super::value::impl_tuple_multiple; + +// Implements row-to-tuple deserialization for all tuple sizes up to 16. +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + // Error facilities /// Failed to type check incoming result column types again given Rust type, @@ -161,11 +246,47 @@ fn mk_typck_err_named( /// Describes why type checking incoming result column types again given Rust type failed. #[derive(Debug, Clone)] #[non_exhaustive] -pub enum BuiltinTypeCheckErrorKind {} +pub enum BuiltinTypeCheckErrorKind { + /// The Rust type expects `rust_cols` columns, but the statement operates on `cql_cols`. + WrongColumnCount { + /// The number of values that the Rust type provides. + rust_cols: usize, + + /// The number of columns that the statement operates on. + cql_cols: usize, + }, + + /// Column type check failed between Rust type and DB type at given position (=in given column). + ColumnTypeCheckFailed { + /// Index of the column. + column_index: usize, + + /// Name of the column, as provided by the DB. + column_name: String, + + /// Inner type check error due to the type mismatch. + err: TypeCheckError, + }, +} impl Display for BuiltinTypeCheckErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Ok(()) + match self { + BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } => { + write!(f, "wrong column count: the statement operates on {cql_cols} columns, but the given rust types contains {rust_cols}") + } + BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } => write!( + f, + "mismatched types in column {column_name} at index {column_index}: {err}" + ), + } } } @@ -201,6 +322,18 @@ pub(super) fn mk_deser_err_named( #[derive(Debug, Clone)] #[non_exhaustive] pub enum BuiltinDeserializationErrorKind { + /// One of the columns failed to deserialize. + ColumnDeserializationFailed { + /// Index of the column that failed to deserialize. + column_index: usize, + + /// Name of the column that failed to deserialize. + column_name: String, + + /// The error that caused the column deserialization to fail. + err: DeserializationError, + }, + /// One of the raw columns failed to deserialize, most probably /// due to the invalid column structure inside a row in the frame. RawColumnDeserializationFailed { @@ -218,6 +351,16 @@ pub enum BuiltinDeserializationErrorKind { impl Display for BuiltinDeserializationErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index, + column_name, + err, + } => { + write!( + f, + "failed to deserialize column {column_name} at index {column_index}: {err}" + ) + } BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { column_index, column_name, @@ -242,6 +385,49 @@ mod tests { use super::super::tests::{serialize_cells, spec}; use super::{ColumnIterator, DeserializeRow}; + #[test] + fn test_tuple_deserialization() { + // Empty tuple + deserialize::<()>(&[], &Bytes::new()).unwrap(); + + // 1-elem tuple + let (a,) = deserialize::<(i32,)>( + &[spec("i", ColumnType::Int)], + &serialize_cells([val_int(123)]), + ) + .unwrap(); + assert_eq!(a, 123); + + // 3-elem tuple + let (a, b, c) = deserialize::<(i32, i32, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap(); + assert_eq!((a, b, c), (123, 456, 789)); + + // Make sure that column type mismatch is detected + deserialize::<(i32, String, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap_err(); + + // Make sure that borrowing types compile and work correctly + let specs = &[spec("s", ColumnType::Text)]; + let byts = serialize_cells([val_str("abc")]); + let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); + assert_eq!(s, "abc"); + } + #[test] fn test_deserialization_as_column_iterator() { let col_specs = [ From c1f81e4fb85cc5f8a83364cebbe583469416a13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 19 Mar 2024 16:45:11 +0100 Subject: [PATCH 32/41] deser/row: impl DeserializeRow for Row This implementation is important for two reasons: 1. It enables using the upper layers of the old framework over the new one, which makes the transition smoother. 2. Some users (perhaps ORM users?) are going to need the dynamic capabilities that the previous framework offered: receiving rows consisting of arbitrary number of columns of arbitrary types. This is a perfect use case for Row. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 38 ++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 007e8041fa..8ff109942a 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -6,7 +6,7 @@ use thiserror::Error; use super::value::DeserializeValue; use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; -use crate::frame::response::result::{ColumnSpec, ColumnType}; +use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row}; /// Represents a raw, unparsed column value. #[non_exhaustive] @@ -135,6 +135,42 @@ make_error_replace_rust_name!( BuiltinDeserializationError ); +// legacy/dynamic deserialization as Row +// +/// While no longer encouraged (because the new framework encourages deserializing +/// directly into desired types, entirely bypassing CqlValue), this can be indispensable +/// for some use cases, i.e. those involving dynamic parsing (ORMs?). +impl<'frame> DeserializeRow<'frame> for Row { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + // CqlValues accept all types, no type checking needed. + Ok(()) + } + + #[inline] + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + let mut columns = Vec::with_capacity(row.size_hint().0); + while let Some(column) = row + .next() + .transpose() + .map_err(deser_error_replace_rust_name::)? + { + columns.push( + >::deserialize(&column.spec.typ, column.slice).map_err(|err| { + mk_deser_err::( + BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: column.index, + column_name: column.spec.name.clone(), + err, + }, + ) + })?, + ); + } + Ok(Self { columns }) + } +} + // tuples // /// This is the new encouraged way for deserializing a row. From da357cfb7e6b5f80dd2fc32aa0cf3a707239f43f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 22 Mar 2024 20:36:00 +0100 Subject: [PATCH 33/41] deser/result: introduce RowIterator This is an iterator over rows, allowing lazy and flexible deserialization. Returns ColumnIterator for each row. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/mod.rs | 1 + scylla-cql/src/types/deserialize/result.rs | 115 +++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 scylla-cql/src/types/deserialize/result.rs diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index f148f23484..0e9b2a0582 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -163,6 +163,7 @@ // TODO: in the above module docstring, stop abusing ParseError once errors are refactored. pub mod frame_slice; +pub mod result; pub mod row; pub mod value; diff --git a/scylla-cql/src/types/deserialize/result.rs b/scylla-cql/src/types/deserialize/result.rs new file mode 100644 index 0000000000..1c81f61af4 --- /dev/null +++ b/scylla-cql/src/types/deserialize/result.rs @@ -0,0 +1,115 @@ +use crate::frame::response::result::ColumnSpec; + +use super::row::{mk_deser_err, BuiltinDeserializationErrorKind, ColumnIterator}; +use super::{DeserializationError, FrameSlice}; + +/// Iterates over the whole result, returning rows. +pub struct RowIterator<'frame> { + specs: &'frame [ColumnSpec], + remaining: usize, + slice: FrameSlice<'frame>, +} + +impl<'frame> RowIterator<'frame> { + /// Creates a new iterator over rows from a serialized response. + /// + /// - `remaining` - number of the remaining rows in the serialized response, + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a [FrameSlice] that points to the serialized rows data. + #[inline] + pub fn new(remaining: usize, specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs, + remaining, + slice, + } + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.specs + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.remaining + } +} + +impl<'frame> Iterator for RowIterator<'frame> { + type Item = Result, DeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + + let iter = ColumnIterator::new(self.specs, self.slice); + + // Skip the row here, manually + for (column_index, spec) in self.specs.iter().enumerate() { + if let Err(err) = self.slice.read_cql_bytes() { + return Some(Err(mk_deser_err::( + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name: spec.name.clone(), + err: DeserializationError::new(err), + }, + ))); + } + } + + Some(Ok(iter)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // The iterator will always return exactly `self.remaining` + // elements: Oks until an error is encountered and then Errs + // containing that same first encountered error. + (self.remaining, Some(self.remaining)) + } +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::ColumnType; + + use super::super::tests::{serialize_cells, spec, CELL1, CELL2}; + use super::{FrameSlice, RowIterator}; + + #[test] + fn test_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + let mut row1 = iter.next().unwrap().unwrap(); + let c11 = row1.next().unwrap().unwrap(); + assert_eq!(c11.slice.unwrap().as_slice(), CELL1); + let c12 = row1.next().unwrap().unwrap(); + assert_eq!(c12.slice.unwrap().as_slice(), CELL2); + assert!(row1.next().is_none()); + + let mut row2 = iter.next().unwrap().unwrap(); + let c21 = row2.next().unwrap().unwrap(); + assert_eq!(c21.slice.unwrap().as_slice(), CELL2); + let c22 = row2.next().unwrap().unwrap(); + assert_eq!(c22.slice.unwrap().as_slice(), CELL1); + assert!(row2.next().is_none()); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_row_iterator_too_few_rows() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + iter.next().unwrap().unwrap(); + assert!(iter.next().unwrap().is_err()); + } +} From c458b6a8359d9541cdf9d0bdd2ffa0d7284b73c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 22 Mar 2024 20:27:15 +0100 Subject: [PATCH 34/41] deser/result: introduce TypedRowIterator This iterator wraps over RowIterator and for each row consumes the ColumnIterator and deserializes the given type. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/result.rs | 91 +++++++++++++++++++++- 1 file changed, 88 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/result.rs b/scylla-cql/src/types/deserialize/result.rs index 1c81f61af4..036b909afb 100644 --- a/scylla-cql/src/types/deserialize/result.rs +++ b/scylla-cql/src/types/deserialize/result.rs @@ -1,7 +1,8 @@ use crate::frame::response::result::ColumnSpec; -use super::row::{mk_deser_err, BuiltinDeserializationErrorKind, ColumnIterator}; -use super::{DeserializationError, FrameSlice}; +use super::row::{mk_deser_err, BuiltinDeserializationErrorKind, ColumnIterator, DeserializeRow}; +use super::{DeserializationError, FrameSlice, TypeCheckError}; +use std::marker::PhantomData; /// Iterates over the whole result, returning rows. pub struct RowIterator<'frame> { @@ -73,12 +74,70 @@ impl<'frame> Iterator for RowIterator<'frame> { } } +/// A typed version of [RowIterator] which deserializes the rows before +/// returning them. +pub struct TypedRowIterator<'frame, R> { + inner: RowIterator<'frame>, + _phantom: PhantomData, +} + +impl<'frame, R> TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + /// Creates a new [TypedRowIterator] from given [RowIterator]. + /// + /// Calls `R::type_check` and fails if the type check fails. + #[inline] + pub fn new(raw: RowIterator<'frame>) -> Result { + R::type_check(raw.specs())?; + Ok(Self { + inner: raw, + _phantom: PhantomData, + }) + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.inner.specs() + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.inner.rows_remaining() + } +} + +impl<'frame, R> Iterator for TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + type Item = Result; + + #[inline] + fn next(&mut self) -> Option { + self.inner + .next() + .map(|raw| raw.and_then(|raw| R::deserialize(raw))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + #[cfg(test)] mod tests { + use bytes::Bytes; + use crate::frame::response::result::ColumnType; use super::super::tests::{serialize_cells, spec, CELL1, CELL2}; - use super::{FrameSlice, RowIterator}; + use super::{FrameSlice, RowIterator, TypedRowIterator}; #[test] fn test_row_iterator_basic_parse() { @@ -112,4 +171,30 @@ mod tests { iter.next().unwrap().unwrap(); assert!(iter.next().unwrap().is_err()); } + + #[test] + fn test_typed_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + let mut iter = TypedRowIterator::<'_, (&[u8], Vec)>::new(iter).unwrap(); + + let (c11, c12) = iter.next().unwrap().unwrap(); + assert_eq!(c11, CELL1); + assert_eq!(c12, CELL2); + + let (c21, c22) = iter.next().unwrap().unwrap(); + assert_eq!(c21, CELL2); + assert_eq!(c22, CELL1); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_typed_row_iterator_wrong_type() { + let raw_data = Bytes::new(); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(0, &specs, FrameSlice::new(&raw_data)); + assert!(TypedRowIterator::<'_, (i32, i64)>::new(iter).is_err()); + } } From dd5e71e962ac4429af2202dc9486fc4ccde6d655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 8 May 2024 11:39:51 +0200 Subject: [PATCH 35/41] deser/value: errors tests --- scylla-cql/src/types/deserialize/value.rs | 522 +++++++++++++++++++++- 1 file changed, 519 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 6fe399d08b..7fb8c8e098 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1763,7 +1763,8 @@ impl From for BuiltinDeserializationErrorKind { } #[cfg(test)] -mod tests { +pub(super) mod tests { + use assert_matches::assert_matches; use bytes::{BufMut, Bytes, BytesMut}; use uuid::Uuid; @@ -1777,13 +1778,18 @@ mod tests { use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; + use crate::types::deserialize::value::{ + TupleDeserializationErrorKind, TupleTypeCheckErrorKind, + }; use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; use crate::types::serialize::value::SerializeValue; use crate::types::serialize::CellWriter; use super::{ - mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, ListlikeIterator, - MapIterator, MaybeEmpty, + mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, + MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, + SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, }; #[test] @@ -2429,4 +2435,514 @@ mod tests { fn append_null(b: &mut impl BufMut) { b.put_i32(-1); } + + /* Errors checks */ + + #[track_caller] + pub(crate) fn get_typeck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), + ) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } + } + + #[track_caller] + pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typeck_err_inner(err.0.as_ref()) + } + + #[track_caller] + pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } + } + + macro_rules! assert_given_error { + ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + let cql_typ = $cql_typ.clone(); + let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); + let err = $get_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); + assert_eq!(err.cql_type, cql_typ); + assert_matches::assert_matches!(err.kind, $kind); + }; + } + + macro_rules! assert_type_check_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); + }; + } + + macro_rules! assert_deser_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); + }; + } + + #[test] + fn test_native_errors() { + // Simple type mismatch + { + let v = 123_i32; + let bytes = serialize(&ColumnType::Int, &v); + + // Incompatible types render type check error. + assert_type_check_error!( + &bytes, + f64, + ColumnType::Int, + super::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Double], + } + ); + + // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. + assert_deser_error!( + &bytes, + f64, + ColumnType::Double, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } + ); + + // ColumnType is said to be Float, but in reality Int was serialized. + // As these types have the same size, though, and every binary number in [0, 2^32] is a valid + // value for both of them, this always succeeds. + { + deserialize::(&ColumnType::Float, &bytes).unwrap(); + } + } + + // str (and also Uuid) are interesting because they accept two types. + { + let v = "Ala ma kota"; + let bytes = serialize(&ColumnType::Ascii, &v); + + assert_type_check_error!( + &bytes, + &str, + ColumnType::Double, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text], + } + ); + + // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form + // (the string) has 11 bytes. + assert_deser_error!( + &bytes, + i64, + ColumnType::BigInt, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 11, // str len + } + ); + } + { + // -126 is not a valid ASCII nor UTF-8 byte. + let v = -126_i8; + let bytes = serialize(&ColumnType::TinyInt, &v); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Ascii, + BuiltinDeserializationErrorKind::ExpectedAscii + ); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Text, + BuiltinDeserializationErrorKind::InvalidUtf8(_) + ); + } + } + + #[test] + fn test_set_or_list_errors() { + // Not a set or list + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::Float, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList + ) + ); + + // Type check of Rust set against CQL list must fail, because it would be lossy. + assert_type_check_error!( + &Bytes::new(), + BTreeSet, + ColumnType::List(Box::new(ColumnType::Int)), + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) + ); + } + + // Got null + { + type RustTyp = Vec; + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad element type + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::List(Box::new(ColumnType::Ascii)), + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) + ) + ); + + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::Varint)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); + let BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + { + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + let v = vec![123_i32]; + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::BigInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); + let BuiltinDeserializationErrorKind::SetOrListError( + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } + } + + #[test] + fn test_map_errors() { + // Not a map + { + let ser_typ = ColumnType::Float; + let v = 2.12_f32; + let bytes = serialize(&ser_typ, &v); + + assert_type_check_error!( + &bytes, + HashMap, + ser_typ, + BuiltinTypeCheckErrorKind::MapError( + MapTypeCheckErrorKind::NotMap, + ) + ); + } + + // Got null + { + type RustTyp = HashMap; + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Key type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Value type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + + // Key length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Value length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 2, + got: 1 + } + ); + } + } + + #[test] + fn test_tuple_errors() { + // Not a tuple + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::BigInt, + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) + ); + } + // Wrong element count + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::Tuple(vec![]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 0, + }) + ); + + assert_type_check_error!( + &Bytes::new(), + (f32,), + ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 2, + }) + ); + } + + // Bad field type + { + { + let err = deserialize::<(i64,)>( + &ColumnType::Tuple(vec![ColumnType::SmallInt]), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); + let BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(position, 0); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + } + + { + let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); + let v = (123_i32, 123.123_f32); + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::<(i32, f64)>( + &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); + assert_eq!( + err.cql_type, + ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) + ); + let BuiltinDeserializationErrorKind::TupleError( + TupleDeserializationErrorKind::FieldDeserializationFailed { + ref err, + position: index, + }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(index, 1); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Double); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } + } + + #[test] + fn test_null_errors() { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + deserialize::>(&ser_typ, &bytes).unwrap_err(); + } } From 58f523f26ad5d03772f4c9f5962d7131e3f4147f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Fri, 10 May 2024 06:55:00 +0200 Subject: [PATCH 36/41] deser/row: errors tests --- scylla-cql/src/types/deserialize/row.rs | 170 +++++++++++++++++++++++- 1 file changed, 169 insertions(+), 1 deletion(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 8ff109942a..c68658f4e1 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -413,13 +413,17 @@ impl Display for BuiltinDeserializationErrorKind { #[cfg(test)] mod tests { + use assert_matches::assert_matches; use bytes::Bytes; + use crate::frame::frame_errors::ParseError; use crate::frame::response::result::{ColumnSpec, ColumnType}; + use crate::types::deserialize::row::BuiltinDeserializationErrorKind; use crate::types::deserialize::{DeserializationError, FrameSlice}; use super::super::tests::{serialize_cells, spec}; - use super::{ColumnIterator, DeserializeRow}; + use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; + use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; #[test] fn test_tuple_deserialization() { @@ -513,4 +517,168 @@ mod tests { let iter = ColumnIterator::new(specs, slice); >::deserialize(iter) } + + #[track_caller] + fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } + } + + #[track_caller] + fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } + } + + #[test] + fn test_tuple_errors() { + // Column type check failure + { + let col_name: &str = "i"; + let specs = &[spec(col_name, ColumnType::Int)]; + let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); + let err = get_typck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!( + err.cql_types, + specs + .iter() + .map(|spec| spec.typ.clone()) + .collect::>() + ); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(*column_index, 0); + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + &err.kind, + super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_name, + err, + .. + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } + } + + #[test] + fn test_row_errors() { + // Column type check failure - happens never, because Row consists of CqlValues, + // which accept all CQL types. + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: _column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + let super::super::value::BuiltinDeserializationErrorKind::GenericParseError( + ParseError::BadIncomingData(info), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(info, "Buffer length should be 8 not 4"); + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } + } } From 802e6835424a188127fcc96e60cda387e1a23e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 28 May 2024 12:08:58 +0200 Subject: [PATCH 37/41] result/deser_cql_value: use DeserializeValue impls In the future, we will probably deprecate and remove `deser_cql_value` altogether. For now, let's make it at least less bloaty. To reduce code duplication, `deser_cql_value()` now uses DeserializeValue impls for nearly all of the deserialized types. Two notable exceptions are: 1. CQL Map - because it is represented as Vec<(CqlValue, CqlValue)> in CqlValue, and Vec is only deserializable from CQL Set|Map. Therefore, MapIterator is deserialized using its DeserializeValue impl, and then collected into Vec. 2. CQL Tuple - because it is represented in CqlValue much differently than in DeserializeValue impls: Vec vs (T1, T2, ..., Tn). Therefore, it's similarly to how it was before, just style is changed from imperative to iterator-based, and DeserializeValue impl is called instead of `deser_cql_value` there. As a bonus, we get more descriptive error messages (as compared to old `ParseError::BadIncomingData` ones). --- scylla-cql/src/frame/response/result.rs | 299 ++++++------------ .../src/types/deserialize/frame_slice.rs | 26 ++ scylla-cql/src/types/deserialize/row.rs | 11 +- 3 files changed, 123 insertions(+), 213 deletions(-) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 527d481eb2..3d9b1b0914 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -1,19 +1,14 @@ use crate::cql_to_rust::{FromRow, FromRowError}; use crate::frame::response::event::SchemaChangeEvent; -use crate::frame::types::vint_decode; use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; use crate::frame::{frame_errors::ParseError, types}; -use byteorder::{BigEndian, ReadBytesExt}; +use crate::types::deserialize::value::{DeserializeValue, MapIterator, UdtIterator}; +use crate::types::deserialize::FrameSlice; use bytes::{Buf, Bytes}; use std::borrow::Cow; -use std::{ - convert::{TryFrom, TryInto}, - net::IpAddr, - result::Result as StdResult, - str, -}; +use std::{convert::TryInto, net::IpAddr, result::Result as StdResult, str}; use uuid::Uuid; #[cfg(feature = "chrono")] @@ -655,6 +650,11 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult return Ok(CqlValue::Empty), } } + // The `new_borrowed` version of FrameSlice is deficient in that it does not hold + // a `Bytes` reference to the frame, only a slice. + // This is not a problem here, fortunately, because none of CqlValue variants contain + // any `Bytes` - only exclusively owned types - so we never call FrameSlice::to_bytes(). + let v = Some(FrameSlice::new_borrowed(buf)); Ok(match typ { Custom(type_str) => { @@ -664,239 +664,112 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - if !buf.is_ascii() { - return Err(ParseError::BadIncomingData( - "String is not ascii!".to_string(), - )); - } - CqlValue::Ascii(str::from_utf8(buf)?.to_owned()) + let s = String::deserialize(typ, v)?; + CqlValue::Ascii(s) } Boolean => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::Boolean(buf[0] != 0x00) + let b = bool::deserialize(typ, v)?; + CqlValue::Boolean(b) + } + Blob => { + let b = Vec::::deserialize(typ, v)?; + CqlValue::Blob(b) } - Blob => CqlValue::Blob(buf.to_vec()), Date => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - - let date_value = buf.read_u32::()?; - CqlValue::Date(CqlDate(date_value)) + let d = CqlDate::deserialize(typ, v)?; + CqlValue::Date(d) } Counter => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Counter(crate::frame::value::Counter(buf.read_i64::()?)) + let c = crate::frame::response::result::Counter::deserialize(typ, v)?; + CqlValue::Counter(c) } Decimal => { - let scale = types::read_int(buf)?; - let bytes = buf.to_vec(); - let big_decimal: CqlDecimal = - CqlDecimal::from_signed_be_bytes_and_exponent(bytes, scale); - - CqlValue::Decimal(big_decimal) + let d = CqlDecimal::deserialize(typ, v)?; + CqlValue::Decimal(d) } Double => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Double(buf.read_f64::()?) + let d = f64::deserialize(typ, v)?; + CqlValue::Double(d) } Float => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Float(buf.read_f32::()?) + let f = f32::deserialize(typ, v)?; + CqlValue::Float(f) } Int => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Int(buf.read_i32::()?) + let i = i32::deserialize(typ, v)?; + CqlValue::Int(i) } SmallInt => { - if buf.len() != 2 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 2 not {}", - buf.len() - ))); - } - - CqlValue::SmallInt(buf.read_i16::()?) + let si = i16::deserialize(typ, v)?; + CqlValue::SmallInt(si) } TinyInt => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::TinyInt(buf.read_i8()?) + let ti = i8::deserialize(typ, v)?; + CqlValue::TinyInt(ti) } BigInt => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::BigInt(buf.read_i64::()?) + let bi = i64::deserialize(typ, v)?; + CqlValue::BigInt(bi) + } + Text => { + let s = String::deserialize(typ, v)?; + CqlValue::Text(s) } - Text => CqlValue::Text(str::from_utf8(buf)?.to_owned()), Timestamp => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let millis = buf.read_i64::()?; - - CqlValue::Timestamp(CqlTimestamp(millis)) + let t = CqlTimestamp::deserialize(typ, v)?; + CqlValue::Timestamp(t) } Time => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let nanoseconds: i64 = buf.read_i64::()?; - - // Valid values are in the range 0 to 86399999999999 - if !(0..=86399999999999).contains(&nanoseconds) { - return Err(ParseError::BadIncomingData(format! { - "Invalid time value only 0 to 86399999999999 allowed: {}.", nanoseconds - })); - } - - CqlValue::Time(CqlTime(nanoseconds)) + let t = CqlTime::deserialize(typ, v)?; + CqlValue::Time(t) } Timeuuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); - CqlValue::Timeuuid(CqlTimeuuid::from(uuid)) + let t = CqlTimeuuid::deserialize(typ, v)?; + CqlValue::Timeuuid(t) } Duration => { - let months = i32::try_from(vint_decode(buf)?)?; - let days = i32::try_from(vint_decode(buf)?)?; - let nanoseconds = vint_decode(buf)?; - - CqlValue::Duration(CqlDuration { - months, - days, - nanoseconds, - }) + let d = CqlDuration::deserialize(typ, v)?; + CqlValue::Duration(d) + } + Inet => { + let i = IpAddr::deserialize(typ, v)?; + CqlValue::Inet(i) } - Inet => CqlValue::Inet(match buf.len() { - 4 => { - let ret = IpAddr::from(<[u8; 4]>::try_from(&buf[0..4])?); - buf.advance(4); - ret - } - 16 => { - let ret = IpAddr::from(<[u8; 16]>::try_from(&buf[0..16])?); - buf.advance(16); - ret - } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))); - } - }), Uuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); + let uuid = uuid::Uuid::deserialize(typ, v)?; CqlValue::Uuid(uuid) } - Varint => CqlValue::Varint(CqlVarint::from_signed_bytes_be(buf.to_vec())), - List(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::List(res) + Varint => { + let vi = CqlVarint::deserialize(typ, v)?; + CqlValue::Varint(vi) } - Map(key_type, value_type) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - let key = deser_cql_value(key_type, &mut b)?; - b = types::read_bytes(buf)?; - let val = deser_cql_value(value_type, &mut b)?; - res.push((key, val)); - } - CqlValue::Map(res) + List(_type_name) => { + let l = Vec::::deserialize(typ, v)?; + CqlValue::List(l) } - Set(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - // TODO: is `null` allowed as set element? Should we use read_bytes_opt? - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::Set(res) + Map(_key_type, _value_type) => { + let iter = MapIterator::<'_, CqlValue, CqlValue>::deserialize(typ, v)?; + let m: Vec<(CqlValue, CqlValue)> = iter.collect::>()?; + CqlValue::Map(m) + } + Set(_type_name) => { + let s = Vec::::deserialize(typ, v)?; + CqlValue::Set(s) } UserDefinedType { type_name, keyspace, - field_types, + .. } => { - let mut fields: Vec<(String, Option)> = Vec::new(); - - for (field_name, field_type) in field_types { - // If a field is added to a UDT and we read an old (frozen ?) version of it, - // the driver will fail to parse the whole UDT. - // This is why we break the parsing after we reach the end of the serialized UDT. - if buf.is_empty() { - break; - } - - let mut field_value: Option = None; - if let Some(mut field_val_bytes) = types::read_bytes_opt(buf)? { - field_value = Some(deser_cql_value(field_type, &mut field_val_bytes)?); - } - - fields.push((field_name.clone(), field_value)); - } + let iter = UdtIterator::deserialize(typ, v)?; + let fields: Vec<(String, Option)> = iter + .map(|((col_name, col_type), res)| { + res.and_then(|v| { + let val = Option::::deserialize(col_type, v.flatten())?; + Ok((col_name.clone(), val)) + }) + }) + .collect::>()?; CqlValue::UserDefinedType { keyspace: keyspace.clone(), @@ -905,15 +778,19 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - let mut res = Vec::with_capacity(type_names.len()); - for type_name in type_names { - match types::read_bytes_opt(buf)? { - Some(mut b) => res.push(Some(deser_cql_value(type_name, &mut b)?)), - None => res.push(None), - }; - } - - CqlValue::Tuple(res) + let t = type_names + .iter() + .map(|typ| { + types::read_bytes_opt(buf).and_then(|v| { + v.map(|v| { + CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v))) + .map_err(Into::into) + }) + .transpose() + }) + }) + .collect::>()?; + CqlValue::Tuple(t) } }) } diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs index 02713bbe7c..cfc98d5ce5 100644 --- a/scylla-cql/src/types/deserialize/frame_slice.rs +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -72,6 +72,23 @@ impl<'frame> FrameSlice<'frame> { } } + /// Creates a new FrameSlice from a reference to a slice. + /// + /// This method creates a not-fully-valid FrameSlice that does not hold + /// the valid original frame Bytes. Thus, it is intended to be used in + /// legacy code that does not operate on Bytes, but rather on borrowed slice only. + /// For correctness in an unlikely case that someone calls `to_bytes()` on such + /// a deficient slice, a special treatment is added there that copies + /// the slice into a new-allocation-based Bytes. + /// This is pub(crate) for the above reason. + #[inline] + pub(crate) fn new_borrowed(frame_subslice: &'frame [u8]) -> Self { + Self { + frame_subslice, + original_frame: &EMPTY_BYTES, + } + } + /// Returns `true` if the slice has length of 0. #[inline] pub fn is_empty(&self) -> bool { @@ -105,6 +122,15 @@ impl<'frame> FrameSlice<'frame> { /// frame slice object. #[inline] pub fn to_bytes(&self) -> Bytes { + if self.original_frame.is_empty() { + // For the borrowed, deficient version of FrameSlice - the one created with + // FrameSlice::new_borrowed to work properly in case someone calls + // FrameSlice::to_bytes on it (even though it's not intended for the borrowed version), + // the special case is introduced that creates new Bytes by copying the slice into + // a new allocation. Note that it's something unexpected to be ever done. + return Bytes::copy_from_slice(self.as_slice()); + } + self.original_frame.slice_ref(self.frame_subslice) } diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index c68658f4e1..1eea286802 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -652,12 +652,19 @@ mod tests { assert_eq!(err.rust_name, std::any::type_name::()); assert_eq!(err.cql_type, ColumnType::BigInt); let super::super::value::BuiltinDeserializationErrorKind::GenericParseError( - ParseError::BadIncomingData(info), + ParseError::DeserializationError(d), ) = &err.kind else { panic!("unexpected error kind: {}", err.kind) }; - assert_eq!(info, "Buffer length should be 8 not 4"); + let err = super::super::value::tests::get_deser_err(d); + let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; } // Raw column deserialization failure From ff375352b2795ae73c2e08ae03326c8154fce8f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 28 May 2024 12:09:40 +0200 Subject: [PATCH 38/41] deser_rows: use DeserializeRow impl for Row In a manner similar to the previous commit, old imperative logic in `deser_row` is replaced with new iterator-based one, which uses the new deserialization framework. As a bonus, we get more descriptive error messages (as compared to old `ParseError::BadIncomingData` ones). --- scylla-cql/src/frame/response/result.rs | 26 ++++++++++++------------- scylla-cql/src/types/deserialize/row.rs | 2 +- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 3d9b1b0914..f961f4b99e 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -4,8 +4,9 @@ use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; use crate::frame::{frame_errors::ParseError, types}; +use crate::types::deserialize::result::{RowIterator, TypedRowIterator}; use crate::types::deserialize::value::{DeserializeValue, MapIterator, UdtIterator}; -use crate::types::deserialize::FrameSlice; +use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; use std::borrow::Cow; use std::{convert::TryInto, net::IpAddr, result::Result as StdResult, str}; @@ -820,19 +821,16 @@ fn deser_rows( let rows_count: usize = types::read_int(buf)?.try_into()?; - let mut rows = Vec::with_capacity(rows_count); - for _ in 0..rows_count { - let mut columns = Vec::with_capacity(metadata.col_count); - for i in 0..metadata.col_count { - let v = if let Some(mut b) = types::read_bytes_opt(buf)? { - Some(deser_cql_value(&metadata.col_specs[i].typ, &mut b)?) - } else { - None - }; - columns.push(v); - } - rows.push(Row { columns }); - } + let raw_rows_iter = RowIterator::new( + rows_count, + &metadata.col_specs, + FrameSlice::new_borrowed(buf), + ); + let rows_iter = TypedRowIterator::::new(raw_rows_iter) + .map_err(|err| DeserializationError::new(err.0))?; + + let rows = rows_iter.collect::>()?; + Ok(Rows { metadata, rows_count, diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 1eea286802..c66f3c7328 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -138,7 +138,7 @@ make_error_replace_rust_name!( // legacy/dynamic deserialization as Row // /// While no longer encouraged (because the new framework encourages deserializing -/// directly into desired types, entirely bypassing CqlValue), this can be indispensable +/// directly into desired types, entirely bypassing [CqlValue]), this can be indispensable /// for some use cases, i.e. those involving dynamic parsing (ORMs?). impl<'frame> DeserializeRow<'frame> for Row { #[inline] From 9b8f4f3357ccbb079ddab428f10b95c3e80d974f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 4 Jun 2024 10:14:33 +0200 Subject: [PATCH 39/41] value: remove compatibility tests As `deser_cql_value` was made use the new deserialization framework, these tests no longer test anything. Therefore, they are deleted. Their presence in the previous commits is useful, though, to prove that compatibility. It is worth noting that now, with `deser_cql_value` using the new framework, tests there in frame/response/result.rs now are used to test the deserialization implementation in the new framework. --- scylla-cql/src/types/deserialize/value.rs | 313 ---------------------- 1 file changed, 313 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 7fb8c8e098..b0c5858d18 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1772,9 +1772,7 @@ pub(super) mod tests { use std::fmt::Debug; use std::net::{IpAddr, Ipv6Addr}; - use crate::frame::response::cql_to_rust::FromCqlVal; use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; - use crate::frame::types; use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; @@ -1907,286 +1905,6 @@ pub(super) mod tests { assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); } - #[test] - fn test_from_cql_value_compatibility() { - // This test should have a sub-case for each type - // that implements FromCqlValue - - // fixed size integers - for i in 0..7 { - let v: i8 = 1 << i; - compat_check::(&ColumnType::TinyInt, make_bytes(&v.to_be_bytes())); - compat_check::(&ColumnType::TinyInt, make_bytes(&(-v).to_be_bytes())); - } - for i in 0..15 { - let v: i16 = 1 << i; - compat_check::(&ColumnType::SmallInt, make_bytes(&v.to_be_bytes())); - compat_check::(&ColumnType::SmallInt, make_bytes(&(-v).to_be_bytes())); - } - for i in 0..31 { - let v: i32 = 1 << i; - compat_check::(&ColumnType::Int, make_bytes(&v.to_be_bytes())); - compat_check::(&ColumnType::Int, make_bytes(&(-v).to_be_bytes())); - } - for i in 0..63 { - let v: i64 = 1 << i; - compat_check::(&ColumnType::BigInt, make_bytes(&v.to_be_bytes())); - compat_check::(&ColumnType::BigInt, make_bytes(&(-v).to_be_bytes())); - } - - // bool - compat_check::(&ColumnType::Boolean, make_bytes(&[0])); - compat_check::(&ColumnType::Boolean, make_bytes(&[1])); - - // fixed size floating point types - compat_check::(&ColumnType::Float, make_bytes(&123f32.to_be_bytes())); - compat_check::(&ColumnType::Float, make_bytes(&(-123f32).to_be_bytes())); - compat_check::(&ColumnType::Double, make_bytes(&123f64.to_be_bytes())); - compat_check::(&ColumnType::Double, make_bytes(&(-123f64).to_be_bytes())); - - // big integers - const PI_STR: &[u8] = b"3.1415926535897932384626433832795028841971693993751058209749445923"; - let num1 = &PI_STR[2..]; - let num2 = [b'-'] - .into_iter() - .chain(PI_STR[2..].iter().copied()) - .collect::>(); - let num3 = &b"0"[..]; - - // native - CqlVarint - { - let num1 = CqlVarint::from_signed_bytes_be_slice(num1); - let num2 = CqlVarint::from_signed_bytes_be_slice(&num2); - let num3 = CqlVarint::from_signed_bytes_be_slice(num3); - compat_check_serialized::(&ColumnType::Varint, &num1); - compat_check_serialized::(&ColumnType::Varint, &num2); - compat_check_serialized::(&ColumnType::Varint, &num3); - } - - #[cfg(feature = "num-bigint-03")] - { - use num_bigint_03::BigInt; - - let num1 = BigInt::parse_bytes(num1, 10).unwrap(); - let num2 = BigInt::parse_bytes(&num2, 10).unwrap(); - let num3 = BigInt::parse_bytes(num3, 10).unwrap(); - compat_check_serialized::(&ColumnType::Varint, &num1); - compat_check_serialized::(&ColumnType::Varint, &num2); - compat_check_serialized::(&ColumnType::Varint, &num3); - } - - #[cfg(feature = "num-bigint-04")] - { - use num_bigint_04::BigInt; - - let num1 = BigInt::parse_bytes(num1, 10).unwrap(); - let num2 = BigInt::parse_bytes(&num2, 10).unwrap(); - let num3 = BigInt::parse_bytes(num3, 10).unwrap(); - compat_check_serialized::(&ColumnType::Varint, &num1); - compat_check_serialized::(&ColumnType::Varint, &num2); - compat_check_serialized::(&ColumnType::Varint, &num3); - } - - // big decimals - { - let scale1 = 0; - let scale2 = -42; - let scale3 = 2137; - let num1 = CqlDecimal::from_signed_be_bytes_slice_and_exponent(num1, scale1); - let num2 = CqlDecimal::from_signed_be_bytes_and_exponent(num2, scale2); - let num3 = CqlDecimal::from_signed_be_bytes_slice_and_exponent(num3, scale3); - compat_check_serialized::(&ColumnType::Decimal, &num1); - compat_check_serialized::(&ColumnType::Decimal, &num2); - compat_check_serialized::(&ColumnType::Decimal, &num3); - } - - // native - CqlDecimal - - #[cfg(feature = "bigdecimal-04")] - { - use bigdecimal_04::BigDecimal; - - let num1 = PI_STR.to_vec(); - let num2 = vec![b'-'] - .into_iter() - .chain(PI_STR.iter().copied()) - .collect::>(); - let num3 = b"0.0".to_vec(); - - let num1 = BigDecimal::parse_bytes(&num1, 10).unwrap(); - let num2 = BigDecimal::parse_bytes(&num2, 10).unwrap(); - let num3 = BigDecimal::parse_bytes(&num3, 10).unwrap(); - compat_check_serialized::(&ColumnType::Decimal, &num1); - compat_check_serialized::(&ColumnType::Decimal, &num2); - compat_check_serialized::(&ColumnType::Decimal, &num3); - } - - // blob - compat_check::>(&ColumnType::Blob, make_bytes(&[])); - compat_check::>(&ColumnType::Blob, make_bytes(&[1, 9, 2, 8, 3, 7, 4, 6, 5])); - - // text types - for typ in &[ColumnType::Ascii, ColumnType::Text] { - compat_check::(typ, make_bytes("".as_bytes())); - compat_check::(typ, make_bytes("foo".as_bytes())); - compat_check::(typ, make_bytes("superfragilisticexpialidocious".as_bytes())); - } - - // counters - for i in 0..63 { - let v: i64 = 1 << i; - compat_check::(&ColumnType::Counter, make_bytes(&v.to_be_bytes())); - } - - // duration - let duration1 = CqlDuration { - days: 123, - months: 456, - nanoseconds: 789, - }; - let duration2 = CqlDuration { - days: 987, - months: 654, - nanoseconds: 321, - }; - compat_check_serialized::(&ColumnType::Duration, &duration1); - compat_check_serialized::(&ColumnType::Duration, &duration2); - - // date - let date1 = (2u32.pow(31)).to_be_bytes(); - let date2 = (2u32.pow(31) - 30).to_be_bytes(); - let date3 = (2u32.pow(31) + 30).to_be_bytes(); - - compat_check::(&ColumnType::Date, make_bytes(&date1)); - compat_check::(&ColumnType::Date, make_bytes(&date2)); - compat_check::(&ColumnType::Date, make_bytes(&date3)); - - #[cfg(feature = "chrono")] - { - compat_check::(&ColumnType::Date, make_bytes(&date1)); - compat_check::(&ColumnType::Date, make_bytes(&date2)); - compat_check::(&ColumnType::Date, make_bytes(&date3)); - } - - #[cfg(feature = "time")] - { - compat_check::(&ColumnType::Date, make_bytes(&date1)); - compat_check::(&ColumnType::Date, make_bytes(&date2)); - compat_check::(&ColumnType::Date, make_bytes(&date3)); - } - - // time - let time1 = CqlTime(0); - let time2 = CqlTime(123456789); - let time3 = CqlTime(86399999999999); // maximum allowed - - compat_check_serialized::(&ColumnType::Time, &time1); - compat_check_serialized::(&ColumnType::Time, &time2); - compat_check_serialized::(&ColumnType::Time, &time3); - - #[cfg(feature = "chrono")] - { - compat_check_serialized::(&ColumnType::Time, &time1); - compat_check_serialized::(&ColumnType::Time, &time2); - compat_check_serialized::(&ColumnType::Time, &time3); - } - - #[cfg(feature = "time")] - { - compat_check_serialized::(&ColumnType::Time, &time1); - compat_check_serialized::(&ColumnType::Time, &time2); - compat_check_serialized::(&ColumnType::Time, &time3); - } - - // timestamp - let timestamp1 = CqlTimestamp(0); - let timestamp2 = CqlTimestamp(123456789); - let timestamp3 = CqlTimestamp(98765432123456); - - compat_check_serialized::(&ColumnType::Timestamp, ×tamp1); - compat_check_serialized::(&ColumnType::Timestamp, ×tamp2); - compat_check_serialized::(&ColumnType::Timestamp, ×tamp3); - - #[cfg(feature = "chrono")] - { - compat_check_serialized::>( - &ColumnType::Timestamp, - ×tamp1, - ); - compat_check_serialized::>( - &ColumnType::Timestamp, - ×tamp2, - ); - compat_check_serialized::>( - &ColumnType::Timestamp, - ×tamp3, - ); - } - - #[cfg(feature = "time")] - { - compat_check_serialized::(&ColumnType::Timestamp, ×tamp1); - compat_check_serialized::(&ColumnType::Timestamp, ×tamp2); - compat_check_serialized::(&ColumnType::Timestamp, ×tamp3); - } - - // inet - let ipv4 = IpAddr::from([127u8, 0, 0, 1]); - let ipv6: IpAddr = Ipv6Addr::LOCALHOST.into(); - compat_check::(&ColumnType::Inet, make_ip_address(ipv4)); - compat_check::(&ColumnType::Inet, make_ip_address(ipv6)); - - // uuid and timeuuid - // new_v4 generates random UUIDs, so these are different cases - for uuid in std::iter::repeat_with(Uuid::new_v4).take(3) { - compat_check_serialized::(&ColumnType::Uuid, &uuid); - compat_check_serialized::(&ColumnType::Timeuuid, &CqlTimeuuid::from(uuid)); - } - - // empty values - // ...are implemented via MaybeEmpty and are handled in other tests - - // nulls, represented via Option - compat_check_serialized::>(&ColumnType::Int, &123i32); - compat_check::>(&ColumnType::Int, make_null()); - - // collections - let mut list = BytesMut::new(); - list.put_i32(3); - append_bytes(&mut list, &123i32.to_be_bytes()); - append_bytes(&mut list, &456i32.to_be_bytes()); - append_bytes(&mut list, &789i32.to_be_bytes()); - let list = make_bytes(&list); - let list_type = ColumnType::List(Box::new(ColumnType::Int)); - compat_check::>(&list_type, list.clone()); - // Support for deserialization List -> {Hash,BTree}Set was removed not to cause confusion. - // Such deserialization would be lossy, which is unwanted. - - let mut set = BytesMut::new(); - set.put_i32(3); - append_bytes(&mut set, &123i32.to_be_bytes()); - append_bytes(&mut set, &456i32.to_be_bytes()); - append_bytes(&mut set, &789i32.to_be_bytes()); - let set = make_bytes(&set); - let set_type = ColumnType::Set(Box::new(ColumnType::Int)); - compat_check::>(&set_type, set.clone()); - compat_check::>(&set_type, set.clone()); - compat_check::>(&set_type, set); - - let mut map = BytesMut::new(); - map.put_i32(3); - append_bytes(&mut map, &123i32.to_be_bytes()); - append_bytes(&mut map, "quick".as_bytes()); - append_bytes(&mut map, &456i32.to_be_bytes()); - append_bytes(&mut map, "brown".as_bytes()); - append_bytes(&mut map, &789i32.to_be_bytes()); - append_bytes(&mut map, "fox".as_bytes()); - let map = make_bytes(&map); - let map_type = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Text)); - compat_check::>(&map_type, map.clone()); - compat_check::>(&map_type, map); - } - #[test] fn test_maybe_empty() { let empty = make_bytes(&[]); @@ -2348,37 +2066,6 @@ pub(super) mod tests { assert_eq!(tup, SwappedPair("foo", 42)); } - // Checks that both new and old serialization framework - // produces the same results in this case - fn compat_check(typ: &ColumnType, raw: Bytes) - where - T: for<'f> DeserializeValue<'f>, - T: FromCqlVal>, - T: Debug + PartialEq, - { - let mut slice = raw.as_ref(); - let mut cell = types::read_bytes_opt(&mut slice).unwrap(); - let old = T::from_cql( - cell.as_mut() - .map(|c| deser_cql_value(typ, c)) - .transpose() - .unwrap(), - ) - .unwrap(); - let new = deserialize::(typ, &raw).unwrap(); - assert_eq!(old, new); - } - - fn compat_check_serialized(typ: &ColumnType, val: &dyn SerializeValue) - where - T: for<'f> DeserializeValue<'f>, - T: FromCqlVal>, - T: Debug + PartialEq, - { - let raw = serialize(typ, val); - compat_check::(typ, raw); - } - fn deserialize<'frame, T>( typ: &'frame ColumnType, bytes: &'frame Bytes, From 9a3e6e634460a098a85ad8b8afc350476b18276a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 4 Jun 2024 16:55:13 +0200 Subject: [PATCH 40/41] value: add ser/de identity tests A suite of tests is added, which assert that serialization composed with deserialization yields identity. --- scylla-cql/src/types/deserialize/value.rs | 320 ++++++++++++++++++++-- 1 file changed, 302 insertions(+), 18 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index b0c5858d18..95103998e2 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1770,9 +1770,9 @@ pub(super) mod tests { use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Debug; - use std::net::{IpAddr, Ipv6Addr}; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; + use crate::frame::response::result::{ColumnType, CqlValue}; use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; @@ -1803,6 +1803,14 @@ pub(super) mod tests { assert_eq!(decoded_slice, ORIGINAL_BYTES); assert_eq!(decoded_vec, ORIGINAL_BYTES); assert_eq!(decoded_bytes, ORIGINAL_BYTES); + + // ser/de identity + + // Nonempty blob + assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); + + // Empty blob + assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); } #[test] @@ -1811,15 +1819,23 @@ pub(super) mod tests { let ascii = make_bytes(ASCII_TEXT.as_bytes()); - let decoded_ascii_str = deserialize::<&str>(&ColumnType::Ascii, &ascii).unwrap(); - let decoded_ascii_string = deserialize::(&ColumnType::Ascii, &ascii).unwrap(); - let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &ascii).unwrap(); - let decoded_text_string = deserialize::(&ColumnType::Text, &ascii).unwrap(); + for typ in [ColumnType::Ascii, ColumnType::Text].iter() { + let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); + let decoded_string = deserialize::(typ, &ascii).unwrap(); + + assert_eq!(decoded_str, ASCII_TEXT); + assert_eq!(decoded_string, ASCII_TEXT); + + // ser/de identity + + // Empty string + assert_ser_de_identity(typ, &"", &mut Bytes::new()); + assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); - assert_eq!(decoded_ascii_str, ASCII_TEXT); - assert_eq!(decoded_ascii_string, ASCII_TEXT); - assert_eq!(decoded_text_str, ASCII_TEXT); - assert_eq!(decoded_text_string, ASCII_TEXT); + // Nonempty string + assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); + assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); + } } #[test] @@ -1836,6 +1852,15 @@ pub(super) mod tests { let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); assert_eq!(decoded_text_str, UNICODE_TEXT); assert_eq!(decoded_text_string, UNICODE_TEXT); + + // ser/de identity + + assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Text, + &UNICODE_TEXT.to_owned(), + &mut Bytes::new(), + ); } #[test] @@ -1855,6 +1880,12 @@ pub(super) mod tests { let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); assert_eq!(decoded_bigint, 0x0102030405060708); + + // ser/de identity + assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); } #[test] @@ -1863,6 +1894,9 @@ pub(super) mod tests { let boolean_bytes = make_bytes(&[boolean as u8]); let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); assert_eq!(decoded_bool, boolean); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); } } @@ -1875,6 +1909,150 @@ pub(super) mod tests { let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); assert_eq!(decoded_double, 2.0); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); + } + + #[test] + fn test_varlen_numbers() { + // varint + assert_ser_de_identity( + &ColumnType::Varint, + &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-03")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-04")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + // decimal + assert_ser_de_identity( + &ColumnType::Decimal, + &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), + &mut Bytes::new(), + ); + + #[cfg(feature = "bigdecimal-04")] + assert_ser_de_identity( + &ColumnType::Decimal, + &bigdecimal_04::BigDecimal::new( + bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), + 42, + ), + &mut Bytes::new(), + ); + } + + #[test] + fn test_date_time_types() { + // duration + assert_ser_de_identity( + &ColumnType::Duration, + &CqlDuration { + months: 21, + days: 37, + nanoseconds: 42, + }, + &mut Bytes::new(), + ); + + // date + assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Date, + &chrono::NaiveDate::from_yo_opt(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Date, + &time::Date::from_ordinal_date(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + // time + assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Time, + &chrono::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Time, + &time::Time::from_hms_micro(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + // timestamp + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlTimestamp(0xceed), + &mut Bytes::new(), + ); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &chrono::DateTime::::from_timestamp_millis(0xdead_cafe_deaf).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &time::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), + &mut Bytes::new(), + ); + } + + #[test] + fn test_inet() { + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V4(Ipv4Addr::BROADCAST), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V6(Ipv6Addr::LOCALHOST), + &mut Bytes::new(), + ); + } + + #[test] + fn test_uuid() { + assert_ser_de_identity( + &ColumnType::Uuid, + &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timeuuid, + &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); } #[test] @@ -1903,6 +2081,15 @@ pub(super) mod tests { let int = make_bytes(&[]); let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Ascii)), + &None::>, + &mut Bytes::new(), + ); } #[test] @@ -1917,6 +2104,40 @@ pub(super) mod tests { assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); } + #[test] + fn test_cql_value() { + assert_ser_de_identity( + &ColumnType::Counter, + &CqlValue::Counter(Counter(765)), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlValue::Timestamp(CqlTimestamp(2136)), + &mut Bytes::new(), + ); + + assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); + + assert_ser_de_identity( + &ColumnType::Text, + &CqlValue::Text("kremówki".to_owned()), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &ColumnType::Ascii, + &CqlValue::Ascii("kremowy".to_owned()), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Text)), + &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), + &mut Bytes::new(), + ); + } + #[test] fn test_list_and_set() { let mut collection_contents = BytesMut::new(); @@ -1969,6 +2190,20 @@ pub(super) mod tests { decoded_btree_string, expected_vec_string.into_iter().collect(), ); + + // ser/de identity + assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity( + &set_typ, + &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &set_typ, + &BTreeSet::<&str>::from_iter(["qwik"]), + &mut Bytes::new(), + ); } #[test] @@ -2016,7 +2251,21 @@ pub(super) mod tests { decoded_btree_str, expected_str.clone().into_iter().collect(), ); - assert_eq!(decoded_btree_string, expected_string.into_iter().collect(),); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); + + // ser/de identity + assert_ser_de_identity( + &typ, + &HashMap::::from_iter([( + -42, "qwik", + )]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &typ, + &BTreeMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); } #[test] @@ -2032,6 +2281,37 @@ pub(super) mod tests { let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); assert_eq!(tup, (42, "foo", None)); + + // ser/de identity + + // () does not implement SerializeValue, yet it does implement DeserializeValue. + // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); + + // nonempty, varied tuple + assert_ser_de_identity( + &ColumnType::Tuple(vec![ + ColumnType::List(Box::new(ColumnType::Boolean)), + ColumnType::BigInt, + ColumnType::Uuid, + ColumnType::Inet, + ]), + &( + vec![true, false, true], + 42_i64, + Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), + ), + &mut Bytes::new(), + ); + + // nested tuples + assert_ser_de_identity( + &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ + ColumnType::Text, + ])])]), + &((("",),),), + &mut Bytes::new(), + ); } #[test] @@ -2101,13 +2381,6 @@ pub(super) mod tests { *buf = v.into(); } - fn make_ip_address(ip: IpAddr) -> Bytes { - match ip { - IpAddr::V4(v4) => make_bytes(&v4.octets()), - IpAddr::V6(v6) => make_bytes(&v6.octets()), - } - } - fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { b.put_i32(cell.len() as i32); b.put_slice(cell); @@ -2123,6 +2396,17 @@ pub(super) mod tests { b.put_i32(-1); } + fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( + typ: &'f ColumnType, + v: &'f T, + buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise + // we cannot specify the lifetime for DeserializeValue. + ) { + serialize_to_buf(typ, v, buf); + let deserialized = deserialize::(typ, buf).unwrap(); + assert_eq!(&deserialized, v); + } + /* Errors checks */ #[track_caller] From 30d51493d4577412ccf6633bc4765e900c8ec176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 12 Jun 2024 13:39:49 +0200 Subject: [PATCH 41/41] scylla lib: constrain pub deserialize reexports --- scylla-cql/src/types/deserialize/mod.rs | 3 ++ scylla/src/lib.rs | 39 ++++++++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 0e9b2a0582..12e73052ba 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -169,6 +169,9 @@ pub mod value; pub use frame_slice::FrameSlice; +pub use row::DeserializeRow; +pub use value::DeserializeValue; + use std::error::Error; use std::fmt::Display; use std::sync::Arc; diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 818c5ebbd0..e7b9afb7ee 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -126,10 +126,41 @@ pub mod frame { } } -// FIXME: finer-grained control over exports -// Some types are `pub` in scylla-cql just for scylla crate, -// and those shouldn't be exposed for users. -pub use scylla_cql::types::{deserialize, serialize}; +/// Serializing bound values of a query to be sent to the DB. +pub mod serialize { + pub use scylla_cql::types::serialize::*; +} + +/// Deserializing DB response containing CQL query results. +pub mod deserialize { + pub use scylla_cql::types::deserialize::{ + DeserializationError, DeserializeRow, DeserializeValue, FrameSlice, TypeCheckError, + }; + + /// Deserializing the whole query result contents. + pub mod result { + pub use scylla_cql::types::deserialize::result::{RowIterator, TypedRowIterator}; + } + + /// Deserializing a row of the query result. + pub mod row { + pub use scylla_cql::types::deserialize::row::{ + BuiltinDeserializationError, BuiltinDeserializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, ColumnIterator, RawColumn, + }; + } + + /// Deserializing a single CQL value from a column of the query result row. + pub mod value { + pub use scylla_cql::types::deserialize::value::{ + BuiltinDeserializationError, BuiltinDeserializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, Emptiable, ListlikeIterator, MapDeserializationErrorKind, + MapIterator, MapTypeCheckErrorKind, MaybeEmpty, SetOrListDeserializationErrorKind, + SetOrListTypeCheckErrorKind, TupleDeserializationErrorKind, TupleTypeCheckErrorKind, + UdtIterator, UdtTypeCheckErrorKind, + }; + } +} pub mod authentication; #[cfg(feature = "cloud")]