Skip to content

Commit

Permalink
Lantern embeddings: fix row unlock functionality and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Feb 6, 2024
1 parent 3095ad2 commit 82f7bbb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
19 changes: 15 additions & 4 deletions lantern_daemon/src/embedding_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use itertools::Itertools;
use lantern_embeddings::cli::EmbeddingArgs;
use lantern_logger::Logger;
use lantern_utils::get_full_table_name;
use tokio_postgres::types::ToSql;
use std::collections::HashMap;
use std::path::Path;
use std::process;
Expand Down Expand Up @@ -88,17 +89,27 @@ async fn unlock_rows(
job_id: i32,
row_ids: &Vec<String>,
) {
let row_ids_str = row_ids.iter().map(|r| format!("'{r}'")).join(",");
let mut row_ids_query = "".to_owned();
let mut params: Vec<&(dyn ToSql + Sync)> = row_ids
.iter()
.enumerate()
.map(|(idx, id)| {
let comma = if idx < row_ids.len() - 1 { "," } else { "" };
row_ids_query = format!("{row_ids_query}${}{comma}", idx+1);
id as &(dyn ToSql + Sync)
})
.collect();
params.push(&job_id);
let res = client
.execute(
&format!("DELETE FROM {lock_table_name} WHERE job_id=$1 AND row_id IN ($2)"),
&[&job_id, &row_ids_str],
&format!("DELETE FROM {lock_table_name} WHERE job_id=${job_id_pos} AND row_id IN ({row_ids_query})", job_id_pos=params.len()),
&params,
)
.await;

if let Err(e) = res {
logger.error(&format!(
"Error while unlocking rows: {row_ids_str} for job: {job_id} : {e}"
"Error while unlocking rows: {:?} for job: {job_id} : {e}",row_ids
));
}
}
Expand Down
11 changes: 11 additions & 0 deletions lantern_daemon/tests/daemon_test_with_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use lantern_daemon::{
};
use tokio_postgres::{Client, NoTls};

static EMB_LOCK_TABLE_NAME: &'static str = "_lantern_emb_job_locks";
static EMBEDDING_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_embedding_jobs";
static AUTOTUNE_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_autotune_jobs";
static INDEX_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_index_jobs";
Expand Down Expand Up @@ -396,6 +397,16 @@ async fn test_embedding_generation_runtime(
.unwrap();
let updated_embedding = updated_embedding.get::<usize, Vec<f32>>(0);
assert_ne!(old_embedding, updated_embedding);

// Check that all row locks are removed
let locks = db_client
.query(
&format!("SELECT * FROM lantern_test.{EMB_LOCK_TABLE_NAME}"),
&[],
)
.await
.unwrap();
assert_eq!(locks.len(), 0);
stop_tx.send(()).unwrap();
}

Expand Down

0 comments on commit 82f7bbb

Please sign in to comment.