From b897441335f3daad9fc9a5599287af6ddd3c5334 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 6 Sep 2024 16:25:54 +0200 Subject: [PATCH 01/33] test: Add tests for VertexAIChatGeminiGenerator and migrate from preview package in vertexai (#1042) * Add tests for chat generator and migrate from preview package to a stable version of vertexai generative_model --- .github/workflows/google_vertex.yml | 2 +- .../generators/google_vertex/chat/gemini.py | 2 +- .../google_vertex/tests/chat/test_gemini.py | 295 ++++++++++++++++++ .../google_vertex/tests/test_gemini.py | 25 +- 4 files changed, 319 insertions(+), 5 deletions(-) create mode 100644 integrations/google_vertex/tests/chat/test_gemini.py diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml index 78ba5694b..34c0cf07c 100644 --- a/.github/workflows/google_vertex.yml +++ b/.github/workflows/google_vertex.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.9", "3.10"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - name: Support longpaths diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index e5ca1166d..8cdb58d2d 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -8,7 +8,7 @@ from haystack.dataclasses.chat_message import ChatMessage, ChatRole from haystack.utils import deserialize_callable, serialize_callable from vertexai import init as vertexai_init -from vertexai.preview.generative_models import ( +from vertexai.generative_models import ( Content, GenerationConfig, GenerationResponse, diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py new file mode 100644 index 000000000..a1564b9f2 --- /dev/null +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -0,0 +1,295 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.dataclasses import ChatMessage, StreamingChunk +from vertexai.generative_models import ( + Content, + FunctionDeclaration, + GenerationConfig, + GenerationResponse, + HarmBlockThreshold, + HarmCategory, + Part, + Tool, +) + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator + +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France"), + ] + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_init(mock_vertexai_init, _mock_generative_model): + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + mock_vertexai_init.assert_called() + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == [tool] + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_to_dict(_mock_vertexai_init, _mock_generative_model): + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiChatGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "model": "gemini-1.5-flash", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_from_dict(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._generation_config is None + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiChatGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-1.5-flash", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-1.5-flash" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + assert isinstance(gemini._generation_config, GenerationConfig) + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_run(mock_generative_model): + mock_model = Mock() + mock_candidate = Mock(content=Content(parts=[Part.from_text("This is a generated response.")], role="model")) + mock_response = MagicMock(spec=GenerationResponse, candidates=[mock_candidate]) + + mock_model.send_message.return_value = mock_response + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + + +@patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") +def test_run_with_streaming_callback(mock_generative_model): + mock_model = Mock() + mock_responses = iter( + [MagicMock(spec=GenerationResponse, text="First part"), MagicMock(spec=GenerationResponse, text="Second part")] + ) + + mock_model.send_message.return_value = mock_responses + mock_model.start_chat.return_value = mock_model + mock_generative_model.return_value = mock_model + + streaming_callback_called = [] + + def streaming_callback(chunk: StreamingChunk) -> None: + streaming_callback_called.append(chunk.content) + + gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + messages = [ + ChatMessage.from_system("You are a helpful assistant"), + ChatMessage.from_user("What's the capital of France?"), + ] + gemini.run(messages=messages) + + mock_model.send_message.assert_called_once() + assert streaming_callback_called == ["First part", "Second part"] + + +def test_serialization_deserialization_pipeline(): + + pipeline = Pipeline() + template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")] + pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template)) + pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123")) + pipeline.connect("prompt_builder.prompt", "gemini.messages") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 8d08e0859..bb96ec409 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock, Mock, patch +from haystack import Pipeline +from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk -from vertexai.preview.generative_models import ( +from vertexai.generative_models import ( FunctionDeclaration, GenerationConfig, HarmBlockThreshold, @@ -191,18 +193,18 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "function_declarations": [ { "name": "get_current_weather", - "description": "Get the current weather in a given location", "parameters": { "type_": "OBJECT", "properties": { + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, "location": { "type_": "STRING", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, + "description": "Get the current weather in a given location", } ] } @@ -254,3 +256,20 @@ def streaming_callback(_chunk: StreamingChunk) -> None: gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called + + +def test_serialization_deserialization_pipeline(): + template = """ + Answer the following questions: + 1. What is the weather like today? + """ + pipeline = Pipeline() + + pipeline.add_component("prompt_builder", PromptBuilder(template=template)) + pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123")) + pipeline.connect("prompt_builder", "gemini") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline From ec236b8618524e04586bf7de77ccef8c95887f5f Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 6 Sep 2024 21:41:10 -0700 Subject: [PATCH 02/33] fix: Weaviate - fix connection issues with some WCS URLs (#1058) * try fix * small fix in docstring * simplify condition * better condition and comments * only lint on 3.9 like other integrations --- .github/workflows/weaviate.yml | 1 + .../document_stores/weaviate/document_store.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index 5e29eafe7..06a4bc289 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -44,6 +44,7 @@ jobs: run: pip install --upgrade hatch - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all - name: Run Weaviate container diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 82088dd89..09e0a673d 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -68,7 +68,7 @@ class WeaviateDocumentStore: from haystack_integrations.document_stores.weaviate.auth import AuthApiKey from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore - os.environ["WEAVIATE_API_KEY"] = "MY_API_KEY + os.environ["WEAVIATE_API_KEY"] = "MY_API_KEY" document_store = WeaviateDocumentStore( url="rAnD0mD1g1t5.something.weaviate.cloud", @@ -172,17 +172,18 @@ def client(self): if self._client: return self._client - if self._url and self._url.startswith("http") and self._url.endswith(".weaviate.network"): - # We use this utility function instead of using WeaviateClient directly like in other cases - # otherwise we'd have to parse the URL to get some information about the connection. - # This utility function does all that for us. - self._client = weaviate.connect_to_wcs( + if self._url and self._url.endswith((".weaviate.network", ".weaviate.cloud")): + # If we detect that the URL is a Weaviate Cloud URL, we use the utility function to connect + # instead of using WeaviateClient directly like in other cases. + # Among other things, the utility function takes care of parsing the URL. + self._client = weaviate.connect_to_weaviate_cloud( self._url, auth_credentials=self._auth_client_secret.resolve_value() if self._auth_client_secret else None, headers=self._additional_headers, additional_config=self._additional_config, ) else: + # Embedded, local Docker deployment or custom connection. # proxies, timeout_config, trust_env are part of additional_config now # startup_period has been removed self._client = weaviate.WeaviateClient( From cb449364e43ce8639e2cfb086ea73c8a89a656af Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Sat, 7 Sep 2024 04:42:10 +0000 Subject: [PATCH 03/33] Update the changelog --- integrations/weaviate/CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index bddde1b7d..d934826bc 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [unreleased] +## [integrations/weaviate-v2.2.1] - 2024-09-07 ### ๐Ÿš€ Features @@ -10,6 +10,12 @@ - Weaviate filter error (#811) - Fix connection to Weaviate Cloud Service (#624) +- Pin weaviate-client (#1046) +- Weaviate - fix connection issues with some WCS URLs (#1058) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) ### โš™๏ธ Miscellaneous Tasks From 79217e802ec391fc54b469ad65762ac6b6ca145c Mon Sep 17 00:00:00 2001 From: Ulises M <30765968+lbux@users.noreply.github.com> Date: Sat, 7 Sep 2024 11:53:02 -0700 Subject: [PATCH 04/33] refactor!: use ollama python library instead of calling the API with `requests` (#1059) * switch from api to python library * minor fixes, remove template as not applicable * expect proper error * fix typing issues * client as internal attr * lint * remove requests from deps * impr readme --------- Co-authored-by: anakin87 --- README.md | 2 +- integrations/ollama/README.md | 2 +- .../ollama/examples/chat_generator_example.py | 2 +- integrations/ollama/pyproject.toml | 4 +- .../embedders/ollama/document_embedder.py | 18 ++---- .../embedders/ollama/text_embedder.py | 24 +++----- .../generators/ollama/chat/chat_generator.py | 56 ++++++------------- .../components/generators/ollama/generator.py | 51 +++++------------ .../ollama/tests/test_chat_generator.py | 34 +++-------- .../ollama/tests/test_document_embedder.py | 10 ++-- integrations/ollama/tests/test_generator.py | 52 ++--------------- .../ollama/tests/test_text_embedder.py | 10 ++-- 12 files changed, 71 insertions(+), 194 deletions(-) diff --git a/README.md b/README.md index c4178184b..7ba853d62 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | | [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | | [nvidia-haystack](integrations/nvidia/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | -| [ollama-haystack](integrations/ollama/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | +| [ollama-haystack](integrations/ollama/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | | [optimum-haystack](integrations/optimum/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/optimum-haystack.svg)](https://pypi.org/project/optimum-haystack) | [![Test / optimum](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml) | | [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | diff --git a/integrations/ollama/README.md b/integrations/ollama/README.md index c842cddf1..a8ec1d526 100644 --- a/integrations/ollama/README.md +++ b/integrations/ollama/README.md @@ -36,4 +36,4 @@ Then run tests: hatch run test ``` -The default model used here is ``orca-mini`` \ No newline at end of file +The default model used here is ``orca-mini`` for generation and ``nomic-embed-text`` for embeddings \ No newline at end of file diff --git a/integrations/ollama/examples/chat_generator_example.py b/integrations/ollama/examples/chat_generator_example.py index 2326ba708..3dfd01065 100644 --- a/integrations/ollama/examples/chat_generator_example.py +++ b/integrations/ollama/examples/chat_generator_example.py @@ -17,7 +17,7 @@ ), ChatMessage.from_user("How do I get started?"), ] -client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434/api/chat") +client = OllamaChatGenerator(model="orca-mini", timeout=45, url="http://localhost:11434") response = client.run(messages, generation_kwargs={"temperature": 0.2}) diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 57aee153b..1174d3b78 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "requests"] +dependencies = ["haystack-ai", "ollama"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme" @@ -161,5 +161,5 @@ markers = [ addopts = ["--import-mode=importlib"] [[tool.mypy.overrides]] -module = ["haystack.*", "haystack_integrations.*", "pytest.*"] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*"] ignore_missing_imports = true diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index b5783c611..ac8f38f35 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional -import requests from haystack import Document, component from tqdm import tqdm +from ollama import Client + @component class OllamaDocumentEmbedder: @@ -27,7 +28,7 @@ class OllamaDocumentEmbedder: def __init__( self, model: str = "nomic-embed-text", - url: str = "http://localhost:11434/api/embeddings", + url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, timeout: int = 120, prefix: str = "", @@ -40,7 +41,7 @@ def __init__( :param model: The name of the model to use. The model should be available in the running Ollama instance. :param url: - The URL of the chat endpoint of a running Ollama instance. + The URL of a running Ollama instance. :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in @@ -59,11 +60,7 @@ def __init__( self.suffix = suffix self.prefix = prefix - def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: - """ - Returns A dictionary of JSON arguments for a POST request to an Ollama service - """ - return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} + self._client = Client(host=self.url, timeout=self.timeout) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ @@ -103,10 +100,7 @@ def _embed_batch( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): batch = texts_to_embed[i] # Single batch only - payload = self._create_json_payload(batch, generation_kwargs) - response = requests.post(url=self.url, json=payload, timeout=self.timeout) - response.raise_for_status() - result = response.json() + result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs) all_embeddings.append(result["embedding"]) meta["model"] = self.model diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index 5a28ba393..7779c6d6e 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -1,8 +1,9 @@ from typing import Any, Dict, List, Optional -import requests from haystack import component +from ollama import Client + @component class OllamaTextEmbedder: @@ -23,7 +24,7 @@ class OllamaTextEmbedder: def __init__( self, model: str = "nomic-embed-text", - url: str = "http://localhost:11434/api/embeddings", + url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, timeout: int = 120, ): @@ -31,7 +32,7 @@ def __init__( :param model: The name of the model to use. The model should be available in the running Ollama instance. :param url: - The URL of the chat endpoint of a running Ollama instance. + The URL of a running Ollama instance. :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in @@ -44,11 +45,7 @@ def __init__( self.url = url self.model = model - def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]: - """ - Returns A dictionary of JSON arguments for a POST request to an Ollama service - """ - return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}} + self._client = Client(host=self.url, timeout=self.timeout) @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): @@ -65,14 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): - `embedding`: The computed embeddings - `meta`: The metadata collected during the embedding process """ - - payload = self._create_json_payload(text, generation_kwargs) - - response = requests.post(url=self.url, json=payload, timeout=self.timeout) - - response.raise_for_status() - - result = response.json() - result["meta"] = {"model": self.model, "duration": response.elapsed} + result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs) + result["meta"] = {"model": self.model} return result diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index a95d8c4fb..1f3a0bf1e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,10 +1,9 @@ -import json from typing import Any, Callable, Dict, List, Optional -import requests from haystack import component from haystack.dataclasses import ChatMessage, StreamingChunk -from requests import Response + +from ollama import Client @component @@ -19,7 +18,7 @@ class OllamaChatGenerator: from haystack.dataclasses import ChatMessage generator = OllamaChatGenerator(model="zephyr", - url = "http://localhost:11434/api/chat", + url = "http://localhost:11434", generation_kwargs={ "num_predict": 100, "temperature": 0.9, @@ -35,9 +34,8 @@ class OllamaChatGenerator: def __init__( self, model: str = "orca-mini", - url: str = "http://localhost:11434/api/chat", + url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, - template: Optional[str] = None, timeout: int = 120, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -45,13 +43,11 @@ def __init__( :param model: The name of the model to use. The model should be available in the running Ollama instance. :param url: - The URL of the chat endpoint of a running Ollama instance. + The URL of a running Ollama instance. :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). - :param template: - The full prompt template (overrides what is defined in the Ollama Modelfile). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. :param streaming_callback: @@ -60,35 +56,22 @@ def __init__( """ self.timeout = timeout - self.template = template self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model self.streaming_callback = streaming_callback + self._client = Client(host=self.url, timeout=self.timeout) + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _create_json_payload(self, messages: List[ChatMessage], stream=False, generation_kwargs=None) -> Dict[str, Any]: - """ - Returns A dictionary of JSON arguments for a POST request to an Ollama service - """ - generation_kwargs = generation_kwargs or {} - return { - "messages": [self._message_to_dict(message) for message in messages], - "model": self.model, - "stream": stream, - "template": self.template, - "options": generation_kwargs, - } - - def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ - json_content = ollama_response.json() - message = ChatMessage.from_assistant(content=json_content["message"]["content"]) - message.meta.update({key: value for key, value in json_content.items() if key != "message"}) + message = ChatMessage.from_assistant(content=ollama_response["message"]["content"]) + message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) return message def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: @@ -105,11 +88,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - decoded_chunk = json.loads(chunk_response.decode("utf-8")) - - content = decoded_chunk["message"]["content"] - meta = {key: value for key, value in decoded_chunk.items() if key != "message"} - meta["role"] = decoded_chunk["message"]["role"] + content = chunk_response["message"]["content"] + meta = {key: value for key, value in chunk_response.items() if key != "message"} + meta["role"] = chunk_response["message"]["role"] chunk_message = StreamingChunk(content, meta) return chunk_message @@ -119,7 +100,7 @@ def _handle_streaming_response(self, response) -> List[StreamingChunk]: Handles Streaming response cases """ chunks: List[StreamingChunk] = [] - for chunk in response.iter_lines(): + for chunk in response: chunk_delta: StreamingChunk = self._build_chunk(chunk) chunks.append(chunk_delta) if self.streaming_callback is not None: @@ -149,13 +130,8 @@ def run( generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} stream = self.streaming_callback is not None - - json_payload = self._create_json_payload(messages, stream, generation_kwargs) - - response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) - - # throw error on unsuccessful response - response.raise_for_status() + messages = [self._message_to_dict(message) for message in messages] + response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 50c65b650..d92932c3e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,11 +1,10 @@ -import json from typing import Any, Callable, Dict, List, Optional -import requests from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from requests import Response + +from ollama import Client @component @@ -18,7 +17,7 @@ class OllamaGenerator: from haystack_integrations.components.generators.ollama import OllamaGenerator generator = OllamaGenerator(model="zephyr", - url = "http://localhost:11434/api/generate", + url = "http://localhost:11434", generation_kwargs={ "num_predict": 100, "temperature": 0.9, @@ -31,7 +30,7 @@ class OllamaGenerator: def __init__( self, model: str = "orca-mini", - url: str = "http://localhost:11434/api/generate", + url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, system_prompt: Optional[str] = None, template: Optional[str] = None, @@ -43,7 +42,7 @@ def __init__( :param model: The name of the model to use. The model should be available in the running Ollama instance. :param url: - The URL of the generation endpoint of a running Ollama instance. + The URL of a running Ollama instance. :param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others. See the available arguments in @@ -70,6 +69,8 @@ def __init__( self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback + self._client = Client(host=self.url, timeout=self.timeout) + def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -106,30 +107,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _create_json_payload(self, prompt: str, stream: bool, generation_kwargs=None) -> Dict[str, Any]: - """ - Returns a dictionary of JSON arguments for a POST request to an Ollama service. - """ - generation_kwargs = generation_kwargs or {} - return { - "prompt": prompt, - "model": self.model, - "stream": stream, - "raw": self.raw, - "template": self.template, - "system": self.system_prompt, - "options": generation_kwargs, - } - - def _convert_to_response(self, ollama_response: Response) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]: """ Converts a response from the Ollama API to the required Haystack format. """ - resp_dict = ollama_response.json() - - replies = [resp_dict["response"]] - meta = {key: value for key, value in resp_dict.items() if key != "response"} + replies = [ollama_response["response"]] + meta = {key: value for key, value in ollama_response.items() if key != "response"} return {"replies": replies, "meta": [meta]} @@ -148,7 +132,7 @@ def _handle_streaming_response(self, response) -> List[StreamingChunk]: Handles Streaming response cases """ chunks: List[StreamingChunk] = [] - for chunk in response.iter_lines(): + for chunk in response: chunk_delta: StreamingChunk = self._build_chunk(chunk) chunks.append(chunk_delta) if self.streaming_callback is not None: @@ -159,10 +143,8 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - decoded_chunk = json.loads(chunk_response.decode("utf-8")) - - content = decoded_chunk["response"] - meta = {key: value for key, value in decoded_chunk.items() if key != "response"} + content = chunk_response["response"] + meta = {key: value for key, value in chunk_response.items() if key != "response"} chunk_message = StreamingChunk(content, meta) return chunk_message @@ -190,12 +172,7 @@ def run( stream = self.streaming_callback is not None - json_payload = self._create_json_payload(prompt, stream, generation_kwargs) - - response = requests.post(url=self.url, json=json_payload, timeout=self.timeout, stream=stream) - - # throw error on unsuccessful response - response.raise_for_status() + response = self._client.generate(model=self.model, prompt=prompt, stream=stream, options=generation_kwargs) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index dd4e746aa..79d70675a 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -3,7 +3,7 @@ import pytest from haystack.dataclasses import ChatMessage, ChatRole -from requests import HTTPError, Response +from ollama._types import ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -22,47 +22,27 @@ class TestOllamaChatGenerator: def test_init_default(self): component = OllamaChatGenerator() assert component.model == "orca-mini" - assert component.url == "http://localhost:11434/api/chat" + assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} - assert component.template is None assert component.timeout == 120 def test_init(self): component = OllamaChatGenerator( model="llama2", - url="http://my-custom-endpoint:11434/api/chat", + url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, timeout=5, ) assert component.model == "llama2" - assert component.url == "http://my-custom-endpoint:11434/api/chat" + assert component.url == "http://my-custom-endpoint:11434" assert component.generation_kwargs == {"temperature": 0.5} - assert component.template is None assert component.timeout == 5 - def test_create_json_payload(self, chat_messages): - observed = OllamaChatGenerator(model="some_model")._create_json_payload( - chat_messages, False, {"temperature": 0.1} - ) - expected = { - "messages": [ - {"role": "user", "content": "Tell me about why Super Mario is the greatest superhero"}, - {"role": "assistant", "content": "Super Mario has prevented Bowser from destroying the world"}, - ], - "model": "some_model", - "stream": False, - "template": None, - "options": {"temperature": 0.1}, - } - - assert observed == expected - def test_build_message_from_ollama_response(self): model = "some_model" - mock_ollama_response = Mock(Response) - mock_ollama_response.json.return_value = { + ollama_response = { "model": model, "created_at": "2023-12-12T14:13:43.416799Z", "message": {"role": "assistant", "content": "Hello! How are you today?"}, @@ -75,7 +55,7 @@ def test_build_message_from_ollama_response(self): "eval_duration": 4799921000, } - observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(mock_ollama_response) + observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) assert observed.role == "assistant" assert observed.content == "Hello! How are you today?" @@ -123,7 +103,7 @@ def test_run_with_chat_history(self): def test_run_model_unavailable(self): component = OllamaChatGenerator(model="Alistair_and_Stefano_are_great") - with pytest.raises(HTTPError): + with pytest.raises(ResponseError): message = ChatMessage.from_user( "Based on your infinite wisdom, can you tell me why Alistair and Stefano are so great?" ) diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index 0f5b55881..4fe3cfbb3 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -1,6 +1,6 @@ import pytest from haystack import Document -from requests import HTTPError +from ollama._types import ResponseError from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder @@ -11,27 +11,27 @@ def test_init_defaults(self): assert embedder.timeout == 120 assert embedder.generation_kwargs == {} - assert embedder.url == "http://localhost:11434/api/embeddings" + assert embedder.url == "http://localhost:11434" assert embedder.model == "nomic-embed-text" def test_init(self): embedder = OllamaDocumentEmbedder( model="nomic-embed-text", - url="http://my-custom-endpoint:11434/api/embeddings", + url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, timeout=3000, ) assert embedder.timeout == 3000 assert embedder.generation_kwargs == {"temperature": 0.5} - assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings" + assert embedder.url == "http://my-custom-endpoint:11434" assert embedder.model == "nomic-embed-text" @pytest.mark.integration def test_model_not_found(self): embedder = OllamaDocumentEmbedder(model="cheese") - with pytest.raises(HTTPError): + with pytest.raises(ResponseError): embedder.run([Document("hello")]) @pytest.mark.integration diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index 069bbd227..c4c6906db 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -2,12 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any - import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import StreamingChunk -from requests import HTTPError +from ollama._types import ResponseError from haystack_integrations.components.generators.ollama import OllamaGenerator @@ -35,13 +33,13 @@ def test_run_capital_cities(self): def test_run_model_unavailable(self): component = OllamaGenerator(model="Alistair_is_great") - with pytest.raises(HTTPError): + with pytest.raises(ResponseError): component.run(prompt="Why is Alistair so great?") def test_init_default(self): component = OllamaGenerator() assert component.model == "orca-mini" - assert component.url == "http://localhost:11434/api/generate" + assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} assert component.system_prompt is None assert component.template is None @@ -55,14 +53,14 @@ def callback(x: StreamingChunk): component = OllamaGenerator( model="llama2", - url="http://my-custom-endpoint:11434/api/generate", + url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, streaming_callback=callback, ) assert component.model == "llama2" - assert component.url == "http://my-custom-endpoint:11434/api/generate" + assert component.url == "http://my-custom-endpoint:11434" assert component.generation_kwargs == {"temperature": 0.5} assert component.system_prompt == "You are Luigi from Super Mario Bros." assert component.template is None @@ -80,7 +78,7 @@ def callback(x: StreamingChunk): "template": None, "system_prompt": None, "model": "orca-mini", - "url": "http://localhost:11434/api/generate", + "url": "http://localhost:11434", "streaming_callback": None, "generation_kwargs": {}, }, @@ -128,44 +126,6 @@ def test_from_dict(self): assert component.url == "going_to_51_pegasi_b_for_weekend" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} - @pytest.mark.parametrize( - "configuration", - [ - { - "model": "some_model", - "url": "https://localhost:11434/api/generate", - "raw": True, - "system_prompt": "You are mario from Super Mario Bros.", - "template": None, - }, - { - "model": "some_model2", - "url": "https://localhost:11434/api/generate", - "raw": False, - "system_prompt": None, - "template": "some template", - }, - ], - ) - @pytest.mark.parametrize("stream", [True, False]) - def test_create_json_payload(self, configuration: dict[str, Any], stream: bool): - prompt = "hello" - component = OllamaGenerator(**configuration) - - observed = component._create_json_payload(prompt=prompt, stream=stream) - - expected = { - "prompt": prompt, - "model": configuration["model"], - "stream": stream, - "system": configuration["system_prompt"], - "raw": configuration["raw"], - "template": configuration["template"], - "options": {}, - } - - assert observed == expected - @pytest.mark.integration def test_ollama_generator_run_streaming(self): class Callback: diff --git a/integrations/ollama/tests/test_text_embedder.py b/integrations/ollama/tests/test_text_embedder.py index e7b69460f..d0f74c377 100644 --- a/integrations/ollama/tests/test_text_embedder.py +++ b/integrations/ollama/tests/test_text_embedder.py @@ -1,5 +1,5 @@ import pytest -from requests import HTTPError +from ollama._types import ResponseError from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder @@ -10,27 +10,27 @@ def test_init_defaults(self): assert embedder.timeout == 120 assert embedder.generation_kwargs == {} - assert embedder.url == "http://localhost:11434/api/embeddings" + assert embedder.url == "http://localhost:11434" assert embedder.model == "nomic-embed-text" def test_init(self): embedder = OllamaTextEmbedder( model="llama2", - url="http://my-custom-endpoint:11434/api/embeddings", + url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, timeout=3000, ) assert embedder.timeout == 3000 assert embedder.generation_kwargs == {"temperature": 0.5} - assert embedder.url == "http://my-custom-endpoint:11434/api/embeddings" + assert embedder.url == "http://my-custom-endpoint:11434" assert embedder.model == "llama2" @pytest.mark.integration def test_model_not_found(self): embedder = OllamaTextEmbedder(model="cheese") - with pytest.raises(HTTPError): + with pytest.raises(ResponseError): embedder.run("hello") @pytest.mark.integration From 647f668e7aa69c109a76e90213f6c80e3735b4ee Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Sat, 7 Sep 2024 18:54:20 +0000 Subject: [PATCH 05/33] Update the changelog --- integrations/ollama/CHANGELOG.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 6467aa868..8f51237e9 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,24 @@ # Changelog +## [integrations/ollama-v1.0.0] - 2024-09-07 + +### ๐Ÿ› Bug Fixes + +- Chat roles for model responses in chat generators (#1030) + +### ๐Ÿšœ Refactor + +- [**breaking**] Use ollama python library instead of calling the API with `requests` (#1059) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + ## [integrations/ollama-v0.0.7] - 2024-05-31 ### ๐Ÿš€ Features From 2fd6d1a9643c860ae1b5eb04bf6cb1745529d13a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 9 Sep 2024 08:43:19 +0100 Subject: [PATCH 06/33] ChromaDocumentStore lint fix (#1065) --- .../document_stores/chroma/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 3ea84780f..addcba296 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -453,7 +453,7 @@ def _query_result_to_documents(result: QueryResult) -> List[List[Document]]: for j in range(len(answers)): document_dict: Dict[str, Any] = { "id": result["ids"][i][j], - "content": documents[i][j], + "content": answers[j], } # prepare metadata From ee61033753b7dbd3a366a6d833d2f253e5b650c3 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 9 Sep 2024 01:14:03 -0700 Subject: [PATCH 07/33] fix chroma linting; rm numpy (#1063) Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .../document_stores/chroma/document_store.py | 3 +-- integrations/chroma/tests/test_document_store.py | 7 ++++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index addcba296..3353ed5aa 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Tuple import chromadb -import numpy as np from chromadb.api.types import GetResult, QueryResult, validate_where, validate_where_document from haystack import default_from_dict, default_to_dict from haystack.dataclasses import Document @@ -465,7 +464,7 @@ def _query_result_to_documents(result: QueryResult) -> List[List[Document]]: pass if embeddings := result.get("embeddings"): - document_dict["embedding"] = np.array(embeddings[i][j]) + document_dict["embedding"] = embeddings[i][j] if distances := result.get("distances"): document_dict["score"] = distances[i][j] diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index d4b6ed272..5a7e12b3d 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -106,7 +106,12 @@ def test_search(self): # Assertions to verify correctness assert len(result) == 1 - assert result[0][0].content == "Third document" + doc = result[0][0] + assert doc.content == "Third document" + assert doc.meta == {"author": "Author2"} + assert doc.embedding + assert isinstance(doc.embedding, list) + assert all(isinstance(el, float) for el in doc.embedding) def test_write_documents_unsupported_meta_values(self, document_store: ChromaDocumentStore): """ From 484895e2e272a80a6367eed86a8b69fe23992bbb Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 10 Sep 2024 10:11:15 +0000 Subject: [PATCH 08/33] Update the changelog --- integrations/google_vertex/CHANGELOG.md | 58 +++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 integrations/google_vertex/CHANGELOG.md diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md new file mode 100644 index 000000000..17a730b60 --- /dev/null +++ b/integrations/google_vertex/CHANGELOG.md @@ -0,0 +1,58 @@ +# Changelog + +## [unreleased] + +### ๐Ÿš€ Features + +- Enable streaming for VertexAIGeminiChatGenerator (#1014) +- Add tests for VertexAIGeminiGenerator and enable streaming (#1012) + +### ๐Ÿ› Bug Fixes + +- Remove the use of deprecated gemini models (#1032) +- Chat roles for model responses in chat generators (#1030) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) +- Add tests for VertexAIChatGeminiGenerator and migrate from preview package in vertexai (#1042) + +### โš™๏ธ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +## [integrations/google_vertex-v1.1.0] - 2024-03-28 + +## [integrations/google_vertex-v1.0.0] - 2024-03-27 + +### ๐Ÿ› Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### ๐Ÿ“š Documentation + +- Update category slug (#442) +- Review google vertex integration (#535) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### Google_vertex + +- Create api docs (#355) + +## [integrations/google_vertex-v0.2.0] - 2024-01-26 + +## [integrations/google_vertex-v0.1.0] - 2024-01-03 + +### ๐Ÿ› Bug Fixes + +- The default model of VertexAIImagegenerator (#158) + +### โš™๏ธ Miscellaneous Tasks + +- Replace - with _ (#114) + + From e0292f5a72d88d9f4f2d35b7074c6aa11e92c3ee Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Wed, 11 Sep 2024 12:27:35 +0200 Subject: [PATCH 09/33] fix: Add upper-bound pin to `ragas` dependency in `ragas-haystack` (#1076) --- integrations/ragas/pyproject.toml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index edc33eee1..d9ae6ca02 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ragas>=0.1.11"] +dependencies = ["haystack-ai", "ragas>=0.1.11,<=0.1.16"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ragas" @@ -41,7 +41,13 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ragas-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", "pytest-asyncio"] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "haystack-pydoc-tools", + "pytest-asyncio", +] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" From 50519ef87b0462683e675212c81ea4196a14923d Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 11 Sep 2024 10:29:10 +0000 Subject: [PATCH 10/33] Update the changelog --- integrations/ragas/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/ragas/CHANGELOG.md b/integrations/ragas/CHANGELOG.md index 7055f1931..94946bddc 100644 --- a/integrations/ragas/CHANGELOG.md +++ b/integrations/ragas/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/ragas-v1.0.1] - 2024-09-11 + +### ๐Ÿ› Bug Fixes + +- Add upper-bound pin to `ragas` dependency in `ragas-haystack` (#1076) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ## [integrations/ragas-v1.0.0] - 2024-07-24 ### โš™๏ธ Miscellaneous Tasks From 70e8356c43bee1a47dcc4aacf734de5e247541dc Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 11 Sep 2024 11:25:45 +0000 Subject: [PATCH 11/33] Update the changelog --- integrations/langfuse/CHANGELOG.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 2efa17a68..0a90a7121 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## [unreleased] + +### ๐Ÿšœ Refactor + +- Remove usage of deprecated `ChatMessage.to_openai_format` (#1001) + +### ๐Ÿ“š Documentation + +- Add link to langfuse in LangfuseConnector (#981) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- `Langfuse` - replace DynamicChatPromptBuilder with ChatPromptBuilder (#925) +- Remove all `DynamicChatPromptBuilder` references in Langfuse integration (#931) + ## [integrations/langfuse-v0.2.0] - 2024-06-18 ## [integrations/langfuse-v0.1.0] - 2024-06-13 From 6daefc60c4127f4e25256aa12a10a4665fcaac37 Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Wed, 11 Sep 2024 14:28:43 +0200 Subject: [PATCH 12/33] update links (#1073) --- .../components/converters/unstructured/converter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py index 637c0840f..9230ecb0d 100644 --- a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py +++ b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py @@ -27,7 +27,7 @@ class UnstructuredFileConverter: A component for converting files to Haystack Documents using the Unstructured API (hosted or running locally). For the supported file types and the specific API parameters, see - [Unstructured docs](https://unstructured-io.github.io/unstructured/api.html). + [Unstructured docs](https://docs.unstructured.io/api-reference/api-services/overview). Usage example: ```python @@ -68,7 +68,7 @@ def __init__( :param separator: Separator between elements when concatenating them into one text field. :param unstructured_kwargs: Additional parameters that are passed to the Unstructured API. For the available parameters, see - [Unstructured API docs](https://unstructured-io.github.io/unstructured/apis/api_parameters.html). + [Unstructured API docs](https://docs.unstructured.io/api-reference/api-services/api-parameters). :param progress_bar: Whether to show a progress bar during the conversion. """ From e7de37e18f74be24cad664bacf75abd0e40f2636 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Sep 2024 16:58:11 +0200 Subject: [PATCH 13/33] Update docstrings to remove vertex.preview (#1074) --- .../components/generators/google_vertex/chat/gemini.py | 8 ++++---- .../components/generators/google_vertex/gemini.py | 8 ++++---- .../generators/google_vertex/image_generator.py | 2 +- integrations/google_vertex/tests/test_image_generator.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index 8cdb58d2d..e693c10f4 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -67,14 +67,14 @@ def __init__( :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. :param generation_config: Configuration for the generation process. - See the [GenerationConfig documentation](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig + See the [GenerationConfig documentation](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.GenerationConfig for a list of supported arguments. :param safety_settings: Safety settings to use when generating content. See the documentation - for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmBlockThreshold) - and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmCategory) + for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) + and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) + [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) the list of supported arguments. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 11592671f..7394211bf 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -70,7 +70,7 @@ def __init__( :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. - Can either be a [`GenerationConfig`](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.GenerationConfig) + Can either be a [`GenerationConfig`](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.GenerationConfig) object or a dictionary of parameters. Accepted fields are: - temperature @@ -80,11 +80,11 @@ def __init__( - max_output_tokens - stop_sequences :param safety_settings: The safety settings to use. See the documentation - for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmBlockThreshold) - and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.HarmCategory) + for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) + and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) + [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) the list of supported arguments. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py index ae8c4892f..0534a20f2 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py @@ -5,7 +5,7 @@ from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.byte_stream import ByteStream -from vertexai.preview.vision_models import ImageGenerationModel +from vertexai.vision_models import ImageGenerationModel logger = logging.getLogger(__name__) diff --git a/integrations/google_vertex/tests/test_image_generator.py b/integrations/google_vertex/tests/test_image_generator.py index 42cc0a0a3..6cd42a11c 100644 --- a/integrations/google_vertex/tests/test_image_generator.py +++ b/integrations/google_vertex/tests/test_image_generator.py @@ -1,6 +1,6 @@ from unittest.mock import Mock, patch -from vertexai.preview.vision_models import ImageGenerationResponse +from vertexai.vision_models import ImageGenerationResponse from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator From 66d11b31290ac73a58deac587a179ba7e51a6b6e Mon Sep 17 00:00:00 2001 From: Kane Norman <51185594+kanenorman@users.noreply.github.com> Date: Thu, 12 Sep 2024 05:40:00 -0500 Subject: [PATCH 14/33] fix: AstraDocumentStore filter by id (#1053) --- .../haystack_integrations/document_stores/astra/filters.py | 4 ++-- integrations/astra/tests/test_document_store.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 61f3e5402..340e95ba9 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -30,8 +30,6 @@ def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[ if key in {"$and", "$or"}: filter_statements[key] = value else: - if key == "id": - filter_statements[key] = {"_id": value} if key != "$in" and isinstance(value, list): filter_statements[key] = {"$in": value} elif isinstance(value, pd.DataFrame): @@ -45,6 +43,8 @@ def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[ filter_statements[key] = converted else: filter_statements[key] = value + if key == "id": + filter_statements["_id"] = filter_statements.pop("id") return filter_statements diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index df181ad8c..c4d1b6347 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -200,6 +200,12 @@ def test_filter_documents_nested_filters(self, document_store, filterable_docs): ], ) + def test_filter_documents_by_id(self, document_store): + docs = [Document(id="1", content="test doc 1"), Document(id="2", content="test doc 2")] + document_store.write_documents(docs) + result = document_store.filter_documents(filters={"field": "id", "operator": "==", "value": "1"}) + self.assert_documents_are_equal(result, [docs[0]]) + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass From e1860d440a9f747bca661576d175d094d79c324d Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 12 Sep 2024 10:42:58 +0000 Subject: [PATCH 15/33] Update the changelog --- integrations/astra/CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md index 55c22f540..79bb9e35d 100644 --- a/integrations/astra/CHANGELOG.md +++ b/integrations/astra/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/astra-v0.9.3] - 2024-09-12 + +### ๐Ÿ› Bug Fixes + +- Astra DB, improved warnings and guidance about indexing-related mismatches (#932) +- AstraDocumentStore filter by id (#1053) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ## [integrations/astra-v0.9.2] - 2024-07-22 ## [integrations/astra-v0.9.1] - 2024-07-15 From 8bac5def887b61bd8aaae5420ba863fa31a90648 Mon Sep 17 00:00:00 2001 From: Corentin Meyer Date: Thu, 12 Sep 2024 13:11:11 +0200 Subject: [PATCH 16/33] feat: Qdrant - Add group_by and group_size optional parameters to Retrievers (#1054) * Qdrant: Add group_by and group_size optional parameters to Retrievers * Simplify ifs --------- Co-authored-by: Silvano Cerza --- .../components/retrievers/qdrant/retriever.py | 66 +++++- .../document_stores/qdrant/document_store.py | 214 +++++++++++++----- .../qdrant/tests/test_document_store.py | 33 +++ integrations/qdrant/tests/test_retriever.py | 103 +++++++++ 4 files changed, 355 insertions(+), 61 deletions(-) diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 408b2458a..fee9a6182 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -44,13 +44,16 @@ def __init__( return_embedding: bool = False, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Create a QdrantEmbeddingRetriever component. :param document_store: An instance of QdrantDocumentStore. :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to retrieve. + :param top_k: The maximum number of documents to retrieve. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the embedding of the retrieved Documents. :param filter_policy: Policy to determine how filters are applied. @@ -58,6 +61,9 @@ def __init__( Score of the returned result might be higher or smaller than the threshold depending on the `similarity` function specified in the Document Store. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :raises ValueError: If `document_store` is not an instance of `QdrantDocumentStore`. """ @@ -75,6 +81,8 @@ def __init__( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold + self._group_by = group_by + self._group_size = group_size def to_dict(self) -> Dict[str, Any]: """ @@ -92,6 +100,8 @@ def to_dict(self) -> Dict[str, Any]: scale_score=self._scale_score, return_embedding=self._return_embedding, score_threshold=self._score_threshold, + group_by=self._group_by, + group_size=self._group_size, ) d["init_parameters"]["document_store"] = self._document_store.to_dict() @@ -124,16 +134,22 @@ def run( scale_score: Optional[bool] = None, return_embedding: Optional[bool] = None, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Run the Embedding Retriever on the given input data. :param query_embedding: Embedding of the query. :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to return. + :param top_k: The maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the embedding of the retrieved Documents. :param score_threshold: A minimal score threshold for the result. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: The retrieved documents. @@ -147,6 +163,8 @@ def run( scale_score=scale_score or self._scale_score, return_embedding=return_embedding or self._return_embedding, score_threshold=score_threshold or self._score_threshold, + group_by=group_by or self._group_by, + group_size=group_size or self._group_size, ) return {"documents": docs} @@ -188,13 +206,16 @@ def __init__( return_embedding: bool = False, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Create a QdrantSparseEmbeddingRetriever component. :param document_store: An instance of QdrantDocumentStore. :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to retrieve. + :param top_k: The maximum number of documents to retrieve. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the sparse embedding of the retrieved Documents. :param filter_policy: Policy to determine how filters are applied. Defaults to "replace". @@ -202,6 +223,9 @@ def __init__( Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :raises ValueError: If `document_store` is not an instance of `QdrantDocumentStore`. """ @@ -219,6 +243,8 @@ def __init__( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold + self._group_by = group_by + self._group_size = group_size def to_dict(self) -> Dict[str, Any]: """ @@ -236,6 +262,8 @@ def to_dict(self) -> Dict[str, Any]: filter_policy=self._filter_policy.value, return_embedding=self._return_embedding, score_threshold=self._score_threshold, + group_by=self._group_by, + group_size=self._group_size, ) d["init_parameters"]["document_store"] = self._document_store.to_dict() @@ -268,6 +296,8 @@ def run( scale_score: Optional[bool] = None, return_embedding: Optional[bool] = None, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Run the Sparse Embedding Retriever on the given input data. @@ -276,13 +306,17 @@ def run( :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on the `filter_policy` chosen at retriever initialization. See init method docstring for more details. - :param top_k: The maximum number of documents to return. + :param top_k: The maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents or not. :param return_embedding: Whether to return the embedding of the retrieved Documents. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: The retrieved documents. @@ -296,6 +330,8 @@ def run( scale_score=scale_score or self._scale_score, return_embedding=return_embedding or self._return_embedding, score_threshold=score_threshold or self._score_threshold, + group_by=group_by or self._group_by, + group_size=group_size or self._group_size, ) return {"documents": docs} @@ -342,19 +378,25 @@ def __init__( return_embedding: bool = False, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Create a QdrantHybridRetriever component. :param document_store: An instance of QdrantDocumentStore. :param filters: A dictionary with filters to narrow down the search space. - :param top_k: The maximum number of documents to retrieve. + :param top_k: The maximum number of documents to retrieve. If using `group_by` parameters, maximum number of + groups to return. :param return_embedding: Whether to return the embeddings of the retrieved Documents. :param filter_policy: Policy to determine how filters are applied. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. """ @@ -371,6 +413,8 @@ def __init__( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) self._score_threshold = score_threshold + self._group_by = group_by + self._group_size = group_size def to_dict(self) -> Dict[str, Any]: """ @@ -387,6 +431,8 @@ def to_dict(self) -> Dict[str, Any]: filter_policy=self._filter_policy.value, return_embedding=self._return_embedding, score_threshold=self._score_threshold, + group_by=self._group_by, + group_size=self._group_size, ) @classmethod @@ -416,6 +462,8 @@ def run( top_k: Optional[int] = None, return_embedding: Optional[bool] = None, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ): """ Run the Sparse Embedding Retriever on the given input data. @@ -425,12 +473,16 @@ def run( :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on the `filter_policy` chosen at retriever initialization. See init method docstring for more details. - :param top_k: The maximum number of documents to return. + :param top_k: The maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param return_embedding: Whether to return the embedding of the retrieved Documents. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: The retrieved documents. @@ -444,6 +496,8 @@ def run( top_k=top_k or self._top_k, return_embedding=return_embedding or self._return_embedding, score_threshold=score_threshold or self._score_threshold, + group_by=group_by or self._group_by, + group_size=group_size or self._group_size, ) return {"documents": docs} diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index a436fba55..da48e0f28 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -506,19 +506,25 @@ def _query_by_sparse( scale_score: bool = False, return_embedding: bool = False, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ) -> List[Document]: """ Queries Qdrant using a sparse embedding and returns the most relevant documents. :param query_sparse_embedding: Sparse embedding of the query. :param filters: Filters applied to the retrieved documents. - :param top_k: Maximum number of documents to return. + :param top_k: Maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents. :param return_embedding: Whether to return the embeddings of the retrieved documents. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: List of documents that are most similar to `query_sparse_embedding`. @@ -536,22 +542,47 @@ def _query_by_sparse( qdrant_filters = convert_filters_to_qdrant(filters) query_indices = query_sparse_embedding.indices query_values = query_sparse_embedding.values - points = self.client.query_points( - collection_name=self.index, - query=rest.SparseVector( - indices=query_indices, - values=query_values, - ), - using=SPARSE_VECTORS_NAME, - query_filter=qdrant_filters, - limit=top_k, - with_vectors=return_embedding, - score_threshold=score_threshold, - ).points - results = [ - convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) - for point in points - ] + if group_by: + groups = self.client.query_points_groups( + collection_name=self.index, + query=rest.SparseVector( + indices=query_indices, + values=query_values, + ), + using=SPARSE_VECTORS_NAME, + query_filter=qdrant_filters, + limit=top_k, + group_by=group_by, + group_size=group_size, + with_vectors=return_embedding, + score_threshold=score_threshold, + ).groups + results = ( + [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for group in groups + for point in group.hits + ] + if groups + else [] + ) + else: + points = self.client.query_points( + collection_name=self.index, + query=rest.SparseVector( + indices=query_indices, + values=query_values, + ), + using=SPARSE_VECTORS_NAME, + query_filter=qdrant_filters, + limit=top_k, + with_vectors=return_embedding, + score_threshold=score_threshold, + ).points + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] if scale_score: for document in results: score = document.score @@ -567,37 +598,65 @@ def _query_by_embedding( scale_score: bool = False, return_embedding: bool = False, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ) -> List[Document]: """ Queries Qdrant using a dense embedding and returns the most relevant documents. :param query_embedding: Dense embedding of the query. :param filters: Filters applied to the retrieved documents. - :param top_k: Maximum number of documents to return. + :param top_k: Maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param scale_score: Whether to scale the scores of the retrieved documents. :param return_embedding: Whether to return the embeddings of the retrieved documents. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: List of documents that are most similar to `query_embedding`. """ qdrant_filters = convert_filters_to_qdrant(filters) + if group_by: + groups = self.client.query_points_groups( + collection_name=self.index, + query=query_embedding, + using=DENSE_VECTORS_NAME if self.use_sparse_embeddings else None, + query_filter=qdrant_filters, + limit=top_k, + group_by=group_by, + group_size=group_size, + with_vectors=return_embedding, + score_threshold=score_threshold, + ).groups + results = ( + [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for group in groups + for point in group.hits + ] + if groups + else [] + ) + else: + points = self.client.query_points( + collection_name=self.index, + query=query_embedding, + using=DENSE_VECTORS_NAME if self.use_sparse_embeddings else None, + query_filter=qdrant_filters, + limit=top_k, + with_vectors=return_embedding, + score_threshold=score_threshold, + ).points + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] - points = self.client.query_points( - collection_name=self.index, - query=query_embedding, - using=DENSE_VECTORS_NAME if self.use_sparse_embeddings else None, - query_filter=qdrant_filters, - limit=top_k, - with_vectors=return_embedding, - score_threshold=score_threshold, - ).points - results = [ - convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) - for point in points - ] if scale_score: for document in results: score = document.score @@ -616,6 +675,8 @@ def _query_hybrid( top_k: int = 10, return_embedding: bool = False, score_threshold: Optional[float] = None, + group_by: Optional[str] = None, + group_size: Optional[int] = None, ) -> List[Document]: """ Retrieves documents based on dense and sparse embeddings and fuses the results using Reciprocal Rank Fusion. @@ -626,12 +687,16 @@ def _query_hybrid( :param query_embedding: Dense embedding of the query. :param query_sparse_embedding: Sparse embedding of the query. :param filters: Filters applied to the retrieved documents. - :param top_k: Maximum number of documents to return. + :param top_k: Maximum number of documents to return. If using `group_by` parameters, maximum number of + groups to return. :param return_embedding: Whether to return the embeddings of the retrieved documents. :param score_threshold: A minimal score threshold for the result. Score of the returned result might be higher or smaller than the threshold depending on the Distance function used. E.g. for cosine similarity only higher scores will be returned. + :param group_by: Payload field to group by, must be a string or number field. If the field contains more than 1 + value, all values will be used for grouping. One point can be in multiple groups. + :param group_size: Maximum amount of points to return per group. Default is 3. :returns: List of Document that are most similar to `query_embedding` and `query_sparse_embedding`. @@ -651,34 +716,73 @@ def _query_hybrid( qdrant_filters = convert_filters_to_qdrant(filters) try: - points = self.client.query_points( - collection_name=self.index, - prefetch=[ - rest.Prefetch( - query=rest.SparseVector( - indices=query_sparse_embedding.indices, - values=query_sparse_embedding.values, + if group_by: + groups = self.client.query_points_groups( + collection_name=self.index, + prefetch=[ + rest.Prefetch( + query=rest.SparseVector( + indices=query_sparse_embedding.indices, + values=query_sparse_embedding.values, + ), + using=SPARSE_VECTORS_NAME, + filter=qdrant_filters, ), - using=SPARSE_VECTORS_NAME, - filter=qdrant_filters, - ), - rest.Prefetch( - query=query_embedding, - using=DENSE_VECTORS_NAME, - filter=qdrant_filters, - ), - ], - query=rest.FusionQuery(fusion=rest.Fusion.RRF), - limit=top_k, - score_threshold=score_threshold, - with_payload=True, - with_vectors=return_embedding, - ).points + rest.Prefetch( + query=query_embedding, + using=DENSE_VECTORS_NAME, + filter=qdrant_filters, + ), + ], + query=rest.FusionQuery(fusion=rest.Fusion.RRF), + limit=top_k, + group_by=group_by, + group_size=group_size, + score_threshold=score_threshold, + with_payload=True, + with_vectors=return_embedding, + ).groups + else: + points = self.client.query_points( + collection_name=self.index, + prefetch=[ + rest.Prefetch( + query=rest.SparseVector( + indices=query_sparse_embedding.indices, + values=query_sparse_embedding.values, + ), + using=SPARSE_VECTORS_NAME, + filter=qdrant_filters, + ), + rest.Prefetch( + query=query_embedding, + using=DENSE_VECTORS_NAME, + filter=qdrant_filters, + ), + ], + query=rest.FusionQuery(fusion=rest.Fusion.RRF), + limit=top_k, + score_threshold=score_threshold, + with_payload=True, + with_vectors=return_embedding, + ).points + except Exception as e: msg = "Error during hybrid search" raise QdrantStoreError(msg) from e - results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points] + if group_by: + results = ( + [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for group in groups + for point in group.hits + ] + if groups + else [] + ) + else: + results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points] return results diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index 112b7e5ac..79523531b 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -97,6 +97,39 @@ def test_query_hybrid(self, generate_sparse_embedding): assert document.sparse_embedding assert document.embedding + def test_query_hybrid_with_group_by(self, generate_sparse_embedding): + document_store = QdrantDocumentStore(location=":memory:", use_sparse_embeddings=True) + + docs = [] + for i in range(20): + docs.append( + Document( + content=f"doc {i}", + sparse_embedding=generate_sparse_embedding(), + embedding=_random_embeddings(768), + meta={"group_field": i // 2}, + ) + ) + + document_store.write_documents(docs) + + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + embedding = [0.1] * 768 + + results: List[Document] = document_store._query_hybrid( + query_sparse_embedding=sparse_embedding, + query_embedding=embedding, + top_k=3, + return_embedding=True, + group_by="meta.group_field", + group_size=2, + ) + assert len(results) == 6 + + for document in results: + assert document.sparse_embedding + assert document.embedding + def test_query_hybrid_fail_without_sparse_embedding(self, document_store): sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) embedding = [0.1] * 768 diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index eb0386828..bd6b92842 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -27,6 +27,8 @@ def test_init_default(self): assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None retriever = QdrantEmbeddingRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -87,6 +89,8 @@ def test_to_dict(self): "scale_score": False, "return_embedding": False, "score_threshold": None, + "group_by": None, + "group_size": None, }, } @@ -104,6 +108,8 @@ def test_from_dict(self): "scale_score": False, "return_embedding": True, "score_threshold": None, + "group_by": None, + "group_size": None, }, } retriever = QdrantEmbeddingRetriever.from_dict(data) @@ -115,6 +121,8 @@ def test_from_dict(self): assert retriever._scale_score is False assert retriever._return_embedding is True assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None def test_run(self, filterable_docs: List[Document]): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=False) @@ -200,6 +208,26 @@ def test_run_with_sparse_activated(self, filterable_docs: List[Document]): for document in results: assert document.embedding is None + def test_run_with_group_by(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) + # Add group_field metadata to documents + for index, doc in enumerate(filterable_docs): + doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + document_store.write_documents(filterable_docs) + + retriever = QdrantEmbeddingRetriever(document_store=document_store) + results = retriever.run( + query_embedding=_random_embeddings(768), + top_k=3, + return_embedding=False, + group_by="meta.group_field", + group_size=2, + )["documents"] + assert len(results) >= 3 # This test is Flaky + assert len(results) <= 6 # This test is Flaky + for document in results: + assert document.embedding is None + class TestQdrantSparseEmbeddingRetriever(FilterableDocsFixtureMixin): def test_init_default(self): @@ -211,6 +239,8 @@ def test_init_default(self): assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None retriever = QdrantSparseEmbeddingRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -271,6 +301,8 @@ def test_to_dict(self): "return_embedding": False, "filter_policy": "replace", "score_threshold": None, + "group_by": None, + "group_size": None, }, } @@ -288,6 +320,8 @@ def test_from_dict(self): "return_embedding": True, "filter_policy": "replace", "score_threshold": None, + "group_by": None, + "group_size": None, }, } retriever = QdrantSparseEmbeddingRetriever.from_dict(data) @@ -299,6 +333,8 @@ def test_from_dict(self): assert retriever._scale_score is False assert retriever._return_embedding is True assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None def test_from_dict_no_filter_policy(self): data = { @@ -313,6 +349,8 @@ def test_from_dict_no_filter_policy(self): "scale_score": False, "return_embedding": True, "score_threshold": None, + "group_by": None, + "group_size": None, }, } retriever = QdrantSparseEmbeddingRetriever.from_dict(data) @@ -324,6 +362,8 @@ def test_from_dict_no_filter_policy(self): assert retriever._scale_score is False assert retriever._return_embedding is True assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) @@ -345,6 +385,29 @@ def test_run(self, filterable_docs: List[Document], generate_sparse_embedding): for document in results: assert document.sparse_embedding + def test_run_with_group_by(self, filterable_docs: List[Document], generate_sparse_embedding): + document_store = QdrantDocumentStore(location=":memory:", index="Boi", use_sparse_embeddings=True) + + # Add fake sparse embedding to documents + for index, doc in enumerate(filterable_docs): + doc.sparse_embedding = generate_sparse_embedding() + doc.meta = {"group_field": index // 2} # So at least two docs have same group each time + document_store.write_documents(filterable_docs) + retriever = QdrantSparseEmbeddingRetriever(document_store=document_store) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + results = retriever.run( + query_sparse_embedding=sparse_embedding, + top_k=3, + return_embedding=True, + group_by="meta.group_field", + group_size=2, + )["documents"] + assert len(results) >= 3 # This test is Flaky + assert len(results) <= 6 # This test is Flaky + + for document in results: + assert document.sparse_embedding + class TestQdrantHybridRetriever: def test_init_default(self): @@ -357,6 +420,8 @@ def test_init_default(self): assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding is False assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None retriever = QdrantHybridRetriever(document_store=document_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -416,6 +481,8 @@ def test_to_dict(self): "filter_policy": "replace", "return_embedding": True, "score_threshold": None, + "group_by": None, + "group_size": None, }, } @@ -432,6 +499,8 @@ def test_from_dict(self): "filter_policy": "replace", "return_embedding": True, "score_threshold": None, + "group_by": None, + "group_size": None, }, } retriever = QdrantHybridRetriever.from_dict(data) @@ -442,6 +511,8 @@ def test_from_dict(self): assert retriever._filter_policy == FilterPolicy.REPLACE assert retriever._return_embedding assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None def test_from_dict_no_filter_policy(self): data = { @@ -455,6 +526,8 @@ def test_from_dict_no_filter_policy(self): "top_k": 5, "return_embedding": True, "score_threshold": None, + "group_by": None, + "group_size": None, }, } retriever = QdrantHybridRetriever.from_dict(data) @@ -465,6 +538,8 @@ def test_from_dict_no_filter_policy(self): assert retriever._filter_policy == FilterPolicy.REPLACE # defaults to REPLACE assert retriever._return_embedding assert retriever._score_threshold is None + assert retriever._group_by is None + assert retriever._group_size is None def test_run(self): mock_store = Mock(spec=QdrantDocumentStore) @@ -488,3 +563,31 @@ def test_run(self): assert res["documents"][0].content == "Test doc" assert res["documents"][0].embedding == [0.1, 0.2] assert res["documents"][0].sparse_embedding == sparse_embedding + + def test_run_with_group_by(self): + mock_store = Mock(spec=QdrantDocumentStore) + sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33]) + mock_store._query_hybrid.return_value = [ + Document(content="Test doc", embedding=[0.1, 0.2], sparse_embedding=sparse_embedding) + ] + + retriever = QdrantHybridRetriever(document_store=mock_store) + res = retriever.run( + query_embedding=[0.5, 0.7], + query_sparse_embedding=SparseEmbedding(indices=[0, 5], values=[0.1, 0.7]), + group_by="meta.group_field", + group_size=2, + ) + + call_args = mock_store._query_hybrid.call_args + assert call_args[1]["query_embedding"] == [0.5, 0.7] + assert call_args[1]["query_sparse_embedding"].indices == [0, 5] + assert call_args[1]["query_sparse_embedding"].values == [0.1, 0.7] + assert call_args[1]["top_k"] == 10 + assert call_args[1]["return_embedding"] is False + assert call_args[1]["group_by"] == "meta.group_field" + assert call_args[1]["group_size"] == 2 + + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + assert res["documents"][0].sparse_embedding == sparse_embedding From 77751ece31d3ab575c00020ec2477c3a0175393e Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 12 Sep 2024 11:12:29 +0000 Subject: [PATCH 17/33] Update the changelog --- integrations/qdrant/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md index ad664bdd4..edc936fb2 100644 --- a/integrations/qdrant/CHANGELOG.md +++ b/integrations/qdrant/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/qdrant-v5.1.0] - 2024-09-12 + +### ๐Ÿš€ Features + +- Qdrant - Add group_by and group_size optional parameters to Retrievers (#1054) + ## [integrations/qdrant-v5.0.0] - 2024-09-02 ## [integrations/qdrant-v4.2.0] - 2024-08-27 From f467983fe261d3623ba8f43af2316b695cb47745 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 14:59:00 +0200 Subject: [PATCH 18/33] chore: MongoDB - remove legacy filter support (#1066) * Remove legacy filter support * Lint * Add _normalize_filters * Lint all * Linting issues * Error msg fmt --- .../document_stores/mongodb_atlas/filters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index 4583d6cd3..0b5986222 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -5,7 +5,6 @@ from typing import Any, Dict from haystack.errors import FilterError -from haystack.utils.filters import convert from pandas import DataFrame UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame) @@ -20,7 +19,8 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) if "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) if "field" in filters: return _parse_comparison_condition(filters) From 64be1581ab1d88184395833bd36528bb084a4f00 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 14:59:11 +0200 Subject: [PATCH 19/33] chore: OpenSearch - remove legacy filter support (#1067) * Remove legacy filter support * Lint * More linting * Small fix * Remove outdated test * Lint tests * Remove outdated test * Improve error message * Error msg fmt * More formatting --- .../opensearch/document_store.py | 11 ++- .../opensearch/tests/test_document_store.py | 91 ------------------- 2 files changed, 6 insertions(+), 96 deletions(-) diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 3a6056bd2..6f7a6c96e 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -9,7 +9,6 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.utils.filters import convert from opensearchpy import OpenSearch from opensearchpy.helpers import bulk @@ -238,14 +237,14 @@ def _search_documents(self, **kwargs) -> List[Document]: def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) if filters: query = {"bool": {"filter": normalize_filters(filters)}} documents = self._search_documents(query=query, size=10_000) else: documents = self._search_documents(size=10_000) - return documents def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: @@ -384,7 +383,8 @@ def _bm25_retrieval( :returns: List of Document that match `query` """ if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) if not query: body: Dict[str, Any] = {"query": {"bool": {"must": {"match_all": {}}}}} @@ -478,7 +478,8 @@ def _embedding_retrieval( :returns: List of Document that are most similar to `query_embedding` """ if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Legacy filters support has been removed. Please see documentation for new filter syntax." + raise ValueError(msg) if not query_embedding: msg = "query_embedding must be a non-empty list of floats" diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index a7d516b3e..9cc4bf4ea 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -574,76 +574,6 @@ def test_bm25_retrieval_with_filters(self, document_store: OpenSearchDocumentSto retrieved_ids = sorted([doc.id for doc in res]) assert retrieved_ids == ["1", "2", "3", "4", "5"] - def test_bm25_retrieval_with_legacy_filters(self, document_store: OpenSearchDocumentStore): - document_store.write_documents( - [ - Document( - content="Haskell is a functional programming language", - meta={"likes": 100000, "language_type": "functional"}, - id="1", - ), - Document( - content="Lisp is a functional programming language", - meta={"likes": 10000, "language_type": "functional"}, - id="2", - ), - Document( - content="Exilir is a functional programming language", - meta={"likes": 1000, "language_type": "functional"}, - id="3", - ), - Document( - content="F# is a functional programming language", - meta={"likes": 100, "language_type": "functional"}, - id="4", - ), - Document( - content="C# is a functional programming language", - meta={"likes": 10, "language_type": "functional"}, - id="5", - ), - Document( - content="C++ is an object oriented programming language", - meta={"likes": 100000, "language_type": "object_oriented"}, - id="6", - ), - Document( - content="Dart is an object oriented programming language", - meta={"likes": 10000, "language_type": "object_oriented"}, - id="7", - ), - Document( - content="Go is an object oriented programming language", - meta={"likes": 1000, "language_type": "object_oriented"}, - id="8", - ), - Document( - content="Python is a object oriented programming language", - meta={"likes": 100, "language_type": "object_oriented"}, - id="9", - ), - Document( - content="Ruby is a object oriented programming language", - meta={"likes": 10, "language_type": "object_oriented"}, - id="10", - ), - Document( - content="PHP is a object oriented programming language", - meta={"likes": 1, "language_type": "object_oriented"}, - id="11", - ), - ] - ) - - res = document_store._bm25_retrieval( - "programming", - top_k=10, - filters={"language_type": "functional"}, - ) - assert len(res) == 5 - retrieved_ids = sorted([doc.id for doc in res]) - assert retrieved_ids == ["1", "2", "3", "4", "5"] - def test_bm25_retrieval_with_custom_query(self, document_store: OpenSearchDocumentStore): document_store.write_documents( [ @@ -760,27 +690,6 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" - def test_embedding_retrieval_with_legacy_filters(self, document_store_embedding_dim_4: OpenSearchDocumentStore): - docs = [ - Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), - Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), - Document( - content="Not very similar document with meta field", - embedding=[0.0, 0.8, 0.3, 0.9], - meta={"meta_field": "custom_value"}, - ), - ] - document_store_embedding_dim_4.write_documents(docs) - - filters = {"meta_field": "custom_value"} - # we set top_k=3, to make the test pass as we are not sure whether efficient filtering is supported for nmslib - # TODO: remove top_k=3, when efficient filtering is supported for nmslib - results = document_store_embedding_dim_4._embedding_retrieval( - query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=3, filters=filters - ) - assert len(results) == 1 - assert results[0].content == "Not very similar document with meta field" - def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. From 3290da69ffd48090c291e37fbcb98847f786bd48 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 14:59:24 +0200 Subject: [PATCH 20/33] chore: PgVector - remove legacy filter support (#1068) * Remove legacy filter support * Linting * Error msg fmt --- .../document_stores/pgvector/document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index ae4878aba..a02c46200 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -9,7 +9,6 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils.auth import Secret, deserialize_secrets_inplace -from haystack.utils.filters import convert from psycopg import Error, IntegrityError, connect from psycopg.abc import Query from psycopg.cursor import Cursor @@ -389,7 +388,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Filters must be a dictionary" raise TypeError(msg) if "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) From 69946c0f8c89f13d42ff9f9e35098652fc3e7db5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 14:59:35 +0200 Subject: [PATCH 21/33] chore: Pinecone - remove legacy filter support (#1069) * Remove legacy filter support * Linting * Improve error message * Improve error message - lint * Change message to be more generic * Error msg fmt --- .../document_stores/pinecone/document_store.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 75d6270ca..07f217f5b 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -11,7 +11,6 @@ from haystack.dataclasses import Document from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.filters import convert from pinecone import Pinecone, PodSpec, ServerlessSpec @@ -201,6 +200,10 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ + if filters and "operator" not in filters and "conditions" not in filters: + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) + # Pinecone only performs vector similarity search # here we are querying with a dummy vector and the max compatible top_k documents = self._embedding_retrieval(query_embedding=self._dummy_vector, filters=filters, top_k=TOP_K_LIMIT) @@ -253,7 +256,8 @@ def _embedding_retrieval( raise ValueError(msg) if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Legacy filters support has been removed. Please see documentation for new filter syntax." + raise ValueError(msg) filters = _normalize_filters(filters) if filters else None result = self.index.query( From 4f19d5726a89d000b17501426feccebd4748457a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 14:59:46 +0200 Subject: [PATCH 22/33] chore: Weaviate - remove legacy filter support (#1070) * Remove legacy filter support * Linting * Remove outdated test * Improve error message * Error msg fmt --- .../document_stores/weaviate/document_store.py | 4 ++-- integrations/weaviate/tests/test_document_store.py | 9 --------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 09e0a673d..e312b1473 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -12,7 +12,6 @@ from haystack.dataclasses.document import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types.policy import DuplicatePolicy -from haystack.utils.filters import convert import weaviate from weaviate.collections.classes.data import DataObject @@ -388,7 +387,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) result = [] if filters: diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 8d531cade..190c23408 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -660,15 +660,6 @@ def test_embedding_retrieval_with_distance_and_certainty(self, document_store): with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) - def test_filter_documents_with_legacy_filters(self, document_store): - docs = [] - for index in range(10): - docs.append(Document(content="This is some content", meta={"index": index})) - document_store.write_documents(docs) - result = document_store.filter_documents({"content": {"$eq": "This is some content"}}) - - assert len(result) == 10 - def test_filter_documents_below_default_limit(self, document_store): docs = [] for index in range(9998): From 3d2693d2e429e584250f222c6b7c149b0142aa3b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 15:30:03 +0200 Subject: [PATCH 23/33] chore: ElasticSearch - remove legacy filters elasticsearch (#1078) * Remove legacy filter support * Improve error message * Error msg fmt --- .../document_stores/elasticsearch/document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 11016e3fc..734e2d2b8 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -12,7 +12,6 @@ from haystack.dataclasses import Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.utils.filters import convert from haystack.version import __version__ as haystack_version from elasticsearch import Elasticsearch, helpers # type: ignore[import-not-found] @@ -224,7 +223,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: List of `Document`s that match the filters. """ if filters and "operator" not in filters and "conditions" not in filters: - filters = convert(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None documents = self._search_documents(query=query) From 8ef0e6d956b27cb67dc9fe9b5838b0c00d800e7c Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Thu, 12 Sep 2024 13:42:01 +0000 Subject: [PATCH 24/33] Update the changelog --- integrations/mongodb_atlas/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/integrations/mongodb_atlas/CHANGELOG.md b/integrations/mongodb_atlas/CHANGELOG.md index 851858355..91b073102 100644 --- a/integrations/mongodb_atlas/CHANGELOG.md +++ b/integrations/mongodb_atlas/CHANGELOG.md @@ -12,10 +12,16 @@ - Pass empty dict to filter instead of None (#775) - `Mongo` - Fallback to default filter policy when deserializing retrievers without the init parameter (#899) +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update mongodb test for the new `apply_filter_policy` usage (#971) +- MongoDB - remove legacy filter support (#1066) ## [integrations/mongodb_atlas-v0.2.1] - 2024-04-09 From ac0c5800070cf06ab45d8f71ba5a90b28fe89a08 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 16:50:19 +0200 Subject: [PATCH 25/33] chore: Update changelog after removing legacy filters (#1083) * Update changelog after removing legacy filters * Updates --- integrations/elasticsearch/CHANGELOG.md | 7 ++++++- integrations/opensearch/CHANGELOG.md | 18 ++++++++++++++++++ integrations/pgvector/CHANGELOG.md | 7 ++++++- integrations/pinecone/CHANGELOG.md | 6 ++++++ integrations/weaviate/CHANGELOG.md | 6 ++++++ 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/integrations/elasticsearch/CHANGELOG.md b/integrations/elasticsearch/CHANGELOG.md index a825234bc..5d2b66470 100644 --- a/integrations/elasticsearch/CHANGELOG.md +++ b/integrations/elasticsearch/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [unreleased] +## [integrations/elasticsearch-v1.0.0] - 2024-09-12 ### ๐Ÿš€ Features @@ -11,10 +11,15 @@ - `ElasticSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#898) +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- ElasticSearch - remove legacy filters elasticsearch (#1078) ## [integrations/elasticsearch-v0.5.0] - 2024-05-24 diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index 6509d1e0f..713848915 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## [integrations/opensearch-v1.0.0] - 2024-09-12 + +### ๐Ÿ“š Documentation + +- Update opensearch retriever docstrings (#1035) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- OpenSearch - remove legacy filter support (#1067) + +### Docs + +- Update BM25 docstrings (#945) + ## [integrations/opensearch-v0.9.0] - 2024-08-01 ### ๐Ÿš€ Features diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md index deb6faece..0fe5f4fa4 100644 --- a/integrations/pgvector/CHANGELOG.md +++ b/integrations/pgvector/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## [unreleased] +## [integrations/pgvector-v1.0.0] - 2024-09-12 ### ๐Ÿš€ Features @@ -10,10 +10,15 @@ - `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- PgVector - remove legacy filter support (#1068) ## [integrations/pgvector-v0.4.0] - 2024-06-20 diff --git a/integrations/pinecone/CHANGELOG.md b/integrations/pinecone/CHANGELOG.md index a041d63de..7810e486c 100644 --- a/integrations/pinecone/CHANGELOG.md +++ b/integrations/pinecone/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/pinecone-v2.0.0] - 2024-09-12 + +### โš™๏ธ Miscellaneous Tasks + +- Pinecone - remove legacy filter support (#1069) + ## [integrations/pinecone-v1.2.3] - 2024-08-29 ### ๐Ÿš€ Features diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index d934826bc..dacf3fef8 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/weaviate-v3.0.0] - 2024-09-12 + +### โš™๏ธ Miscellaneous Tasks + +- Weaviate - remove legacy filter support (#1070) + ## [integrations/weaviate-v2.2.1] - 2024-09-07 ### ๐Ÿš€ Features From 781941a4481368a7139883077d37b1221ea7f21f Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Fri, 13 Sep 2024 10:14:27 +0200 Subject: [PATCH 26/33] Remove support for deprecated legacy filters in Qdrant (#1084) * Remove support for deprecated legacy filters in Qdrant * Remove legacy filters tests --- .../document_stores/qdrant/document_store.py | 4 +- .../qdrant/tests/test_legacy_filters.py | 442 ------------------ 2 files changed, 2 insertions(+), 444 deletions(-) delete mode 100644 integrations/qdrant/tests/test_legacy_filters.py diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index da48e0f28..88afd8f65 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -11,7 +11,6 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.filters import convert as convert_legacy_filters from qdrant_client import grpc from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse @@ -323,7 +322,8 @@ def filter_documents( raise ValueError(msg) if filters and not isinstance(filters, rest.Filter) and "operator" not in filters: - filters = convert_legacy_filters(filters) + msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." + raise ValueError(msg) return list( self.get_documents_generator( filters, diff --git a/integrations/qdrant/tests/test_legacy_filters.py b/integrations/qdrant/tests/test_legacy_filters.py deleted file mode 100644 index a09f042c5..000000000 --- a/integrations/qdrant/tests/test_legacy_filters.py +++ /dev/null @@ -1,442 +0,0 @@ -from typing import List - -import pytest -from haystack import Document -from haystack.document_stores.types import DocumentStore -from haystack.testing.document_store import LegacyFilterDocumentsTest -from haystack.utils.filters import FilterError - -from haystack_integrations.document_stores.qdrant import QdrantDocumentStore - -# The tests below are from haystack.testing.document_store.LegacyFilterDocumentsTest -# Updated to include `meta` prefix for filter keys wherever necessary -# And skip tests that are not supported in Qdrant(Dataframes, embeddings) - - -class TestQdrantLegacyFilterDocuments(LegacyFilterDocumentsTest): - """ - Utility class to test a Document Store `filter_documents` method using different types of legacy filters - """ - - @pytest.fixture - def document_store(self) -> QdrantDocumentStore: - return QdrantDocumentStore( - ":memory:", - recreate_index=True, - return_embedding=True, - wait_result_from_api=True, - ) - - def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): - """ - Assert that two lists of Documents are equal. - This is used in every test. - """ - - # Check that the lengths of the lists are the same - assert len(received) == len(expected) - - # Check that the sets are equal, meaning the content and IDs match regardless of order - assert {doc.id for doc in received} == {doc.id for doc in expected} - - def test_filter_simple_metadata_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": "100"}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": {"$eq": "100"}}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": "100"}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsNotEqualTest - - def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": {"$ne": "100"}}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsInTest - - def test_filter_simple_list_single_element(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": ["100"]}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_filter_simple_list_one_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": ["100"]}) - self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) - - def test_filter_simple_list(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": ["100", "123"]}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], - ) - - def test_incorrect_filter_value(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": ["nope"]}) - self.assert_documents_are_equal(result, []) - - def test_in_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": {"$in": ["100", "123", "n.a."]}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], - ) - - def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": ["100", "123", "n.a."]}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], - ) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsNotInTest - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.page": {"$nin": ["100", "123", "n.a."]}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]], - ) - - # LegacyFilterDocumentsGreaterThanTest - - def test_gt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$gt": 0.0}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0], - ) - - def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"meta.page": {"$gt": "100"}}) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsGreaterThanEqualTest - - def test_gte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$gte": -2}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2], - ) - - def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"meta.page": {"$gte": "100"}}) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsLessThanTest - - def test_lt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$lt": 0.0}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] < 0], - ) - - def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"meta.page": {"$lt": "100"}}) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsLessThanEqualTest - - def test_lte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0}}) - self.assert_documents_are_equal( - result, - [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0], - ) - - def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - with pytest.raises(FilterError): - document_store.filter_documents(filters={"meta.page": {"$lte": "100"}}) - - @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... - - # LegacyFilterDocumentsSimpleLogicalTest - - def test_filter_simple_or(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters = { - "$or": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "meta.number": {"$lt": 1.0}, - } - } - result = document_store.filter_documents(filters=filters) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if (doc.meta.get("number") is not None and doc.meta["number"] < 1) - or doc.meta.get("name") in ["name_0", "name_1"] - ], - ) - - def test_filter_simple_implicit_and_with_multi_key_dict( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0.0}}) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 - ], - ) - - def test_filter_simple_explicit_and_with_list(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ], - ) - - def test_filter_simple_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0}}) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ], - ) - - # LegacyFilterDocumentsNestedLogicalTest( - - def test_filter_nested_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "meta.number": {"$lte": 2, "$gte": 0}, - "meta.name": ["name_0", "name_1"], - } - result = document_store.filter_documents(filters=filters_simplified) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - "number" in doc.meta - and doc.meta["number"] <= 2 - and doc.meta["number"] >= 0 - and doc.meta.get("name") in ["name_0", "name_1"] - ) - ], - ) - - def test_filter_nested_or(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters = { - "$or": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "meta.number": {"$lt": 1.0}, - } - } - result = document_store.filter_documents(filters=filters) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - doc.meta.get("name") in ["name_0", "name_1"] - or (doc.meta.get("number") is not None and doc.meta["number"] < 1) - ) - ], - ) - - def test_filter_nested_and_or_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "$and": { - "meta.page": {"$eq": "123"}, - "$or": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "meta.number": {"$lt": 1.0}, - }, - } - } - result = document_store.filter_documents(filters=filters_simplified) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and ( - doc.meta.get("name") in ["name_0", "name_1"] - or ("number" in doc.meta and doc.meta["number"] < 1) - ) - ) - ], - ) - - def test_filter_nested_and_or_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "meta.page": {"$eq": "123"}, - "$or": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "meta.number": {"$lt": 1.0}, - }, - } - result = document_store.filter_documents(filters=filters_simplified) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and ( - doc.meta.get("name") in ["name_0", "name_1"] - or ("number" in doc.meta and doc.meta["number"] < 1) - ) - ) - ], - ) - - def test_filter_nested_or_and(self, document_store: DocumentStore, filterable_docs: List[Document]): - document_store.write_documents(filterable_docs) - filters_simplified = { - "$or": { - "meta.number": {"$lt": 1}, - "$and": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "$not": {"meta.chapter": {"$eq": "intro"}}, - }, - } - } - result = document_store.filter_documents(filters=filters_simplified) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - (doc.meta.get("number") is not None and doc.meta["number"] < 1) - or (doc.meta.get("name") in ["name_0", "name_1"] and (doc.meta.get("chapter") != "intro")) - ) - ], - ) - - def test_filter_nested_multiple_identical_operators_same_level( - self, document_store: DocumentStore, filterable_docs: List[Document] - ): - document_store.write_documents(filterable_docs) - filters = { - "$or": [ - { - "$and": { - "meta.name": {"$in": ["name_0", "name_1"]}, - "meta.page": "100", - } - }, - { - "$and": { - "meta.chapter": {"$in": ["intro", "abstract"]}, - "meta.page": "123", - } - }, - ] - } - result = document_store.filter_documents(filters=filters) - self.assert_documents_are_equal( - result, - [ - doc - for doc in filterable_docs - if ( - (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") - or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") - ) - ], - ) - - def test_no_filter_not_empty(self, document_store: DocumentStore): - docs = [Document(content="test doc")] - document_store.write_documents(docs) - self.assert_documents_are_equal(document_store.filter_documents(), docs) - self.assert_documents_are_equal(document_store.filter_documents(filters={}), docs) From 704847f45da8aa7cf631a07bcd86ee4927c4b10a Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Fri, 13 Sep 2024 08:15:29 +0000 Subject: [PATCH 27/33] Update the changelog --- integrations/qdrant/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md index edc936fb2..a275529f8 100644 --- a/integrations/qdrant/CHANGELOG.md +++ b/integrations/qdrant/CHANGELOG.md @@ -1,5 +1,7 @@ # Changelog +## [integrations/qdrant-v6.0.0] - 2024-09-13 + ## [integrations/qdrant-v5.1.0] - 2024-09-12 ### ๐Ÿš€ Features From b72d8572d338e3bcccfcf6ae685cac48d1219ad8 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Fri, 13 Sep 2024 11:06:19 +0200 Subject: [PATCH 28/33] Unpin protobuf dependency in Google Vertex integration (#1085) --- integrations/google_vertex/pyproject.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index 71158f712..747bbecbf 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -22,12 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "google-cloud-aiplatform>=1.38", - "pyarrow>3", - "protobuf<5.28", -] +dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.38", "pyarrow>3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_vertex#readme" From b47583f798d7f89a91daf44d59c7e3dea4ba60f4 Mon Sep 17 00:00:00 2001 From: Mo Sriha <22803208+medsriha@users.noreply.github.com> Date: Mon, 16 Sep 2024 07:58:49 -0500 Subject: [PATCH 29/33] feat: Add Snowflake integration (#1064) * initial commit * add unit tests * add pyproject.toml * add pydoc config * add CHANGELOG file * update pyproject.toml * lint file * add example and fix lint * update comments * add header and trailing line * update based on review --- integrations/snowflake/CHANGELOG.md | 1 + integrations/snowflake/README.md | 23 + .../snowflake/example/text2sql_example.py | 120 ++++ integrations/snowflake/pydoc/config.yml | 30 + integrations/snowflake/pyproject.toml | 149 +++++ .../retrievers/snowflake/__init__.py | 7 + .../snowflake/snowflake_table_retriever.py | 335 ++++++++++ integrations/snowflake/tests/__init__.py | 3 + .../tests/test_snowflake_table_retriever.py | 611 ++++++++++++++++++ 9 files changed, 1279 insertions(+) create mode 100644 integrations/snowflake/CHANGELOG.md create mode 100644 integrations/snowflake/README.md create mode 100644 integrations/snowflake/example/text2sql_example.py create mode 100644 integrations/snowflake/pydoc/config.yml create mode 100644 integrations/snowflake/pyproject.toml create mode 100644 integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py create mode 100644 integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py create mode 100644 integrations/snowflake/tests/__init__.py create mode 100644 integrations/snowflake/tests/test_snowflake_table_retriever.py diff --git a/integrations/snowflake/CHANGELOG.md b/integrations/snowflake/CHANGELOG.md new file mode 100644 index 000000000..0553a3f4b --- /dev/null +++ b/integrations/snowflake/CHANGELOG.md @@ -0,0 +1 @@ +## [integrations/snowflake-v0.0.1] - 2024-09-06 \ No newline at end of file diff --git a/integrations/snowflake/README.md b/integrations/snowflake/README.md new file mode 100644 index 000000000..30f0aee1a --- /dev/null +++ b/integrations/snowflake/README.md @@ -0,0 +1,23 @@ +# snowflake-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +```console +pip install snowflake-haystack +``` +## Examples +You can find a code example showing how to use the Retriever under the `example/` folder of this repo. + +## License + +`snowflake-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file diff --git a/integrations/snowflake/example/text2sql_example.py b/integrations/snowflake/example/text2sql_example.py new file mode 100644 index 000000000..b85a4c677 --- /dev/null +++ b/integrations/snowflake/example/text2sql_example.py @@ -0,0 +1,120 @@ +from dotenv import load_dotenv +from haystack import Pipeline +from haystack.components.builders import PromptBuilder +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.utils import Secret + +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever + +load_dotenv() + +sql_template = """ + You are a SQL expert working with Snowflake. + + Your task is to create a Snowflake SQL query for the given question. + + Refrain from explaining your answer. Your answer must be the SQL query + in plain text format without using Markdown. + + Here are some relevant tables, a description about it, and their + columns: + + Database name: DEMO_DB + Schema name: ADVENTURE_WORKS + Table names: + - ADDRESS: Employees Address Table + - EMPLOYEE: Employees directory + - SALESTERRITORY: Sales territory lookup table. + - SALESORDERHEADER: General sales order information. + + User's question: {{ question }} + + Generated SQL query: +""" + +sql_builder = PromptBuilder(template=sql_template) + +analyst_template = """ + You are an expert data analyst. + + Your role is to answer the user's question {{ question }} using the information + in the table. + + You will base your response solely on the information provided in the + table. + + Do not rely on your knowledge base; only the data that is in the table. + + Refrain from using the term "table" in your response, but instead, use + the word "data" + + If the table is blank say: + + "The specific answer can't be found in the database. Try rephrasing your + question." + + Additionally, you will present the table in a tabular format and provide + the SQL query used to extract the relevant rows from the database in + Markdown. + + If the table is larger than 10 rows, display the most important rows up + to 10 rows. Your answer must be detailed and provide insights based on + the question and the available data. + + SQL query: + + {{ sql_query }} + + Table: + + {{ table }} + + Answer: +""" + +analyst_builder = PromptBuilder(template=analyst_template) + +# LLM responsible for generating the SQL query +sql_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 1000}, +) + +# LLM responsible for analyzing the table +analyst_llm = OpenAIGenerator( + model="gpt-4o", + api_key=Secret.from_env_var("OPENAI_API_KEY"), + generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, +) + +snowflake = SnowflakeTableRetriever( + user="", + account="", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + warehouse="", +) + +adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + +pipeline = Pipeline() + +pipeline.add_component(name="sql_builder", instance=sql_builder) +pipeline.add_component(name="sql_llm", instance=sql_llm) +pipeline.add_component(name="adapter", instance=adapter) +pipeline.add_component(name="snowflake", instance=snowflake) +pipeline.add_component(name="analyst_builder", instance=analyst_builder) +pipeline.add_component(name="analyst_llm", instance=analyst_llm) + + +pipeline.connect("sql_builder.prompt", "sql_llm.prompt") +pipeline.connect("sql_llm.replies", "adapter.replies") +pipeline.connect("adapter.output", "snowflake.query") +pipeline.connect("snowflake.table", "analyst_builder.table") +pipeline.connect("adapter.output", "analyst_builder.sql_query") +pipeline.connect("analyst_builder.prompt", "analyst_llm.prompt") + +question = "What are my top territories by number of orders and by sales value?" + +response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) diff --git a/integrations/snowflake/pydoc/config.yml b/integrations/snowflake/pydoc/config.yml new file mode 100644 index 000000000..7237b3816 --- /dev/null +++ b/integrations/snowflake/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: + [ + "haystack_integrations.components.retrievers.snowflake.snowflake_retriever" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Snowflake integration for Haystack + category_slug: integrations-api + title: Snowflake + slug: integrations-Snowflake + order: 130 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_snowflake.md \ No newline at end of file diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml new file mode 100644 index 000000000..68f9ec477 --- /dev/null +++ b/integrations/snowflake/pyproject.toml @@ -0,0 +1,149 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "snowflake-haystack" +dynamic = ["version"] +description = 'A Snowflake integration for the Haystack framework.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Mohamed Sriha", email = "mohamed.sriha@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "snowflake-connector-python>=3.10.1", "tabulate>=0.9.0"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/snowflake-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] +all = ["style", "typing"] + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Ignore SQL injection + "S608", + # Unused method argument + "ARG002" +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["snowflake_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "openai.*", "snowflake.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py new file mode 100644 index 000000000..294d3cce4 --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .snowflake_table_retriever import SnowflakeTableRetriever + +__all__ = ["SnowflakeTableRetriever"] diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py new file mode 100644 index 000000000..aa6f5ff4d --- /dev/null +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import re +from typing import Any, Dict, Final, Optional, Union + +import pandas as pd +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace + +with LazyImport("Run 'pip install snowflake-connector-python>=3.10.1'") as snow_import: + import snowflake.connector + from snowflake.connector.connection import SnowflakeConnection + from snowflake.connector.errors import ( + DatabaseError, + ForbiddenError, + ProgrammingError, + ) + +logger = logging.getLogger(__name__) + +MAX_SYS_ROWS: Final = 1000000 # Max rows to fetch from a table + + +@component +class SnowflakeTableRetriever: + """ + Connects to a Snowflake database to execute a SQL query. + For more information, see [Snowflake documentation](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector). + + ### Usage example: + + ```python + executor = SnowflakeTableRetriever( + user="", + account="", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + database="", + db_schema="", + warehouse="", + ) + + # When database and schema are provided during component initialization. + query = "SELECT * FROM table_name" + + # or + + # When database and schema are NOT provided during component initialization. + query = "SELECT * FROM database_name.schema_name.table_name" + + results = executor.run(query=query) + + print(results["dataframe"].head(2)) # Pandas dataframe + # Column 1 Column 2 + # 0 Value1 Value2 + # 1 Value1 Value2 + + print(results["table"]) # Markdown + # | Column 1 | Column 2 | + # |:----------|:----------| + # | Value1 | Value2 | + # | Value1 | Value2 | + ``` + """ + + def __init__( + self, + user: str, + account: str, + api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + database: Optional[str] = None, + db_schema: Optional[str] = None, + warehouse: Optional[str] = None, + login_timeout: Optional[int] = None, + ) -> None: + """ + :param user: User's login. + :param account: Snowflake account identifier. + :param api_key: Snowflake account password. + :param database: Name of the database to use. + :param db_schema: Name of the schema to use. + :param warehouse: Name of the warehouse to use. + :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + """ + + self.user = user + self.account = account + self.api_key = api_key + self.database = database + self.db_schema = db_schema + self.warehouse = warehouse + self.login_timeout = login_timeout or 60 + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + user=self.user, + account=self.account, + api_key=self.api_key.to_dict(), + database=self.database, + db_schema=self.db_schema, + warehouse=self.warehouse, + login_timeout=self.login_timeout, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SnowflakeTableRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + + @staticmethod + def _snowflake_connector(connect_params: Dict[str, Any]) -> Union[SnowflakeConnection, None]: + """ + Connect to a Snowflake database. + + :param connect_params: Snowflake connection parameters. + """ + try: + return snowflake.connector.connect(**connect_params) + except DatabaseError as e: + logger.error("{error_msg} ", errno=e.errno, error_msg=e.msg) + return None + + @staticmethod + def _extract_table_names(query: str) -> list: + """ + Extract table names from an SQL query using regex. + The extracted table names will be checked for privilege. + + :param query: SQL query to extract table names from. + """ + + suffix = "\\s+([a-zA-Z0-9_.]+)" # Regular expressions to match table names in various clauses + + patterns = [ + "MERGE\\s+INTO", + "USING", + "JOIN", + "FROM", + "UPDATE", + "DROP\\s+TABLE", + "TRUNCATE\\s+TABLE", + "CREATE\\s+TABLE", + "INSERT\\s+INTO", + "DELETE\\s+FROM", + ] + + # Combine all patterns into a single regex + combined_pattern = "|".join([pattern + suffix for pattern in patterns]) + + # Find all matches in the query + matches = re.findall(pattern=combined_pattern, string=query, flags=re.IGNORECASE) + + # Flatten the list of tuples and remove duplication + matches = list(set(sum(matches, ()))) + + # Clean and return unique table names + return [match.strip('`"[]').upper() for match in matches if match] + + @staticmethod + def _execute_sql_query(conn: SnowflakeConnection, query: str) -> pd.DataFrame: + """ + Execute an SQL query and fetch the results. + + :param conn: An open connection to Snowflake. + :param query: The query to execute. + """ + cur = conn.cursor() + try: + cur.execute(query) + rows = cur.fetchmany(size=MAX_SYS_ROWS) # set a limit to avoid fetching too many rows + + df = pd.DataFrame(rows, columns=[desc.name for desc in cur.description]) # Convert data to a dataframe + return df + except Exception as e: + if isinstance(e, ProgrammingError): + logger.warning( + "{error_msg} Use the following ID to check the status of the query in Snowflake UI (ID: {sfqid})", + error_msg=e.msg, + sfqid=e.sfqid, + ) + else: + logger.warning("An unexpected error occurred: {error_msg}", error_msg=e) + + return pd.DataFrame() + + @staticmethod + def _has_select_privilege(privileges: list, table_name: str) -> bool: + """ + Check user's privilege for a specific table. + + :param privileges: List of privileges. + :param table_name: Name of the table. + """ + + for privilege in reversed(privileges): + if table_name.lower() == privilege[3].lower() and re.match( + pattern="truncate|update|insert|delete|operate|references", + string=privilege[1], + flags=re.IGNORECASE, + ): + return False + + return True + + def _check_privilege( + self, + conn: SnowflakeConnection, + query: str, + user: str, + ) -> bool: + """ + Check whether a user has a `select`-only access to the table. + + :param conn: An open connection to Snowflake. + :param query: The query from where to extract table names to check read-only access. + """ + cur = conn.cursor() + + cur.execute(f"SHOW GRANTS TO USER {user};") + + # Get user's latest role + roles = cur.fetchall() + if not roles: + logger.error("User does not exist") + return False + + # Last row second column from GRANT table + role = roles[-1][1] + + # Get role privilege + cur.execute(f"SHOW GRANTS TO ROLE {role};") + + # Keep table level privileges + table_privileges = [row for row in cur.fetchall() if row[2] == "TABLE"] + + # Get table names to check for privilege + table_names = self._extract_table_names(query=query) + + for table_name in table_names: + if not self._has_select_privilege( + privileges=table_privileges, + table_name=table_name, + ): + return False + return True + + def _fetch_data( + self, + query: str, + ) -> pd.DataFrame: + """ + Fetch data from a database using a SQL query. + + :param query: SQL query to use to fetch the data from the database. Query must be a valid SQL query. + """ + + df = pd.DataFrame() + if not query: + return df + try: + # Create a new connection with every run + conn = self._snowflake_connector( + connect_params={ + "user": self.user, + "account": self.account, + "password": self.api_key.resolve_value(), + "database": self.database, + "schema": self.db_schema, + "warehouse": self.warehouse, + "login_timeout": self.login_timeout, + } + ) + if conn is None: + return df + except (ForbiddenError, ProgrammingError) as e: + logger.error( + "Error connecting to Snowflake ({errno}): {error_msg}", + errno=e.errno, + error_msg=e.msg, + ) + return df + + try: + # Check if user has `select` privilege on the table + if self._check_privilege( + conn=conn, + query=query, + user=self.user, + ): + df = self._execute_sql_query(conn=conn, query=query) + else: + logger.error("User does not have `Select` privilege on the table.") + + except Exception as e: + logger.error("An unexpected error has occurred: {error}", error=e) + + # Close connection after every execution + conn.close() + return df + + @component.output_types(dataframe=pd.DataFrame, table=str) + def run(self, query: str) -> Dict[str, Any]: + """ + Execute a SQL query against a Snowflake database. + + :param query: A SQL query to execute. + """ + if not query: + logger.error("Provide a valid SQL query.") + return { + "dataframe": pd.DataFrame, + "table": "", + } + else: + df = self._fetch_data(query) + table_markdown = df.to_markdown(index=False) if not df.empty else "" + + return {"dataframe": df, "table": table_markdown} diff --git a/integrations/snowflake/tests/__init__.py b/integrations/snowflake/tests/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/snowflake/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/snowflake/tests/test_snowflake_table_retriever.py b/integrations/snowflake/tests/test_snowflake_table_retriever.py new file mode 100644 index 000000000..547f7e1b1 --- /dev/null +++ b/integrations/snowflake/tests/test_snowflake_table_retriever.py @@ -0,0 +1,611 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from datetime import datetime +from typing import Generator +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from dateutil.tz import tzlocal +from haystack import Pipeline +from haystack.components.converters import OutputAdapter +from haystack.components.generators import OpenAIGenerator +from haystack.components.builders import PromptBuilder +from haystack.utils import Secret +from openai.types.chat import ChatCompletion, ChatCompletionMessage +from openai.types.chat.chat_completion import Choice +from pytest import LogCaptureFixture +from snowflake.connector.errors import DatabaseError, ForbiddenError, ProgrammingError + +from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever + + +class TestSnowflakeTableRetriever: + @pytest.fixture + def snowflake_table_retriever(self) -> SnowflakeTableRetriever: + return SnowflakeTableRetriever( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_snowflake_connector( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + conn = snowflake_table_retriever._snowflake_connector( + connect_params={ + "user": "test_user", + "account": "test_account", + "api_key": Secret.from_token("test-api-key"), + "database": "test_database", + "schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + } + ) + mock_connect.assert_called_once_with( + user="test_user", + account="test_account", + api_key=Secret.from_token("test-api-key"), + database="test_database", + schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + assert conn == mock_conn + + def test_query_is_empty( + self, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + query = "" + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + assert "Provide a valid SQL query" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_exception( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect = mock_connect.return_value + mock_connect._fetch_data.side_effect = Exception("Unknown error") + + query = 4 + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + assert "An unexpected error has occurred" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_forbidden_error_during_connection( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ForbiddenError(msg="Forbidden error", errno=403) + + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Forbidden error" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_programing_error_during_connection( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_connect.side_effect = ProgrammingError(msg="Programming error", errno=403) + + result = snowflake_table_retriever._fetch_data(query="SELECT * FROM test_table") + + assert result.empty + assert "000403: Programming error" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query_programming_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_cursor = mock_conn.cursor.return_value + + mock_cursor.execute.side_effect = ProgrammingError(msg="Simulated programming error", sfqid="ABC-123") + + result = snowflake_table_retriever._execute_sql_query(mock_conn, "SELECT * FROM some_table") + + assert result.empty + + assert ( + "Simulated programming error Use the following ID to check the status of " + "the query in Snowflake UI (ID: ABC-123)" in caplog.text + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_connection_error( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_connect.side_effect = DatabaseError(msg="Connection error", errno=1234) + + query = "SELECT * FROM test_table" + result = snowflake_table_retriever.run(query=query) + + assert result["table"] == "" + assert result["dataframe"].empty + + def test_extract_single_table_name(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + queries = [ + "SELECT * FROM table_a", + "SELECT name, value FROM (SELECT name, value FROM table_a) AS subquery", + "SELECT name, value FROM (SELECT name, value FROM table_a ) AS subquery", + "UPDATE table_a SET value = 'new_value' WHERE id = 1", + "INSERT INTO table_a (id, name, value) VALUES (1, 'name1', 'value1')", + "DELETE FROM table_a WHERE id = 1", + "TRUNCATE TABLE table_a", + "DROP TABLE table_a", + ] + for query in queries: + result = snowflake_table_retriever._extract_table_names(query) + assert result == ["TABLE_A"] + + def test_extract_database_and_schema_from_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + # when database and schema are next to table name + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM DB.SCHEMA.TABLE_A") == [ + "DB.SCHEMA.TABLE_A" + ] + # No database + assert snowflake_table_retriever._extract_table_names(query="SELECT * FROM SCHEMA.TABLE_A") == [ + "SCHEMA.TABLE_A" + ] + + def test_extract_multiple_table_names(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + queries = [ + "MERGE INTO table_a USING table_b ON table_a.id = table_b.id WHEN MATCHED", + "SELECT a.name, b.value FROM table_a AS a FULL OUTER JOIN table_b AS b ON a.id = b.id", + "SELECT a.name, b.value FROM table_a AS a RIGHT JOIN table_b AS b ON a.id = b.id", + ] + for query in queries: + result = snowflake_table_retriever._extract_table_names(query) + # Due to using set when deduplicating + assert result == ["TABLE_A", "TABLE_B"] or ["TABLE_B", "TABLE_A"] + + def test_extract_multiple_db_schema_from_table_names( + self, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + assert ( + snowflake_table_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM DB.SCHEMA.TABLE_A AS a + FULL OUTER JOIN DATABASE.SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["DB.SCHEMA.TABLE_A", "DATABASE.SCHEMA.TABLE_A"] + or ["DATABASE.SCHEMA.TABLE_A", "DB.SCHEMA.TABLE_B"] + ) + # No database + assert ( + snowflake_table_retriever._extract_table_names( + query="""SELECT a.name, b.value FROM SCHEMA.TABLE_A AS a + FULL OUTER JOIN SCHEMA.TABLE_b AS b ON a.id = b.id""" + ) + == ["SCHEMA.TABLE_A", "SCHEMA.TABLE_A"] + or ["SCHEMA.TABLE_A", "SCHEMA.TABLE_B"] + ) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_execute_sql_query( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM test_table" + expected = pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) + + mock_cursor.execute.assert_called_once_with(query) + + assert result.equals(expected) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_is_select_only( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], # Table privileges + ] + + query = "select * from locations" + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + assert result + + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "INSERT", + "TABLE", + "LOCATIONS", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + + result = snowflake_table_retriever._check_privilege(conn=mock_conn, user="test_user", query=query) + + assert not result + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_column_after_from( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_col1.name = "id" + mock_col2.name = "year" + mock_cursor.fetchmany.return_value = [(1233, 1998)] + mock_cursor.description = [mock_col1, mock_col2] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT id, extract(year from date_col) as year FROM test_table" + expected = pd.DataFrame(data={"id": [1233], "year": [1998]}) + result = snowflake_table_retriever._execute_sql_query(conn=mock_conn, query=query) + mock_cursor.execute.assert_called_once_with(query) + + assert result.equals(expected) + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run(self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_col2 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "locations", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "City" + mock_col2.name = "State" + mock_cursor.description = [mock_col1, mock_col2] + + mock_cursor.fetchmany.return_value = [("Chicago", "Illinois")] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + query = "SELECT * FROM locations" + + expected = { + "dataframe": pd.DataFrame(data={"City": ["Chicago"], "State": ["Illinois"]}), + "table": "| City | State |\n|:--------|:---------|\n| Chicago | Illinois |", + } + + result = snowflake_table_retriever.run(query=query) + + assert result["dataframe"].equals(expected["dataframe"]) + assert result["table"] == expected["table"] + + @pytest.fixture + def mock_chat_completion(self) -> Generator: + """ + Mock the OpenAI API completion response and reuse it for tests + """ + with patch("openai.resources.chat.completions.Completions.create") as mock_chat_completion_create: + completion = ChatCompletion( + id="foo", + model="gpt-4o-mini", + object="chat.completion", + choices=[ + Choice( + finish_reason="stop", + logprobs=None, + index=0, + message=ChatCompletionMessage(content="select locations from table_a", role="assistant"), + ) + ], + created=int(datetime.now(tz=tzlocal()).timestamp()), + usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + ) + + mock_chat_completion_create.return_value = completion + yield mock_chat_completion_create + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_run_pipeline( + self, + mock_connect: MagicMock, + mock_chat_completion: MagicMock, + snowflake_table_retriever: SnowflakeTableRetriever, + ) -> None: + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_col1 = MagicMock() + mock_cursor.fetchall.side_effect = [ + [("DATETIME", "ROLE_NAME", "USER", "USER_NAME", "GRANTED_BY")], # User roles + [ + ( + "DATETIME", + "SELECT", + "TABLE", + "test_database.test_schema.table_a", + "ROLE", + "ROLE_NAME", + "GRANT_OPTION", + "GRANTED_BY", + ) + ], + ] + mock_col1.name = "locations" + + mock_cursor.description = [mock_col1] + + mock_cursor.fetchmany.return_value = [("Chicago",), ("Miami",), ("Berlin",)] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + expected = { + "dataframe": pd.DataFrame(data={"locations": ["Chicago", "Miami", "Berlin"]}), + "table": "| locations |\n|:------------|\n| Chicago |\n| Miami |\n| Berlin |", + } + + llm = OpenAIGenerator(model="gpt-4o-mini", api_key=Secret.from_token("test-api-key")) + adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) + pipeline = Pipeline() + + pipeline.add_component("llm", llm) + pipeline.add_component("adapter", adapter) + pipeline.add_component("snowflake", snowflake_table_retriever) + + pipeline.connect(sender="llm.replies", receiver="adapter.replies") + pipeline.connect(sender="adapter.output", receiver="snowflake.query") + + result = pipeline.run(data={"llm": {"prompt": "Generate a SQL query that extract all locations from table_a"}}) + + assert result["snowflake"]["dataframe"].equals(expected["dataframe"]) + assert result["snowflake"]["table"] == expected["table"] + + def test_from_dict(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + data = { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever" + ".SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "new_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 3, + }, + } + component = SnowflakeTableRetriever.from_dict(data) + + assert component.user == "test_user" + assert component.account == "new_account" + assert component.api_key == Secret.from_env_var("SNOWFLAKE_API_KEY") + assert component.database == "test_database" + assert component.db_schema == "test_schema" + assert component.warehouse == "test_warehouse" + assert component.login_timeout == 3 + + def test_to_dict_default(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeTableRetriever( + user="test_user", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="test_account", + database="test_database", + db_schema="test_schema", + warehouse="test_warehouse", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "test_user", + "account": "test_account", + "database": "test_database", + "db_schema": "test_schema", + "warehouse": "test_warehouse", + "login_timeout": 30, + }, + } + + def test_to_dict_with_parameters(self, monkeypatch: MagicMock) -> None: + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + monkeypatch.setenv("SNOWFLAKE_API_KEY", "test-api-key") + component = SnowflakeTableRetriever( + user="John", + api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), + account="TGMD-EEREW", + database="CITY", + db_schema="SMALL_TOWNS", + warehouse="COMPUTE_WH", + login_timeout=30, + ) + + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.SnowflakeTableRetriever", + "init_parameters": { + "api_key": { + "env_vars": ["SNOWFLAKE_API_KEY"], + "strict": True, + "type": "env_var", + }, + "user": "John", + "account": "TGMD-EEREW", + "database": "CITY", + "db_schema": "SMALL_TOWNS", + "warehouse": "COMPUTE_WH", + "login_timeout": 30, + }, + } + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_has_select_privilege( + self, mock_logger: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever + ) -> None: + # Define test cases + test_cases = [ + # Test case 1: Fully qualified table name in query + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 2: Schema and table names in query, database name as argument + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 3: Only table name in query, database and schema names as arguments + { + "privileges": [[None, "SELECT", None, "table"]], + "table_name": "table", + "expected_result": True, + }, + # Test case 5: Privilege does not match + { + "privileges": [[None, "INSERT", None, "table"]], + "table_name": "table", + "expected_result": False, + }, + # Test case 6: Case-insensitive match + { + "privileges": [[None, "select", None, "table"]], + "table_name": "TABLE", + "expected_result": True, + }, + ] + + for case in test_cases: + result = snowflake_table_retriever._has_select_privilege( + privileges=case["privileges"], # type: ignore + table_name=case["table_name"], # type: ignore + ) + assert result == case["expected_result"] # type: ignore + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_user_does_not_exist( + self, mock_connect: MagicMock, snowflake_table_retriever: SnowflakeTableRetriever, caplog: LogCaptureFixture + ) -> None: + mock_conn = MagicMock() + mock_connect.return_value = mock_conn + + mock_cursor = mock_conn.cursor.return_value + mock_cursor.fetchall.return_value = [] + + result = snowflake_table_retriever._fetch_data(query="""SELECT * FROM test_table""") + + assert result.empty + assert "User does not exist" in caplog.text + + @patch( + "haystack_integrations.components.retrievers.snowflake.snowflake_table_retriever.snowflake.connector.connect" + ) + def test_empty_query(self, snowflake_table_retriever: SnowflakeTableRetriever) -> None: + result = snowflake_table_retriever._fetch_data(query="") + + assert result.empty + + def test_serialization_deserialization_pipeline(self) -> None: + + pipeline = Pipeline() + pipeline.add_component("snow", SnowflakeTableRetriever(user="test_user", account="test_account")) + pipeline.add_component("prompt_builder", PromptBuilder(template="Display results {{ table }}")) + pipeline.connect("snow.table", "prompt_builder.table") + + pipeline_dict = pipeline.to_dict() + + new_pipeline = Pipeline.from_dict(pipeline_dict) + assert new_pipeline == pipeline From 26bb3288e4c4934544bec9c2e55b292d8e113718 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 16 Sep 2024 16:22:11 +0200 Subject: [PATCH 30/33] feat: Cohere LLM - adjust token counting meta to match OpenAI format (#1086) * Cohere - adjust token counting in meta * Update integration test * Lint --- .../components/generators/cohere/chat/chat_generator.py | 8 +++++--- integrations/cohere/tests/test_cohere_chat_generator.py | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 568a26979..e635e291c 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -178,7 +178,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, if finish_response.meta.billed_units: tokens_in = finish_response.meta.billed_units.input_tokens or -1 tokens_out = finish_response.meta.billed_units.output_tokens or -1 - chat_message.meta["usage"] = tokens_in + tokens_out + chat_message.meta["usage"] = {"prompt_tokens": tokens_in, "completion_tokens": tokens_out} chat_message.meta.update( { "model": self.model, @@ -220,11 +220,13 @@ def _build_message(self, cohere_response): message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json()) elif cohere_response.text: message = ChatMessage.from_assistant(content=cohere_response.text) - total_tokens = cohere_response.meta.billed_units.input_tokens + cohere_response.meta.billed_units.output_tokens message.meta.update( { "model": self.model, - "usage": total_tokens, + "usage": { + "prompt_tokens": cohere_response.meta.billed_units.input_tokens, + "completion_tokens": cohere_response.meta.billed_units.output_tokens, + }, "index": 0, "finish_reason": cohere_response.finish_reason, "documents": cohere_response.documents, diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 6521503f2..fe9b7f43e 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -169,6 +169,9 @@ def test_live_run(self): assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.content + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), @@ -210,6 +213,10 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + assert "usage" in message.meta + assert "prompt_tokens" in message.meta["usage"] + assert "completion_tokens" in message.meta["usage"] + @pytest.mark.skipif( not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", From 803aaa813a0e626ac7ea13dd88b1c8aba91dbe94 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Mon, 16 Sep 2024 14:57:37 +0000 Subject: [PATCH 31/33] Update the changelog --- integrations/cohere/CHANGELOG.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/integrations/cohere/CHANGELOG.md b/integrations/cohere/CHANGELOG.md index 3067b0a5e..3f36836cc 100644 --- a/integrations/cohere/CHANGELOG.md +++ b/integrations/cohere/CHANGELOG.md @@ -1,15 +1,30 @@ # Changelog -## [unreleased] +## [integrations/cohere-v2.0.0] - 2024-09-16 ### ๐Ÿš€ Features - Update Anthropic/Cohere for tools use (#790) - Update Cohere default LLMs, add examples and update unit tests (#838) +- Cohere LLM - adjust token counting meta to match OpenAI format (#1086) + +### ๐Ÿ› Bug Fixes + +- Lints in `cohere-haystack` (#995) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) ### โš™๏ธ Miscellaneous Tasks - Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + +### Docs + +- Update CohereChatGenerator docstrings (#958) +- Update CohereGenerator docstrings (#960) ## [integrations/cohere-v1.1.1] - 2024-06-12 From eda6c9faa01e6d7d3cfe05ade77c8d2583fa9b71 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 17 Sep 2024 14:44:07 +0200 Subject: [PATCH 32/33] feat: Langfuse - support generation span for more LLMs (#1087) * Langfuse: support generation span for more LLMs * Add example instructions * Avoid instantiation of all generators, only selected * Linting * Formatting and naming * Add integration test for Anthropic * Add cohere integration test * Lint * Parametrize integration test * Linting * Simplify test parameters * Move LLM deps to test env --- integrations/langfuse/example/chat.py | 31 ++++++++++- integrations/langfuse/pyproject.toml | 2 + .../tracing/langfuse/tracer.py | 18 +++++- integrations/langfuse/tests/test_tracing.py | 55 +++++++++++-------- 4 files changed, 78 insertions(+), 28 deletions(-) diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 443d65a13..0d9c42787 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -1,19 +1,46 @@ import os +# See README.md for more information on how to set up the environment variables +# before running this script + +# In addition to setting the environment variables, you need to install the following packages: +# pip install cohere-haystack anthropic-haystack os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.generators.chat import HuggingFaceAPIChatGenerator, OpenAIChatGenerator from haystack.dataclasses import ChatMessage +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType + from haystack_integrations.components.connectors.langfuse import LangfuseConnector +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator + +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" + +selected_chat_generator = "openai" + +generators = { + "openai": OpenAIChatGenerator, + "anthropic": AnthropicChatGenerator, + "hf_api": lambda: HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "mistralai/Mixtral-8x7B-Instruct-v0.1"}, + token=Secret.from_token(os.environ["HF_API_KEY"]), + ), + "cohere": CohereChatGenerator, +} + +selected_chat_generator = generators[selected_chat_generator]() if __name__ == "__main__": pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) pipe.add_component("prompt_builder", ChatPromptBuilder()) - pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) + pipe.add_component("llm", selected_chat_generator) pipe.connect("prompt_builder.prompt", "llm.messages") diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index d92c62668..6f9213be7 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -47,6 +47,8 @@ dependencies = [ "pytest", "pytest-rerunfailures", "haystack-pydoc-tools", + "anthropic-haystack", + "cohere-haystack" ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 7d141c08c..94064a0d1 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -10,8 +10,22 @@ import langfuse HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" -_SUPPORTED_GENERATORS = ["AzureOpenAIGenerator", "OpenAIGenerator"] -_SUPPORTED_CHAT_GENERATORS = ["AzureOpenAIChatGenerator", "OpenAIChatGenerator"] +_SUPPORTED_GENERATORS = [ + "AzureOpenAIGenerator", + "OpenAIGenerator", + "AnthropicGenerator", + "HuggingFaceAPIGenerator", + "HuggingFaceLocalGenerator", + "CohereGenerator", +] +_SUPPORTED_CHAT_GENERATORS = [ + "AzureOpenAIChatGenerator", + "OpenAIChatGenerator", + "AnthropicChatGenerator", + "HuggingFaceAPIChatGenerator", + "HuggingFaceLocalChatGenerator", + "CohereChatGenerator", +] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 111d89dfd..4e8c679be 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -1,34 +1,38 @@ import os - -# don't remove (or move) this env var setting from here, it's needed to turn tracing on -os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" - -from urllib.parse import urlparse - import pytest +from urllib.parse import urlparse import requests - +from requests.auth import HTTPBasicAuth from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder -from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage -from requests.auth import HTTPBasicAuth - from haystack_integrations.components.connectors.langfuse import LangfuseConnector +from haystack.components.generators.chat import OpenAIChatGenerator + +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator + +# don't remove (or move) this env var setting from here, it's needed to turn tracing on +os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" @pytest.mark.integration -@pytest.mark.skipif( - not os.environ.get("LANGFUSE_SECRET_KEY", None) and not os.environ.get("LANGFUSE_PUBLIC_KEY", None), - reason="Export an env var called LANGFUSE_SECRET_KEY and LANGFUSE_PUBLIC_KEY containing Langfuse credentials.", +@pytest.mark.parametrize( + "llm_class, env_var, expected_trace", + [ + (OpenAIChatGenerator, "OPENAI_API_KEY", "OpenAI"), + (AnthropicChatGenerator, "ANTHROPIC_API_KEY", "Anthropic"), + (CohereChatGenerator, "COHERE_API_KEY", "Cohere"), + ], ) -def test_tracing_integration(): +def test_tracing_integration(llm_class, env_var, expected_trace): + if not all([os.environ.get("LANGFUSE_SECRET_KEY"), os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get(env_var)]): + pytest.skip(f"Missing required environment variables: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, or {env_var}") pipe = Pipeline() - pipe.add_component("tracer", LangfuseConnector(name="Chat example", public=True)) # public so anyone can verify run + pipe.add_component("tracer", LangfuseConnector(name=f"Chat example - {expected_trace}", public=True)) pipe.add_component("prompt_builder", ChatPromptBuilder()) - pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) - + pipe.add_component("llm", llm_class()) pipe.connect("prompt_builder.prompt", "llm.messages") messages = [ @@ -39,17 +43,20 @@ def test_tracing_integration(): response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] + url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] - parsed_url = urlparse(trace_url) - # trace id is the last part of the path (after the last '/') - uuid = os.path.basename(parsed_url.path) + uuid = os.path.basename(urlparse(trace_url).path) + try: - # GET request with Basic Authentication on the Langfuse API response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ.get("LANGFUSE_PUBLIC_KEY"), os.environ.get("LANGFUSE_SECRET_KEY")) + url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + + # check if the trace contains the expected LLM name + assert expected_trace in str(response.content) + # check if the trace contains the expected generation span + assert "GENERATION" in str(response.content) except requests.exceptions.RequestException as e: - assert False, f"Failed to retrieve data from Langfuse API: {e}" + pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") From b32f620b20fcc4cc246ae58b624a34c916333a8f Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Tue, 17 Sep 2024 14:39:59 +0000 Subject: [PATCH 33/33] Update the changelog --- integrations/langfuse/CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 0a90a7121..ccd68ded3 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [unreleased] +## [integrations/langfuse-v0.4.0] - 2024-09-17 + +### ๐Ÿš€ Features + +- Langfuse - support generation span for more LLMs (#1087) ### ๐Ÿšœ Refactor