From ed58436bfe5362038e156ab08953b64fc5e9b776 Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:20:15 +0530 Subject: [PATCH 1/8] Add Diversity Ranker --- haystack/components/rankers/__init__.py | 3 +- haystack/components/rankers/diversity.py | 215 +++++++++++ ...add-diversity-ranker-6ecee21134eda673.yaml | 4 + test/components/rankers/test_diversity.py | 344 ++++++++++++++++++ 4 files changed, 565 insertions(+), 1 deletion(-) create mode 100644 haystack/components/rankers/diversity.py create mode 100644 releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml create mode 100644 test/components/rankers/test_diversity.py diff --git a/haystack/components/rankers/__init__.py b/haystack/components/rankers/__init__.py index bb8c7dd999..330b0576fa 100644 --- a/haystack/components/rankers/__init__.py +++ b/haystack/components/rankers/__init__.py @@ -1,5 +1,6 @@ +from haystack.components.rankers.diversity import DiversityRanker from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker from haystack.components.rankers.meta_field import MetaFieldRanker from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker -__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"] +__all__ = ["DiversityRanker", "LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"] diff --git a/haystack/components/rankers/diversity.py b/haystack/components/rankers/diversity.py new file mode 100644 index 0000000000..4251f288a9 --- /dev/null +++ b/haystack/components/rankers/diversity.py @@ -0,0 +1,215 @@ +import logging +from typing import Any, Dict, List, Literal, Optional + +from haystack import ComponentError, Document, component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_transformers_import: + import torch + from sentence_transformers import SentenceTransformer + + +@component +class DiversityRanker: + """ + Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity + of the documents. + It uses a pre-trained Sentence Transformers model to embed the query and the Documents. + + Usage example: + ```python + from haystack import Document + from haystack.components.rankers import DiversityRanker + + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + docs = output["documents"] + assert len(docs) == 2 + assert docs[0].content == "Paris" + ``` + """ + + def __init__( + self, + model: str = "sentence-transformers/all-MiniLM-L6-v2", + top_k: int = 10, + device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + similarity: Literal["dot_product", "cosine"] = "dot_product", + prefix: str = "", + suffix: str = "", + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Initialize a DiversityRanker. + + :param model: Local path or name of the model in Hugging Face's model hub, + such as `'sentence-transformers/all-MiniLM-L6-v2'`. + :param top_k: The maximum number of Documents to return per query. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. + :param token: The API token used to download private models from Hugging Face. + :param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or + "cosine". + :param prefix: A string to add to the beginning of each Document text before embedding. + Can be used to prepend the text with an instruction, as required by some embedding models, + such as E5 and bge. + :param suffix: A string to add to the end of each Document text before embedding. + :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. + :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + """ + torch_and_transformers_import.check() + + self.model_name_or_path = model + if top_k is None or top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + self.top_k = top_k + self.device = ComponentDevice.resolve_device(device) + self.token = token + self.model = None + if similarity not in ["dot_product", "cosine"]: + raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.") + self.similarity = similarity + self.prefix = prefix + self.suffix = suffix + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def warm_up(self): + """ + Warm up the model used for scoring the Documents. + """ + if self.model is None: + self.model = SentenceTransformer( + model_name_or_path=self.model_name_or_path, + device=self.device.to_torch_str(), + use_auth_token=self.token.resolve_value() if self.token else None, + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model=self.model_name_or_path, + device=self.device.to_dict(), + token=self.token.to_dict() if self.token else None, + top_k=self.top_k, + similarity=self.similarity, + prefix=self.prefix, + suffix=self.suffix, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DiversityRanker": + """ + Deserialize this component from a dictionary. + """ + serialized_device = data["init_parameters"]["device"] + data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) + + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]: + """ + Orders the given list of documents to maximize diversity. + + The algorithm first calculates embeddings for each document and the query. It starts by selecting the document + that is semantically closest to the query. Then, for each remaining document, it selects the one that, on + average, is least similar to the already selected documents. This process continues until all documents are + selected, resulting in a list where each subsequent document contributes the most to the overall diversity of + the selected set. + + :param query: The search query. + :param documents: The list of Document objects to be ranked. + + :return: A list of documents ordered to maximize diversity. + """ + + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] + ] + text_to_embed = ( + self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix + ) + texts_to_embed.append(text_to_embed) + + # Calculate embeddings + doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] + query_embedding = self.model.encode([query], convert_to_tensor=True) # type: ignore[attr-defined] + + # Normalize embeddings to unit length for computing cosine similarity + if self.similarity == "cosine": + doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) + query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) + + n = len(documents) + selected: List[int] = [] + + # Compute the similarity vector between the query and documents + query_doc_sim = query_embedding @ doc_embeddings.T + + # Start with the document with the highest similarity to the query + selected.append(int(torch.argmax(query_doc_sim).item())) + + selected_sum = doc_embeddings[selected[0]] / n + + while len(selected) < n: + # Compute mean of dot products of all selected documents and all other documents + similarities = selected_sum @ doc_embeddings.T + # Mask documents that are already selected + similarities[selected] = torch.inf + # Select the document with the lowest total similarity score + index_unselected = int(torch.argmin(similarities).item()) + selected.append(index_unselected) + # It's enough just to add to the selected vectors because dot product is distributive + # It's divided by n for numerical stability + selected_sum += doc_embeddings[index_unselected] / n + + ranked_docs: List[Document] = [documents[i] for i in selected] + + return ranked_docs + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Rank the documents based on their diversity and return the top_k documents. + + :param query: The query. + :param documents: A list of Document objects that should be ranked. + :param top_k: The maximum number of documents to return. + + :return: A list of top_k documents ranked based on diversity. + """ + if query is None or len(query) == 0: + raise ValueError("Query is empty") + + if not documents: + return {"documents": []} + + if top_k is None: + top_k = self.top_k + elif top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + + if self.model is None: + raise ComponentError( + f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) + + diversity_sorted = self._greedy_diversity_order(query=query, documents=documents) + + return {"documents": diversity_sorted[:top_k]} diff --git a/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml new file mode 100644 index 0000000000..b8a65c10fb --- /dev/null +++ b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add `DiversityRanker`. Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query. diff --git a/test/components/rankers/test_diversity.py b/test/components/rankers/test_diversity.py new file mode 100644 index 0000000000..11f7f01f8f --- /dev/null +++ b/test/components/rankers/test_diversity.py @@ -0,0 +1,344 @@ +import pytest + +from haystack import ComponentError, Document +from haystack.components.rankers import DiversityRanker +from haystack.utils import ComponentDevice +from haystack.utils.auth import Secret + + +class TestDiversityRanker: + def test_init(self): + component = DiversityRanker() + assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert component.top_k == 10 + assert component.device == ComponentDevice.resolve_device(None) + assert component.similarity == "dot_product" + assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert component.prefix == "" + assert component.suffix == "" + assert component.meta_fields_to_embed == [] + assert component.embedding_separator == "\n" + + def test_init_with_custom_init_parameters(self): + component = DiversityRanker( + model="sentence-transformers/msmarco-distilbert-base-v4", + top_k=5, + device=ComponentDevice.from_str("cuda:0"), + token=Secret.from_token("fake-api-token"), + similarity="cosine", + prefix="query:", + suffix="document:", + meta_fields_to_embed=["meta_field"], + embedding_separator="--", + ) + assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" + assert component.top_k == 5 + assert component.device == ComponentDevice.from_str("cuda:0") + assert component.similarity == "cosine" + assert component.token == Secret.from_token("fake-api-token") + assert component.prefix == "query:" + assert component.suffix == "document:" + assert component.meta_fields_to_embed == ["meta_field"] + assert component.embedding_separator == "--" + + def test_to_and_from_dict(self): + component = DiversityRanker() + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.diversity.DiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": ComponentDevice.resolve_device(None).to_dict(), + "similarity": "dot_product", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "", + "suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + ranker = DiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" + assert ranker.top_k == 10 + assert ranker.device == ComponentDevice.resolve_device(None) + assert ranker.similarity == "dot_product" + assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert ranker.prefix == "" + assert ranker.suffix == "" + assert ranker.meta_fields_to_embed == [] + assert ranker.embedding_separator == "\n" + + def test_to_and_from_dict_with_custom_init_parameters(self): + component = DiversityRanker( + model="sentence-transformers/msmarco-distilbert-base-v4", + top_k=5, + device=ComponentDevice.from_str("cuda:0"), + token=Secret.from_env_var("ENV_VAR", strict=False), + similarity="cosine", + prefix="query:", + suffix="document:", + meta_fields_to_embed=["meta_field"], + embedding_separator="--", + ) + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.diversity.DiversityRanker", + "init_parameters": { + "model": "sentence-transformers/msmarco-distilbert-base-v4", + "top_k": 5, + "device": ComponentDevice.from_str("cuda:0").to_dict(), + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "similarity": "cosine", + "prefix": "query:", + "suffix": "document:", + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": "--", + }, + } + + ranker = DiversityRanker.from_dict(data) + + assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" + assert ranker.top_k == 5 + assert ranker.device == ComponentDevice.from_str("cuda:0") + assert ranker.similarity == "cosine" + assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False) + assert ranker.prefix == "query:" + assert ranker.suffix == "document:" + assert ranker.meta_fields_to_embed == ["meta_field"] + assert ranker.embedding_separator == "--" + + def test_run_incorrect_similarity(self): + """ + Tests that run method raises ValueError if similarity is incorrect + """ + with pytest.raises(ValueError): + DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="incorrect") + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_without_warm_up(self, similarity): + """ + Tests that run method raises ComponentError if model is not warmed up + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity) + documents = [Document(content="doc1"), Document(content="doc2")] + + with pytest.raises(ComponentError): + ranker.run(query="test query", documents=documents) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_empty_query(self, similarity): + """ + Test that run method raises ValueError if query is empty or None. + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity) + ranker.warm_up() + documents = [Document(content="doc1"), Document(content="doc2")] + + with pytest.raises(ValueError): + ranker.run(query="", documents=documents) + + with pytest.raises(ValueError): + ranker.run(query=None, documents=documents) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_top_k(self, similarity): + """ + Test that run method returns the correct number of documents for different top_k values passed at + initialization and runtime. + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3) + ranker.warm_up() + query = "test query" + documents = [ + Document(content="doc1"), + Document(content="doc2"), + Document(content="doc3"), + Document(content="doc4"), + ] + + result = ranker.run(query=query, documents=documents) + ranked_docs = result["documents"] + + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 3 + assert all(isinstance(doc, Document) for doc in ranked_docs) + + # Passing a different top_k at runtime + result = ranker.run(query=query, documents=documents, top_k=2) + ranked_docs = result["documents"] + + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 2 + assert all(isinstance(doc, Document) for doc in ranked_docs) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_diversity_ranker_negative_top_k(self, similarity): + """ + Tests that run method raises an error for negative top-k. + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10) + ranker.warm_up() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + + # Setting top_k at runtime + with pytest.raises(ValueError): + ranker.run(query=query, documents=documents, top_k=-5) + + # Setting top_k at init + with pytest.raises(ValueError): + DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_diversity_ranker_top_k_is_none(self, similarity): + """ + Tests that run method returns the correct order of documents for top-k set to None. + """ + # Setting top_k to None at init should raise error + with pytest.raises(ValueError): + DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None) + + # Setting top_k to None is ignored during runtime, it should use top_k set at init. + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2) + ranker.warm_up() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + result = ranker.run(query=query, documents=documents, top_k=None) + + assert len(result["documents"]) == 2 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_with_less_documents_than_top_k(self, similarity): + """ + Tests that run method returns the correct number of documents for top_k values greater than number of documents. + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5) + ranker.warm_up() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 3 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_single_document_corner_case(self, similarity): + """ + Tests that run method returns the correct number of documents for a single document + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker.warm_up() + query = "test" + documents = [Document(content="doc1")] + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 1 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_no_documents_provided(self, similarity): + """ + Test that run method returns an empty list if no documents are supplied. + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker.warm_up() + query = "test query" + documents = [] + results = ranker.run(query=query, documents=documents) + + assert len(results["documents"]) == 0 + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run(self, similarity): + """ + Tests that run method returns documents in the correct order + """ + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker.warm_up() + query = "city" + documents = [ + Document(content="France"), + Document(content="Germany"), + Document(content="Eiffel Tower"), + Document(content="Berlin"), + Document(content="Bananas"), + Document(content="Silicon Valley"), + Document(content="Brandenburg Gate"), + ] + result = ranker.run(query=query, documents=documents) + ranked_docs = result["documents"] + ranked_order = ", ".join([doc.content for doc in ranked_docs]) + expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany" + + assert ranked_order == expected_order + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_real_world_use_case(self, similarity): + ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker.warm_up() + query = "What are the reasons for long-standing animosities between Russia and Poland?" + + doc1 = Document( + "One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , " + "Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by " + "that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland " + "accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was " + "Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into " + "the Catholic and Orthodox branches separating the Poles from the Eastern Slavs." + ) + doc2 = Document( + "Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the " + "Polish Russian border has mostly been replaced by borders with the respective countries, but there still " + "is a 210 km long border between Poland and the Kaliningrad Oblast" + ) + doc3 = Document( + "As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr " + "Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of " + "the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the " + "Stockholm Arbitrary Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices " + "should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when " + "PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG." + ) + doc4 = Document( + "Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly " + "accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in " + "2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was " + "not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a " + "notorious Sobibor extermination camp." + ) + doc5 = Document( + "President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish Russian " + "relations begin with the fall of communism in1989 in Poland ( Solidarity and the Polish Round Table " + "Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after " + "the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly " + "independent states , including the Russian Federation . Relations between modern Poland and Russia suffer " + "from constant ups and downs." + ) + doc6 = Document( + "Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections " + "in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Block , and " + "finally the formal dissolution of the Warsaw Pact." + ) + doc7 = Document( + "Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of " + "the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish " + "relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase " + "suspicions toward Russia in Poland." + ) + doc8 = Document( + "Soviet control over the Polish People's Republic lessened after Stalin's death and Gomulka's Thaw , and " + "ceased completely after the fall of the communist government in Poland in late 1989, although the " + "Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military " + "presence allowed the Soviet Union to heavily influence Polish politics." + ) + + documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8] + result = ranker.run(query=query, documents=documents) + expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8] + expected_content = " ".join([doc.content or "" for doc in expected_order]) + result_content = " ".join([doc.content or "" for doc in result["documents"]]) + + # Check the order of ranked documents by comparing the content of the ranked documents + assert result_content == expected_content From 0aed48c1c55b557d4a1e7629f4e7fbe7fe72671e Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:49:51 +0530 Subject: [PATCH 2/8] Update tests --- test/components/rankers/test_diversity.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/components/rankers/test_diversity.py b/test/components/rankers/test_diversity.py index 11f7f01f8f..1d21b37e48 100644 --- a/test/components/rankers/test_diversity.py +++ b/test/components/rankers/test_diversity.py @@ -129,6 +129,7 @@ def test_run_without_warm_up(self, similarity): with pytest.raises(ComponentError): ranker.run(query="test query", documents=documents) + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_empty_query(self, similarity): """ @@ -144,6 +145,7 @@ def test_run_empty_query(self, similarity): with pytest.raises(ValueError): ranker.run(query=None, documents=documents) + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_top_k(self, similarity): """ @@ -175,6 +177,7 @@ def test_run_top_k(self, similarity): assert len(ranked_docs) == 2 assert all(isinstance(doc, Document) for doc in ranked_docs) + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_diversity_ranker_negative_top_k(self, similarity): """ @@ -193,6 +196,7 @@ def test_diversity_ranker_negative_top_k(self, similarity): with pytest.raises(ValueError): DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5) + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_diversity_ranker_top_k_is_none(self, similarity): """ @@ -211,6 +215,7 @@ def test_diversity_ranker_top_k_is_none(self, similarity): assert len(result["documents"]) == 2 + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_with_less_documents_than_top_k(self, similarity): """ @@ -224,6 +229,7 @@ def test_run_with_less_documents_than_top_k(self, similarity): assert len(result["documents"]) == 3 + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_single_document_corner_case(self, similarity): """ @@ -237,6 +243,7 @@ def test_run_single_document_corner_case(self, similarity): assert len(result["documents"]) == 1 + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_no_documents_provided(self, similarity): """ @@ -250,6 +257,7 @@ def test_run_no_documents_provided(self, similarity): assert len(results["documents"]) == 0 + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run(self, similarity): """ @@ -274,6 +282,7 @@ def test_run(self, similarity): assert ranked_order == expected_order + @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_real_world_use_case(self, similarity): ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) From 9955ada1fb1a803d0dbb0dfcf445630645387634 Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 26 Feb 2024 21:18:08 +0530 Subject: [PATCH 3/8] Add separate suffix, prefix params for query and documents; allow empty query --- haystack/components/rankers/diversity.py | 37 ++++++++------ test/components/rankers/test_diversity.py | 59 +++++++++++++++-------- 2 files changed, 61 insertions(+), 35 deletions(-) diff --git a/haystack/components/rankers/diversity.py b/haystack/components/rankers/diversity.py index 4251f288a9..e8b66e137c 100644 --- a/haystack/components/rankers/diversity.py +++ b/haystack/components/rankers/diversity.py @@ -42,8 +42,10 @@ def __init__( device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), similarity: Literal["dot_product", "cosine"] = "dot_product", - prefix: str = "", - suffix: str = "", + query_prefix: str = "", + document_prefix: str = "", + query_suffix: str = "", + document_suffix: str = "", meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", ): @@ -58,10 +60,14 @@ def __init__( :param token: The API token used to download private models from Hugging Face. :param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or "cosine". - :param prefix: A string to add to the beginning of each Document text before embedding. + :param query_prefix: A string to add to the beginning of the query text before ranking. + Can be used to prepend the text with an instruction, as required by some re-ranking models, + such as E5 and BGE. + :param document_prefix: A string to add to the beginning of each Document text before ranking. Can be used to prepend the text with an instruction, as required by some embedding models, - such as E5 and bge. - :param suffix: A string to add to the end of each Document text before embedding. + such as E5 and BGE. + :param query_suffix: A string to add to the end of the query text before ranking. + :param document_suffix: A string to add to the end of each Document text before ranking. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ @@ -77,8 +83,10 @@ def __init__( if similarity not in ["dot_product", "cosine"]: raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.") self.similarity = similarity - self.prefix = prefix - self.suffix = suffix + self.query_prefix = query_prefix + self.document_prefix = document_prefix + self.query_suffix = query_suffix + self.document_suffix = document_suffix self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator @@ -104,8 +112,10 @@ def to_dict(self) -> Dict[str, Any]: token=self.token.to_dict() if self.token else None, top_k=self.top_k, similarity=self.similarity, - prefix=self.prefix, - suffix=self.suffix, + query_prefix=self.query_prefix, + document_prefix=self.document_prefix, + query_suffix=self.query_suffix, + document_suffix=self.document_suffix, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, ) @@ -143,13 +153,15 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] ] text_to_embed = ( - self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix + self.document_prefix + + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + + self.document_suffix ) texts_to_embed.append(text_to_embed) # Calculate embeddings doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] - query_embedding = self.model.encode([query], convert_to_tensor=True) # type: ignore[attr-defined] + query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined] # Normalize embeddings to unit length for computing cosine similarity if self.similarity == "cosine": @@ -194,9 +206,6 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None :return: A list of top_k documents ranked based on diversity. """ - if query is None or len(query) == 0: - raise ValueError("Query is empty") - if not documents: return {"documents": []} diff --git a/test/components/rankers/test_diversity.py b/test/components/rankers/test_diversity.py index 1d21b37e48..68cbcd3bfe 100644 --- a/test/components/rankers/test_diversity.py +++ b/test/components/rankers/test_diversity.py @@ -14,8 +14,10 @@ def test_init(self): assert component.device == ComponentDevice.resolve_device(None) assert component.similarity == "dot_product" assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert component.prefix == "" - assert component.suffix == "" + assert component.query_prefix == "" + assert component.document_prefix == "" + assert component.query_suffix == "" + assert component.document_suffix == "" assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" @@ -26,8 +28,10 @@ def test_init_with_custom_init_parameters(self): device=ComponentDevice.from_str("cuda:0"), token=Secret.from_token("fake-api-token"), similarity="cosine", - prefix="query:", - suffix="document:", + query_prefix="query:", + document_prefix="document:", + query_suffix="query suffix", + document_suffix="document suffix", meta_fields_to_embed=["meta_field"], embedding_separator="--", ) @@ -36,8 +40,10 @@ def test_init_with_custom_init_parameters(self): assert component.device == ComponentDevice.from_str("cuda:0") assert component.similarity == "cosine" assert component.token == Secret.from_token("fake-api-token") - assert component.prefix == "query:" - assert component.suffix == "document:" + assert component.query_prefix == "query:" + assert component.document_prefix == "document:" + assert component.query_suffix == "query suffix" + assert component.document_suffix == "document suffix" assert component.meta_fields_to_embed == ["meta_field"] assert component.embedding_separator == "--" @@ -52,8 +58,10 @@ def test_to_and_from_dict(self): "device": ComponentDevice.resolve_device(None).to_dict(), "similarity": "dot_product", "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, - "prefix": "", - "suffix": "", + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", "meta_fields_to_embed": [], "embedding_separator": "\n", }, @@ -66,8 +74,10 @@ def test_to_and_from_dict(self): assert ranker.device == ComponentDevice.resolve_device(None) assert ranker.similarity == "dot_product" assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert ranker.prefix == "" - assert ranker.suffix == "" + assert ranker.query_prefix == "" + assert ranker.document_prefix == "" + assert ranker.query_suffix == "" + assert ranker.document_suffix == "" assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" @@ -78,8 +88,10 @@ def test_to_and_from_dict_with_custom_init_parameters(self): device=ComponentDevice.from_str("cuda:0"), token=Secret.from_env_var("ENV_VAR", strict=False), similarity="cosine", - prefix="query:", - suffix="document:", + query_prefix="query:", + document_prefix="document:", + query_suffix="query suffix", + document_suffix="document suffix", meta_fields_to_embed=["meta_field"], embedding_separator="--", ) @@ -92,8 +104,10 @@ def test_to_and_from_dict_with_custom_init_parameters(self): "device": ComponentDevice.from_str("cuda:0").to_dict(), "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "similarity": "cosine", - "prefix": "query:", - "suffix": "document:", + "query_prefix": "query:", + "document_prefix": "document:", + "query_suffix": "query suffix", + "document_suffix": "document suffix", "meta_fields_to_embed": ["meta_field"], "embedding_separator": "--", }, @@ -106,8 +120,10 @@ def test_to_and_from_dict_with_custom_init_parameters(self): assert ranker.device == ComponentDevice.from_str("cuda:0") assert ranker.similarity == "cosine" assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False) - assert ranker.prefix == "query:" - assert ranker.suffix == "document:" + assert ranker.query_prefix == "query:" + assert ranker.document_prefix == "document:" + assert ranker.query_suffix == "query suffix" + assert ranker.document_suffix == "document suffix" assert ranker.meta_fields_to_embed == ["meta_field"] assert ranker.embedding_separator == "--" @@ -133,17 +149,18 @@ def test_run_without_warm_up(self, similarity): @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_empty_query(self, similarity): """ - Test that run method raises ValueError if query is empty or None. + Test that ranker can be run with an empty query. """ ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity) ranker.warm_up() documents = [Document(content="doc1"), Document(content="doc2")] - with pytest.raises(ValueError): - ranker.run(query="", documents=documents) + result = ranker.run(query="", documents=documents) + ranked_docs = result["documents"] - with pytest.raises(ValueError): - ranker.run(query=None, documents=documents) + assert isinstance(ranked_docs, list) + assert len(ranked_docs) == 2 + assert all(isinstance(doc, Document) for doc in ranked_docs) @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) From 5ab2a327a9961cb3922bc842195ab1050ec1f4f5 Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Tue, 27 Feb 2024 00:15:35 +0530 Subject: [PATCH 4/8] Update docstrings --- haystack/components/rankers/diversity.py | 36 ++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/haystack/components/rankers/diversity.py b/haystack/components/rankers/diversity.py index e8b66e137c..0047d7b272 100644 --- a/haystack/components/rankers/diversity.py +++ b/haystack/components/rankers/diversity.py @@ -18,7 +18,10 @@ class DiversityRanker: """ Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity of the documents. - It uses a pre-trained Sentence Transformers model to embed the query and the Documents. + + This component provides functionality to rank a list of documents based on their similarity with respect to the + query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and + the Documents. Usage example: ```python @@ -26,12 +29,12 @@ class DiversityRanker: from haystack.components.rankers import DiversityRanker ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") + ranker.warm_up() + docs = [Document(content="Paris"), Document(content="Berlin")] query = "What is the capital of germany?" output = ranker.run(query=query, documents=docs) docs = output["documents"] - assert len(docs) == 2 - assert docs[0].content == "Paris" ``` """ @@ -92,7 +95,7 @@ def __init__( def warm_up(self): """ - Warm up the model used for scoring the Documents. + Initializes the component. """ if self.model is None: self.model = SentenceTransformer( @@ -103,7 +106,10 @@ def warm_up(self): def to_dict(self) -> Dict[str, Any]: """ - Serialize this component to a dictionary. + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. """ return default_to_dict( self, @@ -123,7 +129,12 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "DiversityRanker": """ - Deserialize this component from a dictionary. + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. """ serialized_device = data["init_parameters"]["device"] data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device) @@ -198,13 +209,16 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List @component.output_types(documents=List[Document]) def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): """ - Rank the documents based on their diversity and return the top_k documents. + Rank the documents based on their diversity. + + :param query: The search query. + :param documents: List of Document objects to be ranker. + :param top_k: Optional. An integer to override the top_k set during initialization. - :param query: The query. - :param documents: A list of Document objects that should be ranked. - :param top_k: The maximum number of documents to return. + :returns: A dictionary with the following key: + - `documents`: List of Document objects that have been selected based on the diversity ranking. - :return: A list of top_k documents ranked based on diversity. + :raises ValueError: If the top_k value is less than or equal to 0. """ if not documents: return {"documents": []} From 0d923b211ab4535fb4711ec716ed59bf8b3c7e1b Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Thu, 7 Mar 2024 14:34:30 +0530 Subject: [PATCH 5/8] Make changes based on review --- haystack/components/rankers/__init__.py | 9 +- ....py => sentence_transformers_diversity.py} | 25 +++--- ...> test_sentence_transformers_diversity.py} | 90 ++++++++++++------- 3 files changed, 76 insertions(+), 48 deletions(-) rename haystack/components/rankers/{diversity.py => sentence_transformers_diversity.py} (93%) rename test/components/rankers/{test_diversity.py => test_sentence_transformers_diversity.py} (85%) diff --git a/haystack/components/rankers/__init__.py b/haystack/components/rankers/__init__.py index 330b0576fa..282cf5cf2f 100644 --- a/haystack/components/rankers/__init__.py +++ b/haystack/components/rankers/__init__.py @@ -1,6 +1,11 @@ -from haystack.components.rankers.diversity import DiversityRanker from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker from haystack.components.rankers.meta_field import MetaFieldRanker +from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker -__all__ = ["DiversityRanker", "LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"] +__all__ = [ + "LostInTheMiddleRanker", + "MetaFieldRanker", + "SentenceTransformersDiversityRanker", + "TransformersSimilarityRanker", +] diff --git a/haystack/components/rankers/diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py similarity index 93% rename from haystack/components/rankers/diversity.py rename to haystack/components/rankers/sentence_transformers_diversity.py index 0047d7b272..382daaa661 100644 --- a/haystack/components/rankers/diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -1,20 +1,19 @@ -import logging from typing import Any, Dict, List, Literal, Optional -from haystack import ComponentError, Document, component, default_from_dict, default_to_dict +from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging from haystack.lazy_imports import LazyImport from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace logger = logging.getLogger(__name__) -with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_transformers_import: +with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_sentence_transformers_import: import torch from sentence_transformers import SentenceTransformer @component -class DiversityRanker: +class SentenceTransformersDiversityRanker: """ Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity of the documents. @@ -26,9 +25,9 @@ class DiversityRanker: Usage example: ```python from haystack import Document - from haystack.components.rankers import DiversityRanker + from haystack.components.rankers import SentenceTransformersDiversityRanker - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") + ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine") ranker.warm_up() docs = [Document(content="Paris"), Document(content="Berlin")] @@ -44,16 +43,16 @@ def __init__( top_k: int = 10, device: Optional[ComponentDevice] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), - similarity: Literal["dot_product", "cosine"] = "dot_product", + similarity: Literal["dot_product", "cosine"] = "cosine", query_prefix: str = "", - document_prefix: str = "", query_suffix: str = "", + document_prefix: str = "", document_suffix: str = "", meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", ): """ - Initialize a DiversityRanker. + Initialize a SentenceTransformersDiversityRanker. :param model: Local path or name of the model in Hugging Face's model hub, such as `'sentence-transformers/all-MiniLM-L6-v2'`. @@ -64,17 +63,17 @@ def __init__( :param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or "cosine". :param query_prefix: A string to add to the beginning of the query text before ranking. - Can be used to prepend the text with an instruction, as required by some re-ranking models, + Can be used to prepend the text with an instruction, as required by some embedding models, such as E5 and BGE. + :param query_suffix: A string to add to the end of the query text before ranking. :param document_prefix: A string to add to the beginning of each Document text before ranking. Can be used to prepend the text with an instruction, as required by some embedding models, such as E5 and BGE. - :param query_suffix: A string to add to the end of the query text before ranking. :param document_suffix: A string to add to the end of each Document text before ranking. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ - torch_and_transformers_import.check() + torch_and_sentence_transformers_import.check() self.model_name_or_path = model if top_k is None or top_k <= 0: @@ -127,7 +126,7 @@ def to_dict(self) -> Dict[str, Any]: ) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "DiversityRanker": + def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker": """ Deserializes the component from a dictionary. diff --git a/test/components/rankers/test_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py similarity index 85% rename from test/components/rankers/test_diversity.py rename to test/components/rankers/test_sentence_transformers_diversity.py index 68cbcd3bfe..6f3780c8dd 100644 --- a/test/components/rankers/test_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -1,18 +1,18 @@ import pytest from haystack import ComponentError, Document -from haystack.components.rankers import DiversityRanker +from haystack.components.rankers import SentenceTransformersDiversityRanker from haystack.utils import ComponentDevice from haystack.utils.auth import Secret -class TestDiversityRanker: +class TestSentenceTransformersDiversityRanker: def test_init(self): - component = DiversityRanker() + component = SentenceTransformersDiversityRanker() assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert component.top_k == 10 assert component.device == ComponentDevice.resolve_device(None) - assert component.similarity == "dot_product" + assert component.similarity == "cosine" assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert component.query_prefix == "" assert component.document_prefix == "" @@ -22,12 +22,12 @@ def test_init(self): assert component.embedding_separator == "\n" def test_init_with_custom_init_parameters(self): - component = DiversityRanker( + component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", top_k=5, device=ComponentDevice.from_str("cuda:0"), token=Secret.from_token("fake-api-token"), - similarity="cosine", + similarity="dot_product", query_prefix="query:", document_prefix="document:", query_suffix="query suffix", @@ -38,7 +38,7 @@ def test_init_with_custom_init_parameters(self): assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" assert component.top_k == 5 assert component.device == ComponentDevice.from_str("cuda:0") - assert component.similarity == "cosine" + assert component.similarity == "dot_product" assert component.token == Secret.from_token("fake-api-token") assert component.query_prefix == "query:" assert component.document_prefix == "document:" @@ -48,15 +48,15 @@ def test_init_with_custom_init_parameters(self): assert component.embedding_separator == "--" def test_to_and_from_dict(self): - component = DiversityRanker() + component = SentenceTransformersDiversityRanker() data = component.to_dict() assert data == { - "type": "haystack.components.rankers.diversity.DiversityRanker", + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", "init_parameters": { "model": "sentence-transformers/all-MiniLM-L6-v2", "top_k": 10, "device": ComponentDevice.resolve_device(None).to_dict(), - "similarity": "dot_product", + "similarity": "cosine", "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, "query_prefix": "", "document_prefix": "", @@ -67,12 +67,12 @@ def test_to_and_from_dict(self): }, } - ranker = DiversityRanker.from_dict(data) + ranker = SentenceTransformersDiversityRanker.from_dict(data) assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" assert ranker.top_k == 10 assert ranker.device == ComponentDevice.resolve_device(None) - assert ranker.similarity == "dot_product" + assert ranker.similarity == "cosine" assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False) assert ranker.query_prefix == "" assert ranker.document_prefix == "" @@ -82,12 +82,12 @@ def test_to_and_from_dict(self): assert ranker.embedding_separator == "\n" def test_to_and_from_dict_with_custom_init_parameters(self): - component = DiversityRanker( + component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", top_k=5, device=ComponentDevice.from_str("cuda:0"), token=Secret.from_env_var("ENV_VAR", strict=False), - similarity="cosine", + similarity="dot_product", query_prefix="query:", document_prefix="document:", query_suffix="query suffix", @@ -97,13 +97,13 @@ def test_to_and_from_dict_with_custom_init_parameters(self): ) data = component.to_dict() assert data == { - "type": "haystack.components.rankers.diversity.DiversityRanker", + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", "init_parameters": { "model": "sentence-transformers/msmarco-distilbert-base-v4", "top_k": 5, "device": ComponentDevice.from_str("cuda:0").to_dict(), "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, - "similarity": "cosine", + "similarity": "dot_product", "query_prefix": "query:", "document_prefix": "document:", "query_suffix": "query suffix", @@ -113,12 +113,12 @@ def test_to_and_from_dict_with_custom_init_parameters(self): }, } - ranker = DiversityRanker.from_dict(data) + ranker = SentenceTransformersDiversityRanker.from_dict(data) assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" assert ranker.top_k == 5 assert ranker.device == ComponentDevice.from_str("cuda:0") - assert ranker.similarity == "cosine" + assert ranker.similarity == "dot_product" assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False) assert ranker.query_prefix == "query:" assert ranker.document_prefix == "document:" @@ -132,14 +132,16 @@ def test_run_incorrect_similarity(self): Tests that run method raises ValueError if similarity is incorrect """ with pytest.raises(ValueError): - DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="incorrect") + SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="incorrect") @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_without_warm_up(self, similarity): """ Tests that run method raises ComponentError if model is not warmed up """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity + ) documents = [Document(content="doc1"), Document(content="doc2")] with pytest.raises(ComponentError): @@ -151,7 +153,9 @@ def test_run_empty_query(self, similarity): """ Test that ranker can be run with an empty query. """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity + ) ranker.warm_up() documents = [Document(content="doc1"), Document(content="doc2")] @@ -169,7 +173,9 @@ def test_run_top_k(self, similarity): Test that run method returns the correct number of documents for different top_k values passed at initialization and runtime. """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3 + ) ranker.warm_up() query = "test query" documents = [ @@ -196,11 +202,13 @@ def test_run_top_k(self, similarity): @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) - def test_diversity_ranker_negative_top_k(self, similarity): + def test_run_negative_top_k(self, similarity): """ Tests that run method raises an error for negative top-k. """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10 + ) ranker.warm_up() query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] @@ -211,20 +219,26 @@ def test_diversity_ranker_negative_top_k(self, similarity): # Setting top_k at init with pytest.raises(ValueError): - DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5) + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5 + ) @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) - def test_diversity_ranker_top_k_is_none(self, similarity): + def test_run_top_k_is_none(self, similarity): """ Tests that run method returns the correct order of documents for top-k set to None. """ # Setting top_k to None at init should raise error with pytest.raises(ValueError): - DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None) + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None + ) # Setting top_k to None is ignored during runtime, it should use top_k set at init. - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2 + ) ranker.warm_up() query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] @@ -238,7 +252,9 @@ def test_run_with_less_documents_than_top_k(self, similarity): """ Tests that run method returns the correct number of documents for top_k values greater than number of documents. """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5 + ) ranker.warm_up() query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] @@ -252,7 +268,9 @@ def test_run_single_document_corner_case(self, similarity): """ Tests that run method returns the correct number of documents for a single document """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) ranker.warm_up() query = "test" documents = [Document(content="doc1")] @@ -266,7 +284,9 @@ def test_run_no_documents_provided(self, similarity): """ Test that run method returns an empty list if no documents are supplied. """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) ranker.warm_up() query = "test query" documents = [] @@ -280,7 +300,9 @@ def test_run(self, similarity): """ Tests that run method returns documents in the correct order """ - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) ranker.warm_up() query = "city" documents = [ @@ -302,7 +324,9 @@ def test_run(self, similarity): @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_real_world_use_case(self, similarity): - ranker = DiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) ranker.warm_up() query = "What are the reasons for long-standing animosities between Russia and Poland?" From b6cdbcbbea5b93b3a3c2957e3e90797047e0e575 Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:59:05 +0530 Subject: [PATCH 6/8] Add additional tests --- .../sentence_transformers_diversity.py | 37 +-- .../test_sentence_transformers_diversity.py | 215 ++++++++++++++---- 2 files changed, 198 insertions(+), 54 deletions(-) diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index 382daaa661..86915eb657 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -141,6 +141,24 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) return default_from_dict(cls, data) + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] + ] + text_to_embed = ( + self.document_prefix + + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + + self.document_suffix + ) + texts_to_embed.append(text_to_embed) + + return texts_to_embed + def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]: """ Orders the given list of documents to maximize diversity. @@ -156,18 +174,7 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List :return: A list of documents ordered to maximize diversity. """ - - texts_to_embed = [] - for doc in documents: - meta_values_to_embed = [ - str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] - ] - text_to_embed = ( - self.document_prefix - + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) - + self.document_suffix - ) - texts_to_embed.append(text_to_embed) + texts_to_embed = self._prepare_texts_to_embed(documents) # Calculate embeddings doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined] @@ -228,9 +235,11 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None raise ValueError(f"top_k must be > 0, but got {top_k}") if self.model is None: - raise ComponentError( - f"The component {self.__class__.__name__} wasn't warmed up. Run 'warm_up()' before calling 'run()'." + error_msg = ( + "The component SentenceTransformersDiversityRanker wasn't warmed up. " + "Run 'warm_up()' before calling 'run()'." ) + raise ComponentError(error_msg) diversity_sorted = self._greedy_diversity_order(query=query, documents=documents) diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index 6f3780c8dd..600730c2ad 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -1,4 +1,7 @@ +from unittest.mock import MagicMock, call + import pytest +import torch from haystack import ComponentError, Document from haystack.components.rankers import SentenceTransformersDiversityRanker @@ -6,6 +9,15 @@ from haystack.utils.auth import Secret +def mock_encode_response(texts, **kwargs): + if texts == ["city"]: + return torch.tensor([[1.0, 1.0]]) + elif texts == ["Eiffel Tower", "Berlin", "Bananas"]: + return torch.tensor([[1.0, 0.0], [0.8, 0.8], [0.0, 1.0]]) + else: + return torch.tensor([[0.0, 1.0]] * len(texts)) + + class TestSentenceTransformersDiversityRanker: def test_init(self): component = SentenceTransformersDiversityRanker() @@ -47,7 +59,7 @@ def test_init_with_custom_init_parameters(self): assert component.meta_fields_to_embed == ["meta_field"] assert component.embedding_separator == "--" - def test_to_and_from_dict(self): + def test_to_dict(self): component = SentenceTransformersDiversityRanker() data = component.to_dict() assert data == { @@ -67,6 +79,23 @@ def test_to_and_from_dict(self): }, } + def test_from_dict(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/all-MiniLM-L6-v2", + "top_k": 10, + "device": ComponentDevice.resolve_device(None).to_dict(), + "similarity": "cosine", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "query_prefix": "", + "document_prefix": "", + "query_suffix": "", + "document_suffix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + }, + } ranker = SentenceTransformersDiversityRanker.from_dict(data) assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2" @@ -81,7 +110,7 @@ def test_to_and_from_dict(self): assert ranker.meta_fields_to_embed == [] assert ranker.embedding_separator == "\n" - def test_to_and_from_dict_with_custom_init_parameters(self): + def test_to_dict_with_custom_init_parameters(self): component = SentenceTransformersDiversityRanker( model="sentence-transformers/msmarco-distilbert-base-v4", top_k=5, @@ -113,6 +142,23 @@ def test_to_and_from_dict_with_custom_init_parameters(self): }, } + def test_from_dict_with_custom_init_parameters(self): + data = { + "type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker", + "init_parameters": { + "model": "sentence-transformers/msmarco-distilbert-base-v4", + "top_k": 5, + "device": ComponentDevice.from_str("cuda:0").to_dict(), + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "similarity": "dot_product", + "query_prefix": "query:", + "document_prefix": "document:", + "query_suffix": "query suffix", + "document_suffix": "document suffix", + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": "--", + }, + } ranker = SentenceTransformersDiversityRanker.from_dict(data) assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4" @@ -131,8 +177,11 @@ def test_run_incorrect_similarity(self): """ Tests that run method raises ValueError if similarity is incorrect """ - with pytest.raises(ValueError): - SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="incorrect") + similarity = "incorrect" + with pytest.raises( + ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}." + ): + SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity) @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_without_warm_up(self, similarity): @@ -144,10 +193,10 @@ def test_run_without_warm_up(self, similarity): ) documents = [Document(content="doc1"), Document(content="doc2")] - with pytest.raises(ComponentError): + error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up." + with pytest.raises(ComponentError, match=error_msg): ranker.run(query="test query", documents=documents) - @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_empty_query(self, similarity): """ @@ -156,7 +205,8 @@ def test_run_empty_query(self, similarity): ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity ) - ranker.warm_up() + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) documents = [Document(content="doc1"), Document(content="doc2")] result = ranker.run(query="", documents=documents) @@ -166,7 +216,6 @@ def test_run_empty_query(self, similarity): assert len(ranked_docs) == 2 assert all(isinstance(doc, Document) for doc in ranked_docs) - @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_top_k(self, similarity): """ @@ -176,7 +225,8 @@ def test_run_top_k(self, similarity): ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3 ) - ranker.warm_up() + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) query = "test query" documents = [ Document(content="doc1"), @@ -200,53 +250,73 @@ def test_run_top_k(self, similarity): assert len(ranked_docs) == 2 assert all(isinstance(doc, Document) for doc in ranked_docs) - @pytest.mark.integration + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_negative_top_k_at_init(self, similarity): + """ + Tests that run method raises an error for negative top-k set at init. + """ + with pytest.raises(ValueError, match="top_k must be > 0, but got"): + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5 + ) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_top_k_is_none_at_init(self, similarity): + """ + Tests that run method raises an error for top-k set to None at init. + """ + with pytest.raises(ValueError, match="top_k must be > 0, but got"): + SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None + ) + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_negative_top_k(self, similarity): """ - Tests that run method raises an error for negative top-k. + Tests that run method raises an error for negative top-k set at runtime. """ ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10 ) - ranker.warm_up() + ranker.model = MagicMock() query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] - # Setting top_k at runtime - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="top_k must be > 0, but got"): ranker.run(query=query, documents=documents, top_k=-5) - # Setting top_k at init - with pytest.raises(ValueError): - SentenceTransformersDiversityRanker( - model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5 - ) - - @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_top_k_is_none(self, similarity): """ Tests that run method returns the correct order of documents for top-k set to None. """ - # Setting top_k to None at init should raise error - with pytest.raises(ValueError): - SentenceTransformersDiversityRanker( - model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None - ) - # Setting top_k to None is ignored during runtime, it should use top_k set at init. ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2 ) - ranker.warm_up() + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] result = ranker.run(query=query, documents=documents, top_k=None) assert len(result["documents"]) == 2 - @pytest.mark.integration + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_no_documents_provided(self, similarity): + """ + Test that run method returns an empty list if no documents are supplied. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity + ) + ranker.model = MagicMock() + query = "test query" + documents = [] + results = ranker.run(query=query, documents=documents) + + assert len(results["documents"]) == 0 + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_with_less_documents_than_top_k(self, similarity): """ @@ -255,14 +325,14 @@ def test_run_with_less_documents_than_top_k(self, similarity): ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5 ) - ranker.warm_up() + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] result = ranker.run(query=query, documents=documents) assert len(result["documents"]) == 3 - @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_single_document_corner_case(self, similarity): """ @@ -271,28 +341,93 @@ def test_run_single_document_corner_case(self, similarity): ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity ) - ranker.warm_up() + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) query = "test" documents = [Document(content="doc1")] result = ranker.run(query=query, documents=documents) assert len(result["documents"]) == 1 - @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) - def test_run_no_documents_provided(self, similarity): + def test_prepare_texts_to_embed(self, similarity): """ - Test that run method returns an empty list if no documents are supplied. + Test creation of texts to embed from documents with meta fields, document prefix and suffix. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + similarity=similarity, + document_prefix="test doc: ", + document_suffix=" end doc.", + meta_fields_to_embed=["meta_field"], + embedding_separator="\n", + ) + documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + texts = ranker._prepare_texts_to_embed(documents=documents) + + assert texts == [ + "test doc: meta_value 0\ndocument number 0 end doc.", + "test doc: meta_value 1\ndocument number 1 end doc.", + "test doc: meta_value 2\ndocument number 2 end doc.", + "test doc: meta_value 3\ndocument number 3 end doc.", + "test doc: meta_value 4\ndocument number 4 end doc.", + ] + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_encode_text(self, similarity): + """ + Test addition of suffix and prefix to the query and documents when creating embeddings. + """ + ranker = SentenceTransformersDiversityRanker( + model="sentence-transformers/all-MiniLM-L6-v2", + similarity=similarity, + query_prefix="test query: ", + query_suffix=" end query.", + document_prefix="test doc: ", + document_suffix=" end doc.", + meta_fields_to_embed=["meta_field"], + embedding_separator="\n", + ) + query = "query" + documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) + ranker.run(query=query, documents=documents) + + assert ranker.model.encode.call_count == 2 + ranker.model.assert_has_calls( + [ + call.encode( + [ + "test doc: meta_value 0\ndocument number 0 end doc.", + "test doc: meta_value 1\ndocument number 1 end doc.", + "test doc: meta_value 2\ndocument number 2 end doc.", + "test doc: meta_value 3\ndocument number 3 end doc.", + "test doc: meta_value 4\ndocument number 4 end doc.", + ], + convert_to_tensor=True, + ), + call.encode(["test query: query end query."], convert_to_tensor=True), + ] + ) + + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_run_greedy_diversity_order(self, similarity): + """ + Tests that the given list of documents is ordered to maximize diversity. """ ranker = SentenceTransformersDiversityRanker( model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity ) - ranker.warm_up() - query = "test query" - documents = [] - results = ranker.run(query=query, documents=documents) + query = "city" + documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")] + ranker.model = MagicMock() + ranker.model.encode = MagicMock(side_effect=mock_encode_response) - assert len(results["documents"]) == 0 + ranked_docs = ranker._greedy_diversity_order(query=query, documents=documents) + ranked_text = " ".join([doc.content for doc in ranked_docs]) + + assert ranked_text == "Berlin Eiffel Tower Bananas" @pytest.mark.integration @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) From 30929776cd57b9eceb2c759fd976dc85c88ea69b Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:39:02 +0530 Subject: [PATCH 7/8] Add test for warm up --- .../test_sentence_transformers_diversity.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index 600730c2ad..0ed6a579bb 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock, call, patch import pytest import torch @@ -197,6 +197,31 @@ def test_run_without_warm_up(self, similarity): with pytest.raises(ComponentError, match=error_msg): ranker.run(query="test query", documents=documents) + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) + def test_warm_up(self, similarity): + """ + Test that ranker loads the SentenceTransformer model correctly during warm up. + """ + mock_model_class = MagicMock() + mock_model_instance = MagicMock() + mock_model_class.return_value = mock_model_instance + + with patch( + "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer", new=mock_model_class + ): + ranker = SentenceTransformersDiversityRanker(model="mock_model_name", similarity=similarity) + + assert ranker.model is None + + ranker.warm_up() + + mock_model_class.assert_called_once_with( + model_name_or_path="mock_model_name", + device=ComponentDevice.resolve_device(None).to_torch_str(), + use_auth_token=None, + ) + assert ranker.model == mock_model_instance + @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) def test_run_empty_query(self, similarity): """ From 00c9e0fb8927068cd003fed89be655b2f16b547a Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:57:35 +0530 Subject: [PATCH 8/8] Update release notes --- releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml index b8a65c10fb..0a9dfa9058 100644 --- a/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml +++ b/releasenotes/notes/add-diversity-ranker-6ecee21134eda673.yaml @@ -1,4 +1,6 @@ --- features: - | - Add `DiversityRanker`. Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query. + Add `SentenceTransformersDiversityRanker`. + The Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents. + The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query.