diff --git a/crates/core/src/ranking/bm25.rs b/crates/core/src/ranking/bm25.rs index e79de857..06914828 100644 --- a/crates/core/src/ranking/bm25.rs +++ b/crates/core/src/ranking/bm25.rs @@ -1,5 +1,6 @@ // source: https://github.com/quickwit-oss/tantivy/blob/main/src/query/bm25.rs +use itertools::Itertools; use serde::{Deserialize, Serialize}; use tantivy::fieldnorm::FieldNormReader; @@ -35,27 +36,20 @@ pub struct Bm25Params { } #[derive(Clone)] -pub struct Bm25Weight { - idf_explain: Explanation, - weight: Score, - cache: [Score; 256], - average_fieldnorm: Score, +pub struct MultiBm25Weight { + weights: Vec, } -impl Bm25Weight { - pub fn boost_by(&self, boost: Score) -> Bm25Weight { - Bm25Weight { - idf_explain: self.idf_explain.clone(), - weight: self.weight * boost, - cache: self.cache, - average_fieldnorm: self.average_fieldnorm, +impl MultiBm25Weight { + pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> tantivy::Result { + if terms.is_empty() { + return Ok(Self { + weights: Vec::new(), + }); } - } - pub fn for_terms(searcher: &Searcher, terms: &[Term]) -> tantivy::Result { - assert!(!terms.is_empty(), "Bm25 requires at least one term"); let field = terms[0].field(); - for term in &terms[1..] { + for term in terms.iter().skip(1) { assert_eq!( term.field(), field, @@ -72,21 +66,50 @@ impl Bm25Weight { } let average_fieldnorm = total_num_tokens as Score / total_num_docs as Score; - if terms.len() == 1 { - let term_doc_freq = searcher.doc_freq(&terms[0])?; - Ok(Bm25Weight::for_one_term( + let mut weights = Vec::new(); + + for term in terms { + let term_doc_freq = searcher.doc_freq(term)?; + weights.push(Bm25Weight::for_one_term( term_doc_freq, total_num_docs, average_fieldnorm, - )) - } else { - let mut idf_sum: Score = 0.0; - for term in terms { - let term_doc_freq = searcher.doc_freq(term)?; - idf_sum += idf(term_doc_freq, total_num_docs); - } - let idf_explain = Explanation::new("idf", idf_sum); - Ok(Bm25Weight::new(idf_explain, average_fieldnorm)) + )); + } + + Ok(Self { weights }) + } + + #[inline] + pub fn score(&self, stats: impl Iterator) -> Score { + stats + .zip_eq(self.weights.iter()) + .map(|((fieldnorm_id, term_freq), weight)| weight.score(fieldnorm_id, term_freq)) + .sum() + } + + pub fn boost_by(&self, boost: Score) -> Self { + Self { + weights: self.weights.iter().map(|w| w.boost_by(boost)).collect(), + } + } +} + +#[derive(Clone)] +pub struct Bm25Weight { + idf_explain: Explanation, + weight: Score, + cache: [Score; 256], + average_fieldnorm: Score, +} + +impl Bm25Weight { + pub fn boost_by(&self, boost: Score) -> Bm25Weight { + Bm25Weight { + idf_explain: self.idf_explain.clone(), + weight: self.weight * boost, + cache: self.cache, + average_fieldnorm: self.average_fieldnorm, } } diff --git a/crates/core/src/ranking/signal.rs b/crates/core/src/ranking/signal.rs index 35ab50a9..aafe3b0b 100644 --- a/crates/core/src/ranking/signal.rs +++ b/crates/core/src/ranking/signal.rs @@ -45,7 +45,7 @@ use crate::{ webpage::region::{Region, RegionCount}, }; -use super::bm25::Bm25Weight; +use super::bm25::MultiBm25Weight; use super::models::linear::LinearRegression; use super::{inbound_similarity, query_centrality}; @@ -249,19 +249,17 @@ fn bm25(field: &mut TextFieldData, doc: DocId) -> f64 { return 0.0; } - let mut term_freq = 0; - for posting in &mut field.postings { - if posting.doc() == doc || (posting.doc() < doc && posting.seek(doc) == doc) { - term_freq += posting.term_freq(); - } - } - - if term_freq == 0 { - return 0.0; - } - let fieldnorm_id = field.fieldnorm_reader.fieldnorm_id(doc); - field.weight.score(fieldnorm_id, term_freq) as f64 + + field + .weight + .score(field.postings.iter_mut().map(move |posting| { + if posting.doc() == doc || (posting.doc() < doc && posting.seek(doc) == doc) { + (fieldnorm_id, posting.term_freq()) + } else { + (fieldnorm_id, 0) + } + })) as f64 } impl Signal { @@ -635,7 +633,7 @@ impl SignalCoefficient { #[derive(Clone)] struct TextFieldData { postings: Vec, - weight: Bm25Weight, + weight: MultiBm25Weight, fieldnorm_reader: FieldNormReader, } @@ -797,19 +795,20 @@ impl SignalAggregator { continue; } - let weight = Bm25Weight::for_terms(tv_searcher, &terms)?; - let fieldnorm_reader = segment_reader.get_fieldnorms_reader(tv_field)?; let inverted_index = segment_reader.inverted_index(tv_field)?; + let mut matching_terms = Vec::with_capacity(terms.len()); let mut postings = Vec::with_capacity(terms.len()); for term in &terms { if let Some(p) = inverted_index.read_postings(term, text_field.index_option())? { postings.push(p); + matching_terms.push(term.clone()); } } + let weight = MultiBm25Weight::for_terms(tv_searcher, &matching_terms)?; text_fields.insert( text_field,