Skip to content

Commit

Permalink
fix tuple parsing for bit arrays, add test for hamming distance
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Jul 25, 2024
1 parent fd30bb3 commit ef0f94b
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 17 deletions.
1 change: 1 addition & 0 deletions lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ utoipa-swagger-ui = { version = "6.0.0", features = ["actix-web"], optional = tr
actix-web-httpauth = { version = "0.8.1", optional = true }
tokio-util = "0.7.11"
byteorder = "1.5.0"
bitvec = "1.0.1"

[features]
default = ["cli", "daemon", "http-server", "autotune", "pq", "external-index", "embeddings"]
Expand Down
79 changes: 62 additions & 17 deletions lantern_cli/src/external_index/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::cli::{IndexServerArgs, UMetricKind};
use bitvec::prelude::*;
use byteorder::{ByteOrder, LittleEndian};
use itertools::Itertools;
use rand::Rng;
use std::fs;
use std::io::{Read, Write};
Expand All @@ -9,14 +11,15 @@ use std::path::Path;
use std::sync::mpsc::{self, Receiver, SyncSender};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use usearch::ffi::{IndexOptions, ScalarKind};
use usearch::ffi::{IndexOptions, MetricKind, ScalarKind};
use usearch::Index;

use crate::logger::{LogLevel, Logger};
use crate::types::*;

const LABEL_SIZE: usize = 8; // for now we are only using 32bit integers
const INTEGER_SIZE: usize = 4; // for now we are only using 32bit integers
const CHAR_BITS: usize = 8;
const LABEL_SIZE: usize = 8;
const INTEGER_SIZE: usize = 4;
const SOCKET_TIMEOUT: u64 = 5;
pub const PROTOCOL_HEADER_SIZE: usize = 4;
pub const INIT_MSG: u32 = 0x13333337;
Expand All @@ -25,7 +28,13 @@ pub const ERR_MSG: u32 = 0x37333337;
// magic byte + pq + metric_kind + quantization + dim + m + efc + ef + num_centroids +
// num_subvectors + capacity
static INDEX_HEADER_LENGTH: usize = INTEGER_SIZE * 11;
type Row = (u64, Vec<f32>);

enum VectorType {
F32(Vec<f32>),
I8(Vec<i8>),
}

type Row = (u64, VectorType);

struct ThreadSafeIndex(Index);

Expand Down Expand Up @@ -120,17 +129,30 @@ fn bytes_to_f32_vec_le(bytes: &[u8]) -> Vec<f32> {
float_vec
}

fn parse_tuple(buf: &[u8]) -> Result<Row, anyhow::Error> {
fn parse_tuple(buf: &[u8], element_bits: usize) -> Result<Row, anyhow::Error> {
let label = u64::from_le_bytes(buf[..LABEL_SIZE].try_into()?);
let vec: Vec<f32> = bytes_to_f32_vec_le(&buf[LABEL_SIZE..]);
let vec: VectorType = match element_bits {
1 => VectorType::I8(
buf[LABEL_SIZE..]
.iter()
.map(|e| {
BitSlice::<_, Lsb0>::from_element(e)
.iter()
.map(|n| if *n.as_ref() { 0 } else { 1 })
.collect::<Vec<i8>>()
})
.concat(),
),
_ => VectorType::F32(bytes_to_f32_vec_le(&buf[LABEL_SIZE..])),
};

Ok((label, vec))
}

fn initialize_index(
logger: Arc<Logger>,
stream: Arc<Mutex<TcpStream>>,
) -> Result<ThreadSafeIndex, anyhow::Error> {
) -> Result<(usize, ThreadSafeIndex), anyhow::Error> {
let mut buf = vec![0 as u8; INDEX_HEADER_LENGTH];
let mut soc_stream = stream.lock().unwrap();
match read_frame(
Expand Down Expand Up @@ -165,7 +187,13 @@ fn initialize_index(
let mut soc_stream = stream.lock().unwrap();
// send success code
soc_stream.write(&[0]).unwrap();
Ok(ThreadSafeIndex(index))

let element_bits = match index_options.metric {
MetricKind::Hamming => 1,
_ => INTEGER_SIZE * CHAR_BITS,
};

Ok((element_bits, ThreadSafeIndex(index)))
}
_ => anyhow::bail!("send init message first"),
}
Expand All @@ -176,18 +204,28 @@ fn receive_rows(
logger: Arc<Logger>,
index: Arc<RwLock<ThreadSafeIndex>>,
worker_tx: SyncSender<Row>,
element_bits: usize,
) -> AnyhowVoidResult {
let mut current_capacity = index.read().unwrap().0.capacity();
let idx = index.read().unwrap();
let mut current_capacity = idx.0.capacity();
let mut stream = stream.lock().unwrap();
let mut received_rows = 0;
let expected_payload_size = LABEL_SIZE + INTEGER_SIZE * index.read().unwrap().0.dimensions();

let expected_payload_size = if element_bits < CHAR_BITS {
LABEL_SIZE + idx.0.dimensions().div_ceil(CHAR_BITS)
} else {
LABEL_SIZE + idx.0.dimensions() * (element_bits / CHAR_BITS)
};

let mut buf = vec![0 as u8; expected_payload_size];

drop(idx);

loop {
match read_frame(&mut stream, &mut buf, expected_payload_size, None)? {
ProtocolMessage::Exit => break,
ProtocolMessage::Data(buf) => {
let row = parse_tuple(&buf)?;
let row = parse_tuple(&buf, element_bits)?;

if received_rows == current_capacity {
current_capacity *= 2;
Expand Down Expand Up @@ -256,10 +294,8 @@ pub fn create_streaming_usearch_index(
let start_time = Instant::now();
let num_cores: usize = std::thread::available_parallelism().unwrap().into();
logger.info(&format!("Number of available CPU cores: {}", num_cores));
let index = Arc::new(RwLock::new(initialize_index(
logger.clone(),
stream.clone(),
)?));
let (element_bits, index) = initialize_index(logger.clone(), stream.clone())?;
let index = Arc::new(RwLock::new(index));

// Create a vector to store thread handles
let mut handles = vec![];
Expand Down Expand Up @@ -288,7 +324,10 @@ pub fn create_streaming_usearch_index(

if let Ok(row) = row_result {
let index = index_ref.read().unwrap();
index.0.add(row.0, &row.1)?;
match row.1 {
VectorType::F32(vec) => index.0.add(row.0, &vec)?,
VectorType::I8(vec) => index.0.add(row.0, &vec)?,
}
} else {
// channel has been closed
break;
Expand All @@ -300,7 +339,13 @@ pub fn create_streaming_usearch_index(
handles.push(handle);
}

receive_rows(stream.clone(), logger.clone(), index.clone(), tx)?;
receive_rows(
stream.clone(),
logger.clone(),
index.clone(),
tx,
element_bits,
)?;

// Wait for all threads to finish processing
for handle in handles {
Expand Down
107 changes: 107 additions & 0 deletions lantern_cli/tests/external_index_server_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,113 @@ async fn test_external_index_server_indexing_scalar_quantization() {
assert_eq!(index.size(), received_index.size());
}

#[tokio::test]
async fn test_external_index_server_indexing_hamming_distance() {
initialize();
let pq_codebook: *const f32 = std::ptr::null();
let index_options = IndexOptions {
dimensions: 3 * 32,
metric: UMetricKind::from_u32(8).unwrap().value(),
quantization: ScalarKind::B1,
multi: false,
connectivity: 12,
expansion_add: 64,
expansion_search: 32,
num_threads: 0, // automatic
pq_construction: false,
pq_output: false,
num_centroids: 0,
num_subvectors: 0,
codebook: pq_codebook,
};

let tuples = vec![
(0, vec![0.0, 0.0, 0.0]),
(1, vec![0.0, 0.0, 1.0]),
(2, vec![0.0, 0.0, 2.0]),
(3, vec![0.0, 0.0, 3.0]),
(4, vec![0.0, 1.0, 0.0]),
(5, vec![0.0, 1.0, 1.0]),
(6, vec![0.0, 1.0, 2.0]),
(7, vec![0.0, 1.0, 3.0]),
(8, vec![1.0, 0.0, 0.0]),
(9, vec![1.0, 0.0, 1.0]),
(10, vec![1.0, 0.0, 2.0]),
(11, vec![1.0, 0.0, 3.0]),
(12, vec![1.0, 1.0, 0.0]),
(13, vec![1.0, 1.0, 1.0]),
];

let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap();
let init_msg = [
INIT_MSG.to_le_bytes(),
(0 as u32).to_le_bytes(),
(8 as u32).to_le_bytes(),
(5 as u32).to_le_bytes(),
(index_options.dimensions as u32).to_le_bytes(),
(index_options.connectivity as u32).to_le_bytes(),
(index_options.expansion_add as u32).to_le_bytes(),
(index_options.expansion_search as u32).to_le_bytes(),
(index_options.num_centroids as u32).to_le_bytes(),
(index_options.num_subvectors as u32).to_le_bytes(),
(tuples.len() as u32).to_le_bytes(),
]
.concat();

let bytes_written = stream.write(&init_msg).unwrap();
assert_eq!(bytes_written, init_msg.len());
let mut buf: [u8; 1] = [1; 1];
stream.read(&mut buf).unwrap();

assert_eq!(buf[0], 0);
let index = Index::new(&index_options).unwrap();
index.reserve(tuples.len()).unwrap();
for tuple in &tuples {
index.add(tuple.0 as u64, &*tuple.1).unwrap();
let tuple_buf = unsafe {
let byte_count = tuple.1.len() * std::mem::size_of::<f32>();

// Allocate a buffer for the bytes
let mut byte_vec: Vec<u8> = Vec::with_capacity(byte_count);
let float_slice = std::slice::from_raw_parts(tuple.1.as_ptr() as *const u8, byte_count);
//
// // Copy the bytes into the byte vector
byte_vec.extend_from_slice(float_slice);
let label = (tuple.0 as u64).to_le_bytes();
vec![&label, byte_vec.as_slice()].concat()
};
stream.write_all(&tuple_buf).unwrap();
}

let buf = END_MSG.to_le_bytes();
stream.write(&buf).unwrap();
let index_file_name = "/tmp/test_external_index_server_indexing.usearch";
let index_file_path = Path::new(&index_file_name);
index.save(index_file_name).unwrap();
let mut reader = fs::File::open(index_file_path).unwrap();
let index_size = reader.metadata().unwrap().size();
let mut expected_index_buffer = Vec::with_capacity(index_size as usize);
reader.read_to_end(&mut expected_index_buffer).unwrap();

// receiver num tuples added
let mut uint64_buf = [0; 8];
stream.read_exact(&mut uint64_buf).unwrap();
assert_eq!(u64::from_le_bytes(uint64_buf), tuples.len() as u64);

stream.read_exact(&mut uint64_buf).unwrap();
let received_index_size = u64::from_le_bytes(uint64_buf);
assert!(received_index_size > 0);

let mut received_index_buffer = vec![0; received_index_size as usize];
stream.read_exact(&mut received_index_buffer).unwrap();

let received_index = Index::new(&index_options).unwrap();
received_index.reserve(tuples.len()).unwrap();
Index::load_from_buffer(&received_index, &received_index_buffer).unwrap();

assert_eq!(index.size(), received_index.size());
}

#[tokio::test]
async fn test_external_index_server_indexing_pq() {
initialize();
Expand Down

0 comments on commit ef0f94b

Please sign in to comment.