Skip to content

Commit

Permalink
Cassandra 5 vector support (#1676)
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai authored Jul 15, 2024
1 parent aa0c320 commit ce4a153
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 11 deletions.
16 changes: 16 additions & 0 deletions shotover-proxy/tests/cassandra_int_tests/collections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use test_helpers::connection::cassandra::{
mod list;
mod map;
mod set;
mod vector;

fn supported_native_col_types(connection: &CassandraConnection) -> &'static [ColType] {
match connection {
Expand Down Expand Up @@ -338,3 +339,18 @@ pub async fn test(connection: &CassandraConnection, driver: CassandraDriver) {
set::test(connection, driver).await;
map::test(connection, driver).await;
}

pub async fn test_cassandra_5(connection: &CassandraConnection, driver: CassandraDriver) {
run_query(
connection,
"CREATE KEYSPACE collections WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };"
).await;

list::test(connection, driver).await;
set::test(connection, driver).await;
map::test(connection, driver).await;

if let CassandraDriver::Java = driver {
vector::test(connection).await;
}
}
47 changes: 47 additions & 0 deletions shotover-proxy/tests/cassandra_int_tests/collections/vector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use super::*;

// We dont use the collection abstractions used by list/map/set since vectors only support a small subset of data types.

async fn create(connection: &CassandraConnection) {
run_query(
connection,
"CREATE TABLE collections.vector (id int PRIMARY KEY, col0 vector<int, 1>, col1 vector<bigint, 2>, col2 vector<float, 2>, col3 vector<double, 2>);",
)
.await;
}

async fn insert(connection: &CassandraConnection) {
run_query(
connection,
"INSERT INTO collections.vector (id, col0, col1, col2, col3) VALUES (1, [1], [2, 3], [4.1, 5.2], [6.1, 7.2]);",
)
.await;
}

async fn select(connection: &CassandraConnection) {
let results = vec![
ResultValue::Vector(vec![ResultValue::Int(1)]),
ResultValue::Vector(vec![ResultValue::BigInt(2), ResultValue::BigInt(3)]),
ResultValue::Vector(vec![
ResultValue::Float(4.1.into()),
ResultValue::Float(5.2.into()),
]),
ResultValue::Vector(vec![
ResultValue::Double(6.1.into()),
ResultValue::Double(7.2.into()),
]),
];

assert_query_result(
connection,
"SELECT col0, col1, col2, col3 FROM collections.vector;",
&[&results],
)
.await;
}

pub async fn test(connection: &CassandraConnection) {
create(connection).await;
insert(connection).await;
select(connection).await;
}
27 changes: 24 additions & 3 deletions shotover-proxy/tests/cassandra_int_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ where
timestamp::test(&connection).await;
}

async fn standard_test_suite_cassandra5<Fut>(
connection_creator: impl Fn() -> Fut,
driver: CassandraDriver,
) where
Fut: Future<Output = CassandraConnection>,
{
// reuse a single connection a bunch to save time recreating connections
let connection = connection_creator().await;

keyspace::test(&connection).await;
table::test(&connection).await;
udt::test(&connection).await;
native_types::test(&connection).await;
collections::test_cassandra_5(&connection, driver).await;
functions::test(&connection).await;
prepared_statements_simple::test(&connection, connection_creator, 1).await;
prepared_statements_all::test(&connection, 1).await;
batch_statements::test(&connection).await;
timestamp::test(&connection).await;
}

async fn standard_test_suite_rf3<Fut>(connection_creator: impl Fn() -> Fut, driver: CassandraDriver)
where
Fut: Future<Output = CassandraConnection>,
Expand Down Expand Up @@ -1066,7 +1087,7 @@ async fn passthrough_tls_websockets() {

#[apply(all_cassandra_drivers)]
#[tokio::test(flavor = "multi_thread")]
async fn cassandra_5(#[case] driver: CassandraDriver) {
async fn cassandra_5_passthrough(#[case] driver: CassandraDriver) {
let _compose = docker_compose("tests/test-configs/cassandra/cassandra-5/docker-compose.yaml");

let shotover = shotover_process("tests/test-configs/cassandra/cassandra-5/topology.yaml")
Expand All @@ -1075,7 +1096,7 @@ async fn cassandra_5(#[case] driver: CassandraDriver) {

let connection = || CassandraConnectionBuilder::new("127.0.0.1", 9042, driver).build();

standard_test_suite(&connection, driver).await;
standard_test_suite_cassandra5(&connection, driver).await;

shotover.shutdown_and_then_consume_events(&[]).await;
}
Expand All @@ -1102,7 +1123,7 @@ async fn cassandra_5_cluster(#[case] driver: CassandraDriver) {
connection
};

standard_test_suite(&connection, driver).await;
standard_test_suite_cassandra5(&connection, driver).await;

shotover.shutdown_and_then_consume_events(&[]).await;
}
1 change: 1 addition & 0 deletions shotover/src/frame/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub enum GenericValue {
Counter(i64),
Tuple(Vec<GenericValue>),
Udt(BTreeMap<String, GenericValue>),
Custom(Bytes),
}

#[derive(PartialEq, Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialOrd, Ord)]
Expand Down
33 changes: 29 additions & 4 deletions shotover/src/frame/value/cassandra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::{Duration, GenericValue, IntSize};
use crate::frame::cassandra::to_cassandra_type;
use bigdecimal::BigDecimal;
use bytes::Bytes;
use cassandra_protocol::frame::message_result::ColType;
use cassandra_protocol::frame::Serialize as FrameSerialize;
use cassandra_protocol::types::CInt;
use cassandra_protocol::{
Expand All @@ -25,6 +26,18 @@ impl From<&Operand> for GenericValue {
}
}

/// This type is a hack and should eventually be removed.
/// To properly resolve this we really need to:
/// * remove GenericValue from the cassandra AST
/// + its a holdover from a long time ago
/// - shotover no longer attempts to provide a generic abstraction over multiple DB types
/// * Replace cassandra-protocol with something simpler and faster.
/// Once that is done we can just include `Custom` directly as a variant within our own CassandraType enum.
enum CustomOrStandardType {
Custom(Option<Vec<u8>>),
Standard(CassandraType),
}

impl GenericValue {
pub fn value_byte_string(string: String) -> GenericValue {
GenericValue::Bytes(Bytes::from(string))
Expand All @@ -40,16 +53,27 @@ impl GenericValue {
data: &CBytes,
) -> GenericValue {
let cassandra_type = GenericValue::into_cassandra_type(version, &spec.col_type, data);
GenericValue::create_element(cassandra_type)
match cassandra_type {
CustomOrStandardType::Custom(Some(bytes)) => GenericValue::Custom(bytes.into()),
CustomOrStandardType::Custom(None) => GenericValue::Null,
CustomOrStandardType::Standard(element) => Self::create_element(element),
}
}

fn into_cassandra_type(
version: Version,
col_type: &ColTypeOption,
data: &CBytes,
) -> CassandraType {
let wrapper = wrapper_fn(&col_type.id);
wrapper(data, col_type, version).unwrap()
) -> CustomOrStandardType {
// cassandra-protocol will error on an unknown custom type,
// but we need to continue succesfully with custom types even if that means treating them as a magical bag of bytes.
// so we check for custom type before running the cassandra-protocol parser.
if col_type.id == ColType::Custom {
CustomOrStandardType::Custom(data.clone().into_bytes())
} else {
let wrapper = wrapper_fn(&col_type.id);
CustomOrStandardType::Standard(wrapper(data, col_type, version).unwrap())
}
}

fn create_element(element: CassandraType) -> GenericValue {
Expand Down Expand Up @@ -120,6 +144,7 @@ impl GenericValue {
pub fn cassandra_serialize(&self, cursor: &mut Cursor<&mut Vec<u8>>) {
match self {
GenericValue::Null => cursor.write_all(&[255, 255, 255, 255]).unwrap(),
GenericValue::Custom(b) => serialize_bytes(cursor, b),
GenericValue::Bytes(b) => serialize_bytes(cursor, b),
GenericValue::Strings(s) => serialize_bytes(cursor, s.as_bytes()),
GenericValue::Integer(x, size) => match size {
Expand Down
1 change: 1 addition & 0 deletions shotover/src/frame/value/redis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl From<GenericValue> for RedisFrame {
GenericValue::Tuple(_) => todo!(),
GenericValue::Udt(_) => todo!(),
GenericValue::Duration(_) => todo!(),
GenericValue::Custom(_) => todo!(),
}
}
}
85 changes: 82 additions & 3 deletions test-helpers/src/connection/cassandra/connection/java.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use super::{Compression, Consistency, PreparedQuery, ProtocolVersion, Tls};
use crate::connection::cassandra::ResultValue;
use crate::connection::java::{Jvm, Value};
use cdrs_tokio::frame::message_error::ErrorBody;
use cdrs_tokio::frame::message_result::ColType;
use std::net::IpAddr;

pub struct JavaConnection {
Expand Down Expand Up @@ -249,7 +248,17 @@ impl JavaConnection {
})
.collect()
}),
ty => todo!("{ty}"),
ColType::Vector => ResultValue::Vector({
value
.cast("com.datastax.oss.driver.api.core.data.CqlVector")
.call("iterator", vec![])
.into_iter()
// TODO: no way to provide a correct raw_bytes value here,
// need to change tests to not use raw_bytes instead.
.map(|value| Self::java_value_to_rust(value, raw_bytes, &ty.element_col_type()))
.collect()
}),
ty => todo!("{ty:?}"),
}
}

Expand Down Expand Up @@ -410,7 +419,46 @@ struct DataType(Value);
impl DataType {
fn col_type(&self) -> ColType {
let code: i32 = self.0.call("getProtocolCode", vec![]).into_rust();
ColType::try_from(code as i16).unwrap()
match code {
0x00 => {
if self
.0
.cast_fallible("com.datastax.oss.driver.api.core.type.VectorType")
.is_ok()
{
ColType::Vector
} else {
ColType::Custom
}
}
0x01 => ColType::Ascii,
0x02 => ColType::Bigint,
0x03 => ColType::Blob,
0x04 => ColType::Boolean,
0x05 => ColType::Counter,
0x06 => ColType::Decimal,
0x07 => ColType::Double,
0x08 => ColType::Float,
0x09 => ColType::Int,
0x0B => ColType::Timestamp,
0x0C => ColType::Uuid,
0x0D => ColType::Varchar,
0x0E => ColType::Varint,
0x0F => ColType::Timeuuid,
0x10 => ColType::Inet,
0x11 => ColType::Date,
0x12 => ColType::Time,
0x13 => ColType::Smallint,
0x14 => ColType::Tinyint,
0x15 => ColType::Duration,
0x20 => ColType::List,
0x21 => ColType::Map,
0x22 => ColType::Set,
0x30 => ColType::Udt,
0x31 => ColType::Tuple,
0x80 => ColType::Varchar,
code => panic!("unknown type code {code:?}"),
}
}

fn element_col_type(&self) -> DataType {
Expand All @@ -429,3 +477,34 @@ impl DataType {
)
}
}

#[derive(Debug, Clone, Copy)]
pub enum ColType {
Custom,
Ascii,
Bigint,
Blob,
Boolean,
Counter,
Decimal,
Double,
Float,
Int,
Timestamp,
Uuid,
Varchar,
Varint,
Timeuuid,
Inet,
Date,
Time,
Smallint,
Tinyint,
Duration,
List,
Map,
Set,
Udt,
Tuple,
Vector,
}
2 changes: 2 additions & 0 deletions test-helpers/src/connection/cassandra/result_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub enum ResultValue {
Set(Vec<ResultValue>),
List(Vec<ResultValue>),
Tuple(Vec<ResultValue>),
Vector(Vec<ResultValue>),
Map(Vec<(ResultValue, ResultValue)>),
Null,
/// Never output by the DB
Expand Down Expand Up @@ -66,6 +67,7 @@ impl PartialEq for ResultValue {
(Self::List(l0), Self::List(r0)) => l0 == r0,
(Self::Tuple(l0), Self::Tuple(r0)) => l0 == r0,
(Self::Map(l0), Self::Map(r0)) => l0 == r0,
(Self::Vector(l0), Self::Vector(r0)) => l0 == r0,
(Self::Null, Self::Null) => true,
(Self::Any, _) => true,
(_, Self::Any) => true,
Expand Down
2 changes: 1 addition & 1 deletion test-helpers/src/connection/java.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl Value {
self.cast_fallible(name).unwrap()
}

fn cast_fallible(&self, name: &str) -> Result<Self> {
pub(crate) fn cast_fallible(&self, name: &str) -> Result<Self> {
let instance = self.jvm.cast(&self.instance, name)?;
Ok(Self {
instance,
Expand Down

0 comments on commit ce4a153

Please sign in to comment.