diff --git a/integrations/cohere/README.md b/integrations/cohere/README.md index 86a43bf83..ca3922615 100644 --- a/integrations/cohere/README.md +++ b/integrations/cohere/README.md @@ -47,6 +47,11 @@ To only run generators tests: hatch run test -m"generators" ``` +To only run ranker tests: +``` +hatch run test -m"ranker" +``` + Markers can be combined, for example you can run only integration tests for embedders with: ``` hatch run test -m"integrations and embedders" diff --git a/integrations/cohere/examples/cohere_ranker_in_a_pipeline.py b/integrations/cohere/examples/cohere_ranker_in_a_pipeline.py new file mode 100644 index 000000000..2234eb3a6 --- /dev/null +++ b/integrations/cohere/examples/cohere_ranker_in_a_pipeline.py @@ -0,0 +1,29 @@ +from haystack import Document, Pipeline +from haystack.components.retrievers.in_memory import InMemoryBM25Retriever +from haystack.document_stores.in_memory import InMemoryDocumentStore +from haystack_integrations.components.rankers.cohere import CohereRanker + +# Note set your API key by running the below command in your terminal +# export CO_API_KEY="" + +docs = [ + Document(content="Paris is in France"), + Document(content="Berlin is in Germany"), + Document(content="Lyon is in France"), +] +document_store = InMemoryDocumentStore() +document_store.write_documents(docs) + +retriever = InMemoryBM25Retriever(document_store=document_store) +ranker = CohereRanker(model="rerank-english-v2.0", top_k=3) + +document_ranker_pipeline = Pipeline() +document_ranker_pipeline.add_component(instance=retriever, name="retriever") +document_ranker_pipeline.add_component(instance=ranker, name="ranker") + +document_ranker_pipeline.connect("retriever.documents", "ranker.documents") + +query = "Cities in France" +res = document_ranker_pipeline.run( + data={"retriever": {"query": query, "top_k": 3}, "ranker": {"query": query, "top_k": 3}} +) diff --git a/integrations/cohere/pydoc/config.yml b/integrations/cohere/pydoc/config.yml index 48608625c..5d4e747f5 100644 --- a/integrations/cohere/pydoc/config.yml +++ b/integrations/cohere/pydoc/config.yml @@ -7,6 +7,7 @@ loaders: "haystack_integrations.components.embedders.cohere.utils", "haystack_integrations.components.generators.cohere.generator", "haystack_integrations.components.generators.cohere.chat.chat_generator", + "haystack_integrations.components.rankers.cohere.ranker", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index e88534052..fd34b7743 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -163,5 +163,6 @@ markers = [ "embedders: embedders tests", "generators: generators tests", "chat_generators: chat_generators tests", + "ranker: ranker tests" ] log_cli = true diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/__init__.py new file mode 100644 index 000000000..b4e09d00a --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/__init__.py @@ -0,0 +1,3 @@ +from .ranker import CohereRanker + +__all__ = ["CohereRanker"] diff --git a/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py new file mode 100644 index 000000000..f902a286c --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py @@ -0,0 +1,165 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +import cohere + +logger = logging.getLogger(__name__) + +MAX_NUM_DOCS_FOR_COHERE_RANKER = 1000 + + +@component +class CohereRanker: + """ + Ranks Documents based on their similarity to the query using [Cohere models](https://docs.cohere.com/reference/rerank-1). + + Documents are indexed from most to least semantically relevant to the query. + + Usage example: + ```python + from haystack import Document + from haystack.components.rankers import CohereRanker + + ranker = CohereRanker(model="rerank-english-v2.0", top_k=2) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + docs = output["documents"] + ``` + """ + + def __init__( + self, + model: str = "rerank-english-v2.0", + top_k: int = 10, + api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), + api_base_url: str = cohere.COHERE_API_URL, + max_chunks_per_doc: Optional[int] = None, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + """ + Creates an instance of the 'CohereRanker'. + + :param model: Cohere model name. Check the list of supported models in the [Cohere documentation](https://docs.cohere.com/docs/models). + :param top_k: The maximum number of documents to return. + :param api_key: Cohere API key. + :param api_base_url: the base URL of the Cohere API. + :param max_chunks_per_doc: If your document exceeds 512 tokens, this determines the maximum number of + chunks a document can be split into. If `None`, the default of 10 is used. + For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10 + chunks each of 512 tokens and the last 880 tokens will be disregarded. + Check [Cohere docs](https://docs.cohere.com/docs/reranking-best-practices) for more information. + :param meta_fields_to_embed: List of meta fields that should be concatenated + with the document content for reranking. + :param meta_data_separator: Separator used to concatenate the meta fields + to the Document content. + """ + self.model_name = model + self.api_key = api_key + self.api_base_url = api_base_url + self.top_k = top_k + self.max_chunks_per_doc = max_chunks_per_doc + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + self._cohere_client = cohere.Client( + api_key=self.api_key.resolve_value(), api_url=self.api_base_url, client_name="haystack" + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self.model_name, + api_key=self.api_key.to_dict() if self.api_key else None, + api_base_url=self.api_base_url, + top_k=self.top_k, + max_chunks_per_doc=self.max_chunks_per_doc, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_cohere_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be given as input to Cohere model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) + ] + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Use the Cohere Reranker to re-rank the list of documents based on the query. + + :param query: + Query string. + :param documents: + List of Documents. + :param top_k: + The maximum number of Documents you want the Ranker to return. + :returns: + A dictionary with the following keys: + - `documents`: List of Documents most similar to the given query in descending order of similarity. + + :raises ValueError: If `top_k` is not > 0. + """ + top_k = top_k or self.top_k + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + cohere_input_docs = self._prepare_cohere_input_docs(documents) + if len(cohere_input_docs) > MAX_NUM_DOCS_FOR_COHERE_RANKER: + logger.warning( + f"The Cohere reranking endpoint only supports {MAX_NUM_DOCS_FOR_COHERE_RANKER} documents.\ + The number of documents has been truncated to {MAX_NUM_DOCS_FOR_COHERE_RANKER} \ + from {len(cohere_input_docs)}." + ) + cohere_input_docs = cohere_input_docs[:MAX_NUM_DOCS_FOR_COHERE_RANKER] + + response = self._cohere_client.rerank( + model=self.model_name, + query=query, + documents=cohere_input_docs, + max_chunks_per_doc=self.max_chunks_per_doc, + top_n=top_k, + ) + indices = [output.index for output in response.results] + scores = [output.relevance_score for output in response.results] + sorted_docs = [] + for idx, score in zip(indices, scores): + doc = documents[idx] + doc.score = score + sorted_docs.append(documents[idx]) + return {"documents": sorted_docs} diff --git a/integrations/cohere/tests/test_cohere_ranker.py b/integrations/cohere/tests/test_cohere_ranker.py new file mode 100644 index 000000000..08e01c647 --- /dev/null +++ b/integrations/cohere/tests/test_cohere_ranker.py @@ -0,0 +1,345 @@ +import os +from unittest.mock import Mock, patch + +import pytest +from cohere import COHERE_API_URL +from haystack import Document +from haystack.utils.auth import Secret +from haystack_integrations.components.rankers.cohere import CohereRanker + +pytestmark = pytest.mark.ranker + + +@pytest.fixture +def mock_ranker_response(): + """ + Mock the Cohere ranker API response and reuse it for tests + The `response` is an object of + and `response.results` is list : [RerankResult, + RerankResult, + RerankResult] + """ + with patch("cohere.Client.rerank", autospec=True) as mock_ranker_response: + + mock_response = Mock() + + mock_ranker_res_obj1 = Mock() + mock_ranker_res_obj1.index = 2 + mock_ranker_res_obj1.relevance_score = 0.98 + + mock_ranker_res_obj2 = Mock() + mock_ranker_res_obj2.index = 1 + mock_ranker_res_obj2.relevance_score = 0.95 + + mock_response.results = [mock_ranker_res_obj1, mock_ranker_res_obj2] + mock_ranker_response.return_value = mock_response + yield mock_ranker_response + + +class TestCohereRanker: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker() + assert component.model_name == "rerank-english-v2.0" + assert component.top_k == 10 + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) + assert component.api_base_url == COHERE_API_URL + assert component.max_chunks_per_doc is None + assert component.meta_fields_to_embed == [] + assert component.meta_data_separator == "\n" + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("CO_API_KEY", raising=False) + monkeypatch.delenv("COHERE_API_KEY", raising=False) + with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): + CohereRanker() + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker( + model="rerank-multilingual-v2.0", + top_k=5, + api_key=Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), + api_base_url="test-base-url", + max_chunks_per_doc=40, + meta_fields_to_embed=["meta_field_1", "meta_field_2"], + meta_data_separator=",", + ) + assert component.model_name == "rerank-multilingual-v2.0" + assert component.top_k == 5 + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) + assert component.api_base_url == "test-base-url" + assert component.max_chunks_per_doc == 40 + assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] + assert component.meta_data_separator == "," + + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker() + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", + "init_parameters": { + "model": "rerank-english-v2.0", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, + "api_base_url": COHERE_API_URL, + "top_k": 10, + "max_chunks_per_doc": None, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker( + model="rerank-multilingual-v2.0", + top_k=2, + api_key=Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), + api_base_url="test-base-url", + max_chunks_per_doc=50, + meta_fields_to_embed=["meta_field_1", "meta_field_2"], + meta_data_separator=",", + ) + data = component.to_dict() + assert data == { + "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, + "api_base_url": "test-base-url", + "top_k": 2, + "max_chunks_per_doc": 50, + "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], + "meta_data_separator": ",", + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + data = { + "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, + "api_base_url": "test-base-url", + "top_k": 2, + "max_chunks_per_doc": 50, + "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], + "meta_data_separator": ",", + }, + } + component = CohereRanker.from_dict(data) + assert component.model_name == "rerank-multilingual-v2.0" + assert component.top_k == 2 + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) + assert component.api_base_url == "test-base-url" + assert component.max_chunks_per_doc == 50 + assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] + assert component.meta_data_separator == "," + + def test_from_dict_fail_wo_env_var(self, monkeypatch): + monkeypatch.delenv("CO_API_KEY", raising=False) + monkeypatch.delenv("COHERE_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.rankers.cohere.ranker.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, + "top_k": 2, + "max_chunks_per_doc": 50, + }, + } + with pytest.raises(ValueError, match="None of the following authentication environment variables are set: *"): + CohereRanker.from_dict(data) + + def test_prepare_cohere_input_docs_default_separator(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"]) + documents = [ + Document( + content=f"document number {i}", + meta={ + "meta_field_1": f"meta_value_1 {i}", + "meta_field_2": f"meta_value_2 {i+5}", + "meta_field_3": f"meta_value_3 {i+15}", + }, + ) + for i in range(5) + ] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "meta_value_1 0\nmeta_value_2 5\ndocument number 0", + "meta_value_1 1\nmeta_value_2 6\ndocument number 1", + "meta_value_1 2\nmeta_value_2 7\ndocument number 2", + "meta_value_1 3\nmeta_value_2 8\ndocument number 3", + "meta_value_1 4\nmeta_value_2 9\ndocument number 4", + ] + + def test_prepare_cohere_input_docs_custom_separator(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") + documents = [ + Document( + content=f"document number {i}", + meta={ + "meta_field_1": f"meta_value_1 {i}", + "meta_field_2": f"meta_value_2 {i+5}", + "meta_field_3": f"meta_value_3 {i+15}", + }, + ) + for i in range(5) + ] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "meta_value_1 0 meta_value_2 5 document number 0", + "meta_value_1 1 meta_value_2 6 document number 1", + "meta_value_1 2 meta_value_2 7 document number 2", + "meta_value_1 3 meta_value_2 8 document number 3", + "meta_value_1 4 meta_value_2 9 document number 4", + ] + + def test_prepare_cohere_input_docs_no_meta_data(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") + documents = [Document(content=f"document number {i}") for i in range(5)] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "document number 0", + "document number 1", + "document number 2", + "document number 3", + "document number 4", + ] + + def test_prepare_cohere_input_docs_no_docs(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") + documents = [] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [] + + def test_run_negative_topk_in_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker(top_k=-2) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents) + + def test_run_zero_topk_in_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker(top_k=0) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents) + + def test_run_negative_topk_in_run(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents, -3) + + def test_run_zero_topk_in_run_and_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker(top_k=0) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents, 0) + + def test_run_documents_provided(self, monkeypatch, mock_ranker_response): # noqa: ARG002 + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker() + query = "test" + documents = [ + Document(id="abcd", content="doc1", meta={"meta_field": "meta_value_1"}), + Document(id="efgh", content="doc2", meta={"meta_field": "meta_value_2"}), + Document(id="ijkl", content="doc3", meta={"meta_field": "meta_value_3"}), + ] + ranker_results = ranker.run(query, documents, 2) + + assert isinstance(ranker_results, dict) + reranked_docs = ranker_results["documents"] + assert reranked_docs == [ + Document(id="ijkl", content="doc3", meta={"meta_field": "meta_value_3"}, score=0.98), + Document(id="efgh", content="doc2", meta={"meta_field": "meta_value_2"}, score=0.95), + ] + + def test_run_topk_set_in_init(self, monkeypatch, mock_ranker_response): # noqa: ARG002 + monkeypatch.setenv("CO_API_KEY", "test-api-key") + ranker = CohereRanker(top_k=2) + query = "test" + documents = [ + Document(id="abcd", content="doc1"), + Document(id="efgh", content="doc2"), + Document(id="ijkl", content="doc3"), + ] + + ranker_results = ranker.run(query, documents) + + assert isinstance(ranker_results, dict) + reranked_docs = ranker_results["documents"] + assert reranked_docs == [ + Document(id="ijkl", content="doc3", score=0.98), + Document(id="efgh", content="doc2", score=0.95), + ] + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + component = CohereRanker() + documents = [ + Document(id="abcd", content="Paris is in France"), + Document(id="efgh", content="Berlin is in Germany"), + Document(id="ijkl", content="Lyon is in France"), + ] + + ranker_result = component.run("Cities in France", documents, 2) + expected_documents = [documents[0], documents[2]] + expected_documents_content = [doc.content for doc in expected_documents] + result_documents_contents = [doc.content for doc in ranker_result["documents"]] + + assert isinstance(ranker_result, dict) + assert isinstance(ranker_result["documents"], list) + assert len(ranker_result["documents"]) == 2 + assert all(isinstance(doc, Document) for doc in ranker_result["documents"]) + assert set(result_documents_contents) == set(expected_documents_content) + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_topk_greater_than_docs(self): + component = CohereRanker() + documents = [ + Document(id="abcd", content="Paris is in France"), + Document(id="efgh", content="Berlin is in Germany"), + Document(id="ijkl", content="Lyon is in France"), + ] + + ranker_result = component.run("Cities in France", documents, 5) + expected_documents = [documents[0], documents[2], documents[1]] + expected_documents_content = [doc.content for doc in expected_documents] + result_documents_contents = [doc.content for doc in ranker_result["documents"]] + + assert isinstance(ranker_result, dict) + assert isinstance(ranker_result["documents"], list) + assert len(ranker_result["documents"]) == 3 + assert all(isinstance(doc, Document) for doc in ranker_result["documents"]) + assert set(result_documents_contents) == set(expected_documents_content)