Skip to content

Commit

Permalink
Use attention mask to zero padded tokens when mean pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 committed Mar 1, 2024
1 parent b77db72 commit 5a6cd81
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 17 deletions.
84 changes: 67 additions & 17 deletions lantern_cli/src/embeddings/core/ort_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,63 @@ pub enum PoolingStrategy {
}

impl PoolingStrategy {
fn cls_pooling(embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>) -> Vec<f32> {
embeddings.slice(s![.., 0, ..]).iter().map(|s| *s).collect()
fn cls_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
output_dims: usize,
) -> Vec<Vec<f32>> {
embeddings
.slice(s![.., 0, ..])
.iter()
.map(|s| *s)
.chunks(output_dims)
.into_iter()
.map(|b| b.collect())
.collect()
}

fn mean_pooling(embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>) -> Vec<f32> {
embeddings
.mean_axis(Axis(1))
fn mean_pooling(
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
attention_mask: &SessionInput,
output_dims: usize,
) -> Vec<Vec<f32>> {
let attention_mask_shape = attention_mask.shape();
let input_mask_expanded = attention_mask.clone().insert_axis(Axis(2)).into_owned();
let input_mask_expanded = input_mask_expanded
.broadcast((
attention_mask_shape[0],
attention_mask_shape[1],
output_dims,
))
.unwrap()
.to_owned();
let input_mask_expanded = input_mask_expanded.mapv(|v| v as f32);
let embeddings = embeddings.to_owned();
let masked_embeddings = &embeddings * &input_mask_expanded;

let mean_embeddings = masked_embeddings.sum_axis(Axis(1));
let divider = input_mask_expanded.sum_axis(Axis(1)).mapv(|x| x.max(1e-9));
let embeddings = mean_embeddings / divider;

embeddings
.iter()
.map(|s| *s)
.chunks(output_dims)
.into_iter()
.map(|b| b.collect())
.collect()
}

pub fn pool(&self, embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>) -> Vec<f32> {
pub fn pool(
&self,
embeddings: ViewHolder<'_, f32, Dim<IxDynImpl>>,
attention_mask: &SessionInput,
output_dims: usize,
) -> Vec<Vec<f32>> {
match self {
&PoolingStrategy::CLS => PoolingStrategy::cls_pooling(embeddings),
&PoolingStrategy::Mean => PoolingStrategy::mean_pooling(embeddings),
&PoolingStrategy::CLS => PoolingStrategy::cls_pooling(embeddings, output_dims),
&PoolingStrategy::Mean => {
PoolingStrategy::mean_pooling(embeddings, attention_mask, output_dims)
}
}
}
}
Expand Down Expand Up @@ -363,6 +403,7 @@ impl EncoderService {
let mut vecs = Vec::with_capacity(session.inputs.len());
let mut processed_tokens = 0;

let mut attention_mask_idx = None;
for input in &session.inputs {
let mut val = None;
match input.name.as_str() {
Expand All @@ -380,6 +421,7 @@ impl EncoderService {
.map(|i| i.get_attention_mask().iter().map(|b| *b as i64).collect())
.concat();
val = Some(v);
attention_mask_idx = Some(vecs.len())
}
"token_type_ids" => {
let v: Vec<i64> = preprocessed
Expand All @@ -395,6 +437,16 @@ impl EncoderService {
}
}

if attention_mask_idx.is_none() {
return Err(format!(
"Could not get attention_mask_idx from sesssion inputs: {:?}",
session.inputs
)
.into());
}

let attention_mask_idx = attention_mask_idx.unwrap();

let input_chunks = self.chunk_session_input(vecs, text_len)?;
let embeddings = input_chunks
.iter()
Expand All @@ -409,17 +461,15 @@ impl EncoderService {

let binding = outputs[0].try_extract()?;
let embeddings = binding.view();
let embeddings: Vec<f32> = self.model_params.pooling_strategy.pool(embeddings);

let attention_mask = &chunk[attention_mask_idx];
let output_dims = session.outputs[0].dimensions.last().unwrap().unwrap() as usize;
let embeddings: Vec<Vec<f32>> = self.model_params.pooling_strategy.pool(
embeddings,
attention_mask,
output_dims,
);

Ok(embeddings
.iter()
.map(|s| *s)
.chunks(output_dims)
.into_iter()
.map(|b| b.collect())
.collect::<Vec<Vec<f32>>>())
Ok(embeddings)
})
.collect::<Result<Vec<Vec<Vec<f32>>>, anyhow::Error>>();

Expand Down
Loading

0 comments on commit 5a6cd81

Please sign in to comment.