diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 03773b338..6db127faa 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -29,6 +29,7 @@ use tokio::sync::{mpsc, oneshot, watch}; use topology::{create_topology_task, TaskConnectionInfo}; use uuid::Uuid; +mod murmur; pub mod node; mod node_pool; mod rewrite; @@ -396,7 +397,6 @@ impl CassandraSinkCluster { .get_replica_connection_in_dc( execute, rack, - self.version.unwrap(), &mut self.rng, &self.connection_factory, ) diff --git a/shotover/src/transforms/cassandra/sink_cluster/murmur.rs b/shotover/src/transforms/cassandra/sink_cluster/murmur.rs new file mode 100644 index 000000000..28011670f --- /dev/null +++ b/shotover/src/transforms/cassandra/sink_cluster/murmur.rs @@ -0,0 +1,175 @@ +//! Taken from https://github.com/scylladb/scylla-rust-driver/blob/4a4fd0e5e785031956f560ecf22cb8653eea122b/scylla/src/routing.rs +//! We cant import it as that would bring the openssl dependency into shotover. + +use bytes::Buf; +use cassandra_protocol::token::Murmur3Token; +use std::num::Wrapping; + +pub struct Murmur3PartitionerHasher { + total_len: usize, + buf: [u8; Self::BUF_CAPACITY], + h1: Wrapping, + h2: Wrapping, +} + +impl Murmur3PartitionerHasher { + const BUF_CAPACITY: usize = 16; + + const C1: Wrapping = Wrapping(0x87c3_7b91_1142_53d5_u64 as i64); + const C2: Wrapping = Wrapping(0x4cf5_ad43_2745_937f_u64 as i64); + + pub fn new() -> Self { + Self { + total_len: 0, + buf: Default::default(), + h1: Wrapping(0), + h2: Wrapping(0), + } + } + + fn hash_16_bytes(&mut self, mut k1: Wrapping, mut k2: Wrapping) { + k1 *= Self::C1; + k1 = Self::rotl64(k1, 31); + k1 *= Self::C2; + self.h1 ^= k1; + + self.h1 = Self::rotl64(self.h1, 27); + self.h1 += self.h2; + self.h1 = self.h1 * Wrapping(5) + Wrapping(0x52dce729); + + k2 *= Self::C2; + k2 = Self::rotl64(k2, 33); + k2 *= Self::C1; + self.h2 ^= k2; + + self.h2 = Self::rotl64(self.h2, 31); + self.h2 += self.h1; + self.h2 = self.h2 * Wrapping(5) + Wrapping(0x38495ab5); + } + + fn fetch_16_bytes_from_buf(buf: &mut &[u8]) -> (Wrapping, Wrapping) { + let k1 = Wrapping(buf.get_i64_le()); + let k2 = Wrapping(buf.get_i64_le()); + (k1, k2) + } + + #[inline] + fn rotl64(v: Wrapping, n: u32) -> Wrapping { + Wrapping((v.0 << n) | (v.0 as u64 >> (64 - n)) as i64) + } + + #[inline] + fn fmix(mut k: Wrapping) -> Wrapping { + k ^= Wrapping((k.0 as u64 >> 33) as i64); + k *= Wrapping(0xff51afd7ed558ccd_u64 as i64); + k ^= Wrapping((k.0 as u64 >> 33) as i64); + k *= Wrapping(0xc4ceb9fe1a85ec53_u64 as i64); + k ^= Wrapping((k.0 as u64 >> 33) as i64); + + k + } +} + +// The implemented Murmur3 algorithm is roughly as follows: +// 1. while there are at least 16 bytes given: +// consume 16 bytes by parsing them into i64s, then +// include them in h1, h2, k1, k2; +// 2. do some magic with remaining n < 16 bytes, +// include them in h1, h2, k1, k2; +// 3. compute the token based on h1, h2, k1, k2. +// +// Therefore, the buffer of capacity 16 is used. As soon as it gets full, +// point 1. is executed. Points 2. and 3. are exclusively done in `finish()`, +// so they don't mutate the state. +impl Murmur3PartitionerHasher { + pub fn write(&mut self, mut pk_part: &[u8]) { + let mut buf_len = self.total_len % Self::BUF_CAPACITY; + self.total_len += pk_part.len(); + + // If the buffer is nonempty and can be filled completely, so that we can fetch two i64s from it, + // fill it and hash its contents, then make it empty. + if buf_len > 0 && Self::BUF_CAPACITY - buf_len <= pk_part.len() { + // First phase: populate buffer until full, then consume two i64s. + let to_write = Ord::min(Self::BUF_CAPACITY - buf_len, pk_part.len()); + self.buf[buf_len..buf_len + to_write].copy_from_slice(&pk_part[..to_write]); + pk_part.advance(to_write); + buf_len += to_write; + + debug_assert_eq!(buf_len, Self::BUF_CAPACITY); + // consume 16 bytes from internal buf + let mut buf_ptr = &self.buf[..]; + let (k1, k2) = Self::fetch_16_bytes_from_buf(&mut buf_ptr); + debug_assert!(buf_ptr.is_empty()); + self.hash_16_bytes(k1, k2); + buf_len = 0; + } + + // If there were enough data, now we have an empty buffer. Further data, if enough, can be hence + // hashed directly from the external buffer. + if buf_len == 0 { + // Second phase: fast path for big values. + while pk_part.len() >= Self::BUF_CAPACITY { + let (k1, k2) = Self::fetch_16_bytes_from_buf(&mut pk_part); + self.hash_16_bytes(k1, k2); + } + } + + // Third phase: move remaining bytes to the buffer. + debug_assert!(pk_part.len() < Self::BUF_CAPACITY - buf_len); + let to_write = pk_part.len(); + self.buf[buf_len..buf_len + to_write].copy_from_slice(&pk_part[..to_write]); + pk_part.advance(to_write); + buf_len += to_write; + debug_assert!(pk_part.is_empty()); + + debug_assert!(buf_len < Self::BUF_CAPACITY); + } + + pub fn finish(&self) -> Murmur3Token { + let mut h1 = self.h1; + let mut h2 = self.h2; + + let mut k1 = Wrapping(0_i64); + let mut k2 = Wrapping(0_i64); + + let buf_len = self.total_len % Self::BUF_CAPACITY; + + if buf_len > 8 { + for i in (8..buf_len).rev() { + k2 ^= Wrapping(self.buf[i] as i8 as i64) << ((i - 8) * 8); + } + + k2 *= Self::C2; + k2 = Self::rotl64(k2, 33); + k2 *= Self::C1; + h2 ^= k2; + } + + if buf_len > 0 { + for i in (0..std::cmp::min(8, buf_len)).rev() { + k1 ^= Wrapping(self.buf[i] as i8 as i64) << (i * 8); + } + + k1 *= Self::C1; + k1 = Self::rotl64(k1, 31); + k1 *= Self::C2; + h1 ^= k1; + } + + h1 ^= Wrapping(self.total_len as i64); + h2 ^= Wrapping(self.total_len as i64); + + h1 += h2; + h2 += h1; + + h1 = Self::fmix(h1); + h2 = Self::fmix(h2); + + h1 += h2; + h2 += h1; + + Murmur3Token { + value: (((h2.0 as i128) << 64) | h1.0 as i128) as i64, + } + } +} diff --git a/shotover/src/transforms/cassandra/sink_cluster/node_pool.rs b/shotover/src/transforms/cassandra/sink_cluster/node_pool.rs index 30400fe85..94b899884 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/node_pool.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/node_pool.rs @@ -1,13 +1,10 @@ -use crate::transforms::cassandra::connection::CassandraConnection; - use super::node::{CassandraNode, ConnectionFactory}; use super::routing_key::calculate_routing_key; use super::token_map::TokenMap; use super::KeyspaceChanRx; +use crate::transforms::cassandra::connection::CassandraConnection; use anyhow::{anyhow, Context, Error, Result}; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; -use cassandra_protocol::frame::Version; -use cassandra_protocol::token::Murmur3Token; use cassandra_protocol::types::CBytesShort; use metrics::{register_counter, Counter}; use rand::prelude::*; @@ -187,7 +184,6 @@ impl NodePool { &mut self, execute: &BodyReqExecuteOwned, rack: &str, - version: Version, rng: &mut SmallRng, ) -> Result, GetReplicaErr> { let metadata = { @@ -195,6 +191,7 @@ impl NodePool { read_lock .get(&execute.id) .ok_or(GetReplicaErr::NoPreparedMetadata)? + // TODO: wrap metadata in arc or something to make clone cheap .clone() }; @@ -213,13 +210,13 @@ impl NodePool { execute.query_parameters.values.as_ref().ok_or_else(|| { GetReplicaErr::Other(anyhow!("Execute body does not have query parameters")) })?, - version, ) .ok_or(GetReplicaErr::NoRoutingKey)?; + // TODO: How does scylla implement this? let replica_host_ids = self .token_map - .iter_replica_nodes(self.nodes(), Murmur3Token::generate(&routing_key), keyspace) + .iter_replica_nodes(self.nodes(), routing_key, keyspace) .collect::>(); let mut nodes: Vec<&mut CassandraNode> = self @@ -253,6 +250,7 @@ impl NodePool { "Shotover with designated rack {rack:?} found replica nodes {replica_host_ids:?}" ); + // TODO: can we return an iterator instead of a Vec? Ok(nodes) } @@ -260,13 +258,10 @@ impl NodePool { &mut self, execute: &BodyReqExecuteOwned, rack: &str, - version: Version, rng: &mut SmallRng, connection_factory: &ConnectionFactory, ) -> Result<&CassandraConnection, GetReplicaErr> { - let nodes = self - .get_replica_node_in_dc(execute, rack, version, rng) - .await?; + let nodes = self.get_replica_node_in_dc(execute, rack, rng).await?; get_accessible_node(connection_factory, nodes) .await diff --git a/shotover/src/transforms/cassandra/sink_cluster/routing_key.rs b/shotover/src/transforms/cassandra/sink_cluster/routing_key.rs index 1a07f8f9e..1d6dc4fa0 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/routing_key.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/routing_key.rs @@ -1,63 +1,71 @@ -use cassandra_protocol::frame::{Serialize, Version}; +use super::murmur::Murmur3PartitionerHasher; use cassandra_protocol::query::QueryValues; +use cassandra_protocol::token::Murmur3Token; use cassandra_protocol::types::value::Value; -use cassandra_protocol::types::CIntShort; -use itertools::Itertools; -use std::io::{Cursor, Write}; // functions taken from https://github.com/krojew/cdrs-tokio/blob/9246dcf4227c1d4b1ff1eafaf0abfae2d831eec4/cdrs-tokio/src/cluster/session.rs#L126 pub fn calculate_routing_key( pk_indexes: &[i16], query_values: &QueryValues, - version: Version, -) -> Option> { +) -> Option { let values = match query_values { QueryValues::SimpleValues(values) => values, _ => panic!("handle named"), }; - serialize_routing_key_with_indexes(values, pk_indexes, version) + serialize_routing_key_with_indexes(values, pk_indexes) } fn serialize_routing_key_with_indexes( values: &[Value], pk_indexes: &[i16], - version: Version, -) -> Option> { +) -> Option { + let mut partitioner_hasher = Murmur3PartitionerHasher::new(); + match pk_indexes.len() { 0 => None, 1 => values .get(pk_indexes[0] as usize) .and_then(|value| match value { - Value::Some(value) => Some(value.serialize_to_vec(version)), + Value::Some(value) => { + partitioner_hasher.write(value); + Some(Murmur3Token { + value: partitioner_hasher.finish().value, + }) + } _ => None, }), _ => { - let mut buf = vec![]; - if pk_indexes - .iter() - .map(|index| values.get(*index as usize)) - .fold_options(Cursor::new(&mut buf), |mut cursor, value| { - if let Value::Some(value) = value { - serialize_routing_value(&mut cursor, value, version) + for index in pk_indexes { + match values.get(*index as usize) { + Some(Value::Some(value)) => { + // logic for hashing in this case is not documented but implemented at: + // https://github.com/apache/cassandra/blob/3a950b45c321e051a9744721408760c568c05617/src/java/org/apache/cassandra/db/marshal/CompositeType.java#L39 + let len = value.len(); + let attempt: Result = len.try_into(); + match attempt { + Ok(len) => partitioner_hasher.write(&len.to_be_bytes()), + Err(_) => tracing::error!( + "could not route cassandra request as value was too long: {len}", + ), + } + partitioner_hasher.write(value); + partitioner_hasher.write(&[0u8]); + } + Some(Value::Null | Value::NotSet) => { + // write nothing + } + _ => { + // do not perform routing + return None; } - cursor - }) - .is_some() - { - Some(buf) - } else { - None + } } + // serialize_routing_value(&mut cursor, value, version); + Some(Murmur3Token { + value: partitioner_hasher.finish().value, + }) } } } - -// https://github.com/apache/cassandra/blob/3a950b45c321e051a9744721408760c568c05617/src/java/org/apache/cassandra/db/marshal/CompositeType.java#L39 -fn serialize_routing_value(cursor: &mut Cursor<&mut Vec>, value: &Vec, version: Version) { - let size: CIntShort = value.len().try_into().unwrap(); - size.serialize(cursor, version); - value.serialize(cursor, version); - let _ = cursor.write(&[0]); -} diff --git a/shotover/src/transforms/cassandra/sink_cluster/test_router.rs b/shotover/src/transforms/cassandra/sink_cluster/test_router.rs index f6705f168..4cb1750be 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/test_router.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/test_router.rs @@ -9,7 +9,6 @@ mod test_token_aware_router { use crate::transforms::cassandra::sink_cluster::{KeyspaceChanRx, KeyspaceChanTx}; use cassandra_protocol::consistency::Consistency::One; use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned; - use cassandra_protocol::frame::Version; use cassandra_protocol::query::QueryParams; use cassandra_protocol::query::QueryValues::SimpleValues; use cassandra_protocol::token::Murmur3Token; @@ -69,14 +68,11 @@ mod test_token_aware_router { now_in_seconds: None, }; - let token = Murmur3Token::generate( - &calculate_routing_key( - &prepared_metadata().pk_indexes, - query_parameters.values.as_ref().unwrap(), - Version::V4, - ) - .unwrap(), - ); + let token = calculate_routing_key( + &prepared_metadata().pk_indexes, + query_parameters.values.as_ref().unwrap(), + ) + .unwrap(); assert_eq!(token, test_token); @@ -84,7 +80,6 @@ mod test_token_aware_router { .get_replica_node_in_dc( &execute_body(id.clone(), query_parameters), "rack1", - Version::V4, &mut rng, ) .await