diff --git a/.github/ISSUE_TEMPLATE/breaking-change-proposal.md b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md index 71aa2a5e9..6c6fb9017 100644 --- a/.github/ISSUE_TEMPLATE/breaking-change-proposal.md +++ b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md @@ -15,9 +15,12 @@ Briefly explain how the change is breaking and why is needed. ```[tasklist] ### Tasks -- [ ] The change is documented with docstrings and was merged in the `main` branch -- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated +- [ ] The changes are merged in the `main` branch (Code + Docstrings) +- [ ] New package version declares the breaking change +- [ ] The package has been released on PyPI - [ ] Docs at https://docs.haystack.deepset.ai/ were updated +- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated - [ ] Notebooks on https://github.com/deepset-ai/haystack-cookbook were updated (if needed) -- [ ] New package version declares the breaking change and package has been released on PyPI -``` \ No newline at end of file +- [ ] Tutorials on https://github.com/deepset-ai/haystack-tutorials were updated (if needed) +- [ ] Articles on https://github.com/deepset-ai/haystack-home/tree/main/content were updated (if needed) +``` diff --git a/.github/labeler.yml b/.github/labeler.yml index ba74c43a2..4d060772c 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,6 +4,11 @@ integration:amazon-bedrock: - any-glob-to-any-file: "integrations/amazon_bedrock/**/*" - any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml" +integration:amazon-sagemaker: + - changed-files: + - any-glob-to-any-file: "integrations/amazon_sagemaker/**/*" + - any-glob-to-any-file: ".github/workflows/amazon_sagemaker.yml" + integration:astra: - changed-files: - any-glob-to-any-file: "integrations/astra/**/*" @@ -64,6 +69,11 @@ integration:opensearch: - any-glob-to-any-file: "integrations/opensearch/**/*" - any-glob-to-any-file: ".github/workflows/opensearch.yml" +integration:pgvector: + - changed-files: + - any-glob-to-any-file: "integrations/pgvector/**/*" + - any-glob-to-any-file: ".github/workflows/pgvector.yml" + integration:pinecone: - changed-files: - any-glob-to-any-file: "integrations/pinecone/**/*" @@ -76,8 +86,8 @@ integration:qdrant: integration:unstructured-fileconverter: - changed-files: - - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" - - any-glob-to-any-file: ".github/workflows/unstructured_fileconverter.yml" + - any-glob-to-any-file: "integrations/unstructured/**/*" + - any-glob-to-any-file: ".github/workflows/unstructured.yml" integration:uptrain: - changed-files: diff --git a/.github/workflows/amazon_sagemaker.yml b/.github/workflows/amazon_sagemaker.yml new file mode 100644 index 000000000..88f397c85 --- /dev/null +++ b/.github/workflows/amazon_sagemaker.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / amazon-sagemaker + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/amazon_sagemaker/**" + - ".github/workflows/amazon_sagemaker.yml" + +defaults: + run: + working-directory: integrations/amazon_sagemaker + +concurrency: + group: amazon-sagemaker-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index b751550de..a1aab7154 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -53,6 +53,10 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests env: ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 0f0030ec1..562556e47 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -52,5 +52,9 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests run: hatch run cov \ No newline at end of file diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index eb2c1748d..688e5c48f 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -10,6 +10,10 @@ on: - "integrations/elasticsearch/**" - ".github/workflows/elasticsearch.yml" +defaults: + run: + working-directory: integrations/elasticsearch + concurrency: group: elasticsearch-${{ github.head_ref }} cancel-in-progress: true @@ -40,14 +44,15 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/elasticsearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run ElasticSearch container - working-directory: integrations/elasticsearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests - working-directory: integrations/elasticsearch run: hatch run cov diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 7f61af14e..a375fc7db 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -54,6 +54,11 @@ jobs: - name: Pull the LLM in the Ollama service run: docker exec ollama ollama pull ${{ env.LLM_FOR_TESTS }} + - name: Generate docs + working-directory: integrations/ollama + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/ollama run: hatch run cov diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index aacb4ce71..72a01d090 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -18,6 +18,10 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" +defaults: + run: + working-directory: integrations/opensearch + jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} @@ -40,14 +44,16 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/opensearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run opensearch container - working-directory: integrations/opensearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/opensearch run: hatch run cov diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index c985b765a..badb2565b 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -18,6 +18,10 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" +defaults: + run: + working-directory: integrations/pgvector + jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} @@ -49,10 +53,12 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/pgvector if: matrix.python-version == '3.9' - run: hatch run lint:all + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs - name: Run tests - working-directory: integrations/pgvector run: hatch run cov diff --git a/README.md b/README.md index 20b17b377..39d669322 100644 --- a/README.md +++ b/README.md @@ -80,3 +80,4 @@ deepset-haystack | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | | [uptrain-haystack](integrations/uptrain/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [![Test / uptrain](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml) | +| [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | diff --git a/integrations/amazon_bedrock/README.md b/integrations/amazon_bedrock/README.md index f84c8f3c4..3a689ef3b 100644 --- a/integrations/amazon_bedrock/README.md +++ b/integrations/amazon_bedrock/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Contributing](#contributing) - [License](#license) ## Installation @@ -16,6 +17,24 @@ pip install amazon-bedrock-haystack ``` +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` +> Note: there are no integration tests for this project. + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + ## License `amazon-bedrock-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 7e82924a8..6a2ce3eab 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/amazon_bedrock-v(?P.*)' @@ -71,7 +74,8 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/amazon_bedrock_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" + style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -136,26 +140,24 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["amazon_bedrock_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["amazon_bedrock_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true -omit = [ - "src/amazon_bedrock_haystack/__about__.py", -] + [tool.coverage.paths] -amazon_bedrock_haystack = ["src/amazon_bedrock_haystack", "*/amazon_bedrock/src/amazon_bedrock_haystack"] -tests = ["tests", "*/amazon_bedrock_haystack/tests"] +amazon_bedrock_haystack = ["src/*"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -170,6 +172,7 @@ module = [ "transformers.*", "boto3.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py similarity index 63% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 3e05179c0..236347b61 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator +from .generator import AmazonBedrockGenerator __all__ = ["AmazonBedrockGenerator"] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index bec172867..40ba0bc67 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import TokenStreamingHandler +from .handlers import TokenStreamingHandler class BedrockModelAdapter(ABC): diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index dda84fe14..4c43c9a09 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -7,12 +7,7 @@ from botocore.exceptions import BotoCoreError, ClientError from haystack import component, default_from_dict, default_to_dict -from amazon_bedrock_haystack.errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, - AWSConfigurationError, -) -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from .adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -20,7 +15,12 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import ( +from .errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) +from .handlers import ( DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler, diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index a05c95ba3..b08e9dfd5 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -4,9 +4,8 @@ import pytest from botocore.exceptions import BotoCoreError -from amazon_bedrock_haystack.errors import AmazonBedrockConfigurationError -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator +from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -14,6 +13,7 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) +from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.fixture @@ -34,7 +34,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock_handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -55,7 +55,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): ) expected_dict = { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -72,7 +72,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): """ generator = AmazonBedrockGenerator.from_dict( { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -235,7 +235,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -266,7 +266,7 @@ def test_supports_for_invalid_bedrock_config(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -282,7 +282,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -314,7 +314,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -335,7 +335,7 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises( AmazonBedrockConfigurationError, diff --git a/integrations/amazon_sagemaker/LICENSE.txt b/integrations/amazon_sagemaker/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/amazon_sagemaker/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/amazon_sagemaker/README.md b/integrations/amazon_sagemaker/README.md new file mode 100644 index 000000000..1ea01871d --- /dev/null +++ b/integrations/amazon_sagemaker/README.md @@ -0,0 +1,52 @@ +# amazon-sagemaker-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [Contributing](#contributing) +- [License](#license) + +## Installation + +```console +pip install amazon-sagemaker-haystack +``` + +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` + +> Note: You need to export your AWS credentials for Sagemaker integration tests to run (`AWS_ACCESS_KEY_ID` and +`AWS_SECRET_SECRET_KEY`). If those are missing, the integration tests will be skipped. + +To only run unit tests: +``` +hatch run test -m "not integration" +``` + +To only run integration tests: +``` +hatch run test -m "integration" +``` + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + +## License + +`amazon-sagemaker-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml new file mode 100644 index 000000000..916307156 --- /dev/null +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "amazon-sagemaker-haystack" +dynamic = ["version"] +description = 'An integration of Amazon Sagemaker as an SagemakerGenerator component.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "boto3>=1.28.57", +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/amazon_sagemaker-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/amazon_sagemaker-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Import sorting doesn't seem to work + "I001", + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["haystack_integrations"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +branch = true +parallel = true + +[tool.coverage.paths] +amazon_sagemaker_haystack = ["src"] +tests = ["tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py new file mode 100644 index 000000000..0fe45a8a1 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from haystack_integrations.components.generators.amazon_sagemaker.sagemaker import SagemakerGenerator + +__all__ = ["SagemakerGenerator"] diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py new file mode 100644 index 000000000..6c13d0fcb --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py @@ -0,0 +1,46 @@ +from typing import Optional + + +class SagemakerError(Exception): + """ + Error generated by the Amazon Sagemaker integration. + """ + + def __init__( + self, + message: Optional[str] = None, + ): + super().__init__() + if message: + self.message = message + + def __getattr__(self, attr): + # If self.__cause__ is None, it will raise the expected AttributeError + getattr(self.__cause__, attr) + + def __str__(self): + return self.message + + def __repr__(self): + return str(self) + + +class AWSConfigurationError(SagemakerError): + """Exception raised when AWS is not configured correctly""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerNotReadyError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerInferenceError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py new file mode 100644 index 000000000..35e54a055 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -0,0 +1,224 @@ +import json +import logging +import os +from typing import Any, ClassVar, Dict, List, Optional + +import requests +from haystack import component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack_integrations.components.generators.amazon_sagemaker.errors import ( + AWSConfigurationError, + SagemakerInferenceError, + SagemakerNotReadyError, +) + +with LazyImport(message="Run 'pip install boto3'") as boto3_import: + import boto3 # type: ignore + from botocore.client import BaseClient # type: ignore + + +logger = logging.getLogger(__name__) + + +MODEL_NOT_READY_STATUS_CODE = 429 + + +@component +class SagemakerGenerator: + """ + Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker + Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the + [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). + + **Example:** + + First export your AWS credentials as environment variables: + ```bash + export AWS_ACCESS_KEY_ID= + export AWS_SECRET_ACCESS_KEY= + ``` + (Note: you may also need to set the session token and region name, depending on your AWS configuration) + + Then you can use the generator as follows: + ```python + from haystack.components.generators.sagemaker import SagemakerGenerator + generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") + generator.warm_up() + response = generator.run("What's Natural Language Processing? Be brief.") + print(response) + ``` + ``` + >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on + >> the interaction between computers and human language. It involves enabling computers to understand, interpret, + >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]} + ``` + """ + + model_generation_keys: ClassVar = ["generated_text", "generation"] + + def __init__( + self, + model: str, + aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID", + aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY", + aws_session_token_var: str = "AWS_SESSION_TOKEN", + aws_region_name_var: str = "AWS_REGION", + aws_profile_name_var: str = "AWS_PROFILE", + aws_custom_attributes: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Instantiates the session with SageMaker. + + :param model: The name for SageMaker Model Endpoint. + :param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored. + :param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored. + :param aws_session_token_var: The name of the env var where the AWS session token is stored. + :param aws_region_name_var: The name of the env var where the AWS region name is stored. + :param aws_profile_name_var: The name of the env var where the AWS profile name is stored. + :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` + in case of Llama-2 models. + :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters + see your model's documentation page, for example here for HuggingFace models: + https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model + + Specifically, Llama-2 models support the following inference payload parameters: + + - `max_new_tokens`: Model generates text until the output length (excluding the input context length) + reaches `max_new_tokens`. If specified, it must be a positive integer. + - `temperature`: Controls the randomness in the output. Higher temperature results in output sequence with + low-probability words and lower temperature results in output sequence with high-probability words. + If `temperature=0`, it results in greedy decoding. If specified, it must be a positive float. + - `top_p`: In each step of text generation, sample from the smallest possible set of words with cumulative + probability `top_p`. If specified, it must be a float between 0 and 1. + - `return_full_text`: If `True`, input text will be part of the output generated text. If specified, it must + be boolean. The default value for it is `False`. + """ + self.model = model + self.aws_access_key_id_var = aws_access_key_id_var + self.aws_secret_access_key_var = aws_secret_access_key_var + self.aws_session_token_var = aws_session_token_var + self.aws_region_name_var = aws_region_name_var + self.aws_profile_name_var = aws_profile_name_var + self.aws_custom_attributes = aws_custom_attributes or {} + self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} + self.client: Optional[BaseClient] = None + + if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var): + msg = ( + f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and " + f"'{self.aws_secret_access_key_var}'." + ) + raise AWSConfigurationError(msg) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + """ + return default_to_dict( + self, + model=self.model, + aws_access_key_id_var=self.aws_access_key_id_var, + aws_secret_access_key_var=self.aws_secret_access_key_var, + aws_session_token_var=self.aws_session_token_var, + aws_region_name_var=self.aws_region_name_var, + aws_profile_name_var=self.aws_profile_name_var, + aws_custom_attributes=self.aws_custom_attributes, + generation_kwargs=self.generation_kwargs, + ) + + @classmethod + def from_dict(cls, data) -> "SagemakerGenerator": + """ + Deserialize the dictionary into an instance of SagemakerGenerator. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the SageMaker Inference client. + """ + boto3_import.check() + try: + session = boto3.Session( + aws_access_key_id=os.getenv(self.aws_access_key_id_var), + aws_secret_access_key=os.getenv(self.aws_secret_access_key_var), + aws_session_token=os.getenv(self.aws_session_token_var), + region_name=os.getenv(self.aws_region_name_var), + profile_name=os.getenv(self.aws_profile_name_var), + ) + self.client = session.client("sagemaker-runtime") + except Exception as e: + msg = ( + f"Could not connect to SageMaker Inference Endpoint '{self.model}'." + f"Make sure the Endpoint exists and AWS environment is configured." + ) + raise AWSConfigurationError(msg) from e + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param prompt: The string prompt to use for text generation. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the `__init__` method. + + :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata + for each response. + """ + if self.client is None: + msg = "SageMaker Inference client is not initialized. Please call warm_up() first." + raise ValueError(msg) + + generation_kwargs = generation_kwargs or self.generation_kwargs + custom_attributes = ";".join( + f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() + ) + try: + body = json.dumps({"inputs": prompt, "parameters": generation_kwargs}) + response = self.client.invoke_endpoint( + EndpointName=self.model, + Body=body, + ContentType="application/json", + Accept="application/json", + CustomAttributes=custom_attributes, + ) + response_json = response.get("Body").read().decode("utf-8") + output: Dict[str, Dict[str, Any]] = json.loads(response_json) + + # The output might be either a list of dictionaries or a single dictionary + list_output: List[Dict[str, Any]] + if output and isinstance(output, dict): + list_output = [output] + elif isinstance(output, list) and all(isinstance(o, dict) for o in output): + list_output = output + else: + msg = f"Unexpected model response type: {type(output)}" + raise ValueError(msg) + + # The key where the replies are stored changes from model to model, so we need to look for it. + # All other keys in the response are added to the metadata. + # Unfortunately every model returns different metadata, most of them return none at all, + # so we can't replicate the metadata structure of other generators. + for key in self.model_generation_keys: + if key in list_output[0]: + break + replies = [o.pop(key, None) for o in list_output] + + return {"replies": replies, "meta": list_output * len(replies)} + + except requests.HTTPError as err: + res = err.response + if res.status_code == MODEL_NOT_READY_STATUS_CODE: + msg = f"Sagemaker model not ready: {res.text}" + raise SagemakerNotReadyError(msg) from err + + msg = f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}" + raise SagemakerInferenceError(msg, status_code=res.status_code) from err diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py b/integrations/amazon_sagemaker/tests/__init__.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py rename to integrations/amazon_sagemaker/tests/__init__.py diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py new file mode 100644 index 000000000..a22634be1 --- /dev/null +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -0,0 +1,243 @@ +import os +from unittest.mock import Mock + +import pytest +from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator +from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError + + +class TestSagemakerGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator(model="test-model") + assert component.model == "test-model" + assert component.aws_access_key_id_var == "AWS_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "AWS_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "AWS_SESSION_TOKEN" + assert component.aws_region_name_var == "AWS_REGION" + assert component.aws_profile_name_var == "AWS_PROFILE" + assert component.aws_custom_attributes == {} + assert component.generation_kwargs == {"max_new_tokens": 1024} + assert component.client is None + + def test_init_fail_wo_access_key_or_secret_key(self, monkeypatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + assert component.model == "test-model" + assert component.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "MY_SESSION_TOKEN" + assert component.aws_region_name_var == "MY_REGION" + assert component.aws_profile_name_var == "MY_PROFILE" + assert component.aws_custom_attributes == {"custom": "attr"} + assert component.generation_kwargs == {"generation": "kwargs"} + assert component.client is None + + def test_to_from_dict(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + serialized = component.to_dict() + assert serialized == { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "test-model", + "aws_access_key_id_var": "MY_ACCESS_KEY_ID", + "aws_secret_access_key_var": "MY_SECRET_ACCESS_KEY", + "aws_session_token_var": "MY_SESSION_TOKEN", + "aws_region_name_var": "MY_REGION", + "aws_profile_name_var": "MY_PROFILE", + "aws_custom_attributes": {"custom": "attr"}, + "generation_kwargs": {"generation": "kwargs"}, + }, + } + deserialized = SagemakerGenerator.from_dict(serialized) + assert deserialized.model == "test-model" + assert deserialized.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert deserialized.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert deserialized.aws_session_token_var == "MY_SESSION_TOKEN" + assert deserialized.aws_region_name_var == "MY_REGION" + assert deserialized.aws_profile_name_var == "MY_PROFILE" + assert deserialized.aws_custom_attributes == {"custom": "attr"} + assert deserialized.generation_kwargs == {"generation": "kwargs"} + assert deserialized.client is None + + def test_run_with_list_of_dictionaries(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + def test_run_with_single_dictionary(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_falcon(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_llama2(self): + component = SagemakerGenerator( + model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_bloomz(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] diff --git a/integrations/astra/pydoc/config.yml b/integrations/astra/pydoc/config.yml new file mode 100644 index 000000000..68cc1c809 --- /dev/null +++ b/integrations/astra/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.astra.retriever", + "haystack_integrations.document_stores.astra.document_store", + "haystack_integrations.document_stores.astra.errors", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Astra integration for Haystack + category_slug: haystack-integrations + title: Astra + slug: integrations-astra + order: 20 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_astra.md \ No newline at end of file diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 6b4e2565d..7599797a8 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -50,6 +50,7 @@ git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*" dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -62,7 +63,9 @@ cov = [ "test-cov", "cov-report", ] - +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index ce4641611..2653c491f 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ dependencies = [ "haystack-ai", "chromadb<0.4.20", # FIXME: investigate why filtering tests broke on 0.4.20 + "typing_extensions>=4.8.0", ] [project.urls] diff --git a/integrations/cohere/pydoc/config.yml b/integrations/cohere/pydoc/config.yml new file mode 100644 index 000000000..9418739b5 --- /dev/null +++ b/integrations/cohere/pydoc/config.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.embedders.cohere.document_embedder", + "haystack_integrations.components.embedders.cohere.text_embedder", + "haystack_integrations.components.embedders.cohere.utils", + "haystack_integrations.components.generators.cohere.generator", + "haystack_integrations.components.generators.cohere.chat.chat_generator", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Cohere integration for Haystack + category_slug: haystack-integrations + title: Cohere + slug: integrations-cohere + order: 40 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_cohere.md diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 42349d9fb..332471674 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -49,6 +49,7 @@ git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index c91ada419..edefc1a43 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,7 +260,10 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): + with pytest.raises( + cohere.CohereAPIError, + match="model not found, make sure the correct model ID was used and that you have access to the model.", + ): component.run(chat_messages) @pytest.mark.skipif( diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index e2ce10405..90d4d3e28 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -164,7 +164,7 @@ def __init__(self): self.responses = "" def __call__(self, chunk): - self.responses += chunk.text + self.responses += chunk.content return chunk callback = Callback() diff --git a/integrations/elasticsearch/pydoc/config.yml b/integrations/elasticsearch/pydoc/config.yml new file mode 100644 index 000000000..dc5a090bc --- /dev/null +++ b/integrations/elasticsearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever", + "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever", + "haystack_integrations.document_stores.elasticsearch.document_store", + "haystack_integrations.document_stores.elasticsearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Elasticsearch integration for Haystack + category_slug: haystack-integrations + title: Elasticsearch + slug: integrations-elasticsearch + order: 50 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_elasticsearch.md \ No newline at end of file diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index af3d89c0c..b67df7e03 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 67cbcb7af..c8a591b69 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "requests>=2.26.0", "scikit_learn>=1.0.2", "scipy", - "sentence_transformers>=2.2.0", + "sentence_transformers>=2.2.0,<2.3.0", "torch", "tqdm", "rich", diff --git a/integrations/ollama/pydoc/config.yml b/integrations/ollama/pydoc/config.yml new file mode 100644 index 000000000..768694991 --- /dev/null +++ b/integrations/ollama/pydoc/config.yml @@ -0,0 +1,29 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.generators.ollama.generator", + "haystack_integrations.components.generators.ollama.chat.chat_generator" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Ollama integration for Haystack + category_slug: haystack-integrations + title: Ollama + slug: integrations-ollama + order: 120 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_ollama.md \ No newline at end of file diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 69cc2ed16..e3bb738b6 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -48,6 +48,7 @@ git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -60,6 +61,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] diff --git a/integrations/opensearch/pydoc/config.yml b/integrations/opensearch/pydoc/config.yml new file mode 100644 index 000000000..3e07f6625 --- /dev/null +++ b/integrations/opensearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.opensearch.bm25_retriever", + "haystack_integrations.components.retrievers.opensearch.embedding_retriever", + "haystack_integrations.document_stores.opensearch.document_store", + "haystack_integrations.document_stores.opensearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: OpenSearch integration for Haystack + category_slug: haystack-integrations + title: OpenSearch + slug: integrations-opensearch + order: 130 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_opensearch.md diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 3edd544a2..794fa73fa 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -62,6 +63,10 @@ cov = [ "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] + [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/pgvector/pydoc/config.yml b/integrations/pgvector/pydoc/config.yml new file mode 100644 index 000000000..79974b4a1 --- /dev/null +++ b/integrations/pgvector/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.pgvector.embedding_retriever", + "haystack_integrations.document_stores.pgvector.document_store", + "haystack_integrations.document_stores.pgvector.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Pgvector integration for Haystack + category_slug: haystack-integrations + title: Pgvector + slug: integrations-pgvector + order: 140 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_pgvector.md diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index b361af8b1..10ef5d314 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "ipython", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -63,6 +64,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py new file mode 100644 index 000000000..ec0cf0dc4 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .embedding_retriever import PgvectorEmbeddingRetriever + +__all__ = ["PgvectorEmbeddingRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py new file mode 100644 index 000000000..26807e9bd --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS + + +@component +class PgvectorEmbeddingRetriever: + """ + Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. + + Needs to be connected to the PgvectorDocumentStore. + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Create the PgvectorEmbeddingRetriever component. + + :param document_store: An instance of PgvectorDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param vector_function: The similarity function to use when searching for similar embeddings. + Defaults to the one set in the `document_store` instance. + "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: if the document store is using the "hnsw" search strategy, the vector function + should match the one utilized during index creation to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + + :raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.vector_function = vector_function or document_store.vector_function + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + vector_function=self.vector_function, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PgvectorDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Retrieve documents from the PgvectorDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param vector_function: The similarity function to use when searching for similar embeddings. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + vector_function = vector_function or self.vector_function + + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + vector_function=vector_function, + ) + return {"documents": docs} diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index bb1915a6f..097e86c7e 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -8,6 +8,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.filters import convert from psycopg import Error, IntegrityError, connect from psycopg.abc import Query from psycopg.cursor import Cursor @@ -18,6 +19,8 @@ from pgvector.psycopg import register_vector +from .filters import _convert_filters_to_where_clause_and_params + logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ @@ -49,8 +52,10 @@ meta = EXCLUDED.meta """ +VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"] + VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { - "cosine_distance": "vector_cosine_ops", + "cosine_similarity": "vector_cosine_ops", "inner_product": "vector_ip_ops", "l2_distance": "vector_l2_ops", } @@ -67,7 +72,7 @@ def __init__( connection_string: str, table_name: str = "haystack_documents", embedding_dimension: int = 768, - vector_function: Literal["cosine_distance", "inner_product", "l2_distance"] = "cosine_distance", + vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", recreate_table: bool = False, search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, @@ -84,12 +89,23 @@ def __init__( :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". :param embedding_dimension: The dimension of the embedding. Defaults to 768. :param vector_function: The similarity function to use when searching for similar embeddings. - Defaults to "cosine_distance". Set it to one of the following values: - :type vector_function: Literal["cosine_distance", "inner_product", "l2_distance"] + Defaults to "cosine_similarity". "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] :param recreate_table: Whether to recreate the table if it already exists. Defaults to False. :param search_strategy: The search strategy to use when searching for similar embeddings. Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy, which trades off some accuracy for speed; it is recommended for large numbers of documents. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. :type search_strategy: Literal["exact_nearest_neighbor", "hnsw"] :param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists. Defaults to False. Only used if search_strategy is set to "hnsw". @@ -104,6 +120,9 @@ def __init__( self.connection_string = connection_string self.table_name = table_name self.embedding_dimension = embedding_dimension + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) self.vector_function = vector_function self.recreate_table = recreate_table self.search_strategy = search_strategy @@ -158,11 +177,16 @@ def _execute_sql( params = params or () cursor = cursor or self._cursor + sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) + try: result = cursor.execute(sql_query, params) except Error as e: self._connection.rollback() - raise DocumentStoreError(error_msg) from e + detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs." + raise DocumentStoreError(detailed_error_msg) from e + return result def _create_table_if_not_exists(self): @@ -257,15 +281,37 @@ def count_documents(self) -> int: ] return count - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 - # TODO: implement filters - sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + For a detailed specification of the filters, + refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) + + :param filters: The filters to apply to the document list. + :return: A list of Documents that match the given filters. + """ + if filters: + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise TypeError(msg) + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + sql_filter += sql_where_clause result = self._execute_sql( - sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor + sql_filter, + params, + error_msg="Could not filter documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, ) - # Fetch all the records records = result.fetchall() docs = self._from_pg_to_haystack_documents(records) return docs @@ -300,6 +346,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D sql_insert += SQL(" RETURNING id") + sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) + try: self._cursor.executemany(sql_insert, db_documents, returning=True) except IntegrityError as ie: @@ -307,7 +356,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D raise DuplicateDocumentError from ie except Error as e: self._connection.rollback() - raise DocumentStoreError from e + error_msg = ( + "Could not write documents to PgvectorDocumentStore. \n" + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(error_msg) from e # get the number of the inserted documents, inspired by psycopg3 docs # https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany @@ -356,7 +409,7 @@ def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> Lis # postgresql returns the embedding as a string # so we need to convert it to a list of floats - if "embedding" in document and document["embedding"]: + if document.get("embedding"): haystack_dict["embedding"] = [float(el) for el in document["embedding"].strip("[]").split(",")] haystack_document = Document.from_dict(haystack_dict) @@ -386,3 +439,81 @@ def delete_documents(self, document_ids: List[str]) -> None: ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + + This method is not meant to be part of the public interface of + `PgvectorDocumentStore` and it should not be called directly. + `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. + :raises ValueError + :return: List of Documents that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + if len(query_embedding) != self.embedding_dimension: + msg = ( + f"query_embedding dimension ({len(query_embedding)}) does not match PgvectorDocumentStore " + f"embedding dimension ({self.embedding_dimension})." + ) + raise ValueError(msg) + + vector_function = vector_function or self.vector_function + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) + + # the vector must be a string with this format: "'[3,1,2]'" + query_embedding_for_postgres = f"'[{','.join(str(el) for el in query_embedding)}]'" + + # to compute the scores, we use the approach described in pgvector README: + # https://github.com/pgvector/pgvector?tab=readme-ov-file#distances + # cosine_similarity and inner_product are modified from the result of the operator + if vector_function == "cosine_similarity": + score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score" + elif vector_function == "inner_product": + score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score" + elif vector_function == "l2_distance": + score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" + + sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + table_name=Identifier(self.table_name), + score=SQL(score_definition), + ) + + sql_where_clause = SQL("") + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + + # we always want to return the most similar documents first + # so when using l2_distance, the sort order must be ASC + sort_order = "ASC" if vector_function == "l2_distance" else "DESC" + + sql_sort = SQL(" ORDER BY score {sort_order} LIMIT {top_k}").format( + top_k=SQLLiteral(top_k), + sort_order=SQL(sort_order), + ) + + sql_query = sql_select + sql_where_clause + sql_sort + + result = self._execute_sql( + sql_query, + params, + error_msg="Could not retrieve documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py new file mode 100644 index 000000000..daa90f502 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from itertools import chain +from typing import Any, Dict, List + +from haystack.errors import FilterError +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + +# we need this mapping to cast meta values to the correct type, +# since they are stored in the JSONB field as strings. +# this dict can be extended if needed +PYTHON_TYPES_TO_PG_TYPES = { + int: "integer", + float: "real", + bool: "boolean", +} + +NO_VALUE = "no_value" + + +def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]: + """ + Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. + """ + if "field" in filters: + query, values = _parse_comparison_condition(filters) + else: + query, values = _parse_logical_condition(filters) + + where_clause = SQL(" WHERE ") + SQL(query) + params = tuple(value for value in values if value != NO_VALUE) + + return where_clause, params + + +def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + if operator not in ["AND", "OR"]: + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR'" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + query, vals = _parse_comparison_condition(c) + else: + query, vals = _parse_logical_condition(c) + conditions.append((query, vals)) + + query_parts, values = [], [] + for c in conditions: + query_parts.append(c[0]) + values.append(c[1]) + if isinstance(values[0], list): + values = list(chain.from_iterable(values)) + + if operator == "AND": + sql_query = f"({' AND '.join(query_parts)})" + elif operator == "OR": + sql_query = f"({' OR '.join(query_parts)})" + else: + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + return sql_query, values + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown comparison operator '{operator}'. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise FilterError(msg) + + value: Any = condition["value"] + if isinstance(value, DataFrame): + # DataFrames are stored as JSONB and we query them as such + value = Jsonb(value.to_json()) + field = f"({field})::jsonb" + + if field.startswith("meta."): + field = _treat_meta_field(field, value) + + field, value = COMPARISON_OPERATORS[operator](field, value) + return field, [value] + + +def _treat_meta_field(field: str, value: Any) -> str: + """ + Internal method that modifies the field str + to make the meta JSONB field queryable. + + Examples: + >>> _treat_meta_field(field="meta.number", value=9) + "(meta->>'number')::integer" + + >>> _treat_meta_field(field="meta.name", value="my_name") + "meta->>'name'" + """ + + # use the ->> operator to access keys in the meta JSONB field + field_name = field.split(".", 1)[-1] + field = f"meta->>'{field_name}'" + + # meta fields are stored as strings in the JSONB field, + # so we need to cast them to the correct type + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value)) + if isinstance(value, list) and len(value) > 0: + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0])) + + if type_value: + field = f"({field})::{type_value}" + + return field + + +def _equal(field: str, value: Any) -> tuple[str, Any]: + if value is None: + # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params + return f"{field} IS NULL", NO_VALUE + return f"{field} = %s", value + + +def _not_equal(field: str, value: Any) -> tuple[str, Any]: + # we use IS DISTINCT FROM to correctly handle NULL values + # (not handled by !=) + return f"{field} IS DISTINCT FROM %s", value + + +def _greater_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} > %s", value + + +def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} >= %s", value + + +def _less_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} < %s", value + + +def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} <= %s", value + + +def _not_in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return f"{field} IS NULL OR {field} != ALL(%s)", [value] + + +def _in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation + return f"{field} = ANY(%s)", [value] + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py new file mode 100644 index 000000000..743e8de14 --- /dev/null +++ b/integrations/pgvector/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.fixture +def document_store(request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "exact_nearest_neighbor" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 9f3521838..e8d9107d7 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -14,27 +14,6 @@ class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): - @pytest.fixture - def document_store(self, request): - connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" - table_name = f"haystack_{request.node.name}" - embedding_dimension = 768 - vector_function = "cosine_distance" - recreate_table = True - search_strategy = "exact_nearest_neighbor" - - store = PgvectorDocumentStore( - connection_string=connection_string, - table_name=table_name, - embedding_dimension=embedding_dimension, - vector_function=vector_function, - recreate_table=recreate_table, - search_strategy=search_strategy, - ) - yield store - - store.delete_table() - def test_write_documents(self, document_store: PgvectorDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 diff --git a/integrations/pgvector/tests/test_embedding_retrieval.py b/integrations/pgvector/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..1d5e8e297 --- /dev/null +++ b/integrations/pgvector/tests/test_embedding_retrieval.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from numpy.random import rand + + +class TestEmbeddingRetrieval: + @pytest.fixture + def document_store_w_hnsw_index(self, request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_hnsw_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "hnsw" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_cosine_similarity(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="cosine_similarity" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (cosine sim)" + assert results[1].content == "2nd best document (cosine sim)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_inner_product(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (inner product)", embedding=most_similar_embedding), + Document(content="2nd best document (inner product)", embedding=second_best_embedding), + Document(content="Not very similar document (inner product)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="inner_product" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (inner product)" + assert results[1].content == "2nd best document (inner product)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_l2_distance(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.1] * 765 + [0.15] * 3 + second_best_embedding = [0.1] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (l2 dist)", embedding=most_similar_embedding), + Document(content="2nd best document (l2 dist)", embedding=second_best_embedding), + Document(content="Not very similar document (l2 dist)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="l2_distance" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (l2 dist)" + assert results[1].content == "2nd best document (l2 dist)" + assert results[0].score < results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [Document(content=f"Document {i}", embedding=rand(768).tolist()) for i in range(10)] + + for i in range(10): + docs[i].meta["meta_field"] = "custom_value" if i % 2 == 0 else "other_value" + + document_store.write_documents(docs) + + query_embedding = [0.1] * 768 + filters = {"field": "meta.meta_field", "operator": "==", "value": "custom_value"} + + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=3, filters=filters) + assert len(results) == 3 + for result in results: + assert result.meta["meta_field"] == "custom_value" + assert results[0].score > results[1].score > results[2].score + + def test_empty_query_embedding(self, document_store: PgvectorDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py new file mode 100644 index 000000000..8b2dc8ec9 --- /dev/null +++ b/integrations/pgvector/tests/test_filters.py @@ -0,0 +1,179 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack_integrations.document_stores.pgvector.filters import ( + FilterError, + _convert_filters_to_where_clause_and_params, + _parse_comparison_condition, + _parse_logical_condition, + _treat_meta_field, +) +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + This overrides the default assert_documents_are_equal from FilterDocumentsTest. + It is needed because the embeddings are not exactly the same when they are retrieved from Postgres. + """ + + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + # we first compare the embeddings approximately + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + received_doc.embedding, expected_doc.embedding = None, None + assert received_doc == expected_doc + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) + + @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") + def test_not_operator(self, document_store, filterable_docs): ... + + def test_treat_meta_field(self): + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" + assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + + # do not cast the field if its value is not one of the known types, an empty list or None + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" + assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" + assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + + def test_comparison_condition_dataframe_jsonb_conversion(self): + dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + condition = {"field": "meta.df", "operator": "==", "value": dataframe} + field, values = _parse_comparison_condition(condition) + assert field == "(meta.df)::jsonb = %s" + + # we check each slot of the Jsonb object because it does not implement __eq__ + assert values[0].obj == Jsonb(dataframe.to_json()).obj + assert values[0].dumps == Jsonb(dataframe.to_json()).dumps + + def test_comparison_condition_missing_operator(self): + condition = {"field": "meta.type", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_missing_value(self): + condition = {"field": "meta.type", "operator": "=="} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_unknown_operator(self): + condition = {"field": "meta.type", "operator": "unknown", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_logical_condition_missing_operator(self): + condition = {"conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_missing_conditions(self): + condition = {"operator": "AND"} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_unknown_operator(self): + condition = {"operator": "unknown", "conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_nested(self): + condition = { + "operator": "AND", + "conditions": [ + { + "operator": "OR", + "conditions": [ + {"field": "meta.domain", "operator": "!=", "value": "science"}, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, + ], + }, + { + "operator": "OR", + "conditions": [ + {"field": "meta.number", "operator": ">=", "value": 90}, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, + ], + }, + ], + } + query, values = _parse_logical_condition(condition) + assert query == ( + "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " + "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + ) + assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] + + def test_convert_filters_to_where_clause_and_params(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + assert params == (100, "intro") + + def test_convert_filters_to_where_clause_and_params_handle_null(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": None}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + assert params == ("intro",) diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py new file mode 100644 index 000000000..cca6bbc9f --- /dev/null +++ b/integrations/pgvector/tests/test_retriever.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +from haystack.dataclasses import Document +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +class TestRetriever: + def test_init_default(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever(document_store=document_store) + assert retriever.document_store == document_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.vector_function == document_store.vector_function + + def test_init(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + assert retriever.document_store == document_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_to_dict(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + def test_from_dict(self): + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + retriever = PgvectorEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_ef_search is None + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" + ) + + assert res == {"documents": [doc]} diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 2d73cdf58..c95ee0aac 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -54,8 +54,8 @@ dependencies = [ [tool.hatch.envs.default.scripts] # Pinecone tests are slow (require HTTP requests), so we run them in parallel # with pytest-xdist (https://pytest-xdist.readthedocs.io/en/stable/distribution.html) -test = "pytest -n auto --maxprocesses=3 {args:tests}" -test-cov = "coverage run -m pytest -n auto --maxprocesses=3 {args:tests}" +test = "pytest -n auto --maxprocesses=2 {args:tests}" +test-cov = "coverage run -m pytest -n auto --maxprocesses=2 {args:tests}" cov-report = [ "- coverage combine", "coverage report", diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index a755b7e47..92ea987b4 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -85,7 +85,7 @@ def __init__( ) self.dimension = actual_dimension or dimension - self._dummy_vector = [0.0] * self.dimension + self._dummy_vector = [-10.0] * self.dimension self.environment = environment self.index = index self.namespace = namespace diff --git a/integrations/uptrain/pyproject.toml b/integrations/uptrain/pyproject.toml index 631b7dab8..498772313 100644 --- a/integrations/uptrain/pyproject.toml +++ b/integrations/uptrain/pyproject.toml @@ -34,11 +34,11 @@ packages = ["src/haystack_integrations"] [tool.hatch.version] source = "vcs" -tag-pattern = 'integrations\/uptrain(?P.*)' +tag-pattern = 'integrations\/uptrain-v(?P.*)' [tool.hatch.version.raw-options] root = "../.." -git_describe_command = 'git describe --tags --match="integrations/uptrain[0-9]*"' +git_describe_command = 'git describe --tags --match="integrations/uptrain-v[0-9]*"' [tool.hatch.envs.default] dependencies = ["coverage[toml]>=6.5", "pytest"]