Skip to content

Commit

Permalink
Send index definitions on collection response
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Mar 7, 2024
1 parent f62b863 commit c52d4af
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 12 deletions.
115 changes: 106 additions & 9 deletions lantern_cli/src/http_server/collection.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,83 @@
use bytes::BytesMut;
use futures::SinkExt;
use itertools::Itertools;
use regex::Regex;
use std::collections::HashMap;

use actix_web::{
delete,
error::{ErrorBadRequest, ErrorInternalServerError},
error::{ErrorBadRequest, ErrorInternalServerError, ErrorNotFound},
get,
http::StatusCode,
post, put, web, HttpResponse, Responder, Result,
};

use crate::utils::quote_ident;
use crate::{external_index::cli::UMetricKind, utils::quote_ident};

use super::{AppState, COLLECTION_TABLE_NAME};
use serde::{Deserialize, Serialize};

fn parse_index_def(definition: &str) -> HashMap<String, String> {
let mut result = HashMap::new();
for key in [
"ef_construction",
"ef",
"m",
"dim",
"pq",
"_experimental_index_path",
]
.iter()
{
let regex = Regex::new(&format!(r"[\(,\s]{}='(.*?)'", key)).unwrap();
if let Some(match_) = regex.captures(&definition) {
if *key == "_experimental_index_path" {
result.insert("external".to_string(), "true".to_string());
} else {
result.insert(key.to_string(), match_[1].to_string());
}
}
}

let mut metric = "l2sq".to_string();
let operator_match = Regex::new(r"hnsw\s*\(.*?\s+(\w+)\s*\)")
.unwrap()
.captures(&definition);

if let Some(operator_class) = operator_match {
let _match = operator_class[1].to_string();
let umetric = UMetricKind::from_ops(&_match).unwrap();
metric = umetric.to_string();
}

result.insert("metric".to_string(), metric);

return result;
}

fn parse_indexes(index_definitions: Vec<HashMap<String, String>>) -> Vec<HashMap<String, String>> {
let mut result = Vec::with_capacity(index_definitions.len());
for index_info in &index_definitions {
let mut parsed_info = parse_index_def(index_info.get("definition").unwrap());
parsed_info.insert(
"name".to_owned(),
index_info.get("name").unwrap().to_string(),
);
result.push(parsed_info);
}

result
}

fn get_collection_query(filter: &str) -> String {
format!("SELECT b.name, b.schema, COALESCE(json_agg(json_build_object('name', i.indexname , 'definition', i.indexdef)) FILTER (WHERE i.indexname IS NOT NULL), '[]')::text as indexes FROM (SELECT c.name, json_object_agg(t.column_name, t.data_type)::text as schema FROM {COLLECTION_TABLE_NAME} c INNER JOIN information_schema.columns t ON t.table_name=c.name {filter} GROUP BY c.name) b LEFT JOIN pg_indexes i ON i.tablename=b.name AND i.indexdef ILIKE '%USING lantern_hnsw%' GROUP BY b.name, b.schema")
}

#[derive(Serialize, Debug, utoipa::ToSchema)]
pub struct CollectionInfo {
name: String,
schema: HashMap<String, String>,
indexes: Vec<HashMap<String, String>>,
}
/// Get all collections
#[utoipa::path(
Expand All @@ -33,24 +92,60 @@ pub struct CollectionInfo {
pub async fn list(data: web::Data<AppState>) -> Result<impl Responder> {
let client = data.pool.get().await?;
let rows = client
.query(&format!("SELECT name FROM {COLLECTION_TABLE_NAME}"), &[])
.await;

let rows = match rows {
Ok(rows) => rows,
Err(e) => return Err(ErrorInternalServerError(e)),
};
.query(&get_collection_query(""), &[])
.await
.map_err(ErrorInternalServerError)?;

let tables: Vec<CollectionInfo> = rows
.iter()
.map(|r| CollectionInfo {
name: r.get::<usize, String>(0),
schema: serde_json::from_str(r.get::<usize, &str>(1)).unwrap(),
indexes: parse_indexes(serde_json::from_str(r.get::<usize, &str>(2)).unwrap()),
})
.collect();

Ok(web::Json(tables))
}

/// Get collection by name
#[utoipa::path(
get,
path = "/collections/{name}",
responses(
(status = 200, description = "Returns the collection data", body = CollectionInfo),
(status = 500, description = "Internal Server Error")
),
params(
("name", description = "Collection name")
),
)]
#[get("/collections/{name}")]
pub async fn get(data: web::Data<AppState>, name: web::Path<String>) -> Result<impl Responder> {
let client = data.pool.get().await?;
let rows = client
.query(
&get_collection_query("WHERE c.name=$1"),
&[&name.to_string()],
)
.await
.map_err(ErrorInternalServerError)?;

if rows.is_empty() {
return Err(ErrorNotFound("Collection not found"));
}

let first_row = rows.first().unwrap();

let table: CollectionInfo = CollectionInfo {
name: first_row.get::<usize, String>(0),
schema: serde_json::from_str(first_row.get::<usize, &str>(1)).unwrap(),
indexes: parse_indexes(serde_json::from_str(first_row.get::<usize, &str>(2)).unwrap()),
};

Ok(web::Json(table))
}

#[derive(Deserialize, Debug, Clone, utoipa::ToSchema)]
pub struct CreateTableInput {
name: String,
Expand Down Expand Up @@ -143,6 +238,8 @@ pub async fn create(

Ok(web::Json(CollectionInfo {
name: body.name.clone(),
schema,
indexes: Vec::new(),
}))
}

Expand Down
4 changes: 3 additions & 1 deletion lantern_cli/src/http_server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ pub struct AppState {
The API endpoints are not SQL injection safe, so it can provide maximum flexibility for data manipulation, so please sanitize user input before sending requests to this API."
),
paths(
collection::list,
collection::create,
collection::list,
collection::get,
collection::delete,
collection::insert_data,
search::vector_search,
Expand Down Expand Up @@ -118,6 +119,7 @@ pub async fn start(
.url("/api-docs/openapi.json", ApiDoc::openapi()),
)
.service(collection::list)
.service(collection::get)
.service(collection::create)
.service(collection::delete)
.service(collection::insert_data)
Expand Down
44 changes: 42 additions & 2 deletions lantern_cli/tests/http_server_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ async fn test_collection_create() -> AnyhowVoidResult {
let mut body: Vec<u8> = Vec::with_capacity(body.capacity());
response.copy_to(&mut body)?;
let body_json = String::from_utf8(body)?;
let body_json: HashMap<String, String> = serde_json::from_str(&body_json)?;
println!("Response: {:?}", body_json);
let body_json: HashMap<String, serde_json::Value> = serde_json::from_str(&body_json)?;
assert_eq!(body_json.get("name").unwrap(), TEST_COLLECTION_NAME);

Ok(())
Expand Down Expand Up @@ -128,7 +129,8 @@ async fn test_collection_list() -> AnyhowVoidResult {
let mut body: Vec<u8> = Vec::new();
response.copy_to(&mut body)?;
let body_json = String::from_utf8(body)?;
let body_json: Vec<HashMap<String, String>> = serde_json::from_str(&body_json)?;
println!("Response: {:?}", body_json);
let body_json: Vec<HashMap<String, serde_json::Value>> = serde_json::from_str(&body_json)?;

assert_eq!(body_json.len(), 1);
assert_eq!(
Expand All @@ -139,6 +141,26 @@ async fn test_collection_list() -> AnyhowVoidResult {
Ok(())
}

async fn test_collection_get() -> AnyhowVoidResult {
let mut response = isahc::get(&format!("{SERVER_URL}/collections/{TEST_COLLECTION_NAME}"))?;

let mut body: Vec<u8> = Vec::new();
response.copy_to(&mut body)?;
let body_json = String::from_utf8(body)?;
println!("Response: {:?}", body_json);
let body_json: HashMap<String, serde_json::Value> = serde_json::from_str(&body_json)?;

assert_eq!(body_json.get("name").unwrap(), TEST_COLLECTION_NAME);
assert_eq!(body_json.get("schema").unwrap().get("v").unwrap(), "ARRAY");
assert_eq!(
body_json.get("schema").unwrap().get("id").unwrap(),
"integer"
);
assert_eq!(body_json.get("schema").unwrap().get("m").unwrap(), "jsonb");

Ok(())
}

async fn test_collection_delete() -> AnyhowVoidResult {
isahc::delete(&format!("{SERVER_URL}/collections/{TEST_COLLECTION_NAME}"))?;

Expand Down Expand Up @@ -195,6 +217,23 @@ async fn test_index_create() -> AnyhowVoidResult {
let response = isahc::send(request)?;
assert_eq!(response.status(), StatusCode::from_u16(200)?);

let mut response = isahc::get(&format!("{SERVER_URL}/collections/{TEST_COLLECTION_NAME}"))?;

let mut body: Vec<u8> = Vec::new();
response.copy_to(&mut body)?;
let body_json = String::from_utf8(body)?;
println!("Response: {:?}", body_json);
let body_json: HashMap<String, serde_json::Value> = serde_json::from_str(&body_json)?;
let indexes: Vec<HashMap<String, String>> =
serde_json::from_value(body_json.get("indexes").unwrap().clone())?;

assert_eq!(indexes.len(), 1);
assert_eq!(indexes[0].get("name").unwrap(), "test_idx");
assert_eq!(indexes[0].get("m").unwrap(), "16");
assert_eq!(indexes[0].get("ef_construction").unwrap(), "128");
assert_eq!(indexes[0].get("ef").unwrap(), "64");
assert_eq!(indexes[0].get("dim").unwrap(), "3");
assert_eq!(indexes[0].get("metric").unwrap(), "cos");
Ok(())
}

Expand Down Expand Up @@ -356,6 +395,7 @@ async fn test_http_server() {
let tx = test_setup().await;
test_collection_create().await.unwrap();
test_collection_list().await.unwrap();
test_collection_get().await.unwrap();
test_collection_insert().await.unwrap();
test_pq().await.unwrap();
test_index_create().await.unwrap();
Expand Down

0 comments on commit c52d4af

Please sign in to comment.