diff --git a/haystack/preview/components/rankers/__init__.py b/haystack/preview/components/rankers/__init__.py index 27337481eb..ac1001eea2 100644 --- a/haystack/preview/components/rankers/__init__.py +++ b/haystack/preview/components/rankers/__init__.py @@ -1,3 +1,3 @@ -from haystack.preview.components.rankers.similarity import SimilarityRanker +from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker -__all__ = ["SimilarityRanker"] +__all__ = ["TransformersSimilarityRanker"] diff --git a/haystack/preview/components/rankers/similarity.py b/haystack/preview/components/rankers/transformers_similarity.py similarity index 90% rename from haystack/preview/components/rankers/similarity.py rename to haystack/preview/components/rankers/transformers_similarity.py index 6bcc8ba00f..7733721e98 100644 --- a/haystack/preview/components/rankers/similarity.py +++ b/haystack/preview/components/rankers/transformers_similarity.py @@ -14,19 +14,20 @@ @component -class SimilarityRanker: +class TransformersSimilarityRanker: """ Ranks documents based on query similarity. + It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents. Usage example: ``` from haystack.preview import Document - from haystack.preview.components.rankers import SimilarityRanker + from haystack.preview.components.rankers import TransformersSimilarityRanker - sampler = SimilarityRanker() + ranker = TransformersSimilarityRanker() docs = [Document(text="Paris"), Document(text="Berlin")] query = "City in Germany" - output = sampler.run(query=query, documents=docs) + output = ranker.run(query=query, documents=docs) docs = output["documents"] assert len(docs) == 2 assert docs[0].text == "Berlin" @@ -41,9 +42,10 @@ def __init__( top_k: int = 10, ): """ - Creates an instance of SimilarityRanker. + Creates an instance of TransformersSimilarityRanker. - :param model_name_or_path: Path to a pre-trained sentence-transformers model. + :param model_name_or_path: The name or path of a pre-trained cross-encoder model + from Hugging Face Hub. :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. :param token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, then the token generated when running diff --git a/releasenotes/notes/rename_similarity_ranker-d755c2cd00449ecc.yaml b/releasenotes/notes/rename_similarity_ranker-d755c2cd00449ecc.yaml new file mode 100644 index 0000000000..a899742830 --- /dev/null +++ b/releasenotes/notes/rename_similarity_ranker-d755c2cd00449ecc.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Rename `SimilarityRanker` to `TransformersSimilarityRanker`, + as there will be more similarity rankers in the future. diff --git a/test/preview/components/rankers/test_similarity.py b/test/preview/components/rankers/test_transformers_similarity.py similarity index 82% rename from test/preview/components/rankers/test_similarity.py rename to test/preview/components/rankers/test_transformers_similarity.py index b9c5fe0ddf..6e5511b2e1 100644 --- a/test/preview/components/rankers/test_similarity.py +++ b/test/preview/components/rankers/test_transformers_similarity.py @@ -1,30 +1,32 @@ import pytest from haystack.preview import Document, ComponentError -from haystack.preview.components.rankers.similarity import SimilarityRanker +from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker class TestSimilarityRanker: @pytest.mark.unit def test_to_dict(self): - component = SimilarityRanker() + component = TransformersSimilarityRanker() data = component.to_dict() assert data == { - "type": "SimilarityRanker", + "type": "TransformersSimilarityRanker", "init_parameters": { "device": "cpu", "top_k": 10, - "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", "token": None, + "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", }, } @pytest.mark.unit def test_to_dict_with_custom_init_parameters(self): - component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5) + component = TransformersSimilarityRanker( + model_name_or_path="my_model", device="cuda", token="my_token", top_k=5 + ) data = component.to_dict() assert data == { - "type": "SimilarityRanker", + "type": "TransformersSimilarityRanker", "init_parameters": { "device": "cuda", "model_name_or_path": "my_model", @@ -46,7 +48,7 @@ def test_run(self, query, docs_before_texts, expected_first_text): """ Test if the component ranks documents correctly. """ - ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") + ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") ranker.warm_up() docs_before = [Document(text=text) for text in docs_before_texts] output = ranker.run(query=query, documents=docs_before) @@ -61,7 +63,7 @@ def test_run(self, query, docs_before_texts, expected_first_text): # Returns an empty list if no documents are provided @pytest.mark.integration def test_returns_empty_list_if_no_documents_are_provided(self): - sampler = SimilarityRanker() + sampler = TransformersSimilarityRanker() sampler.warm_up() output = sampler.run(query="City in Germany", documents=[]) assert output["documents"] == [] @@ -69,7 +71,7 @@ def test_returns_empty_list_if_no_documents_are_provided(self): # Raises ComponentError if model is not warmed up @pytest.mark.integration def test_raises_component_error_if_model_not_warmed_up(self): - sampler = SimilarityRanker() + sampler = TransformersSimilarityRanker() with pytest.raises(ComponentError): sampler.run(query="query", documents=[Document(text="document")]) @@ -87,7 +89,7 @@ def test_run_top_k(self, query, docs_before_texts, expected_first_text): """ Test if the component ranks documents correctly with a custom top_k. """ - ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) + ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) ranker.warm_up() docs_before = [Document(text=text) for text in docs_before_texts] output = ranker.run(query=query, documents=docs_before)