Skip to content

Commit

Permalink
Optimize cassandra routing runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Dec 22, 2023
1 parent 871da4c commit e0ff300
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 54 deletions.
2 changes: 1 addition & 1 deletion shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use tokio::sync::{mpsc, oneshot, watch};
use topology::{create_topology_task, TaskConnectionInfo};
use uuid::Uuid;

mod murmur;
pub mod node;
mod node_pool;
mod rewrite;
Expand Down Expand Up @@ -396,7 +397,6 @@ impl CassandraSinkCluster {
.get_replica_connection_in_dc(
execute,
rack,
self.version.unwrap(),
&mut self.rng,
&self.connection_factory,
)
Expand Down
175 changes: 175 additions & 0 deletions shotover/src/transforms/cassandra/sink_cluster/murmur.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//! Taken from https://github.com/scylladb/scylla-rust-driver/blob/4a4fd0e5e785031956f560ecf22cb8653eea122b/scylla/src/routing.rs
//! We cant import it as that would bring the openssl dependency into shotover.
use bytes::Buf;
use cassandra_protocol::token::Murmur3Token;
use std::num::Wrapping;

pub struct Murmur3PartitionerHasher {
total_len: usize,
buf: [u8; Self::BUF_CAPACITY],
h1: Wrapping<i64>,
h2: Wrapping<i64>,
}

impl Murmur3PartitionerHasher {
const BUF_CAPACITY: usize = 16;

const C1: Wrapping<i64> = Wrapping(0x87c3_7b91_1142_53d5_u64 as i64);
const C2: Wrapping<i64> = Wrapping(0x4cf5_ad43_2745_937f_u64 as i64);

pub fn new() -> Self {
Self {
total_len: 0,
buf: Default::default(),
h1: Wrapping(0),
h2: Wrapping(0),
}
}

fn hash_16_bytes(&mut self, mut k1: Wrapping<i64>, mut k2: Wrapping<i64>) {
k1 *= Self::C1;
k1 = Self::rotl64(k1, 31);
k1 *= Self::C2;
self.h1 ^= k1;

self.h1 = Self::rotl64(self.h1, 27);
self.h1 += self.h2;
self.h1 = self.h1 * Wrapping(5) + Wrapping(0x52dce729);

k2 *= Self::C2;
k2 = Self::rotl64(k2, 33);
k2 *= Self::C1;
self.h2 ^= k2;

self.h2 = Self::rotl64(self.h2, 31);
self.h2 += self.h1;
self.h2 = self.h2 * Wrapping(5) + Wrapping(0x38495ab5);
}

fn fetch_16_bytes_from_buf(buf: &mut &[u8]) -> (Wrapping<i64>, Wrapping<i64>) {
let k1 = Wrapping(buf.get_i64_le());
let k2 = Wrapping(buf.get_i64_le());
(k1, k2)
}

#[inline]
fn rotl64(v: Wrapping<i64>, n: u32) -> Wrapping<i64> {
Wrapping((v.0 << n) | (v.0 as u64 >> (64 - n)) as i64)
}

#[inline]
fn fmix(mut k: Wrapping<i64>) -> Wrapping<i64> {
k ^= Wrapping((k.0 as u64 >> 33) as i64);
k *= Wrapping(0xff51afd7ed558ccd_u64 as i64);
k ^= Wrapping((k.0 as u64 >> 33) as i64);
k *= Wrapping(0xc4ceb9fe1a85ec53_u64 as i64);
k ^= Wrapping((k.0 as u64 >> 33) as i64);

k
}
}

// The implemented Murmur3 algorithm is roughly as follows:
// 1. while there are at least 16 bytes given:
// consume 16 bytes by parsing them into i64s, then
// include them in h1, h2, k1, k2;
// 2. do some magic with remaining n < 16 bytes,
// include them in h1, h2, k1, k2;
// 3. compute the token based on h1, h2, k1, k2.
//
// Therefore, the buffer of capacity 16 is used. As soon as it gets full,
// point 1. is executed. Points 2. and 3. are exclusively done in `finish()`,
// so they don't mutate the state.
impl Murmur3PartitionerHasher {
pub fn write(&mut self, mut pk_part: &[u8]) {
let mut buf_len = self.total_len % Self::BUF_CAPACITY;
self.total_len += pk_part.len();

// If the buffer is nonempty and can be filled completely, so that we can fetch two i64s from it,
// fill it and hash its contents, then make it empty.
if buf_len > 0 && Self::BUF_CAPACITY - buf_len <= pk_part.len() {
// First phase: populate buffer until full, then consume two i64s.
let to_write = Ord::min(Self::BUF_CAPACITY - buf_len, pk_part.len());
self.buf[buf_len..buf_len + to_write].copy_from_slice(&pk_part[..to_write]);
pk_part.advance(to_write);
buf_len += to_write;

debug_assert_eq!(buf_len, Self::BUF_CAPACITY);
// consume 16 bytes from internal buf
let mut buf_ptr = &self.buf[..];
let (k1, k2) = Self::fetch_16_bytes_from_buf(&mut buf_ptr);
debug_assert!(buf_ptr.is_empty());
self.hash_16_bytes(k1, k2);
buf_len = 0;
}

// If there were enough data, now we have an empty buffer. Further data, if enough, can be hence
// hashed directly from the external buffer.
if buf_len == 0 {
// Second phase: fast path for big values.
while pk_part.len() >= Self::BUF_CAPACITY {
let (k1, k2) = Self::fetch_16_bytes_from_buf(&mut pk_part);
self.hash_16_bytes(k1, k2);
}
}

// Third phase: move remaining bytes to the buffer.
debug_assert!(pk_part.len() < Self::BUF_CAPACITY - buf_len);
let to_write = pk_part.len();
self.buf[buf_len..buf_len + to_write].copy_from_slice(&pk_part[..to_write]);
pk_part.advance(to_write);
buf_len += to_write;
debug_assert!(pk_part.is_empty());

debug_assert!(buf_len < Self::BUF_CAPACITY);
}

pub fn finish(&self) -> Murmur3Token {
let mut h1 = self.h1;
let mut h2 = self.h2;

let mut k1 = Wrapping(0_i64);
let mut k2 = Wrapping(0_i64);

let buf_len = self.total_len % Self::BUF_CAPACITY;

if buf_len > 8 {
for i in (8..buf_len).rev() {
k2 ^= Wrapping(self.buf[i] as i8 as i64) << ((i - 8) * 8);
}

k2 *= Self::C2;
k2 = Self::rotl64(k2, 33);
k2 *= Self::C1;
h2 ^= k2;
}

if buf_len > 0 {
for i in (0..std::cmp::min(8, buf_len)).rev() {
k1 ^= Wrapping(self.buf[i] as i8 as i64) << (i * 8);
}

k1 *= Self::C1;
k1 = Self::rotl64(k1, 31);
k1 *= Self::C2;
h1 ^= k1;
}

h1 ^= Wrapping(self.total_len as i64);
h2 ^= Wrapping(self.total_len as i64);

h1 += h2;
h2 += h1;

h1 = Self::fmix(h1);
h2 = Self::fmix(h2);

h1 += h2;
h2 += h1;

Murmur3Token {
value: (((h2.0 as i128) << 64) | h1.0 as i128) as i64,
}
}
}
17 changes: 6 additions & 11 deletions shotover/src/transforms/cassandra/sink_cluster/node_pool.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
use crate::transforms::cassandra::connection::CassandraConnection;

use super::node::{CassandraNode, ConnectionFactory};
use super::routing_key::calculate_routing_key;
use super::token_map::TokenMap;
use super::KeyspaceChanRx;
use crate::transforms::cassandra::connection::CassandraConnection;
use anyhow::{anyhow, Context, Error, Result};
use cassandra_protocol::frame::message_execute::BodyReqExecuteOwned;
use cassandra_protocol::frame::Version;
use cassandra_protocol::token::Murmur3Token;
use cassandra_protocol::types::CBytesShort;
use metrics::{register_counter, Counter};
use rand::prelude::*;
Expand Down Expand Up @@ -187,14 +184,14 @@ impl NodePool {
&mut self,
execute: &BodyReqExecuteOwned,
rack: &str,
version: Version,
rng: &mut SmallRng,
) -> Result<Vec<&mut CassandraNode>, GetReplicaErr> {
let metadata = {
let read_lock = self.prepared_metadata.read().await;
read_lock
.get(&execute.id)
.ok_or(GetReplicaErr::NoPreparedMetadata)?
// TODO: wrap metadata in arc or something to make clone cheap
.clone()
};

Expand All @@ -213,13 +210,13 @@ impl NodePool {
execute.query_parameters.values.as_ref().ok_or_else(|| {
GetReplicaErr::Other(anyhow!("Execute body does not have query parameters"))
})?,
version,
)
.ok_or(GetReplicaErr::NoRoutingKey)?;

// TODO: How does scylla implement this?
let replica_host_ids = self
.token_map
.iter_replica_nodes(self.nodes(), Murmur3Token::generate(&routing_key), keyspace)
.iter_replica_nodes(self.nodes(), routing_key, keyspace)
.collect::<Vec<uuid::Uuid>>();

let mut nodes: Vec<&mut CassandraNode> = self
Expand Down Expand Up @@ -253,20 +250,18 @@ impl NodePool {
"Shotover with designated rack {rack:?} found replica nodes {replica_host_ids:?}"
);

// TODO: can we return an iterator instead of a Vec?
Ok(nodes)
}

pub async fn get_replica_connection_in_dc(
&mut self,
execute: &BodyReqExecuteOwned,
rack: &str,
version: Version,
rng: &mut SmallRng,
connection_factory: &ConnectionFactory,
) -> Result<&CassandraConnection, GetReplicaErr> {
let nodes = self
.get_replica_node_in_dc(execute, rack, version, rng)
.await?;
let nodes = self.get_replica_node_in_dc(execute, rack, rng).await?;

get_accessible_node(connection_factory, nodes)
.await
Expand Down
72 changes: 40 additions & 32 deletions shotover/src/transforms/cassandra/sink_cluster/routing_key.rs
Original file line number Diff line number Diff line change
@@ -1,63 +1,71 @@
use cassandra_protocol::frame::{Serialize, Version};
use super::murmur::Murmur3PartitionerHasher;
use cassandra_protocol::query::QueryValues;
use cassandra_protocol::token::Murmur3Token;
use cassandra_protocol::types::value::Value;
use cassandra_protocol::types::CIntShort;
use itertools::Itertools;
use std::io::{Cursor, Write};

// functions taken from https://github.com/krojew/cdrs-tokio/blob/9246dcf4227c1d4b1ff1eafaf0abfae2d831eec4/cdrs-tokio/src/cluster/session.rs#L126

pub fn calculate_routing_key(
pk_indexes: &[i16],
query_values: &QueryValues,
version: Version,
) -> Option<Vec<u8>> {
) -> Option<Murmur3Token> {
let values = match query_values {
QueryValues::SimpleValues(values) => values,
_ => panic!("handle named"),
};

serialize_routing_key_with_indexes(values, pk_indexes, version)
serialize_routing_key_with_indexes(values, pk_indexes)
}

fn serialize_routing_key_with_indexes(
values: &[Value],
pk_indexes: &[i16],
version: Version,
) -> Option<Vec<u8>> {
) -> Option<Murmur3Token> {
let mut partitioner_hasher = Murmur3PartitionerHasher::new();

match pk_indexes.len() {
0 => None,
1 => values
.get(pk_indexes[0] as usize)
.and_then(|value| match value {
Value::Some(value) => Some(value.serialize_to_vec(version)),
Value::Some(value) => {
partitioner_hasher.write(value);
Some(Murmur3Token {
value: partitioner_hasher.finish().value,
})
}
_ => None,
}),
_ => {
let mut buf = vec![];
if pk_indexes
.iter()
.map(|index| values.get(*index as usize))
.fold_options(Cursor::new(&mut buf), |mut cursor, value| {
if let Value::Some(value) = value {
serialize_routing_value(&mut cursor, value, version)
for index in pk_indexes {
match values.get(*index as usize) {
Some(Value::Some(value)) => {
// logic for hashing in this case is not documented but implemented at:
// https://github.com/apache/cassandra/blob/3a950b45c321e051a9744721408760c568c05617/src/java/org/apache/cassandra/db/marshal/CompositeType.java#L39
let len = value.len();
let attempt: Result<u16, _> = len.try_into();
match attempt {
Ok(len) => partitioner_hasher.write(&len.to_be_bytes()),
Err(_) => tracing::error!(
"could not route cassandra request as value was too long: {len}",
),
}
partitioner_hasher.write(value);
partitioner_hasher.write(&[0u8]);
}
Some(Value::Null | Value::NotSet) => {
// write nothing
}
_ => {
// do not perform routing
return None;
}
cursor
})
.is_some()
{
Some(buf)
} else {
None
}
}
// serialize_routing_value(&mut cursor, value, version);
Some(Murmur3Token {
value: partitioner_hasher.finish().value,
})
}
}
}

// https://github.com/apache/cassandra/blob/3a950b45c321e051a9744721408760c568c05617/src/java/org/apache/cassandra/db/marshal/CompositeType.java#L39
fn serialize_routing_value(cursor: &mut Cursor<&mut Vec<u8>>, value: &Vec<u8>, version: Version) {
let size: CIntShort = value.len().try_into().unwrap();
size.serialize(cursor, version);
value.serialize(cursor, version);
let _ = cursor.write(&[0]);
}
Loading

0 comments on commit e0ff300

Please sign in to comment.