From 414fb98a2ddbd313cc9845de8c1fa5548c7e1e80 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 6 Dec 2023 18:13:34 +0100 Subject: [PATCH] review feedback --- .../src/cohere_haystack/embedders/document_embedder.py | 6 +++++- integrations/cohere/tests/test_document_embedder.py | 9 ++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index e14474fd9..681471947 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -119,13 +119,17 @@ def run(self, documents: List[Document]): :param documents: A list of Documents to embed. """ - if not isinstance(documents, list) or not isinstance(documents[0], Document): + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "CohereDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the CohereTextEmbedder." ) raise TypeError(msg) + if not documents: + # return early if we were passed an empty list + return {"documents": [], "metadata": {}} + texts_to_embed = self._prepare_texts_to_embed(documents) if self.use_async_client: diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index 5fab29e5b..d6309704c 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -128,10 +128,9 @@ def test_run(self): def test_run_wrong_input_format(self): embedder = CohereDocumentEmbedder(api_key="test-api-key") - string_input = "text" - list_integers_input = [1, 2, 3] - with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): - embedder.run(documents=string_input) + embedder.run(documents="text") with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): - embedder.run(documents=list_integers_input) + embedder.run(documents=[1, 2, 3]) + + assert embedder.run(documents=[]) == {"documents": [], "metadata": {}}