diff --git a/lantern_daemon/src/embedding_jobs.rs b/lantern_daemon/src/embedding_jobs.rs index 606a0b1..7a4b218 100644 --- a/lantern_daemon/src/embedding_jobs.rs +++ b/lantern_daemon/src/embedding_jobs.rs @@ -33,6 +33,7 @@ use itertools::Itertools; use lantern_embeddings::cli::EmbeddingArgs; use lantern_logger::Logger; use lantern_utils::get_full_table_name; +use tokio_postgres::types::ToSql; use std::collections::HashMap; use std::path::Path; use std::process; @@ -88,17 +89,27 @@ async fn unlock_rows( job_id: i32, row_ids: &Vec, ) { - let row_ids_str = row_ids.iter().map(|r| format!("'{r}'")).join(","); + let mut row_ids_query = "".to_owned(); + let mut params: Vec<&(dyn ToSql + Sync)> = row_ids + .iter() + .enumerate() + .map(|(idx, id)| { + let comma = if idx < row_ids.len() - 1 { "," } else { "" }; + row_ids_query = format!("{row_ids_query}${}{comma}", idx+1); + id as &(dyn ToSql + Sync) + }) + .collect(); + params.push(&job_id); let res = client .execute( - &format!("DELETE FROM {lock_table_name} WHERE job_id=$1 AND row_id IN ($2)"), - &[&job_id, &row_ids_str], + &format!("DELETE FROM {lock_table_name} WHERE job_id=${job_id_pos} AND row_id IN ({row_ids_query})", job_id_pos=params.len()), + ¶ms, ) .await; if let Err(e) = res { logger.error(&format!( - "Error while unlocking rows: {row_ids_str} for job: {job_id} : {e}" + "Error while unlocking rows: {:?} for job: {job_id} : {e}",row_ids )); } } diff --git a/lantern_daemon/tests/daemon_test_with_db.rs b/lantern_daemon/tests/daemon_test_with_db.rs index 225351c..bda14ba 100644 --- a/lantern_daemon/tests/daemon_test_with_db.rs +++ b/lantern_daemon/tests/daemon_test_with_db.rs @@ -10,6 +10,7 @@ use lantern_daemon::{ }; use tokio_postgres::{Client, NoTls}; +static EMB_LOCK_TABLE_NAME: &'static str = "_lantern_emb_job_locks"; static EMBEDDING_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_embedding_jobs"; static AUTOTUNE_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_autotune_jobs"; static INDEX_JOBS_TABLE_NAME: &'static str = "_lantern_daemon_index_jobs"; @@ -396,6 +397,16 @@ async fn test_embedding_generation_runtime( .unwrap(); let updated_embedding = updated_embedding.get::>(0); assert_ne!(old_embedding, updated_embedding); + + // Check that all row locks are removed + let locks = db_client + .query( + &format!("SELECT * FROM lantern_test.{EMB_LOCK_TABLE_NAME}"), + &[], + ) + .await + .unwrap(); + assert_eq!(locks.len(), 0); stop_tx.send(()).unwrap(); }