From 2acca753ec27dc4f44fbf6fa039659a842d986c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Sat, 25 May 2024 13:30:48 +0200 Subject: [PATCH 01/29] macros/parser: use syn::Path, not TokenStream syn provides strong typing on syntax items, which is better for us. Co-authored-by: Piotr Dulikowski --- scylla-macros/src/parser.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scylla-macros/src/parser.rs b/scylla-macros/src/parser.rs index ec72a81b1c..7c376c16f8 100644 --- a/scylla-macros/src/parser.rs +++ b/scylla-macros/src/parser.rs @@ -1,5 +1,4 @@ -use syn::{Data, DeriveInput, ExprLit, Fields, FieldsNamed, FieldsUnnamed, Lit}; -use syn::{Expr, Meta}; +use syn::{Data, DeriveInput, Expr, ExprLit, Fields, FieldsNamed, FieldsUnnamed, Lit, Meta}; /// Parses a struct DeriveInput and returns named fields of this struct. pub(crate) fn parse_named_fields<'a>( @@ -51,8 +50,8 @@ pub(crate) fn parse_struct_fields<'a>( } } -pub(crate) fn get_path(input: &DeriveInput) -> Result { - let mut this_path: Option = None; +pub(crate) fn get_path(input: &DeriveInput) -> Result { + let mut this_path: Option = None; for attr in input.attrs.iter() { if !attr.path().is_ident("scylla_crate") { continue; @@ -65,7 +64,7 @@ pub(crate) fn get_path(input: &DeriveInput) -> Result Result Date: Tue, 25 Jun 2024 12:46:02 +0200 Subject: [PATCH 02/29] deserialize: pub-ify items for use in macros --- scylla-cql/src/lib.rs | 4 +--- scylla-cql/src/types/deserialize/mod.rs | 4 +++- scylla-cql/src/types/deserialize/row.rs | 12 +++++++----- scylla-cql/src/types/deserialize/value.rs | 8 ++++++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index f8b3e07995..9d3ba489ce 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -27,7 +27,7 @@ pub mod _macro_internal { pub use crate::frame::response::cql_to_rust::{ FromCqlVal, FromCqlValError, FromRow, FromRowError, }; - pub use crate::frame::response::result::{CqlValue, Row}; + pub use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row}; pub use crate::frame::value::{ LegacySerializedValues, SerializedResult, Value, ValueList, ValueTooBig, }; @@ -51,6 +51,4 @@ pub mod _macro_internal { pub use crate::types::serialize::{ CellValueBuilder, CellWriter, RowWriter, SerializationError, }; - - pub use crate::frame::response::result::ColumnType; } diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs index 12e73052ba..2d8f0713e8 100644 --- a/scylla-cql/src/types/deserialize/mod.rs +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -245,7 +245,9 @@ impl Display for DeserializationError { // - BEFORE an error is cloned (because otherwise the Arc::get_mut fails). macro_rules! make_error_replace_rust_name { ($fn_name: ident, $outer_err: ty, $inner_err: ty) => { - fn $fn_name(mut err: $outer_err) -> $outer_err { + // Not part of the public API; used in derive macros. + #[doc(hidden)] + pub fn $fn_name(mut err: $outer_err) -> $outer_err { // Safety: the assumed usage of this function guarantees that the Arc has not yet been cloned. let arc_mut = std::sync::Arc::get_mut(&mut err.0).unwrap(); diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 5dfec4b12a..9e556537ea 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -260,7 +260,9 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_typck_err( cql_types: impl IntoIterator, kind: impl Into, ) -> TypeCheckError { @@ -338,13 +340,13 @@ pub struct BuiltinDeserializationError { pub kind: BuiltinDeserializationErrorKind, } -pub(super) fn mk_deser_err( - kind: impl Into, -) -> DeserializationError { +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_deser_err(kind: impl Into) -> DeserializationError { mk_deser_err_named(std::any::type_name::(), kind) } -pub(super) fn mk_deser_err_named( +fn mk_deser_err_named( name: &'static str, kind: impl Into, ) -> DeserializationError { diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 8431ea17cc..7ba4d38d84 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1335,7 +1335,9 @@ pub struct BuiltinTypeCheckError { pub kind: BuiltinTypeCheckErrorKind, } -fn mk_typck_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_typck_err( cql_type: &ColumnType, kind: impl Into, ) -> TypeCheckError { @@ -1574,7 +1576,9 @@ pub struct BuiltinDeserializationError { pub kind: BuiltinDeserializationErrorKind, } -pub(crate) fn mk_deser_err( +// Not part of the public API; used in derive macros. +#[doc(hidden)] +pub fn mk_deser_err( cql_type: &ColumnType, kind: impl Into, ) -> DeserializationError { From cc234cdeda251143b52fe913f29ae3c9fdad270d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 12:50:03 +0200 Subject: [PATCH 03/29] deserialize macros common helpers These helpers will be used by both DeserializeValue and DeserializeRow derive macros. Co-authored-by: Piotr Dulikowski --- scylla-macros/src/deserialize/mod.rs | 180 +++++++++++++++++++++++++++ scylla-macros/src/lib.rs | 2 + 2 files changed, 182 insertions(+) create mode 100644 scylla-macros/src/deserialize/mod.rs diff --git a/scylla-macros/src/deserialize/mod.rs b/scylla-macros/src/deserialize/mod.rs new file mode 100644 index 0000000000..523f06d907 --- /dev/null +++ b/scylla-macros/src/deserialize/mod.rs @@ -0,0 +1,180 @@ +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use syn::parse_quote; + +/// Common attributes that all deserialize impls should understand. +trait DeserializeCommonStructAttrs { + /// The path to either `scylla` or `scylla_cql` crate. + fn crate_path(&self) -> Option<&syn::Path>; + + /// The path to `macro_internal` module, + /// which contains exports used by macros. + fn macro_internal_path(&self) -> syn::Path { + match self.crate_path() { + Some(path) => parse_quote!(#path::_macro_internal), + None => parse_quote!(scylla::_macro_internal), + } + } +} + +/// Provides access to attributes that are common to DeserializeValue +/// and DeserializeRow traits. +trait DeserializeCommonFieldAttrs { + /// Does the type of this field need Default to be implemented? + fn needs_default(&self) -> bool; + + /// The type of the field, i.e. what this field deserializes to. + fn deserialize_target(&self) -> &syn::Type; +} + +/// A structure helpful in implementing DeserializeValue and DeserializeRow. +/// +/// It implements some common logic for both traits: +/// - Generates a unique lifetime that binds all other lifetimes in both structs, +/// - Adds appropriate trait bounds (DeserializeValue + Default) +struct StructDescForDeserialize { + name: syn::Ident, + attrs: Attrs, + fields: Vec, + constraint_trait: syn::Path, + constraint_lifetime: syn::Lifetime, + + generics: syn::Generics, +} + +impl StructDescForDeserialize +where + Attrs: FromAttributes + DeserializeCommonStructAttrs, + Field: FromField + DeserializeCommonFieldAttrs, +{ + fn new( + input: &syn::DeriveInput, + trait_name: &str, + constraint_trait: syn::Path, + ) -> Result { + let attrs = Attrs::from_attributes(&input.attrs)?; + + // TODO: support structs with unnamed fields. + // A few things to consider: + // - such support would necessarily require `enforce_order` and `skip_name_checks` attributes to be passed, + // - either: + // - the inner code would have to represent unnamed fields differently and handle the errors differently, + // - or we could use `.0, .1` or `0`, `1` as names for consecutive fields, making representation and error handling uniform. + let fields = crate::parser::parse_named_fields(input, trait_name) + .unwrap_or_else(|err| panic!("{}", err)) + .named + .iter() + .map(Field::from_field) + .collect::>()?; + + let constraint_lifetime = generate_unique_lifetime_for_impl(&input.generics); + + Ok(Self { + name: input.ident.clone(), + attrs, + fields, + constraint_trait, + constraint_lifetime, + generics: input.generics.clone(), + }) + } + + fn struct_attrs(&self) -> &Attrs { + &self.attrs + } + + fn constraint_lifetime(&self) -> &syn::Lifetime { + &self.constraint_lifetime + } + + fn fields(&self) -> &[Field] { + &self.fields + } + + fn generate_impl( + &self, + trait_: syn::Path, + items: impl IntoIterator, + ) -> syn::ItemImpl { + let constraint_lifetime = &self.constraint_lifetime; + let (_, ty_generics, _) = self.generics.split_for_impl(); + let impl_generics = &self.generics.params; + + let macro_internal = self.attrs.macro_internal_path(); + let struct_name = &self.name; + let predicates = generate_lifetime_constraints_for_impl( + &self.generics, + self.constraint_trait.clone(), + &self.constraint_lifetime, + ) + .chain(generate_default_constraints(&self.fields)); + let trait_: syn::Path = parse_quote!(#macro_internal::#trait_); + let items = items.into_iter(); + + parse_quote! { + impl<#constraint_lifetime, #impl_generics> #trait_<#constraint_lifetime> for #struct_name #ty_generics + where #(#predicates),* + { + #(#items)* + } + } + } +} + +/// Generates T: Default constraints for those fields that need it. +fn generate_default_constraints( + fields: &[Field], +) -> impl Iterator + '_ { + fields.iter().filter(|f| f.needs_default()).map(|f| { + let t = f.deserialize_target(); + parse_quote!(#t: std::default::Default) + }) +} + +/// Helps introduce a lifetime to an `impl` definition that constrains +/// other lifetimes and types. +/// +/// The original use case is DeserializeValue and DeserializeRow. Both of those traits +/// are parametrized with a lifetime. If T: DeserializeValue<'a> then this means +/// that you can deserialize T as some CQL value from bytes that have +/// lifetime 'a, similarly for DeserializeRow. In impls for those traits, +/// an additional lifetime must be introduced and properly constrained. +fn generate_lifetime_constraints_for_impl<'a>( + generics: &'a syn::Generics, + trait_full_name: syn::Path, + constraint_lifetime: &'a syn::Lifetime, +) -> impl Iterator + 'a { + // Constrain the new lifetime with the existing lifetime parameters + // 'lifetime: 'a + 'b + 'c ... + let mut lifetimes = generics.lifetimes().map(|l| &l.lifetime).peekable(); + let lifetime_constraints = std::iter::from_fn(move || { + let lifetimes = lifetimes.by_ref(); + lifetimes + .peek() + .is_some() + .then::(|| parse_quote!(#constraint_lifetime: #(#lifetimes)+*)) + }); + + // For each type parameter T, constrain it like this: + // T: DeserializeValue<'lifetime>, + let type_constraints = generics.type_params().map(move |t| { + let t_ident = &t.ident; + parse_quote!(#t_ident: #trait_full_name<#constraint_lifetime>) + }); + + lifetime_constraints.chain(type_constraints) +} + +/// Generates a new lifetime parameter, with a different name to any of the +/// existing generic lifetimes. +fn generate_unique_lifetime_for_impl(generics: &syn::Generics) -> syn::Lifetime { + let mut constraint_lifetime_name = "'lifetime".to_string(); + while generics + .lifetimes() + .any(|l| l.lifetime.to_string() == constraint_lifetime_name) + { + // Extend the lifetime name with another underscore. + constraint_lifetime_name += "_"; + } + syn::Lifetime::new(&constraint_lifetime_name, Span::call_site()) +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 5022f09f15..babb890481 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -66,3 +66,5 @@ pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { let res = value_list::value_list_derive(tokens_input); res.unwrap_or_else(|e| e.into_compile_error().into()) } + +mod deserialize; From a1d37d238afa08713d43189a32cd63522cc1e0a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 13:51:52 +0200 Subject: [PATCH 04/29] DeserializeValue: unordered flavour support DeserializeValue unordered flavour is added, analogous to SerializeValue unordered flavour. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/lib.rs | 11 + scylla-cql/src/types/deserialize/value.rs | 80 +++++ scylla-macros/src/deserialize/mod.rs | 2 + scylla-macros/src/deserialize/value.rs | 395 ++++++++++++++++++++++ scylla-macros/src/lib.rs | 9 +- scylla/src/macros.rs | 75 +++- 6 files changed, 569 insertions(+), 3 deletions(-) create mode 100644 scylla-macros/src/deserialize/value.rs diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 9d3ba489ce..2cb5327c3f 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -2,6 +2,7 @@ pub mod errors; pub mod frame; #[macro_use] pub mod macros { + pub use scylla_macros::DeserializeValue; pub use scylla_macros::FromRow; pub use scylla_macros::FromUserType; pub use scylla_macros::IntoUserType; @@ -33,6 +34,16 @@ pub mod _macro_internal { }; pub use crate::macros::*; + pub use crate::types::deserialize::value::{ + deser_error_replace_rust_name as value_deser_error_replace_rust_name, + mk_deser_err as mk_value_deser_err, mk_typck_err as mk_value_typck_err, + BuiltinDeserializationError as BuiltinTypeDeserializationError, + BuiltinDeserializationErrorKind as BuiltinTypeDeserializationErrorKind, + BuiltinTypeCheckErrorKind as DeserBuiltinTypeTypeCheckErrorKind, DeserializeValue, + UdtDeserializationErrorKind, UdtIterator, + UdtTypeCheckErrorKind as DeserUdtTypeCheckErrorKind, + }; + pub use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; pub use crate::types::serialize::row::{ BuiltinSerializationError as BuiltinRowSerializationError, BuiltinSerializationErrorKind as BuiltinRowSerializationErrorKind, diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 7ba4d38d84..1d56300259 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1549,6 +1549,33 @@ impl Display for TupleTypeCheckErrorKind { pub enum UdtTypeCheckErrorKind { /// The CQL type is not a user defined type. NotUdt, + + /// The CQL UDT type does not have some fields that is required in the Rust struct. + ValuesMissingForUdtFields { + /// Names of fields that the Rust struct requires but are missing in the CQL UDT. + field_names: Vec<&'static str>, + }, + + /// UDT contains an excess field, which does not correspond to any Rust struct's field. + ExcessFieldInUdt { + /// The name of the CQL UDT field. + db_field_name: String, + }, + + /// Duplicated field in serialized data. + DuplicatedField { + /// The name of the duplicated field. + field_name: String, + }, + + /// Type check failed between UDT and Rust type field. + FieldTypeCheckFailed { + /// The name of the field whose type check failed. + field_name: String, + + /// Inner type check error that occured. + err: TypeCheckError, + }, } impl Display for UdtTypeCheckErrorKind { @@ -1558,6 +1585,25 @@ impl Display for UdtTypeCheckErrorKind { f, "the CQL type the Rust type was attempted to be type checked against is not a UDT" ), + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { field_names } => { + write!(f, "the fields {field_names:?} are missing from the DB data but are required by the Rust type") + }, + UdtTypeCheckErrorKind::ExcessFieldInUdt { db_field_name } => write!( + f, + "UDT contains an excess field {}, which does not correspond to any Rust struct's field.", + db_field_name + ), + UdtTypeCheckErrorKind::DuplicatedField { field_name } => write!( + f, + "field {} occurs more than once in CQL UDT type", + field_name + ), + UdtTypeCheckErrorKind::FieldTypeCheckFailed { field_name, err } => write!( + f, + "the UDT field {} types between the CQL type and the Rust type failed to type check against each other: {}", + field_name, + err + ), } } } @@ -1643,6 +1689,9 @@ pub enum BuiltinDeserializationErrorKind { /// A deserialization failure specific to a CQL tuple. TupleError(TupleDeserializationErrorKind), + + /// A deserialization failure specific to a CQL UDT. + UdtError(UdtDeserializationErrorKind), } impl Display for BuiltinDeserializationErrorKind { @@ -1675,6 +1724,7 @@ impl Display for BuiltinDeserializationErrorKind { BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::UdtError(err) => err.fmt(f), BuiltinDeserializationErrorKind::CustomTypeNotSupported(typ) => write!(f, "Support for custom types is not yet implemented: {}", typ), } } @@ -1780,6 +1830,36 @@ impl From for BuiltinDeserializationErrorKind { } } +/// Describes why deserialization of a user defined type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtDeserializationErrorKind { + /// One of the fields failed to deserialize. + FieldDeserializationFailed { + /// Name of the field which failed to deserialize. + field_name: String, + + /// The error that caused the UDT field deserialization to fail. + err: DeserializationError, + }, +} + +impl Display for UdtDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtDeserializationErrorKind::FieldDeserializationFailed { field_name, err } => { + write!(f, "field {field_name} failed to deserialize: {err}") + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + fn from(err: UdtDeserializationErrorKind) -> Self { + Self::UdtError(err) + } +} + #[cfg(test)] pub(super) mod tests { use assert_matches::assert_matches; diff --git a/scylla-macros/src/deserialize/mod.rs b/scylla-macros/src/deserialize/mod.rs index 523f06d907..7533046cc4 100644 --- a/scylla-macros/src/deserialize/mod.rs +++ b/scylla-macros/src/deserialize/mod.rs @@ -2,6 +2,8 @@ use darling::{FromAttributes, FromField}; use proc_macro2::Span; use syn::parse_quote; +pub(crate) mod value; + /// Common attributes that all deserialize impls should understand. trait DeserializeCommonStructAttrs { /// The path to either `scylla` or `scylla_cql` crate. diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs new file mode 100644 index 0000000000..03ed4a28fb --- /dev/null +++ b/scylla-macros/src/deserialize/value.rs @@ -0,0 +1,395 @@ +use darling::{FromAttributes, FromField}; +use proc_macro::TokenStream; +use proc_macro2::Span; +use syn::{ext::IdentExt, parse_quote}; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> Option<&syn::Path> { + self.crate_path.as_ref() + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip + } + + fn deserialize_target(&self) -> &syn::Type { + &self.ty + } +} + +// derive(DeserializeValue) for the DeserializeValue trait +pub(crate) fn deserialize_value_derive( + tokens_input: TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait: syn::Path = parse_quote!(DeserializeValue); + let implemented_trait_name = implemented_trait + .segments + .last() + .unwrap() + .ident + .unraw() + .to_string(); + let constraining_trait = implemented_trait.clone(); + let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + + let items = [ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + let field_name = self.ident.as_ref().unwrap().unraw().to_string(); + syn::LitStr::new(&field_name, Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + /// Generates an expression which extracts the UDT fields or returns an error. + fn generate_extract_fields_from_type(&self, typ_expr: syn::Expr) -> syn::Expr { + let macro_internal = &self.struct_attrs().macro_internal_path(); + parse_quote!( + match #typ_expr { + #macro_internal::ColumnType::UserDefinedType { field_types, .. } => field_types, + other => return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + &other, + #macro_internal::DeserUdtTypeCheckErrorKind::NotUdt, + ) + ), + } + ) + } + + fn generate_type_check_method(&self) -> syn::ImplItemFn { + TypeCheckUnorderedGenerator(self).generate() + } + + fn generate_deserialize_method(&self) -> syn::ImplItemFn { + DeserializeUnorderedGenerator(self).generate() + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let visited_flag = Self::visited_flag_variable(field); + parse_quote! { + let mut #visited_flag = false; + } + }) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required: Option = field + .is_required() + .then(|| parse_quote! {remaining_required_cql_fields -= 1;}); + + parse_quote! { + { + if !#visited_flag { + <#typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(cql_field_typ) + .map_err(|err| #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldTypeCheckFailed { + field_name: <_ as ::std::clone::Clone>::clone(cql_field_name), + err, + } + ))?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::DuplicatedField { + field_name: <_ as ::std::borrow::ToOwned>::to_owned(#cql_name_literal), + } + ) + ) + } + } + } + }) + } + + // Generates code that appends the field name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + field.is_required().then(|| { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote!( + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + ) + }) + } + + // Generates the type_check method. + fn generate(&self) -> syn::ImplItemFn { + // The generated method will: + // - Check that every required field appears on the list exactly once, in any order + // - Every type on the list is correct + + let macro_internal = &self.0.struct_attrs().macro_internal_path(); + let rust_fields = self.0.fields(); + let visited_field_declarations = rust_fields + .iter() + .flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = rust_fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = rust_fields.iter().flat_map(Self::generate_append_name); + let rust_nonskipped_field_names = rust_fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let required_cql_field_count = rust_fields.iter().filter(|f| f.is_required()).count(); + let required_cql_field_count_lit = + syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site()); + let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + + parse_quote! { + fn type_check( + typ: &#macro_internal::ColumnType, + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Extract information about the field types from the UDT + // type definition. + let cql_fields = #extract_cql_fields_expr; + + // Counts down how many required fields are remaining + let mut remaining_required_cql_fields: ::std::primitive::usize = #required_cql_field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + for (cql_field_name, cql_field_typ) in cql_fields { + // Pattern match on the name and verify that the type is correct. + match cql_field_name.as_str() { + #(#rust_nonskipped_field_names => #type_check_blocks,)* + _unknown => { + // We ignore excess UDT fields, as this facilitates the process of adding new fields + // to a UDT in running production cluster & clients. + } + } + } + + if remaining_required_cql_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_cql_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ValuesMissingForUdtFields { + field_names: missing_fields, + } + ) + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + /// An identifier for a variable that is meant to store the parsed variable + /// before being ultimately moved to the struct on deserialize. + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + /// Generates an expression which produces a value ready to be put into a field + /// of the target structure. + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote!(#deserialize_field.unwrap_or_else(|| panic!( + "field {} missing in UDT - type check should have prevented this!", + #cql_name_literal + ))) + } + + /// Generates code that performs deserialization when the raw field + /// is being processed. + fn generate_deserialization(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + + let do_deserialize: syn::Expr = parse_quote! { + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(cql_field_typ, value) + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))? + }; + + parse_quote! { + { + assert!( + #deserialize_field.is_none(), + "duplicated field {} - type check should have prevented this!", + stringify!(#deserialize_field) + ); + + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + + #deserialize_field = ::std::option::Option::Some( + #do_deserialize + ); + } + } + }) + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let deserialize_field = Self::deserialize_field_variable(field); + parse_quote! { + let mut #deserialize_field = ::std::option::Option::None; + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields.iter().map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields.iter().flat_map(|f| self.generate_deserialization(f)); + let rust_field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let rust_nonskipped_field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + let iterator_type: syn::Type = parse_quote! { + #macro_internal::UdtIterator<#constraint_lifetime> + }; + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #macro_internal::ColumnType, + v: ::std::option::Option<#macro_internal::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let cql_field_iter = <#iterator_type as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(typ, v) + .map_err(#macro_internal::value_deser_error_replace_rust_name::)?; + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for item in cql_field_iter { + let ((cql_field_name, cql_field_typ), value_res) = item; + let value = value_res.map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: ::std::clone::Clone::clone(cql_field_name), + err, + } + ))?; + // Pattern match on the field name and deserialize. + match cql_field_name.as_str() { + #(#rust_nonskipped_field_names => #deserialize_blocks,)* + unknown => { + // Assuming we type checked sucessfully, this must be an excess field. + // Let's skip it. + }, + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + ::std::result::Result::Ok(Self { + #(#rust_field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index babb890481..5d67c85dc4 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -1,5 +1,5 @@ +use darling::ToTokens; use proc_macro::TokenStream; -use quote::ToTokens; mod from_row; mod from_user_type; @@ -68,3 +68,10 @@ pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { } mod deserialize; +#[proc_macro_derive(DeserializeValue, attributes(scylla))] +pub fn deserialize_value_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::value::deserialize_value_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 4bb29559bd..a154c95763 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -51,7 +51,7 @@ pub use scylla_cql::macros::IntoUserType; /// /// A UDT defined like this: /// -/// ```notrust +/// ```text /// CREATE TYPE ks.my_udt (a int, b text, c blob); /// ``` /// @@ -154,7 +154,7 @@ pub use scylla_cql::macros::SerializeValue; /// A UDT defined like this: /// Given a table and a query: /// -/// ```notrust +/// ```text /// CREATE TABLE ks.my_t (a int PRIMARY KEY, b text, c blob); /// INSERT INTO ks.my_t (a, b, c) VALUES (?, ?, ?); /// ``` @@ -237,6 +237,77 @@ pub use scylla_cql::macros::SerializeValue; /// pub use scylla_cql::macros::SerializeRow; +/// Derive macro for the `DeserializeValue` trait that generates an implementation +/// which deserializes a User Defined Type with the same layout as the Rust +/// struct. +/// +/// At the moment, only structs with named fields are supported. +/// +/// This macro properly supports structs with lifetimes, meaning that you can +/// deserialize UDTs with fields that borrow memory from the serialized response. +/// +/// # Example +/// +/// A UDT defined like this: +/// +/// ```text +/// CREATE TYPE ks.my_udt (a i32, b text, c blob); +/// ``` +/// +/// ...can be deserialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::DeserializeValue; +/// #[derive(DeserializeValue)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyUdt<'a> { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`DeserializeValue`](crate::deserialize::DeserializeValue) +/// trait using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::DeserializeValue; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +/// +pub use scylla_macros::DeserializeValue; + /// #[derive(ValueList)] allows to pass struct as a list of values for a query /// /// --- From e3240290f76f558ecc9b559f9c171f6cf9e1f5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 14:54:41 +0200 Subject: [PATCH 05/29] DeserializeValue: unordered flavour tests DeserializeValue unordered flavour is tested in the following aspects: - the macro executes properly on a struct, - the generated type_check() and deserialize() implementations are correct both in valid and invalid cases (i.e. return error in invalid cases and expected value in valid cases). Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 143 ++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 1d56300259..d6f98b4bbf 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2413,6 +2413,149 @@ pub(super) mod tests { ); } + fn udt_def_with_fields( + fields: impl IntoIterator, ColumnType)>, + ) -> ColumnType { + ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: fields.into_iter().map(|(s, t)| (s.into(), t)).collect(), + } + } + + #[must_use] + struct UdtSerializer { + buf: BytesMut, + } + + impl UdtSerializer { + fn new() -> Self { + Self { + buf: BytesMut::default(), + } + } + + fn field(mut self, field_bytes: &[u8]) -> Self { + append_bytes(&mut self.buf, field_bytes); + self + } + + fn finalize(&self) -> Bytes { + make_bytes(&self.buf) + } + } + + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[allow(unused)] + #[derive(scylla_macros::DeserializeValue)] + #[scylla(crate = crate)] + struct TestUdtWithNoFieldsUnordered {} + + #[test] + fn test_udt_loose_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: i64, + } + + // UDT fields in correct same order. + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // UDT fields switched - should still work. + { + let udt = UdtSerializer::new() + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // An excess UDT field - should still work. + { + let udt = UdtSerializer::new() + .field(&12_i8.to_be_bytes()) + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("d", ColumnType::TinyInt), + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + Udt::type_check(&typ).unwrap(); + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + } + #[test] fn test_custom_type_parser() { #[derive(Default, Debug, PartialEq, Eq)] From 472badd323b8e7537da352917543d892ae932696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 15:07:34 +0200 Subject: [PATCH 06/29] DeserializeValue: unordered flavour errors tests DeserializeValue unordered flavour is tested in the following aspects: - the generated type_check() and deserialize() implementations produce meaningful, appropriate errors in invalid cases. --- scylla-cql/src/types/deserialize/value.rs | 117 ++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index d6f98b4bbf..a02f3df791 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1886,6 +1886,7 @@ pub(super) mod tests { BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, + UdtDeserializationErrorKind, UdtTypeCheckErrorKind, }; #[test] @@ -3161,4 +3162,120 @@ pub(super) mod tests { deserialize::>(&ser_typ, &bytes).unwrap_err(); } + + #[test] + fn test_udt_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = + ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = + err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT missing fields + { + let typ = udt_def_with_fields([("c", ColumnType::Boolean)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { + field_names: ref missing_fields, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["a", "b"]); + } + + // missing UDT field + { + let typ = + udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { ref field_names }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_names, &["c"]); + } + + // UDT fields incompatible types - field type check failed + { + let typ = + udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("c", ColumnType::Boolean), + ("a", ColumnType::Blob), + ("b", ColumnType::Int), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + } + } + } } From 894414bb0ea3e47f9691d0ceb4c7d18d320fd2d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 16:00:18 +0200 Subject: [PATCH 07/29] DeserializeRow: unordered flavour support DeserializeRow unordered flavour is added, analogous to SerializeRow unordered flavour. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/lib.rs | 9 + scylla-cql/src/types/deserialize/row.rs | 74 +++++ scylla-macros/src/deserialize/mod.rs | 1 + scylla-macros/src/deserialize/row.rs | 364 ++++++++++++++++++++++++ scylla-macros/src/lib.rs | 9 + scylla/src/macros.rs | 84 ++++++ 6 files changed, 541 insertions(+) create mode 100644 scylla-macros/src/deserialize/row.rs diff --git a/scylla-cql/src/lib.rs b/scylla-cql/src/lib.rs index 2cb5327c3f..61d9f345c5 100644 --- a/scylla-cql/src/lib.rs +++ b/scylla-cql/src/lib.rs @@ -2,6 +2,7 @@ pub mod errors; pub mod frame; #[macro_use] pub mod macros { + pub use scylla_macros::DeserializeRow; pub use scylla_macros::DeserializeValue; pub use scylla_macros::FromRow; pub use scylla_macros::FromUserType; @@ -34,6 +35,14 @@ pub mod _macro_internal { }; pub use crate::macros::*; + pub use crate::types::deserialize::row::{ + deser_error_replace_rust_name as row_deser_error_replace_rust_name, + mk_deser_err as mk_row_deser_err, mk_typck_err as mk_row_typck_err, + BuiltinDeserializationError as BuiltinRowDeserializationError, + BuiltinDeserializationErrorKind as BuiltinRowDeserializationErrorKind, + BuiltinTypeCheckErrorKind as DeserBuiltinRowTypeCheckErrorKind, ColumnIterator, + DeserializeRow, + }; pub use crate::types::deserialize::value::{ deser_error_replace_rust_name as value_deser_error_replace_rust_name, mk_deser_err as mk_value_deser_err, mk_typck_err as mk_value_typck_err, diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 9e556537ea..a45c8f99ac 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -294,6 +294,38 @@ pub enum BuiltinTypeCheckErrorKind { cql_cols: usize, }, + /// The CQL row contains a column for which a corresponding field is not found + /// in the Rust type. + ColumnWithUnknownName { + /// Index of the excess column. + column_index: usize, + + /// Name of the column that is present in CQL row but not in the Rust type. + column_name: String, + }, + + /// Several values required by the Rust type are not provided by the DB. + ValuesMissingForColumns { + /// Names of the columns in the Rust type for which the DB doesn't + /// provide value. + column_names: Vec<&'static str>, + }, + + /// A different column name was expected at given position. + ColumnNameMismatch { + /// Index of the field determining the expected name. + field_index: usize, + + /// Index of the column having mismatched name. + column_index: usize, + + /// Name of the column, as expected by the Rust type. + rust_column_name: &'static str, + + /// Name of the column for which the DB requested a value. + db_column_name: String, + }, + /// Column type check failed between Rust type and DB type at given position (=in given column). ColumnTypeCheckFailed { /// Index of the column. @@ -305,6 +337,15 @@ pub enum BuiltinTypeCheckErrorKind { /// Inner type check error due to the type mismatch. err: TypeCheckError, }, + + /// Duplicated column in DB metadata. + DuplicatedColumn { + /// Column index of the second occurence of the column with the same name. + column_index: usize, + + /// The name of the duplicated column. + column_name: &'static str, + }, } impl Display for BuiltinTypeCheckErrorKind { @@ -316,6 +357,33 @@ impl Display for BuiltinTypeCheckErrorKind { } => { write!(f, "wrong column count: the statement operates on {cql_cols} columns, but the given rust types contains {rust_cols}") } + BuiltinTypeCheckErrorKind::ColumnWithUnknownName { column_name, column_index } => { + write!( + f, + "the CQL row contains a column {} at column index {}, but the corresponding field is not found in the Rust type", + column_name, + column_index, + ) + } + BuiltinTypeCheckErrorKind::ValuesMissingForColumns { column_names } => { + write!( + f, + "values for columns {:?} are missing from the DB data but are required by the Rust type", + column_names + ) + }, + BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index,rust_column_name, + db_column_name + } => write!( + f, + "expected column with name {} at column index {}, but the Rust field name at corresponding field index {} is {}", + db_column_name, + column_index, + field_index, + rust_column_name, + ), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { column_index, column_name, @@ -324,6 +392,12 @@ impl Display for BuiltinTypeCheckErrorKind { f, "mismatched types in column {column_name} at index {column_index}: {err}" ), + BuiltinTypeCheckErrorKind::DuplicatedColumn { column_name, column_index } => write!( + f, + "column {} occurs more than once in DB metadata; second occurence is at column index {}", + column_name, + column_index, + ), } } } diff --git a/scylla-macros/src/deserialize/mod.rs b/scylla-macros/src/deserialize/mod.rs index 7533046cc4..9b038ecfb6 100644 --- a/scylla-macros/src/deserialize/mod.rs +++ b/scylla-macros/src/deserialize/mod.rs @@ -2,6 +2,7 @@ use darling::{FromAttributes, FromField}; use proc_macro2::Span; use syn::parse_quote; +pub(crate) mod row; pub(crate) mod value; /// Common attributes that all deserialize impls should understand. diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs new file mode 100644 index 0000000000..39503fc697 --- /dev/null +++ b/scylla-macros/src/deserialize/row.rs @@ -0,0 +1,364 @@ +use darling::{FromAttributes, FromField}; +use proc_macro2::Span; +use syn::ext::IdentExt; +use syn::parse_quote; + +use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; + +#[derive(FromAttributes)] +#[darling(attributes(scylla))] +struct StructAttrs { + #[darling(rename = "crate")] + crate_path: Option, +} + +impl DeserializeCommonStructAttrs for StructAttrs { + fn crate_path(&self) -> Option<&syn::Path> { + self.crate_path.as_ref() + } +} + +#[derive(FromField)] +#[darling(attributes(scylla))] +struct Field { + // If true, then the field is not parsed at all, but it is initialized + // with Default::default() instead. All other attributes are ignored. + #[darling(default)] + skip: bool, + + ident: Option, + ty: syn::Type, +} + +impl DeserializeCommonFieldAttrs for Field { + fn needs_default(&self) -> bool { + self.skip + } + + fn deserialize_target(&self) -> &syn::Type { + &self.ty + } +} + +// derive(DeserializeRow) for the new DeserializeRow trait +pub(crate) fn deserialize_row_derive( + tokens_input: proc_macro::TokenStream, +) -> Result { + let input = syn::parse(tokens_input)?; + + let implemented_trait: syn::Path = parse_quote! { DeserializeRow }; + let implemented_trait_name = implemented_trait + .segments + .last() + .unwrap() + .ident + .unraw() + .to_string(); + let constraining_trait = parse_quote! { DeserializeValue }; + let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + + let items = [ + s.generate_type_check_method().into(), + s.generate_deserialize_method().into(), + ]; + + Ok(s.generate_impl(implemented_trait, items)) +} + +impl Field { + // Returns whether this field is mandatory for deserialization. + fn is_required(&self) -> bool { + !self.skip + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + let field_name = self.ident.as_ref().unwrap().unraw().to_string(); + syn::LitStr::new(&field_name, Span::call_site()) + } +} + +type StructDesc = super::StructDescForDeserialize; + +impl StructDesc { + fn generate_type_check_method(&self) -> syn::ImplItemFn { + TypeCheckUnorderedGenerator(self).generate() + } + + fn generate_deserialize_method(&self) -> syn::ImplItemFn { + DeserializeUnorderedGenerator(self).generate() + } +} + +struct TypeCheckUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckUnorderedGenerator<'sd> { + // An identifier for a bool variable that represents whether given + // field was already visited during type check + fn visited_flag_variable(field: &Field) -> syn::Ident { + quote::format_ident!("visited_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates a declaration of a "visited" flag for the purpose of type check. + // We generate it even if the flag is not required in order to protect + // from fields appearing more than once + fn generate_visited_flag_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let visited_flag = Self::visited_flag_variable(field); + parse_quote! { + let mut #visited_flag = false; + } + }) + } + + // Generates code that, given variable `typ`, type-checks given field + fn generate_type_check(&self, field: &Field) -> Option { + (!field.skip).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let visited_flag = Self::visited_flag_variable(field); + let typ = field.deserialize_target(); + let cql_name_literal = field.cql_name_literal(); + let decrement_if_required: Option:: = field.is_required().then(|| parse_quote! { + remaining_required_fields -= 1; + }); + + parse_quote! { + { + if !#visited_flag { + <#typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(&spec.typ) + .map_err(|err| { + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name: <_ as ::std::borrow::ToOwned>::to_owned(#cql_name_literal), + err, + } + ) + })?; + #visited_flag = true; + #decrement_if_required + } else { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::DuplicatedColumn { + column_index, + column_name: #cql_name_literal, + } + ) + ) + } + } + } + }) + } + + // Generates code that appends the flag name if it is missing. + // The generated code is used to construct a nice error message. + fn generate_append_name(field: &Field) -> Option { + field.is_required().then(|| { + let visited_flag = Self::visited_flag_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + { + if !#visited_flag { + missing_fields.push(#cql_name_literal); + } + } + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + + let fields = self.0.fields(); + let visited_field_declarations = fields.iter().flat_map(Self::generate_visited_flag_decl); + let type_check_blocks = fields.iter().flat_map(|f| self.generate_type_check(f)); + let append_name_blocks = fields.iter().flat_map(Self::generate_append_name); + let nonskipped_field_names = fields + .iter() + .filter(|f| !f.skip) + .map(|f| f.cql_name_literal()); + let field_count_lit = fields.iter().filter(|f| f.is_required()).count(); + + parse_quote! { + fn type_check( + specs: &[#macro_internal::ColumnSpec], + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Counts down how many required fields are remaining + let mut remaining_required_fields: ::std::primitive::usize = #field_count_lit; + + // For each required field, generate a "visited" boolean flag + #(#visited_field_declarations)* + + let column_types_iter = || specs.iter().map(|spec| ::std::clone::Clone::clone(&spec.typ)); + + for (column_index, spec) in specs.iter().enumerate() { + // Pattern match on the name and verify that the type is correct. + match spec.name.as_str() { + #(#nonskipped_field_names => #type_check_blocks,)* + _unknown => { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnWithUnknownName { + column_index, + column_name: <_ as ::std::clone::Clone>::clone(&spec.name) + } + ) + ) + } + } + } + + if remaining_required_fields > 0 { + // If there are some missing required fields, generate an error + // which contains missing field names + let mut missing_fields = ::std::vec::Vec::<&'static str>::with_capacity(remaining_required_fields); + #(#append_name_blocks)* + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ValuesMissingForColumns { + column_names: missing_fields + } + ) + ) + } + + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeUnorderedGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeUnorderedGenerator<'sd> { + // An identifier for a variable that is meant to store the parsed variable + // before being ultimately moved to the struct on deserialize + fn deserialize_field_variable(field: &Field) -> syn::Ident { + quote::format_ident!("f_{}", field.ident.as_ref().unwrap().unraw()) + } + + // Generates an expression which produces a value ready to be put into a field + // of the target structure + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote! { + ::std::default::Default::default() + }; + } + + let deserialize_field = Self::deserialize_field_variable(field); + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + #deserialize_field.unwrap_or_else(|| panic!( + "column {} missing in DB row - type check should have prevented this!", + #cql_name_literal + )) + } + } + + // Generated code that performs deserialization when the raw field + // is being processed + fn generate_deserialization(&self, column_index: usize, field: &Field) -> syn::Expr { + assert!(!field.skip); + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let deserialize_field = Self::deserialize_field_variable(field); + let deserializer = field.deserialize_target(); + + parse_quote! { + { + assert!( + #deserialize_field.is_none(), + "duplicated column {} - type check should have prevented this!", + stringify!(#deserialize_field) + ); + + #deserialize_field = ::std::option::Option::Some( + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice) + .map_err(|err| { + #macro_internal::mk_row_deser_err::( + #macro_internal::BuiltinRowDeserializationErrorKind::ColumnDeserializationFailed { + column_index: #column_index, + column_name: <_ as std::clone::Clone>::clone(&col.spec.name), + err, + } + ) + })? + ); + } + } + } + + // Generate a declaration of a variable that temporarily keeps + // the deserialized value + fn generate_deserialize_field_decl(field: &Field) -> Option { + (!field.skip).then(|| { + let deserialize_field = Self::deserialize_field_variable(field); + parse_quote! { + let mut #deserialize_field = ::std::option::Option::None; + } + }) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let deserialize_field_decls = fields + .iter() + .flat_map(Self::generate_deserialize_field_decl); + let deserialize_blocks = fields + .iter() + .filter(|f| !f.skip) + .enumerate() + .map(|(col_idx, f)| self.generate_deserialization(col_idx, f)); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let nonskipped_field_names = fields + .iter() + .filter(|&f| (!f.skip)) + .map(|f| f.cql_name_literal()); + + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + // TODO: Allow collecting unrecognized fields into some special field + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #macro_internal::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + + // Generate fields that will serve as temporary storage + // for the fields' values. Those are of type Option. + #(#deserialize_field_decls)* + + for col in row { + let col = col.map_err(#macro_internal::row_deser_error_replace_rust_name::)?; + // Pattern match on the field name and deserialize. + match col.spec.name.as_str() { + #(#nonskipped_field_names => #deserialize_blocks,)* + unknown => unreachable!("Typecheck should have prevented this scenario! Unknown column name: {}", unknown), + } + } + + // Create the final struct. The finalizer expressions convert + // the temporary storage fields to the final field values. + // For example, if a field is missing but marked as + // `default_when_null` it will create a default value, otherwise + // it will report an error. + Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } + } +} diff --git a/scylla-macros/src/lib.rs b/scylla-macros/src/lib.rs index 5d67c85dc4..05e24362f4 100644 --- a/scylla-macros/src/lib.rs +++ b/scylla-macros/src/lib.rs @@ -68,6 +68,15 @@ pub fn value_list_derive(tokens_input: TokenStream) -> TokenStream { } mod deserialize; + +#[proc_macro_derive(DeserializeRow, attributes(scylla))] +pub fn deserialize_row_derive(tokens_input: TokenStream) -> TokenStream { + match deserialize::row::deserialize_row_derive(tokens_input) { + Ok(tokens) => tokens.into_token_stream().into(), + Err(err) => err.into_compile_error().into(), + } +} + #[proc_macro_derive(DeserializeValue, attributes(scylla))] pub fn deserialize_value_derive(tokens_input: TokenStream) -> TokenStream { match deserialize::value::deserialize_value_derive(tokens_input) { diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index a154c95763..404be45308 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -308,6 +308,90 @@ pub use scylla_cql::macros::SerializeRow; /// pub use scylla_macros::DeserializeValue; +/// Derive macro for the `DeserializeRow` trait that generates an implementation +/// which deserializes a row with a similar layout to the Rust struct. +/// +/// At the moment, only structs with named fields are supported. +/// +/// This macro properly supports structs with lifetimes, meaning that you can +/// deserialize columns that borrow memory from the serialized response. +/// +/// # Example +/// +/// Having a table defined like this: +/// +/// ```text +/// CREATE TABLE ks.my_table (a PRIMARY KEY, b text, c blob); +/// ``` +/// +/// results of a query "SELECT * FROM ks.my_table" +/// or "SELECT a, b, c FROM ks.my_table" +/// can be deserialized using the following struct: +/// +/// ```rust +/// # use scylla_cql::macros::DeserializeRow; +/// #[derive(DeserializeRow)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyRow<'a> { +/// a: i32, +/// b: Option, +/// c: &'a [u8], +/// } +/// ``` +/// +/// In general, the struct must match the queried names and types, +/// not the table itself. For example, the query +/// "SELECT a AS b FROM ks.my_table" executed against +/// the aforementioned table can be deserialized to the struct: +/// ```rust +/// # use scylla_cql::macros::DeserializeRow; +/// #[derive(DeserializeRow)] +/// # #[scylla(crate = "scylla_cql")] +/// struct MyRow { +/// b: i32, +/// } +/// ``` +/// +/// # Attributes +/// +/// The macro supports a number of attributes that customize the generated +/// implementation. Many of the attributes were inspired by procedural macros +/// from `serde` and try to follow the same naming conventions. +/// +/// ## Struct attributes +/// +/// `#[scylla(crate = "crate_name")]` +/// +/// By default, the code generated by the derive macro will refer to the items +/// defined by the driver (types, traits, etc.) via the `::scylla` path. +/// For example, it will refer to the [`DeserializeValue`](crate::deserialize::DeserializeValue) +/// trait using the following path: +/// +/// ```rust,ignore +/// use ::scylla::_macro_internal::DeserializeValue; +/// ``` +/// +/// Most users will simply add `scylla` to their dependencies, then use +/// the derive macro and the path above will work. However, there are some +/// niche cases where this path will _not_ work: +/// +/// - The `scylla` crate is imported under a different name, +/// - The `scylla` crate is _not imported at all_ - the macro actually +/// is defined in the `scylla-macros` crate and the generated code depends +/// on items defined in `scylla-cql`. +/// +/// It's not possible to automatically resolve those issues in the procedural +/// macro itself, so in those cases the user must provide an alternative path +/// to either the `scylla` or `scylla-cql` crate. +/// +/// ## Field attributes +/// +/// `#[scylla(skip)]` +/// +/// The field will be completely ignored during deserialization and will +/// be initialized with `Default::default()`. +pub use scylla_macros::DeserializeRow; + /// #[derive(ValueList)] allows to pass struct as a list of values for a query /// /// --- From 01f1afdc8f42e64b332064c02ae3050b95943f0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 16:06:02 +0200 Subject: [PATCH 08/29] DeserializeRow: unordered flavour tests DeserializeRow unordered flavour is tested in the following aspects: - the macro executes properly on a struct, - the generated type_check() and deserialize() implementations are correct both in valid and invalid cases (i.e. return error in invalid cases and expected value in valid cases). Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 55 +++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index a45c8f99ac..b9ca57abd3 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -491,6 +491,7 @@ impl Display for BuiltinDeserializationErrorKind { mod tests { use assert_matches::assert_matches; use bytes::Bytes; + use scylla_macros::DeserializeRow; use crate::frame::response::result::{ColumnSpec, ColumnType}; use crate::types::deserialize::row::BuiltinDeserializationErrorKind; @@ -571,6 +572,60 @@ mod tests { assert!(iter.next().is_none()); } + // Do not remove. It's not used in tests but we keep it here to check that + // we properly ignore warnings about unused variables, unnecessary `mut`s + // etc. that usually pop up when generating code for empty structs. + #[allow(unused)] + #[derive(DeserializeRow)] + #[scylla(crate = crate)] + struct TestUdtWithNoFieldsUnordered {} + + #[test] + fn test_struct_deserialization_loose_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Original order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Different order of columns - should still work + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + let byts = serialize_cells([val_int(123), val_str("abc")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); + } + fn val_int(i: i32) -> Option> { Some(i.to_be_bytes().to_vec()) } From f740d495e2d058278b2d7e57f643d7c4731fef9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 26 Jun 2024 16:25:26 +0200 Subject: [PATCH 09/29] DeserializeRow: unordered flavour errors tests DeserializeRow unordered flavour is tested in the following aspects: - the generated type_check() and deserialize() implementations produce meaningful, appropriate errors in invalid cases. --- scylla-cql/src/types/deserialize/row.rs | 199 +++++++++++++++++++++++- 1 file changed, 196 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index b9ca57abd3..f3146acad8 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -495,7 +495,7 @@ mod tests { use crate::frame::response::result::{ColumnSpec, ColumnType}; use crate::types::deserialize::row::BuiltinDeserializationErrorKind; - use crate::types::deserialize::{DeserializationError, FrameSlice}; + use crate::types::deserialize::{value, DeserializationError, FrameSlice}; use super::super::tests::{serialize_cells, spec}; use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; @@ -649,13 +649,20 @@ mod tests { } #[track_caller] - fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { - match err.0.downcast_ref() { + pub(crate) fn get_typck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), + ) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { Some(err) => err, None => panic!("not a BuiltinTypeCheckError: {:?}", err), } } + #[track_caller] + fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typck_err_inner(err.0.as_ref()) + } + #[track_caller] fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { match err.0.downcast_ref() { @@ -811,4 +818,190 @@ mod tests { assert_eq!(column_name, col_name); } } + + fn specs_to_types(specs: &[ColumnSpec]) -> Vec { + specs.iter().map(|spec| spec.typ.clone()).collect() + } + + #[test] + fn test_struct_deserialization_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Missing column + { + let specs = [spec("a", ColumnType::Ascii), spec("b", ColumnType::Int)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ValuesMissingForColumns { + column_names: ref missing_fields, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["c"]); + } + + // Duplicated column + { + let specs = [ + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::DuplicatedColumn { + column_index, + column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "a"); + } + + // Unknown column + { + let specs = [ + spec("d", ColumnType::Counter), + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnWithUnknownName { + column_index, + ref column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "d"); + } + + // Column incompatible types - column type check failed + { + let specs = [spec("b", ColumnType::Int), spec("a", ColumnType::Blob)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let specs = [ + spec("c", ColumnType::Boolean), + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "a"); + } + + // Column deserialization failed + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [ + &0_i32.to_be_bytes(), + "alamakota".as_bytes(), + &42_i16.to_be_bytes(), + ] + .map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name.as_str(), "c"); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } + } + } + } } From bab6db8cd19333fa8fa61722aa0b0008123658fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 16:08:28 +0200 Subject: [PATCH 10/29] Deserialize{Value,Row}: support `rename` attribute The attribute allows for matching a struct field to a CQL UDT field with a different name. A row test is modified to showcase and check the new feature. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 3 ++- scylla-macros/src/deserialize/row.rs | 11 ++++++++++- scylla-macros/src/deserialize/value.rs | 11 ++++++++++- scylla/src/macros.rs | 11 +++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index f3146acad8..bb2b3ba2e3 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -834,7 +834,8 @@ mod tests { #[scylla(skip)] x: String, b: Option, - c: bool, + #[scylla(rename = "c")] + d: bool, } // Type check errors diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index 39503fc697..c6726c3fc8 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -26,6 +26,12 @@ struct Field { #[darling(default)] skip: bool, + // If set, then deserialization will look for the column with given name + // and deserialize it to this Rust field, instead of just using the Rust + // field name. + #[darling(default)] + rename: Option, + ident: Option, ty: syn::Type, } @@ -73,7 +79,10 @@ impl Field { // A Rust literal representing the name of this field fn cql_name_literal(&self) -> syn::LitStr { - let field_name = self.ident.as_ref().unwrap().unraw().to_string(); + let field_name = match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + }; syn::LitStr::new(&field_name, Span::call_site()) } } diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 03ed4a28fb..2f800dd1a0 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -25,6 +25,12 @@ struct Field { // with Default::default() instead. All other attributes are ignored. #[darling(default)] skip: bool, + + // If set, then deserializes from the UDT field with this particular name + // instead of the Rust field name. + #[darling(default)] + rename: Option, + ident: Option, ty: syn::Type, } @@ -71,7 +77,10 @@ impl Field { // A Rust literal representing the name of this field fn cql_name_literal(&self) -> syn::LitStr { - let field_name = self.ident.as_ref().unwrap().unraw().to_string(); + let field_name = match self.rename.as_ref() { + Some(rename) => rename.to_owned(), + None => self.ident.as_ref().unwrap().unraw().to_string(), + }; syn::LitStr::new(&field_name, Span::call_site()) } } diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 404be45308..307d975efb 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -306,6 +306,11 @@ pub use scylla_cql::macros::SerializeRow; /// The field will be completely ignored during deserialization and will /// be initialized with `Default::default()`. /// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a UDT field with the same name. This attribute instead allows to match +/// to a UDT field with provided name. pub use scylla_macros::DeserializeValue; /// Derive macro for the `DeserializeRow` trait that generates an implementation @@ -390,6 +395,12 @@ pub use scylla_macros::DeserializeValue; /// /// The field will be completely ignored during deserialization and will /// be initialized with `Default::default()`. +/// +/// `#[scylla(rename = "field_name")` +/// +/// By default, the generated implementation will try to match the Rust field +/// to a column with the same name. This attribute allows to match to a column +/// with provided name. pub use scylla_macros::DeserializeRow; /// #[derive(ValueList)] allows to pass struct as a list of values for a query From a9ae38ab0349efc553dd5451f66e99d549677f64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 16:26:53 +0200 Subject: [PATCH 11/29] DeserializeValue: enforce_order flavour support DeserializeValue enforce_order flavour is added, analogous to SerializeValue enforce_order flavour. The flavour requires that fields are ordered the same way in both Rust struct and in CQL UDT definition. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 29 +++ scylla-macros/src/deserialize/value.rs | 250 +++++++++++++++++++++- scylla/src/macros.rs | 11 + 3 files changed, 287 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index a02f3df791..a98dfbe429 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1556,6 +1556,18 @@ pub enum UdtTypeCheckErrorKind { field_names: Vec<&'static str>, }, + /// A different field name was expected at given position. + FieldNameMismatch { + /// Index of the field in the Rust struct. + position: usize, + + /// The name of the Rust field. + rust_field_name: String, + + /// The name of the CQL UDT field. + db_field_name: String, + }, + /// UDT contains an excess field, which does not correspond to any Rust struct's field. ExcessFieldInUdt { /// The name of the CQL UDT field. @@ -1568,6 +1580,13 @@ pub enum UdtTypeCheckErrorKind { field_name: String, }, + /// Fewer fields present in the UDT than required by the Rust type. + TooFewFields { + // TODO: decide whether we are OK with restricting to `&'static str` here. + required_fields: Vec<&'static str>, + present_fields: Vec, + }, + /// Type check failed between UDT and Rust type field. FieldTypeCheckFailed { /// The name of the field whose type check failed. @@ -1588,6 +1607,10 @@ impl Display for UdtTypeCheckErrorKind { UdtTypeCheckErrorKind::ValuesMissingForUdtFields { field_names } => { write!(f, "the fields {field_names:?} are missing from the DB data but are required by the Rust type") }, + UdtTypeCheckErrorKind::FieldNameMismatch { rust_field_name, db_field_name, position } => write!( + f, + "expected field with name {db_field_name} at position {position}, but the Rust field name is {rust_field_name}" + ), UdtTypeCheckErrorKind::ExcessFieldInUdt { db_field_name } => write!( f, "UDT contains an excess field {}, which does not correspond to any Rust struct's field.", @@ -1598,6 +1621,12 @@ impl Display for UdtTypeCheckErrorKind { "field {} occurs more than once in CQL UDT type", field_name ), + UdtTypeCheckErrorKind::TooFewFields { required_fields, present_fields } => write!( + f, + "fewer fields present in the UDT than required by the Rust type: UDT has {:?}, Rust type requires {:?}", + present_fields, + required_fields, + ), UdtTypeCheckErrorKind::FieldTypeCheckFailed { field_name, err } => write!( f, "the UDT field {} types between the CQL type and the Rust type failed to type check against each other: {}", diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 2f800dd1a0..5ec7182d94 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -10,6 +10,13 @@ use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; struct StructAttrs { #[darling(rename = "crate")] crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the UDT. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -105,11 +112,248 @@ impl StructDesc { } fn generate_type_check_method(&self) -> syn::ImplItemFn { - TypeCheckUnorderedGenerator(self).generate() + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } } fn generate_deserialize_method(&self) -> syn::ImplItemFn { - DeserializeUnorderedGenerator(self).generate() + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + // Generates name and type validation for given Rust struct's field. + fn generate_field_validation(&self, rust_field_idx: usize, field: &Field) -> syn::Expr { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let rust_field_name = field.cql_name_literal(); + let rust_field_typ = field.deserialize_target(); + + // Action performed in case of field name mismatch. + let name_mismatch: syn::Expr = parse_quote! { + { + // Error - required value for field not present among the CQL fields. + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldNameMismatch { + position: #rust_field_idx, + rust_field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), + db_field_name: <_ as ::std::borrow::ToOwned>::to_owned(cql_field_name), + } + ) + ); + } + }; + + let name_verification: syn::Expr = parse_quote! { + if #rust_field_name != cql_field_name { + // The read UDT field is not the one expected by the Rust struct. + #name_mismatch + } + }; + + parse_quote! { + 'field: { + let next_cql_field = match cql_field_iter.next() { + ::std::option::Option::Some(cql_field) => cql_field, + ::std::option::Option::None => return Err(too_few_fields()), + }; + let (cql_field_name, cql_field_typ) = next_cql_field; + + 'verifications: { + #name_verification + + // Verify the type + <#rust_field_typ as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(cql_field_typ) + .map_err(|err| #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldTypeCheckFailed { + field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), + err, + } + ))?; + } + } + } + } + + // Generates the type_check method for when ensure_order == true. + fn generate(&self) -> syn::ImplItemFn { + // The generated method will: + // - Check that every required field appears on the list in the same order as struct fields + // - Every type on the list is correct + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + + let extract_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + + let required_fields_iter = || self.0.fields().iter().filter(|f| f.is_required()); + + let required_field_count = required_fields_iter().count(); + let required_field_count_lit = + syn::LitInt::new(&required_field_count.to_string(), Span::call_site()); + + let required_fields_names = required_fields_iter().map(|field| field.ident.as_ref()); + + let nonskipped_fields_iter = || { + self.0 + .fields() + .iter() + // It is important that we enumerate **before** filtering, because otherwise we would not + // count the skipped fields, which might be confusing. + .enumerate() + .filter(|(_idx, f)| !f.skip) + }; + + let field_validations = + nonskipped_fields_iter().map(|(idx, field)| self.generate_field_validation(idx, field)); + + parse_quote! { + fn type_check( + typ: &#macro_internal::ColumnType, + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + // Extract information about the field types from the UDT + // type definition. + let fields = #extract_fields_expr; + + let too_few_fields = || #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::TooFewFields { + required_fields: vec![ + #(stringify!(#required_fields_names),)* + ], + present_fields: fields.iter().map(|(name, _typ)| name.clone()).collect(), + } + ); + + // Verify that the field count is correct + if fields.len() < #required_field_count_lit { + return ::std::result::Result::Err(too_few_fields()); + } + + let mut cql_field_iter = fields.iter(); + #( + #field_validations + )* + + // All is good! + ::std::result::Result::Ok(()) + } + } + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote! { + ::std::default::Default::default() + }; + } + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let deserialize: syn::Expr = parse_quote! { + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(cql_field_typ, value) + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))? + }; + + // Action performed in case of field name mismatch. + let name_mismatch: syn::Expr = parse_quote! { + panic!( + "type check should have prevented this scenario - field name mismatch! Rust field name {}, CQL field name {}", + #cql_name_literal, + cql_field_name + ) + }; + + let name_check_and_deserialize: syn::Expr = parse_quote! { + if #cql_name_literal == cql_field_name { + #deserialize + } else { + #name_mismatch + } + }; + + parse_quote! { + { + let next_cql_field = cql_field_iter.next() + .map(|(specs, value_res)| value_res.map(|value| (specs, value))) + // Type check has ensured that there are enough CQL UDT fields. + .expect("Too few CQL UDT fields - type check should have prevented this scenario!") + // Propagate deserialization errors. + .map_err(|err| #macro_internal::mk_value_deser_err::( + typ, + #macro_internal::UdtDeserializationErrorKind::FieldDeserializationFailed { + field_name: #cql_name_literal.to_owned(), + err, + } + ))?; + + + let ((cql_field_name, cql_field_typ), value) = next_cql_field; + + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); + + #name_check_and_deserialize + } + } + } + + fn generate(&self) -> syn::ImplItemFn { + // We can assume that type_check was called. + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + let fields = self.0.fields(); + + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields.iter().map(|f| self.generate_finalize_field(f)); + + #[allow(unused_mut)] + let mut iterator_type: syn::Type = + parse_quote!(#macro_internal::UdtIterator<#constraint_lifetime>); + + parse_quote! { + fn deserialize( + typ: &#constraint_lifetime #macro_internal::ColumnType, + v: ::std::option::Option<#macro_internal::FrameSlice<#constraint_lifetime>>, + ) -> ::std::result::Result { + // Create an iterator over the fields of the UDT. + let mut cql_field_iter = <#iterator_type as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(typ, v) + .map_err(#macro_internal::value_deser_error_replace_rust_name::)?; + + ::std::result::Result::Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } } } @@ -190,7 +434,7 @@ impl<'sd> TypeCheckUnorderedGenerator<'sd> { }) } - // Generates the type_check method. + // Generates the type_check method for when ensure_order == false. fn generate(&self) -> syn::ImplItemFn { // The generated method will: // - Check that every required field appears on the list exactly once, in any order diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 307d975efb..ca7bc6e160 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -299,6 +299,17 @@ pub use scylla_cql::macros::SerializeRow; /// macro itself, so in those cases the user must provide an alternative path /// to either the `scylla` or `scylla-cql` crate. /// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the UDT field order - when processing a field, it will look it up +/// in the Rust struct with the corresponding field and set it. However, +/// if the UDT field order is known to be the same both in the UDT +/// and the Rust struct, then the `enforce_order` annotation can be used +/// so that a more efficient implementation that does not perform lookups +/// is be generated. The UDT field names will still be checked during the +/// type check phase. +/// /// ## Field attributes /// /// `#[scylla(skip)]` From 8e10af6d92e5a30db1a4f3be6ef2ef25b6ce1fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:22:20 +0200 Subject: [PATCH 12/29] DeserializeValue: enforce_order flavour tests DeserializeValue enforce_order flavour is tested in the following aspects: - the macro executes properly on a struct, - the generated type_check() and deserialize() implementations are correct both in valid and invalid cases (i.e. return error in invalid cases and expected value in valid cases). - the generated type_check() and deserialize() implementations produce meaningful, appropriate errors in invalid cases. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 289 ++++++++++++++++++++++ 1 file changed, 289 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index a98dfbe429..ba9b2ec3dd 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2483,6 +2483,11 @@ pub(super) mod tests { #[scylla(crate = crate)] struct TestUdtWithNoFieldsUnordered {} + #[allow(unused)] + #[derive(scylla_macros::DeserializeValue)] + #[scylla(crate = crate, enforce_order)] + struct TestUdtWithNoFieldsOrdered {} + #[test] fn test_udt_loose_ordering() { #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] @@ -2586,6 +2591,97 @@ pub(super) mod tests { } } + #[test] + fn test_udt_strict_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDT fields in correct same order + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // The last UDT field is missing in serialized form - it should treat + // as if there were null at the end + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + } + + // An excess field at the end of UDT + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&(true as i8).to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // UDT fields switched - will not work + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Int), ("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + } + #[test] fn test_custom_type_parser() { #[derive(Default, Debug, PartialEq, Eq)] @@ -3306,5 +3402,198 @@ pub(super) mod tests { } } } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = + ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = + err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT too few fields + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::TooFewFields { + ref required_fields, + ref present_fields, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(required_fields.as_slice(), &["a", "b", "c"]); + assert_eq!(present_fields.as_slice(), &["a".to_string()]); + } + + // UDT fields switched - field name mismatch + { + let typ = udt_def_with_fields([ + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::Boolean), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldNameMismatch { + position, + ref rust_field_name, + ref db_field_name, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(position, 0); + assert_eq!(rust_field_name.as_str(), "a".to_owned()); + assert_eq!(db_field_name.as_str(), "b".to_owned()); + } + + // UDT fields incompatible types - field type check failed + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Blob), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad field format + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let udt_bytes_too_short = udt_bytes.slice(..udt_bytes.len() - 1); + assert!(udt_bytes.len() > udt_bytes_too_short.len()); + + let err = deserialize::(&typ, &udt_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::RawCqlBytesReadError(_) = err.kind else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT field deserialization failed + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "b"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } } } From cc0457bf0ddd98856f660318916278dbbffbba9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:34:13 +0200 Subject: [PATCH 13/29] DeserializeRow: enforce_order flavour support DeserializeRow enforce_order flavour is added, analogous to SerializeRow enforce_order flavour. The flavour requires that fields in Rust struct are ordered the same way as columns in CQL row definition. Co-authored-by: Piotr Dulikowski --- scylla-macros/src/deserialize/row.rs | 181 ++++++++++++++++++++++++++- scylla/src/macros.rs | 10 ++ 2 files changed, 189 insertions(+), 2 deletions(-) diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index c6726c3fc8..02975de6da 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -10,6 +10,13 @@ use super::{DeserializeCommonFieldAttrs, DeserializeCommonStructAttrs}; struct StructAttrs { #[darling(rename = "crate")] crate_path: Option, + + // If true, then the type checking code will require the order of the fields + // to be the same in both the Rust struct and the columns. This allows the + // deserialization to be slightly faster because looking struct fields up + // by name can be avoided, though it is less convenient. + #[darling(default)] + enforce_order: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -91,11 +98,181 @@ type StructDesc = super::StructDescForDeserialize; impl StructDesc { fn generate_type_check_method(&self) -> syn::ImplItemFn { - TypeCheckUnorderedGenerator(self).generate() + if self.attrs.enforce_order { + TypeCheckAssumeOrderGenerator(self).generate() + } else { + TypeCheckUnorderedGenerator(self).generate() + } } fn generate_deserialize_method(&self) -> syn::ImplItemFn { - DeserializeUnorderedGenerator(self).generate() + if self.attrs.enforce_order { + DeserializeAssumeOrderGenerator(self).generate() + } else { + DeserializeUnorderedGenerator(self).generate() + } + } +} + +struct TypeCheckAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { + fn generate_name_verification( + &self, + field_index: usize, // These two indices can be different because of `skip` attribute + column_index: usize, // applied to some field. + field: &Field, + column_spec: &syn::Ident, + ) -> syn::Expr { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let rust_field_name = field.cql_name_literal(); + + parse_quote! { + if #column_spec.name != #rust_field_name { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnNameMismatch { + field_index: #field_index, + column_index: #column_index, + rust_column_name: #rust_field_name, + db_column_name: ::std::clone::Clone::clone(&#column_spec.name), + } + ) + ); + } + } + } + + fn generate(&self) -> syn::ImplItemFn { + // The generated method will check that the order and the types + // of the columns correspond fields' names/types. + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let required_fields_iter = || { + self.0 + .fields() + .iter() + .enumerate() + .filter(|(_, f)| f.is_required()) + }; + let required_fields_count = required_fields_iter().count(); + let required_fields_idents: Vec<_> = (0..required_fields_count) + .map(|i| quote::format_ident!("f_{}", i)) + .collect(); + let name_verifications = required_fields_iter() + .zip(required_fields_idents.iter().enumerate()) + .map(|((field_idx, field), (col_idx, fidents))| { + self.generate_name_verification(field_idx, col_idx, field, fidents) + }); + + let required_fields_deserializers = + required_fields_iter().map(|(_, f)| f.deserialize_target()); + let numbers = 0usize..; + + parse_quote! { + fn type_check( + specs: &[#macro_internal::ColumnSpec], + ) -> ::std::result::Result<(), #macro_internal::TypeCheckError> { + let column_types_iter = || specs.iter().map(|spec| ::std::clone::Clone::clone(&spec.typ)); + + match specs { + [#(#required_fields_idents),*] => { + #( + // Verify the name + #name_verifications + + // Verify the type + <#required_fields_deserializers as #macro_internal::DeserializeValue<#constraint_lifetime>>::type_check(&#required_fields_idents.typ) + .map_err(|err| #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index: #numbers, + column_name: ::std::clone::Clone::clone(&#required_fields_idents.name), + err, + } + ))?; + )* + ::std::result::Result::Ok(()) + }, + _ => ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::WrongColumnCount { + rust_cols: #required_fields_count, + cql_cols: specs.len(), + } + ), + ), + } + } + } + } +} + +struct DeserializeAssumeOrderGenerator<'sd>(&'sd StructDesc); + +impl<'sd> DeserializeAssumeOrderGenerator<'sd> { + fn generate_finalize_field(&self, field_index: usize, field: &Field) -> syn::Expr { + if field.skip { + // Skipped fields are initialized with Default::default() + return parse_quote!(::std::default::Default::default()); + } + + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let cql_name_literal = field.cql_name_literal(); + let deserializer = field.deserialize_target(); + let constraint_lifetime = self.0.constraint_lifetime(); + + parse_quote!( + { + let col = row.next() + .expect("Typecheck should have prevented this scenario! Too few columns in the serialized data.") + .map_err(#macro_internal::row_deser_error_replace_rust_name::)?; + + if col.spec.name.as_str() != #cql_name_literal { + panic!( + "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", + #cql_name_literal, + col.spec.name.as_str() + ); + } + + <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice) + .map_err(|err| #macro_internal::mk_row_deser_err::( + #macro_internal::BuiltinRowDeserializationErrorKind::ColumnDeserializationFailed { + column_index: #field_index, + column_name: <_ as std::clone::Clone>::clone(&col.spec.name), + err, + } + ))? + } + ) + } + + fn generate(&self) -> syn::ImplItemFn { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let constraint_lifetime = self.0.constraint_lifetime(); + + let fields = self.0.fields(); + let field_idents = fields.iter().map(|f| f.ident.as_ref().unwrap()); + let field_finalizers = fields + .iter() + .enumerate() + .map(|(field_idx, f)| self.generate_finalize_field(field_idx, f)); + + parse_quote! { + fn deserialize( + #[allow(unused_mut)] + mut row: #macro_internal::ColumnIterator<#constraint_lifetime>, + ) -> ::std::result::Result { + ::std::result::Result::Ok(Self { + #(#field_idents: #field_finalizers,)* + }) + } + } } } diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index ca7bc6e160..4bd5cd36b5 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -400,6 +400,16 @@ pub use scylla_macros::DeserializeValue; /// macro itself, so in those cases the user must provide an alternative path /// to either the `scylla` or `scylla-cql` crate. /// +/// `#[scylla(enforce_order)]` +/// +/// By default, the generated deserialization code will be insensitive +/// to the column order - when processing a column, the corresponding Rust field +/// will be looked up and the column will be deserialized based on its type. +/// However, if the column order and the Rust field order is known to be the +/// same, then the `enforce_order` annotation can be used so that a more +/// efficient implementation that does not perform lookups is be generated. +/// The generated code will still check that the column and field names match. +/// /// ## Field attributes /// /// `#[scylla(skip)]` From b1fa8eb1a7dcd8aea7d3592427db4cf4d701a779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:37:25 +0200 Subject: [PATCH 14/29] DeserializeRow: enforce_order flavour tests DeserializeRow enforce_order flavour is tested in the following aspects: - the macro executes properly on a struct, - the generated type_check() and deserialize() implementations are correct both in valid and invalid cases (i.e. return error in invalid cases and expected value in valid cases). Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 42 +++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index bb2b3ba2e3..5d0bfbbc88 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -580,6 +580,11 @@ mod tests { #[scylla(crate = crate)] struct TestUdtWithNoFieldsUnordered {} + #[allow(unused)] + #[derive(DeserializeRow)] + #[scylla(crate = crate, enforce_order)] + struct TestUdtWithNoFieldsOrdered {} + #[test] fn test_struct_deserialization_loose_ordering() { #[derive(DeserializeRow, PartialEq, Eq, Debug)] @@ -626,6 +631,43 @@ mod tests { MyRow::type_check(specs).unwrap_err(); } + #[test] + fn test_struct_deserialization_strict_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Wrong order of columns + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); + } + fn val_int(i: i32) -> Option> { Some(i.to_be_bytes().to_vec()) } From 11d1a7111151acd50493a920f1bc82d3b05b46e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 26 Jun 2024 16:25:26 +0200 Subject: [PATCH 15/29] DeserializeRow: enforce_order flavour errors tests DeserializeRow enforce_order flavour is tested in the following aspects: - the generated type_check() and deserialize() implementations produce meaningful, appropriate errors in invalid cases. --- scylla-cql/src/types/deserialize/row.rs | 239 ++++++++++++++++++++++++ 1 file changed, 239 insertions(+) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index 5d0bfbbc88..b10f78a154 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -1046,5 +1046,244 @@ mod tests { } } } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Too few columns + { + let specs = [spec("a", ColumnType::Text)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 1); + } + + // Excess columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + spec("d", ColumnType::Counter), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 4); + } + + // Renamed column name mismatch + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("d", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 3); + assert_eq!(rust_column_name, "c"); + assert_eq!(column_index, 2); + assert_eq!(db_column_name.as_str(), "d"); + } + + // Columns switched - column name mismatch + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Text), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 0); + assert_eq!(column_index, 0); + assert_eq!(rust_column_name, "a"); + assert_eq!(db_column_name.as_str(), "b"); + } + + // Column incompatible types - column type check failed + { + let specs = [ + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Too few columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "b"); + } + + // Bad field format + { + let typ = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [(&b"alamakota"[..]), &42_i32.to_be_bytes(), &[true as u8]].map(Some), + ); + + let row_bytes_too_short = row_bytes.slice(..row_bytes.len() - 1); + assert!(row_bytes.len() > row_bytes_too_short.len()); + + let err = deserialize::(&typ, &row_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "c"); + } + + // Column deserialization failed + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [&b"alamakota"[..], &42_i64.to_be_bytes(), &[true as u8]].map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: field_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_name.as_str(), "b"); + assert_eq!(field_index, 2); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } } } From e09abc5f682562b8b2aa16bbd58f138364398257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:42:48 +0200 Subject: [PATCH 16/29] DeserializeValue: `skip_name_checks` flag support `skip_name_checks` is allowed only in `enforce_order` flavour. If enabled, it turns off name match verification (e.g., for performance purposes) between Rust struct and CQL UDT definition. Co-authored-by: Piotr Dulikowski --- scylla-macros/src/deserialize/value.rs | 39 +++++++++++++++++++------- scylla/src/macros.rs | 10 +++++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 5ec7182d94..b81c33cea7 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -17,6 +17,13 @@ struct StructAttrs { // by name can be avoided, though it is less convenient. #[darling(default)] enforce_order: bool, + + // If true, then the type checking code won't verify the UDT field names. + // UDT fields will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + skip_name_checks: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -137,6 +144,7 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { let constraint_lifetime = self.0.constraint_lifetime(); let rust_field_name = field.cql_name_literal(); let rust_field_typ = field.deserialize_target(); + let skip_name_checks = self.0.attrs.skip_name_checks; // Action performed in case of field name mismatch. let name_mismatch: syn::Expr = parse_quote! { @@ -155,12 +163,15 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { } }; - let name_verification: syn::Expr = parse_quote! { - if #rust_field_name != cql_field_name { - // The read UDT field is not the one expected by the Rust struct. - #name_mismatch + // Optional name verification. + let name_verification: Option = (!skip_name_checks).then(|| { + parse_quote! { + if #rust_field_name != cql_field_name { + // The read UDT field is not the one expected by the Rust struct. + #name_mismatch + } } - }; + }); parse_quote! { 'field: { @@ -171,6 +182,7 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { let (cql_field_name, cql_field_typ) = next_cql_field; 'verifications: { + // Verify the name (unless `skip_name_checks` is specified) #name_verification // Verify the type @@ -268,6 +280,7 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let cql_name_literal = field.cql_name_literal(); let deserializer = field.deserialize_target(); let constraint_lifetime = self.0.constraint_lifetime(); + let skip_name_checks = self.0.attrs.skip_name_checks; let deserialize: syn::Expr = parse_quote! { <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(cql_field_typ, value) @@ -289,11 +302,17 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { ) }; - let name_check_and_deserialize: syn::Expr = parse_quote! { - if #cql_name_literal == cql_field_name { + let maybe_name_check_and_deserialize: syn::Expr = if skip_name_checks { + parse_quote! { #deserialize - } else { - #name_mismatch + } + } else { + parse_quote! { + if #cql_name_literal == cql_field_name { + #deserialize + } else { + #name_mismatch + } } }; @@ -321,7 +340,7 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { // For now, we treat both cases as "null". let value = value.flatten(); - #name_check_and_deserialize + #maybe_name_check_and_deserialize } } } diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 4bd5cd36b5..2e9f15b9d2 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -310,6 +310,16 @@ pub use scylla_cql::macros::SerializeRow; /// is be generated. The UDT field names will still be checked during the /// type check phase. /// +/// #[(scylla(skip_name_checks))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the UDT field names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// UDT field into the first struct field, second UDT field into the second +/// struct field and so on. It will still verify that the UDT field types +/// and struct field types match. +/// /// ## Field attributes /// /// `#[scylla(skip)]` From 70c3615ed9eadd5bb441ea13303cbe91e7fd5c04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:43:52 +0200 Subject: [PATCH 17/29] DeserializeValue: `skip_name_checks` flag tests New tests are added for `enforce_order` flavour with name match verification turned off. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 50 +++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index ba9b2ec3dd..89dcbef6dd 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2682,6 +2682,56 @@ pub(super) mod tests { } } + #[test] + fn test_udt_no_name_check() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDT fields in correct same order + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // Correct order of UDT fields, but different names - should still succeed + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("k", ColumnType::Text), ("l", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + } + #[test] fn test_custom_type_parser() { #[derive(Default, Debug, PartialEq, Eq)] From 64ead02de31addcaa934c792c82cb1bbd49b58fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:52:22 +0200 Subject: [PATCH 18/29] DeserializeRow: `skip_name_checks` flag support `skip_name_checks` is allowed only in `enforce_order` flavour. If enabled, it turns off name match verification (e.g., for performance purposes) between Rust struct fields and CQL row columns. Co-authored-by: Piotr Dulikowski --- scylla-macros/src/deserialize/row.rs | 63 +++++++++++++++++----------- scylla/src/macros.rs | 9 ++++ 2 files changed, 47 insertions(+), 25 deletions(-) diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index 02975de6da..059e51f7f6 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -17,6 +17,13 @@ struct StructAttrs { // by name can be avoided, though it is less convenient. #[darling(default)] enforce_order: bool, + + // If true, then the type checking code won't verify the column names. + // Columns will be matched to struct fields based solely on the order. + // + // This annotation only works if `enforce_order` is specified. + #[darling(default)] + skip_name_checks: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -123,25 +130,27 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { column_index: usize, // applied to some field. field: &Field, column_spec: &syn::Ident, - ) -> syn::Expr { - let macro_internal = self.0.struct_attrs().macro_internal_path(); - let rust_field_name = field.cql_name_literal(); + ) -> Option { + (!self.0.attrs.skip_name_checks).then(|| { + let macro_internal = self.0.struct_attrs().macro_internal_path(); + let rust_field_name = field.cql_name_literal(); - parse_quote! { - if #column_spec.name != #rust_field_name { - return ::std::result::Result::Err( - #macro_internal::mk_row_typck_err::( - column_types_iter(), - #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnNameMismatch { - field_index: #field_index, - column_index: #column_index, - rust_column_name: #rust_field_name, - db_column_name: ::std::clone::Clone::clone(&#column_spec.name), - } - ) - ); + parse_quote! { + if #column_spec.name != #rust_field_name { + return ::std::result::Result::Err( + #macro_internal::mk_row_typck_err::( + column_types_iter(), + #macro_internal::DeserBuiltinRowTypeCheckErrorKind::ColumnNameMismatch { + field_index: #field_index, + column_index: #column_index, + rust_column_name: #rust_field_name, + db_column_name: ::std::clone::Clone::clone(&#column_spec.name), + } + ) + ); + } } - } + }) } fn generate(&self) -> syn::ImplItemFn { @@ -181,7 +190,7 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { match specs { [#(#required_fields_idents),*] => { #( - // Verify the name + // Verify the name (unless `skip_name_checks' is specified) #name_verifications // Verify the type @@ -226,19 +235,23 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let deserializer = field.deserialize_target(); let constraint_lifetime = self.0.constraint_lifetime(); + let name_check: Option = (!self.0.struct_attrs().skip_name_checks).then(|| parse_quote! { + if col.spec.name.as_str() != #cql_name_literal { + panic!( + "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", + #cql_name_literal, + col.spec.name.as_str() + ); + } + }); + parse_quote!( { let col = row.next() .expect("Typecheck should have prevented this scenario! Too few columns in the serialized data.") .map_err(#macro_internal::row_deser_error_replace_rust_name::)?; - if col.spec.name.as_str() != #cql_name_literal { - panic!( - "Typecheck should have prevented this scenario - field-column name mismatch! Rust field name {}, CQL column name {}", - #cql_name_literal, - col.spec.name.as_str() - ); - } + #name_check <#deserializer as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(&col.spec.typ, col.slice) .map_err(|err| #macro_internal::mk_row_deser_err::( diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 2e9f15b9d2..6f61380882 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -420,6 +420,15 @@ pub use scylla_macros::DeserializeValue; /// efficient implementation that does not perform lookups is be generated. /// The generated code will still check that the column and field names match. /// +/// #[(scylla(skip_name_checks))] +/// +/// This attribute only works when used with `enforce_order`. +/// +/// If set, the generated implementation will not verify the column names at +/// all. Because it only works with `enforce_order`, it will deserialize first +/// column into the first field, second column into the second field and so on. +/// It will still still verify that the column types and field types match. +/// /// ## Field attributes /// /// `#[scylla(skip)]` From 597c62fb6ad7da72e00f38e95043cb04f2eb6d13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:52:38 +0200 Subject: [PATCH 19/29] DeserializeRow: `skip_name_checks` flag tests New tests are added for `enforce_order` flavour with name match verification turned off. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/row.rs | 38 +++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index b10f78a154..e5f4e92464 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -668,6 +668,44 @@ mod tests { MyRow::type_check(specs).unwrap_err(); } + #[test] + fn test_struct_deserialization_no_name_check() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Correct order of columns, but different names - should still succeed + let specs = &[spec("z", ColumnType::Text), spec("x", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + } + fn val_int(i: i32) -> Option> { Some(i.to_be_bytes().to_vec()) } From 623adc308fd091de10d7c3432821f2be93f9b1e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 17:56:34 +0200 Subject: [PATCH 20/29] Deserialize{Value,Row}: attribute validation `skip_name_checks` requires `enforce_order`, so let's ensure this at compile time by panicking the macro with an insightful message. The same with `skip_name_checks` excluding `rename`. And we check that `rename`s do not introduce any name clashes. --- scylla-cql/src/types/deserialize/row.rs | 73 ++++++++++++++++++++++ scylla-cql/src/types/deserialize/value.rs | 74 +++++++++++++++++++++++ scylla-macros/src/deserialize/row.rs | 58 ++++++++++++++++-- scylla-macros/src/deserialize/value.rs | 59 ++++++++++++++++-- 4 files changed, 254 insertions(+), 10 deletions(-) diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index e5f4e92464..c4e8e9cf1c 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -706,6 +706,35 @@ mod tests { ); } + #[test] + fn test_struct_deserialization_cross_rename_fields() { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestRow { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // Columns switched wrt fields - should still work. + { + let row_bytes = serialize_cells( + ["The quick brown fox".as_bytes(), &42_i32.to_be_bytes()].map(Some), + ); + let specs = [spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + + let row = deserialize::(&specs, &row_bytes).unwrap(); + assert_eq!( + row, + TestRow { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } + } + fn val_int(i: i32) -> Option> { Some(i.to_be_bytes().to_vec()) } @@ -1325,3 +1354,47 @@ mod tests { } } } + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestRow {} +/// ``` +fn _test_struct_deserialization_name_check_skip_requires_enforce_order() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestRow { +/// #[scylla(rename = "b")] +/// a: i32, +/// } +/// ``` +fn _test_struct_deserialization_skip_name_check_conflicts_with_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql)] +/// struct TestRow { +/// #[scylla(rename = "b")] +/// a: i32, +/// b: String, +/// } +/// ``` +fn _test_struct_deserialization_skip_rename_collision_with_field() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeRow)] +/// #[scylla(crate = scylla_cql)] +/// struct TestRow { +/// #[scylla(rename = "c")] +/// a: i32, +/// #[scylla(rename = "c")] +/// b: String, +/// } +/// ``` +fn _test_struct_deserialization_rename_collision_with_another_rename() {} diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 89dcbef6dd..180eb21b9e 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2732,6 +2732,36 @@ pub(super) mod tests { } } + #[test] + fn test_udt_cross_rename_fields() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestUdt { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // UDT fields switched - should still work. + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::(&typ, &udt).unwrap(); + assert_eq!( + udt, + TestUdt { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } + } + #[test] fn test_custom_type_parser() { #[derive(Default, Debug, PartialEq, Eq)] @@ -3647,3 +3677,47 @@ pub(super) mod tests { } } } + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, skip_name_checks)] +/// struct TestUdt {} +/// ``` +fn _test_udt_bad_attributes_skip_name_check_requires_enforce_order() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, enforce_order, skip_name_checks)] +/// struct TestUdt { +/// #[scylla(rename = "b")] +/// a: i32, +/// } +/// ``` +fn _test_udt_bad_attributes_skip_name_check_conflicts_with_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// #[scylla(rename = "b")] +/// a: i32, +/// b: String, +/// } +/// ``` +fn _test_udt_bad_attributes_rename_collision_with_field() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// #[scylla(rename = "c")] +/// a: i32, +/// #[scylla(rename = "c")] +/// b: String, +/// } +/// ``` +fn _test_udt_bad_attributes_rename_collision_with_another_rename() {} diff --git a/scylla-macros/src/deserialize/row.rs b/scylla-macros/src/deserialize/row.rs index 059e51f7f6..1a43c8b343 100644 --- a/scylla-macros/src/deserialize/row.rs +++ b/scylla-macros/src/deserialize/row.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use darling::{FromAttributes, FromField}; use proc_macro2::Span; use syn::ext::IdentExt; @@ -77,6 +79,8 @@ pub(crate) fn deserialize_row_derive( let constraining_trait = parse_quote! { DeserializeValue }; let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + validate_attrs(&s.attrs, &s.fields)?; + let items = [ s.generate_type_check_method().into(), s.generate_deserialize_method().into(), @@ -85,19 +89,63 @@ pub(crate) fn deserialize_row_derive( Ok(s.generate_impl(implemented_trait, items)) } +fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling::Error> { + let mut errors = darling::Error::accumulator(); + + if attrs.skip_name_checks { + // Skipping name checks is only available in enforce_order mode + if !attrs.enforce_order { + let error = + darling::Error::custom("attribute requires ."); + errors.push(error); + } + + // annotations don't make sense with skipped name checks + for field in fields { + if field.rename.is_some() { + let err = darling::Error::custom( + " annotations don't make sense with attribute", + ) + .with_span(&field.ident); + errors.push(err); + } + } + } else { + // Detect name collisions caused by `rename`. + let mut used_names = HashMap::::new(); + for field in fields { + let column_name = field.column_name(); + if let Some(other_field) = used_names.get(&column_name) { + let other_field_ident = other_field.ident.as_ref().unwrap(); + let msg = format!("the column name `{column_name}` used by this struct field is already used by field `{other_field_ident}`"); + let err = darling::Error::custom(msg).with_span(&field.ident); + errors.push(err); + } else { + used_names.insert(column_name, field); + } + } + } + + errors.finish() +} + impl Field { // Returns whether this field is mandatory for deserialization. fn is_required(&self) -> bool { !self.skip } - // A Rust literal representing the name of this field - fn cql_name_literal(&self) -> syn::LitStr { - let field_name = match self.rename.as_ref() { + // The name of the column corresponding to this Rust struct field + fn column_name(&self) -> String { + match self.rename.as_ref() { Some(rename) => rename.to_owned(), None => self.ident.as_ref().unwrap().unraw().to_string(), - }; - syn::LitStr::new(&field_name, Span::call_site()) + } + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + syn::LitStr::new(&self.column_name(), Span::call_site()) } } diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index b81c33cea7..2fcbab6af9 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use darling::{FromAttributes, FromField}; use proc_macro::TokenStream; use proc_macro2::Span; @@ -76,6 +78,8 @@ pub(crate) fn deserialize_value_derive( let constraining_trait = implemented_trait.clone(); let s = StructDesc::new(&input, &implemented_trait_name, constraining_trait)?; + validate_attrs(&s.attrs, s.fields())?; + let items = [ s.generate_type_check_method().into(), s.generate_deserialize_method().into(), @@ -83,19 +87,64 @@ pub(crate) fn deserialize_value_derive( Ok(s.generate_impl(implemented_trait, items)) } + +fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling::Error> { + let mut errors = darling::Error::accumulator(); + + if attrs.skip_name_checks { + // Skipping name checks is only available in enforce_order mode + if !attrs.enforce_order { + let error = + darling::Error::custom("attribute requires ."); + errors.push(error); + } + + // annotations don't make sense with skipped name checks + for field in fields { + if field.rename.is_some() { + let err = darling::Error::custom( + " annotations don't make sense with attribute", + ) + .with_span(&field.ident); + errors.push(err); + } + } + } else { + // Detect name collisions caused by . + let mut used_names = HashMap::::new(); + for field in fields { + let udt_field_name = field.udt_field_name(); + if let Some(other_field) = used_names.get(&udt_field_name) { + let other_field_ident = other_field.ident.as_ref().unwrap(); + let msg = format!("the UDT field name `{udt_field_name}` used by this struct field is already used by field `{other_field_ident}`"); + let err = darling::Error::custom(msg).with_span(&field.ident); + errors.push(err); + } else { + used_names.insert(udt_field_name, field); + } + } + } + + errors.finish() +} + impl Field { // Returns whether this field is mandatory for deserialization. fn is_required(&self) -> bool { !self.skip } - // A Rust literal representing the name of this field - fn cql_name_literal(&self) -> syn::LitStr { - let field_name = match self.rename.as_ref() { + // The name of UDT field corresponding to this Rust struct field + fn udt_field_name(&self) -> String { + match self.rename.as_ref() { Some(rename) => rename.to_owned(), None => self.ident.as_ref().unwrap().unraw().to_string(), - }; - syn::LitStr::new(&field_name, Span::call_site()) + } + } + + // A Rust literal representing the name of this field + fn cql_name_literal(&self) -> syn::LitStr { + syn::LitStr::new(&self.udt_field_name(), Span::call_site()) } } From ddc08b5fc47d5023026b7873024453084c885a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:04:43 +0200 Subject: [PATCH 21/29] DeserializeValue: `forbid_excess_udt_fields` flag support By default, if a UDT definition contains more fields than the Rust struct (in unordered flavour: anywhere, in enforce_order: in suffix), those excess fields are ignored. The `forbid_excess_udt_fields` attribute is added to fail with an error in case such fields are present. For comparison, DeserializeRow always requires the same number of Rust fields and CQL columns, effectively rejecting rows with excess columns. --- scylla-macros/src/deserialize/value.rs | 50 +++++++++++++++++++++++--- scylla/src/macros.rs | 9 +++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 2fcbab6af9..397162f660 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -26,6 +26,13 @@ struct StructAttrs { // This annotation only works if `enforce_order` is specified. #[darling(default)] skip_name_checks: bool, + + // If true, then the type checking code will require that the UDT does not + // contain excess fields at its suffix. Otherwise, if UDT has some fields + // at its suffix that do not correspond to Rust struct's fields, + // they will be ignored. With true, an error will be raised. + #[darling(default)] + forbid_excess_udt_fields: bool, } impl DeserializeCommonStructAttrs for StructAttrs { @@ -279,6 +286,20 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { let field_validations = nonskipped_fields_iter().map(|(idx, field)| self.generate_field_validation(idx, field)); + let check_excess_udt_fields: Option = + self.0.attrs.forbid_excess_udt_fields.then(|| { + parse_quote! { + if let ::std::option::Option::Some((cql_field_name, cql_field_typ)) = cql_field_iter.next() { + return ::std::result::Result::Err(#macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ExcessFieldInUdt { + db_field_name: <_ as ::std::clone::Clone>::clone(cql_field_name), + } + )); + } + } + }); + parse_quote! { fn type_check( typ: &#macro_internal::ColumnType, @@ -307,6 +328,8 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { #field_validations )* + #check_excess_udt_fields + // All is good! ::std::result::Result::Ok(()) } @@ -509,6 +532,7 @@ impl<'sd> TypeCheckUnorderedGenerator<'sd> { // - Every type on the list is correct let macro_internal = &self.0.struct_attrs().macro_internal_path(); + let forbid_excess_udt_fields = self.0.attrs.forbid_excess_udt_fields; let rust_fields = self.0.fields(); let visited_field_declarations = rust_fields .iter() @@ -524,6 +548,27 @@ impl<'sd> TypeCheckUnorderedGenerator<'sd> { syn::LitInt::new(&required_cql_field_count.to_string(), Span::call_site()); let extract_cql_fields_expr = self.0.generate_extract_fields_from_type(parse_quote!(typ)); + // If UDT contains a field with an unknown name, an error is raised iff + // `forbid_excess_udt_fields` attribute is specified. + let excess_udt_field_action: syn::Expr = if forbid_excess_udt_fields { + parse_quote! { + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::ExcessFieldInUdt { + db_field_name: unknown.to_owned(), + } + ) + ) + } + } else { + parse_quote! { + // We ignore excess UDT fields, as this facilitates the process of adding new fields + // to a UDT in running production cluster & clients. + () + } + }; + parse_quote! { fn type_check( typ: &#macro_internal::ColumnType, @@ -542,10 +587,7 @@ impl<'sd> TypeCheckUnorderedGenerator<'sd> { // Pattern match on the name and verify that the type is correct. match cql_field_name.as_str() { #(#rust_nonskipped_field_names => #type_check_blocks,)* - _unknown => { - // We ignore excess UDT fields, as this facilitates the process of adding new fields - // to a UDT in running production cluster & clients. - } + unknown => #excess_udt_field_action, } } diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 6f61380882..ac37b0575d 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -320,6 +320,15 @@ pub use scylla_cql::macros::SerializeRow; /// struct field and so on. It will still verify that the UDT field types /// and struct field types match. /// +/// #[(scylla(forbid_excess_udt_fields))] +/// +/// By default, the generated deserialization code ignores excess UDT fields. +/// I.e., `enforce_order` flavour ignores excess UDT fields in the suffix +/// of the UDT definition, and the default unordered flavour ignores excess +/// UDT fields anywhere. +/// If more strictness is desired, this flag makes sure that no excess fields +/// are present and forces error in case there are some. +/// /// ## Field attributes /// /// `#[scylla(skip)]` From bc9814ad8467147035956c9e82ea5d068f4b31d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:05:32 +0200 Subject: [PATCH 22/29] DeserializeValue: `forbid_excess_udt_fields` tests New tests are added for both flavours that show how the attribute fails type check phase in case that excess fields are present. --- scylla-cql/src/types/deserialize/value.rs | 65 ++++++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 180eb21b9e..a74a086515 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2663,6 +2663,26 @@ pub(super) mod tests { ); } + // An excess field at the end of UDT, when such are forbidden + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + + Udt::type_check(&typ).unwrap_err(); + } + // UDT fields switched - will not work { let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); @@ -3373,7 +3393,7 @@ pub(super) mod tests { // Loose ordering { #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate")] + #[scylla(crate = "crate", forbid_excess_udt_fields)] struct Udt<'a> { a: &'a str, #[scylla(skip)] @@ -3417,6 +3437,26 @@ pub(super) mod tests { assert_eq!(missing_fields.as_slice(), &["a", "b"]); } + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("d", ColumnType::Boolean), + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ExcessFieldInUdt { ref db_field_name }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + // missing UDT field { let typ = @@ -3486,7 +3526,7 @@ pub(super) mod tests { // Strict ordering { #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] struct Udt<'a> { a: &'a str, #[scylla(skip)] @@ -3530,6 +3570,27 @@ pub(super) mod tests { assert_eq!(present_fields.as_slice(), &["a".to_string()]); } + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ("d", ColumnType::Counter), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ExcessFieldInUdt { ref db_field_name }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + // UDT fields switched - field name mismatch { let typ = udt_def_with_fields([ From 7083ed9d55b0e14a5d0d7b165140cab7785bb288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:09:27 +0200 Subject: [PATCH 23/29] DeserializeValue: `allow_missing` attribute support By default, if a UDT definition does not contain a field that the Rust struct requires, type checking fails. Instead, if `allow_missing` is specified for a field, type check lets it pass and deserialization yields `Default::default()`. This is important in production: if adding a field to a UDT and clients are updated before the cluster, then using this attribute clients can operate correctly until the cluster gets updated (by extending the UDT definition with the field expected by clients). Co-authored-by: Piotr Dulikowski --- scylla-macros/src/deserialize/value.rs | 173 +++++++++++++++++++------ scylla/src/macros.rs | 5 + 2 files changed, 135 insertions(+), 43 deletions(-) diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 397162f660..83b1bc4ff6 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -49,6 +49,12 @@ struct Field { #[darling(default)] skip: bool, + // If true, then - if this field is missing from the UDT fields metadata + // - it will be initialized to Default::default(). + #[darling(default)] + #[darling(rename = "allow_missing")] + default_when_missing: bool, + // If set, then deserializes from the UDT field with this particular name // instead of the Rust field name. #[darling(default)] @@ -60,7 +66,7 @@ struct Field { impl DeserializeCommonFieldAttrs for Field { fn needs_default(&self) -> bool { - self.skip + self.skip || self.default_when_missing } fn deserialize_target(&self) -> &syn::Type { @@ -138,7 +144,7 @@ fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling:: impl Field { // Returns whether this field is mandatory for deserialization. fn is_required(&self) -> bool { - !self.skip + !self.skip && !self.default_when_missing } // The name of UDT field corresponding to this Rust struct field @@ -200,22 +206,35 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { let constraint_lifetime = self.0.constraint_lifetime(); let rust_field_name = field.cql_name_literal(); let rust_field_typ = field.deserialize_target(); + let default_when_missing = field.default_when_missing; let skip_name_checks = self.0.attrs.skip_name_checks; // Action performed in case of field name mismatch. - let name_mismatch: syn::Expr = parse_quote! { - { - // Error - required value for field not present among the CQL fields. - return ::std::result::Result::Err( - #macro_internal::mk_value_typck_err::( - typ, - #macro_internal::DeserUdtTypeCheckErrorKind::FieldNameMismatch { - position: #rust_field_idx, - rust_field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), - db_field_name: <_ as ::std::borrow::ToOwned>::to_owned(cql_field_name), - } - ) - ); + let name_mismatch: syn::Expr = if default_when_missing { + parse_quote! { + { + // If the Rust struct's field is marked as `default_when_missing`, then let's assume + // optimistically that the remaining UDT fields match required Rust struct fields. + // For that, store the read UDT field to be fit against the next Rust struct field. + saved_cql_field = ::std::option::Option::Some(next_cql_field); + break 'verifications; // Skip type verification, because the UDT field is absent. + } + } + } else { + parse_quote! { + { + // Error - required value for field not present among the CQL fields. + return ::std::result::Result::Err( + #macro_internal::mk_value_typck_err::( + typ, + #macro_internal::DeserUdtTypeCheckErrorKind::FieldNameMismatch { + position: #rust_field_idx, + rust_field_name: <_ as ::std::borrow::ToOwned>::to_owned(#rust_field_name), + db_field_name: <_ as ::std::borrow::ToOwned>::to_owned(cql_field_name), + } + ) + ); + } } }; @@ -231,14 +250,23 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { parse_quote! { 'field: { - let next_cql_field = match cql_field_iter.next() { - ::std::option::Option::Some(cql_field) => cql_field, - ::std::option::Option::None => return Err(too_few_fields()), + let next_cql_field = match saved_cql_field + // We may have a stored CQL UDT field that did not match the previous Rust struct's field. + .take() + // If not, simply fetch another CQL UDT field from the iterator. + .or_else(|| cql_field_iter.next()) { + ::std::option::Option::Some(cql_field) => cql_field, + // In case the Rust field allows default-initialisation and there are no more CQL fields, + // simply assume it's going to be default-initialised. + ::std::option::Option::None if #default_when_missing => break 'field, + ::std::option::Option::None => return Err(too_few_fields()), }; let (cql_field_name, cql_field_typ) = next_cql_field; 'verifications: { // Verify the name (unless `skip_name_checks` is specified) + // In a specific case when this Rust field is going to be default-initialised + // due to no corresponding CQL UDT field, the below type verification will be skipped. #name_verification // Verify the type @@ -289,7 +317,9 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { let check_excess_udt_fields: Option = self.0.attrs.forbid_excess_udt_fields.then(|| { parse_quote! { - if let ::std::option::Option::Some((cql_field_name, cql_field_typ)) = cql_field_iter.next() { + if let ::std::option::Option::Some((cql_field_name, cql_field_typ)) = saved_cql_field + .take() + .or_else(|| cql_field_iter.next()) { return ::std::result::Result::Err(#macro_internal::mk_value_typck_err::( typ, #macro_internal::DeserUdtTypeCheckErrorKind::ExcessFieldInUdt { @@ -324,6 +354,12 @@ impl<'sd> TypeCheckAssumeOrderGenerator<'sd> { } let mut cql_field_iter = fields.iter(); + // A CQL UDT field that has already been fetched from the field iterator, + // but not yet matched to a Rust struct field (because the previous + // Rust struct field didn't match it and had #[allow_missing] specified). + let mut saved_cql_field = ::std::option::Option::None::< + &(::std::string::String, #macro_internal::ColumnType), + >; #( #field_validations )* @@ -352,6 +388,7 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let cql_name_literal = field.cql_name_literal(); let deserializer = field.deserialize_target(); let constraint_lifetime = self.0.constraint_lifetime(); + let default_when_missing = field.default_when_missing; let skip_name_checks = self.0.attrs.skip_name_checks; let deserialize: syn::Expr = parse_quote! { @@ -366,15 +403,30 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { }; // Action performed in case of field name mismatch. - let name_mismatch: syn::Expr = parse_quote! { - panic!( - "type check should have prevented this scenario - field name mismatch! Rust field name {}, CQL field name {}", - #cql_name_literal, - cql_field_name - ) + let name_mismatch: syn::Expr = if default_when_missing { + parse_quote! { + { + // If the Rust struct's field is marked as `default_when_missing`, then let's assume + // optimistically that the remaining UDT fields match required Rust struct fields. + // For that, store the read UDT field to be fit against the next Rust struct field. + saved_cql_field = ::std::option::Option::Some(next_cql_field); + + ::std::default::Default::default() + } + } + } else { + parse_quote! { + { + panic!( + "type check should have prevented this scenario - field name mismatch! Rust field name {}, CQL field name {}", + #cql_name_literal, + cql_field_name + ); + } + } }; - let maybe_name_check_and_deserialize: syn::Expr = if skip_name_checks { + let maybe_name_check_and_deserialize_or_save: syn::Expr = if skip_name_checks { parse_quote! { #deserialize } @@ -388,12 +440,27 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { } }; + let no_more_fields: syn::Expr = if default_when_missing { + parse_quote! { + ::std::default::Default::default() + } + } else { + parse_quote! { + // Type check has ensured that there are enough CQL UDT fields. + panic!("Too few CQL UDT fields - type check should have prevented this scenario!") + } + }; + parse_quote! { { - let next_cql_field = cql_field_iter.next() - .map(|(specs, value_res)| value_res.map(|value| (specs, value))) - // Type check has ensured that there are enough CQL UDT fields. - .expect("Too few CQL UDT fields - type check should have prevented this scenario!") + let maybe_next_cql_field = saved_cql_field + .take() + .map(::std::result::Result::Ok) + .or_else(|| { + cql_field_iter.next() + .map(|(specs, value_res)| value_res.map(|value| (specs, value))) + }) + .transpose() // Propagate deserialization errors. .map_err(|err| #macro_internal::mk_value_deser_err::( typ, @@ -403,16 +470,19 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { } ))?; + if let Some(next_cql_field) = maybe_next_cql_field { + let ((cql_field_name, cql_field_typ), value) = next_cql_field; - let ((cql_field_name, cql_field_typ), value) = next_cql_field; - - // The value can be either - // - None - missing from the serialized representation - // - Some(None) - present in the serialized representation but null - // For now, we treat both cases as "null". - let value = value.flatten(); + // The value can be either + // - None - missing from the serialized representation + // - Some(None) - present in the serialized representation but null + // For now, we treat both cases as "null". + let value = value.flatten(); - #maybe_name_check_and_deserialize + #maybe_name_check_and_deserialize_or_save + } else { + #no_more_fields + } } } } @@ -440,6 +510,14 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let mut cql_field_iter = <#iterator_type as #macro_internal::DeserializeValue<#constraint_lifetime>>::deserialize(typ, v) .map_err(#macro_internal::value_deser_error_replace_rust_name::)?; + // This is to hold another field that already popped up from the field iterator but appeared to not match + // the expected nonrequired field. Therefore, that field is stored here, while the expected field + // is default-initialized. + let mut saved_cql_field = ::std::option::Option::None::<( + &(::std::string::String, #macro_internal::ColumnType), + ::std::option::Option<::std::option::Option<#macro_internal::FrameSlice>> + )>; + ::std::result::Result::Ok(Self { #(#field_idents: #field_finalizers,)* }) @@ -630,11 +708,20 @@ impl<'sd> DeserializeUnorderedGenerator<'sd> { } let deserialize_field = Self::deserialize_field_variable(field); - let cql_name_literal = field.cql_name_literal(); - parse_quote!(#deserialize_field.unwrap_or_else(|| panic!( - "field {} missing in UDT - type check should have prevented this!", - #cql_name_literal - ))) + if field.default_when_missing { + // Generate Default::default if the field was missing + parse_quote! { + #deserialize_field.unwrap_or_default() + } + } else { + let cql_name_literal = field.cql_name_literal(); + parse_quote! { + #deserialize_field.unwrap_or_else(|| panic!( + "field {} missing in UDT - type check should have prevented this!", + #cql_name_literal + )) + } + } } /// Generates code that performs deserialization when the raw field diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index ac37b0575d..44aafa2f09 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -336,6 +336,11 @@ pub use scylla_cql::macros::SerializeRow; /// The field will be completely ignored during deserialization and will /// be initialized with `Default::default()`. /// +/// `#[scylla(allow_missing)]` +/// +/// If the UDT definition does not contain this field, it will be initialized +/// with `Default::default()`. +/// /// `#[scylla(rename = "field_name")` /// /// By default, the generated implementation will try to match the Rust field From 2d76697a6d4aa141866e0367764329000bb69870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:10:25 +0200 Subject: [PATCH 24/29] DeserializeValue: `allow_missing` usage validation `enforce_order` flavour with `skip_name_checks` can't support `allow_missing` attribute in certain cases. Namely: * Fields with `allow_missing` are only permitted at the end of the struct, i.e. no field without `allow_missing` and `skip` is allowed to be after any field with `allow_missing`. If the condition is not upheld, the problem of matching fields becomes unsolvable. Thus, we don't support such cases, so let's ensure this condition at compile time by panicking the macro with an insightful message. --- scylla-cql/src/types/deserialize/value.rs | 25 +++++++++++++++++++++++ scylla-macros/src/deserialize/value.rs | 20 ++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index a74a086515..58e8afb91f 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -3782,3 +3782,28 @@ fn _test_udt_bad_attributes_rename_collision_with_field() {} /// } /// ``` fn _test_udt_bad_attributes_rename_collision_with_another_rename() {} + +/// ```compile_fail +/// +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql, enforce_order, skip_name_checks)] +/// struct TestUdt { +/// a: i32, +/// #[scylla(allow_missing)] +/// b: bool, +/// c: String, +/// } +/// ``` +fn _test_udt_bad_attributes_name_skip_name_checks_limitations_on_allow_missing() {} + +/// ``` +/// #[derive(scylla_macros::DeserializeValue)] +/// #[scylla(crate = scylla_cql)] +/// struct TestUdt { +/// a: i32, +/// #[scylla(allow_missing)] +/// b: bool, +/// c: String, +/// } +/// ``` +fn _test_udt_unordered_flavour_no_limitations_on_allow_missing() {} diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index 83b1bc4ff6..de0ff0c33b 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -112,6 +112,26 @@ fn validate_attrs(attrs: &StructAttrs, fields: &[Field]) -> Result<(), darling:: errors.push(error); } + // Fields with `allow_missing` are only permitted at the end of the + // struct, i.e. no field without `allow_missing` and `skip` is allowed + // to be after any field with `allow_missing`. + let invalid_default_when_missing_field = fields + .iter() + .rev() + // Skip the whole suffix of and . + .skip_while(|field| !field.is_required()) + // skip_while finished either because the iterator is empty or it found a field without both and . + // In either case, there aren't allowed to be any more fields with `allow_missing`. + .find(|field| field.default_when_missing); + if let Some(invalid) = invalid_default_when_missing_field { + let error = + darling::Error::custom( + "when is on, fields with are only permitted at the end of the struct, \ + i.e. no field without and is allowed to be after any field with ." + ).with_span(&invalid.ident); + errors.push(error); + } + // annotations don't make sense with skipped name checks for field in fields { if field.rename.is_some() { From e086dc46984bd2ba7b073166f272937430b87778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:13:30 +0200 Subject: [PATCH 25/29] DeserializeValue: `allow_missing` tests New tests are added for both flavours that show how the attribute allows for successful type check and deserialization when the corresponding field is missing from the CQL UDT definition. Co-authored-by: Piotr Dulikowski --- scylla-cql/src/types/deserialize/value.rs | 78 +++++++++++++++++++---- 1 file changed, 64 insertions(+), 14 deletions(-) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 58e8afb91f..f6d0366776 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2496,6 +2496,7 @@ pub(super) mod tests { a: &'a str, #[scylla(skip)] x: String, + #[scylla(allow_missing)] b: Option, c: i64, } @@ -2599,6 +2600,7 @@ pub(super) mod tests { a: &'a str, #[scylla(skip)] x: String, + #[scylla(allow_missing)] b: Option, } @@ -2700,6 +2702,22 @@ pub(super) mod tests { let typ = udt_def_with_fields([("b", ColumnType::Int)]); Udt::type_check(&typ).unwrap_err(); } + + // Missing non-required column + { + let udt = UdtSerializer::new().field(b"kotmaale").finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "kotmaale", + x: String::new(), + b: None, + } + ); + } } #[test] @@ -3398,6 +3416,7 @@ pub(super) mod tests { a: &'a str, #[scylla(skip)] x: String, + #[scylla(allow_missing)] b: Option, c: bool, } @@ -3434,7 +3453,7 @@ pub(super) mod tests { else { panic!("unexpected error kind: {:?}", err.kind) }; - assert_eq!(missing_fields.as_slice(), &["a", "b"]); + assert_eq!(missing_fields.as_slice(), &["a"]); } // excess fields in UDT @@ -3520,6 +3539,43 @@ pub(super) mod tests { assert_eq!(err.cql_type, typ); assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); } + + // UDT field deserialization failed + { + let typ = + udt_def_with_fields([("a", ColumnType::Ascii), ("c", ColumnType::Boolean)]); + + let udt_bytes = UdtSerializer::new() + .field("alamakota".as_bytes()) + .field(&42_i16.to_be_bytes()) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "c"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } } } @@ -3532,6 +3588,7 @@ pub(super) mod tests { #[scylla(skip)] x: String, b: Option, + #[scylla(allow_missing)] c: bool, } @@ -3566,7 +3623,7 @@ pub(super) mod tests { else { panic!("unexpected error kind: {:?}", err.kind) }; - assert_eq!(required_fields.as_slice(), &["a", "b", "c"]); + assert_eq!(required_fields.as_slice(), &["a", "b"]); assert_eq!(present_fields.as_slice(), &["a".to_string()]); } @@ -3575,8 +3632,7 @@ pub(super) mod tests { let typ = udt_def_with_fields([ ("a", ColumnType::Text), ("b", ColumnType::Int), - ("c", ColumnType::Boolean), - ("d", ColumnType::Counter), + ("d", ColumnType::Boolean), ]); let err = Udt::type_check(&typ).unwrap_err(); let err = get_typeck_err_inner(err.0.as_ref()); @@ -3593,11 +3649,8 @@ pub(super) mod tests { // UDT fields switched - field name mismatch { - let typ = udt_def_with_fields([ - ("b", ColumnType::Int), - ("a", ColumnType::Text), - ("c", ColumnType::Boolean), - ]); + let typ = + udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); let err = Udt::type_check(&typ).unwrap_err(); let err = get_typeck_err_inner(err.0.as_ref()); assert_eq!(err.rust_name, std::any::type_name::()); @@ -3619,11 +3672,8 @@ pub(super) mod tests { // UDT fields incompatible types - field type check failed { - let typ = udt_def_with_fields([ - ("a", ColumnType::Blob), - ("b", ColumnType::Int), - ("c", ColumnType::Boolean), - ]); + let typ = + udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); let err = Udt::type_check(&typ).unwrap_err(); let err = get_typeck_err_inner(err.0.as_ref()); assert_eq!(err.rust_name, std::any::type_name::()); From a85067a83f0f89f2b63f1ea37616e1e1444dd29a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:14:56 +0200 Subject: [PATCH 26/29] DeserializeValue: `default_when_null` attribute support By default, if DB provides null value in serialized data and the corresponding Rust type expects non-null value, deserialization fails. Instead, if `default_when_null` is specified for a field, deserialization yields `Default::default()`. This is important in production: when added a field to a UDT, the existing UDT instances have the new field filled with null, even if the data represented makes no sense to allow nulls. By using `default_when_null`, clients can handle such situation without using `Option` in their Rust struct. --- scylla-macros/src/deserialize/value.rs | 39 ++++++++++++++++++++++++-- scylla/src/macros.rs | 5 ++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/scylla-macros/src/deserialize/value.rs b/scylla-macros/src/deserialize/value.rs index de0ff0c33b..6b89bcfb58 100644 --- a/scylla-macros/src/deserialize/value.rs +++ b/scylla-macros/src/deserialize/value.rs @@ -55,6 +55,12 @@ struct Field { #[darling(rename = "allow_missing")] default_when_missing: bool, + // If true, then - if this field is present among UDT fields metadata + // but at the same time missing from serialized data or set to null + // - it will be initialized to Default::default(). + #[darling(default)] + default_when_null: bool, + // If set, then deserializes from the UDT field with this particular name // instead of the Rust field name. #[darling(default)] @@ -409,6 +415,7 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let deserializer = field.deserialize_target(); let constraint_lifetime = self.0.constraint_lifetime(); let default_when_missing = field.default_when_missing; + let default_when_null = field.default_when_null; let skip_name_checks = self.0.attrs.skip_name_checks; let deserialize: syn::Expr = parse_quote! { @@ -422,6 +429,20 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { ))? }; + let maybe_default_deserialize: syn::Expr = if default_when_null { + parse_quote! { + if value.is_none() { + ::std::default::Default::default() + } else { + #deserialize + } + } + } else { + parse_quote! { + #deserialize + } + }; + // Action performed in case of field name mismatch. let name_mismatch: syn::Expr = if default_when_missing { parse_quote! { @@ -448,12 +469,12 @@ impl<'sd> DeserializeAssumeOrderGenerator<'sd> { let maybe_name_check_and_deserialize_or_save: syn::Expr = if skip_name_checks { parse_quote! { - #deserialize + #maybe_default_deserialize } } else { parse_quote! { if #cql_name_literal == cql_field_name { - #deserialize + #maybe_default_deserialize } else { #name_mismatch } @@ -765,6 +786,18 @@ impl<'sd> DeserializeUnorderedGenerator<'sd> { ))? }; + let deserialize_action: syn::Expr = if field.default_when_null { + parse_quote! { + if value.is_some() { + #do_deserialize + } else { + ::std::default::Default::default() + } + } + } else { + do_deserialize + }; + parse_quote! { { assert!( @@ -780,7 +813,7 @@ impl<'sd> DeserializeUnorderedGenerator<'sd> { let value = value.flatten(); #deserialize_field = ::std::option::Option::Some( - #do_deserialize + #deserialize_action ); } } diff --git a/scylla/src/macros.rs b/scylla/src/macros.rs index 44aafa2f09..490fb499a9 100644 --- a/scylla/src/macros.rs +++ b/scylla/src/macros.rs @@ -341,6 +341,11 @@ pub use scylla_cql::macros::SerializeRow; /// If the UDT definition does not contain this field, it will be initialized /// with `Default::default()`. /// +/// `#[scylla(default_when_null)]` +/// +/// If the value of the field received from DB is null, the field will be +/// initialized with `Default::default()`. +/// /// `#[scylla(rename = "field_name")` /// /// By default, the generated implementation will try to match the Rust field From 96155f1eaf87a529d0d402f2f3faaed0b8be9ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Tue, 25 Jun 2024 18:16:09 +0200 Subject: [PATCH 27/29] DeserializeValue: `default_when_null` tests New tests are added for both flavours that show how the attribute allows for successful deserialization when the field contains null and the corresponding Rust type expects non-null value. --- scylla-cql/src/types/deserialize/value.rs | 70 +++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index f6d0366776..12a89e86ca 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -2470,6 +2470,11 @@ pub(super) mod tests { self } + fn null_field(mut self) -> Self { + append_null(&mut self.buf); + self + } + fn finalize(&self) -> Bytes { make_bytes(&self.buf) } @@ -2498,6 +2503,7 @@ pub(super) mod tests { x: String, #[scylla(allow_missing)] b: Option, + #[scylla(default_when_null)] c: i64, } @@ -2526,6 +2532,30 @@ pub(super) mod tests { ); } + // The last two UDT field are missing in serialized form - it should treat it + // as if there were nulls at the end. + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + // UDT fields switched - should still work. { let udt = UdtSerializer::new() @@ -2579,6 +2609,25 @@ pub(super) mod tests { ); } + // Only field 'a' is present + { + let udt = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("c", ColumnType::BigInt)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + // Wrong column type { let typ = udt_def_with_fields([("a", ColumnType::Text)]); @@ -2597,6 +2646,7 @@ pub(super) mod tests { #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] #[scylla(crate = "crate", enforce_order)] struct Udt<'a> { + #[scylla(default_when_null)] a: &'a str, #[scylla(skip)] x: String, @@ -2718,6 +2768,25 @@ pub(super) mod tests { } ); } + + // The first field is null, but `default_when_null` prevents failure. + { + let udt = UdtSerializer::new() + .null_field() + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt).unwrap(); + assert_eq!( + udt, + Udt { + a: "", + x: String::new(), + b: Some(42), + } + ); + } } #[test] @@ -3418,6 +3487,7 @@ pub(super) mod tests { x: String, #[scylla(allow_missing)] b: Option, + #[scylla(default_when_null)] c: bool, } From a970436b0300b93c4a32990055d20f9fbceddfaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 26 Jun 2024 08:38:42 +0200 Subject: [PATCH 28/29] value: move tests to a separate file value.rs already got huge, better to put tests aside. Also, if only tests are modified when put in the same file as library code, Cargo will rebuild the lib crate anyway. Conversely, when tests are in a separate file, the lib crate won't be rebuilt and this saves precious time. --- scylla-cql/src/types/deserialize/value.rs | 1970 +---------------- .../src/types/deserialize/value_tests.rs | 1946 ++++++++++++++++ 2 files changed, 1948 insertions(+), 1968 deletions(-) create mode 100644 scylla-cql/src/types/deserialize/value_tests.rs diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index 12a89e86ca..074a7c298a 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -1890,1974 +1890,8 @@ impl From for BuiltinDeserializationErrorKind { } #[cfg(test)] -pub(super) mod tests { - use assert_matches::assert_matches; - use bytes::{BufMut, Bytes, BytesMut}; - use uuid::Uuid; - - use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; - use std::fmt::Debug; - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - - use crate::frame::response::result::{ColumnType, CqlValue}; - use crate::frame::value::{ - Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, - }; - use crate::types::deserialize::value::{ - TupleDeserializationErrorKind, TupleTypeCheckErrorKind, - }; - use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; - use crate::types::serialize::value::SerializeValue; - use crate::types::serialize::CellWriter; - - use super::{ - mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, - BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, - MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, - SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, - UdtDeserializationErrorKind, UdtTypeCheckErrorKind, - }; - - #[test] - fn test_deserialize_bytes() { - const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; - - let bytes = make_bytes(ORIGINAL_BYTES); - - let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); - let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); - let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); - - assert_eq!(decoded_slice, ORIGINAL_BYTES); - assert_eq!(decoded_vec, ORIGINAL_BYTES); - assert_eq!(decoded_bytes, ORIGINAL_BYTES); - - // ser/de identity - - // Nonempty blob - assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); - - // Empty blob - assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); - } - - #[test] - fn test_deserialize_ascii() { - const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; - - let ascii = make_bytes(ASCII_TEXT.as_bytes()); - - for typ in [ColumnType::Ascii, ColumnType::Text].iter() { - let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); - let decoded_string = deserialize::(typ, &ascii).unwrap(); - - assert_eq!(decoded_str, ASCII_TEXT); - assert_eq!(decoded_string, ASCII_TEXT); - - // ser/de identity - - // Empty string - assert_ser_de_identity(typ, &"", &mut Bytes::new()); - assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); - - // Nonempty string - assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); - assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); - } - } - - #[test] - fn test_deserialize_text() { - const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; - - let unicode = make_bytes(UNICODE_TEXT.as_bytes()); - - // Should fail because it's not an ASCII string - deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); - deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); - - let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); - let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); - assert_eq!(decoded_text_str, UNICODE_TEXT); - assert_eq!(decoded_text_string, UNICODE_TEXT); - - // ser/de identity - - assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); - assert_ser_de_identity( - &ColumnType::Text, - &UNICODE_TEXT.to_owned(), - &mut Bytes::new(), - ); - } - - #[test] - fn test_integral() { - let tinyint = make_bytes(&[0x01]); - let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); - assert_eq!(decoded_tinyint, 0x01); - - let smallint = make_bytes(&[0x01, 0x02]); - let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); - assert_eq!(decoded_smallint, 0x0102); - - let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); - let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, 0x01020304); - - let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); - let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); - assert_eq!(decoded_bigint, 0x0102030405060708); - - // ser/de identity - assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); - } - - #[test] - fn test_bool() { - for boolean in [true, false] { - let boolean_bytes = make_bytes(&[boolean as u8]); - let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); - assert_eq!(decoded_bool, boolean); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); - } - } - - #[test] - fn test_floating_point() { - let float = make_bytes(&[63, 0, 0, 0]); - let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); - assert_eq!(decoded_float, 0.5); - - let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); - let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); - assert_eq!(decoded_double, 2.0); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); - } - - #[test] - fn test_varlen_numbers() { - // varint - assert_ser_de_identity( - &ColumnType::Varint, - &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), - &mut Bytes::new(), - ); - - #[cfg(feature = "num-bigint-03")] - assert_ser_de_identity( - &ColumnType::Varint, - &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), - &mut Bytes::new(), - ); - - #[cfg(feature = "num-bigint-04")] - assert_ser_de_identity( - &ColumnType::Varint, - &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), - &mut Bytes::new(), - ); - - // decimal - assert_ser_de_identity( - &ColumnType::Decimal, - &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), - &mut Bytes::new(), - ); - - #[cfg(feature = "bigdecimal-04")] - assert_ser_de_identity( - &ColumnType::Decimal, - &bigdecimal_04::BigDecimal::new( - bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), - 42, - ), - &mut Bytes::new(), - ); - } - - #[test] - fn test_date_time_types() { - // duration - assert_ser_de_identity( - &ColumnType::Duration, - &CqlDuration { - months: 21, - days: 37, - nanoseconds: 42, - }, - &mut Bytes::new(), - ); - - // date - assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Date, - &chrono_04::NaiveDate::from_yo_opt(1999, 99).unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Date, - &time_03::Date::from_ordinal_date(1999, 99).unwrap(), - &mut Bytes::new(), - ); - - // time - assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Time, - &chrono_04::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Time, - &time_03::Time::from_hms_micro(21, 37, 21, 37).unwrap(), - &mut Bytes::new(), - ); - - // timestamp - assert_ser_de_identity( - &ColumnType::Timestamp, - &CqlTimestamp(0xceed), - &mut Bytes::new(), - ); - - #[cfg(feature = "chrono-04")] - assert_ser_de_identity( - &ColumnType::Timestamp, - &chrono_04::DateTime::::from_timestamp_millis(0xdead_cafe_deaf) - .unwrap(), - &mut Bytes::new(), - ); - - #[cfg(feature = "time-03")] - assert_ser_de_identity( - &ColumnType::Timestamp, - &time_03::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), - &mut Bytes::new(), - ); - } - - #[test] - fn test_inet() { - assert_ser_de_identity( - &ColumnType::Inet, - &IpAddr::V4(Ipv4Addr::BROADCAST), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Inet, - &IpAddr::V6(Ipv6Addr::LOCALHOST), - &mut Bytes::new(), - ); - } - - #[test] - fn test_uuid() { - assert_ser_de_identity( - &ColumnType::Uuid, - &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Timeuuid, - &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - &mut Bytes::new(), - ); - } - - #[test] - fn test_null_and_empty() { - // non-nullable emptiable deserialization, non-empty value - let int = make_bytes(&[21, 37, 0, 0]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); - - // non-nullable emptiable deserialization, empty value - let int = make_bytes(&[]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, MaybeEmpty::Empty); - - // nullable non-emptiable deserialization, non-null value - let int = make_bytes(&[21, 37, 0, 0]); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); - - // nullable non-emptiable deserialization, null value - let int = make_null(); - let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, None); - - // nullable emptiable deserialization, non-null non-empty value - let int = make_bytes(&[]); - let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); - assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); - - // ser/de identity - assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); - assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); - assert_ser_de_identity( - &ColumnType::Set(Box::new(ColumnType::Ascii)), - &None::>, - &mut Bytes::new(), - ); - } - - #[test] - fn test_maybe_empty() { - let empty = make_bytes(&[]); - let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); - assert_eq!(decoded_empty, MaybeEmpty::Empty); - - let non_empty = make_bytes(&[0x01]); - let decoded_non_empty = - deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); - assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); - } - - #[test] - fn test_cql_value() { - assert_ser_de_identity( - &ColumnType::Counter, - &CqlValue::Counter(Counter(765)), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Timestamp, - &CqlValue::Timestamp(CqlTimestamp(2136)), - &mut Bytes::new(), - ); - - assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); - - assert_ser_de_identity( - &ColumnType::Text, - &CqlValue::Text("kremówki".to_owned()), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &ColumnType::Ascii, - &CqlValue::Ascii("kremowy".to_owned()), - &mut Bytes::new(), - ); - - assert_ser_de_identity( - &ColumnType::Set(Box::new(ColumnType::Text)), - &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_list_and_set() { - let mut collection_contents = BytesMut::new(); - collection_contents.put_i32(3); - append_bytes(&mut collection_contents, "quick".as_bytes()); - append_bytes(&mut collection_contents, "brown".as_bytes()); - append_bytes(&mut collection_contents, "fox".as_bytes()); - - let collection = make_bytes(&collection_contents); - - let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); - let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); - - // iterator - let mut iter = deserialize::>(&list_typ, &collection).unwrap(); - assert_eq!(iter.next().transpose().unwrap(), Some("quick")); - assert_eq!(iter.next().transpose().unwrap(), Some("brown")); - assert_eq!(iter.next().transpose().unwrap(), Some("fox")); - assert_eq!(iter.next().transpose().unwrap(), None); - - let expected_vec_str = vec!["quick", "brown", "fox"]; - let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; - - // list - let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); - let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); - assert_eq!(decoded_vec_str, expected_vec_str); - assert_eq!(decoded_vec_string, expected_vec_string); - - // hash set - let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); - let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); - assert_eq!( - decoded_hash_str, - expected_vec_str.clone().into_iter().collect(), - ); - assert_eq!( - decoded_hash_string, - expected_vec_string.clone().into_iter().collect(), - ); - - // btree set - let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); - let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); - assert_eq!( - decoded_btree_str, - expected_vec_str.clone().into_iter().collect(), - ); - assert_eq!( - decoded_btree_string, - expected_vec_string.into_iter().collect(), - ); - - // ser/de identity - assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); - assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); - assert_ser_de_identity( - &set_typ, - &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &set_typ, - &BTreeSet::<&str>::from_iter(["qwik"]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_map() { - let mut collection_contents = BytesMut::new(); - collection_contents.put_i32(3); - append_bytes(&mut collection_contents, &1i32.to_be_bytes()); - append_bytes(&mut collection_contents, "quick".as_bytes()); - append_bytes(&mut collection_contents, &2i32.to_be_bytes()); - append_bytes(&mut collection_contents, "brown".as_bytes()); - append_bytes(&mut collection_contents, &3i32.to_be_bytes()); - append_bytes(&mut collection_contents, "fox".as_bytes()); - - let collection = make_bytes(&collection_contents); - - let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); - - // iterator - let mut iter = deserialize::>(&typ, &collection).unwrap(); - assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); - assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); - assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); - assert_eq!(iter.next().transpose().unwrap(), None); - - let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; - let expected_string = vec![ - (1, "quick".to_string()), - (2, "brown".to_string()), - (3, "fox".to_string()), - ]; - - // hash set - let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); - let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); - assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); - assert_eq!( - decoded_hash_string, - expected_string.clone().into_iter().collect(), - ); - - // btree set - let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); - let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); - assert_eq!( - decoded_btree_str, - expected_str.clone().into_iter().collect(), - ); - assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); - - // ser/de identity - assert_ser_de_identity( - &typ, - &HashMap::::from_iter([( - -42, "qwik", - )]), - &mut Bytes::new(), - ); - assert_ser_de_identity( - &typ, - &BTreeMap::::from_iter([(-42, "qwik")]), - &mut Bytes::new(), - ); - } - - #[test] - fn test_tuples() { - let mut tuple_contents = BytesMut::new(); - append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); - append_bytes(&mut tuple_contents, "foo".as_bytes()); - append_null(&mut tuple_contents); - - let tuple = make_bytes(&tuple_contents); - - let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); - - let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); - assert_eq!(tup, (42, "foo", None)); - - // ser/de identity - - // () does not implement SerializeValue, yet it does implement DeserializeValue. - // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); - - // nonempty, varied tuple - assert_ser_de_identity( - &ColumnType::Tuple(vec![ - ColumnType::List(Box::new(ColumnType::Boolean)), - ColumnType::BigInt, - ColumnType::Uuid, - ColumnType::Inet, - ]), - &( - vec![true, false, true], - 42_i64, - Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), - IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), - ), - &mut Bytes::new(), - ); - - // nested tuples - assert_ser_de_identity( - &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ - ColumnType::Text, - ])])]), - &((("",),),), - &mut Bytes::new(), - ); - } - - fn udt_def_with_fields( - fields: impl IntoIterator, ColumnType)>, - ) -> ColumnType { - ColumnType::UserDefinedType { - type_name: "udt".to_owned(), - keyspace: "ks".to_owned(), - field_types: fields.into_iter().map(|(s, t)| (s.into(), t)).collect(), - } - } - - #[must_use] - struct UdtSerializer { - buf: BytesMut, - } - - impl UdtSerializer { - fn new() -> Self { - Self { - buf: BytesMut::default(), - } - } - - fn field(mut self, field_bytes: &[u8]) -> Self { - append_bytes(&mut self.buf, field_bytes); - self - } - - fn null_field(mut self) -> Self { - append_null(&mut self.buf); - self - } - - fn finalize(&self) -> Bytes { - make_bytes(&self.buf) - } - } - - // Do not remove. It's not used in tests but we keep it here to check that - // we properly ignore warnings about unused variables, unnecessary `mut`s - // etc. that usually pop up when generating code for empty structs. - #[allow(unused)] - #[derive(scylla_macros::DeserializeValue)] - #[scylla(crate = crate)] - struct TestUdtWithNoFieldsUnordered {} - - #[allow(unused)] - #[derive(scylla_macros::DeserializeValue)] - #[scylla(crate = crate, enforce_order)] - struct TestUdtWithNoFieldsOrdered {} - - #[test] - fn test_udt_loose_ordering() { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate")] - struct Udt<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - #[scylla(allow_missing)] - b: Option, - #[scylla(default_when_null)] - c: i64, - } - - // UDT fields in correct same order. - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42_i32.to_be_bytes()) - .field(&2137_i64.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("c", ColumnType::BigInt), - ]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - c: 2137, - } - ); - } - - // The last two UDT field are missing in serialized form - it should treat it - // as if there were nulls at the end. - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .finalize(); - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("c", ColumnType::BigInt), - ]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: None, - c: 0, - } - ); - } - - // UDT fields switched - should still work. - { - let udt = UdtSerializer::new() - .field(&42_i32.to_be_bytes()) - .field("The quick brown fox".as_bytes()) - .field(&2137_i64.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([ - ("b", ColumnType::Int), - ("a", ColumnType::Text), - ("c", ColumnType::BigInt), - ]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - c: 2137, - } - ); - } - - // An excess UDT field - should still work. - { - let udt = UdtSerializer::new() - .field(&12_i8.to_be_bytes()) - .field(&42_i32.to_be_bytes()) - .field("The quick brown fox".as_bytes()) - .field(&2137_i64.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([ - ("d", ColumnType::TinyInt), - ("b", ColumnType::Int), - ("a", ColumnType::Text), - ("c", ColumnType::BigInt), - ]); - - Udt::type_check(&typ).unwrap(); - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - c: 2137, - } - ); - } - - // Only field 'a' is present - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("c", ColumnType::BigInt)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: None, - c: 0, - } - ); - } - - // Wrong column type - { - let typ = udt_def_with_fields([("a", ColumnType::Text)]); - Udt::type_check(&typ).unwrap_err(); - } - - // Missing required column - { - let typ = udt_def_with_fields([("b", ColumnType::Int)]); - Udt::type_check(&typ).unwrap_err(); - } - } - - #[test] - fn test_udt_strict_ordering() { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order)] - struct Udt<'a> { - #[scylla(default_when_null)] - a: &'a str, - #[scylla(skip)] - x: String, - #[scylla(allow_missing)] - b: Option, - } - - // UDT fields in correct same order - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42i32.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - } - ); - } - - // The last UDT field is missing in serialized form - it should treat - // as if there were null at the end - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: None, - } - ); - } - - // An excess field at the end of UDT - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42_i32.to_be_bytes()) - .field(&(true as i8).to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("d", ColumnType::Boolean), - ]); - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - } - ); - } - - // An excess field at the end of UDT, when such are forbidden - { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] - struct Udt<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - b: Option, - } - - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("d", ColumnType::Boolean), - ]); - - Udt::type_check(&typ).unwrap_err(); - } - - // UDT fields switched - will not work - { - let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); - Udt::type_check(&typ).unwrap_err(); - } - - // Wrong column type - { - let typ = udt_def_with_fields([("a", ColumnType::Int), ("b", ColumnType::Int)]); - Udt::type_check(&typ).unwrap_err(); - } - - // Missing required column - { - let typ = udt_def_with_fields([("b", ColumnType::Int)]); - Udt::type_check(&typ).unwrap_err(); - } - - // Missing non-required column - { - let udt = UdtSerializer::new().field(b"kotmaale").finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "kotmaale", - x: String::new(), - b: None, - } - ); - } - - // The first field is null, but `default_when_null` prevents failure. - { - let udt = UdtSerializer::new() - .null_field() - .field(&42i32.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "", - x: String::new(), - b: Some(42), - } - ); - } - } - - #[test] - fn test_udt_no_name_check() { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order, skip_name_checks)] - struct Udt<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - b: Option, - } - - // UDT fields in correct same order - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42i32.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - } - ); - } - - // Correct order of UDT fields, but different names - should still succeed - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42i32.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([("k", ColumnType::Text), ("l", ColumnType::Int)]); - - let udt = deserialize::>(&typ, &udt).unwrap(); - assert_eq!( - udt, - Udt { - a: "The quick brown fox", - x: String::new(), - b: Some(42), - } - ); - } - } - - #[test] - fn test_udt_cross_rename_fields() { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = crate)] - struct TestUdt { - #[scylla(rename = "b")] - a: i32, - #[scylla(rename = "a")] - b: String, - } - - // UDT fields switched - should still work. - { - let udt = UdtSerializer::new() - .field("The quick brown fox".as_bytes()) - .field(&42_i32.to_be_bytes()) - .finalize(); - let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); - - let udt = deserialize::(&typ, &udt).unwrap(); - assert_eq!( - udt, - TestUdt { - a: 42, - b: "The quick brown fox".to_owned(), - } - ); - } - } - - #[test] - fn test_custom_type_parser() { - #[derive(Default, Debug, PartialEq, Eq)] - struct SwappedPair(B, A); - impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair - where - A: DeserializeValue<'frame>, - B: DeserializeValue<'frame>, - { - fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { - <(B, A) as DeserializeValue<'frame>>::type_check(typ) - } - - fn deserialize( - typ: &'frame ColumnType, - v: Option>, - ) -> Result { - <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) - } - } - - let mut tuple_contents = BytesMut::new(); - append_bytes(&mut tuple_contents, "foo".as_bytes()); - append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); - let tuple = make_bytes(&tuple_contents); - - let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); - - let tup = deserialize::>(&typ, &tuple).unwrap(); - assert_eq!(tup, SwappedPair("foo", 42)); - } - - fn deserialize<'frame, T>( - typ: &'frame ColumnType, - bytes: &'frame Bytes, - ) -> Result - where - T: DeserializeValue<'frame>, - { - >::type_check(typ) - .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; - let mut frame_slice = FrameSlice::new(bytes); - let value = frame_slice.read_cql_bytes().map_err(|err| { - mk_deser_err::( - typ, - BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), - ) - })?; - >::deserialize(typ, value) - } - - fn make_bytes(cell: &[u8]) -> Bytes { - let mut b = BytesMut::new(); - append_bytes(&mut b, cell); - b.freeze() - } - - fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { - let mut bytes = Bytes::new(); - serialize_to_buf(typ, value, &mut bytes); - bytes - } - - fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { - let mut v = Vec::new(); - let writer = CellWriter::new(&mut v); - value.serialize(typ, writer).unwrap(); - *buf = v.into(); - } - - fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { - b.put_i32(cell.len() as i32); - b.put_slice(cell); - } - - fn make_null() -> Bytes { - let mut b = BytesMut::new(); - append_null(&mut b); - b.freeze() - } - - fn append_null(b: &mut impl BufMut) { - b.put_i32(-1); - } - - fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( - typ: &'f ColumnType, - v: &'f T, - buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise - // we cannot specify the lifetime for DeserializeValue. - ) { - serialize_to_buf(typ, v, buf); - let deserialized = deserialize::(typ, buf).unwrap(); - assert_eq!(&deserialized, v); - } - - /* Errors checks */ - - #[track_caller] - pub(crate) fn get_typeck_err_inner<'a>( - err: &'a (dyn std::error::Error + 'static), - ) -> &'a BuiltinTypeCheckError { - match err.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinTypeCheckError: {:?}", err), - } - } - - #[track_caller] - pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { - get_typeck_err_inner(err.0.as_ref()) - } - - #[track_caller] - pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { - match err.0.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinDeserializationError: {:?}", err), - } - } - - macro_rules! assert_given_error { - ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - let cql_typ = $cql_typ.clone(); - let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); - let err = $get_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); - assert_eq!(err.cql_type, cql_typ); - assert_matches::assert_matches!(err.kind, $kind); - }; - } - - macro_rules! assert_type_check_error { - ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); - }; - } - - macro_rules! assert_deser_error { - ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { - assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); - }; - } - - #[test] - fn test_native_errors() { - // Simple type mismatch - { - let v = 123_i32; - let bytes = serialize(&ColumnType::Int, &v); - - // Incompatible types render type check error. - assert_type_check_error!( - &bytes, - f64, - ColumnType::Int, - super::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Double], - } - ); - - // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. - assert_deser_error!( - &bytes, - f64, - ColumnType::Double, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4, - } - ); - - // ColumnType is said to be Float, but in reality Int was serialized. - // As these types have the same size, though, and every binary number in [0, 2^32] is a valid - // value for both of them, this always succeeds. - { - deserialize::(&ColumnType::Float, &bytes).unwrap(); - } - } - - // str (and also Uuid) are interesting because they accept two types. - { - let v = "Ala ma kota"; - let bytes = serialize(&ColumnType::Ascii, &v); - - assert_type_check_error!( - &bytes, - &str, - ColumnType::Double, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text], - } - ); - - // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form - // (the string) has 11 bytes. - assert_deser_error!( - &bytes, - i64, - ColumnType::BigInt, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 11, // str len - } - ); - } - { - // -126 is not a valid ASCII nor UTF-8 byte. - let v = -126_i8; - let bytes = serialize(&ColumnType::TinyInt, &v); - - assert_deser_error!( - &bytes, - &str, - ColumnType::Ascii, - BuiltinDeserializationErrorKind::ExpectedAscii - ); - - assert_deser_error!( - &bytes, - &str, - ColumnType::Text, - BuiltinDeserializationErrorKind::InvalidUtf8(_) - ); - } - } - - #[test] - fn test_set_or_list_errors() { - // Not a set or list - { - assert_type_check_error!( - &Bytes::new(), - Vec, - ColumnType::Float, - BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::NotSetOrList - ) - ); - - // Type check of Rust set against CQL list must fail, because it would be lossy. - assert_type_check_error!( - &Bytes::new(), - BTreeSet, - ColumnType::List(Box::new(ColumnType::Int)), - BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) - ); - } - - // Got null - { - type RustTyp = Vec; - let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // Bad element type - { - assert_type_check_error!( - &Bytes::new(), - Vec, - ColumnType::List(Box::new(ColumnType::Ascii)), - BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) - ) - ); - - let err = deserialize::>( - &ColumnType::List(Box::new(ColumnType::Varint)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); - let BuiltinTypeCheckErrorKind::SetOrListError( - SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Varint); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - - { - let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); - let v = vec![123_i32]; - let bytes = serialize(&ser_typ, &v); - - { - let err = deserialize::>( - &ColumnType::List(Box::new(ColumnType::BigInt)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); - let BuiltinDeserializationErrorKind::SetOrListError( - SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - } - } - - #[test] - fn test_map_errors() { - // Not a map - { - let ser_typ = ColumnType::Float; - let v = 2.12_f32; - let bytes = serialize(&ser_typ, &v); - - assert_type_check_error!( - &bytes, - HashMap, - ser_typ, - BuiltinTypeCheckErrorKind::MapError( - MapTypeCheckErrorKind::NotMap, - ) - ); - } - - // Got null - { - type RustTyp = HashMap; - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - - let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ser_typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // Key type mismatch - { - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) - ); - let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed( - ref err, - )) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Varint); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - - // Value type mismatch - { - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) - ); - let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( - ref err, - )) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Boolean); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - - // Key length mismatch - { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); - - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) - ); - let BuiltinDeserializationErrorKind::MapError( - MapDeserializationErrorKind::KeyDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - - // Value length mismatch - { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); - - let err = deserialize::>( - &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::>()); - assert_eq!( - err.cql_type, - ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) - ); - let BuiltinDeserializationErrorKind::MapError( - MapDeserializationErrorKind::ValueDeserializationFailed(err), - ) = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::SmallInt); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 2, - got: 1 - } - ); - } - } - - #[test] - fn test_tuple_errors() { - // Not a tuple - { - assert_type_check_error!( - &Bytes::new(), - (i64,), - ColumnType::BigInt, - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) - ); - } - // Wrong element count - { - assert_type_check_error!( - &Bytes::new(), - (i64,), - ColumnType::Tuple(vec![]), - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { - rust_type_el_count: 1, - cql_type_el_count: 0, - }) - ); - - assert_type_check_error!( - &Bytes::new(), - (f32,), - ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), - BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { - rust_type_el_count: 1, - cql_type_el_count: 2, - }) - ); - } - - // Bad field type - { - { - let err = deserialize::<(i64,)>( - &ColumnType::Tuple(vec![ColumnType::SmallInt]), - &Bytes::new(), - ) - .unwrap_err(); - let err = get_typeck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); - let BuiltinTypeCheckErrorKind::TupleError( - TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(position, 0); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::SmallInt); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - } - - { - let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); - let v = (123_i32, 123.123_f32); - let bytes = serialize(&ser_typ, &v); - - { - let err = deserialize::<(i32, f64)>( - &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), - &bytes, - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); - assert_eq!( - err.cql_type, - ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) - ); - let BuiltinDeserializationErrorKind::TupleError( - TupleDeserializationErrorKind::FieldDeserializationFailed { - ref err, - position: index, - }, - ) = err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(index, 1); - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Double); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - } - } - - #[test] - fn test_null_errors() { - let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); - let v = HashMap::from([(42, false), (2137, true)]); - let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); - - deserialize::>(&ser_typ, &bytes).unwrap_err(); - } - - #[test] - fn test_udt_errors() { - // Loose ordering - { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", forbid_excess_udt_fields)] - struct Udt<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - #[scylla(allow_missing)] - b: Option, - #[scylla(default_when_null)] - c: bool, - } - - // Type check errors - { - // Not UDT - { - let typ = - ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = - err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - } - - // UDT missing fields - { - let typ = udt_def_with_fields([("c", ColumnType::Boolean)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::ValuesMissingForUdtFields { - field_names: ref missing_fields, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(missing_fields.as_slice(), &["a"]); - } - - // excess fields in UDT - { - let typ = udt_def_with_fields([ - ("d", ColumnType::Boolean), - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::ExcessFieldInUdt { ref db_field_name }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(db_field_name.as_str(), "d"); - } - - // missing UDT field - { - let typ = - udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::ValuesMissingForUdtFields { ref field_names }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_names, &["c"]); - } - - // UDT fields incompatible types - field type check failed - { - let typ = - udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::FieldTypeCheckFailed { - ref field_name, - ref err, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_name.as_str(), "a"); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Blob); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - } - - // Deserialization errors - { - // Got null - { - let typ = udt_def_with_fields([ - ("c", ColumnType::Boolean), - ("a", ColumnType::Blob), - ("b", ColumnType::Int), - ]); - - let err = Udt::deserialize(&typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // UDT field deserialization failed - { - let typ = - udt_def_with_fields([("a", ColumnType::Ascii), ("c", ColumnType::Boolean)]); - - let udt_bytes = UdtSerializer::new() - .field("alamakota".as_bytes()) - .field(&42_i16.to_be_bytes()) - .finalize(); - - let err = deserialize::(&typ, &udt_bytes).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinDeserializationErrorKind::UdtError( - UdtDeserializationErrorKind::FieldDeserializationFailed { - ref field_name, - ref err, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_name.as_str(), "c"); - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Boolean); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 1, - got: 2, - } - ); - } - } - } - - // Strict ordering - { - #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] - struct Udt<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - b: Option, - #[scylla(allow_missing)] - c: bool, - } - - // Type check errors - { - // Not UDT - { - let typ = - ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = - err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - } - - // UDT too few fields - { - let typ = udt_def_with_fields([("a", ColumnType::Text)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::TooFewFields { - ref required_fields, - ref present_fields, - }) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(required_fields.as_slice(), &["a", "b"]); - assert_eq!(present_fields.as_slice(), &["a".to_string()]); - } - - // excess fields in UDT - { - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("d", ColumnType::Boolean), - ]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::ExcessFieldInUdt { ref db_field_name }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(db_field_name.as_str(), "d"); - } - - // UDT fields switched - field name mismatch - { - let typ = - udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::FieldNameMismatch { - position, - ref rust_field_name, - ref db_field_name, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(position, 0); - assert_eq!(rust_field_name.as_str(), "a".to_owned()); - assert_eq!(db_field_name.as_str(), "b".to_owned()); - } - - // UDT fields incompatible types - field type check failed - { - let typ = - udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); - let err = Udt::type_check(&typ).unwrap_err(); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinTypeCheckErrorKind::UdtError( - UdtTypeCheckErrorKind::FieldTypeCheckFailed { - ref field_name, - ref err, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_name.as_str(), "a"); - let err = get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Blob); - assert_matches!( - err.kind, - BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - } - - // Deserialization errors - { - // Got null - { - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("c", ColumnType::Boolean), - ]); - - let err = Udt::deserialize(&typ, None).unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); - } - - // Bad field format - { - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("c", ColumnType::Boolean), - ]); - - let udt_bytes = UdtSerializer::new() - .field(b"alamakota") - .field(&42_i64.to_be_bytes()) - .field(&[true as u8]) - .finalize(); - - let udt_bytes_too_short = udt_bytes.slice(..udt_bytes.len() - 1); - assert!(udt_bytes.len() > udt_bytes_too_short.len()); - - let err = deserialize::(&typ, &udt_bytes_too_short).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinDeserializationErrorKind::RawCqlBytesReadError(_) = err.kind else { - panic!("unexpected error kind: {:?}", err.kind) - }; - } - - // UDT field deserialization failed - { - let typ = udt_def_with_fields([ - ("a", ColumnType::Text), - ("b", ColumnType::Int), - ("c", ColumnType::Boolean), - ]); - - let udt_bytes = UdtSerializer::new() - .field(b"alamakota") - .field(&42_i64.to_be_bytes()) - .field(&[true as u8]) - .finalize(); - - let err = deserialize::(&typ, &udt_bytes).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, typ); - let BuiltinDeserializationErrorKind::UdtError( - UdtDeserializationErrorKind::FieldDeserializationFailed { - ref field_name, - ref err, - }, - ) = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_name.as_str(), "b"); - let err = get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Int); - assert_matches!( - err.kind, - BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 4, - got: 8, - } - ); - } - } - } - } -} +#[path = "value_tests.rs"] +pub(super) mod tests; /// ```compile_fail /// diff --git a/scylla-cql/src/types/deserialize/value_tests.rs b/scylla-cql/src/types/deserialize/value_tests.rs new file mode 100644 index 0000000000..9375ce47f6 --- /dev/null +++ b/scylla-cql/src/types/deserialize/value_tests.rs @@ -0,0 +1,1946 @@ +use assert_matches::assert_matches; +use bytes::{BufMut, Bytes, BytesMut}; +use uuid::Uuid; + +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::fmt::Debug; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + +use crate::frame::response::result::{ColumnType, CqlValue}; +use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, +}; +use crate::types::deserialize::value::{TupleDeserializationErrorKind, TupleTypeCheckErrorKind}; +use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +use crate::types::serialize::value::SerializeValue; +use crate::types::serialize::CellWriter; + +use super::{ + mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, + MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, + SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, UdtDeserializationErrorKind, + UdtTypeCheckErrorKind, +}; + +#[test] +fn test_deserialize_bytes() { + const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; + + let bytes = make_bytes(ORIGINAL_BYTES); + + let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); + + assert_eq!(decoded_slice, ORIGINAL_BYTES); + assert_eq!(decoded_vec, ORIGINAL_BYTES); + assert_eq!(decoded_bytes, ORIGINAL_BYTES); + + // ser/de identity + + // Nonempty blob + assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); + + // Empty blob + assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); +} + +#[test] +fn test_deserialize_ascii() { + const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; + + let ascii = make_bytes(ASCII_TEXT.as_bytes()); + + for typ in [ColumnType::Ascii, ColumnType::Text].iter() { + let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); + let decoded_string = deserialize::(typ, &ascii).unwrap(); + + assert_eq!(decoded_str, ASCII_TEXT); + assert_eq!(decoded_string, ASCII_TEXT); + + // ser/de identity + + // Empty string + assert_ser_de_identity(typ, &"", &mut Bytes::new()); + assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); + + // Nonempty string + assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); + assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); + } +} + +#[test] +fn test_deserialize_text() { + const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; + + let unicode = make_bytes(UNICODE_TEXT.as_bytes()); + + // Should fail because it's not an ASCII string + deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); + deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); + + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); + assert_eq!(decoded_text_str, UNICODE_TEXT); + assert_eq!(decoded_text_string, UNICODE_TEXT); + + // ser/de identity + + assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Text, + &UNICODE_TEXT.to_owned(), + &mut Bytes::new(), + ); +} + +#[test] +fn test_integral() { + let tinyint = make_bytes(&[0x01]); + let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); + assert_eq!(decoded_tinyint, 0x01); + + let smallint = make_bytes(&[0x01, 0x02]); + let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); + assert_eq!(decoded_smallint, 0x0102); + + let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); + let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, 0x01020304); + + let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); + assert_eq!(decoded_bigint, 0x0102030405060708); + + // ser/de identity + assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); +} + +#[test] +fn test_bool() { + for boolean in [true, false] { + let boolean_bytes = make_bytes(&[boolean as u8]); + let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); + assert_eq!(decoded_bool, boolean); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); + } +} + +#[test] +fn test_floating_point() { + let float = make_bytes(&[63, 0, 0, 0]); + let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); + assert_eq!(decoded_float, 0.5); + + let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); + let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); + assert_eq!(decoded_double, 2.0); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); +} + +#[test] +fn test_varlen_numbers() { + // varint + assert_ser_de_identity( + &ColumnType::Varint, + &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-03")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-04")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + // decimal + assert_ser_de_identity( + &ColumnType::Decimal, + &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), + &mut Bytes::new(), + ); + + #[cfg(feature = "bigdecimal-04")] + assert_ser_de_identity( + &ColumnType::Decimal, + &bigdecimal_04::BigDecimal::new( + bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), + 42, + ), + &mut Bytes::new(), + ); +} + +#[test] +fn test_date_time_types() { + // duration + assert_ser_de_identity( + &ColumnType::Duration, + &CqlDuration { + months: 21, + days: 37, + nanoseconds: 42, + }, + &mut Bytes::new(), + ); + + // date + assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Date, + &chrono_04::NaiveDate::from_yo_opt(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Date, + &time_03::Date::from_ordinal_date(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + // time + assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Time, + &chrono_04::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Time, + &time_03::Time::from_hms_micro(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + // timestamp + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlTimestamp(0xceed), + &mut Bytes::new(), + ); + + #[cfg(feature = "chrono-04")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &chrono_04::DateTime::::from_timestamp_millis(0xdead_cafe_deaf).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time-03")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &time_03::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), + &mut Bytes::new(), + ); +} + +#[test] +fn test_inet() { + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V4(Ipv4Addr::BROADCAST), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V6(Ipv6Addr::LOCALHOST), + &mut Bytes::new(), + ); +} + +#[test] +fn test_uuid() { + assert_ser_de_identity( + &ColumnType::Uuid, + &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timeuuid, + &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); +} + +#[test] +fn test_null_and_empty() { + // non-nullable emptiable deserialization, non-empty value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); + + // non-nullable emptiable deserialization, empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Empty); + + // nullable non-emptiable deserialization, non-null value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); + + // nullable non-emptiable deserialization, null value + let int = make_null(); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, None); + + // nullable emptiable deserialization, non-null non-empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Ascii)), + &None::>, + &mut Bytes::new(), + ); +} + +#[test] +fn test_maybe_empty() { + let empty = make_bytes(&[]); + let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); + assert_eq!(decoded_empty, MaybeEmpty::Empty); + + let non_empty = make_bytes(&[0x01]); + let decoded_non_empty = + deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); + assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); +} + +#[test] +fn test_cql_value() { + assert_ser_de_identity( + &ColumnType::Counter, + &CqlValue::Counter(Counter(765)), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlValue::Timestamp(CqlTimestamp(2136)), + &mut Bytes::new(), + ); + + assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); + + assert_ser_de_identity( + &ColumnType::Text, + &CqlValue::Text("kremówki".to_owned()), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &ColumnType::Ascii, + &CqlValue::Ascii("kremowy".to_owned()), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Text)), + &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_list_and_set() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some("quick")); + assert_eq!(iter.next().transpose().unwrap(), Some("brown")); + assert_eq!(iter.next().transpose().unwrap(), Some("fox")); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_vec_str = vec!["quick", "brown", "fox"]; + let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; + + // list + let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); + let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(decoded_vec_str, expected_vec_str); + assert_eq!(decoded_vec_string, expected_vec_string); + + // hash set + let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_hash_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_hash_string, + expected_vec_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_btree_string, + expected_vec_string.into_iter().collect(), + ); + + // ser/de identity + assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity( + &set_typ, + &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &set_typ, + &BTreeSet::<&str>::from_iter(["qwik"]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_map() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, &1i32.to_be_bytes()); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, &2i32.to_be_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, &3i32.to_be_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); + assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); + assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; + let expected_string = vec![ + (1, "quick".to_string()), + (2, "brown".to_string()), + (3, "fox".to_string()), + ]; + + // hash set + let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); + assert_eq!( + decoded_hash_string, + expected_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_str.clone().into_iter().collect(), + ); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); + + // ser/de identity + assert_ser_de_identity( + &typ, + &HashMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &typ, + &BTreeMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); +} + +#[test] +fn test_tuples() { + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_null(&mut tuple_contents); + + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); + + let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); + assert_eq!(tup, (42, "foo", None)); + + // ser/de identity + + // () does not implement SerializeValue, yet it does implement DeserializeValue. + // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); + + // nonempty, varied tuple + assert_ser_de_identity( + &ColumnType::Tuple(vec![ + ColumnType::List(Box::new(ColumnType::Boolean)), + ColumnType::BigInt, + ColumnType::Uuid, + ColumnType::Inet, + ]), + &( + vec![true, false, true], + 42_i64, + Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), + ), + &mut Bytes::new(), + ); + + // nested tuples + assert_ser_de_identity( + &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ + ColumnType::Text, + ])])]), + &((("",),),), + &mut Bytes::new(), + ); +} + +fn udt_def_with_fields( + fields: impl IntoIterator, ColumnType)>, +) -> ColumnType { + ColumnType::UserDefinedType { + type_name: "udt".to_owned(), + keyspace: "ks".to_owned(), + field_types: fields.into_iter().map(|(s, t)| (s.into(), t)).collect(), + } +} + +#[must_use] +struct UdtSerializer { + buf: BytesMut, +} + +impl UdtSerializer { + fn new() -> Self { + Self { + buf: BytesMut::default(), + } + } + + fn field(mut self, field_bytes: &[u8]) -> Self { + append_bytes(&mut self.buf, field_bytes); + self + } + + fn null_field(mut self) -> Self { + append_null(&mut self.buf); + self + } + + fn finalize(&self) -> Bytes { + make_bytes(&self.buf) + } +} + +// Do not remove. It's not used in tests but we keep it here to check that +// we properly ignore warnings about unused variables, unnecessary `mut`s +// etc. that usually pop up when generating code for empty structs. +#[allow(unused)] +#[derive(scylla_macros::DeserializeValue)] +#[scylla(crate = crate)] +struct TestUdtWithNoFieldsUnordered {} + +#[allow(unused)] +#[derive(scylla_macros::DeserializeValue)] +#[scylla(crate = crate, enforce_order)] +struct TestUdtWithNoFieldsOrdered {} + +#[test] +fn test_udt_loose_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + #[scylla(default_when_null)] + c: i64, + } + + // UDT fields in correct same order. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // The last two UDT field are missing in serialized form - it should treat it + // as if there were nulls at the end. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + + // UDT fields switched - should still work. + { + let udt_bytes = UdtSerializer::new() + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // An excess UDT field - should still work. + { + let udt_bytes = UdtSerializer::new() + .field(&12_i8.to_be_bytes()) + .field(&42_i32.to_be_bytes()) + .field("The quick brown fox".as_bytes()) + .field(&2137_i64.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("d", ColumnType::TinyInt), + ("b", ColumnType::Int), + ("a", ColumnType::Text), + ("c", ColumnType::BigInt), + ]); + + Udt::type_check(&typ).unwrap(); + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + c: 2137, + } + ); + } + + // Only field 'a' is present + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("c", ColumnType::BigInt)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + c: 0, + } + ); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } +} + +#[test] +fn test_udt_strict_ordering() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct Udt<'a> { + #[scylla(default_when_null)] + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + } + + // UDT fields in correct same order + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // The last UDT field is missing in serialized form - it should treat + // as if there were null at the end + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: None, + } + ); + } + + // An excess field at the end of UDT + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .field(&(true as i8).to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // An excess field at the end of UDT, when such are forbidden + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + + Udt::type_check(&typ).unwrap_err(); + } + + // UDT fields switched - will not work + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Wrong column type + { + let typ = udt_def_with_fields([("a", ColumnType::Int), ("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing required column + { + let typ = udt_def_with_fields([("b", ColumnType::Int)]); + Udt::type_check(&typ).unwrap_err(); + } + + // Missing non-required column + { + let udt_bytes = UdtSerializer::new().field(b"kotmaale").finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "kotmaale", + x: String::new(), + b: None, + } + ); + } + + // The first field is null, but `default_when_null` prevents failure. + { + let udt_bytes = UdtSerializer::new() + .null_field() + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "", + x: String::new(), + b: Some(42), + } + ); + } +} + +#[test] +fn test_udt_no_name_check() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + } + + // UDT fields in correct same order + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } + + // Correct order of UDT fields, but different names - should still succeed + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("k", ColumnType::Text), ("l", ColumnType::Int)]); + + let udt = deserialize::>(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + Udt { + a: "The quick brown fox", + x: String::new(), + b: Some(42), + } + ); + } +} + +#[test] +fn test_udt_cross_rename_fields() { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestUdt { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // UDT fields switched - should still work. + { + let udt_bytes = UdtSerializer::new() + .field("The quick brown fox".as_bytes()) + .field(&42_i32.to_be_bytes()) + .finalize(); + let typ = udt_def_with_fields([("a", ColumnType::Text), ("b", ColumnType::Int)]); + + let udt = deserialize::(&typ, &udt_bytes).unwrap(); + assert_eq!( + udt, + TestUdt { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } +} + +#[test] +fn test_custom_type_parser() { + #[derive(Default, Debug, PartialEq, Eq)] + struct SwappedPair(B, A); + impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair + where + A: DeserializeValue<'frame>, + B: DeserializeValue<'frame>, + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + <(B, A) as DeserializeValue<'frame>>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) + } + } + + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); + + let tup = deserialize::>(&typ, &tuple).unwrap(); + assert_eq!(tup, SwappedPair("foo", 42)); +} + +fn deserialize<'frame, T>( + typ: &'frame ColumnType, + bytes: &'frame Bytes, +) -> Result +where + T: DeserializeValue<'frame>, +{ + >::type_check(typ) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let mut frame_slice = FrameSlice::new(bytes); + let value = frame_slice.read_cql_bytes().map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) + })?; + >::deserialize(typ, value) +} + +fn make_bytes(cell: &[u8]) -> Bytes { + let mut b = BytesMut::new(); + append_bytes(&mut b, cell); + b.freeze() +} + +fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { + let mut bytes = Bytes::new(); + serialize_to_buf(typ, value, &mut bytes); + bytes +} + +fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { + let mut v = Vec::new(); + let writer = CellWriter::new(&mut v); + value.serialize(typ, writer).unwrap(); + *buf = v.into(); +} + +fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { + b.put_i32(cell.len() as i32); + b.put_slice(cell); +} + +fn make_null() -> Bytes { + let mut b = BytesMut::new(); + append_null(&mut b); + b.freeze() +} + +fn append_null(b: &mut impl BufMut) { + b.put_i32(-1); +} + +fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( + typ: &'f ColumnType, + v: &'f T, + buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise + // we cannot specify the lifetime for DeserializeValue. +) { + serialize_to_buf(typ, v, buf); + let deserialized = deserialize::(typ, buf).unwrap(); + assert_eq!(&deserialized, v); +} + +/* Errors checks */ + +#[track_caller] +pub(crate) fn get_typeck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), +) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } +} + +#[track_caller] +pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typeck_err_inner(err.0.as_ref()) +} + +#[track_caller] +pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } +} + +macro_rules! assert_given_error { + ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + let cql_typ = $cql_typ.clone(); + let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); + let err = $get_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); + assert_eq!(err.cql_type, cql_typ); + assert_matches::assert_matches!(err.kind, $kind); + }; +} + +macro_rules! assert_type_check_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); + }; +} + +macro_rules! assert_deser_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); + }; +} + +#[test] +fn test_native_errors() { + // Simple type mismatch + { + let v = 123_i32; + let bytes = serialize(&ColumnType::Int, &v); + + // Incompatible types render type check error. + assert_type_check_error!( + &bytes, + f64, + ColumnType::Int, + super::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Double], + } + ); + + // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. + assert_deser_error!( + &bytes, + f64, + ColumnType::Double, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } + ); + + // ColumnType is said to be Float, but in reality Int was serialized. + // As these types have the same size, though, and every binary number in [0, 2^32] is a valid + // value for both of them, this always succeeds. + { + deserialize::(&ColumnType::Float, &bytes).unwrap(); + } + } + + // str (and also Uuid) are interesting because they accept two types. + { + let v = "Ala ma kota"; + let bytes = serialize(&ColumnType::Ascii, &v); + + assert_type_check_error!( + &bytes, + &str, + ColumnType::Double, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text], + } + ); + + // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form + // (the string) has 11 bytes. + assert_deser_error!( + &bytes, + i64, + ColumnType::BigInt, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 11, // str len + } + ); + } + { + // -126 is not a valid ASCII nor UTF-8 byte. + let v = -126_i8; + let bytes = serialize(&ColumnType::TinyInt, &v); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Ascii, + BuiltinDeserializationErrorKind::ExpectedAscii + ); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Text, + BuiltinDeserializationErrorKind::InvalidUtf8(_) + ); + } +} + +#[test] +fn test_set_or_list_errors() { + // Not a set or list + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::Float, + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSetOrList) + ); + + // Type check of Rust set against CQL list must fail, because it would be lossy. + assert_type_check_error!( + &Bytes::new(), + BTreeSet, + ColumnType::List(Box::new(ColumnType::Int)), + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) + ); + } + + // Got null + { + type RustTyp = Vec; + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad element type + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::List(Box::new(ColumnType::Ascii)), + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) + ) + ); + + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::Varint)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); + let BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + { + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + let v = vec![123_i32]; + let bytes = serialize(&ser_typ, &v); + + { + let err = + deserialize::>(&ColumnType::List(Box::new(ColumnType::BigInt)), &bytes) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); + let BuiltinDeserializationErrorKind::SetOrListError( + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } +} + +#[test] +fn test_map_errors() { + // Not a map + { + let ser_typ = ColumnType::Float; + let v = 2.12_f32; + let bytes = serialize(&ser_typ, &v); + + assert_type_check_error!( + &bytes, + HashMap, + ser_typ, + BuiltinTypeCheckErrorKind::MapError( + MapTypeCheckErrorKind::NotMap, + ) + ); + } + + // Got null + { + type RustTyp = HashMap; + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Key type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed(ref err)) = + err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Value type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + + // Key length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Value length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 2, + got: 1 + } + ); + } +} + +#[test] +fn test_tuple_errors() { + // Not a tuple + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::BigInt, + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) + ); + } + // Wrong element count + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::Tuple(vec![]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 0, + }) + ); + + assert_type_check_error!( + &Bytes::new(), + (f32,), + ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 2, + }) + ); + } + + // Bad field type + { + { + let err = deserialize::<(i64,)>( + &ColumnType::Tuple(vec![ColumnType::SmallInt]), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); + let BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(position, 0); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + } + + { + let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); + let v = (123_i32, 123.123_f32); + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::<(i32, f64)>( + &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); + assert_eq!( + err.cql_type, + ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) + ); + let BuiltinDeserializationErrorKind::TupleError( + TupleDeserializationErrorKind::FieldDeserializationFailed { + ref err, + position: index, + }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(index, 1); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Double); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } +} + +#[test] +fn test_null_errors() { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + deserialize::>(&ser_typ, &bytes).unwrap_err(); +} + +#[test] +fn test_udt_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + #[scylla(allow_missing)] + b: Option, + #[scylla(default_when_null)] + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT missing fields + { + let typ = udt_def_with_fields([("c", ColumnType::Boolean)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { + field_names: ref missing_fields, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["a"]); + } + + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("d", ColumnType::Boolean), + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::ExcessFieldInUdt { + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + + // missing UDT field + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::ValuesMissingForUdtFields { ref field_names }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_names, &["c"]); + } + + // UDT fields incompatible types - field type check failed + { + let typ = udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("c", ColumnType::Boolean), + ("a", ColumnType::Blob), + ("b", ColumnType::Int), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // UDT field deserialization failed + { + let typ = + udt_def_with_fields([("a", ColumnType::Ascii), ("c", ColumnType::Boolean)]); + + let udt_bytes = UdtSerializer::new() + .field("alamakota".as_bytes()) + .field(&42_i16.to_be_bytes()) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "c"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } + } + } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeValue, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, forbid_excess_udt_fields)] + struct Udt<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + #[scylla(allow_missing)] + c: bool, + } + + // Type check errors + { + // Not UDT + { + let typ = ColumnType::Map(Box::new(ColumnType::Ascii), Box::new(ColumnType::Blob)); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::NotUdt) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT too few fields + { + let typ = udt_def_with_fields([("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::TooFewFields { + ref required_fields, + ref present_fields, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(required_fields.as_slice(), &["a", "b"]); + assert_eq!(present_fields.as_slice(), &["a".to_string()]); + } + + // excess fields in UDT + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("d", ColumnType::Boolean), + ]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::ExcessFieldInUdt { + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(db_field_name.as_str(), "d"); + } + + // UDT fields switched - field name mismatch + { + let typ = udt_def_with_fields([("b", ColumnType::Int), ("a", ColumnType::Text)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError(UdtTypeCheckErrorKind::FieldNameMismatch { + position, + ref rust_field_name, + ref db_field_name, + }) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(position, 0); + assert_eq!(rust_field_name.as_str(), "a".to_owned()); + assert_eq!(db_field_name.as_str(), "b".to_owned()); + } + + // UDT fields incompatible types - field type check failed + { + let typ = udt_def_with_fields([("a", ColumnType::Blob), ("b", ColumnType::Int)]); + let err = Udt::type_check(&typ).unwrap_err(); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinTypeCheckErrorKind::UdtError( + UdtTypeCheckErrorKind::FieldTypeCheckFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "a"); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let err = Udt::deserialize(&typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad field format + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let udt_bytes_too_short = udt_bytes.slice(..udt_bytes.len() - 1); + assert!(udt_bytes.len() > udt_bytes_too_short.len()); + + let err = deserialize::(&typ, &udt_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::RawCqlBytesReadError(_) = err.kind else { + panic!("unexpected error kind: {:?}", err.kind) + }; + } + + // UDT field deserialization failed + { + let typ = udt_def_with_fields([ + ("a", ColumnType::Text), + ("b", ColumnType::Int), + ("c", ColumnType::Boolean), + ]); + + let udt_bytes = UdtSerializer::new() + .field(b"alamakota") + .field(&42_i64.to_be_bytes()) + .field(&[true as u8]) + .finalize(); + + let err = deserialize::(&typ, &udt_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, typ); + let BuiltinDeserializationErrorKind::UdtError( + UdtDeserializationErrorKind::FieldDeserializationFailed { + ref field_name, + ref err, + }, + ) = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_name.as_str(), "b"); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } +} From f470d1d0d7443ec86f7f6ec11b3c6d3b7f4e7670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojciech=20Przytu=C5=82a?= Date: Wed, 26 Jun 2024 08:40:18 +0200 Subject: [PATCH 29/29] row: move tests to a separate file If only tests are modified when put in the same file as library code, Cargo will rebuild the lib crate anyway. Conversely, when tests are in a separate file, the lib crate won't be rebuilt and this saves precious time. --- scylla-cql/src/types/deserialize/row.rs | 868 +----------------- scylla-cql/src/types/deserialize/row_tests.rs | 861 +++++++++++++++++ 2 files changed, 863 insertions(+), 866 deletions(-) create mode 100644 scylla-cql/src/types/deserialize/row_tests.rs diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs index c4e8e9cf1c..8971e1a711 100644 --- a/scylla-cql/src/types/deserialize/row.rs +++ b/scylla-cql/src/types/deserialize/row.rs @@ -488,872 +488,8 @@ impl Display for BuiltinDeserializationErrorKind { } #[cfg(test)] -mod tests { - use assert_matches::assert_matches; - use bytes::Bytes; - use scylla_macros::DeserializeRow; - - use crate::frame::response::result::{ColumnSpec, ColumnType}; - use crate::types::deserialize::row::BuiltinDeserializationErrorKind; - use crate::types::deserialize::{value, DeserializationError, FrameSlice}; - - use super::super::tests::{serialize_cells, spec}; - use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; - use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; - - #[test] - fn test_tuple_deserialization() { - // Empty tuple - deserialize::<()>(&[], &Bytes::new()).unwrap(); - - // 1-elem tuple - let (a,) = deserialize::<(i32,)>( - &[spec("i", ColumnType::Int)], - &serialize_cells([val_int(123)]), - ) - .unwrap(); - assert_eq!(a, 123); - - // 3-elem tuple - let (a, b, c) = deserialize::<(i32, i32, i32)>( - &[ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Int), - spec("i3", ColumnType::Int), - ], - &serialize_cells([val_int(123), val_int(456), val_int(789)]), - ) - .unwrap(); - assert_eq!((a, b, c), (123, 456, 789)); - - // Make sure that column type mismatch is detected - deserialize::<(i32, String, i32)>( - &[ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Int), - spec("i3", ColumnType::Int), - ], - &serialize_cells([val_int(123), val_int(456), val_int(789)]), - ) - .unwrap_err(); - - // Make sure that borrowing types compile and work correctly - let specs = &[spec("s", ColumnType::Text)]; - let byts = serialize_cells([val_str("abc")]); - let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); - assert_eq!(s, "abc"); - } - - #[test] - fn test_deserialization_as_column_iterator() { - let col_specs = [ - spec("i1", ColumnType::Int), - spec("i2", ColumnType::Text), - spec("i3", ColumnType::Counter), - ]; - let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); - let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); - - let col1 = iter.next().unwrap().unwrap(); - assert_eq!(col1.spec.name, "i1"); - assert_eq!(col1.spec.typ, ColumnType::Int); - assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); - - let col2 = iter.next().unwrap().unwrap(); - assert_eq!(col2.spec.name, "i2"); - assert_eq!(col2.spec.typ, ColumnType::Text); - assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); - - let col3 = iter.next().unwrap().unwrap(); - assert_eq!(col3.spec.name, "i3"); - assert_eq!(col3.spec.typ, ColumnType::Counter); - assert!(col3.slice.is_none()); - - assert!(iter.next().is_none()); - } - - // Do not remove. It's not used in tests but we keep it here to check that - // we properly ignore warnings about unused variables, unnecessary `mut`s - // etc. that usually pop up when generating code for empty structs. - #[allow(unused)] - #[derive(DeserializeRow)] - #[scylla(crate = crate)] - struct TestUdtWithNoFieldsUnordered {} - - #[allow(unused)] - #[derive(DeserializeRow)] - #[scylla(crate = crate, enforce_order)] - struct TestUdtWithNoFieldsOrdered {} - - #[test] - fn test_struct_deserialization_loose_ordering() { - #[derive(DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = "crate")] - struct MyRow<'a> { - a: &'a str, - b: Option, - #[scylla(skip)] - c: String, - } - - // Original order of columns - let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; - let byts = serialize_cells([val_str("abc"), val_int(123)]); - let row = deserialize::>(specs, &byts).unwrap(); - assert_eq!( - row, - MyRow { - a: "abc", - b: Some(123), - c: String::new(), - } - ); - - // Different order of columns - should still work - let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; - let byts = serialize_cells([val_int(123), val_str("abc")]); - let row = deserialize::>(specs, &byts).unwrap(); - assert_eq!( - row, - MyRow { - a: "abc", - b: Some(123), - c: String::new(), - } - ); - - // Missing column - let specs = &[spec("a", ColumnType::Text)]; - MyRow::type_check(specs).unwrap_err(); - - // Wrong column type - let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; - MyRow::type_check(specs).unwrap_err(); - } - - #[test] - fn test_struct_deserialization_strict_ordering() { - #[derive(DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order)] - struct MyRow<'a> { - a: &'a str, - b: Option, - #[scylla(skip)] - c: String, - } - - // Correct order of columns - let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; - let byts = serialize_cells([val_str("abc"), val_int(123)]); - let row = deserialize::>(specs, &byts).unwrap(); - assert_eq!( - row, - MyRow { - a: "abc", - b: Some(123), - c: String::new(), - } - ); - - // Wrong order of columns - let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; - MyRow::type_check(specs).unwrap_err(); - - // Missing column - let specs = &[spec("a", ColumnType::Text)]; - MyRow::type_check(specs).unwrap_err(); - - // Wrong column type - let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; - MyRow::type_check(specs).unwrap_err(); - } - - #[test] - fn test_struct_deserialization_no_name_check() { - #[derive(DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order, skip_name_checks)] - struct MyRow<'a> { - a: &'a str, - b: Option, - #[scylla(skip)] - c: String, - } - - // Correct order of columns - let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; - let byts = serialize_cells([val_str("abc"), val_int(123)]); - let row = deserialize::>(specs, &byts).unwrap(); - assert_eq!( - row, - MyRow { - a: "abc", - b: Some(123), - c: String::new(), - } - ); - - // Correct order of columns, but different names - should still succeed - let specs = &[spec("z", ColumnType::Text), spec("x", ColumnType::Int)]; - let byts = serialize_cells([val_str("abc"), val_int(123)]); - let row = deserialize::>(specs, &byts).unwrap(); - assert_eq!( - row, - MyRow { - a: "abc", - b: Some(123), - c: String::new(), - } - ); - } - - #[test] - fn test_struct_deserialization_cross_rename_fields() { - #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = crate)] - struct TestRow { - #[scylla(rename = "b")] - a: i32, - #[scylla(rename = "a")] - b: String, - } - - // Columns switched wrt fields - should still work. - { - let row_bytes = serialize_cells( - ["The quick brown fox".as_bytes(), &42_i32.to_be_bytes()].map(Some), - ); - let specs = [spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; - - let row = deserialize::(&specs, &row_bytes).unwrap(); - assert_eq!( - row, - TestRow { - a: 42, - b: "The quick brown fox".to_owned(), - } - ); - } - } - - fn val_int(i: i32) -> Option> { - Some(i.to_be_bytes().to_vec()) - } - - fn val_str(s: &str) -> Option> { - Some(s.as_bytes().to_vec()) - } - - fn deserialize<'frame, R>( - specs: &'frame [ColumnSpec], - byts: &'frame Bytes, - ) -> Result - where - R: DeserializeRow<'frame>, - { - >::type_check(specs) - .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; - let slice = FrameSlice::new(byts); - let iter = ColumnIterator::new(specs, slice); - >::deserialize(iter) - } - - #[track_caller] - pub(crate) fn get_typck_err_inner<'a>( - err: &'a (dyn std::error::Error + 'static), - ) -> &'a BuiltinTypeCheckError { - match err.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinTypeCheckError: {:?}", err), - } - } - - #[track_caller] - fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { - get_typck_err_inner(err.0.as_ref()) - } - - #[track_caller] - fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { - match err.0.downcast_ref() { - Some(err) => err, - None => panic!("not a BuiltinDeserializationError: {:?}", err), - } - } - - #[test] - fn test_tuple_errors() { - // Column type check failure - { - let col_name: &str = "i"; - let specs = &[spec(col_name, ColumnType::Int)]; - let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); - let err = get_typck_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - assert_eq!( - err.cql_types, - specs - .iter() - .map(|spec| spec.typ.clone()) - .collect::>() - ); - let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { - column_index, - column_name, - err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(*column_index, 0); - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Int); - assert_matches!( - &err.kind, - super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::BigInt, ColumnType::Counter] - } - ); - } - - // Column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::<(i64,)>( - &[spec(col_name, ColumnType::BigInt)], - &serialize_cells([val_int(123)]), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_name, - err, - .. - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - assert_matches!( - err.kind, - super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4 - } - ); - } - - // Raw column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::<(i64,)>( - &[spec(col_name, ColumnType::BigInt)], - &Bytes::from_static(b"alamakota"), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index: _column_index, - column_name, - err: _err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - } - } - - #[test] - fn test_row_errors() { - // Column type check failure - happens never, because Row consists of CqlValues, - // which accept all CQL types. - - // Column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::( - &[spec(col_name, ColumnType::BigInt)], - &serialize_cells([val_int(123)]), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_index: _column_index, - column_name, - err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - let err = super::super::value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::BigInt); - let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 8, - got: 4, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - } - - // Raw column deserialization failure - { - let col_name: &str = "i"; - let err = deserialize::( - &[spec(col_name, ColumnType::BigInt)], - &Bytes::from_static(b"alamakota"), - ) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index: _column_index, - column_name, - err: _err, - } = &err.kind - else { - panic!("unexpected error kind: {}", err.kind) - }; - assert_eq!(column_name, col_name); - } - } - - fn specs_to_types(specs: &[ColumnSpec]) -> Vec { - specs.iter().map(|spec| spec.typ.clone()).collect() - } - - #[test] - fn test_struct_deserialization_errors() { - // Loose ordering - { - #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = "crate")] - struct MyRow<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - b: Option, - #[scylla(rename = "c")] - d: bool, - } - - // Type check errors - { - // Missing column - { - let specs = [spec("a", ColumnType::Ascii), spec("b", ColumnType::Int)]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::ValuesMissingForColumns { - column_names: ref missing_fields, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(missing_fields.as_slice(), &["c"]); - } - - // Duplicated column - { - let specs = [ - spec("a", ColumnType::Ascii), - spec("b", ColumnType::Int), - spec("a", ColumnType::Ascii), - ]; - - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::DuplicatedColumn { - column_index, - column_name, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 2); - assert_eq!(column_name, "a"); - } - - // Unknown column - { - let specs = [ - spec("d", ColumnType::Counter), - spec("a", ColumnType::Ascii), - spec("b", ColumnType::Int), - ]; - - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::ColumnWithUnknownName { - column_index, - ref column_name, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 0); - assert_eq!(column_name.as_str(), "d"); - } - - // Column incompatible types - column type check failed - { - let specs = [spec("b", ColumnType::Int), spec("a", ColumnType::Blob)]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { - column_index, - ref column_name, - ref err, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 1); - assert_eq!(column_name.as_str(), "a"); - let err = value::tests::get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Blob); - assert_matches!( - err.kind, - value::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - } - - // Deserialization errors - { - // Got null - { - let specs = [ - spec("c", ColumnType::Boolean), - spec("a", ColumnType::Blob), - spec("b", ColumnType::Int), - ]; - - let err = MyRow::deserialize(ColumnIterator::new( - &specs, - FrameSlice::new(&serialize_cells([Some([true as u8])])), - )) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index, - ref column_name, - .. - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 1); - assert_eq!(column_name, "a"); - } - - // Column deserialization failed - { - let specs = [ - spec("b", ColumnType::Int), - spec("a", ColumnType::Ascii), - spec("c", ColumnType::Boolean), - ]; - - let row_bytes = serialize_cells( - [ - &0_i32.to_be_bytes(), - "alamakota".as_bytes(), - &42_i16.to_be_bytes(), - ] - .map(Some), - ); - - let err = deserialize::(&specs, &row_bytes).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_index, - ref column_name, - ref err, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 2); - assert_eq!(column_name.as_str(), "c"); - let err = value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Boolean); - assert_matches!( - err.kind, - value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 1, - got: 2, - } - ); - } - } - } - - // Strict ordering - { - #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] - #[scylla(crate = "crate", enforce_order)] - struct MyRow<'a> { - a: &'a str, - #[scylla(skip)] - x: String, - b: Option, - c: bool, - } - - // Type check errors - { - // Too few columns - { - let specs = [spec("a", ColumnType::Text)]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::WrongColumnCount { - rust_cols, - cql_cols, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(rust_cols, 3); - assert_eq!(cql_cols, 1); - } - - // Excess columns - { - let specs = [ - spec("a", ColumnType::Text), - spec("b", ColumnType::Int), - spec("c", ColumnType::Boolean), - spec("d", ColumnType::Counter), - ]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::WrongColumnCount { - rust_cols, - cql_cols, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(rust_cols, 3); - assert_eq!(cql_cols, 4); - } - - // Renamed column name mismatch - { - let specs = [ - spec("a", ColumnType::Text), - spec("b", ColumnType::Int), - spec("d", ColumnType::Boolean), - ]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinTypeCheckErrorKind::ColumnNameMismatch { - field_index, - column_index, - rust_column_name, - ref db_column_name, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_index, 3); - assert_eq!(rust_column_name, "c"); - assert_eq!(column_index, 2); - assert_eq!(db_column_name.as_str(), "d"); - } - - // Columns switched - column name mismatch - { - let specs = [ - spec("b", ColumnType::Int), - spec("a", ColumnType::Text), - spec("c", ColumnType::Boolean), - ]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::ColumnNameMismatch { - field_index, - column_index, - rust_column_name, - ref db_column_name, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(field_index, 0); - assert_eq!(column_index, 0); - assert_eq!(rust_column_name, "a"); - assert_eq!(db_column_name.as_str(), "b"); - } - - // Column incompatible types - column type check failed - { - let specs = [ - spec("a", ColumnType::Blob), - spec("b", ColumnType::Int), - spec("c", ColumnType::Boolean), - ]; - let err = MyRow::type_check(&specs).unwrap_err(); - let err = get_typck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_types, specs_to_types(&specs)); - let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { - column_index, - ref column_name, - ref err, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 0); - assert_eq!(column_name.as_str(), "a"); - let err = value::tests::get_typeck_err_inner(err.0.as_ref()); - assert_eq!(err.rust_name, std::any::type_name::<&str>()); - assert_eq!(err.cql_type, ColumnType::Blob); - assert_matches!( - err.kind, - value::BuiltinTypeCheckErrorKind::MismatchedType { - expected: &[ColumnType::Ascii, ColumnType::Text] - } - ); - } - } - - // Deserialization errors - { - // Too few columns - { - let specs = [ - spec("a", ColumnType::Text), - spec("b", ColumnType::Int), - spec("c", ColumnType::Boolean), - ]; - - let err = MyRow::deserialize(ColumnIterator::new( - &specs, - FrameSlice::new(&serialize_cells([Some([true as u8])])), - )) - .unwrap_err(); - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index, - ref column_name, - .. - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 1); - assert_eq!(column_name, "b"); - } - - // Bad field format - { - let typ = [ - spec("a", ColumnType::Text), - spec("b", ColumnType::Int), - spec("c", ColumnType::Boolean), - ]; - - let row_bytes = serialize_cells( - [(&b"alamakota"[..]), &42_i32.to_be_bytes(), &[true as u8]].map(Some), - ); - - let row_bytes_too_short = row_bytes.slice(..row_bytes.len() - 1); - assert!(row_bytes.len() > row_bytes_too_short.len()); - - let err = deserialize::(&typ, &row_bytes_too_short).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { - column_index, - ref column_name, - .. - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_index, 2); - assert_eq!(column_name, "c"); - } - - // Column deserialization failed - { - let specs = [ - spec("a", ColumnType::Text), - spec("b", ColumnType::Int), - spec("c", ColumnType::Boolean), - ]; - - let row_bytes = serialize_cells( - [&b"alamakota"[..], &42_i64.to_be_bytes(), &[true as u8]].map(Some), - ); - - let err = deserialize::(&specs, &row_bytes).unwrap_err(); - - let err = get_deser_err(&err); - assert_eq!(err.rust_name, std::any::type_name::()); - let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { - column_index: field_index, - ref column_name, - ref err, - } = err.kind - else { - panic!("unexpected error kind: {:?}", err.kind) - }; - assert_eq!(column_name.as_str(), "b"); - assert_eq!(field_index, 2); - let err = value::tests::get_deser_err(err); - assert_eq!(err.rust_name, std::any::type_name::()); - assert_eq!(err.cql_type, ColumnType::Int); - assert_matches!( - err.kind, - value::BuiltinDeserializationErrorKind::ByteLengthMismatch { - expected: 4, - got: 8, - } - ); - } - } - } - } -} +#[path = "row_tests.rs"] +mod tests; /// ```compile_fail /// diff --git a/scylla-cql/src/types/deserialize/row_tests.rs b/scylla-cql/src/types/deserialize/row_tests.rs new file mode 100644 index 0000000000..f4c79d66a8 --- /dev/null +++ b/scylla-cql/src/types/deserialize/row_tests.rs @@ -0,0 +1,861 @@ +use assert_matches::assert_matches; +use bytes::Bytes; +use scylla_macros::DeserializeRow; + +use crate::frame::response::result::{ColumnSpec, ColumnType}; +use crate::types::deserialize::row::BuiltinDeserializationErrorKind; +use crate::types::deserialize::{value, DeserializationError, FrameSlice}; + +use super::super::tests::{serialize_cells, spec}; +use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; +use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; + +#[test] +fn test_tuple_deserialization() { + // Empty tuple + deserialize::<()>(&[], &Bytes::new()).unwrap(); + + // 1-elem tuple + let (a,) = deserialize::<(i32,)>( + &[spec("i", ColumnType::Int)], + &serialize_cells([val_int(123)]), + ) + .unwrap(); + assert_eq!(a, 123); + + // 3-elem tuple + let (a, b, c) = deserialize::<(i32, i32, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap(); + assert_eq!((a, b, c), (123, 456, 789)); + + // Make sure that column type mismatch is detected + deserialize::<(i32, String, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap_err(); + + // Make sure that borrowing types compile and work correctly + let specs = &[spec("s", ColumnType::Text)]; + let byts = serialize_cells([val_str("abc")]); + let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); + assert_eq!(s, "abc"); +} + +#[test] +fn test_deserialization_as_column_iterator() { + let col_specs = [ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Text), + spec("i3", ColumnType::Counter), + ]; + let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); + let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); + + let col1 = iter.next().unwrap().unwrap(); + assert_eq!(col1.spec.name, "i1"); + assert_eq!(col1.spec.typ, ColumnType::Int); + assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); + + let col2 = iter.next().unwrap().unwrap(); + assert_eq!(col2.spec.name, "i2"); + assert_eq!(col2.spec.typ, ColumnType::Text); + assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); + + let col3 = iter.next().unwrap().unwrap(); + assert_eq!(col3.spec.name, "i3"); + assert_eq!(col3.spec.typ, ColumnType::Counter); + assert!(col3.slice.is_none()); + + assert!(iter.next().is_none()); +} + +// Do not remove. It's not used in tests but we keep it here to check that +// we properly ignore warnings about unused variables, unnecessary `mut`s +// etc. that usually pop up when generating code for empty structs. +#[allow(unused)] +#[derive(DeserializeRow)] +#[scylla(crate = crate)] +struct TestUdtWithNoFieldsUnordered {} + +#[allow(unused)] +#[derive(DeserializeRow)] +#[scylla(crate = crate, enforce_order)] +struct TestUdtWithNoFieldsOrdered {} + +#[test] +fn test_struct_deserialization_loose_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Original order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Different order of columns - should still work + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + let byts = serialize_cells([val_int(123), val_str("abc")]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); +} + +#[test] +fn test_struct_deserialization_strict_ordering() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Wrong order of columns + let specs = &[spec("b", ColumnType::Int), spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Missing column + let specs = &[spec("a", ColumnType::Text)]; + MyRow::type_check(specs).unwrap_err(); + + // Wrong column type + let specs = &[spec("a", ColumnType::Int), spec("b", ColumnType::Int)]; + MyRow::type_check(specs).unwrap_err(); +} + +#[test] +fn test_struct_deserialization_no_name_check() { + #[derive(DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order, skip_name_checks)] + struct MyRow<'a> { + a: &'a str, + b: Option, + #[scylla(skip)] + c: String, + } + + // Correct order of columns + let specs = &[spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); + + // Correct order of columns, but different names - should still succeed + let specs = &[spec("z", ColumnType::Text), spec("x", ColumnType::Int)]; + let byts = serialize_cells([val_str("abc"), val_int(123)]); + let row = deserialize::>(specs, &byts).unwrap(); + assert_eq!( + row, + MyRow { + a: "abc", + b: Some(123), + c: String::new(), + } + ); +} + +#[test] +fn test_struct_deserialization_cross_rename_fields() { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = crate)] + struct TestRow { + #[scylla(rename = "b")] + a: i32, + #[scylla(rename = "a")] + b: String, + } + + // Columns switched wrt fields - should still work. + { + let row_bytes = + serialize_cells(["The quick brown fox".as_bytes(), &42_i32.to_be_bytes()].map(Some)); + let specs = [spec("a", ColumnType::Text), spec("b", ColumnType::Int)]; + + let row = deserialize::(&specs, &row_bytes).unwrap(); + assert_eq!( + row, + TestRow { + a: 42, + b: "The quick brown fox".to_owned(), + } + ); + } +} + +fn val_int(i: i32) -> Option> { + Some(i.to_be_bytes().to_vec()) +} + +fn val_str(s: &str) -> Option> { + Some(s.as_bytes().to_vec()) +} + +fn deserialize<'frame, R>( + specs: &'frame [ColumnSpec], + byts: &'frame Bytes, +) -> Result +where + R: DeserializeRow<'frame>, +{ + >::type_check(specs) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let slice = FrameSlice::new(byts); + let iter = ColumnIterator::new(specs, slice); + >::deserialize(iter) +} + +#[track_caller] +pub(crate) fn get_typck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), +) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } +} + +#[track_caller] +fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typck_err_inner(err.0.as_ref()) +} + +#[track_caller] +fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } +} + +#[test] +fn test_tuple_errors() { + // Column type check failure + { + let col_name: &str = "i"; + let specs = &[spec(col_name, ColumnType::Int)]; + let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); + let err = get_typck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!( + err.cql_types, + specs + .iter() + .map(|spec| spec.typ.clone()) + .collect::>() + ); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(*column_index, 0); + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + &err.kind, + super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_name, err, .. + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } +} + +#[test] +fn test_row_errors() { + // Column type check failure - happens never, because Row consists of CqlValues, + // which accept all CQL types. + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: _column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } +} + +fn specs_to_types(specs: &[ColumnSpec]) -> Vec { + specs.iter().map(|spec| spec.typ.clone()).collect() +} + +#[test] +fn test_struct_deserialization_errors() { + // Loose ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate")] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + #[scylla(rename = "c")] + d: bool, + } + + // Type check errors + { + // Missing column + { + let specs = [spec("a", ColumnType::Ascii), spec("b", ColumnType::Int)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ValuesMissingForColumns { + column_names: ref missing_fields, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(missing_fields.as_slice(), &["c"]); + } + + // Duplicated column + { + let specs = [ + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::DuplicatedColumn { + column_index, + column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "a"); + } + + // Unknown column + { + let specs = [ + spec("d", ColumnType::Counter), + spec("a", ColumnType::Ascii), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnWithUnknownName { + column_index, + ref column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "d"); + } + + // Column incompatible types - column type check failed + { + let specs = [spec("b", ColumnType::Int), spec("a", ColumnType::Blob)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Got null + { + let specs = [ + spec("c", ColumnType::Boolean), + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "a"); + } + + // Column deserialization failed + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Ascii), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [ + &0_i32.to_be_bytes(), + "alamakota".as_bytes(), + &42_i16.to_be_bytes(), + ] + .map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name.as_str(), "c"); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 1, + got: 2, + } + ); + } + } + } + + // Strict ordering + { + #[derive(scylla_macros::DeserializeRow, PartialEq, Eq, Debug)] + #[scylla(crate = "crate", enforce_order)] + struct MyRow<'a> { + a: &'a str, + #[scylla(skip)] + x: String, + b: Option, + c: bool, + } + + // Type check errors + { + // Too few columns + { + let specs = [spec("a", ColumnType::Text)]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 1); + } + + // Excess columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + spec("d", ColumnType::Counter), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(rust_cols, 3); + assert_eq!(cql_cols, 4); + } + + // Renamed column name mismatch + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("d", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 3); + assert_eq!(rust_column_name, "c"); + assert_eq!(column_index, 2); + assert_eq!(db_column_name.as_str(), "d"); + } + + // Columns switched - column name mismatch + { + let specs = [ + spec("b", ColumnType::Int), + spec("a", ColumnType::Text), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnNameMismatch { + field_index, + column_index, + rust_column_name, + ref db_column_name, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(field_index, 0); + assert_eq!(column_index, 0); + assert_eq!(rust_column_name, "a"); + assert_eq!(db_column_name.as_str(), "b"); + } + + // Column incompatible types - column type check failed + { + let specs = [ + spec("a", ColumnType::Blob), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + let err = MyRow::type_check(&specs).unwrap_err(); + let err = get_typck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_types, specs_to_types(&specs)); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 0); + assert_eq!(column_name.as_str(), "a"); + let err = value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Blob); + assert_matches!( + err.kind, + value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + } + + // Deserialization errors + { + // Too few columns + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let err = MyRow::deserialize(ColumnIterator::new( + &specs, + FrameSlice::new(&serialize_cells([Some([true as u8])])), + )) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 1); + assert_eq!(column_name, "b"); + } + + // Bad field format + { + let typ = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [(&b"alamakota"[..]), &42_i32.to_be_bytes(), &[true as u8]].map(Some), + ); + + let row_bytes_too_short = row_bytes.slice(..row_bytes.len() - 1); + assert!(row_bytes.len() > row_bytes_too_short.len()); + + let err = deserialize::(&typ, &row_bytes_too_short).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + ref column_name, + .. + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_index, 2); + assert_eq!(column_name, "c"); + } + + // Column deserialization failed + { + let specs = [ + spec("a", ColumnType::Text), + spec("b", ColumnType::Int), + spec("c", ColumnType::Boolean), + ]; + + let row_bytes = serialize_cells( + [&b"alamakota"[..], &42_i64.to_be_bytes(), &[true as u8]].map(Some), + ); + + let err = deserialize::(&specs, &row_bytes).unwrap_err(); + + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: field_index, + ref column_name, + ref err, + } = err.kind + else { + panic!("unexpected error kind: {:?}", err.kind) + }; + assert_eq!(column_name.as_str(), "b"); + assert_eq!(field_index, 2); + let err = value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + err.kind, + value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 4, + got: 8, + } + ); + } + } + } +}