diff --git a/scylla-cql/Cargo.toml b/scylla-cql/Cargo.toml index 042564fd25..547404cae7 100644 --- a/scylla-cql/Cargo.toml +++ b/scylla-cql/Cargo.toml @@ -33,6 +33,7 @@ assert_matches = "1.5.0" criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 # Use large-dates feature to test potential edge cases time = { version = "0.3.21", features = ["large-dates"] } +uuid = { version = "1.0", features = ["v4"] } [[bench]] name = "benchmark" diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index a203762c53..68757331fc 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -1,6 +1,7 @@ use super::TryFromPrimitiveError; use crate::cql_to_rust::CqlTypeError; use crate::frame::value::SerializeValuesError; +use crate::types::deserialize::DeserializationError; use crate::types::serialize::SerializationError; use thiserror::Error; @@ -39,6 +40,8 @@ pub enum ParseError { #[error("Could not deserialize frame: {0}")] BadIncomingData(String), #[error(transparent)] + DeserializationError(#[from] DeserializationError), + #[error(transparent)] IoError(#[from] std::io::Error), #[error("type not yet implemented, id: {0}")] TypeNotImplemented(u16), diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 527d481eb2..f961f4b99e 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -1,19 +1,15 @@ use crate::cql_to_rust::{FromRow, FromRowError}; use crate::frame::response::event::SchemaChangeEvent; -use crate::frame::types::vint_decode; use crate::frame::value::{ Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, }; use crate::frame::{frame_errors::ParseError, types}; -use byteorder::{BigEndian, ReadBytesExt}; +use crate::types::deserialize::result::{RowIterator, TypedRowIterator}; +use crate::types::deserialize::value::{DeserializeValue, MapIterator, UdtIterator}; +use crate::types::deserialize::{DeserializationError, FrameSlice}; use bytes::{Buf, Bytes}; use std::borrow::Cow; -use std::{ - convert::{TryFrom, TryInto}, - net::IpAddr, - result::Result as StdResult, - str, -}; +use std::{convert::TryInto, net::IpAddr, result::Result as StdResult, str}; use uuid::Uuid; #[cfg(feature = "chrono")] @@ -655,6 +651,11 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult return Ok(CqlValue::Empty), } } + // The `new_borrowed` version of FrameSlice is deficient in that it does not hold + // a `Bytes` reference to the frame, only a slice. + // This is not a problem here, fortunately, because none of CqlValue variants contain + // any `Bytes` - only exclusively owned types - so we never call FrameSlice::to_bytes(). + let v = Some(FrameSlice::new_borrowed(buf)); Ok(match typ { Custom(type_str) => { @@ -664,239 +665,112 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - if !buf.is_ascii() { - return Err(ParseError::BadIncomingData( - "String is not ascii!".to_string(), - )); - } - CqlValue::Ascii(str::from_utf8(buf)?.to_owned()) + let s = String::deserialize(typ, v)?; + CqlValue::Ascii(s) } Boolean => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::Boolean(buf[0] != 0x00) + let b = bool::deserialize(typ, v)?; + CqlValue::Boolean(b) + } + Blob => { + let b = Vec::::deserialize(typ, v)?; + CqlValue::Blob(b) } - Blob => CqlValue::Blob(buf.to_vec()), Date => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - - let date_value = buf.read_u32::()?; - CqlValue::Date(CqlDate(date_value)) + let d = CqlDate::deserialize(typ, v)?; + CqlValue::Date(d) } Counter => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Counter(crate::frame::value::Counter(buf.read_i64::()?)) + let c = crate::frame::response::result::Counter::deserialize(typ, v)?; + CqlValue::Counter(c) } Decimal => { - let scale = types::read_int(buf)?; - let bytes = buf.to_vec(); - let big_decimal: CqlDecimal = - CqlDecimal::from_signed_be_bytes_and_exponent(bytes, scale); - - CqlValue::Decimal(big_decimal) + let d = CqlDecimal::deserialize(typ, v)?; + CqlValue::Decimal(d) } Double => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::Double(buf.read_f64::()?) + let d = f64::deserialize(typ, v)?; + CqlValue::Double(d) } Float => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Float(buf.read_f32::()?) + let f = f32::deserialize(typ, v)?; + CqlValue::Float(f) } Int => { - if buf.len() != 4 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 4 not {}", - buf.len() - ))); - } - CqlValue::Int(buf.read_i32::()?) + let i = i32::deserialize(typ, v)?; + CqlValue::Int(i) } SmallInt => { - if buf.len() != 2 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 2 not {}", - buf.len() - ))); - } - - CqlValue::SmallInt(buf.read_i16::()?) + let si = i16::deserialize(typ, v)?; + CqlValue::SmallInt(si) } TinyInt => { - if buf.len() != 1 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 1 not {}", - buf.len() - ))); - } - CqlValue::TinyInt(buf.read_i8()?) + let ti = i8::deserialize(typ, v)?; + CqlValue::TinyInt(ti) } BigInt => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - CqlValue::BigInt(buf.read_i64::()?) + let bi = i64::deserialize(typ, v)?; + CqlValue::BigInt(bi) + } + Text => { + let s = String::deserialize(typ, v)?; + CqlValue::Text(s) } - Text => CqlValue::Text(str::from_utf8(buf)?.to_owned()), Timestamp => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let millis = buf.read_i64::()?; - - CqlValue::Timestamp(CqlTimestamp(millis)) + let t = CqlTimestamp::deserialize(typ, v)?; + CqlValue::Timestamp(t) } Time => { - if buf.len() != 8 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 8 not {}", - buf.len() - ))); - } - let nanoseconds: i64 = buf.read_i64::()?; - - // Valid values are in the range 0 to 86399999999999 - if !(0..=86399999999999).contains(&nanoseconds) { - return Err(ParseError::BadIncomingData(format! { - "Invalid time value only 0 to 86399999999999 allowed: {}.", nanoseconds - })); - } - - CqlValue::Time(CqlTime(nanoseconds)) + let t = CqlTime::deserialize(typ, v)?; + CqlValue::Time(t) } Timeuuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); - CqlValue::Timeuuid(CqlTimeuuid::from(uuid)) + let t = CqlTimeuuid::deserialize(typ, v)?; + CqlValue::Timeuuid(t) } Duration => { - let months = i32::try_from(vint_decode(buf)?)?; - let days = i32::try_from(vint_decode(buf)?)?; - let nanoseconds = vint_decode(buf)?; - - CqlValue::Duration(CqlDuration { - months, - days, - nanoseconds, - }) + let d = CqlDuration::deserialize(typ, v)?; + CqlValue::Duration(d) + } + Inet => { + let i = IpAddr::deserialize(typ, v)?; + CqlValue::Inet(i) } - Inet => CqlValue::Inet(match buf.len() { - 4 => { - let ret = IpAddr::from(<[u8; 4]>::try_from(&buf[0..4])?); - buf.advance(4); - ret - } - 16 => { - let ret = IpAddr::from(<[u8; 16]>::try_from(&buf[0..16])?); - buf.advance(16); - ret - } - v => { - return Err(ParseError::BadIncomingData(format!( - "Invalid inet bytes length: {}", - v - ))); - } - }), Uuid => { - if buf.len() != 16 { - return Err(ParseError::BadIncomingData(format!( - "Buffer length should be 16 not {}", - buf.len() - ))); - } - let uuid = uuid::Uuid::from_slice(buf).expect("Deserializing Uuid failed."); + let uuid = uuid::Uuid::deserialize(typ, v)?; CqlValue::Uuid(uuid) } - Varint => CqlValue::Varint(CqlVarint::from_signed_bytes_be(buf.to_vec())), - List(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::List(res) + Varint => { + let vi = CqlVarint::deserialize(typ, v)?; + CqlValue::Varint(vi) } - Map(key_type, value_type) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - let mut b = types::read_bytes(buf)?; - let key = deser_cql_value(key_type, &mut b)?; - b = types::read_bytes(buf)?; - let val = deser_cql_value(value_type, &mut b)?; - res.push((key, val)); - } - CqlValue::Map(res) + List(_type_name) => { + let l = Vec::::deserialize(typ, v)?; + CqlValue::List(l) } - Set(type_name) => { - let len: usize = types::read_int(buf)?.try_into()?; - let mut res = Vec::with_capacity(len); - for _ in 0..len { - // TODO: is `null` allowed as set element? Should we use read_bytes_opt? - let mut b = types::read_bytes(buf)?; - res.push(deser_cql_value(type_name, &mut b)?); - } - CqlValue::Set(res) + Map(_key_type, _value_type) => { + let iter = MapIterator::<'_, CqlValue, CqlValue>::deserialize(typ, v)?; + let m: Vec<(CqlValue, CqlValue)> = iter.collect::>()?; + CqlValue::Map(m) + } + Set(_type_name) => { + let s = Vec::::deserialize(typ, v)?; + CqlValue::Set(s) } UserDefinedType { type_name, keyspace, - field_types, + .. } => { - let mut fields: Vec<(String, Option)> = Vec::new(); - - for (field_name, field_type) in field_types { - // If a field is added to a UDT and we read an old (frozen ?) version of it, - // the driver will fail to parse the whole UDT. - // This is why we break the parsing after we reach the end of the serialized UDT. - if buf.is_empty() { - break; - } - - let mut field_value: Option = None; - if let Some(mut field_val_bytes) = types::read_bytes_opt(buf)? { - field_value = Some(deser_cql_value(field_type, &mut field_val_bytes)?); - } - - fields.push((field_name.clone(), field_value)); - } + let iter = UdtIterator::deserialize(typ, v)?; + let fields: Vec<(String, Option)> = iter + .map(|((col_name, col_type), res)| { + res.and_then(|v| { + let val = Option::::deserialize(col_type, v.flatten())?; + Ok((col_name.clone(), val)) + }) + }) + .collect::>()?; CqlValue::UserDefinedType { keyspace: keyspace.clone(), @@ -905,15 +779,19 @@ pub fn deser_cql_value(typ: &ColumnType, buf: &mut &[u8]) -> StdResult { - let mut res = Vec::with_capacity(type_names.len()); - for type_name in type_names { - match types::read_bytes_opt(buf)? { - Some(mut b) => res.push(Some(deser_cql_value(type_name, &mut b)?)), - None => res.push(None), - }; - } - - CqlValue::Tuple(res) + let t = type_names + .iter() + .map(|typ| { + types::read_bytes_opt(buf).and_then(|v| { + v.map(|v| { + CqlValue::deserialize(typ, Some(FrameSlice::new_borrowed(v))) + .map_err(Into::into) + }) + .transpose() + }) + }) + .collect::>()?; + CqlValue::Tuple(t) } }) } @@ -943,19 +821,16 @@ fn deser_rows( let rows_count: usize = types::read_int(buf)?.try_into()?; - let mut rows = Vec::with_capacity(rows_count); - for _ in 0..rows_count { - let mut columns = Vec::with_capacity(metadata.col_count); - for i in 0..metadata.col_count { - let v = if let Some(mut b) = types::read_bytes_opt(buf)? { - Some(deser_cql_value(&metadata.col_specs[i].typ, &mut b)?) - } else { - None - }; - columns.push(v); - } - rows.push(Row { columns }); - } + let raw_rows_iter = RowIterator::new( + rows_count, + &metadata.col_specs, + FrameSlice::new_borrowed(buf), + ); + let rows_iter = TypedRowIterator::::new(raw_rows_iter) + .map_err(|err| DeserializationError::new(err.0))?; + + let rows = rows_iter.collect::>()?; + Ok(Rows { metadata, rows_count, diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs new file mode 100644 index 0000000000..cfc98d5ce5 --- /dev/null +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -0,0 +1,213 @@ +use bytes::Bytes; + +use crate::frame::frame_errors::ParseError; +use crate::frame::types; + +/// A reference to a part of the frame. +// +// # Design justification +// +// ## Why we need a borrowed type +// +// The reason why we need to store a borrowed slice is that we want a lifetime that is longer than one obtained +// when coercing Bytes to a slice in the body of a function. That is, we want to allow deserializing types +// that borrow from the frame, which resides in QueryResult. +// Consider a function with the signature: +// +// fn fun(b: Bytes) { ... } +// +// This function cannot return a type that borrows from the frame, because any slice created from `b` +// inside `fun` cannot escape `fun`. +// Conversely, if a function has signature: +// +// fn fun(s: &'frame [u8]) { ... } +// +// then it can happily return types with lifetime 'frame. +// +// ## Why we need the full frame +// +// We don't. We only need to be able to return Bytes encompassing our subslice. However, the design choice +// was made to only store a reference to the original Bytes object residing in QueryResult, so that we avoid +// cloning Bytes when performing subslicing on FrameSlice. We delay the Bytes cloning, normally a moderately +// expensive operation involving cloning an Arc, up until it is really needed. +// +// ## Why not different design +// +// - why not a &'frame [u8] only? Because we want to enable deserializing types containing owned Bytes, too. +// - why not a Bytes only? Because we need to propagate the 'frame lifetime. +// - why not a &'frame Bytes only? Because we want to somehow represent subslices, and subslicing +// &'frame Bytes return Bytes, not &'frame Bytes. +#[derive(Clone, Copy, Debug)] +pub struct FrameSlice<'frame> { + // The actual subslice represented by this FrameSlice. + frame_subslice: &'frame [u8], + + // Reference to the original Bytes object that this FrameSlice is derived + // from. It is used to convert the `mem` slice into a fully blown Bytes + // object via Bytes::slice_ref method. + original_frame: &'frame Bytes, +} + +static EMPTY_BYTES: Bytes = Bytes::new(); + +impl<'frame> FrameSlice<'frame> { + /// Creates a new FrameSlice from a reference of a Bytes object. + /// + /// This method is exposed to allow writing deserialization tests + /// for custom types. + #[inline] + pub fn new(frame: &'frame Bytes) -> Self { + Self { + frame_subslice: frame, + original_frame: frame, + } + } + + /// Creates an empty FrameSlice. + #[inline] + pub fn new_empty() -> Self { + Self { + frame_subslice: &EMPTY_BYTES, + original_frame: &EMPTY_BYTES, + } + } + + /// Creates a new FrameSlice from a reference to a slice. + /// + /// This method creates a not-fully-valid FrameSlice that does not hold + /// the valid original frame Bytes. Thus, it is intended to be used in + /// legacy code that does not operate on Bytes, but rather on borrowed slice only. + /// For correctness in an unlikely case that someone calls `to_bytes()` on such + /// a deficient slice, a special treatment is added there that copies + /// the slice into a new-allocation-based Bytes. + /// This is pub(crate) for the above reason. + #[inline] + pub(crate) fn new_borrowed(frame_subslice: &'frame [u8]) -> Self { + Self { + frame_subslice, + original_frame: &EMPTY_BYTES, + } + } + + /// Returns `true` if the slice has length of 0. + #[inline] + pub fn is_empty(&self) -> bool { + self.frame_subslice.is_empty() + } + + /// Returns the subslice. + #[inline] + pub fn as_slice(&self) -> &'frame [u8] { + self.frame_subslice + } + + /// Returns a mutable reference to the subslice. + #[inline] + pub fn as_slice_mut(&mut self) -> &mut &'frame [u8] { + &mut self.frame_subslice + } + + /// Returns a reference to the Bytes object which encompasses the whole frame slice. + /// + /// The Bytes object will usually be larger than the slice returned by + /// [FrameSlice::as_slice]. If you wish to obtain a new Bytes object that + /// points only to the subslice represented by the FrameSlice object, + /// see [FrameSlice::to_bytes]. + #[inline] + pub fn as_original_frame_bytes(&self) -> &'frame Bytes { + self.original_frame + } + + /// Returns a new Bytes object which is a subslice of the original Bytes + /// frame slice object. + #[inline] + pub fn to_bytes(&self) -> Bytes { + if self.original_frame.is_empty() { + // For the borrowed, deficient version of FrameSlice - the one created with + // FrameSlice::new_borrowed to work properly in case someone calls + // FrameSlice::to_bytes on it (even though it's not intended for the borrowed version), + // the special case is introduced that creates new Bytes by copying the slice into + // a new allocation. Note that it's something unexpected to be ever done. + return Bytes::copy_from_slice(self.as_slice()); + } + + self.original_frame.slice_ref(self.frame_subslice) + } + + /// Reads and consumes a `[bytes]` item from the beginning of the frame, + /// returning a subslice that encompasses that item. + /// + /// If the operation fails then the slice remains unchanged. + #[inline] + pub(super) fn read_cql_bytes(&mut self) -> Result>, ParseError> { + // We copy the slice reference, not to mutate the FrameSlice in case of an error. + let mut slice = self.frame_subslice; + + let cql_bytes = types::read_bytes_opt(&mut slice)?; + + // `read_bytes_opt` hasn't failed, so now we must update the FrameSlice. + self.frame_subslice = slice; + + Ok(cql_bytes.map(|slice| Self { + frame_subslice: slice, + original_frame: self.original_frame, + })) + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::super::tests::{serialize_cells, CELL1, CELL2}; + use super::FrameSlice; + + #[test] + fn test_cql_bytes_consumption() { + let frame = serialize_cells([Some(CELL1), None, Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + assert!(!slice.is_empty()); + + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL1) + ); + assert!(!slice.is_empty()); + assert!(slice.read_cql_bytes().unwrap().is_none()); + assert!(!slice.is_empty()); + assert_eq!( + slice.read_cql_bytes().unwrap().map(|s| s.as_slice()), + Some(CELL2) + ); + assert!(slice.is_empty()); + slice.read_cql_bytes().unwrap_err(); + assert!(slice.is_empty()); + } + + #[test] + fn test_cql_bytes_owned() { + let frame = serialize_cells([Some(CELL1), Some(CELL2)]); + let mut slice = FrameSlice::new(&frame); + + let subslice1 = slice.read_cql_bytes().unwrap().unwrap(); + let subslice2 = slice.read_cql_bytes().unwrap().unwrap(); + + assert_eq!(subslice1.as_slice(), CELL1); + assert_eq!(subslice2.as_slice(), CELL2); + + assert_eq!( + subslice1.as_original_frame_bytes() as *const Bytes, + &frame as *const Bytes + ); + assert_eq!( + subslice2.as_original_frame_bytes() as *const Bytes, + &frame as *const Bytes + ); + + let subslice1_bytes = subslice1.to_bytes(); + let subslice2_bytes = subslice2.to_bytes(); + + assert_eq!(subslice1.as_slice(), subslice1_bytes.as_ref()); + assert_eq!(subslice2.as_slice(), subslice2_bytes.as_ref()); + } +} diff --git a/scylla-cql/src/types/deserialize/mod.rs b/scylla-cql/src/types/deserialize/mod.rs new file mode 100644 index 0000000000..12e73052ba --- /dev/null +++ b/scylla-cql/src/types/deserialize/mod.rs @@ -0,0 +1,298 @@ +//! Framework for deserialization of data returned by database queries. +//! +//! Deserialization is based on two traits: +//! +//! - A type that implements `DeserializeValue<'frame>` can be deserialized +//! from a single _CQL value_ - i.e. an element of a row in the query result, +//! - A type that implements `DeserializeRow<'frame>` can be deserialized +//! from a single _row_ of a query result. +//! +//! Those traits are quite similar to each other, both in the idea behind them +//! and the interface that they expose. +//! +//! # `type_check` and `deserialize` +//! +//! The deserialization process is divided into two parts: type checking and +//! actual deserialization, represented by `DeserializeValue`/`DeserializeRow`'s +//! methods called `type_check` and `deserialize`. +//! +//! The `deserialize` method can assume that `type_check` was called before, so +//! it doesn't have to verify the type again. This can be a performance gain +//! when deserializing query results with multiple rows: as each row in a result +//! has the same type, it is only necessary to call `type_check` once for the +//! whole result and then `deserialize` for each row. +//! +//! Note that `deserialize` is not an `unsafe` method - although you can be +//! sure that the driver will call `type_check` before `deserialize`, you +//! shouldn't do unsafe things based on this assumption. +//! +//! # Data ownership +//! +//! Some CQL types can be easily consumed while still partially serialized. +//! For example, types like `blob` or `text` can be just represented with +//! `&[u8]` and `&str` that just point to a part of the serialized response. +//! This is more efficient than using `Vec` or `String` because it avoids +//! an allocation and a copy, however it is less convenient because those types +//! are bound with a lifetime. +//! +//! The framework supports types that refer to the serialized response's memory +//! in three different ways: +//! +//! ## Owned types +//! +//! Some types don't borrow anything and fully own their data, e.g. `i32` or +//! `String`. They aren't constrained by any lifetime and should implement +//! the respective trait for _all_ lifetimes, i.e.: +//! +//! ```rust +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! struct MyVec(Vec); +//! impl<'frame> DeserializeValue<'frame> for MyVec { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new( +//! ParseError::BadIncomingData("Expected bytes".to_owned()) +//! )) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new( +//! ParseError::BadIncomingData("Expected non-null value".to_owned()) +//! ) +//! }) +//! .map(|v| Self(v.as_slice().to_vec())) +//! } +//! } +//! ``` +//! +//! ## Borrowing types +//! +//! Some types do not fully contain their data but rather will point to some +//! bytes in the serialized response, e.g. `&str` or `&[u8]`. Those types will +//! usually contain a lifetime in their definition. In order to properly +//! implement `DeserializeValue` or `DeserializeRow` for such a type, the `impl` +//! should still have a generic lifetime parameter, but the lifetimes from the +//! type definition should be constrained with the generic lifetime parameter. +//! For example: +//! +//! ```rust +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! struct MySlice<'a>(&'a [u8]); +//! impl<'a, 'frame> DeserializeValue<'frame> for MySlice<'a> +//! where +//! 'frame: 'a, +//! { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new( +//! ParseError::BadIncomingData("Expected bytes".to_owned()) +//! )) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new( +//! ParseError::BadIncomingData("Expected non-null value".to_owned()) +//! ) +//! }) +//! .map(|v| Self(v.as_slice())) +//! } +//! } +//! ``` +//! +//! ## Reference-counted types +//! +//! Internally, the driver uses the `bytes::Bytes` type to keep the contents +//! of the serialized response. It supports creating derived `Bytes` objects +//! which point to a subslice but keep the whole, original `Bytes` object alive. +//! +//! During deserialization, a type can obtain a `Bytes` subslice that points +//! to the serialized value. This approach combines advantages of the previous +//! two approaches - creating a derived `Bytes` object can be cheaper than +//! allocation and a copy (it supports `Arc`-like semantics) and the `Bytes` +//! type is not constrained by a lifetime. However, you should be aware that +//! the subslice will keep the whole `Bytes` object that holds the frame alive. +//! It is not recommended to use this approach for long-living objects because +//! it can introduce space leaks. +//! +//! Example: +//! +//! ```rust +//! # use scylla_cql::frame::frame_errors::ParseError; +//! # use scylla_cql::frame::response::result::ColumnType; +//! # use scylla_cql::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; +//! # use scylla_cql::types::deserialize::value::DeserializeValue; +//! # use bytes::Bytes; +//! struct MyBytes(Bytes); +//! impl<'frame> DeserializeValue<'frame> for MyBytes { +//! fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { +//! if let ColumnType::Blob = typ { +//! return Ok(()); +//! } +//! Err(TypeCheckError::new(ParseError::BadIncomingData("Expected bytes".to_owned()))) +//! } +//! +//! fn deserialize( +//! _typ: &'frame ColumnType, +//! v: Option>, +//! ) -> Result { +//! v.ok_or_else(|| { +//! DeserializationError::new(ParseError::BadIncomingData("Expected non-null value".to_owned())) +//! }) +//! .map(|v| Self(v.to_bytes())) +//! } +//! } +//! ``` +// TODO: in the above module docstring, stop abusing ParseError once errors are refactored. + +pub mod frame_slice; +pub mod result; +pub mod row; +pub mod value; + +pub use frame_slice::FrameSlice; + +pub use row::DeserializeRow; +pub use value::DeserializeValue; + +use std::error::Error; +use std::fmt::Display; +use std::sync::Arc; + +use thiserror::Error; + +// Errors + +/// An error indicating that a failure happened during type check. +/// +/// The error is type-erased so that the crate users can define their own +/// type check impls and their errors. +/// As for the impls defined or generated +/// by the driver itself, the following errors can be returned: +/// +/// - [`row::BuiltinTypeCheckError`] is returned when type check of +/// one of types with an impl built into the driver fails. It is also returned +/// from impls generated by the `DeserializeRow` macro. +/// - [`value::BuiltinTypeCheckError`] is analogous to the above but is +/// returned from [`DeserializeValue::type_check`] instead both in the case of +/// builtin impls and impls generated by the `DeserializeValue` macro. +/// It won't be returned by the `Session` directly, but it might be nested +/// in the [`row::BuiltinTypeCheckError`]. +#[derive(Debug, Clone, Error)] +#[error(transparent)] +pub struct TypeCheckError(pub(crate) Arc); + +impl TypeCheckError { + /// Constructs a new `TypeCheckError`. + #[inline] + pub fn new(err: impl std::error::Error + Send + Sync + 'static) -> Self { + Self(Arc::new(err)) + } +} + +/// An error indicating that a failure happened during deserialization. +/// +/// The error is type-erased so that the crate users can define their own +/// deserialization impls and their errors. As for the impls defined or generated +/// by the driver itself, the following errors can be returned: +/// +/// - [`row::BuiltinDeserializationError`] is returned when deserialization of +/// one of types with an impl built into the driver fails. It is also returned +/// from impls generated by the `DeserializeRow` macro. +/// - [`value::BuiltinDeserializationError`] is analogous to the above but is +/// returned from [`DeserializeValue::deserialize`] instead both in the case of +/// builtin impls and impls generated by the `DeserializeValue` macro. +/// It won't be returned by the `Session` directly, but it might be nested +/// in the [`row::BuiltinDeserializationError`]. +#[derive(Debug, Clone, Error)] +pub struct DeserializationError(Arc); + +impl DeserializationError { + /// Constructs a new `DeserializationError`. + #[inline] + pub fn new(err: impl Error + Send + Sync + 'static) -> Self { + Self(Arc::new(err)) + } +} + +impl Display for DeserializationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DeserializationError: {}", self.0) + } +} + +// This is a hack to enable setting the proper Rust type name in error messages, +// even though the error originates from some helper type used underneath. +// ASSUMPTION: This should be used: +// - ONLY in proper type_check()/deserialize() implementation, +// - BEFORE an error is cloned (because otherwise the Arc::get_mut fails). +macro_rules! make_error_replace_rust_name { + ($fn_name: ident, $outer_err: ty, $inner_err: ty) => { + fn $fn_name(mut err: $outer_err) -> $outer_err { + // Safety: the assumed usage of this function guarantees that the Arc has not yet been cloned. + let arc_mut = std::sync::Arc::get_mut(&mut err.0).unwrap(); + + let rust_name: &mut &str = { + if let Some(err) = arc_mut.downcast_mut::<$inner_err>() { + &mut err.rust_name + } else { + unreachable!(concat!( + "This function is assumed to be called only on built-in ", + stringify!($inner_err), + " kinds." + )) + } + }; + + *rust_name = std::any::type_name::(); + err + } + }; +} +use make_error_replace_rust_name; + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::frame::types; + + pub(super) static CELL1: &[u8] = &[1, 2, 3]; + pub(super) static CELL2: &[u8] = &[4, 5, 6, 7]; + + pub(super) fn serialize_cells( + cells: impl IntoIterator>>, + ) -> Bytes { + let mut bytes = BytesMut::new(); + for cell in cells { + types::write_bytes_opt(cell, &mut bytes).unwrap(); + } + bytes.freeze() + } + + pub(super) fn spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + name: name.to_owned(), + typ, + table_spec: TableSpec::borrowed("ks", "tbl"), + } + } +} diff --git a/scylla-cql/src/types/deserialize/result.rs b/scylla-cql/src/types/deserialize/result.rs new file mode 100644 index 0000000000..036b909afb --- /dev/null +++ b/scylla-cql/src/types/deserialize/result.rs @@ -0,0 +1,200 @@ +use crate::frame::response::result::ColumnSpec; + +use super::row::{mk_deser_err, BuiltinDeserializationErrorKind, ColumnIterator, DeserializeRow}; +use super::{DeserializationError, FrameSlice, TypeCheckError}; +use std::marker::PhantomData; + +/// Iterates over the whole result, returning rows. +pub struct RowIterator<'frame> { + specs: &'frame [ColumnSpec], + remaining: usize, + slice: FrameSlice<'frame>, +} + +impl<'frame> RowIterator<'frame> { + /// Creates a new iterator over rows from a serialized response. + /// + /// - `remaining` - number of the remaining rows in the serialized response, + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a [FrameSlice] that points to the serialized rows data. + #[inline] + pub fn new(remaining: usize, specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs, + remaining, + slice, + } + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.specs + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.remaining + } +} + +impl<'frame> Iterator for RowIterator<'frame> { + type Item = Result, DeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + + let iter = ColumnIterator::new(self.specs, self.slice); + + // Skip the row here, manually + for (column_index, spec) in self.specs.iter().enumerate() { + if let Err(err) = self.slice.read_cql_bytes() { + return Some(Err(mk_deser_err::( + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name: spec.name.clone(), + err: DeserializationError::new(err), + }, + ))); + } + } + + Some(Ok(iter)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // The iterator will always return exactly `self.remaining` + // elements: Oks until an error is encountered and then Errs + // containing that same first encountered error. + (self.remaining, Some(self.remaining)) + } +} + +/// A typed version of [RowIterator] which deserializes the rows before +/// returning them. +pub struct TypedRowIterator<'frame, R> { + inner: RowIterator<'frame>, + _phantom: PhantomData, +} + +impl<'frame, R> TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + /// Creates a new [TypedRowIterator] from given [RowIterator]. + /// + /// Calls `R::type_check` and fails if the type check fails. + #[inline] + pub fn new(raw: RowIterator<'frame>) -> Result { + R::type_check(raw.specs())?; + Ok(Self { + inner: raw, + _phantom: PhantomData, + }) + } + + /// Returns information about the columns of rows that are iterated over. + #[inline] + pub fn specs(&self) -> &'frame [ColumnSpec] { + self.inner.specs() + } + + /// Returns the remaining number of rows that this iterator is supposed + /// to return. + #[inline] + pub fn rows_remaining(&self) -> usize { + self.inner.rows_remaining() + } +} + +impl<'frame, R> Iterator for TypedRowIterator<'frame, R> +where + R: DeserializeRow<'frame>, +{ + type Item = Result; + + #[inline] + fn next(&mut self) -> Option { + self.inner + .next() + .map(|raw| raw.and_then(|raw| R::deserialize(raw))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use crate::frame::response::result::ColumnType; + + use super::super::tests::{serialize_cells, spec, CELL1, CELL2}; + use super::{FrameSlice, RowIterator, TypedRowIterator}; + + #[test] + fn test_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + let mut row1 = iter.next().unwrap().unwrap(); + let c11 = row1.next().unwrap().unwrap(); + assert_eq!(c11.slice.unwrap().as_slice(), CELL1); + let c12 = row1.next().unwrap().unwrap(); + assert_eq!(c12.slice.unwrap().as_slice(), CELL2); + assert!(row1.next().is_none()); + + let mut row2 = iter.next().unwrap().unwrap(); + let c21 = row2.next().unwrap().unwrap(); + assert_eq!(c21.slice.unwrap().as_slice(), CELL2); + let c22 = row2.next().unwrap().unwrap(); + assert_eq!(c22.slice.unwrap().as_slice(), CELL1); + assert!(row2.next().is_none()); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_row_iterator_too_few_rows() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let mut iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + + iter.next().unwrap().unwrap(); + assert!(iter.next().unwrap().is_err()); + } + + #[test] + fn test_typed_row_iterator_basic_parse() { + let raw_data = serialize_cells([Some(CELL1), Some(CELL2), Some(CELL2), Some(CELL1)]); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(2, &specs, FrameSlice::new(&raw_data)); + let mut iter = TypedRowIterator::<'_, (&[u8], Vec)>::new(iter).unwrap(); + + let (c11, c12) = iter.next().unwrap().unwrap(); + assert_eq!(c11, CELL1); + assert_eq!(c12, CELL2); + + let (c21, c22) = iter.next().unwrap().unwrap(); + assert_eq!(c21, CELL2); + assert_eq!(c22, CELL1); + + assert!(iter.next().is_none()); + } + + #[test] + fn test_typed_row_iterator_wrong_type() { + let raw_data = Bytes::new(); + let specs = [spec("b1", ColumnType::Blob), spec("b2", ColumnType::Blob)]; + let iter = RowIterator::new(0, &specs, FrameSlice::new(&raw_data)); + assert!(TypedRowIterator::<'_, (i32, i64)>::new(iter).is_err()); + } +} diff --git a/scylla-cql/src/types/deserialize/row.rs b/scylla-cql/src/types/deserialize/row.rs new file mode 100644 index 0000000000..c66f3c7328 --- /dev/null +++ b/scylla-cql/src/types/deserialize/row.rs @@ -0,0 +1,691 @@ +//! Provides types for dealing with row deserialization. + +use std::fmt::Display; + +use thiserror::Error; + +use super::value::DeserializeValue; +use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; +use crate::frame::response::result::{ColumnSpec, ColumnType, CqlValue, Row}; + +/// Represents a raw, unparsed column value. +#[non_exhaustive] +pub struct RawColumn<'frame> { + pub index: usize, + pub spec: &'frame ColumnSpec, + pub slice: Option>, +} + +/// Iterates over columns of a single row. +#[derive(Clone, Debug)] +pub struct ColumnIterator<'frame> { + specs: std::iter::Enumerate>, + slice: FrameSlice<'frame>, +} + +impl<'frame> ColumnIterator<'frame> { + /// Creates a new iterator over a single row. + /// + /// - `specs` - information about columns of the serialized response, + /// - `slice` - a [FrameSlice] which points to the serialized row. + #[inline] + pub(crate) fn new(specs: &'frame [ColumnSpec], slice: FrameSlice<'frame>) -> Self { + Self { + specs: specs.iter().enumerate(), + slice, + } + } + + /// Returns the remaining number of columns that this iterator is expected + /// to return. + #[inline] + pub fn columns_remaining(&self) -> usize { + self.specs.len() + } +} + +impl<'frame> Iterator for ColumnIterator<'frame> { + type Item = Result, DeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + let (column_index, spec) = self.specs.next()?; + Some( + self.slice + .read_cql_bytes() + .map(|slice| RawColumn { + index: column_index, + spec, + slice, + }) + .map_err(|err| { + mk_deser_err::( + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name: spec.name.clone(), + err: DeserializationError::new(err), + }, + ) + }), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.specs.size_hint() + } +} + +/// A type that can be deserialized from a row that was returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeRow<'frame> +where + Self: Sized, +{ + /// Checks that the schema of the result matches what this type expects. + /// + /// This function can check whether column types and names match the + /// expectations. + fn type_check(specs: &[ColumnSpec]) -> Result<(), TypeCheckError>; + + /// Deserializes a row from given column iterator. + /// + /// This function can assume that the driver called `type_check` to verify + /// the row's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize(row: ColumnIterator<'frame>) -> Result; +} + +// raw deserialization as ColumnIterator + +// What is the purpose of implementing DeserializeRow for ColumnIterator? +// +// Sometimes users might be interested in operating on ColumnIterator directly. +// Implementing DeserializeRow for it allows us to simplify our interface. For example, +// we have `QueryResult::rows()` - you can put T = ColumnIterator +// instead of having a separate rows_raw function or something like this. +impl<'frame> DeserializeRow<'frame> for ColumnIterator<'frame> { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + Ok(()) + } + + #[inline] + fn deserialize(row: ColumnIterator<'frame>) -> Result { + Ok(row) + } +} + +make_error_replace_rust_name!( + _typck_error_replace_rust_name, + TypeCheckError, + BuiltinTypeCheckError +); + +make_error_replace_rust_name!( + deser_error_replace_rust_name, + DeserializationError, + BuiltinDeserializationError +); + +// legacy/dynamic deserialization as Row +// +/// While no longer encouraged (because the new framework encourages deserializing +/// directly into desired types, entirely bypassing [CqlValue]), this can be indispensable +/// for some use cases, i.e. those involving dynamic parsing (ORMs?). +impl<'frame> DeserializeRow<'frame> for Row { + #[inline] + fn type_check(_specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + // CqlValues accept all types, no type checking needed. + Ok(()) + } + + #[inline] + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + let mut columns = Vec::with_capacity(row.size_hint().0); + while let Some(column) = row + .next() + .transpose() + .map_err(deser_error_replace_rust_name::)? + { + columns.push( + >::deserialize(&column.spec.typ, column.slice).map_err(|err| { + mk_deser_err::( + BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: column.index, + column_name: column.spec.name.clone(), + err, + }, + ) + })?, + ); + } + Ok(Self { columns }) + } +} + +// tuples +// +/// This is the new encouraged way for deserializing a row. +/// If only you know the exact column types in advance, you had better deserialize the row +/// to a tuple. The new deserialization framework will take care of all type checking +/// and needed conversions, issuing meaningful errors in case something goes wrong. +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeRow<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeValue<'frame>),* + { + fn type_check(specs: &[ColumnSpec]) -> Result<(), TypeCheckError> { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + + let column_types_iter = || specs.iter().map(|spec| spec.typ.clone()); + if let [$($idf),*] = &specs { + $( + <$Ti as DeserializeValue<'frame>>::type_check(&$idf.typ) + .map_err(|err| mk_typck_err::(column_types_iter(), BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index: $idx, + column_name: specs[$idx].name.clone(), + err + }))?; + )* + Ok(()) + } else { + Err(mk_typck_err::(column_types_iter(), BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols: TUPLE_LEN, cql_cols: specs.len() + })) + } + } + + fn deserialize(mut row: ColumnIterator<'frame>) -> Result { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + + let ret = ( + $({ + let column = row.next().unwrap_or_else(|| unreachable!( + "Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row {}", + TUPLE_LEN, + $idx + )).map_err(deser_error_replace_rust_name::)?; + + <$Ti as DeserializeValue<'frame>>::deserialize(&column.spec.typ, column.slice) + .map_err(|err| mk_deser_err::(BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: column.index, + column_name: column.spec.name.clone(), + err, + }))? + },)* + ); + assert!( + row.next().is_none(), + "Typecheck should have prevented this scenario! Column count mismatch: rust type {}, cql row is bigger", + TUPLE_LEN, + ); + Ok(ret) + } + } + } +} + +use super::value::impl_tuple_multiple; + +// Implements row-to-tuple deserialization for all tuple sizes up to 16. +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + +// Error facilities + +/// Failed to type check incoming result column types again given Rust type, +/// one of the types having support built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check the Rust type {rust_name} against CQL column types {cql_types:?} : {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type used to represent the values. + pub rust_name: &'static str, + + /// The CQL types of the values that the Rust type was being deserialized from. + pub cql_types: Vec, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err( + cql_types: impl IntoIterator, + kind: impl Into, +) -> TypeCheckError { + mk_typck_err_named(std::any::type_name::(), cql_types, kind) +} + +fn mk_typck_err_named( + name: &'static str, + cql_types: impl IntoIterator, + kind: impl Into, +) -> TypeCheckError { + TypeCheckError::new(BuiltinTypeCheckError { + rust_name: name, + cql_types: Vec::from_iter(cql_types), + kind: kind.into(), + }) +} + +/// Describes why type checking incoming result column types again given Rust type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind { + /// The Rust type expects `rust_cols` columns, but the statement operates on `cql_cols`. + WrongColumnCount { + /// The number of values that the Rust type provides. + rust_cols: usize, + + /// The number of columns that the statement operates on. + cql_cols: usize, + }, + + /// Column type check failed between Rust type and DB type at given position (=in given column). + ColumnTypeCheckFailed { + /// Index of the column. + column_index: usize, + + /// Name of the column, as provided by the DB. + column_name: String, + + /// Inner type check error due to the type mismatch. + err: TypeCheckError, + }, +} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinTypeCheckErrorKind::WrongColumnCount { + rust_cols, + cql_cols, + } => { + write!(f, "wrong column count: the statement operates on {cql_cols} columns, but the given rust types contains {rust_cols}") + } + BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } => write!( + f, + "mismatched types in column {column_name} at index {column_index}: {err}" + ), + } + } +} + +/// Failed to deserialize a row from the DB response, represented by one of the types +/// built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to deserialize query result row {rust_name}: {kind}")] +pub struct BuiltinDeserializationError { + /// Name of the Rust type used to represent the row. + pub rust_name: &'static str, + + /// Detailed information about the failure. + pub kind: BuiltinDeserializationErrorKind, +} + +pub(super) fn mk_deser_err( + kind: impl Into, +) -> DeserializationError { + mk_deser_err_named(std::any::type_name::(), kind) +} + +pub(super) fn mk_deser_err_named( + name: &'static str, + kind: impl Into, +) -> DeserializationError { + DeserializationError::new(BuiltinDeserializationError { + rust_name: name, + kind: kind.into(), + }) +} + +/// Describes why deserializing a result row failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinDeserializationErrorKind { + /// One of the columns failed to deserialize. + ColumnDeserializationFailed { + /// Index of the column that failed to deserialize. + column_index: usize, + + /// Name of the column that failed to deserialize. + column_name: String, + + /// The error that caused the column deserialization to fail. + err: DeserializationError, + }, + + /// One of the raw columns failed to deserialize, most probably + /// due to the invalid column structure inside a row in the frame. + RawColumnDeserializationFailed { + /// Index of the raw column that failed to deserialize. + column_index: usize, + + /// Name of the raw column that failed to deserialize. + column_name: String, + + /// The error that caused the raw column deserialization to fail. + err: DeserializationError, + }, +} + +impl Display for BuiltinDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index, + column_name, + err, + } => { + write!( + f, + "failed to deserialize column {column_name} at index {column_index}: {err}" + ) + } + BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index, + column_name, + err, + } => { + write!( + f, + "failed to deserialize raw column {column_name} at index {column_index} (most probably due to invalid column structure inside a row): {err}" + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + use bytes::Bytes; + + use crate::frame::frame_errors::ParseError; + use crate::frame::response::result::{ColumnSpec, ColumnType}; + use crate::types::deserialize::row::BuiltinDeserializationErrorKind; + use crate::types::deserialize::{DeserializationError, FrameSlice}; + + use super::super::tests::{serialize_cells, spec}; + use super::{BuiltinDeserializationError, ColumnIterator, CqlValue, DeserializeRow, Row}; + use super::{BuiltinTypeCheckError, BuiltinTypeCheckErrorKind}; + + #[test] + fn test_tuple_deserialization() { + // Empty tuple + deserialize::<()>(&[], &Bytes::new()).unwrap(); + + // 1-elem tuple + let (a,) = deserialize::<(i32,)>( + &[spec("i", ColumnType::Int)], + &serialize_cells([val_int(123)]), + ) + .unwrap(); + assert_eq!(a, 123); + + // 3-elem tuple + let (a, b, c) = deserialize::<(i32, i32, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap(); + assert_eq!((a, b, c), (123, 456, 789)); + + // Make sure that column type mismatch is detected + deserialize::<(i32, String, i32)>( + &[ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Int), + spec("i3", ColumnType::Int), + ], + &serialize_cells([val_int(123), val_int(456), val_int(789)]), + ) + .unwrap_err(); + + // Make sure that borrowing types compile and work correctly + let specs = &[spec("s", ColumnType::Text)]; + let byts = serialize_cells([val_str("abc")]); + let (s,) = deserialize::<(&str,)>(specs, &byts).unwrap(); + assert_eq!(s, "abc"); + } + + #[test] + fn test_deserialization_as_column_iterator() { + let col_specs = [ + spec("i1", ColumnType::Int), + spec("i2", ColumnType::Text), + spec("i3", ColumnType::Counter), + ]; + let serialized_values = serialize_cells([val_int(123), val_str("ScyllaDB"), None]); + let mut iter = deserialize::(&col_specs, &serialized_values).unwrap(); + + let col1 = iter.next().unwrap().unwrap(); + assert_eq!(col1.spec.name, "i1"); + assert_eq!(col1.spec.typ, ColumnType::Int); + assert_eq!(col1.slice.unwrap().as_slice(), &123i32.to_be_bytes()); + + let col2 = iter.next().unwrap().unwrap(); + assert_eq!(col2.spec.name, "i2"); + assert_eq!(col2.spec.typ, ColumnType::Text); + assert_eq!(col2.slice.unwrap().as_slice(), "ScyllaDB".as_bytes()); + + let col3 = iter.next().unwrap().unwrap(); + assert_eq!(col3.spec.name, "i3"); + assert_eq!(col3.spec.typ, ColumnType::Counter); + assert!(col3.slice.is_none()); + + assert!(iter.next().is_none()); + } + + fn val_int(i: i32) -> Option> { + Some(i.to_be_bytes().to_vec()) + } + + fn val_str(s: &str) -> Option> { + Some(s.as_bytes().to_vec()) + } + + fn deserialize<'frame, R>( + specs: &'frame [ColumnSpec], + byts: &'frame Bytes, + ) -> Result + where + R: DeserializeRow<'frame>, + { + >::type_check(specs) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let slice = FrameSlice::new(byts); + let iter = ColumnIterator::new(specs, slice); + >::deserialize(iter) + } + + #[track_caller] + fn get_typck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } + } + + #[track_caller] + fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } + } + + #[test] + fn test_tuple_errors() { + // Column type check failure + { + let col_name: &str = "i"; + let specs = &[spec(col_name, ColumnType::Int)]; + let err = deserialize::<(i64,)>(specs, &serialize_cells([val_int(123)])).unwrap_err(); + let err = get_typck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!( + err.cql_types, + specs + .iter() + .map(|spec| spec.typ.clone()) + .collect::>() + ); + let BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(*column_index, 0); + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Int); + assert_matches!( + &err.kind, + super::super::value::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_name, + err, + .. + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::<(i64,)>( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } + } + + #[test] + fn test_row_errors() { + // Column type check failure - happens never, because Row consists of CqlValues, + // which accept all CQL types. + + // Column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &serialize_cells([val_int(123)]), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::ColumnDeserializationFailed { + column_index: _column_index, + column_name, + err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + let err = super::super::value::tests::get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + let super::super::value::BuiltinDeserializationErrorKind::GenericParseError( + ParseError::DeserializationError(d), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = super::super::value::tests::get_deser_err(d); + let super::super::value::BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + } + + // Raw column deserialization failure + { + let col_name: &str = "i"; + let err = deserialize::( + &[spec(col_name, ColumnType::BigInt)], + &Bytes::from_static(b"alamakota"), + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + let BuiltinDeserializationErrorKind::RawColumnDeserializationFailed { + column_index: _column_index, + column_name, + err: _err, + } = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(column_name, col_name); + } + } +} diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs new file mode 100644 index 0000000000..95103998e2 --- /dev/null +++ b/scylla-cql/src/types/deserialize/value.rs @@ -0,0 +1,2919 @@ +//! Provides types for dealing with CQL value deserialization. + +use std::{ + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + hash::{BuildHasher, Hash}, + net::IpAddr, +}; + +use bytes::Bytes; +use uuid::Uuid; + +use std::fmt::Display; + +use thiserror::Error; + +use super::{make_error_replace_rust_name, DeserializationError, FrameSlice, TypeCheckError}; +use crate::frame::frame_errors::ParseError; +use crate::frame::response::result::{deser_cql_value, ColumnType, CqlValue}; +use crate::frame::types; +use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, +}; + +/// A type that can be deserialized from a column value inside a row that was +/// returned from a query. +/// +/// For tips on how to write a custom implementation of this trait, see the +/// documentation of the parent module. +/// +/// The crate also provides a derive macro which allows to automatically +/// implement the trait for a custom type. For more details on what the macro +/// is capable of, see its documentation. +pub trait DeserializeValue<'frame> +where + Self: Sized, +{ + /// Checks that the column type matches what this type expects. + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError>; + + /// Deserialize a column value from given serialized representation. + /// + /// This function can assume that the driver called `type_check` to verify + /// the column's type. Note that `deserialize` is not an unsafe function, + /// so it should not use the assumption about `type_check` being called + /// as an excuse to run `unsafe` code. + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result; +} + +impl<'frame> DeserializeValue<'frame> for CqlValue { + fn type_check(_typ: &ColumnType) -> Result<(), TypeCheckError> { + // CqlValue accepts all possible CQL types + Ok(()) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut val = ensure_not_null_slice::(typ, v)?; + let cql = deser_cql_value(typ, &mut val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) + })?; + Ok(cql) + } +} + +// Option represents nullability of CQL values: +// None corresponds to null, +// Some(val) to non-null values. +impl<'frame, T> DeserializeValue<'frame> for Option +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + T::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + v.map(|_| T::deserialize(typ, v)).transpose() + } +} + +/// Values that may be empty or not. +/// +/// In CQL, some types can have a special value of "empty", represented as +/// a serialized value of length 0. An example of this are integral types: +/// the "int" type can actually hold 2^32 + 1 possible values because of this +/// quirk. Note that this is distinct from being NULL. +/// +/// Rust types that cannot represent an empty value (e.g. i32) should implement +/// this trait in order to be deserialized as [MaybeEmpty]. +pub trait Emptiable {} + +/// A value that may be empty or not. +/// +/// `MaybeEmpty` was introduced to help support the quirk described in [Emptiable] +/// for Rust types which can't represent the empty, additional value. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +pub enum MaybeEmpty { + Empty, + Value(T), +} + +impl<'frame, T> DeserializeValue<'frame> for MaybeEmpty +where + T: DeserializeValue<'frame> + Emptiable, +{ + #[inline] + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + if val.is_empty() { + Ok(MaybeEmpty::Empty) + } else { + let v = >::deserialize(typ, v)?; + Ok(MaybeEmpty::Value(v)) + } + } +} + +macro_rules! impl_strict_type { + ($t:ty, [$($cql:ident)|+], $conv:expr $(, $l:lifetime)?) => { + impl<$($l,)? 'frame> DeserializeValue<'frame> for $t + where + $('frame: $l)? + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // TODO: Format the CQL type names in the same notation + // that ScyllaDB/Cassandra uses internally and include them + // in such form in the error message + exact_type_check!(typ, $($cql),*); + Ok(()) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + $conv(typ, v) + } + } + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident, $conv:expr $(, $l:lifetime)?) => { + impl_strict_type!($t, [$cql], $conv $(, $l)*); + }; +} + +macro_rules! impl_emptiable_strict_type { + ($t:ty, [$($cql:ident)|+], $conv:expr $(, $l:lifetime)?) => { + impl<$($l,)?> Emptiable for $t {} + + impl_strict_type!($t, [$($cql)|*], $conv $(, $l)*); + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident, $conv:expr $(, $l:lifetime)?) => { + impl_emptiable_strict_type!($t, [$cql], $conv $(, $l)*); + }; + +} + +// fixed numeric types + +macro_rules! impl_fixed_numeric_type { + ($t:ty, [$($cql:ident)|+]) => { + impl_emptiable_strict_type!( + $t, + [$($cql)|*], + |typ: &'frame ColumnType, v: Option>| { + const SIZE: usize = std::mem::size_of::<$t>(); + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + Ok(<$t>::from_be_bytes(*arr)) + } + ); + }; + + // Convenience pattern for omitting brackets if type-checking as single types. + ($t:ty, $cql:ident) => { + impl_fixed_numeric_type!($t, [$cql]); + }; +} + +impl_emptiable_strict_type!( + bool, + Boolean, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + Ok(arr[0] != 0x00) + } +); + +impl_fixed_numeric_type!(i8, TinyInt); +impl_fixed_numeric_type!(i16, SmallInt); +impl_fixed_numeric_type!(i32, Int); +impl_fixed_numeric_type!(i64, [BigInt | Counter]); +impl_fixed_numeric_type!(f32, Float); +impl_fixed_numeric_type!(f64, Double); + +// other numeric types + +impl_emptiable_strict_type!( + CqlVarint, + Varint, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(CqlVarint::from_signed_bytes_be_slice(val)) + } +); + +#[cfg(feature = "num-bigint-03")] +impl_emptiable_strict_type!(num_bigint_03::BigInt, Varint, |typ: &'frame ColumnType, + v: Option< + FrameSlice<'frame>, +>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(num_bigint_03::BigInt::from_signed_bytes_be(val)) +}); + +#[cfg(feature = "num-bigint-04")] +impl_emptiable_strict_type!(num_bigint_04::BigInt, Varint, |typ: &'frame ColumnType, + v: Option< + FrameSlice<'frame>, +>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(num_bigint_04::BigInt::from_signed_bytes_be(val)) +}); + +impl_emptiable_strict_type!( + CqlDecimal, + Decimal, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + let scale = types::read_int(&mut val).map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::GenericParseError(err.into()), + ) + })?; + Ok(CqlDecimal::from_signed_be_bytes_slice_and_exponent( + val, scale, + )) + } +); + +#[cfg(feature = "bigdecimal-04")] +impl_emptiable_strict_type!( + bigdecimal_04::BigDecimal, + Decimal, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + let scale = types::read_int(&mut val).map_err(|err| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::GenericParseError(err.into()), + ) + })? as i64; + let int_value = bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(val); + Ok(bigdecimal_04::BigDecimal::from((int_value, scale))) + } +); + +// blob + +impl_strict_type!( + &'a [u8], + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(val) + }, + 'a +); +impl_strict_type!( + Vec, + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + Ok(val.to_vec()) + } +); +impl_strict_type!( + Bytes, + Blob, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_owned::(typ, v)?; + Ok(val) + } +); + +// string + +macro_rules! impl_string_type { + ($t:ty, $conv:expr $(, $l:lifetime)?) => { + impl_strict_type!( + $t, + [Ascii | Text], + $conv + $(, $l)? + ); + } +} + +fn check_ascii(typ: &ColumnType, s: &[u8]) -> Result<(), DeserializationError> { + if matches!(typ, ColumnType::Ascii) && !s.is_ascii() { + return Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ExpectedAscii, + )); + } + Ok(()) +} + +impl_string_type!( + &'a str, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + check_ascii::<&str>(typ, val)?; + let s = std::str::from_utf8(val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::InvalidUtf8(err)) + })?; + Ok(s) + }, + 'a +); +impl_string_type!( + String, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + check_ascii::(typ, val)?; + let s = std::str::from_utf8(val).map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::InvalidUtf8(err)) + })?; + Ok(s.to_string()) + } +); + +// TODO: Consider support for deserialization of string::String + +// counter + +impl_strict_type!( + Counter, + Counter, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let counter = i64::from_be_bytes(*arr); + Ok(Counter(counter)) + } +); + +// date and time types + +// duration +impl_strict_type!( + CqlDuration, + Duration, + |typ: &'frame ColumnType, v: Option>| { + let mut val = ensure_not_null_slice::(typ, v)?; + + macro_rules! mk_err { + ($err: expr) => { + mk_deser_err::(typ, $err) + }; + } + + let months_i64 = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + let months = i32::try_from(months_i64) + .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; + + let days_i64 = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + let days = i32::try_from(days_i64) + .map_err(|_| mk_err!(BuiltinDeserializationErrorKind::ValueOverflow))?; + + let nanoseconds = types::vint_decode(&mut val).map_err(|err| { + mk_err!(BuiltinDeserializationErrorKind::GenericParseError( + err.into() + )) + })?; + + Ok(CqlDuration { + months, + days, + nanoseconds, + }) + } +); + +impl_emptiable_strict_type!( + CqlDate, + Date, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let days = u32::from_be_bytes(*arr); + Ok(CqlDate(days)) + } +); + +#[cfg(any(feature = "chrono", feature = "time"))] +fn get_days_since_epoch_from_date_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let days = u32::from_be_bytes(*arr); + let days_since_epoch = days as i64 - (1i64 << 31); + Ok(days_since_epoch) +} + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::NaiveDate, + Date, + |typ: &'frame ColumnType, v: Option>| { + let fail = || mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow); + let days_since_epoch = + chrono::Duration::try_days(get_days_since_epoch_from_date_column::(typ, v)?) + .ok_or_else(fail)?; + chrono::NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .checked_add_signed(days_since_epoch) + .ok_or_else(fail) + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::Date, + Date, + |typ: &'frame ColumnType, v: Option>| { + let days_since_epoch = + time::Duration::days(get_days_since_epoch_from_date_column::(typ, v)?); + time::Date::from_calendar_date(1970, time::Month::January, 1) + .unwrap() + .checked_add(days_since_epoch) + .ok_or_else(|| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + }) + } +); + +fn get_nanos_from_time_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let nanoseconds = i64::from_be_bytes(*arr); + + // Valid values are in the range 0 to 86399999999999 + if !(0..=86399999999999).contains(&nanoseconds) { + return Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ValueOverflow, + )); + } + + Ok(nanoseconds) +} + +impl_emptiable_strict_type!( + CqlTime, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + Ok(CqlTime(nanoseconds)) + } +); + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::NaiveTime, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + let naive_time: chrono::NaiveTime = CqlTime(nanoseconds).try_into().map_err(|_| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + })?; + Ok(naive_time) + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::Time, + Time, + |typ: &'frame ColumnType, v: Option>| { + let nanoseconds = get_nanos_from_time_column::(typ, v)?; + + let time: time::Time = CqlTime(nanoseconds).try_into().map_err(|_| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow) + })?; + Ok(time) + } +); + +fn get_millis_from_timestamp_column( + typ: &ColumnType, + v: Option>, +) -> Result { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let millis = i64::from_be_bytes(*arr); + + Ok(millis) +} + +impl_emptiable_strict_type!( + CqlTimestamp, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + let millis = get_millis_from_timestamp_column::(typ, v)?; + Ok(CqlTimestamp(millis)) + } +); + +#[cfg(feature = "chrono")] +impl_emptiable_strict_type!( + chrono::DateTime, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + use chrono::TimeZone as _; + + let millis = get_millis_from_timestamp_column::(typ, v)?; + match chrono::Utc.timestamp_millis_opt(millis) { + chrono::LocalResult::Single(datetime) => Ok(datetime), + _ => Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ValueOverflow, + )), + } + } +); + +#[cfg(feature = "time")] +impl_emptiable_strict_type!( + time::OffsetDateTime, + Timestamp, + |typ: &'frame ColumnType, v: Option>| { + let millis = get_millis_from_timestamp_column::(typ, v)?; + time::OffsetDateTime::from_unix_timestamp_nanos(millis as i128 * 1_000_000) + .map_err(|_| mk_deser_err::(typ, BuiltinDeserializationErrorKind::ValueOverflow)) + } +); + +// inet + +impl_emptiable_strict_type!( + IpAddr, + Inet, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + if let Ok(ipv4) = <[u8; 4]>::try_from(val) { + Ok(IpAddr::from(ipv4)) + } else if let Ok(ipv6) = <[u8; 16]>::try_from(val) { + Ok(IpAddr::from(ipv6)) + } else { + Err(mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::BadInetLength { got: val.len() }, + )) + } + } +); + +// uuid + +impl_emptiable_strict_type!( + Uuid, + Uuid, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let i = u128::from_be_bytes(*arr); + Ok(uuid::Uuid::from_u128(i)) + } +); + +impl_emptiable_strict_type!( + CqlTimeuuid, + Timeuuid, + |typ: &'frame ColumnType, v: Option>| { + let val = ensure_not_null_slice::(typ, v)?; + let arr = ensure_exact_length::(typ, val)?; + let i = u128::from_be_bytes(*arr); + Ok(CqlTimeuuid::from(uuid::Uuid::from_u128(i))) + } +); + +// secrecy +#[cfg(feature = "secret")] +impl<'frame, T> DeserializeValue<'frame> for secrecy::Secret +where + T: DeserializeValue<'frame> + secrecy::Zeroize, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + >::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + >::deserialize(typ, v).map(secrecy::Secret::new) + } +} + +// collections + +make_error_replace_rust_name!( + typck_error_replace_rust_name, + TypeCheckError, + BuiltinTypeCheckError +); + +make_error_replace_rust_name!( + deser_error_replace_rust_name, + DeserializationError, + BuiltinDeserializationError +); + +// lists and sets + +/// An iterator over either a CQL set or list. +pub struct ListlikeIterator<'frame, T> { + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} + +impl<'frame, T> ListlikeIterator<'frame, T> { + fn new( + coll_typ: &'frame ColumnType, + elem_typ: &'frame ColumnType, + count: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + elem_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data: std::marker::PhantomData, + } + } +} + +impl<'frame, T> DeserializeValue<'frame> for ListlikeIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::List(el_t) | ColumnType::Set(el_t) => { + >::type_check(el_t).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }) + } + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList, + ), + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { + mk_deser_err::( + typ, + SetOrListDeserializationErrorKind::LengthDeserializationFailed( + DeserializationError::new(err), + ), + ) + })?; + let elem_typ = match typ { + ColumnType::List(elem_typ) | ColumnType::Set(elem_typ) => elem_typ, + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(typ, elem_typ, count, v)) + } +} + +impl<'frame, T> Iterator for ListlikeIterator<'frame, T> +where + T: DeserializeValue<'frame>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?.map_err(|err| { + mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ) + }); + Some(raw.and_then(|raw| { + T::deserialize(self.elem_typ, raw).map_err(|err| { + mk_deser_err::( + self.coll_typ, + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) + }) + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, T> DeserializeValue<'frame> for Vec +where + T: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It makes sense for both Set and List to deserialize to Vec. + ListlikeIterator::<'frame, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, T> DeserializeValue<'frame> for BTreeSet +where + T: DeserializeValue<'frame> + Ord, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It only makes sense for Set to deserialize to BTreeSet. + // Deserializing List straight to BTreeSet would be lossy. + match typ { + ColumnType::Set(el_t) => >::type_check(el_t) + .map_err(typck_error_replace_rust_name::), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSet, + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, T, S> DeserializeValue<'frame> for HashSet +where + T: DeserializeValue<'frame> + Eq + Hash, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + // It only makes sense for Set to deserialize to HashSet. + // Deserializing List straight to HashSet would be lossy. + match typ { + ColumnType::Set(el_t) => >::type_check(el_t) + .map_err(typck_error_replace_rust_name::), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSet, + )), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + ListlikeIterator::<'frame, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +/// An iterator over a CQL map. +pub struct MapIterator<'frame, K, V> { + coll_typ: &'frame ColumnType, + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, +} + +impl<'frame, K, V> MapIterator<'frame, K, V> { + fn new( + coll_typ: &'frame ColumnType, + k_typ: &'frame ColumnType, + v_typ: &'frame ColumnType, + count: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + k_typ, + v_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data_k: std::marker::PhantomData, + phantom_data_v: std::marker::PhantomData, + } + } +} + +impl<'frame, K, V> DeserializeValue<'frame> for MapIterator<'frame, K, V> +where + K: DeserializeValue<'frame>, + V: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::Map(k_t, v_t) => { + >::type_check(k_t).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::KeyTypeCheckFailed(err)) + })?; + >::type_check(v_t).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::ValueTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::(typ, MapTypeCheckErrorKind::NotMap)), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let count = types::read_int_length(v.as_slice_mut()).map_err(|err| { + mk_deser_err::( + typ, + MapDeserializationErrorKind::LengthDeserializationFailed( + DeserializationError::new(err), + ), + ) + })?; + let (k_typ, v_typ) = match typ { + ColumnType::Map(k_t, v_t) => (k_t, v_t), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(typ, k_typ, v_typ, 2 * count, v)) + } +} + +impl<'frame, K, V> Iterator for MapIterator<'frame, K, V> +where + K: DeserializeValue<'frame>, + V: DeserializeValue<'frame>, +{ + type Item = Result<(K, V), DeserializationError>; + + fn next(&mut self) -> Option { + let raw_k = match self.raw_iter.next() { + Some(Ok(raw_k)) => raw_k, + Some(Err(err)) => { + return Some(Err(mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ))); + } + None => return None, + }; + let raw_v = match self.raw_iter.next() { + Some(Ok(raw_v)) => raw_v, + Some(Err(err)) => { + return Some(Err(mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::GenericParseError(err), + ))); + } + None => return None, + }; + + let do_next = || -> Result<(K, V), DeserializationError> { + let k = K::deserialize(self.k_typ, raw_k).map_err(|err| { + mk_deser_err::( + self.coll_typ, + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) + })?; + let v = V::deserialize(self.v_typ, raw_v).map_err(|err| { + mk_deser_err::( + self.coll_typ, + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) + })?; + Ok((k, v)) + }; + Some(do_next()) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +impl<'frame, K, V> DeserializeValue<'frame> for BTreeMap +where + K: DeserializeValue<'frame> + Ord, + V: DeserializeValue<'frame>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +impl<'frame, K, V, S> DeserializeValue<'frame> for HashMap +where + K: DeserializeValue<'frame> + Eq + Hash, + V: DeserializeValue<'frame>, + S: BuildHasher + Default + 'frame, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + MapIterator::<'frame, K, V>::type_check(typ).map_err(typck_error_replace_rust_name::) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + MapIterator::<'frame, K, V>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } +} + +// tuples + +// Implements tuple deserialization. +// The generated impl expects that the serialized data contains exactly the given amount of values. +macro_rules! impl_tuple { + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl<'frame, $($Ti),*> DeserializeValue<'frame> for ($($Ti,)*) + where + $($Ti: DeserializeValue<'frame>),* + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + let [$($idf),*] = ensure_tuple_type::<($($Ti,)*), TUPLE_LEN>(typ)?; + $( + <$Ti>::type_check($idf).map_err(|err| mk_typck_err::( + typ, + TupleTypeCheckErrorKind::FieldTypeCheckFailed { + position: $idx, + err, + } + ))?; + )* + Ok(()) + } + + fn deserialize(typ: &'frame ColumnType, v: Option>) -> Result { + const TUPLE_LEN: usize = (&[$($idx),*] as &[i32]).len(); + // Safety: we are allowed to assume that type_check() was already called. + let [$($idf),*] = ensure_tuple_type::<($($Ti,)*), TUPLE_LEN>(typ) + .expect("Type check should have prevented this!"); + + // Ignore the warning for the zero-sized tuple + #[allow(unused)] + let mut v = ensure_not_null_frame_slice::(typ, v)?; + let ret = ( + $( + v.read_cql_bytes() + .map_err(|err| DeserializationError::new(err)) + .and_then(|cql_bytes| <$Ti>::deserialize($idf, cql_bytes)) + .map_err(|err| mk_deser_err::( + typ, + TupleDeserializationErrorKind::FieldDeserializationFailed { + position: $idx, + err, + } + ) + )?, + )* + ); + Ok(ret) + } + } + } +} + +// Implements tuple deserialization for all tuple sizes up to predefined size. +// Accepts 3 lists, (see usage below the definition): +// - type parameters for the consecutive fields, +// - indices of the consecutive fields, +// - consecutive names for variables corresponding to each field. +// +// The idea is to recursively build prefixes of those lists (starting with an empty prefix) +// and for each prefix, implement deserialization for generic tuple represented by it. +// The < > brackets aid syntactically to separate the prefixes (positioned inside them) +// from the remaining suffixes (positioned beyond them). +macro_rules! impl_tuple_multiple { + // The entry point to the macro. + // Begins with implementing deserialization for (), then proceeds to the main recursive call. + ($($Ti:ident),*; $($idx:literal),*; $($idf:ident),*) => { + impl_tuple!(;;); + impl_tuple_multiple!( + $($Ti),* ; < > ; + $($idx),*; < > ; + $($idf),*; < > + ); + }; + + // The termination condition. No more fields given to extend the tuple with. + (;< $($Ti:ident,)* >;;< $($idx:literal,)* >;;< $($idf:ident,)* >) => {}; + + // The recursion. Upon each call, a new field is appended to the tuple + // and deserialization is implemented for it. + ( + $T_head:ident $(,$T_suffix:ident)*; < $($T_prefix:ident,)* > ; + $idx_head:literal $(,$idx_suffix:literal)*; < $($idx_prefix:literal,)* >; + $idf_head:ident $(,$idf_suffix:ident)* ; <$($idf_prefix:ident,)*> + ) => { + impl_tuple!( + $($T_prefix,)* $T_head; + $($idx_prefix, )* $idx_head; + $($idf_prefix, )* $idf_head + ); + impl_tuple_multiple!( + $($T_suffix),* ; < $($T_prefix,)* $T_head, > ; + $($idx_suffix),*; < $($idx_prefix, )* $idx_head, > ; + $($idf_suffix),*; < $($idf_prefix, )* $idf_head, > + ); + } +} + +pub(super) use impl_tuple_multiple; + +// Implements tuple deserialization for all tuple sizes up to 16. +impl_tuple_multiple!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15 +); + +// udts + +/// An iterator over fields of a User Defined Type. +/// +/// # Note +/// +/// A serialized UDT will generally have one value for each field, but it is +/// allowed to have fewer. This iterator differentiates null values +/// from non-existent values in the following way: +/// +/// - `None` - missing from the serialized form +/// - `Some(None)` - present, but null +/// - `Some(Some(...))` - non-null, present value +pub struct UdtIterator<'frame> { + all_fields: &'frame [(String, ColumnType)], + type_name: &'frame str, + keyspace: &'frame str, + remaining_fields: &'frame [(String, ColumnType)], + raw_iter: BytesSequenceIterator<'frame>, +} + +impl<'frame> UdtIterator<'frame> { + fn new( + fields: &'frame [(String, ColumnType)], + type_name: &'frame str, + keyspace: &'frame str, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + all_fields: fields, + remaining_fields: fields, + type_name, + keyspace, + raw_iter: BytesSequenceIterator::new(slice), + } + } + + #[inline] + pub fn fields(&self) -> &'frame [(String, ColumnType)] { + self.remaining_fields + } +} + +impl<'frame> DeserializeValue<'frame> for UdtIterator<'frame> { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::UserDefinedType { .. } => Ok(()), + _ => Err(mk_typck_err::(typ, UdtTypeCheckErrorKind::NotUdt)), + } + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + let v = ensure_not_null_frame_slice::(typ, v)?; + let (fields, type_name, keyspace) = match typ { + ColumnType::UserDefinedType { + field_types, + type_name, + keyspace, + } => (field_types.as_ref(), type_name.as_ref(), keyspace.as_ref()), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + Ok(Self::new(fields, type_name, keyspace, v)) + } +} + +impl<'frame> Iterator for UdtIterator<'frame> { + type Item = ( + &'frame (String, ColumnType), + Result>>, DeserializationError>, + ); + + fn next(&mut self) -> Option { + // TODO: Should we fail when there are too many fields? + let (head, fields) = self.remaining_fields.split_first()?; + self.remaining_fields = fields; + let raw_res = match self.raw_iter.next() { + // The field is there and it was parsed correctly + Some(Ok(raw)) => Ok(Some(raw)), + + // There were some bytes but they didn't parse as correct field value + Some(Err(err)) => Err(mk_deser_err::( + &ColumnType::UserDefinedType { + type_name: self.type_name.to_owned(), + keyspace: self.keyspace.to_owned(), + field_types: self.all_fields.to_owned(), + }, + BuiltinDeserializationErrorKind::GenericParseError(err), + )), + + // The field is just missing from the serialized form + None => Ok(None), + }; + Some((head, raw_res)) + } + + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +// Utilities + +fn ensure_not_null_frame_slice<'frame, T>( + typ: &ColumnType, + v: Option>, +) -> Result, DeserializationError> { + v.ok_or_else(|| mk_deser_err::(typ, BuiltinDeserializationErrorKind::ExpectedNonNull)) +} + +fn ensure_not_null_slice<'frame, T>( + typ: &ColumnType, + v: Option>, +) -> Result<&'frame [u8], DeserializationError> { + ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.as_slice()) +} + +fn ensure_not_null_owned( + typ: &ColumnType, + v: Option, +) -> Result { + ensure_not_null_frame_slice::(typ, v).map(|frame_slice| frame_slice.to_bytes()) +} + +fn ensure_exact_length<'frame, T, const SIZE: usize>( + typ: &ColumnType, + v: &'frame [u8], +) -> Result<&'frame [u8; SIZE], DeserializationError> { + v.try_into().map_err(|_| { + mk_deser_err::( + typ, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: SIZE, + got: v.len(), + }, + ) + }) +} + +fn ensure_tuple_type( + typ: &ColumnType, +) -> Result<&[ColumnType; SIZE], TypeCheckError> { + if let ColumnType::Tuple(typs_v) = typ { + typs_v.as_slice().try_into().map_err(|_| { + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: SIZE, + cql_type_el_count: typs_v.len(), + }) + }) + } else { + Err(BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::NotTuple, + )) + } + .map_err(|kind| mk_typck_err::(typ, kind)) +} + +// Helper iterators + +/// Iterates over a sequence of `[bytes]` items from a frame subslice, expecting +/// a particular number of items. +/// +/// The iterator does not consider it to be an error if there are some bytes +/// remaining in the slice after parsing requested amount of items. +#[derive(Clone, Copy, Debug)] +pub struct FixedLengthBytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, + remaining: usize, +} + +impl<'frame> FixedLengthBytesSequenceIterator<'frame> { + fn new(count: usize, slice: FrameSlice<'frame>) -> Self { + Self { + slice, + remaining: count, + } + } +} + +impl<'frame> Iterator for FixedLengthBytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + Some(self.slice.read_cql_bytes()) + } +} + +/// Iterates over a sequence of `[bytes]` items from a frame subslice. +/// +/// The `[bytes]` items are parsed until the end of subslice is reached. +#[derive(Clone, Copy, Debug)] +pub struct BytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, +} + +impl<'frame> BytesSequenceIterator<'frame> { + fn new(slice: FrameSlice<'frame>) -> Self { + Self { slice } + } +} + +impl<'frame> From> for BytesSequenceIterator<'frame> { + #[inline] + fn from(slice: FrameSlice<'frame>) -> Self { + Self::new(slice) + } +} + +impl<'frame> Iterator for BytesSequenceIterator<'frame> { + type Item = Result>, ParseError>; + + fn next(&mut self) -> Option { + if self.slice.as_slice().is_empty() { + None + } else { + Some(self.slice.read_cql_bytes()) + } + } +} + +// Error facilities + +/// Type checking of one of the built-in types failed. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check Rust type {rust_name} against CQL type {cql_type:?}: {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type being deserialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being deserialized from. + pub cql_type: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err( + cql_type: &ColumnType, + kind: impl Into, +) -> TypeCheckError { + mk_typck_err_named(std::any::type_name::(), cql_type, kind) +} + +fn mk_typck_err_named( + name: &'static str, + cql_type: &ColumnType, + kind: impl Into, +) -> TypeCheckError { + TypeCheckError::new(BuiltinTypeCheckError { + rust_name: name, + cql_type: cql_type.clone(), + kind: kind.into(), + }) +} + +macro_rules! exact_type_check { + ($typ:ident, $($cql:tt),*) => { + match $typ { + $(ColumnType::$cql)|* => {}, + _ => return Err(mk_typck_err::( + $typ, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[$(ColumnType::$cql),*], + } + )) + } + }; +} +use exact_type_check; + +/// Describes why type checking some of the built-in types failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind { + /// Expected one from a list of particular types. + MismatchedType { + /// The list of types that the Rust type can deserialize from. + expected: &'static [ColumnType], + }, + + /// A type check failure specific to a CQL set or list. + SetOrListError(SetOrListTypeCheckErrorKind), + + /// A type check failure specific to a CQL map. + MapError(MapTypeCheckErrorKind), + + /// A type check failure specific to a CQL tuple. + TupleError(TupleTypeCheckErrorKind), + + /// A type check failure specific to a CQL UDT. + UdtError(UdtTypeCheckErrorKind), +} + +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: SetOrListTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::SetOrListError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: MapTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::MapError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: TupleTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::TupleError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: UdtTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::UdtError(value) + } +} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinTypeCheckErrorKind::MismatchedType { expected } => { + write!(f, "expected one of the CQL types: {expected:?}") + } + BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), + } + } +} + +/// Describes why type checking of a set or list type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum SetOrListTypeCheckErrorKind { + /// The CQL type is neither a set not a list. + NotSetOrList, + /// The CQL type is not a set. + NotSet, + /// Incompatible element types. + ElementTypeCheckFailed(TypeCheckError), +} + +impl Display for SetOrListTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListTypeCheckErrorKind::NotSetOrList => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a set nor a list") + } + SetOrListTypeCheckErrorKind::NotSet => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was not a set") + } + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err) => { + write!(f, "the set or list element types between the CQL type and the Rust type failed to type check against each other: {}", err) + } + } + } +} + +/// Describes why type checking of a map type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum MapTypeCheckErrorKind { + /// The CQL type is not a map. + NotMap, + /// Incompatible key types. + KeyTypeCheckFailed(TypeCheckError), + /// Incompatible value types. + ValueTypeCheckFailed(TypeCheckError), +} + +impl Display for MapTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapTypeCheckErrorKind::NotMap => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was neither a map") + } + MapTypeCheckErrorKind::KeyTypeCheckFailed(err) => { + write!(f, "the map key types between the CQL type and the Rust type failed to type check against each other: {}", err) + }, + MapTypeCheckErrorKind::ValueTypeCheckFailed(err) => { + write!(f, "the map value types between the CQL type and the Rust type failed to type check against each other: {}", err) + }, + } + } +} + +/// Describes why type checking of a tuple failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleTypeCheckErrorKind { + /// The CQL type is not a tuple. + NotTuple, + + /// The tuple has the wrong element count. + WrongElementCount { + /// The number of elements that the Rust tuple has. + rust_type_el_count: usize, + + /// The number of elements that the CQL tuple type has. + cql_type_el_count: usize, + }, + + /// The CQL type and the Rust type of a tuple field failed to type check against each other. + FieldTypeCheckFailed { + /// The index of the field whose type check failed. + position: usize, + + /// The type check error that occured. + err: TypeCheckError, + }, +} + +impl Display for TupleTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleTypeCheckErrorKind::NotTuple => write!( + f, + "the CQL type the tuple was attempted to be serialized to is not a tuple" + ), + TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count, + cql_type_el_count, + } => write!( + f, + "wrong tuple element count: CQL type has {cql_type_el_count}, the Rust tuple has {rust_type_el_count}" + ), + + TupleTypeCheckErrorKind::FieldTypeCheckFailed { position, err } => write!( + f, + "the CQL type and the Rust type of the tuple field {} failed to type check against each other: {}", + position, + err + ) + } + } +} + +/// Describes why type checking of a user defined type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtTypeCheckErrorKind { + /// The CQL type is not a user defined type. + NotUdt, +} + +impl Display for UdtTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtTypeCheckErrorKind::NotUdt => write!( + f, + "the CQL type the Rust type was attempted to be type checked against is not a UDT" + ), + } + } +} + +/// Deserialization of one of the built-in types failed. +#[derive(Debug, Error)] +#[error("Failed to deserialize Rust type {rust_name} from CQL type {cql_type:?}: {kind}")] +pub struct BuiltinDeserializationError { + /// Name of the Rust type being deserialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being deserialized from. + pub cql_type: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinDeserializationErrorKind, +} + +fn mk_deser_err( + cql_type: &ColumnType, + kind: impl Into, +) -> DeserializationError { + mk_deser_err_named(std::any::type_name::(), cql_type, kind) +} + +fn mk_deser_err_named( + name: &'static str, + cql_type: &ColumnType, + kind: impl Into, +) -> DeserializationError { + DeserializationError::new(BuiltinDeserializationError { + rust_name: name, + cql_type: cql_type.clone(), + kind: kind.into(), + }) +} + +/// Describes why deserialization of some of the built-in types failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum BuiltinDeserializationErrorKind { + /// A generic deserialization failure - legacy error type. + GenericParseError(ParseError), + + /// Expected non-null value, got null. + ExpectedNonNull, + + /// The length of read value in bytes is different than expected for the Rust type. + ByteLengthMismatch { expected: usize, got: usize }, + + /// Expected valid ASCII string. + ExpectedAscii, + + /// Invalid UTF-8 string. + InvalidUtf8(std::str::Utf8Error), + + /// The read value is out of range supported by the Rust type. + // TODO: consider storing additional info here (what exactly did not fit and why) + ValueOverflow, + + /// The length of read value in bytes is not suitable for IP address. + BadInetLength { got: usize }, + + /// A deserialization failure specific to a CQL set or list. + SetOrListError(SetOrListDeserializationErrorKind), + + /// A deserialization failure specific to a CQL map. + MapError(MapDeserializationErrorKind), + + /// A deserialization failure specific to a CQL tuple. + TupleError(TupleDeserializationErrorKind), +} + +impl Display for BuiltinDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinDeserializationErrorKind::GenericParseError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::ExpectedNonNull => { + f.write_str("expected a non-null value, got null") + } + BuiltinDeserializationErrorKind::ByteLengthMismatch { expected, got } => write!( + f, + "the CQL type requires {} bytes, but got {}", + expected, got, + ), + BuiltinDeserializationErrorKind::ExpectedAscii => { + f.write_str("expected a valid ASCII string") + } + BuiltinDeserializationErrorKind::InvalidUtf8(err) => err.fmt(f), + BuiltinDeserializationErrorKind::ValueOverflow => { + // TODO: consider storing Arc of the offending value + // inside this variant for debug purposes. + f.write_str("read value is out of representable range") + } + BuiltinDeserializationErrorKind::BadInetLength { got } => write!( + f, + "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" + ), + BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), + } + } +} + +/// Describes why deserialization of a set or list type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum SetOrListDeserializationErrorKind { + /// Failed to deserialize set or list's length. + LengthDeserializationFailed(DeserializationError), + + /// One of the elements of the set/list failed to deserialize. + ElementDeserializationFailed(DeserializationError), +} + +impl Display for SetOrListDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListDeserializationErrorKind::LengthDeserializationFailed(err) => { + write!(f, "failed to deserialize set or list's length: {}", err) + } + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the elements: {}", err) + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + #[inline] + fn from(err: SetOrListDeserializationErrorKind) -> Self { + Self::SetOrListError(err) + } +} + +/// Describes why deserialization of a map type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum MapDeserializationErrorKind { + /// Failed to deserialize map's length. + LengthDeserializationFailed(DeserializationError), + + /// One of the keys in the map failed to deserialize. + KeyDeserializationFailed(DeserializationError), + + /// One of the values in the map failed to deserialize. + ValueDeserializationFailed(DeserializationError), +} + +impl Display for MapDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapDeserializationErrorKind::LengthDeserializationFailed(err) => { + write!(f, "failed to deserialize map's length: {}", err) + } + MapDeserializationErrorKind::KeyDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the keys: {}", err) + } + MapDeserializationErrorKind::ValueDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the values: {}", err) + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + fn from(err: MapDeserializationErrorKind) -> Self { + Self::MapError(err) + } +} + +/// Describes why deserialization of a tuple failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleDeserializationErrorKind { + /// One of the tuple fields failed to deserialize. + FieldDeserializationFailed { + /// Index of the tuple field that failed to deserialize. + position: usize, + + /// The error that caused the tuple field deserialization to fail. + err: DeserializationError, + }, +} + +impl Display for TupleDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleDeserializationErrorKind::FieldDeserializationFailed { + position: index, + err, + } => { + write!(f, "field no. {index} failed to deserialize: {err}") + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + fn from(err: TupleDeserializationErrorKind) -> Self { + Self::TupleError(err) + } +} + +#[cfg(test)] +pub(super) mod tests { + use assert_matches::assert_matches; + use bytes::{BufMut, Bytes, BytesMut}; + use uuid::Uuid; + + use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; + use std::fmt::Debug; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use crate::frame::response::result::{ColumnType, CqlValue}; + use crate::frame::value::{ + Counter, CqlDate, CqlDecimal, CqlDuration, CqlTime, CqlTimestamp, CqlTimeuuid, CqlVarint, + }; + use crate::types::deserialize::value::{ + TupleDeserializationErrorKind, TupleTypeCheckErrorKind, + }; + use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; + use crate::types::serialize::value::SerializeValue; + use crate::types::serialize::CellWriter; + + use super::{ + mk_deser_err, BuiltinDeserializationError, BuiltinDeserializationErrorKind, + BuiltinTypeCheckError, BuiltinTypeCheckErrorKind, DeserializeValue, ListlikeIterator, + MapDeserializationErrorKind, MapIterator, MapTypeCheckErrorKind, MaybeEmpty, + SetOrListDeserializationErrorKind, SetOrListTypeCheckErrorKind, + }; + + #[test] + fn test_deserialize_bytes() { + const ORIGINAL_BYTES: &[u8] = &[1, 5, 2, 4, 3]; + + let bytes = make_bytes(ORIGINAL_BYTES); + + let decoded_slice = deserialize::<&[u8]>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_vec = deserialize::>(&ColumnType::Blob, &bytes).unwrap(); + let decoded_bytes = deserialize::(&ColumnType::Blob, &bytes).unwrap(); + + assert_eq!(decoded_slice, ORIGINAL_BYTES); + assert_eq!(decoded_vec, ORIGINAL_BYTES); + assert_eq!(decoded_bytes, ORIGINAL_BYTES); + + // ser/de identity + + // Nonempty blob + assert_ser_de_identity(&ColumnType::Blob, &ORIGINAL_BYTES, &mut Bytes::new()); + + // Empty blob + assert_ser_de_identity(&ColumnType::Blob, &(&[] as &[u8]), &mut Bytes::new()); + } + + #[test] + fn test_deserialize_ascii() { + const ASCII_TEXT: &str = "The quick brown fox jumps over the lazy dog"; + + let ascii = make_bytes(ASCII_TEXT.as_bytes()); + + for typ in [ColumnType::Ascii, ColumnType::Text].iter() { + let decoded_str = deserialize::<&str>(typ, &ascii).unwrap(); + let decoded_string = deserialize::(typ, &ascii).unwrap(); + + assert_eq!(decoded_str, ASCII_TEXT); + assert_eq!(decoded_string, ASCII_TEXT); + + // ser/de identity + + // Empty string + assert_ser_de_identity(typ, &"", &mut Bytes::new()); + assert_ser_de_identity(typ, &"".to_owned(), &mut Bytes::new()); + + // Nonempty string + assert_ser_de_identity(typ, &ASCII_TEXT, &mut Bytes::new()); + assert_ser_de_identity(typ, &ASCII_TEXT.to_owned(), &mut Bytes::new()); + } + } + + #[test] + fn test_deserialize_text() { + const UNICODE_TEXT: &str = "Zażółć gęślą jaźń"; + + let unicode = make_bytes(UNICODE_TEXT.as_bytes()); + + // Should fail because it's not an ASCII string + deserialize::<&str>(&ColumnType::Ascii, &unicode).unwrap_err(); + deserialize::(&ColumnType::Ascii, &unicode).unwrap_err(); + + let decoded_text_str = deserialize::<&str>(&ColumnType::Text, &unicode).unwrap(); + let decoded_text_string = deserialize::(&ColumnType::Text, &unicode).unwrap(); + assert_eq!(decoded_text_str, UNICODE_TEXT); + assert_eq!(decoded_text_string, UNICODE_TEXT); + + // ser/de identity + + assert_ser_de_identity(&ColumnType::Text, &UNICODE_TEXT, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Text, + &UNICODE_TEXT.to_owned(), + &mut Bytes::new(), + ); + } + + #[test] + fn test_integral() { + let tinyint = make_bytes(&[0x01]); + let decoded_tinyint = deserialize::(&ColumnType::TinyInt, &tinyint).unwrap(); + assert_eq!(decoded_tinyint, 0x01); + + let smallint = make_bytes(&[0x01, 0x02]); + let decoded_smallint = deserialize::(&ColumnType::SmallInt, &smallint).unwrap(); + assert_eq!(decoded_smallint, 0x0102); + + let int = make_bytes(&[0x01, 0x02, 0x03, 0x04]); + let decoded_int = deserialize::(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, 0x01020304); + + let bigint = make_bytes(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + let decoded_bigint = deserialize::(&ColumnType::BigInt, &bigint).unwrap(); + assert_eq!(decoded_bigint, 0x0102030405060708); + + // ser/de identity + assert_ser_de_identity(&ColumnType::TinyInt, &42_i8, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::SmallInt, &2137_i16, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Int, &21372137_i32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::BigInt, &0_i64, &mut Bytes::new()); + } + + #[test] + fn test_bool() { + for boolean in [true, false] { + let boolean_bytes = make_bytes(&[boolean as u8]); + let decoded_bool = deserialize::(&ColumnType::Boolean, &boolean_bytes).unwrap(); + assert_eq!(decoded_bool, boolean); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Boolean, &boolean, &mut Bytes::new()); + } + } + + #[test] + fn test_floating_point() { + let float = make_bytes(&[63, 0, 0, 0]); + let decoded_float = deserialize::(&ColumnType::Float, &float).unwrap(); + assert_eq!(decoded_float, 0.5); + + let double = make_bytes(&[64, 0, 0, 0, 0, 0, 0, 0]); + let decoded_double = deserialize::(&ColumnType::Double, &double).unwrap(); + assert_eq!(decoded_double, 2.0); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Float, &21.37_f32, &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &2137.2137_f64, &mut Bytes::new()); + } + + #[test] + fn test_varlen_numbers() { + // varint + assert_ser_de_identity( + &ColumnType::Varint, + &CqlVarint::from_signed_bytes_be_slice(b"Ala ma kota"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-03")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_03::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + #[cfg(feature = "num-bigint-04")] + assert_ser_de_identity( + &ColumnType::Varint, + &num_bigint_04::BigInt::from_signed_bytes_be(b"Kot ma Ale"), + &mut Bytes::new(), + ); + + // decimal + assert_ser_de_identity( + &ColumnType::Decimal, + &CqlDecimal::from_signed_be_bytes_slice_and_exponent(b"Ala ma kota", 42), + &mut Bytes::new(), + ); + + #[cfg(feature = "bigdecimal-04")] + assert_ser_de_identity( + &ColumnType::Decimal, + &bigdecimal_04::BigDecimal::new( + bigdecimal_04::num_bigint::BigInt::from_signed_bytes_be(b"Ala ma kota"), + 42, + ), + &mut Bytes::new(), + ); + } + + #[test] + fn test_date_time_types() { + // duration + assert_ser_de_identity( + &ColumnType::Duration, + &CqlDuration { + months: 21, + days: 37, + nanoseconds: 42, + }, + &mut Bytes::new(), + ); + + // date + assert_ser_de_identity(&ColumnType::Date, &CqlDate(0xbeaf), &mut Bytes::new()); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Date, + &chrono::NaiveDate::from_yo_opt(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Date, + &time::Date::from_ordinal_date(1999, 99).unwrap(), + &mut Bytes::new(), + ); + + // time + assert_ser_de_identity(&ColumnType::Time, &CqlTime(0xdeed), &mut Bytes::new()); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Time, + &chrono::NaiveTime::from_hms_micro_opt(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Time, + &time::Time::from_hms_micro(21, 37, 21, 37).unwrap(), + &mut Bytes::new(), + ); + + // timestamp + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlTimestamp(0xceed), + &mut Bytes::new(), + ); + + #[cfg(feature = "chrono")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &chrono::DateTime::::from_timestamp_millis(0xdead_cafe_deaf).unwrap(), + &mut Bytes::new(), + ); + + #[cfg(feature = "time")] + assert_ser_de_identity( + &ColumnType::Timestamp, + &time::OffsetDateTime::from_unix_timestamp(0xdead_cafe).unwrap(), + &mut Bytes::new(), + ); + } + + #[test] + fn test_inet() { + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V4(Ipv4Addr::BROADCAST), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Inet, + &IpAddr::V6(Ipv6Addr::LOCALHOST), + &mut Bytes::new(), + ); + } + + #[test] + fn test_uuid() { + assert_ser_de_identity( + &ColumnType::Uuid, + &Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timeuuid, + &CqlTimeuuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + &mut Bytes::new(), + ); + } + + #[test] + fn test_null_and_empty() { + // non-nullable emptiable deserialization, non-empty value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Value((21 << 24) + (37 << 16))); + + // non-nullable emptiable deserialization, empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, MaybeEmpty::Empty); + + // nullable non-emptiable deserialization, non-null value + let int = make_bytes(&[21, 37, 0, 0]); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some((21 << 24) + (37 << 16))); + + // nullable non-emptiable deserialization, null value + let int = make_null(); + let decoded_int = deserialize::>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, None); + + // nullable emptiable deserialization, non-null non-empty value + let int = make_bytes(&[]); + let decoded_int = deserialize::>>(&ColumnType::Int, &int).unwrap(); + assert_eq!(decoded_int, Some(MaybeEmpty::Empty)); + + // ser/de identity + assert_ser_de_identity(&ColumnType::Int, &Some(12321_i32), &mut Bytes::new()); + assert_ser_de_identity(&ColumnType::Double, &None::, &mut Bytes::new()); + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Ascii)), + &None::>, + &mut Bytes::new(), + ); + } + + #[test] + fn test_maybe_empty() { + let empty = make_bytes(&[]); + let decoded_empty = deserialize::>(&ColumnType::TinyInt, &empty).unwrap(); + assert_eq!(decoded_empty, MaybeEmpty::Empty); + + let non_empty = make_bytes(&[0x01]); + let decoded_non_empty = + deserialize::>(&ColumnType::TinyInt, &non_empty).unwrap(); + assert_eq!(decoded_non_empty, MaybeEmpty::Value(0x01)); + } + + #[test] + fn test_cql_value() { + assert_ser_de_identity( + &ColumnType::Counter, + &CqlValue::Counter(Counter(765)), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Timestamp, + &CqlValue::Timestamp(CqlTimestamp(2136)), + &mut Bytes::new(), + ); + + assert_ser_de_identity(&ColumnType::Boolean, &CqlValue::Empty, &mut Bytes::new()); + + assert_ser_de_identity( + &ColumnType::Text, + &CqlValue::Text("kremówki".to_owned()), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &ColumnType::Ascii, + &CqlValue::Ascii("kremowy".to_owned()), + &mut Bytes::new(), + ); + + assert_ser_de_identity( + &ColumnType::Set(Box::new(ColumnType::Text)), + &CqlValue::Set(vec![CqlValue::Text("Ala ma kota".to_owned())]), + &mut Bytes::new(), + ); + } + + #[test] + fn test_list_and_set() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let list_typ = ColumnType::List(Box::new(ColumnType::Ascii)); + let set_typ = ColumnType::Set(Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some("quick")); + assert_eq!(iter.next().transpose().unwrap(), Some("brown")); + assert_eq!(iter.next().transpose().unwrap(), Some("fox")); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_vec_str = vec!["quick", "brown", "fox"]; + let expected_vec_string = vec!["quick".to_string(), "brown".to_string(), "fox".to_string()]; + + // list + let decoded_vec_str = deserialize::>(&list_typ, &collection).unwrap(); + let decoded_vec_string = deserialize::>(&list_typ, &collection).unwrap(); + assert_eq!(decoded_vec_str, expected_vec_str); + assert_eq!(decoded_vec_string, expected_vec_string); + + // hash set + let decoded_hash_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_hash_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_hash_string, + expected_vec_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&set_typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&set_typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_vec_str.clone().into_iter().collect(), + ); + assert_eq!( + decoded_btree_string, + expected_vec_string.into_iter().collect(), + ); + + // ser/de identity + assert_ser_de_identity(&list_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity(&set_typ, &vec!["qwik"], &mut Bytes::new()); + assert_ser_de_identity( + &set_typ, + &HashSet::<&str, std::collections::hash_map::RandomState>::from_iter(["qwik"]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &set_typ, + &BTreeSet::<&str>::from_iter(["qwik"]), + &mut Bytes::new(), + ); + } + + #[test] + fn test_map() { + let mut collection_contents = BytesMut::new(); + collection_contents.put_i32(3); + append_bytes(&mut collection_contents, &1i32.to_be_bytes()); + append_bytes(&mut collection_contents, "quick".as_bytes()); + append_bytes(&mut collection_contents, &2i32.to_be_bytes()); + append_bytes(&mut collection_contents, "brown".as_bytes()); + append_bytes(&mut collection_contents, &3i32.to_be_bytes()); + append_bytes(&mut collection_contents, "fox".as_bytes()); + + let collection = make_bytes(&collection_contents); + + let typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Ascii)); + + // iterator + let mut iter = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(iter.next().transpose().unwrap(), Some((1, "quick"))); + assert_eq!(iter.next().transpose().unwrap(), Some((2, "brown"))); + assert_eq!(iter.next().transpose().unwrap(), Some((3, "fox"))); + assert_eq!(iter.next().transpose().unwrap(), None); + + let expected_str = vec![(1, "quick"), (2, "brown"), (3, "fox")]; + let expected_string = vec![ + (1, "quick".to_string()), + (2, "brown".to_string()), + (3, "fox".to_string()), + ]; + + // hash set + let decoded_hash_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_hash_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!(decoded_hash_str, expected_str.clone().into_iter().collect()); + assert_eq!( + decoded_hash_string, + expected_string.clone().into_iter().collect(), + ); + + // btree set + let decoded_btree_str = deserialize::>(&typ, &collection).unwrap(); + let decoded_btree_string = deserialize::>(&typ, &collection).unwrap(); + assert_eq!( + decoded_btree_str, + expected_str.clone().into_iter().collect(), + ); + assert_eq!(decoded_btree_string, expected_string.into_iter().collect()); + + // ser/de identity + assert_ser_de_identity( + &typ, + &HashMap::::from_iter([( + -42, "qwik", + )]), + &mut Bytes::new(), + ); + assert_ser_de_identity( + &typ, + &BTreeMap::::from_iter([(-42, "qwik")]), + &mut Bytes::new(), + ); + } + + #[test] + fn test_tuples() { + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_null(&mut tuple_contents); + + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Ascii, ColumnType::Uuid]); + + let tup = deserialize::<(i32, &str, Option)>(&typ, &tuple).unwrap(); + assert_eq!(tup, (42, "foo", None)); + + // ser/de identity + + // () does not implement SerializeValue, yet it does implement DeserializeValue. + // assert_ser_de_identity(&ColumnType::Tuple(vec![]), &(), &mut Bytes::new()); + + // nonempty, varied tuple + assert_ser_de_identity( + &ColumnType::Tuple(vec![ + ColumnType::List(Box::new(ColumnType::Boolean)), + ColumnType::BigInt, + ColumnType::Uuid, + ColumnType::Inet, + ]), + &( + vec![true, false, true], + 42_i64, + Uuid::from_u128(0xdead_cafe_deaf_feed_beaf_bead), + IpAddr::V6(Ipv6Addr::new(0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11)), + ), + &mut Bytes::new(), + ); + + // nested tuples + assert_ser_de_identity( + &ColumnType::Tuple(vec![ColumnType::Tuple(vec![ColumnType::Tuple(vec![ + ColumnType::Text, + ])])]), + &((("",),),), + &mut Bytes::new(), + ); + } + + #[test] + fn test_custom_type_parser() { + #[derive(Default, Debug, PartialEq, Eq)] + struct SwappedPair(B, A); + impl<'frame, A, B> DeserializeValue<'frame> for SwappedPair + where + A: DeserializeValue<'frame>, + B: DeserializeValue<'frame>, + { + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + <(B, A) as DeserializeValue<'frame>>::type_check(typ) + } + + fn deserialize( + typ: &'frame ColumnType, + v: Option>, + ) -> Result { + <(B, A) as DeserializeValue<'frame>>::deserialize(typ, v).map(|(b, a)| Self(b, a)) + } + } + + let mut tuple_contents = BytesMut::new(); + append_bytes(&mut tuple_contents, "foo".as_bytes()); + append_bytes(&mut tuple_contents, &42i32.to_be_bytes()); + let tuple = make_bytes(&tuple_contents); + + let typ = ColumnType::Tuple(vec![ColumnType::Ascii, ColumnType::Int]); + + let tup = deserialize::>(&typ, &tuple).unwrap(); + assert_eq!(tup, SwappedPair("foo", 42)); + } + + fn deserialize<'frame, T>( + typ: &'frame ColumnType, + bytes: &'frame Bytes, + ) -> Result + where + T: DeserializeValue<'frame>, + { + >::type_check(typ) + .map_err(|typecheck_err| DeserializationError(typecheck_err.0))?; + let mut frame_slice = FrameSlice::new(bytes); + let value = frame_slice.read_cql_bytes().map_err(|err| { + mk_deser_err::(typ, BuiltinDeserializationErrorKind::GenericParseError(err)) + })?; + >::deserialize(typ, value) + } + + fn make_bytes(cell: &[u8]) -> Bytes { + let mut b = BytesMut::new(); + append_bytes(&mut b, cell); + b.freeze() + } + + fn serialize(typ: &ColumnType, value: &dyn SerializeValue) -> Bytes { + let mut bytes = Bytes::new(); + serialize_to_buf(typ, value, &mut bytes); + bytes + } + + fn serialize_to_buf(typ: &ColumnType, value: &dyn SerializeValue, buf: &mut Bytes) { + let mut v = Vec::new(); + let writer = CellWriter::new(&mut v); + value.serialize(typ, writer).unwrap(); + *buf = v.into(); + } + + fn append_bytes(b: &mut impl BufMut, cell: &[u8]) { + b.put_i32(cell.len() as i32); + b.put_slice(cell); + } + + fn make_null() -> Bytes { + let mut b = BytesMut::new(); + append_null(&mut b); + b.freeze() + } + + fn append_null(b: &mut impl BufMut) { + b.put_i32(-1); + } + + fn assert_ser_de_identity<'f, T: SerializeValue + DeserializeValue<'f> + PartialEq + Debug>( + typ: &'f ColumnType, + v: &'f T, + buf: &'f mut Bytes, // `buf` must be passed as a reference from outside, because otherwise + // we cannot specify the lifetime for DeserializeValue. + ) { + serialize_to_buf(typ, v, buf); + let deserialized = deserialize::(typ, buf).unwrap(); + assert_eq!(&deserialized, v); + } + + /* Errors checks */ + + #[track_caller] + pub(crate) fn get_typeck_err_inner<'a>( + err: &'a (dyn std::error::Error + 'static), + ) -> &'a BuiltinTypeCheckError { + match err.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinTypeCheckError: {:?}", err), + } + } + + #[track_caller] + pub(crate) fn get_typeck_err(err: &DeserializationError) -> &BuiltinTypeCheckError { + get_typeck_err_inner(err.0.as_ref()) + } + + #[track_caller] + pub(crate) fn get_deser_err(err: &DeserializationError) -> &BuiltinDeserializationError { + match err.0.downcast_ref() { + Some(err) => err, + None => panic!("not a BuiltinDeserializationError: {:?}", err), + } + } + + macro_rules! assert_given_error { + ($get_err:ident, $bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + let cql_typ = $cql_typ.clone(); + let err = deserialize::<$DestT>(&cql_typ, $bytes).unwrap_err(); + let err = $get_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<$DestT>()); + assert_eq!(err.cql_type, cql_typ); + assert_matches::assert_matches!(err.kind, $kind); + }; + } + + macro_rules! assert_type_check_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_typeck_err, $bytes, $DestT, $cql_typ, $kind); + }; + } + + macro_rules! assert_deser_error { + ($bytes:expr, $DestT:ty, $cql_typ:expr, $kind:pat) => { + assert_given_error!(get_deser_err, $bytes, $DestT, $cql_typ, $kind); + }; + } + + #[test] + fn test_native_errors() { + // Simple type mismatch + { + let v = 123_i32; + let bytes = serialize(&ColumnType::Int, &v); + + // Incompatible types render type check error. + assert_type_check_error!( + &bytes, + f64, + ColumnType::Int, + super::BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Double], + } + ); + + // ColumnType is said to be Double (8 bytes expected), but in reality the serialized form has 4 bytes only. + assert_deser_error!( + &bytes, + f64, + ColumnType::Double, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4, + } + ); + + // ColumnType is said to be Float, but in reality Int was serialized. + // As these types have the same size, though, and every binary number in [0, 2^32] is a valid + // value for both of them, this always succeeds. + { + deserialize::(&ColumnType::Float, &bytes).unwrap(); + } + } + + // str (and also Uuid) are interesting because they accept two types. + { + let v = "Ala ma kota"; + let bytes = serialize(&ColumnType::Ascii, &v); + + assert_type_check_error!( + &bytes, + &str, + ColumnType::Double, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text], + } + ); + + // ColumnType is said to be BigInt (8 bytes expected), but in reality the serialized form + // (the string) has 11 bytes. + assert_deser_error!( + &bytes, + i64, + ColumnType::BigInt, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 11, // str len + } + ); + } + { + // -126 is not a valid ASCII nor UTF-8 byte. + let v = -126_i8; + let bytes = serialize(&ColumnType::TinyInt, &v); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Ascii, + BuiltinDeserializationErrorKind::ExpectedAscii + ); + + assert_deser_error!( + &bytes, + &str, + ColumnType::Text, + BuiltinDeserializationErrorKind::InvalidUtf8(_) + ); + } + } + + #[test] + fn test_set_or_list_errors() { + // Not a set or list + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::Float, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList + ) + ); + + // Type check of Rust set against CQL list must fail, because it would be lossy. + assert_type_check_error!( + &Bytes::new(), + BTreeSet, + ColumnType::List(Box::new(ColumnType::Int)), + BuiltinTypeCheckErrorKind::SetOrListError(SetOrListTypeCheckErrorKind::NotSet) + ); + } + + // Got null + { + type RustTyp = Vec; + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Bad element type + { + assert_type_check_error!( + &Bytes::new(), + Vec, + ColumnType::List(Box::new(ColumnType::Ascii)), + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(_) + ) + ); + + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::Varint)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::Varint)),); + let BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(ref err), + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + { + let ser_typ = ColumnType::List(Box::new(ColumnType::Int)); + let v = vec![123_i32]; + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::>( + &ColumnType::List(Box::new(ColumnType::BigInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!(err.cql_type, ColumnType::List(Box::new(ColumnType::BigInt)),); + let BuiltinDeserializationErrorKind::SetOrListError( + SetOrListDeserializationErrorKind::ElementDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } + } + + #[test] + fn test_map_errors() { + // Not a map + { + let ser_typ = ColumnType::Float; + let v = 2.12_f32; + let bytes = serialize(&ser_typ, &v); + + assert_type_check_error!( + &bytes, + HashMap, + ser_typ, + BuiltinTypeCheckErrorKind::MapError( + MapTypeCheckErrorKind::NotMap, + ) + ); + } + + // Got null + { + type RustTyp = HashMap; + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + + let err = RustTyp::deserialize(&ser_typ, None).unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ser_typ); + assert_matches!(err.kind, BuiltinDeserializationErrorKind::ExpectedNonNull); + } + + // Key type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Varint), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::KeyTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Varint); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + + // Value type mismatch + { + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinTypeCheckErrorKind::MapError(MapTypeCheckErrorKind::ValueTypeCheckFailed( + ref err, + )) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::<&str>()); + assert_eq!(err.cql_type, ColumnType::Boolean); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::Ascii, ColumnType::Text] + } + ); + } + + // Key length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::BigInt), Box::new(ColumnType::Boolean)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::KeyDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::BigInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + + // Value length mismatch + { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + let err = deserialize::>( + &ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::>()); + assert_eq!( + err.cql_type, + ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::SmallInt)) + ); + let BuiltinDeserializationErrorKind::MapError( + MapDeserializationErrorKind::ValueDeserializationFailed(err), + ) = &err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 2, + got: 1 + } + ); + } + } + + #[test] + fn test_tuple_errors() { + // Not a tuple + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::BigInt, + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::NotTuple) + ); + } + // Wrong element count + { + assert_type_check_error!( + &Bytes::new(), + (i64,), + ColumnType::Tuple(vec![]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 0, + }) + ); + + assert_type_check_error!( + &Bytes::new(), + (f32,), + ColumnType::Tuple(vec![ColumnType::Float, ColumnType::Float]), + BuiltinTypeCheckErrorKind::TupleError(TupleTypeCheckErrorKind::WrongElementCount { + rust_type_el_count: 1, + cql_type_el_count: 2, + }) + ); + } + + // Bad field type + { + { + let err = deserialize::<(i64,)>( + &ColumnType::Tuple(vec![ColumnType::SmallInt]), + &Bytes::new(), + ) + .unwrap_err(); + let err = get_typeck_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i64,)>()); + assert_eq!(err.cql_type, ColumnType::Tuple(vec![ColumnType::SmallInt])); + let BuiltinTypeCheckErrorKind::TupleError( + TupleTypeCheckErrorKind::FieldTypeCheckFailed { ref err, position }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(position, 0); + let err = get_typeck_err_inner(err.0.as_ref()); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::SmallInt); + assert_matches!( + err.kind, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[ColumnType::BigInt, ColumnType::Counter] + } + ); + } + } + + { + let ser_typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Float]); + let v = (123_i32, 123.123_f32); + let bytes = serialize(&ser_typ, &v); + + { + let err = deserialize::<(i32, f64)>( + &ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]), + &bytes, + ) + .unwrap_err(); + let err = get_deser_err(&err); + assert_eq!(err.rust_name, std::any::type_name::<(i32, f64)>()); + assert_eq!( + err.cql_type, + ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Double]) + ); + let BuiltinDeserializationErrorKind::TupleError( + TupleDeserializationErrorKind::FieldDeserializationFailed { + ref err, + position: index, + }, + ) = err.kind + else { + panic!("unexpected error kind: {}", err.kind) + }; + assert_eq!(index, 1); + let err = get_deser_err(err); + assert_eq!(err.rust_name, std::any::type_name::()); + assert_eq!(err.cql_type, ColumnType::Double); + assert_matches!( + err.kind, + BuiltinDeserializationErrorKind::ByteLengthMismatch { + expected: 8, + got: 4 + } + ); + } + } + } + + #[test] + fn test_null_errors() { + let ser_typ = ColumnType::Map(Box::new(ColumnType::Int), Box::new(ColumnType::Boolean)); + let v = HashMap::from([(42, false), (2137, true)]); + let bytes = serialize(&ser_typ, &v as &dyn SerializeValue); + + deserialize::>(&ser_typ, &bytes).unwrap_err(); + } +} diff --git a/scylla-cql/src/types/mod.rs b/scylla-cql/src/types/mod.rs index ec9942a885..59bce9522d 100644 --- a/scylla-cql/src/types/mod.rs +++ b/scylla-cql/src/types/mod.rs @@ -1 +1,2 @@ +pub mod deserialize; pub mod serialize; diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index 1b46559698..e7b9afb7ee 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -126,7 +126,41 @@ pub mod frame { } } -pub use scylla_cql::types::serialize; +/// Serializing bound values of a query to be sent to the DB. +pub mod serialize { + pub use scylla_cql::types::serialize::*; +} + +/// Deserializing DB response containing CQL query results. +pub mod deserialize { + pub use scylla_cql::types::deserialize::{ + DeserializationError, DeserializeRow, DeserializeValue, FrameSlice, TypeCheckError, + }; + + /// Deserializing the whole query result contents. + pub mod result { + pub use scylla_cql::types::deserialize::result::{RowIterator, TypedRowIterator}; + } + + /// Deserializing a row of the query result. + pub mod row { + pub use scylla_cql::types::deserialize::row::{ + BuiltinDeserializationError, BuiltinDeserializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, ColumnIterator, RawColumn, + }; + } + + /// Deserializing a single CQL value from a column of the query result row. + pub mod value { + pub use scylla_cql::types::deserialize::value::{ + BuiltinDeserializationError, BuiltinDeserializationErrorKind, BuiltinTypeCheckError, + BuiltinTypeCheckErrorKind, Emptiable, ListlikeIterator, MapDeserializationErrorKind, + MapIterator, MapTypeCheckErrorKind, MaybeEmpty, SetOrListDeserializationErrorKind, + SetOrListTypeCheckErrorKind, TupleDeserializationErrorKind, TupleTypeCheckErrorKind, + UdtIterator, UdtTypeCheckErrorKind, + }; + } +} pub mod authentication; #[cfg(feature = "cloud")]