Skip to content

Commit

Permalink
change term scaling from idf-sum to correctly weight each term based …
Browse files Browse the repository at this point in the history
…on the number of documents that match that particular term
  • Loading branch information
mikkeldenker committed Feb 18, 2024
1 parent 2d8973b commit 413a469
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 44 deletions.
79 changes: 51 additions & 28 deletions crates/core/src/ranking/bm25.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// source: https://github.com/quickwit-oss/tantivy/blob/main/src/query/bm25.rs

use itertools::Itertools;
use serde::{Deserialize, Serialize};

use tantivy::fieldnorm::FieldNormReader;
Expand Down Expand Up @@ -35,27 +36,20 @@ pub struct Bm25Params {
}

#[derive(Clone)]
pub struct Bm25Weight {
idf_explain: Explanation,
weight: Score,
cache: [Score; 256],
average_fieldnorm: Score,
pub struct MultiBm25Weight {
weights: Vec<Bm25Weight>,
}

impl Bm25Weight {
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
Bm25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
average_fieldnorm: self.average_fieldnorm,
impl MultiBm25Weight {
pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> tantivy::Result<Self> {
if terms.is_empty() {
return Ok(Self {
weights: Vec::new(),
});
}
}

pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> tantivy::Result<Bm25Weight> {
assert!(!terms.is_empty(), "Bm25 requires at least one term");
let field = terms[0].field();
for term in &terms[1..] {
for term in terms.iter().skip(1) {
assert_eq!(
term.field(),
field,
Expand All @@ -72,21 +66,50 @@ impl Bm25Weight {
}
let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score;

if terms.len() == 1 {
let term_doc_freq = searcher.doc_freq(&terms[0])?;
Ok(Bm25Weight::for_one_term(
let mut weights = Vec::new();

for term in terms {
let term_doc_freq = searcher.doc_freq(term)?;
weights.push(Bm25Weight::for_one_term(
term_doc_freq,
total_num_docs,
average_fieldnorm,
))
} else {
let mut idf_sum: Score = 0.0;
for term in terms {
let term_doc_freq = searcher.doc_freq(term)?;
idf_sum += idf(term_doc_freq, total_num_docs);
}
let idf_explain = Explanation::new("idf", idf_sum);
Ok(Bm25Weight::new(idf_explain, average_fieldnorm))
));
}

Ok(Self { weights })
}

#[inline]
pub fn score(&self, stats: impl Iterator<Item = (u8, u32)>) -> Score {
stats
.zip_eq(self.weights.iter())
.map(|((fieldnorm_id, term_freq), weight)| weight.score(fieldnorm_id, term_freq))
.sum()
}

pub fn boost_by(&self, boost: Score) -> Self {
Self {
weights: self.weights.iter().map(|w| w.boost_by(boost)).collect(),
}
}
}

#[derive(Clone)]
pub struct Bm25Weight {
idf_explain: Explanation,
weight: Score,
cache: [Score; 256],
average_fieldnorm: Score,
}

impl Bm25Weight {
pub fn boost_by(&self, boost: Score) -> Bm25Weight {
Bm25Weight {
idf_explain: self.idf_explain.clone(),
weight: self.weight * boost,
cache: self.cache,
average_fieldnorm: self.average_fieldnorm,
}
}

Expand Down
31 changes: 15 additions & 16 deletions crates/core/src/ranking/signal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use crate::{
webpage::region::{Region, RegionCount},
};

use super::bm25::Bm25Weight;
use super::bm25::MultiBm25Weight;
use super::models::linear::LinearRegression;
use super::{inbound_similarity, query_centrality};

Expand Down Expand Up @@ -249,19 +249,17 @@ fn bm25(field: &mut TextFieldData, doc: DocId) -> f64 {
return 0.0;
}

let mut term_freq = 0;
for posting in &mut field.postings {
if posting.doc() == doc || (posting.doc() < doc && posting.seek(doc) == doc) {
term_freq += posting.term_freq();
}
}

if term_freq == 0 {
return 0.0;
}

let fieldnorm_id = field.fieldnorm_reader.fieldnorm_id(doc);
field.weight.score(fieldnorm_id, term_freq) as f64

field
.weight
.score(field.postings.iter_mut().map(move |posting| {
if posting.doc() == doc || (posting.doc() < doc && posting.seek(doc) == doc) {
(fieldnorm_id, posting.term_freq())
} else {
(fieldnorm_id, 0)
}
})) as f64
}

impl Signal {
Expand Down Expand Up @@ -635,7 +633,7 @@ impl SignalCoefficient {
#[derive(Clone)]
struct TextFieldData {
postings: Vec<SegmentPostings>,
weight: Bm25Weight,
weight: MultiBm25Weight,
fieldnorm_reader: FieldNormReader,
}

Expand Down Expand Up @@ -797,19 +795,20 @@ impl SignalAggregator {
continue;
}

let weight = Bm25Weight::for_terms(tv_searcher, &terms)?;

let fieldnorm_reader = segment_reader.get_fieldnorms_reader(tv_field)?;
let inverted_index = segment_reader.inverted_index(tv_field)?;

let mut matching_terms = Vec::with_capacity(terms.len());
let mut postings = Vec::with_capacity(terms.len());
for term in &terms {
if let Some(p) =
inverted_index.read_postings(term, text_field.index_option())?
{
postings.push(p);
matching_terms.push(term.clone());
}
}
let weight = MultiBm25Weight::for_terms(tv_searcher, &matching_terms)?;

text_fields.insert(
text_field,
Expand Down

0 comments on commit 413a469

Please sign in to comment.