diff --git a/scylla-cql/src/types/serialize/mod.rs b/scylla-cql/src/types/serialize/mod.rs index 0cda84e252..44b740c611 100644 --- a/scylla-cql/src/types/serialize/mod.rs +++ b/scylla-cql/src/types/serialize/mod.rs @@ -4,3 +4,426 @@ pub mod row; pub mod value; type SerializationError = Arc; + +/// An interface that facilitates writing values for a CQL query. +pub trait RowWriter { + type CellWriter<'a>: CellWriter + where + Self: 'a; + + /// Appends a new value to the sequence and returns an object that allows + /// to fill it in. + fn make_cell_writer(&mut self) -> Self::CellWriter<'_>; +} + +/// Represents a handle to a CQL value that needs to be written into. +/// +/// The writer can either be transformed into a ready value right away +/// (via [`set_null`](CellWriter::set_null), +/// [`set_unset`](CellWriter::set_unset) +/// or [`set_value`](CellWriter::set_value) or transformed into +/// the [`CellWriter::ValueBuilder`] in order to gradually initialize +/// the value when the contents are not available straight away. +/// +/// After the value is fully initialized, the handle is consumed and +/// a [`WrittenCellProof`](CellWriter::WrittenCellProof) object is returned +/// in its stead. This is a type-level proof that the value was fully initialized +/// and is used in [`SerializeCql::serialize`](`value::SerializeCql::serialize`) +/// in order to enforce the implementor to fully initialize the provided handle +/// to CQL value. +/// +/// Dropping this type without calling any of its methods will result +/// in nothing being written. +pub trait CellWriter { + /// The type of the value builder, returned by the [`CellWriter::set_value`] + /// method. + type ValueBuilder: CellValueBuilder; + + /// An object that serves as a proof that the cell was fully initialized. + /// + /// This type is returned by [`set_null`](CellWriter::set_null), + /// [`set_unset`](CellWriter::set_unset), + /// [`set_value`](CellWriter::set_value) + /// and also [`CellValueBuilder::finish`] - generally speaking, after + /// the value is fully initialized and the `CellWriter` is destroyed. + /// + /// The purpose of this type is to enforce the contract of + /// [`SerializeCql::serialize`](value::SerializeCql::serialize): either + /// the method succeeds and returns a proof that it serialized itself + /// into the given value, or it fails and returns an error or panics. + /// The exact type of [`WrittenCellProof`](CellWriter::WrittenCellProof) + /// is not important as the value is not used at all - it's only + /// a compile-time check. + type WrittenCellProof; + + /// Sets this value to be null, consuming this object. + fn set_null(self) -> Self::WrittenCellProof; + + /// Sets this value to represent an unset value, consuming this object. + fn set_unset(self) -> Self::WrittenCellProof; + + /// Sets this value to a non-zero, non-unset value with given contents. + /// + /// Prefer this to [`into_value_builder`](CellWriter::into_value_builder) + /// if you have all of the contents of the value ready up front (e.g. for + /// fixed size types). + fn set_value(self, contents: &[u8]) -> Self::WrittenCellProof; + + /// Turns this writter into a [`CellValueBuilder`] which can be used + /// to gradually initialize the CQL value. + /// + /// This method should be used if you don't have all of the data + /// up front, e.g. when serializing compound types such as collections + /// or UDTs. + fn into_value_builder(self) -> Self::ValueBuilder; +} + +/// Allows appending bytes to a non-null, non-unset cell. +/// +/// This object needs to be dropped in order for the value to be correctly +/// serialized. Failing to drop this value will result in a payload that will +/// not be parsed by the database correctly, but otherwise should not cause +/// data to be misinterpreted. +pub trait CellValueBuilder { + type SubCellWriter<'a>: CellWriter + where + Self: 'a; + + type WrittenCellProof; + + /// Appends raw bytes to this cell. + fn append_bytes(&mut self, bytes: &[u8]); + + /// Appends a sub-value to the end of the current contents of the cell + /// and returns an object that allows to fill it in. + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_>; + + /// Finishes serializing the value. + fn finish(self) -> Self::WrittenCellProof; +} + +/// A row writer backed by a buffer (vec). +pub struct BufBackedRowWriter<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Number of values written so far. + value_count: u16, +} + +impl<'buf> BufBackedRowWriter<'buf> { + /// Creates a new row writer based on an existing Vec. + /// + /// The newly created row writer will append data to the end of the vec. + #[inline] + pub fn new(buf: &'buf mut Vec) -> Self { + Self { + buf, + value_count: 0, + } + } + + /// Returns the number of values that were written so far. + #[inline] + pub fn value_count(&self) -> u16 { + self.value_count + } +} + +impl<'buf> RowWriter for BufBackedRowWriter<'buf> { + type CellWriter<'a> = BufBackedCellWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + self.value_count = self + .value_count + .checked_add(1) + .expect("tried to serialize too many values for a query (more than u16::MAX)"); + BufBackedCellWriter::new(self.buf) + } +} + +/// A cell writer backed by a buffer (vec). +pub struct BufBackedCellWriter<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, +} + +impl<'buf> BufBackedCellWriter<'buf> { + #[inline] + fn new(buf: &'buf mut Vec) -> Self { + BufBackedCellWriter { buf } + } +} + +impl<'buf> CellWriter for BufBackedCellWriter<'buf> { + type ValueBuilder = BufBackedCellValueBuilder<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + self.buf.extend_from_slice(&(-1i32).to_be_bytes()); + } + + #[inline] + fn set_unset(self) { + self.buf.extend_from_slice(&(-2i32).to_be_bytes()); + } + + #[inline] + fn set_value(self, bytes: &[u8]) { + let value_len: i32 = bytes + .len() + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf.extend_from_slice(&value_len.to_be_bytes()); + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + BufBackedCellValueBuilder::new(self.buf) + } +} + +/// A cell value builder backed by a buffer (vec). +pub struct BufBackedCellValueBuilder<'buf> { + // Buffer that this value should be serialized to. + buf: &'buf mut Vec, + + // Starting position of the value in the buffer. + starting_pos: usize, +} + +impl<'buf> BufBackedCellValueBuilder<'buf> { + #[inline] + fn new(buf: &'buf mut Vec) -> Self { + // "Length" of a [bytes] frame can either be a non-negative i32, + // -1 (null) or -1 (not set). Push an invalid value here. It will be + // overwritten eventually either by set_null, set_unset or Drop. + // If the CellSerializer is not dropped as it should, this will trigger + // an error on the DB side and the serialized data + // won't be misinterpreted. + let starting_pos = buf.len(); + buf.extend_from_slice(&(-3i32).to_be_bytes()); + BufBackedCellValueBuilder { buf, starting_pos } + } +} + +impl<'buf> CellValueBuilder for BufBackedCellValueBuilder<'buf> { + type SubCellWriter<'a> = BufBackedCellWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + self.buf.extend_from_slice(bytes); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + BufBackedCellWriter::new(self.buf) + } + + #[inline] + fn finish(self) { + // TODO: Should this panic, or should we catch this error earlier? + // Vec will panic anyway if we overflow isize, so at least this + // behavior is consistent with what the stdlib does. + let value_len: i32 = (self.buf.len() - self.starting_pos - 4) + .try_into() + .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + self.buf[self.starting_pos..self.starting_pos + 4] + .copy_from_slice(&value_len.to_be_bytes()); + } +} + +/// A writer that does not actually write anything, just counts the bytes. +/// +/// It can serve as a: +/// +/// - [`RowWriter`] +/// - [`CellWriter`] +/// - [`CellValueBuilder`] +pub struct CountingWriter<'buf> { + buf: &'buf mut usize, +} + +impl<'buf> CountingWriter<'buf> { + /// Creates a new writer which increments the counter under given reference + /// when bytes are appended. + #[inline] + fn new(buf: &'buf mut usize) -> Self { + CountingWriter { buf } + } +} + +impl<'buf> RowWriter for CountingWriter<'buf> { + type CellWriter<'a> = CountingWriter<'a> where Self: 'a; + + #[inline] + fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellWriter for CountingWriter<'buf> { + type ValueBuilder = CountingWriter<'buf>; + + type WrittenCellProof = (); + + #[inline] + fn set_null(self) { + *self.buf += 4; + } + + #[inline] + fn set_unset(self) { + *self.buf += 4; + } + + #[inline] + fn set_value(self, contents: &[u8]) { + *self.buf += 4 + contents.len(); + } + + #[inline] + fn into_value_builder(self) -> Self::ValueBuilder { + *self.buf += 4; + CountingWriter::new(self.buf) + } +} + +impl<'buf> CellValueBuilder for CountingWriter<'buf> { + type SubCellWriter<'a> = CountingWriter<'a> + where + Self: 'a; + + type WrittenCellProof = (); + + #[inline] + fn append_bytes(&mut self, bytes: &[u8]) { + *self.buf += bytes.len(); + } + + #[inline] + fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { + CountingWriter::new(self.buf) + } + + #[inline] + fn finish(self) -> Self::WrittenCellProof {} +} + +#[cfg(test)] +mod tests { + use crate::types::serialize::{BufBackedRowWriter, CellValueBuilder}; + + use super::{BufBackedCellWriter, CellWriter, CountingWriter, RowWriter}; + + // We want to perform the same computation for both buf backed writer + // and counting writer, but Rust does not support generic closures. + // This trait comes to the rescue. + trait CellSerializeCheck { + fn check(&self, writer: W); + } + + fn check_cell_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let writer = BufBackedCellWriter::new(&mut data); + c.check(writer); + + let mut byte_count = 0usize; + let counting_writer = CountingWriter::new(&mut byte_count); + c.check(counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_cell_writer() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let mut sub_writer = writer.into_value_builder(); + sub_writer.make_sub_writer().set_null(); + sub_writer.make_sub_writer().set_value(&[1, 2, 3, 4]); + sub_writer.make_sub_writer().set_unset(); + sub_writer.finish(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 0, 0, 0, 16, // Length of inner data is 16 + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ); + } + + #[test] + fn test_poisoned_appender() { + struct Check; + impl CellSerializeCheck for Check { + fn check(&self, writer: W) { + let _ = writer.into_value_builder(); + } + } + + let data = check_cell_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 253, // Invalid value + ] + ); + } + + trait RowSerializeCheck { + fn check(&self, writer: &mut W); + } + + fn check_row_serialize(c: C) -> Vec { + let mut data = Vec::new(); + let mut writer = BufBackedRowWriter::new(&mut data); + c.check(&mut writer); + std::mem::drop(writer); + + let mut byte_count = 0usize; + let mut counting_writer = CountingWriter::new(&mut byte_count); + c.check(&mut counting_writer); + + assert_eq!(data.len(), byte_count); + data + } + + #[test] + fn test_row_writer() { + struct Check; + impl RowSerializeCheck for Check { + fn check(&self, writer: &mut W) { + writer.make_cell_writer().set_null(); + writer.make_cell_writer().set_value(&[1, 2, 3, 4]); + writer.make_cell_writer().set_unset(); + } + } + + let data = check_row_serialize(Check); + assert_eq!( + data, + [ + 255, 255, 255, 255, // Null (encoded as -1) + 0, 0, 0, 4, 1, 2, 3, 4, // Four byte value + 255, 255, 255, 254, // Unset (encoded as -2) + ] + ) + } +} diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 2e9832412d..d1c55e1302 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,20 +1,25 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; + +use thiserror::Error; -use crate::frame::response::result::ColumnSpec; use crate::frame::value::ValueList; +use crate::frame::{response::result::ColumnSpec, types::RawValue}; -use super::SerializationError; +use super::{CellWriter, RowWriter, SerializationError}; +/// Contains information needed to serialize a row. pub struct RowSerializationContext<'a> { columns: &'a [ColumnSpec], } impl<'a> RowSerializationContext<'a> { + /// Returns column/bind marker specifications for given query. #[inline] pub fn columns(&self) -> &'a [ColumnSpec] { self.columns } + /// Looks up and returns a column/bind marker by name. // TODO: change RowSerializationContext to make this faster #[inline] pub fn column_by_name(&self, target: &str) -> Option<&ColumnSpec> { @@ -23,11 +28,25 @@ impl<'a> RowSerializationContext<'a> { } pub trait SerializeRow { + /// Checks if it _might_ be possible to serialize the row according to the + /// information in the context. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Sometimes, a row cannot be fully type checked right away without knowing + /// the exact values of the columns (e.g. when deserializing to `CqlValue`), + /// but it's fine to do full type checking later in `serialize`. fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError>; - fn serialize( + + /// Serializes the row according to the information in the given context. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( &self, ctx: &RowSerializationContext<'_>, - out: &mut Vec, + writer: &mut W, ) -> Result<(), SerializationError>; } @@ -38,12 +57,136 @@ impl SerializeRow for T { Ok(()) } - fn serialize( + fn serialize( &self, - _ctx: &RowSerializationContext<'_>, - out: &mut Vec, + ctx: &RowSerializationContext<'_>, + writer: &mut W, ) -> Result<(), SerializationError> { - self.write_to_request(out) - .map_err(|err| Arc::new(err) as SerializationError) + serialize_legacy_row(self, ctx, writer) + } +} + +pub fn serialize_legacy_row( + r: &T, + ctx: &RowSerializationContext<'_>, + writer: &mut impl RowWriter, +) -> Result<(), SerializationError> { + let serialized = + ::serialized(r).map_err(|err| Arc::new(err) as SerializationError)?; + + let mut append_value = |value: RawValue| { + let cell_writer = writer.make_cell_writer(); + let _proof = match value { + RawValue::Null => cell_writer.set_null(), + RawValue::Unset => cell_writer.set_unset(), + RawValue::Value(v) => cell_writer.set_value(v), + }; + }; + + if !serialized.has_names() { + serialized.iter().for_each(append_value); + } else { + let values_by_name = serialized + .iter_name_value_pairs() + .map(|(k, v)| (k.unwrap(), v)) + .collect::>(); + + for col in ctx.columns() { + let val = values_by_name.get(col.name.as_str()).ok_or_else(|| { + Arc::new(ValueListToSerializeRowAdapterError::NoBindMarkerWithName { + name: col.name.clone(), + }) as SerializationError + })?; + append_value(*val); + } + } + + Ok(()) +} + +#[derive(Error, Debug)] +pub enum ValueListToSerializeRowAdapterError { + #[error("There is no bind marker with name {name}, but a value for it was provided")] + NoBindMarkerWithName { name: String }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::{ColumnSpec, ColumnType, TableSpec}; + use crate::frame::value::{MaybeUnset, SerializedValues, ValueList}; + use crate::types::serialize::BufBackedRowWriter; + + use super::{RowSerializationContext, SerializeRow}; + + fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } + } + + #[test] + fn test_legacy_fallback() { + let row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut legacy_data = Vec::new(); + <_ as ValueList>::write_to_request(&row, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let mut new_data_writer = BufBackedRowWriter::new(&mut new_data); + let ctx = RowSerializationContext { columns: &[] }; + <_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap(); + assert_eq!(new_data_writer.value_count(), 4); + std::mem::drop(new_data_writer); + + // Skip the value count + assert_eq!(&legacy_data[2..], new_data); + } + + #[test] + fn test_legacy_fallback_with_names() { + let sorted_row = ( + 1i32, + "Ala ma kota", + None::, + MaybeUnset::Unset::, + ); + + let mut sorted_row_data = Vec::new(); + <_ as ValueList>::write_to_request(&sorted_row, &mut sorted_row_data).unwrap(); + + let mut unsorted_row = SerializedValues::new(); + unsorted_row.add_named_value("a", &1i32).unwrap(); + unsorted_row.add_named_value("b", &"Ala ma kota").unwrap(); + unsorted_row + .add_named_value("d", &MaybeUnset::Unset::) + .unwrap(); + unsorted_row.add_named_value("c", &None::).unwrap(); + + let mut unsorted_row_data = Vec::new(); + let mut unsorted_row_data_writer = BufBackedRowWriter::new(&mut unsorted_row_data); + let ctx = RowSerializationContext { + columns: &[ + col_spec("a", ColumnType::Int), + col_spec("b", ColumnType::Text), + col_spec("c", ColumnType::BigInt), + col_spec("d", ColumnType::Ascii), + ], + }; + <_ as SerializeRow>::serialize(&unsorted_row, &ctx, &mut unsorted_row_data_writer).unwrap(); + assert_eq!(unsorted_row_data_writer.value_count(), 4); + std::mem::drop(unsorted_row_data_writer); + + // Skip the value count + assert_eq!(&sorted_row_data[2..], unsorted_row_data); } } diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index 43eb9ef738..25d605d13d 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1,13 +1,32 @@ use std::sync::Arc; +use thiserror::Error; + use crate::frame::response::result::ColumnType; use crate::frame::value::Value; -use super::SerializationError; +use super::{CellWriter, SerializationError}; pub trait SerializeCql { + /// Given a CQL type, checks if it _might_ be possible to serialize to that type. + /// + /// This function is intended to serve as an optimization in the future, + /// if we were ever to introduce prepared statements parametrized by types. + /// + /// Some types cannot be type checked without knowing the exact value, + /// this is the case e.g. for `CqlValue`. It's also fine to do it later in + /// `serialize`. fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError>; - fn serialize(&self, typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError>; + + /// Serializes the value to given CQL type. + /// + /// The function may assume that `preliminary_type_check` was called, + /// though it must not do anything unsafe if this assumption does not hold. + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result; } impl SerializeCql for T { @@ -15,8 +34,89 @@ impl SerializeCql for T { Ok(()) } - fn serialize(&self, _typ: &ColumnType, buf: &mut Vec) -> Result<(), SerializationError> { - self.serialize(buf) - .map_err(|err| Arc::new(err) as SerializationError) + fn serialize( + &self, + _typ: &ColumnType, + writer: W, + ) -> Result { + serialize_legacy_value(self, writer) + } +} + +pub fn serialize_legacy_value( + v: &T, + writer: W, +) -> Result { + // It's an inefficient and slightly tricky but correct implementation. + let mut buf = Vec::new(); + ::serialize(v, &mut buf).map_err(|err| Arc::new(err) as SerializationError)?; + + // Analyze the output. + // All this dance shows how unsafe our previous interface was... + if buf.len() < 4 { + return Err(Arc::new(ValueToSerializeCqlAdapterError::TooShort { + size: buf.len(), + })); + } + + let (len_bytes, contents) = buf.split_at(4); + let len = i32::from_be_bytes(len_bytes.try_into().unwrap()); + match len { + -2 => Ok(writer.set_unset()), + -1 => Ok(writer.set_null()), + len if len >= 0 => { + if contents.len() != len as usize { + Err(Arc::new( + ValueToSerializeCqlAdapterError::DeclaredVsActualSizeMismatch { + declared: len as usize, + actual: contents.len(), + }, + )) + } else { + Ok(writer.set_value(contents)) + } + } + _ => Err(Arc::new( + ValueToSerializeCqlAdapterError::InvalidDeclaredSize { size: len }, + )), + } +} + +#[derive(Error, Debug)] +pub enum ValueToSerializeCqlAdapterError { + #[error("Output produced by the Value trait is too short to be considered a value: {size} < 4 minimum bytes")] + TooShort { size: usize }, + + #[error("Mismatch between the declared value size vs. actual size: {declared} != {actual}")] + DeclaredVsActualSizeMismatch { declared: usize, actual: usize }, + + #[error("Invalid declared value size: {size}")] + InvalidDeclaredSize { size: i32 }, +} + +#[cfg(test)] +mod tests { + use crate::frame::response::result::ColumnType; + use crate::frame::value::{MaybeUnset, Value}; + use crate::types::serialize::BufBackedCellWriter; + + use super::SerializeCql; + + fn check_compat(v: V) { + let mut legacy_data = Vec::new(); + ::serialize(&v, &mut legacy_data).unwrap(); + + let mut new_data = Vec::new(); + let new_data_writer = BufBackedCellWriter::new(&mut new_data); + ::serialize(&v, &ColumnType::Int, new_data_writer).unwrap(); + + assert_eq!(legacy_data, new_data); + } + + #[test] + fn test_legacy_fallback() { + check_compat(123i32); + check_compat(None::); + check_compat(MaybeUnset::Unset::); } }