Skip to content

Commit

Permalink
add a guard when doing batch statements to prevent making calls to th…
Browse files Browse the repository at this point in the history
…e server when the number of batch queries is greater than u16::MAX, as well as adding some tests
  • Loading branch information
samuelorji committed Oct 6, 2023
1 parent 4c013e8 commit 772c46f
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 1 deletion.
4 changes: 4 additions & 0 deletions scylla-cql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ pub enum BadQuery {
#[error("Passed invalid keyspace name to use: {0}")]
BadKeyspaceName(#[from] BadKeyspaceName),

/// Too many queries in the batch statement
#[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")]
TooManyQueriesInBatchStatement,

/// Other reasons of bad query
#[error("{0}")]
Other(String),
Expand Down
2 changes: 1 addition & 1 deletion scylla-cql/src/frame/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut};
use num_enum::TryFromPrimitive;
use std::collections::HashMap;
use std::convert::{Infallible, TryFrom};
use std::convert::TryInto;
use std::convert::{Infallible, TryFrom};
use std::net::IpAddr;
use std::net::SocketAddr;
use std::str;
Expand Down
1 change: 1 addition & 0 deletions scylla/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0
tracing-subscriber = { version = "0.3.14", features = ["env-filter"] }
assert_matches = "1.5.0"
rand_chacha = "0.3.1"
bcs = "0.1.5"

[[bench]]
name = "benchmark"
Expand Down
106 changes: 106 additions & 0 deletions scylla/src/transport/large_batch_statements_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use bcs::serialize_into;
use scylla_cql::errors::{BadQuery, QueryError};

use crate::batch::BatchType;
use crate::query::Query;
use crate::{
batch::Batch,
prepared_statement::PreparedStatement,
test_utils::{create_new_session_builder, unique_keyspace_name},
IntoTypedRows, QueryResult, Session,
};

#[tokio::test]
async fn test_large_batch_statements() {
let mut session = create_new_session_builder().build().await.unwrap();
let ks = unique_keyspace_name();
session = create_test_session(session, &ks).await;

let max_number_of_queries = u16::MAX as usize;
write_batch(&session, max_number_of_queries).await;

let key_prefix = vec![0];
let keys = find_keys_by_prefix(&session, key_prefix.clone()).await;
assert_eq!(keys.len(), max_number_of_queries);

let too_many_queries = u16::MAX as usize + 1;

let err = write_batch(&session, too_many_queries).await;

assert!(err.is_err());
}

async fn create_test_session(session: Session, ks: &String) -> Session {
session
.query(
format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks),
&[],
)
.await.unwrap();
session
.query("DROP TABLE IF EXISTS kv.pairs;", &[])
.await
.unwrap();
session
.query(
"CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))",
&[],
)
.await.unwrap();
session
}

async fn write_batch(session: &Session, n: usize) -> Result<QueryResult, QueryError> {
let mut batch_query = Batch::new(BatchType::Logged);
let mut batch_values = Vec::new();
for i in 0..n {
let mut key = vec![0];
serialize_into(&mut key, &(i as usize)).unwrap();
let value = key.clone();
let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)";
let values = vec![key, value];
batch_values.push(values);
let query = Query::new(query);
batch_query.append_statement(query);
}
session.batch(&batch_query, batch_values).await
}

async fn find_keys_by_prefix(session: &Session, key_prefix: Vec<u8>) -> Vec<Vec<u8>> {
let len = key_prefix.len();
let rows = match get_upper_bound_option(&key_prefix) {
None => {
let values = (key_prefix,);
let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING";
session.query(query, values).await.unwrap()
}
Some(upper_bound) => {
let values = (key_prefix, upper_bound);
let query =
"SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING";
session.query(query, values).await.unwrap()
}
};
let mut keys = Vec::new();
if let Some(rows) = rows.rows {
for row in rows.into_typed::<(Vec<u8>,)>() {
let key = row.unwrap();
let short_key = key.0[len..].to_vec();
keys.push(short_key);
}
}
keys
}

fn get_upper_bound_option(key_prefix: &[u8]) -> Option<Vec<u8>> {
let len = key_prefix.len();
for i in (0..len).rev() {
let val = key_prefix[i];
if val < u8::MAX {
let mut upper_bound = key_prefix[0..i + 1].to_vec();
upper_bound[i] += 1;
return Some(upper_bound);
}
}
None
}
2 changes: 2 additions & 0 deletions scylla/src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod silent_prepare_batch_test;
mod cql_types_test;
#[cfg(test)]
mod cql_value_test;
#[cfg(test)]
mod large_batch_statements_test;

pub use cluster::ClusterData;
pub use node::{KnownNode, Node, NodeAddr, NodeRef};
7 changes: 7 additions & 0 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize;
use crate::authentication::AuthenticatorProvider;
#[cfg(feature = "ssl")]
use openssl::ssl::SslContext;
use scylla_cql::errors::BadQuery;

/// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses.
///
Expand Down Expand Up @@ -1143,6 +1144,12 @@ impl Session {
// Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard
// If users batch statements by shard, they will be rewarded with full shard awareness

// check to ensure that we don't send a batch statement with more than u16::MAX queries
if batch.statements.len() > u16::MAX as usize {
return Err(QueryError::BadQuery(
BadQuery::TooManyQueriesInBatchStatement,
));
}
// Extract first serialized_value
let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?;
let first_serialized_value = first_serialized_value.as_deref();
Expand Down

0 comments on commit 772c46f

Please sign in to comment.