Skip to content

Commit

Permalink
Splade v3 model (#123)
Browse files Browse the repository at this point in the history
* add splade v3 model

* Add new splade-v3 model change rwlock to mutex to avoid race conditions

* add github badges

* update cli version in github action
  • Loading branch information
var77 authored Jun 12, 2024
1 parent 75dcd6b commit ffdea61
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-cli-docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
type: string
description: "CLI version"
required: true
default: "0.3.1"
default: "0.3.2"
IMAGE_NAME:
type: string
description: "Container image name to tag"
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Lantern Extras

[![build](https://github.com/lanterndata/lantern_extras/actions/workflows/build.yaml/badge.svg?branch=main)](https://github.com/lanterndata/lantern_extras/actions/workflows/build.yaml)
[![test](https://github.com/lanterndata/lantern_extras/actions/workflows/test.yaml/badge.svg?branch=main)](https://github.com/lanterndata/lantern_extras/actions/workflows/test.yaml)
[![codecov](https://codecov.io/github/lanterndata/lantern_extras/branch/main/graph/badge.svg)](https://codecov.io/github/lanterndata/lantern_extras)

This extension makes it easy to experiment with embeddings from inside a Postgres database. We use this extension along with [Lantern](https://github.com/lanterndata/lantern) to make vector operations performant. But all the helpers here are standalone and may be used without the main database.

**NOTE**: Functions defined in this extension use Postgres in ways Postgres is usually not used.
Expand Down
2 changes: 1 addition & 1 deletion lantern_cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_cli"
version = "0.3.1"
version = "0.3.2"
edition = "2021"

[[bin]]
Expand Down
82 changes: 64 additions & 18 deletions lantern_cli/src/embeddings/core/ort_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
collections::HashMap,
io::Cursor,
path::{Path, PathBuf},
sync::{Arc, Mutex, RwLock},
sync::{Arc, Mutex},
time::Duration,
};
use sysinfo::{System, SystemExt};
Expand All @@ -30,9 +30,53 @@ type SessionInput<'a> = ArrayBase<CowRepr<'a, i64>, Dim<IxDynImpl>>;
pub enum PoolingStrategy {
CLS,
Mean,
ReluLogMaxPooling,
}

impl PoolingStrategy {
fn relu_log_max_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
attention_mask: &SessionInput,
output_dims: usize,
) -> Vec<Vec<f32>> {
// Apply ReLU: max(0, x)
let relu_embeddings = embeddings.mapv(|x| x.max(0.0));

// Apply log(1 + x)
let relu_log_embeddings = relu_embeddings.mapv(|x| (1.0 + x).ln());

// Expand attention mask to match embeddings dimensions
let attention_mask_shape = attention_mask.shape();
let input_mask_expanded = attention_mask.clone().insert_axis(Axis(2)).into_owned();
let input_mask_expanded = input_mask_expanded
.broadcast((
attention_mask_shape[0],
attention_mask_shape[1],
output_dims,
))
.unwrap()
.to_owned();
let input_mask_expanded = input_mask_expanded.mapv(|v| v as f32);

// Apply attention mask
let relu_log_embeddings = relu_log_embeddings.to_owned();
let masked_embeddings = &relu_log_embeddings * &input_mask_expanded;

// Find the maximum value across the sequence dimension (Axis 1)
let max_embeddings = masked_embeddings.map_axis(Axis(1), |view| {
view.fold(f32::NEG_INFINITY, |a, &b| a.max(b))
});

// Convert the resulting max_embeddings to a Vec<Vec<f32>>
max_embeddings
.iter()
.map(|s| *s)
.chunks(output_dims)
.into_iter()
.map(|b| b.collect())
.collect()
}

fn cls_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
output_dims: usize,
Expand Down Expand Up @@ -90,6 +134,9 @@ impl PoolingStrategy {
&PoolingStrategy::Mean => {
PoolingStrategy::mean_pooling(embeddings, attention_mask, output_dims)
}
&PoolingStrategy::ReluLogMaxPooling => {
PoolingStrategy::relu_log_max_pooling(embeddings, attention_mask, output_dims)
}
}
}
}
Expand Down Expand Up @@ -242,7 +289,7 @@ impl ModelInfoBuilder {
}

lazy_static! {
static ref MODEL_INFO_MAP: RwLock<HashMap<&'static str, ModelInfo>> = RwLock::new(HashMap::from([
static ref MODEL_INFO_MAP: Mutex<HashMap<&'static str, ModelInfo>> = Mutex::new(HashMap::from([
("clip/ViT-B-32-textual", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/openai/ViT-B-32/textual").with_tokenizer(true).build()),
("clip/ViT-B-32-visual", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/openai/ViT-B-32/visual").with_visual(true).with_input_image_size(224).build()),
("BAAI/bge-small-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/BAAI/bge-small-en-v1.5").with_tokenizer(true).build()),
Expand All @@ -258,7 +305,8 @@ lazy_static! {
("microsoft/all-mpnet-base-v2", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/microsoft/all-mpnet-base-v2").with_tokenizer(true).build()),
("transformers/multi-qa-mpnet-base-dot-v1", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/transformers/multi-qa-mpnet-base-dot-v1").with_tokenizer(true).build()),
("jinaai/jina-embeddings-v2-small-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/jinaai/jina-embeddings-v2-small-en").with_tokenizer(true).with_layer_cnt(4).with_head_cnt(4).with_head_dim(64).with_pooling_strategy(PoolingStrategy::Mean).build()),
("jinaai/jina-embeddings-v2-base-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/jinaai/jina-embeddings-v2-base-en").with_tokenizer(true).with_layer_cnt(12).with_head_cnt(12).with_head_dim(64).with_pooling_strategy(PoolingStrategy::Mean).build())
("jinaai/jina-embeddings-v2-base-en", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/jinaai/jina-embeddings-v2-base-en").with_tokenizer(true).with_layer_cnt(12).with_head_cnt(12).with_head_dim(64).with_pooling_strategy(PoolingStrategy::Mean).build()),
("naver/splade-v3", ModelInfoBuilder::new("https://huggingface.co/varik77/onnx-models/resolve/main/naver/splade-v3").with_tokenizer(true).with_pooling_strategy(PoolingStrategy::ReluLogMaxPooling).build())
]));
}

Expand Down Expand Up @@ -722,16 +770,19 @@ impl<'a> OrtRuntime<'a> {
Ok(())
}

fn check_and_download_files(&self, model_name: &str) -> Result<(), anyhow::Error> {
fn check_and_download_files(
&self,
model_name: &str,
mut models_map: &mut HashMap<&'static str, ModelInfo>,
) -> Result<(), anyhow::Error> {
{
let map = MODEL_INFO_MAP.read().unwrap();
let model_info = map.get(model_name);
let model_info = models_map.get(model_name);

if model_info.is_none() {
anyhow::bail!(
"Model \"{}\" not found.\nAvailable models: {}",
model_name,
map.keys().join(", ")
models_map.keys().join(", ")
)
}

Expand All @@ -743,8 +794,7 @@ impl<'a> OrtRuntime<'a> {
}
}

let mut map_write = MODEL_INFO_MAP.write().unwrap();
let model_info = map_write.get_mut(model_name).unwrap();
let model_info = models_map.get_mut(model_name).unwrap();

let model_folder = Path::join(&Path::new(&self.data_path), model_name);
let model_path = Path::join(&model_folder, "model.onnx");
Expand Down Expand Up @@ -773,9 +823,9 @@ impl<'a> OrtRuntime<'a> {
}

// Check available memory
self.check_available_memory(&model_path, &mut map_write)?;
self.check_available_memory(&model_path, &mut models_map)?;

let model_info = map_write.get_mut(model_name).unwrap();
let model_info = models_map.get_mut(model_name).unwrap();
let encoder = EncoderService::new(
&ONNX_ENV,
model_name,
Expand All @@ -787,7 +837,6 @@ impl<'a> OrtRuntime<'a> {
match encoder {
Ok(enc) => model_info.encoder = Some(enc),
Err(err) => {
drop(map_write);
anyhow::bail!(err)
}
}
Expand Down Expand Up @@ -891,13 +940,13 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
model_name: &str,
inputs: &Vec<&str>,
) -> Result<EmbeddingResult, anyhow::Error> {
let download_result = self.check_and_download_files(model_name);
let mut map = MODEL_INFO_MAP.lock().unwrap();
let download_result = self.check_and_download_files(model_name, &mut map);

if let Err(err) = download_result {
anyhow::bail!("{:?}", err);
}

let map = MODEL_INFO_MAP.read().unwrap();
let model_info = map.get(model_name).unwrap();

let result;
Expand Down Expand Up @@ -978,10 +1027,7 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
result = model_info.encoder.as_ref().unwrap().process_text(inputs);
}

drop(map);

if !self.cache {
let mut map = MODEL_INFO_MAP.write().unwrap();
let model_info = map.get_mut(model_name).unwrap();
model_info.encoder = None;
}
Expand All @@ -995,7 +1041,7 @@ impl<'a> EmbeddingRuntime for OrtRuntime<'a> {
}

fn get_available_models(&self) -> (String, Vec<(String, bool)>) {
let map = MODEL_INFO_MAP.read().unwrap();
let map = MODEL_INFO_MAP.lock().unwrap();
let mut res = String::new();
let data_path = &self.data_path;
let mut models = Vec::with_capacity(map.len());
Expand Down
1 change: 1 addition & 0 deletions lantern_cli/src/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ pub fn get_default_batch_size(model: &str) -> usize {
"thenlper/gte-base" => 1000,
"thenlper/gte-large" => 800,
"microsoft/all-MiniLM-L12-v2" => 1000,
"naver/splade-v3" => 1000,
"microsoft/all-mpnet-base-v2" => 400,
"transformers/multi-qa-mpnet-base-dot-v1" => 300,
"openai/text-embedding-ada-002" => 500,
Expand Down
59 changes: 41 additions & 18 deletions lantern_cli/tests/text_embedding_test.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion lantern_extras/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_extras"
version = "0.2.0"
version = "0.2.1"
edition = "2021"

[lib]
Expand Down

0 comments on commit ffdea61

Please sign in to comment.