Skip to content

Commit

Permalink
Merge pull request scylladb#1065 from wprzytula/new-deserialization-a…
Browse files Browse the repository at this point in the history
…pi-preparations

New deserialization API - preparations
  • Loading branch information
wprzytula authored Aug 29, 2024
2 parents 47e9864 + d47aa81 commit 31f512c
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 111 deletions.
13 changes: 13 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::frame::frame_errors::{CqlResponseParseError, FrameError, ParseError};
use crate::frame::protocol_features::ProtocolFeatures;
use crate::frame::value::SerializeValuesError;
use crate::types::deserialize::{DeserializationError, TypeCheckError};
use crate::types::serialize::SerializationError;
use crate::Consistency;
use bytes::Bytes;
Expand Down Expand Up @@ -461,6 +462,18 @@ impl From<SerializationError> for QueryError {
}
}

impl From<DeserializationError> for QueryError {
fn from(value: DeserializationError) -> Self {
Self::InvalidMessage(value.to_string())
}
}

impl From<TypeCheckError> for QueryError {
fn from(value: TypeCheckError) -> Self {
Self::InvalidMessage(value.to_string())
}
}

impl From<ParseError> for QueryError {
fn from(parse_error: ParseError) -> QueryError {
QueryError::InvalidMessage(format!("Error parsing message: {}", parse_error))
Expand Down
6 changes: 5 additions & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use super::TryFromPrimitiveError;
use crate::cql_to_rust::CqlTypeError;
use crate::frame::value::SerializeValuesError;
use crate::types::deserialize::DeserializationError;
use crate::types::deserialize::{DeserializationError, TypeCheckError};
use crate::types::serialize::SerializationError;
use thiserror::Error;

Expand Down Expand Up @@ -46,6 +46,8 @@ pub enum ParseError {
#[error(transparent)]
DeserializationError(#[from] DeserializationError),
#[error(transparent)]
DeserializationTypeCheckError(#[from] TypeCheckError),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
Expand Down Expand Up @@ -216,6 +218,8 @@ pub enum PreparedParseError {
ResultMetadataParseError(ResultMetadataParseError),
#[error("Invalid prepared metadata: {0}")]
PreparedMetadataParseError(ResultMetadataParseError),
#[error("Non-zero paging state in result metadata: {0:?}")]
NonZeroPagingState(Arc<[u8]>),
}

/// An error type returned when deserialization
Expand Down
8 changes: 8 additions & 0 deletions scylla-cql/src/frame/request/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ impl PagingStateResponse {
Self::NoMorePages => ControlFlow::Break(()),
}
}

/// Swaps the paging state response with PagingStateResponse::NoMorePages.
///
/// Only for use in driver's inner code, as an optimisation.
#[doc(hidden)]
pub fn take(&mut self) -> Self {
std::mem::replace(self, Self::NoMorePages)
}
}

/// The state of a paged query, i.e. where to resume fetching result rows
Expand Down
11 changes: 8 additions & 3 deletions scylla-cql/src/frame/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ pub mod event;
pub mod result;
pub mod supported;

use std::sync::Arc;

pub use error::Error;
pub use supported::Supported;

Expand Down Expand Up @@ -65,17 +67,20 @@ impl Response {
pub fn deserialize(
features: &ProtocolFeatures,
opcode: ResponseOpcode,
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: bytes::Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> Result<Response, CqlResponseParseError> {
let buf = &mut &*buf_bytes;
let response = match opcode {
ResponseOpcode::Error => Response::Error(Error::deserialize(features, buf)?),
ResponseOpcode::Ready => Response::Ready,
ResponseOpcode::Authenticate => {
Response::Authenticate(authenticate::Authenticate::deserialize(buf)?)
}
ResponseOpcode::Supported => Response::Supported(Supported::deserialize(buf)?),
ResponseOpcode::Result => Response::Result(result::deserialize(buf, cached_metadata)?),
ResponseOpcode::Result => {
Response::Result(result::deserialize(buf_bytes, cached_metadata)?)
}
ResponseOpcode::Event => Response::Event(event::Event::deserialize(buf)?),
ResponseOpcode::AuthChallenge => {
Response::AuthChallenge(authenticate::AuthChallenge::deserialize(buf)?)
Expand Down
81 changes: 48 additions & 33 deletions scylla-cql/src/frame/response/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::types::deserialize::value::{
use crate::types::deserialize::{DeserializationError, FrameSlice};
use bytes::{Buf, Bytes};
use std::borrow::Cow;
use std::sync::Arc;
use std::{net::IpAddr, result::Result as StdResult, str};
use uuid::Uuid;

Expand Down Expand Up @@ -431,7 +432,6 @@ pub struct ColumnSpec {
#[derive(Debug, Clone)]
pub struct ResultMetadata {
col_count: usize,
pub paging_state: PagingStateResponse,
pub col_specs: Vec<ColumnSpec>,
}

Expand All @@ -440,10 +440,18 @@ impl ResultMetadata {
pub fn mock_empty() -> Self {
Self {
col_count: 0,
paging_state: PagingStateResponse::NoMorePages,
col_specs: Vec::new(),
}
}

#[inline]
#[doc(hidden)]
pub fn new_for_test(col_count: usize, col_specs: Vec<ColumnSpec>) -> Self {
Self {
col_count,
col_specs,
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -478,7 +486,8 @@ impl Row {

#[derive(Debug)]
pub struct Rows {
pub metadata: ResultMetadata,
pub metadata: Arc<ResultMetadata>,
pub paging_state_response: PagingStateResponse,
pub rows_count: usize,
pub rows: Vec<Row>,
/// Original size of the serialized rows.
Expand Down Expand Up @@ -620,7 +629,9 @@ fn deser_col_specs(
Ok(col_specs)
}

fn deser_result_metadata(buf: &mut &[u8]) -> StdResult<ResultMetadata, ResultMetadataParseError> {
fn deser_result_metadata(
buf: &mut &[u8],
) -> StdResult<(ResultMetadata, PagingStateResponse), ResultMetadataParseError> {
let flags = types::read_int(buf)
.map_err(|err| ResultMetadataParseError::FlagsParseError(err.into()))?;
let global_tables_spec = flags & 0x0001 != 0;
Expand All @@ -635,27 +646,23 @@ fn deser_result_metadata(buf: &mut &[u8]) -> StdResult<ResultMetadata, ResultMet
.transpose()?;
let paging_state = PagingStateResponse::new_from_raw_bytes(raw_paging_state);

if no_metadata {
return Ok(ResultMetadata {
col_count,
paging_state,
col_specs: vec![],
});
}

let global_table_spec = if global_tables_spec {
Some(deser_table_spec(buf)?)
let col_specs = if no_metadata {
vec![]
} else {
None
};
let global_table_spec = if global_tables_spec {
Some(deser_table_spec(buf)?)
} else {
None
};

let col_specs = deser_col_specs(buf, &global_table_spec, col_count)?;
deser_col_specs(buf, &global_table_spec, col_count)?
};

Ok(ResultMetadata {
let metadata = ResultMetadata {
col_count,
paging_state,
col_specs,
})
};
Ok((metadata, paging_state))
}

fn deser_prepared_metadata(
Expand Down Expand Up @@ -859,17 +866,14 @@ pub fn deser_cql_value(
}

fn deser_rows(
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> StdResult<Rows, RowsParseError> {
let server_metadata = deser_result_metadata(buf)?;
let buf = &mut &*buf_bytes;
let (server_metadata, paging_state_response) = deser_result_metadata(buf)?;

let metadata = match cached_metadata {
Some(cached) => ResultMetadata {
col_count: cached.col_count,
paging_state: server_metadata.paging_state,
col_specs: cached.col_specs.clone(),
},
Some(cached) => Arc::clone(cached),
None => {
// No cached_metadata provided. Server is supposed to provide the result metadata.
if server_metadata.col_count != server_metadata.col_specs.len() {
Expand All @@ -878,7 +882,7 @@ fn deser_rows(
col_specs_count: server_metadata.col_specs.len(),
});
}
server_metadata
Arc::new(server_metadata)
}
};

Expand All @@ -899,6 +903,7 @@ fn deser_rows(

Ok(Rows {
metadata,
paging_state_response,
rows_count,
rows,
serialized_size: original_size - buf.len(),
Expand All @@ -919,8 +924,17 @@ fn deser_prepared(buf: &mut &[u8]) -> StdResult<Prepared, PreparedParseError> {
buf.advance(id_len);
let prepared_metadata =
deser_prepared_metadata(buf).map_err(PreparedParseError::PreparedMetadataParseError)?;
let result_metadata =
let (result_metadata, paging_state_response) =
deser_result_metadata(buf).map_err(PreparedParseError::ResultMetadataParseError)?;
if let PagingStateResponse::HasMorePages { state } = paging_state_response {
return Err(PreparedParseError::NonZeroPagingState(
state
.as_bytes_slice()
.cloned()
.unwrap_or_else(|| Arc::from([])),
));
}

Ok(Prepared {
id,
prepared_metadata,
Expand All @@ -935,16 +949,17 @@ fn deser_schema_change(buf: &mut &[u8]) -> StdResult<SchemaChange, SchemaChangeE
}

pub fn deserialize(
buf: &mut &[u8],
cached_metadata: Option<&ResultMetadata>,
buf_bytes: Bytes,
cached_metadata: Option<&Arc<ResultMetadata>>,
) -> StdResult<Result, CqlResultParseError> {
let buf = &mut &*buf_bytes;
use self::Result::*;
Ok(
match types::read_int(buf)
.map_err(|err| CqlResultParseError::ResultIdParseError(err.into()))?
{
0x0001 => Void,
0x0002 => Rows(deser_rows(buf, cached_metadata)?),
0x0002 => Rows(deser_rows(buf_bytes.slice_ref(buf), cached_metadata)?),
0x0003 => SetKeyspace(deser_set_keyspace(buf)?),
0x0004 => Prepared(deser_prepared(buf)?),
0x0005 => SchemaChange(deser_schema_change(buf)?),
Expand Down
2 changes: 2 additions & 0 deletions scylla-cql/src/types/deserialize/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::{DeserializationError, FrameSlice, TypeCheckError};
use std::marker::PhantomData;

/// Iterates over the whole result, returning rows.
#[derive(Debug)]
pub struct RowIterator<'frame> {
specs: &'frame [ColumnSpec],
remaining: usize,
Expand Down Expand Up @@ -76,6 +77,7 @@ impl<'frame> Iterator for RowIterator<'frame> {

/// A typed version of [RowIterator] which deserializes the rows before
/// returning them.
#[derive(Debug)]
pub struct TypedRowIterator<'frame, R> {
inner: RowIterator<'frame>,
_phantom: PhantomData<R>,
Expand Down
Loading

0 comments on commit 31f512c

Please sign in to comment.