Skip to content

Commit

Permalink
review feedback: Change short to u16 as well as modify tests to be si…
Browse files Browse the repository at this point in the history
…mpler
  • Loading branch information
samuelorji committed Oct 9, 2023
1 parent 772c46f commit fdee6d6
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 87 deletions.
4 changes: 2 additions & 2 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ pub enum BadQuery {
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")]
TooManyQueriesInBatchStatement,
#[error("Number of Queries in Batch Statement supplied is {0} which has exceeded the max value of 65,535")]
TooManyQueriesInBatchStatement(usize),

/// Other reasons of bad query
#[error("{0}")]
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/frame_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub enum ParseError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("type not yet implemented, id: {0}")]
TypeNotImplemented(i16),
TypeNotImplemented(u16),
#[error(transparent)]
SerializeValuesError(#[from] SerializeValuesError),
#[error(transparent)]
Expand Down
4 changes: 2 additions & 2 deletions scylla-cql/src/frame/request/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ where
buf.put_u8(self.batch_type as u8);

// Serializing queries
types::write_u16(self.statements.len().try_into()?, buf);
types::write_short(self.statements.len().try_into()?, buf);

let counts_mismatch_err = |n_values: usize, n_statements: usize| {
ParseError::BadDataToSerialize(format!(
Expand Down Expand Up @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec<SerializedV
fn deserialize(buf: &mut &[u8]) -> Result<Self, ParseError> {
let batch_type = buf.get_u8().try_into()?;

let statements_count: usize = types::read_u16(buf)?.try_into()?;
let statements_count: usize = types::read_short(buf)?.try_into()?;
let statements_with_values = (0..statements_count)
.map(|_| {
let batch_statement = BatchStatement::deserialize(buf)?;
Expand Down
29 changes: 8 additions & 21 deletions scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use uuid::Uuid;
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "SCREAMING_SNAKE_CASE"))]
#[repr(i16)]
#[repr(u16)]
pub enum Consistency {
Any = 0x0000,
One = 0x0001,
Expand Down Expand Up @@ -175,8 +175,8 @@ fn type_long() {
}
}

pub fn read_short(buf: &mut &[u8]) -> Result<i16, ParseError> {
let v = buf.read_i16::<BigEndian>()?;
pub fn read_short(buf: &mut &[u8]) -> Result<u16, ParseError> {
let v = buf.read_u16::<BigEndian>()?;
Ok(v)
}

Expand All @@ -185,11 +185,7 @@ pub fn read_u16(buf: &mut &[u8]) -> Result<u16, ParseError> {
Ok(v)
}

pub fn write_short(v: i16, buf: &mut impl BufMut) {
buf.put_i16(v);
}

pub fn write_u16(v: u16, buf: &mut impl BufMut) {
pub fn write_short(v: u16, buf: &mut impl BufMut) {
buf.put_u16(v);
}

Expand All @@ -200,30 +196,21 @@ pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result<usize, ParseError> {
}

fn write_short_length(v: usize, buf: &mut impl BufMut) -> Result<(), ParseError> {
let v: i16 = v.try_into()?;
let v: u16 = v.try_into()?;
write_short(v, buf);
Ok(())
}

#[test]
fn type_short() {
let vals = [i16::MIN, -1, 0, 1, i16::MAX];
let vals: [u16; 3] = [0, 1, u16::MAX];
for val in vals.iter() {
let mut buf = Vec::new();
write_short(*val, &mut buf);
assert_eq!(read_short(&mut &buf[..]).unwrap(), *val);
}
}

#[test]
fn type_u16() {
let vals = [0, 1, u16::MAX];
for val in vals.iter() {
let mut buf = Vec::new();
write_u16(*val, &mut buf);
assert_eq!(read_u16(&mut &buf[..]).unwrap(), *val);
}
}
// https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208
pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result<Option<&'a [u8]>, ParseError> {
let len = read_int(buf)?;
Expand Down Expand Up @@ -488,11 +475,11 @@ pub fn read_consistency(buf: &mut &[u8]) -> Result<Consistency, ParseError> {
}

pub fn write_consistency(c: Consistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

pub fn write_serial_consistency(c: SerialConsistency, buf: &mut impl BufMut) {
write_short(c as i16, buf);
write_short(c as u16, buf);
}

#[test]
Expand Down
10 changes: 5 additions & 5 deletions scylla-cql/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct Time(pub Duration);
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct SerializedValues {
serialized_values: Vec<u8>,
values_num: i16,
values_num: u16,
contains_names: bool,
}

Expand Down Expand Up @@ -134,7 +134,7 @@ impl SerializedValues {
if self.contains_names {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -158,7 +158,7 @@ impl SerializedValues {
return Err(SerializeValuesError::MixingNamedAndNotNamedValues);
}
self.contains_names = true;
if self.values_num == i16::MAX {
if self.values_num == u16::MAX {
return Err(SerializeValuesError::TooManyValues);
}

Expand All @@ -184,15 +184,15 @@ impl SerializedValues {
}

pub fn write_to_request(&self, buf: &mut impl BufMut) {
buf.put_i16(self.values_num);
buf.put_u16(self.values_num);
buf.put(&self.serialized_values[..]);
}

pub fn is_empty(&self) -> bool {
self.values_num == 0
}

pub fn len(&self) -> i16 {
pub fn len(&self) -> u16 {
self.values_num
}

Expand Down
1 change: 0 additions & 1 deletion scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0
tracing-subscriber = { version = "0.3.14", features = ["env-filter"] }
assert_matches = "1.5.0"
rand_chacha = "0.3.1"
bcs = "0.1.5"

[[bench]]
name = "benchmark"
Expand Down
2 changes: 1 addition & 1 deletion scylla/src/statement/prepared_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ impl PreparedStatement {
#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
pub enum PartitionKeyExtractionError {
#[error("No value with given pk_index! pk_index: {0}, values.len(): {1}")]
NoPkIndexValue(u16, i16),
NoPkIndexValue(u16, u16),
}

#[derive(Clone, Debug, Error, PartialEq, Eq, PartialOrd, Ord)]
Expand Down
79 changes: 27 additions & 52 deletions scylla/src/transport/large_batch_statements_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use bcs::serialize_into;
use assert_matches::assert_matches;

use scylla_cql::errors::{BadQuery, QueryError};
use scylla_cql::Consistency;

use crate::batch::BatchType;
use crate::query::Query;
Expand All @@ -16,48 +18,55 @@ async fn test_large_batch_statements() {
let ks = unique_keyspace_name();
session = create_test_session(session, &ks).await;

// table should be initially empty
let query_result = simple_fetch_all(&session, &ks).await;
assert_eq!(query_result.rows.unwrap().len(), 0);

// Add batch
let max_number_of_queries = u16::MAX as usize;
write_batch(&session, max_number_of_queries).await;
let _ = write_batch(&session, max_number_of_queries, &ks).await;

let key_prefix = vec![0];
let keys = find_keys_by_prefix(&session, key_prefix.clone()).await;
assert_eq!(keys.len(), max_number_of_queries);
// Query batch
let query_result = simple_fetch_all(&session, &ks).await;
assert_eq!(query_result.rows.unwrap().len(), max_number_of_queries);

// Now try with too many queries
let too_many_queries = u16::MAX as usize + 1;

let err = write_batch(&session, too_many_queries).await;

assert!(err.is_err());
let batch_insert_result = write_batch(&session, too_many_queries, &ks).await;
assert_matches!(
batch_insert_result.unwrap_err(),
QueryError::BadQuery(BadQuery::TooManyQueriesInBatchStatement(too_many_queries))
)
}

async fn create_test_session(session: Session, ks: &String) -> Session {
session
.query(
format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks),
format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1 }}",ks),
&[],
)
.await.unwrap();
session
.query("DROP TABLE IF EXISTS kv.pairs;", &[])
.query(format!("DROP TABLE IF EXISTS {}.pairs;", ks), &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))",
format!("CREATE TABLE IF NOT EXISTS {}.pairs (dummy int, k blob, v blob, primary key (dummy, k))", ks),
&[],
)
.await.unwrap();
session
}

async fn write_batch(session: &Session, n: usize) -> Result<QueryResult, QueryError> {
async fn write_batch(session: &Session, n: usize, ks: &String) -> Result<QueryResult, QueryError> {
let mut batch_query = Batch::new(BatchType::Logged);
let mut batch_values = Vec::new();
for i in 0..n {
let mut key = vec![0];
serialize_into(&mut key, &(i as usize)).unwrap();
key.extend(i.to_be_bytes().as_slice());
let value = key.clone();
let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)";
let query = format!("INSERT INTO {}.pairs (dummy, k, v) VALUES (0, ?, ?)", ks);
let values = vec![key, value];
batch_values.push(values);
let query = Query::new(query);
Expand All @@ -66,41 +75,7 @@ async fn write_batch(session: &Session, n: usize) -> Result<QueryResult, QueryEr
session.batch(&batch_query, batch_values).await
}

async fn find_keys_by_prefix(session: &Session, key_prefix: Vec<u8>) -> Vec<Vec<u8>> {
let len = key_prefix.len();
let rows = match get_upper_bound_option(&key_prefix) {
None => {
let values = (key_prefix,);
let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING";
session.query(query, values).await.unwrap()
}
Some(upper_bound) => {
let values = (key_prefix, upper_bound);
let query =
"SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING";
session.query(query, values).await.unwrap()
}
};
let mut keys = Vec::new();
if let Some(rows) = rows.rows {
for row in rows.into_typed::<(Vec<u8>,)>() {
let key = row.unwrap();
let short_key = key.0[len..].to_vec();
keys.push(short_key);
}
}
keys
}

fn get_upper_bound_option(key_prefix: &[u8]) -> Option<Vec<u8>> {
let len = key_prefix.len();
for i in (0..len).rev() {
let val = key_prefix[i];
if val < u8::MAX {
let mut upper_bound = key_prefix[0..i + 1].to_vec();
upper_bound[i] += 1;
return Some(upper_bound);
}
}
None
async fn simple_fetch_all(session: &Session, ks: &String) -> QueryResult {
let select_query = format!("SELECT * FROM {}.pairs", ks);
session.query(select_query, &[]).await.unwrap()
}
5 changes: 3 additions & 2 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1145,9 +1145,10 @@ impl Session {
// If users batch statements by shard, they will be rewarded with full shard awareness

// check to ensure that we don't send a batch statement with more than u16::MAX queries
if batch.statements.len() > u16::MAX as usize {
let batch_statements_length = batch.statements.len();
if batch_statements_length > u16::MAX as usize {
return Err(QueryError::BadQuery(
BadQuery::TooManyQueriesInBatchStatement,
BadQuery::TooManyQueriesInBatchStatement(batch_statements_length),
));
}
// Extract first serialized_value
Expand Down

0 comments on commit fdee6d6

Please sign in to comment.