diff --git a/scylla/src/transport/session_test.rs b/scylla/src/transport/session_test.rs index 318848c57..f318aa601 100644 --- a/scylla/src/transport/session_test.rs +++ b/scylla/src/transport/session_test.rs @@ -1,4 +1,4 @@ -use crate::batch::{Batch, BatchStatement}; +use crate::batch::{Batch, BatchStatement, BatchType}; use crate::deserialize::DeserializeOwnedValue; use crate::prepared_statement::PreparedStatement; use crate::query::Query; @@ -33,7 +33,7 @@ use scylla_cql::types::serialize::value::SerializeValue; use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeSet, HashSet}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tokio::net::TcpListener; use uuid::Uuid; @@ -1328,6 +1328,82 @@ async fn test_timestamp() { assert_eq!(results, expected_results); } +#[tokio::test] +async fn test_timestamp_generator() { + use crate::transport::timestamp_generator::TimestampGenerator; + use std::time::{SystemTime, UNIX_EPOCH}; + setup_tracing(); + struct LocalTimestampGenerator { + generated_timestamps: Arc>>, + } + impl TimestampGenerator for LocalTimestampGenerator { + fn next_timestamp(&self) -> i64 { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + self.generated_timestamps.lock().unwrap().insert(timestamp); + timestamp + } + } + let timestamps = Arc::new(Mutex::new(HashSet::new())); + let generator = LocalTimestampGenerator { + generated_timestamps: timestamps.clone(), + }; + + let session = create_new_session_builder() + .timestamp_generator(Arc::new(generator)) + .build() + .await + .unwrap(); + let ks = unique_keyspace_name(); + session.ddl(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 1}}", ks)).await.unwrap(); + session + .ddl(format!( + "CREATE TABLE IF NOT EXISTS {}.t_generator (a int primary key, b int)", + ks + )) + .await + .unwrap(); + let prepared = session + .prepare(format!( + "INSERT INTO {}.t_generator (a, b) VALUES (1, 1)", + ks + )) + .await + .unwrap(); + session.execute_unpaged(&prepared, []).await.unwrap(); + let unprepared = Query::new(format!( + "INSERT INTO {}.t_generator (a, b) VALUES (2, 2)", + ks + )); + session.query_unpaged(unprepared, []).await.unwrap(); + let mut batch = Batch::new(BatchType::Unlogged); + let stmt = session + .prepare(format!( + "INSERT INTO {}.t_generator (a, b) VALUES (3, 3)", + ks + )) + .await + .unwrap(); + batch.append_statement(stmt); + session.batch(&batch, &((),)).await.unwrap(); + + let query_rows_result = session + .query_unpaged( + format!("SELECT a, b, WRITETIME(b) FROM {}.t_generator", ks), + &[], + ) + .await + .unwrap() + .into_rows_result() + .unwrap(); + assert!(query_rows_result + .rows::<(i32, i32, i64)>() + .unwrap() + .all(|x| timestamps.lock().unwrap().contains(&x.unwrap().2))) +} + #[ignore = "works on remote Scylla instances only (local ones are too fast)"] #[tokio::test] async fn test_request_timeout() { diff --git a/scylla/src/transport/timestamp_generator.rs b/scylla/src/transport/timestamp_generator.rs index 838fb60c8..9f12b32dc 100644 --- a/scylla/src/transport/timestamp_generator.rs +++ b/scylla/src/transport/timestamp_generator.rs @@ -128,3 +128,46 @@ impl TimestampGenerator for MonotonicTimestampGenerator { } } } + +#[tokio::test] +async fn monotonic_timestamp_generator_is_monotonic() { + const NUMBER_OF_ITERATIONS: u32 = 1000; + + let mut prev = None; + let mut cur; + let generator = MonotonicTimestampGenerator::new(); + for _ in 0..NUMBER_OF_ITERATIONS { + cur = generator.next_timestamp(); + if let Some(prev_val) = prev { + assert!(cur > prev_val); + } + prev = Some(cur); + } +} + +#[tokio::test] +async fn monotonic_timestamp_generator_is_monotonic_with_concurrency() { + use std::sync::Arc; + use tokio::sync::mpsc::unbounded_channel; + const NUMBER_OF_ITERATIONS: u32 = 1000; + let (sender, mut receiver) = unbounded_channel(); + let mut prev = None; + let mut cur; + let generator = Arc::new(MonotonicTimestampGenerator::new()); + for _ in 0..10 { + let sender = sender.clone(); + let generator = generator.clone(); + tokio::task::spawn(async move { + for _ in 0..NUMBER_OF_ITERATIONS { + sender.send(generator.next_timestamp()).unwrap(); + } + }); + } + for _ in 0..(10 * NUMBER_OF_ITERATIONS) { + cur = receiver.recv().await; + if let Some(x) = prev { + assert!(cur > x); + } + prev = Some(cur); + } +}