Skip to content

Commit

Permalink
perf: enhanced InMemoryDocumentStore BM25 query efficiency with inc…
Browse files Browse the repository at this point in the history
…remental indexing (#7549)

* incorporating better bm25 impl without breaking interface

* all three bm25 algos

* 1. setting algo post-init not allowed; 2. remove extra underscore for naming consistency; 3. remove unused import

* 1. rename attribute name for IDF computation 2. organize document statistics as a dataclass instead of tuple to improve readability

* fix score type initialization (int -> float) to pass mypy check

* release note included

* fixing linting issues and mypy

* fixing tests

* removing heapq import and cleaning up logging

* changing indexing order

* adding more tests

* increasing tests

* removing rank_bm25 from pyproject.toml

---------

Co-authored-by: David S. Batista <[email protected]>
  • Loading branch information
Guest400123064 and davidsbatista authored May 3, 2024
1 parent 48c7c6a commit cd66a80
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 59 deletions.
317 changes: 265 additions & 52 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
import re
from typing import Any, Dict, Iterable, List, Literal, Optional
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple

import numpy as np
from haystack_bm25 import rank_bm25
from tqdm.auto import tqdm

from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
Expand All @@ -24,6 +25,19 @@
DOT_PRODUCT_SCALING_FACTOR = 100


@dataclass
class BM25DocumentStats:
"""
A dataclass for managing document statistics for BM25 retrieval.
:param freq_token: A Counter of token frequencies in the document.
:param doc_len: Number of tokens in the document.
"""

freq_token: Dict[str, int]
doc_len: int


class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
Expand All @@ -50,15 +64,206 @@ def __init__(
To choose the most appropriate function, look for information about your embedding model.
"""
self.storage: Dict[str, Document] = {}
self._bm25_tokenization_regex = bm25_tokenization_regex
self.bm25_tokenization_regex = bm25_tokenization_regex
self.tokenizer = re.compile(bm25_tokenization_regex).findall
algorithm_class = getattr(rank_bm25, bm25_algorithm)
if algorithm_class is None:
raise ValueError(f"BM25 algorithm '{bm25_algorithm}' not found.")
self.bm25_algorithm = algorithm_class

self.bm25_algorithm = bm25_algorithm
self.bm25_algorithm_inst = self._dispatch_bm25()
self.bm25_parameters = bm25_parameters or {}
self.embedding_similarity_function = embedding_similarity_function

# Global BM25 statistics
self._avg_doc_len: float = 0.0
self._freq_vocab_for_idf: Counter = Counter()

# Per-document statistics
self._bm25_attr: Dict[str, BM25DocumentStats] = {}

def _dispatch_bm25(self):
"""
Select the correct BM25 algorithm based on user specification.
:returns:
The BM25 algorithm method.
"""
table = {"BM25Okapi": self._score_bm25okapi, "BM25L": self._score_bm25l, "BM25Plus": self._score_bm25plus}

if self.bm25_algorithm not in table:
raise ValueError(f"BM25 algorithm '{self.bm25_algorithm}' is not supported.")
return table[self.bm25_algorithm]

def _tokenize_bm25(self, text: str) -> List[str]:
"""
Tokenize text using the BM25 tokenization regex.
Here we explicitly create a tokenization method to encapsulate
all pre-processing logic used to create BM25 tokens, such as
lowercasing. This helps track the exact tokenization process
used for BM25 scoring at any given time.
:param text:
The text to tokenize.
:returns:
A list of tokens.
"""
text = text.lower()
return self.tokenizer(text)

def _score_bm25l(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25L scores for the given query and filtered documents.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25L score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
delta = self.bm25_parameters.get("delta", 0.5)

def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation for all tokens."""
idf = {}
n_corpus = len(self._bm25_attr)
for tok in tokens:
n = self._freq_vocab_for_idf.get(tok, 0)
idf[tok] = math.log((n_corpus + 1.0) / (n + 0.5)) * int(n != 0)
return idf

def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
"""Per-token BM25L computation."""
freq_term = freq.get(token, 0.0)
ctd = freq_term / (1 - b + b * doc_len / self._avg_doc_len)
return (1.0 + k) * (ctd + delta) / (k + ctd + delta)

idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}

ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len

score = 0.0
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))

return ret

def _score_bm25okapi(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25Okapi scores for the given query and filtered documents.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25L score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
epsilon = self.bm25_parameters.get("epsilon", 0.25)

def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation for all tokens."""
sum_idf = 0.0
neg_idf_tokens = []

# Although this is a global statistic, we compute it here
# to make the computation more self-contained. And the
# complexity is O(vocab_size), which is acceptable.
idf = {}
for tok, n in self._freq_vocab_for_idf.items():
idf[tok] = math.log((len(self._bm25_attr) - n + 0.5) / (n + 0.5))
sum_idf += idf[tok]
if idf[tok] < 0:
neg_idf_tokens.append(tok)

eps = epsilon * sum_idf / len(self._freq_vocab_for_idf)
for tok in neg_idf_tokens:
idf[tok] = eps
return {tok: idf.get(tok, 0.0) for tok in tokens}

def _compute_tf(token: str, freq: Dict[str, int], doc_len: int) -> float:
"""Per-token BM25L computation."""
freq_term = freq.get(token, 0.0)
freq_norm = freq_term + k * (1 - b + b * doc_len / self._avg_doc_len)
return freq_term * (1.0 + k) / freq_norm

idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}

ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len

score = 0.0
for tok in idf.keys():
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))

return ret

def _score_bm25plus(self, query: str, documents: List[Document]) -> List[Tuple[Document, float]]:
"""
Calculate BM25+ scores for the given query and filtered documents.
This implementation follows the document on BM25 Wikipedia page,
which add 1 (smoothing factor) to document frequency when computing IDF.
:param query:
The query string.
:param documents:
The list of documents to score, should be produced by
the filter_documents method; may be an empty list.
:returns:
A list of tuples, each containing a Document and its BM25+ score.
"""
k = self.bm25_parameters.get("k1", 1.5)
b = self.bm25_parameters.get("b", 0.75)
delta = self.bm25_parameters.get("delta", 1.0)

def _compute_idf(tokens: List[str]) -> Dict[str, float]:
"""Per-token IDF computation."""
idf = {}
n_corpus = len(self._bm25_attr)
for tok in tokens:
n = self._freq_vocab_for_idf.get(tok, 0)
idf[tok] = math.log(1 + (n_corpus - n + 0.5) / (n + 0.5)) * int(n != 0)
return idf

def _compute_tf(token: str, freq: Dict[str, int], doc_len: float) -> float:
"""Per-token normalized term frequency."""
freq_term = freq.get(token, 0.0)
freq_damp = k * (1 - b + b * doc_len / self._avg_doc_len)
return freq_term * (1.0 + k) / (freq_term + freq_damp) + delta

idf = _compute_idf(self._tokenize_bm25(query))
bm25_attr = {doc.id: self._bm25_attr[doc.id] for doc in documents}

ret = []
for doc in documents:
doc_stats = bm25_attr[doc.id]
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len

score = 0.0
for tok in idf.keys(): # pylint: disable=consider-using-dict-items
score += idf[tok] * _compute_tf(tok, freq, doc_len)
ret.append((doc, score))

return ret

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
Expand All @@ -68,8 +273,8 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
bm25_tokenization_regex=self._bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm.__name__,
bm25_tokenization_regex=self.bm25_tokenization_regex,
bm25_algorithm=self.bm25_algorithm,
bm25_parameters=self.bm25_parameters,
embedding_similarity_function=self.embedding_similarity_function,
)
Expand Down Expand Up @@ -132,7 +337,36 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
logger.warning("ID '{document_id}' already exists", document_id=document.id)
written_documents -= 1
continue

# Since the statistics are updated in an incremental manner,
# we need to explicitly remove the existing document to revert
# the statistics before updating them with the new document.
if document.id in self.storage.keys():
self.delete_documents([document.id])

# This processing logic is extracted from the original bm25_retrieval method.
# Since we are creating index incrementally before the first retrieval,
# we need to determine what content to use for indexing here, not at query time.
if document.content is not None:
if document.dataframe is not None:
logger.warning(
"Document '{document_id}' has both text and dataframe content. "
"Using text content for retrieval and skipping dataframe content.",
document_id=document.id,
)
tokens = self._tokenize_bm25(document.content)
elif document.dataframe is not None:
str_content = document.dataframe.astype(str)
csv_content = str_content.to_csv(index=False)
tokens = self._tokenize_bm25(csv_content)
else:
tokens = []

self.storage[document.id] = document

self._bm25_attr[document.id] = BM25DocumentStats(Counter(tokens), len(tokens))
self._freq_vocab_for_idf.update(set(tokens))
self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._bm25_attr)) / (len(self._bm25_attr) + 1)
return written_documents

def delete_documents(self, document_ids: List[str]) -> None:
Expand All @@ -146,6 +380,17 @@ def delete_documents(self, document_ids: List[str]) -> None:
continue
del self.storage[doc_id]

# Update statistics accordingly
doc_stats = self._bm25_attr.pop(doc_id)
freq = doc_stats.freq_token
doc_len = doc_stats.doc_len

self._freq_vocab_for_idf.subtract(Counter(freq.keys()))
try:
self._avg_doc_len = (self._avg_doc_len * (len(self._bm25_attr) + 1) - doc_len) / len(self._bm25_attr)
except ZeroDivisionError:
self._avg_doc_len = 0

def bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, scale_score: bool = False
) -> List[Document]:
Expand Down Expand Up @@ -174,65 +419,33 @@ def bm25_retrieval(
filters = {"operator": "AND", "conditions": [content_type_filter, filters]}
else:
filters = content_type_filter
all_documents = self.filter_documents(filters=filters)

# Lowercase all documents
lower_case_documents = []
for doc in all_documents:
if doc.content is None and doc.dataframe is None:
logger.info(
"Document '{document_id}' has no text or dataframe content. Skipping it.", document_id=doc.id
)
else:
if doc.content is not None:
lower_case_documents.append(doc.content.lower())
if doc.dataframe is not None:
logger.warning(
"Document '{document_id}' has both text and dataframe content. "
"Using text content and skipping dataframe content.",
document_id=doc.id,
)
continue
if doc.dataframe is not None:
str_content = doc.dataframe.astype(str)
csv_content = str_content.to_csv(index=False)
lower_case_documents.append(csv_content.lower())

# Tokenize the entire content of the DocumentStore
tokenized_corpus = [
self.tokenizer(doc) for doc in tqdm(lower_case_documents, unit=" docs", desc="Ranking by BM25...")
]
if len(tokenized_corpus) == 0:
all_documents = self.filter_documents(filters=filters)
if len(all_documents) == 0:
logger.info("No documents found for BM25 retrieval. Returning empty list.")
return []

# initialize BM25
bm25_scorer = self.bm25_algorithm(tokenized_corpus, **self.bm25_parameters)
# tokenize query
tokenized_query = self.tokenizer(query.lower())
# get scores for the query against the corpus
docs_scores = bm25_scorer.get_scores(tokenized_query)
if scale_score:
docs_scores = [expit(float(score / BM25_SCALING_FACTOR)) for score in docs_scores]
# get the last top_k indexes and reverse them
top_docs_positions = np.argsort(docs_scores)[-top_k:][::-1]
results = sorted(self.bm25_algorithm_inst(query, all_documents), key=lambda x: x[1], reverse=True)[:top_k]

# BM25Okapi can return meaningful negative values, so they should not be filtered out when scale_score is False.
# It's the only algorithm supported by rank_bm25 at the time of writing (2024) that can return negative scores.
# see https://github.com/deepset-ai/haystack/pull/6889 for more context.
negatives_are_valid = self.bm25_algorithm is rank_bm25.BM25Okapi and not scale_score
negatives_are_valid = self.bm25_algorithm == "BM25Okapi" and not scale_score

# Create documents with the BM25 score to return them
return_documents = []
for i in top_docs_positions:
doc = all_documents[i]
score = docs_scores[i]
for doc, score in results:
if scale_score:
score = expit(score / BM25_SCALING_FACTOR)

if not negatives_are_valid and score <= 0.0:
continue

doc_fields = doc.to_dict()
doc_fields["score"] = score
return_document = Document.from_dict(doc_fields)
return_documents.append(return_document)

return return_documents

def embedding_retrieval(
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ classifiers = [
]
dependencies = [
"pandas",
"haystack-bm25",
"tqdm",
"tenacity",
"lazy-imports",
Expand Down
Loading

0 comments on commit cd66a80

Please sign in to comment.