Skip to content

Commit

Permalink
Add new splade-v3 model change rwlock to mutex to avoid race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Jun 12, 2024
1 parent d2b6559 commit 8eee625
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 40 deletions.
2 changes: 1 addition & 1 deletion lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_cli"
version = "0.3.1"
version = "0.3.2"
edition = "2021"

[[bin]]
Expand Down
81 changes: 63 additions & 18 deletions lantern_cli/src/embeddings/core/ort_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
collections::HashMap,
io::Cursor,
path::{Path, PathBuf},
sync::{Arc, Mutex, RwLock},
sync::{Arc, Mutex},
time::Duration,
};
use sysinfo::{System, SystemExt};
Expand All @@ -30,9 +30,53 @@ type SessionInput<'a> = ArrayBase<CowRepr<'a, i64>, Dim<IxDynImpl>>;
pub enum PoolingStrategy {
CLS,
Mean,
ReluLogMaxPooling,
}

impl PoolingStrategy {
fn relu_log_max_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
attention_mask: &SessionInput,
output_dims: usize,
) -> Vec<Vec<f32>> {
// Apply ReLU: max(0, x)
let relu_embeddings = embeddings.mapv(|x| x.max(0.0));

// Apply log(1 + x)
let relu_log_embeddings = relu_embeddings.mapv(|x| (1.0 + x).ln());

// Expand attention mask to match embeddings dimensions
let attention_mask_shape = attention_mask.shape();
let input_mask_expanded = attention_mask.clone().insert_axis(Axis(2)).into_owned();
let input_mask_expanded = input_mask_expanded
.broadcast((
attention_mask_shape[0],
attention_mask_shape[1],
output_dims,
))
.unwrap()
.to_owned();
let input_mask_expanded = input_mask_expanded.mapv(|v| v as f32);

// Apply attention mask
let relu_log_embeddings = relu_log_embeddings.to_owned();
let masked_embeddings = &relu_log_embeddings * &input_mask_expanded;

// Find the maximum value across the sequence dimension (Axis 1)
let max_embeddings = masked_embeddings.map_axis(Axis(1), |view| {
view.fold(f32::NEG_INFINITY, |a, &b| a.max(b))
});

// Convert the resulting max_embeddings to a Vec<Vec<f32>>
max_embeddings
.iter()
.map(|s| *s)
.chunks(output_dims)
.into_iter()
.map(|b| b.collect())
.collect()
}

fn cls_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
output_dims: usize,
Expand Down Expand Up @@ -90,6 +134,9 @@ impl PoolingStrategy {
&PoolingStrategy::Mean => {
PoolingStrategy::mean_pooling(embeddings, attention_mask, output_dims)
}
&PoolingStrategy::ReluLogMaxPooling => {
PoolingStrategy::relu_log_max_pooling(embeddings, attention_mask, output_dims)
}
}
}
}
Expand Down Expand Up @@ -242,7 +289,7 @@ impl ModelInfoBuilder {
}

lazy_static! {
static ref MODEL_INFO_MAP: RwLock<HashMap<&'static str, ModelInfo>> = RwLock::new(HashMap::from([
static ref MODEL_INFO_MAP: Mutex<HashMap<&'static str, ModelInfo>> = Mutex::new(HashMap::from([
("clip/ViT-B-32-textual", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/openai/ViT-B-32/textual").with_tokenizer(true).build()),
("clip/ViT-B-32-visual", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/openai/ViT-B-32/visual").with_visual(true).with_input_image_size(224).build()),
("BAAI/bge-small-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/BAAI/bge-small-en-v1.5").with_tokenizer(true).build()),
Expand All @@ -259,7 +306,7 @@ lazy_static! {
("transformers/multi-qa-mpnet-base-dot-v1", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/transformers/multi-qa-mpnet-base-dot-v1").with_tokenizer(true).build()),
("jinaai/jina-embeddings-v2-small-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/jinaai/jina-embeddings-v2-small-en").with_tokenizer(true).with_layer_cnt(4).with_head_cnt(4).with_head_dim(64).with_pooling_strategy(PoolingStrategy::Mean).build()),
("jinaai/jina-embeddings-v2-base-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/jinaai/jina-embeddings-v2-base-en").with_tokenizer(true).with_layer_cnt(12).with_head_cnt(12).with_head_dim(64).with_pooling_strategy(PoolingStrategy::Mean).build()),
("naver/splade-v3", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/naver/splade-v3").with_tokenizer(true).build())
("naver/splade-v3", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/naver/splade-v3").with_tokenizer(true).with_pooling_strategy(PoolingStrategy::ReluLogMaxPooling).build())
]));
}

Expand Down Expand Up @@ -723,16 +770,19 @@ impl<'a> OrtRuntime<'a> {
Ok(())
}

fn check_and_download_files(&self, model_name: &str) -> Result<(), anyhow::Error> {
fn check_and_download_files(
&self,
model_name: &str,
mut models_map: &mut HashMap<&'static str, ModelInfo>,
) -> Result<(), anyhow::Error> {
{
let map = MODEL_INFO_MAP.read().unwrap();
let model_info = map.get(model_name);
let model_info = models_map.get(model_name);

if model_info.is_none() {
anyhow::bail!(
"Model \"{}\" not found.\nAvailable models: {}",
model_name,
map.keys().join(", ")
models_map.keys().join(", ")
)
}

Expand All @@ -744,8 +794,7 @@ impl<'a> OrtRuntime<'a> {
}
}

let mut map_write = MODEL_INFO_MAP.write().unwrap();
let model_info = map_write.get_mut(model_name).unwrap();
let model_info = models_map.get_mut(model_name).unwrap();

let model_folder = Path::join(&Path::new(&self.data_path), model_name);
let model_path = Path::join(&model_folder, "model.onnx");
Expand Down Expand Up @@ -774,9 +823,9 @@ impl<'a> OrtRuntime<'a> {
}

// Check available memory
self.check_available_memory(&model_path, &mut map_write)?;
self.check_available_memory(&model_path, &mut models_map)?;

let model_info = map_write.get_mut(model_name).unwrap();
let model_info = models_map.get_mut(model_name).unwrap();
let encoder = EncoderService::new(
&ONNX_ENV,
model_name,
Expand All @@ -788,7 +837,6 @@ impl<'a> OrtRuntime<'a> {
match encoder {
Ok(enc) => model_info.encoder = Some(enc),
Err(err) => {
drop(map_write);
anyhow::bail!(err)
}
}
Expand Down Expand Up @@ -892,13 +940,13 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
model_name: &str,
inputs: &Vec<&str>,
) -> Result<EmbeddingResult, anyhow::Error> {
let download_result = self.check_and_download_files(model_name);
let mut map = MODEL_INFO_MAP.lock().unwrap();
let download_result = self.check_and_download_files(model_name, &mut map);

if let Err(err) = download_result {
anyhow::bail!("{:?}", err);
}

let map = MODEL_INFO_MAP.read().unwrap();
let model_info = map.get(model_name).unwrap();

let result;
Expand Down Expand Up @@ -979,10 +1027,7 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
result = model_info.encoder.as_ref().unwrap().process_text(inputs);
}

drop(map);

if !self.cache {
let mut map = MODEL_INFO_MAP.write().unwrap();
let model_info = map.get_mut(model_name).unwrap();
model_info.encoder = None;
}
Expand All @@ -996,7 +1041,7 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
}

fn get_available_models(&self) -> (String, Vec<(String, bool)>) {
let map = MODEL_INFO_MAP.read().unwrap();
let map = MODEL_INFO_MAP.lock().unwrap();
let mut res = String::new();
let data_path = &self.data_path;
let mut models = Vec::with_capacity(map.len());
Expand Down
Loading

0 comments on commit 8eee625

Please sign in to comment.