diff --git a/lantern_cli/src/external_index/server.rs b/lantern_cli/src/external_index/server.rs index d4cc614..569a9ce 100644 --- a/lantern_cli/src/external_index/server.rs +++ b/lantern_cli/src/external_index/server.rs @@ -9,7 +9,7 @@ use rustls::{ServerConfig, ServerConnection, StreamOwned}; use std::cmp; use std::fs::{self, File}; use std::io::{BufReader, Read, Write}; -use std::net::{TcpListener, TcpStream}; +use std::net::{Shutdown, TcpListener, TcpStream}; use std::os::unix::fs::MetadataExt; use std::path::Path; use std::sync::mpsc::{self, Receiver, SyncSender}; @@ -26,6 +26,7 @@ const LABEL_SIZE: usize = 8; const INTEGER_SIZE: usize = 4; const SOCKET_TIMEOUT: u64 = 5; pub const PROTOCOL_HEADER_SIZE: usize = 4; +pub const PROTOCOL_VERSION: u32 = 1; pub const SERVER_TYPE: u32 = 0x1; // (0x1: indexing server, 0x2: router server) pub const INIT_MSG: u32 = 0x13333337; pub const END_MSG: u32 = 0x31333337; @@ -193,6 +194,7 @@ fn initialize_index( ) -> Result<(usize, ThreadSafeIndex), anyhow::Error> { let buf = vec![0 as u8; INDEX_HEADER_LENGTH]; let mut soc_stream = stream.lock().unwrap(); + soc_stream.write_data(&PROTOCOL_VERSION.to_le_bytes())?; soc_stream.write_data(&SERVER_TYPE.to_le_bytes())?; match read_frame(&mut soc_stream, buf, INDEX_HEADER_LENGTH, Some(INIT_MSG))? { ProtocolMessage::Init(buf) => { @@ -627,7 +629,8 @@ fn start_status_server( stream.write("Content-Type: application/json\r\n".as_bytes())?; stream.write(format!("Content-Length: {response_len}\r\n\r\n").as_bytes())?; stream.write(response_bytes)?; - stream.write(&[0x0D, 0x0A])?; + stream.write(&[0x0D, 0x0A])?; // \r\n + stream.shutdown(Shutdown::Both)?; } Err(e) => { logger.error(&format!("Connection error: {e}")); diff --git a/lantern_cli/tests/external_index_server_test.rs b/lantern_cli/tests/external_index_server_test.rs index 3d37d39..6e6b2e6 100644 --- a/lantern_cli/tests/external_index_server_test.rs +++ b/lantern_cli/tests/external_index_server_test.rs @@ -1,6 +1,8 @@ use isahc::{ReadResponseExt, Request}; use lantern_cli::external_index::cli::UMetricKind; -use lantern_cli::external_index::server::{END_MSG, ERR_MSG, INIT_MSG, PROTOCOL_HEADER_SIZE}; +use lantern_cli::external_index::server::{ + END_MSG, ERR_MSG, INIT_MSG, PROTOCOL_HEADER_SIZE, PROTOCOL_VERSION, +}; use lantern_cli::external_index::{self, cli::IndexServerArgs}; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::pki_types::{CertificateDer, UnixTime}; @@ -142,6 +144,8 @@ async fn test_external_index_server_invalid_header() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let bytes_written = stream.write(&[0, 1, 1, 1, 1, 1]).unwrap(); assert_eq!(bytes_written, 6); @@ -168,6 +172,8 @@ async fn test_external_index_server_short_message() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let bytes_written = stream.write(&[0, 1]).unwrap(); assert_eq!(bytes_written, 2); @@ -228,6 +234,8 @@ async fn test_external_index_server_indexing() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let init_msg = [ @@ -374,6 +382,8 @@ async fn test_external_index_server_indexing_ssl() { let mut stream = rustls::Stream::new(&mut conn, &mut sock); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let init_msg = [ @@ -485,6 +495,8 @@ async fn test_external_index_server_indexing_scalar_quantization() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let init_msg = [ INIT_MSG.to_le_bytes(), @@ -595,6 +607,8 @@ async fn test_external_index_server_indexing_hamming_distance() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let init_msg = [ INIT_MSG.to_le_bytes(), @@ -712,6 +726,8 @@ async fn test_external_index_server_indexing_pq() { let mut stream = TcpStream::connect("127.0.0.1:8998").unwrap(); let mut uint32_buf = [0; 4]; stream.read_exact(&mut uint32_buf).unwrap(); + assert_eq!(u32::from_le_bytes(uint32_buf), PROTOCOL_VERSION); + stream.read_exact(&mut uint32_buf).unwrap(); assert_eq!(u32::from_le_bytes(uint32_buf), 0x1); let init_msg = [ INIT_MSG.to_le_bytes(),