diff --git a/.github/ISSUE_TEMPLATE/new-integration-proposal copy.md b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md similarity index 55% rename from .github/ISSUE_TEMPLATE/new-integration-proposal copy.md rename to .github/ISSUE_TEMPLATE/breaking-change-proposal.md index 8ead69467..6c6fb9017 100644 --- a/.github/ISSUE_TEMPLATE/new-integration-proposal copy.md +++ b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md @@ -13,8 +13,14 @@ Briefly explain how the change is breaking and why is needed. ## Checklist -- [ ] The change is documented with docstrings and was merged in the `main` branch -- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated +```[tasklist] +### Tasks +- [ ] The changes are merged in the `main` branch (Code + Docstrings) +- [ ] New package version declares the breaking change +- [ ] The package has been released on PyPI - [ ] Docs at https://docs.haystack.deepset.ai/ were updated +- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated - [ ] Notebooks on https://github.com/deepset-ai/haystack-cookbook were updated (if needed) -- [ ] New package version declares the breaking change and package has been released on PyPI +- [ ] Tutorials on https://github.com/deepset-ai/haystack-tutorials were updated (if needed) +- [ ] Articles on https://github.com/deepset-ai/haystack-home/tree/main/content were updated (if needed) +``` diff --git a/.github/ISSUE_TEMPLATE/new-integration-proposal.md b/.github/ISSUE_TEMPLATE/new-integration-proposal.md index a40388eef..60e88c555 100644 --- a/.github/ISSUE_TEMPLATE/new-integration-proposal.md +++ b/.github/ISSUE_TEMPLATE/new-integration-proposal.md @@ -20,7 +20,8 @@ Also, if there's any new terminology involved, define it here. ## Checklist If the request is accepted, ensure the following checklist is complete before closing this issue. - +```[tasklist] +### Tasks - [ ] The code is documented with docstrings and was merged in the `main` branch - [ ] Docs are published at https://docs.haystack.deepset.ai/ - [ ] There is a Github workflow running the tests for the integration nightly and at every PR @@ -31,3 +32,4 @@ If the request is accepted, ensure the following checklist is complete before cl - [ ] The integration has been listed in the [Inventory section](https://github.com/deepset-ai/haystack-core-integrations#inventory) of this repo README - [ ] There is an example available to demonstrate the feature - [ ] The feature was announced through social media +``` \ No newline at end of file diff --git a/.github/labeler.yml b/.github/labeler.yml index 93eba1d82..4d060772c 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,6 +4,11 @@ integration:amazon-bedrock: - any-glob-to-any-file: "integrations/amazon_bedrock/**/*" - any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml" +integration:amazon-sagemaker: + - changed-files: + - any-glob-to-any-file: "integrations/amazon_sagemaker/**/*" + - any-glob-to-any-file: ".github/workflows/amazon_sagemaker.yml" + integration:astra: - changed-files: - any-glob-to-any-file: "integrations/astra/**/*" @@ -64,6 +69,11 @@ integration:opensearch: - any-glob-to-any-file: "integrations/opensearch/**/*" - any-glob-to-any-file: ".github/workflows/opensearch.yml" +integration:pgvector: + - changed-files: + - any-glob-to-any-file: "integrations/pgvector/**/*" + - any-glob-to-any-file: ".github/workflows/pgvector.yml" + integration:pinecone: - changed-files: - any-glob-to-any-file: "integrations/pinecone/**/*" @@ -76,8 +86,13 @@ integration:qdrant: integration:unstructured-fileconverter: - changed-files: - - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" - - any-glob-to-any-file: ".github/workflows/unstructured_fileconverter.yml" + - any-glob-to-any-file: "integrations/unstructured/**/*" + - any-glob-to-any-file: ".github/workflows/unstructured.yml" + +integration:uptrain: + - changed-files: + - any-glob-to-any-file: "integrations/uptrain/**/*" + - any-glob-to-any-file: ".github/workflows/uptrain.yml" integration:weaviate: - changed-files: diff --git a/.github/workflows/amazon_sagemaker.yml b/.github/workflows/amazon_sagemaker.yml new file mode 100644 index 000000000..88f397c85 --- /dev/null +++ b/.github/workflows/amazon_sagemaker.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / amazon-sagemaker + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/amazon_sagemaker/**" + - ".github/workflows/amazon_sagemaker.yml" + +defaults: + run: + working-directory: integrations/amazon_sagemaker + +concurrency: + group: amazon-sagemaker-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index b751550de..a1aab7154 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -53,6 +53,10 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests env: ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 89b6a5b24..b7f158cfe 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -52,5 +52,9 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests - run: hatch run cov \ No newline at end of file + run: hatch run cov diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 0f0030ec1..562556e47 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -52,5 +52,9 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests run: hatch run cov \ No newline at end of file diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index eb2c1748d..688e5c48f 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -10,6 +10,10 @@ on: - "integrations/elasticsearch/**" - ".github/workflows/elasticsearch.yml" +defaults: + run: + working-directory: integrations/elasticsearch + concurrency: group: elasticsearch-${{ github.head_ref }} cancel-in-progress: true @@ -40,14 +44,15 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/elasticsearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run ElasticSearch container - working-directory: integrations/elasticsearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests - working-directory: integrations/elasticsearch run: hatch run cov diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 7f61af14e..a375fc7db 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -54,6 +54,11 @@ jobs: - name: Pull the LLM in the Ollama service run: docker exec ollama ollama pull ${{ env.LLM_FOR_TESTS }} + - name: Generate docs + working-directory: integrations/ollama + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/ollama run: hatch run cov diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index aacb4ce71..72a01d090 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -18,6 +18,10 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" +defaults: + run: + working-directory: integrations/opensearch + jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} @@ -40,14 +44,16 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/opensearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run opensearch container - working-directory: integrations/opensearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/opensearch run: hatch run cov diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml new file mode 100644 index 000000000..badb2565b --- /dev/null +++ b/.github/workflows/pgvector.yml @@ -0,0 +1,64 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / pgvector + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/pgvector/**" + - ".github/workflows/pgvector.yml" + +concurrency: + group: pgvector-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +defaults: + run: + working-directory: integrations/pgvector + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.9","3.10","3.11"] + services: + pgvector: + image: ankane/pgvector:latest + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: + - 5432:5432 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov diff --git a/.github/workflows/uptrain.yml b/.github/workflows/uptrain.yml new file mode 100644 index 000000000..bacfa27fb --- /dev/null +++ b/.github/workflows/uptrain.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / uptrain + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/uptrain/**" + - ".github/workflows/uptrain.yml" + +defaults: + run: + working-directory: integrations/uptrain + +concurrency: + group: uptrain-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index c638773f0..03cbd45a5 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -29,15 +29,10 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true - - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -49,8 +44,10 @@ jobs: run: pip install --upgrade hatch - name: Lint - if: runner.os == 'Linux' run: hatch run lint:all + - name: Run Weaviate container + run: docker-compose up -d + - name: Run tests run: hatch run cov diff --git a/.gitignore b/.gitignore index 1815e02f8..8634bc259 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,6 @@ dmypy.json # IDEs .vscode + +# Docs generation artifacts +_readme_*.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..2b2d0bf2f --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +haystack@deepset.ai. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/README.md b/README.md index d4d34fd7d..39d669322 100644 --- a/README.md +++ b/README.md @@ -60,21 +60,24 @@ deepset-haystack ## Inventory -| Package | Type | PyPi Package | Status | -| ------------------------------------------------------------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| [astra-haystack](integrations/astra/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/astra-haystack.svg)](https://pypi.org/project/astra-haystack) | [![Test / astra](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml) | -| [amazon-bedrock-haystack](integrations/amazon-bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | -| [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | -| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | -| [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | -| [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | -| [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | -| [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | -| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | -| [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | -| [ollama-haystack](integrations/ollama/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | -| [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | -| [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | -| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | -| [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | +| Package | Type | PyPi Package | Status | +| ------------------------------------------------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [astra-haystack](integrations/astra/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/astra-haystack.svg)](https://pypi.org/project/astra-haystack) | [![Test / astra](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/astra.yml) | +| [amazon-bedrock-haystack](integrations/amazon-bedrock/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-bedrock-haystack.svg)](https://pypi.org/project/amazon-bedrock-haystack) | [![Test / amazon_bedrock](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_bedrock.yml) | +| [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | +| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | +| [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | +| [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | +| [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | +| [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | +| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | +| [ollama-haystack](integrations/ollama/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | +| [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | +| [pinecone-haystack](integrations/pinecone/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pinecone-haystack.svg?color=orange)](https://pypi.org/project/pinecone-haystack) | [![Test / pinecone](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pinecone.yml) | +| [pgvector-haystack](integrations/pgvector/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/pgvector-haystack.svg?color=orange)](https://pypi.org/project/pgvector-haystack) | [![Test / pgvector](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/pgvector.yml) | +| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | +| [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | +| [uptrain-haystack](integrations/uptrain/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [![Test / uptrain](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml) | +| [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | diff --git a/integrations/amazon_bedrock/README.md b/integrations/amazon_bedrock/README.md index f84c8f3c4..3a689ef3b 100644 --- a/integrations/amazon_bedrock/README.md +++ b/integrations/amazon_bedrock/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Contributing](#contributing) - [License](#license) ## Installation @@ -16,6 +17,24 @@ pip install amazon-bedrock-haystack ``` +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` +> Note: there are no integration tests for this project. + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + ## License `amazon-bedrock-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 7e82924a8..6a2ce3eab 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/amazon_bedrock-v(?P.*)' @@ -71,7 +74,8 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/amazon_bedrock_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" + style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -136,26 +140,24 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["amazon_bedrock_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["amazon_bedrock_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true -omit = [ - "src/amazon_bedrock_haystack/__about__.py", -] + [tool.coverage.paths] -amazon_bedrock_haystack = ["src/amazon_bedrock_haystack", "*/amazon_bedrock/src/amazon_bedrock_haystack"] -tests = ["tests", "*/amazon_bedrock_haystack/tests"] +amazon_bedrock_haystack = ["src/*"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -170,6 +172,7 @@ module = [ "transformers.*", "boto3.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py similarity index 63% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 3e05179c0..236347b61 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator +from .generator import AmazonBedrockGenerator __all__ = ["AmazonBedrockGenerator"] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index bec172867..40ba0bc67 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import TokenStreamingHandler +from .handlers import TokenStreamingHandler class BedrockModelAdapter(ABC): diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index dda84fe14..4c43c9a09 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -7,12 +7,7 @@ from botocore.exceptions import BotoCoreError, ClientError from haystack import component, default_from_dict, default_to_dict -from amazon_bedrock_haystack.errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, - AWSConfigurationError, -) -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from .adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -20,7 +15,12 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import ( +from .errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) +from .handlers import ( DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler, diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index a05c95ba3..b08e9dfd5 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -4,9 +4,8 @@ import pytest from botocore.exceptions import BotoCoreError -from amazon_bedrock_haystack.errors import AmazonBedrockConfigurationError -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator +from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -14,6 +13,7 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) +from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.fixture @@ -34,7 +34,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock_handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -55,7 +55,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): ) expected_dict = { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -72,7 +72,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): """ generator = AmazonBedrockGenerator.from_dict( { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -235,7 +235,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -266,7 +266,7 @@ def test_supports_for_invalid_bedrock_config(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -282,7 +282,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -314,7 +314,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -335,7 +335,7 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises( AmazonBedrockConfigurationError, diff --git a/integrations/amazon_sagemaker/LICENSE.txt b/integrations/amazon_sagemaker/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/amazon_sagemaker/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/amazon_sagemaker/README.md b/integrations/amazon_sagemaker/README.md new file mode 100644 index 000000000..1ea01871d --- /dev/null +++ b/integrations/amazon_sagemaker/README.md @@ -0,0 +1,52 @@ +# amazon-sagemaker-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [Contributing](#contributing) +- [License](#license) + +## Installation + +```console +pip install amazon-sagemaker-haystack +``` + +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` + +> Note: You need to export your AWS credentials for Sagemaker integration tests to run (`AWS_ACCESS_KEY_ID` and +`AWS_SECRET_SECRET_KEY`). If those are missing, the integration tests will be skipped. + +To only run unit tests: +``` +hatch run test -m "not integration" +``` + +To only run integration tests: +``` +hatch run test -m "integration" +``` + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + +## License + +`amazon-sagemaker-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml new file mode 100644 index 000000000..916307156 --- /dev/null +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "amazon-sagemaker-haystack" +dynamic = ["version"] +description = 'An integration of Amazon Sagemaker as an SagemakerGenerator component.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "boto3>=1.28.57", +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/amazon_sagemaker-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/amazon_sagemaker-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Import sorting doesn't seem to work + "I001", + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["haystack_integrations"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +branch = true +parallel = true + +[tool.coverage.paths] +amazon_sagemaker_haystack = ["src"] +tests = ["tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py new file mode 100644 index 000000000..0fe45a8a1 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from haystack_integrations.components.generators.amazon_sagemaker.sagemaker import SagemakerGenerator + +__all__ = ["SagemakerGenerator"] diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py new file mode 100644 index 000000000..6c13d0fcb --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py @@ -0,0 +1,46 @@ +from typing import Optional + + +class SagemakerError(Exception): + """ + Error generated by the Amazon Sagemaker integration. + """ + + def __init__( + self, + message: Optional[str] = None, + ): + super().__init__() + if message: + self.message = message + + def __getattr__(self, attr): + # If self.__cause__ is None, it will raise the expected AttributeError + getattr(self.__cause__, attr) + + def __str__(self): + return self.message + + def __repr__(self): + return str(self) + + +class AWSConfigurationError(SagemakerError): + """Exception raised when AWS is not configured correctly""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerNotReadyError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerInferenceError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py new file mode 100644 index 000000000..35e54a055 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -0,0 +1,224 @@ +import json +import logging +import os +from typing import Any, ClassVar, Dict, List, Optional + +import requests +from haystack import component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack_integrations.components.generators.amazon_sagemaker.errors import ( + AWSConfigurationError, + SagemakerInferenceError, + SagemakerNotReadyError, +) + +with LazyImport(message="Run 'pip install boto3'") as boto3_import: + import boto3 # type: ignore + from botocore.client import BaseClient # type: ignore + + +logger = logging.getLogger(__name__) + + +MODEL_NOT_READY_STATUS_CODE = 429 + + +@component +class SagemakerGenerator: + """ + Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker + Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the + [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). + + **Example:** + + First export your AWS credentials as environment variables: + ```bash + export AWS_ACCESS_KEY_ID= + export AWS_SECRET_ACCESS_KEY= + ``` + (Note: you may also need to set the session token and region name, depending on your AWS configuration) + + Then you can use the generator as follows: + ```python + from haystack.components.generators.sagemaker import SagemakerGenerator + generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") + generator.warm_up() + response = generator.run("What's Natural Language Processing? Be brief.") + print(response) + ``` + ``` + >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on + >> the interaction between computers and human language. It involves enabling computers to understand, interpret, + >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]} + ``` + """ + + model_generation_keys: ClassVar = ["generated_text", "generation"] + + def __init__( + self, + model: str, + aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID", + aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY", + aws_session_token_var: str = "AWS_SESSION_TOKEN", + aws_region_name_var: str = "AWS_REGION", + aws_profile_name_var: str = "AWS_PROFILE", + aws_custom_attributes: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Instantiates the session with SageMaker. + + :param model: The name for SageMaker Model Endpoint. + :param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored. + :param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored. + :param aws_session_token_var: The name of the env var where the AWS session token is stored. + :param aws_region_name_var: The name of the env var where the AWS region name is stored. + :param aws_profile_name_var: The name of the env var where the AWS profile name is stored. + :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` + in case of Llama-2 models. + :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters + see your model's documentation page, for example here for HuggingFace models: + https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model + + Specifically, Llama-2 models support the following inference payload parameters: + + - `max_new_tokens`: Model generates text until the output length (excluding the input context length) + reaches `max_new_tokens`. If specified, it must be a positive integer. + - `temperature`: Controls the randomness in the output. Higher temperature results in output sequence with + low-probability words and lower temperature results in output sequence with high-probability words. + If `temperature=0`, it results in greedy decoding. If specified, it must be a positive float. + - `top_p`: In each step of text generation, sample from the smallest possible set of words with cumulative + probability `top_p`. If specified, it must be a float between 0 and 1. + - `return_full_text`: If `True`, input text will be part of the output generated text. If specified, it must + be boolean. The default value for it is `False`. + """ + self.model = model + self.aws_access_key_id_var = aws_access_key_id_var + self.aws_secret_access_key_var = aws_secret_access_key_var + self.aws_session_token_var = aws_session_token_var + self.aws_region_name_var = aws_region_name_var + self.aws_profile_name_var = aws_profile_name_var + self.aws_custom_attributes = aws_custom_attributes or {} + self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} + self.client: Optional[BaseClient] = None + + if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var): + msg = ( + f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and " + f"'{self.aws_secret_access_key_var}'." + ) + raise AWSConfigurationError(msg) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + """ + return default_to_dict( + self, + model=self.model, + aws_access_key_id_var=self.aws_access_key_id_var, + aws_secret_access_key_var=self.aws_secret_access_key_var, + aws_session_token_var=self.aws_session_token_var, + aws_region_name_var=self.aws_region_name_var, + aws_profile_name_var=self.aws_profile_name_var, + aws_custom_attributes=self.aws_custom_attributes, + generation_kwargs=self.generation_kwargs, + ) + + @classmethod + def from_dict(cls, data) -> "SagemakerGenerator": + """ + Deserialize the dictionary into an instance of SagemakerGenerator. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the SageMaker Inference client. + """ + boto3_import.check() + try: + session = boto3.Session( + aws_access_key_id=os.getenv(self.aws_access_key_id_var), + aws_secret_access_key=os.getenv(self.aws_secret_access_key_var), + aws_session_token=os.getenv(self.aws_session_token_var), + region_name=os.getenv(self.aws_region_name_var), + profile_name=os.getenv(self.aws_profile_name_var), + ) + self.client = session.client("sagemaker-runtime") + except Exception as e: + msg = ( + f"Could not connect to SageMaker Inference Endpoint '{self.model}'." + f"Make sure the Endpoint exists and AWS environment is configured." + ) + raise AWSConfigurationError(msg) from e + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param prompt: The string prompt to use for text generation. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the `__init__` method. + + :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata + for each response. + """ + if self.client is None: + msg = "SageMaker Inference client is not initialized. Please call warm_up() first." + raise ValueError(msg) + + generation_kwargs = generation_kwargs or self.generation_kwargs + custom_attributes = ";".join( + f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() + ) + try: + body = json.dumps({"inputs": prompt, "parameters": generation_kwargs}) + response = self.client.invoke_endpoint( + EndpointName=self.model, + Body=body, + ContentType="application/json", + Accept="application/json", + CustomAttributes=custom_attributes, + ) + response_json = response.get("Body").read().decode("utf-8") + output: Dict[str, Dict[str, Any]] = json.loads(response_json) + + # The output might be either a list of dictionaries or a single dictionary + list_output: List[Dict[str, Any]] + if output and isinstance(output, dict): + list_output = [output] + elif isinstance(output, list) and all(isinstance(o, dict) for o in output): + list_output = output + else: + msg = f"Unexpected model response type: {type(output)}" + raise ValueError(msg) + + # The key where the replies are stored changes from model to model, so we need to look for it. + # All other keys in the response are added to the metadata. + # Unfortunately every model returns different metadata, most of them return none at all, + # so we can't replicate the metadata structure of other generators. + for key in self.model_generation_keys: + if key in list_output[0]: + break + replies = [o.pop(key, None) for o in list_output] + + return {"replies": replies, "meta": list_output * len(replies)} + + except requests.HTTPError as err: + res = err.response + if res.status_code == MODEL_NOT_READY_STATUS_CODE: + msg = f"Sagemaker model not ready: {res.text}" + raise SagemakerNotReadyError(msg) from err + + msg = f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}" + raise SagemakerInferenceError(msg, status_code=res.status_code) from err diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py b/integrations/amazon_sagemaker/tests/__init__.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py rename to integrations/amazon_sagemaker/tests/__init__.py diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py new file mode 100644 index 000000000..a22634be1 --- /dev/null +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -0,0 +1,243 @@ +import os +from unittest.mock import Mock + +import pytest +from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator +from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError + + +class TestSagemakerGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator(model="test-model") + assert component.model == "test-model" + assert component.aws_access_key_id_var == "AWS_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "AWS_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "AWS_SESSION_TOKEN" + assert component.aws_region_name_var == "AWS_REGION" + assert component.aws_profile_name_var == "AWS_PROFILE" + assert component.aws_custom_attributes == {} + assert component.generation_kwargs == {"max_new_tokens": 1024} + assert component.client is None + + def test_init_fail_wo_access_key_or_secret_key(self, monkeypatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + assert component.model == "test-model" + assert component.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "MY_SESSION_TOKEN" + assert component.aws_region_name_var == "MY_REGION" + assert component.aws_profile_name_var == "MY_PROFILE" + assert component.aws_custom_attributes == {"custom": "attr"} + assert component.generation_kwargs == {"generation": "kwargs"} + assert component.client is None + + def test_to_from_dict(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + serialized = component.to_dict() + assert serialized == { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "test-model", + "aws_access_key_id_var": "MY_ACCESS_KEY_ID", + "aws_secret_access_key_var": "MY_SECRET_ACCESS_KEY", + "aws_session_token_var": "MY_SESSION_TOKEN", + "aws_region_name_var": "MY_REGION", + "aws_profile_name_var": "MY_PROFILE", + "aws_custom_attributes": {"custom": "attr"}, + "generation_kwargs": {"generation": "kwargs"}, + }, + } + deserialized = SagemakerGenerator.from_dict(serialized) + assert deserialized.model == "test-model" + assert deserialized.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert deserialized.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert deserialized.aws_session_token_var == "MY_SESSION_TOKEN" + assert deserialized.aws_region_name_var == "MY_REGION" + assert deserialized.aws_profile_name_var == "MY_PROFILE" + assert deserialized.aws_custom_attributes == {"custom": "attr"} + assert deserialized.generation_kwargs == {"generation": "kwargs"} + assert deserialized.client is None + + def test_run_with_list_of_dictionaries(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + def test_run_with_single_dictionary(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_falcon(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_llama2(self): + component = SagemakerGenerator( + model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_bloomz(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index ac93f43ed..35963868c 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -8,10 +8,10 @@ from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter from haystack.components.routers import FileTypeRouter from haystack.components.writers import DocumentWriter -from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index fb13c3d93..cacb1eb9f 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -7,10 +7,10 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator from haystack.components.writers import DocumentWriter -from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever +from haystack_integrations.document_stores.astra import AstraDocumentStore logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) diff --git a/integrations/astra/pydoc/config.yml b/integrations/astra/pydoc/config.yml new file mode 100644 index 000000000..68cc1c809 --- /dev/null +++ b/integrations/astra/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.astra.retriever", + "haystack_integrations.document_stores.astra.document_store", + "haystack_integrations.document_stores.astra.errors", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Astra integration for Haystack + category_slug: haystack-integrations + title: Astra + slug: integrations-astra + order: 20 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_astra.md \ No newline at end of file diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index b99449e03..7599797a8 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/astra-v(?P.*)' @@ -47,6 +50,7 @@ git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*" dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -59,7 +63,9 @@ cov = [ "test-cov", "cov-report", ] - +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] @@ -71,7 +77,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/astra_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -141,17 +147,17 @@ unfixable = [ exclude = ["example"] [tool.ruff.isort] -known-first-party = ["astra_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["astra_haystack", "tests"] +source_pkgs = ["haystack_integrations", "tests"] branch = true parallel = true omit = [ @@ -159,7 +165,7 @@ omit = [ ] [tool.coverage.paths] -astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"] +astra_haystack = ["src"] tests = ["tests"] [tool.coverage.report] @@ -178,10 +184,10 @@ markers = [ [[tool.mypy.overrides]] module = [ - "astra_haystack.*", "astra_client.*", "pydantic.*", "haystack.*", + "haystack_integrations.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py new file mode 100644 index 000000000..33ef6d15e --- /dev/null +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present Anant Corporation +# +# SPDX-License-Identifier: Apache-2.0 +from .retriever import AstraRetriever + +__all__ = ["AstraRetriever"] diff --git a/integrations/astra/src/astra_haystack/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py similarity index 96% rename from integrations/astra/src/astra_haystack/retriever.py rename to integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 47304df2c..fdf9b0722 100644 --- a/integrations/astra/src/astra_haystack/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -6,7 +6,7 @@ from haystack import Document, component, default_from_dict, default_to_dict -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @component diff --git a/integrations/astra/src/astra_haystack/__init__.py b/integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py similarity index 71% rename from integrations/astra/src/astra_haystack/__init__.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py index 5c99dedf6..4618beb08 100644 --- a/integrations/astra/src/astra_haystack/__init__.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 -from astra_haystack.document_store import AstraDocumentStore +from .document_store import AstraDocumentStore __all__ = ["AstraDocumentStore"] diff --git a/integrations/astra/src/astra_haystack/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py similarity index 100% rename from integrations/astra/src/astra_haystack/astra_client.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py diff --git a/integrations/astra/src/astra_haystack/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py similarity index 91% rename from integrations/astra/src/astra_haystack/document_store.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 6e630bef5..8e03de4a6 100644 --- a/integrations/astra/src/astra_haystack/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -12,9 +12,9 @@ from haystack.document_stores.errors import DuplicateDocumentError, MissingDocumentError from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.astra_client import AstraClient -from astra_haystack.errors import AstraDocumentStoreFilterError -from astra_haystack.filters import _convert_filters +from .astra_client import AstraClient +from .errors import AstraDocumentStoreFilterError +from .filters import _convert_filters logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def __init__( astra_application_token: str, astra_keyspace: str, astra_collection: str, - embedding_dim: Optional[int] = 768, + embedding_dim: int = 768, duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", ): @@ -104,17 +104,12 @@ def to_dict(self) -> Dict[str, Any]: def write_documents( self, documents: List[Document], - index: Optional[str] = None, - batch_size: int = 20, policy: DuplicatePolicy = DuplicatePolicy.NONE, ): """ Indexes documents for later queries. :param documents: a list of Haystack Document objects. - :param index: Optional name of index where the documents shall be written to. - If None, the DocumentStore's default index (self.index) will be used. - :param batch_size: Number of documents that are passed to bulk function at a time. :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, @@ -125,26 +120,13 @@ def write_documents( - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. :return: int """ - - if index is None and self.index is None: - msg = "No Astra client provided" - raise ValueError(msg) - - if index is None: - index = self.index - if policy is None or policy == DuplicatePolicy.NONE: if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: policy = self.duplicates_policy else: policy = DuplicatePolicy.SKIP - if batch_size > MAX_BATCH_SIZE: - logger.warning( - f"batch_size set to {batch_size}, " - f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20." - ) - batch_size = MAX_BATCH_SIZE + batch_size = MAX_BATCH_SIZE def _convert_input_document(document: Union[dict, Document]): if isinstance(document, Document): @@ -196,7 +178,7 @@ def _convert_input_document(document: Union[dict, Document]): if policy == DuplicatePolicy.SKIP: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: @@ -205,7 +187,7 @@ def _convert_input_document(document: Union[dict, Document]): elif policy == DuplicatePolicy.OVERWRITE: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter += len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: @@ -214,7 +196,7 @@ def _convert_input_document(document: Union[dict, Document]): if len(duplicate_documents) > 0: updated_ids = [] for duplicate_doc in duplicate_documents: - updated = index.update_document(duplicate_doc, "_id") # type: ignore + updated = self.index.update_document(duplicate_doc, "_id") # type: ignore if updated: updated_ids.append(duplicate_doc["_id"]) insertion_counter = insertion_counter + len(updated_ids) @@ -225,7 +207,7 @@ def _convert_input_document(document: Union[dict, Document]): elif policy == DuplicatePolicy.FAIL: if len(new_documents) > 0: for batch in _batches(new_documents, batch_size): - inserted_ids = index.insert(batch) # type: ignore + inserted_ids = self.index.insert(batch) # type: ignore insertion_counter = insertion_counter + len(inserted_ids) logger.info(f"write_documents inserted documents with id {inserted_ids}") else: diff --git a/integrations/astra/src/astra_haystack/errors.py b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py similarity index 100% rename from integrations/astra/src/astra_haystack/errors.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/errors.py diff --git a/integrations/astra/src/astra_haystack/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py similarity index 100% rename from integrations/astra/src/astra_haystack/filters.py rename to integrations/astra/src/haystack_integrations/document_stores/astra/filters.py diff --git a/integrations/astra/tests/conftest.py b/integrations/astra/tests/conftest.py index 02f5d7cad..274b38352 100644 --- a/integrations/astra/tests/conftest.py +++ b/integrations/astra/tests/conftest.py @@ -3,7 +3,7 @@ import pytest from haystack.document_stores.types import DuplicatePolicy -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @pytest.fixture diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index f203ab721..019a66398 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -10,7 +10,7 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests -from astra_haystack.document_store import AstraDocumentStore +from haystack_integrations.document_stores.astra import AstraDocumentStore @pytest.mark.skipif( diff --git a/integrations/astra/tests/test_retriever.py b/integrations/astra/tests/test_retriever.py index 2212d44fd..eb9260590 100644 --- a/integrations/astra/tests/test_retriever.py +++ b/integrations/astra/tests/test_retriever.py @@ -5,7 +5,7 @@ import pytest -from astra_haystack.retriever import AstraRetriever +from haystack_integrations.components.retrievers.astra import AstraRetriever @pytest.mark.skipif( @@ -16,7 +16,7 @@ def test_retriever_to_json(document_store): retriever = AstraRetriever(document_store, filters={"foo": "bar"}, top_k=99) assert retriever.to_dict() == { - "type": "astra_haystack.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", "init_parameters": { "filters": {"foo": "bar"}, "top_k": 99, @@ -30,7 +30,7 @@ def test_retriever_to_json(document_store): "embedding_dim": 768, "similarity": "cosine", }, - "type": "astra_haystack.document_store.AstraDocumentStore", + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, } @@ -43,7 +43,7 @@ def test_retriever_to_json(document_store): @pytest.mark.integration def test_retriever_from_json(): data = { - "type": "astra_haystack.retriever.AstraRetriever", + "type": "haystack_integrations.components.retrievers.astra.retriever.AstraRetriever", "init_parameters": { "filters": {"bar": "baz"}, "top_k": 42, @@ -58,7 +58,7 @@ def test_retriever_from_json(): "embedding_dim": 768, "similarity": "cosine", }, - "type": "astra_haystack.document_store.AstraDocumentStore", + "type": "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore", }, }, } diff --git a/integrations/chroma/pydoc/config.yml b/integrations/chroma/pydoc/config.yml new file mode 100644 index 000000000..fd362d7e0 --- /dev/null +++ b/integrations/chroma/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.chroma.retriever", + "haystack_integrations.document_stores.chroma.document_store", + "haystack_integrations.document_stores.chroma.errors", + "haystack_integrations.document_stores.chroma.utils", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Chroma integration for Haystack + category_slug: haystack-integrations + title: Chroma + slug: integrations-chroma + order: 1 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_chroma.md diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2e531005b..2653c491f 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ dependencies = [ "haystack-ai", "chromadb<0.4.20", # FIXME: investigate why filtering tests broke on 0.4.20 + "typing_extensions>=4.8.0", ] [project.urls] @@ -47,6 +48,7 @@ git_describe_command = 'git describe --tags --match="integrations/chroma-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -59,6 +61,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.9", "3.10"] diff --git a/integrations/cohere/pydoc/config.yml b/integrations/cohere/pydoc/config.yml new file mode 100644 index 000000000..9418739b5 --- /dev/null +++ b/integrations/cohere/pydoc/config.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.embedders.cohere.document_embedder", + "haystack_integrations.components.embedders.cohere.text_embedder", + "haystack_integrations.components.embedders.cohere.utils", + "haystack_integrations.components.generators.cohere.generator", + "haystack_integrations.components.generators.cohere.chat.chat_generator", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Cohere integration for Haystack + category_slug: haystack-integrations + title: Cohere + slug: integrations-cohere + order: 40 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_cohere.md diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index ae96f114d..332471674 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -34,6 +34,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/cohere-v(?P.*)' @@ -46,6 +49,7 @@ git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] @@ -70,7 +77,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/cohere_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -133,26 +140,23 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["cohere_haystack"] +known-first-party = ["src"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["cohere_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true -omit = [ - "src/cohere_haystack/__about__.py", -] [tool.coverage.paths] -cohere_haystack = ["src/cohere_haystack", "*/cohere-haystack/src/cohere_haystack"] -tests = ["tests", "*/cohere-haystack/tests"] +cohere_haystack = ["src/haystack_integrations", "*/cohere/src/haystack_integrations"] +tests = ["tests", "*/cohere/tests"] [tool.coverage.report] exclude_lines = [ @@ -165,6 +169,7 @@ exclude_lines = [ module = [ "cohere.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/__init__.py new file mode 100644 index 000000000..73a863a73 --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_embedder import CohereDocumentEmbedder +from .text_embedder import CohereTextEmbedder + +__all__ = ["CohereDocumentEmbedder", "CohereTextEmbedder"] diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py similarity index 98% rename from integrations/cohere/src/cohere_haystack/embedders/document_embedder.py rename to integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 151c4f794..69308ad19 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -5,10 +5,10 @@ import os from typing import Any, Dict, List, Optional -from cohere import COHERE_API_URL, AsyncClient, Client from haystack import Document, component, default_to_dict +from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response -from cohere_haystack.embedders.utils import get_async_response, get_response +from cohere import COHERE_API_URL, AsyncClient, Client @component diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py similarity index 98% rename from integrations/cohere/src/cohere_haystack/embedders/text_embedder.py rename to integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index bfef97dc3..2fa922004 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -5,10 +5,10 @@ import os from typing import Any, Dict, List, Optional -from cohere import COHERE_API_URL, AsyncClient, Client from haystack import component, default_to_dict +from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response -from cohere_haystack.embedders.utils import get_async_response, get_response +from cohere import COHERE_API_URL, AsyncClient, Client @component diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py similarity index 99% rename from integrations/cohere/src/cohere_haystack/embedders/utils.py rename to integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py index 1c1049852..7b9c90730 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/utils.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/utils.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict, List, Tuple -from cohere import AsyncClient, Client, CohereError from tqdm import tqdm +from cohere import AsyncClient, Client, CohereError + async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): all_embeddings: List[List[float]] = [] diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py new file mode 100644 index 000000000..93c0947e4 --- /dev/null +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .chat.chat_generator import CohereChatGenerator +from .generator import CohereGenerator + +__all__ = ["CohereGenerator", "CohereChatGenerator"] diff --git a/integrations/cohere/src/cohere_haystack/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py similarity index 100% rename from integrations/cohere/src/cohere_haystack/__init__.py rename to integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/__init__.py diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py similarity index 98% rename from integrations/cohere/src/cohere_haystack/chat/chat_generator.py rename to integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index 0ff29ce14..c632bed83 100644 --- a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) +@component class CohereChatGenerator: """Enables text generation using Cohere's chat endpoint. This component is designed to inference Cohere's chat models. @@ -123,10 +124,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": return default_from_dict(cls, data) def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: - if message.role == ChatRole.USER: - role = "User" - elif message.role == ChatRole.ASSISTANT: - role = "Chatbot" + role = "User" if message.role == ChatRole.USER else "Chatbot" chat_message = {"user_name": role, "text": message.content} return chat_message @@ -179,7 +177,6 @@ def _build_chunk(self, chunk) -> StreamingChunk: :param choice: The choice returned by the OpenAI API. :return: The StreamingChunk. """ - # if chunk.event_type == "text-generation": chat_message = StreamingChunk(content=chunk.text, meta={"index": chunk.index, "event_type": chunk.event_type}) return chat_message diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py similarity index 94% rename from integrations/cohere/src/cohere_haystack/generator.py rename to integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 9917f17ea..fee410eab 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -6,9 +6,11 @@ import sys from typing import Any, Callable, Dict, List, Optional, cast +from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack.dataclasses import StreamingChunk + from cohere import COHERE_API_URL, Client from cohere.responses import Generations -from haystack import DeserializationError, component, default_from_dict, default_to_dict logger = logging.getLogger(__name__) @@ -147,8 +149,8 @@ def run(self, prompt: str): if self.streaming_callback: metadata_dict: Dict[str, Any] = {} for chunk in response: - self.streaming_callback(chunk) - metadata_dict["index"] = chunk.index + stream_chunk = self._build_chunk(chunk) + self.streaming_callback(stream_chunk) replies = response.texts metadata_dict["finish_reason"] = response.finish_reason metadata = [metadata_dict] @@ -160,6 +162,15 @@ def run(self, prompt: str): self._check_truncated_answers(metadata) return {"replies": replies, "meta": metadata} + def _build_chunk(self, chunk) -> StreamingChunk: + """ + Converts the response from the Cohere API to a StreamingChunk. + :param chunk: The chunk returned by the OpenAI API. + :return: The StreamingChunk. + """ + streaming_chunk = StreamingChunk(content=chunk.text, meta={"index": chunk.index}) + return streaming_chunk + def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): """ Check the `finish_reason` returned with the Cohere response. diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index e93db51fd..c91ada419 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -5,8 +5,7 @@ import pytest from haystack.components.generators.utils import default_streaming_callback from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk - -from cohere_haystack.chat.chat_generator import CohereChatGenerator +from haystack_integrations.components.generators.cohere import CohereChatGenerator pytestmark = pytest.mark.chat_generators @@ -88,7 +87,7 @@ def test_to_dict_default(self): component = CohereChatGenerator(api_key="test-api-key") data = component.to_dict() assert data == { - "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "streaming_callback": None, @@ -108,7 +107,7 @@ def test_to_dict_with_parameters(self): ) data = component.to_dict() assert data == { - "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command-nightly", "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", @@ -128,7 +127,7 @@ def test_to_dict_with_lambda_streaming_callback(self): ) data = component.to_dict() assert data == { - "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "api_base_url": "test-base-url", @@ -141,7 +140,7 @@ def test_to_dict_with_lambda_streaming_callback(self): def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") data = { - "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "api_base_url": "test-base-url", @@ -159,7 +158,7 @@ def test_from_dict(self, monkeypatch): def test_from_dict_fail_wo_env_var(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) data = { - "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "api_base_url": "test-base-url", diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index f22b38843..e2ce10405 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -5,8 +5,7 @@ import pytest from cohere import COHERE_API_URL - -from cohere_haystack.generator import CohereGenerator +from haystack_integrations.components.generators.cohere import CohereGenerator pytestmark = pytest.mark.generators @@ -48,7 +47,7 @@ def test_to_dict_default(self): component = CohereGenerator(api_key="test-api-key") data = component.to_dict() assert data == { - "type": "cohere_haystack.generator.CohereGenerator", + "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command", "streaming_callback": None, @@ -67,7 +66,7 @@ def test_to_dict_with_parameters(self): ) data = component.to_dict() assert data == { - "type": "cohere_haystack.generator.CohereGenerator", + "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command-light", "max_tokens": 10, @@ -88,7 +87,7 @@ def test_to_dict_with_lambda_streaming_callback(self): ) data = component.to_dict() assert data == { - "type": "cohere_haystack.generator.CohereGenerator", + "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command", "streaming_callback": "tests.test_cohere_generators.", @@ -101,7 +100,7 @@ def test_to_dict_with_lambda_streaming_callback(self): def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "test-key") data = { - "type": "cohere_haystack.generator.CohereGenerator", + "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command", "max_tokens": 10, diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index c9770737e..efe8eb36a 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -6,8 +6,7 @@ import pytest from cohere import COHERE_API_URL from haystack import Document - -from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder +from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder pytestmark = pytest.mark.embedders @@ -60,7 +59,7 @@ def test_to_dict(self): embedder_component = CohereDocumentEmbedder(api_key="test-api-key") component_dict = embedder_component.to_dict() assert component_dict == { - "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model": "embed-english-v2.0", "input_type": "search_document", @@ -93,7 +92,7 @@ def test_to_dict_with_custom_init_parameters(self): ) component_dict = embedder_component.to_dict() assert component_dict == { - "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model": "embed-multilingual-v2.0", "input_type": "search_query", diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 7e91b4812..657d8df83 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -5,8 +5,7 @@ import pytest from cohere import COHERE_API_URL - -from cohere_haystack.embedders.text_embedder import CohereTextEmbedder +from haystack_integrations.components.embedders.cohere import CohereTextEmbedder pytestmark = pytest.mark.embedders @@ -57,7 +56,7 @@ def test_to_dict(self): embedder_component = CohereTextEmbedder(api_key="test-api-key") component_dict = embedder_component.to_dict() assert component_dict == { - "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { "model": "embed-english-v2.0", "input_type": "search_query", @@ -85,7 +84,7 @@ def test_to_dict_with_custom_init_parameters(self): ) component_dict = embedder_component.to_dict() assert component_dict == { - "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { "model": "embed-multilingual-v2.0", "input_type": "classification", diff --git a/integrations/elasticsearch/pydoc/config.yml b/integrations/elasticsearch/pydoc/config.yml new file mode 100644 index 000000000..dc5a090bc --- /dev/null +++ b/integrations/elasticsearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever", + "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever", + "haystack_integrations.document_stores.elasticsearch.document_store", + "haystack_integrations.document_stores.elasticsearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Elasticsearch integration for Haystack + category_slug: haystack-integrations + title: Elasticsearch + slug: integrations-elasticsearch + order: 50 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_elasticsearch.md \ No newline at end of file diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index af3d89c0c..b67df7e03 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index 91fcd655b..1127dc6bf 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -34,6 +34,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_ai_haystack/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_ai_haystack" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/google_ai-v(?P.*)' @@ -70,7 +73,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/google_ai_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -136,26 +139,22 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["google_ai_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["google_ai_haystack", "tests"] branch = true parallel = true -omit = [ - "src/google_ai_haystack/__about__.py", -] [tool.coverage.paths] -google_ai_haystack = ["src/google_ai_haystack", "*/google_ai_haystack/src/google_ai_haystack"] -tests = ["tests", "*/google_ai_haystack/tests"] +google_ai_haystack = ["src"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -167,6 +166,7 @@ exclude_lines = [ module = [ "google.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py new file mode 100644 index 000000000..2b77c813f --- /dev/null +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .chat.gemini import GoogleAIGeminiChatGenerator +from .gemini import GoogleAIGeminiGenerator + +__all__ = ["GoogleAIGeminiGenerator", "GoogleAIGeminiChatGenerator"] diff --git a/integrations/google_ai/src/google_ai_haystack/generators/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py similarity index 98% rename from integrations/google_ai/src/google_ai_haystack/generators/chat/gemini.py rename to integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 9bf33d8d3..030505860 100644 --- a/integrations/google_ai/src/google_ai_haystack/generators/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -21,7 +21,7 @@ class GoogleAIGeminiChatGenerator: Sample usage: ```python from haystack.dataclasses.chat_message import ChatMessage - from google_ai_haystack.generators.chat.gemini import GoogleAIGeminiChatGenerator + from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro", api_key="") @@ -43,7 +43,7 @@ class GoogleAIGeminiChatGenerator: from haystack.dataclasses.chat_message import ChatMessage from google.ai.generativelanguage import FunctionDeclaration, Tool - from google_ai_haystack.generators.chat.gemini import GoogleAIGeminiChatGenerator + from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator # Example function to get the current weather def get_current_weather(location: str, unit: str = "celsius") -> str: diff --git a/integrations/google_ai/src/google_ai_haystack/generators/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py similarity index 97% rename from integrations/google_ai/src/google_ai_haystack/generators/gemini.py rename to integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index bd4ab5150..bd4f1d5e6 100644 --- a/integrations/google_ai/src/google_ai_haystack/generators/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -20,7 +20,7 @@ class GoogleAIGeminiGenerator: Sample usage: ```python - from google_ai_haystack.generators.gemini import GoogleAIGeminiGenerator + from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator gemini = GoogleAIGeminiGenerator(model="gemini-pro", api_key="") res = gemini.run(parts = ["What is the most interesting thing you know?"]) @@ -32,7 +32,7 @@ class GoogleAIGeminiGenerator: ```python import requests from haystack.dataclasses.byte_stream import ByteStream - from google_ai_haystack.generators.gemini import GoogleAIGeminiGenerator + from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator BASE_URL = ( "https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations" diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index 16a2af236..7b2b80088 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -7,7 +7,7 @@ from google.generativeai.types import HarmBlockThreshold, HarmCategory from haystack.dataclasses.chat_message import ChatMessage -from google_ai_haystack.generators.chat.gemini import GoogleAIGeminiChatGenerator +from haystack_integrations.components.generators.google_ai import GoogleAIGeminiChatGenerator def test_init(): @@ -40,7 +40,9 @@ def test_init(): ) tool = Tool(function_declarations=[get_current_weather_func]) - with patch("google_ai_haystack.generators.chat.gemini.genai.configure") as mock_genai_configure: + with patch( + "haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure" + ) as mock_genai_configure: gemini = GoogleAIGeminiChatGenerator( generation_config=generation_config, safety_settings=safety_settings, @@ -85,14 +87,14 @@ def test_to_dict(): tool = Tool(function_declarations=[get_current_weather_func]) - with patch("google_ai_haystack.generators.chat.gemini.genai.configure"): + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator( generation_config=generation_config, safety_settings=safety_settings, tools=[tool], ) assert gemini.to_dict() == { - "type": "google_ai_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator", + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { "model": "gemini-pro-vision", "generation_config": { @@ -114,10 +116,10 @@ def test_to_dict(): def test_from_dict(): - with patch("google_ai_haystack.generators.chat.gemini.genai.configure"): + with patch("haystack_integrations.components.generators.google_ai.chat.gemini.genai.configure"): gemini = GoogleAIGeminiChatGenerator.from_dict( { - "type": "google_ai_haystack.generators.chat.gemini.GoogleAIGeminiChatGenerator", + "type": "haystack_integrations.components.generators.google_ai.chat.gemini.GoogleAIGeminiChatGenerator", "init_parameters": { "model": "gemini-pro-vision", "generation_config": { diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index c01c8b158..9ef818144 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -6,7 +6,7 @@ from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory -from google_ai_haystack.generators.gemini import GoogleAIGeminiGenerator +from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator def test_init(): @@ -39,7 +39,7 @@ def test_init(): ) tool = Tool(function_declarations=[get_current_weather_func]) - with patch("google_ai_haystack.generators.gemini.genai.configure") as mock_genai_configure: + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, @@ -84,14 +84,14 @@ def test_to_dict(): tool = Tool(function_declarations=[get_current_weather_func]) - with patch("google_ai_haystack.generators.gemini.genai.configure"): + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, tools=[tool], ) assert gemini.to_dict() == { - "type": "google_ai_haystack.generators.gemini.GoogleAIGeminiGenerator", + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", "init_parameters": { "model": "gemini-pro-vision", "generation_config": { @@ -113,10 +113,10 @@ def test_to_dict(): def test_from_dict(): - with patch("google_ai_haystack.generators.gemini.genai.configure"): + with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator.from_dict( { - "type": "google_ai_haystack.generators.gemini.GoogleAIGeminiGenerator", + "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", "init_parameters": { "model": "gemini-pro-vision", "generation_config": { diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index 1d15a4270..ecd509f15 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -33,6 +33,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_vertex" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/google_vertex-v(?P.*)' @@ -69,7 +72,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/google_vertex_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -132,26 +135,23 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["google_vertex_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["google_vertex_haystack", "tests"] +source_pkgs = ["haystack_integrations", "tests"] branch = true parallel = true -omit = [ - "src/google_vertex_haystack/__about__.py", -] [tool.coverage.paths] -google_vertex_haystack = ["src/google_vertex_haystack", "*/google_vertex/src/google_vertex_haystack"] -tests = ["tests", "*/google_vertex_haystack/tests"] +google_vertex_haystack = ["src/"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -164,6 +164,7 @@ exclude_lines = [ module = [ "vertexai.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py new file mode 100644 index 000000000..07c2a5260 --- /dev/null +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .captioner import VertexAIImageCaptioner +from .chat.gemini import VertexAIGeminiChatGenerator +from .code_generator import VertexAICodeGenerator +from .gemini import VertexAIGeminiGenerator +from .image_generator import VertexAIImageGenerator +from .question_answering import VertexAIImageQA +from .text_generator import VertexAITextGenerator + +__all__ = [ + "VertexAICodeGenerator", + "VertexAIGeminiGenerator", + "VertexAIGeminiChatGenerator", + "VertexAIImageCaptioner", + "VertexAIImageGenerator", + "VertexAIImageQA", + "VertexAITextGenerator", +] diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/captioner.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/captioner.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/__init__.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/chat/gemini.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/code_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/code_generator.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/gemini.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/image_generator.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/question_answering.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/question_answering.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/text_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py similarity index 100% rename from integrations/google_vertex/src/google_vertex_haystack/generators/text_generator.py rename to integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py diff --git a/integrations/google_vertex/tests/test_captioner.py b/integrations/google_vertex/tests/test_captioner.py index bc7e4f829..26249dbee 100644 --- a/integrations/google_vertex/tests/test_captioner.py +++ b/integrations/google_vertex/tests/test_captioner.py @@ -2,11 +2,11 @@ from haystack.dataclasses.byte_stream import ByteStream -from google_vertex_haystack.generators.captioner import VertexAIImageCaptioner +from haystack_integrations.components.generators.google_vertex import VertexAIImageCaptioner -@patch("google_vertex_haystack.generators.captioner.vertexai") -@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_init(mock_model_class, mock_vertexai): captioner = VertexAIImageCaptioner( model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" @@ -19,14 +19,14 @@ def test_init(mock_model_class, mock_vertexai): assert captioner._kwargs == {"number_of_results": 1, "language": "it"} -@patch("google_vertex_haystack.generators.captioner.vertexai") -@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): captioner = VertexAIImageCaptioner( model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" ) assert captioner.to_dict() == { - "type": "google_vertex_haystack.generators.captioner.VertexAIImageCaptioner", + "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -37,12 +37,12 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.captioner.vertexai") -@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_from_dict(_mock_model_class, _mock_vertexai): captioner = VertexAIImageCaptioner.from_dict( { - "type": "google_vertex_haystack.generators.captioner.VertexAIImageCaptioner", + "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -58,8 +58,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): assert captioner._model is not None -@patch("google_vertex_haystack.generators.captioner.vertexai") -@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model_class.from_pretrained.return_value = mock_model diff --git a/integrations/google_vertex/tests/test_code_generator.py b/integrations/google_vertex/tests/test_code_generator.py index c2a2e5aa9..129954062 100644 --- a/integrations/google_vertex/tests/test_code_generator.py +++ b/integrations/google_vertex/tests/test_code_generator.py @@ -2,11 +2,11 @@ from vertexai.language_models import TextGenerationResponse -from google_vertex_haystack.generators.code_generator import VertexAICodeGenerator +from haystack_integrations.components.generators.google_vertex import VertexAICodeGenerator -@patch("google_vertex_haystack.generators.code_generator.vertexai") -@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAICodeGenerator( model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 @@ -19,14 +19,14 @@ def test_init(mock_model_class, mock_vertexai): assert generator._kwargs == {"candidate_count": 3, "temperature": 0.5} -@patch("google_vertex_haystack.generators.code_generator.vertexai") -@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAICodeGenerator( model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 ) assert generator.to_dict() == { - "type": "google_vertex_haystack.generators.code_generator.VertexAICodeGenerator", + "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", "project_id": "myproject-123456", @@ -37,12 +37,12 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.code_generator.vertexai") -@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAICodeGenerator.from_dict( { - "type": "google_vertex_haystack.generators.code_generator.VertexAICodeGenerator", + "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", "project_id": "myproject-123456", @@ -58,8 +58,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): assert generator._model is not None -@patch("google_vertex_haystack.generators.code_generator.vertexai") -@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_run_calls_predict(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = TextGenerationResponse("answer", None) diff --git a/integrations/google_vertex/tests/test_image_generator.py b/integrations/google_vertex/tests/test_image_generator.py index 1c5381a48..42cc0a0a3 100644 --- a/integrations/google_vertex/tests/test_image_generator.py +++ b/integrations/google_vertex/tests/test_image_generator.py @@ -2,11 +2,11 @@ from vertexai.preview.vision_models import ImageGenerationResponse -from google_vertex_haystack.generators.image_generator import VertexAIImageGenerator +from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator -@patch("google_vertex_haystack.generators.image_generator.vertexai") -@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", @@ -25,8 +25,8 @@ def test_init(mock_model_class, mock_vertexai): } -@patch("google_vertex_haystack.generators.image_generator.vertexai") -@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", @@ -35,7 +35,7 @@ def test_to_dict(_mock_model_class, _mock_vertexai): number_of_images=3, ) assert generator.to_dict() == { - "type": "google_vertex_haystack.generators.image_generator.VertexAIImageGenerator", + "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -46,12 +46,12 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.image_generator.vertexai") -@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator.from_dict( { - "type": "google_vertex_haystack.generators.image_generator.VertexAIImageGenerator", + "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -70,8 +70,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.image_generator.vertexai") -@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.image_generator.ImageGenerationModel") def test_run_calls_generate_images(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.generate_images.return_value = ImageGenerationResponse(images=[]) diff --git a/integrations/google_vertex/tests/test_question_answering.py b/integrations/google_vertex/tests/test_question_answering.py index 3495afcb2..3f414f0e0 100644 --- a/integrations/google_vertex/tests/test_question_answering.py +++ b/integrations/google_vertex/tests/test_question_answering.py @@ -2,11 +2,11 @@ from haystack.dataclasses.byte_stream import ByteStream -from google_vertex_haystack.generators.question_answering import VertexAIImageQA +from haystack_integrations.components.generators.google_vertex import VertexAIImageQA -@patch("google_vertex_haystack.generators.question_answering.vertexai") -@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_init(mock_model_class, mock_vertexai): generator = VertexAIImageQA( model="imagetext", @@ -21,8 +21,8 @@ def test_init(mock_model_class, mock_vertexai): assert generator._kwargs == {"number_of_results": 3} -@patch("google_vertex_haystack.generators.question_answering.vertexai") -@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA( model="imagetext", @@ -30,7 +30,7 @@ def test_to_dict(_mock_model_class, _mock_vertexai): number_of_results=3, ) assert generator.to_dict() == { - "type": "google_vertex_haystack.generators.question_answering.VertexAIImageQA", + "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -40,12 +40,12 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.question_answering.vertexai") -@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA.from_dict( { - "type": "google_vertex_haystack.generators.question_answering.VertexAIImageQA", + "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", "project_id": "myproject-123456", @@ -60,8 +60,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): assert generator._kwargs == {"number_of_results": 3} -@patch("google_vertex_haystack.generators.question_answering.vertexai") -@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.question_answering.ImageTextModel") def test_run_calls_ask_question(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.ask_question.return_value = [] diff --git a/integrations/google_vertex/tests/test_text_generator.py b/integrations/google_vertex/tests/test_text_generator.py index f2edbfc3b..3e5248dc7 100644 --- a/integrations/google_vertex/tests/test_text_generator.py +++ b/integrations/google_vertex/tests/test_text_generator.py @@ -2,11 +2,11 @@ from vertexai.language_models import GroundingSource -from google_vertex_haystack.generators.text_generator import VertexAITextGenerator +from haystack_integrations.components.generators.google_vertex import VertexAITextGenerator -@patch("google_vertex_haystack.generators.text_generator.vertexai") -@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_init(mock_model_class, mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") generator = VertexAITextGenerator( @@ -20,15 +20,15 @@ def test_init(mock_model_class, mock_vertexai): assert generator._kwargs == {"temperature": 0.2, "grounding_source": grounding_source} -@patch("google_vertex_haystack.generators.text_generator.vertexai") -@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") generator = VertexAITextGenerator( model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source ) assert generator.to_dict() == { - "type": "google_vertex_haystack.generators.text_generator.VertexAITextGenerator", + "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", "project_id": "myproject-123456", @@ -47,12 +47,12 @@ def test_to_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.text_generator.vertexai") -@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_from_dict(_mock_model_class, _mock_vertexai): generator = VertexAITextGenerator.from_dict( { - "type": "google_vertex_haystack.generators.text_generator.VertexAITextGenerator", + "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", "project_id": "myproject-123456", @@ -79,8 +79,8 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } -@patch("google_vertex_haystack.generators.text_generator.vertexai") -@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = MagicMock() diff --git a/integrations/instructor_embedders/instructor_embedders_haystack/__init__.py b/integrations/instructor_embedders/instructor_embedders_haystack/__init__.py deleted file mode 100644 index 88e2e9df2..000000000 --- a/integrations/instructor_embedders/instructor_embedders_haystack/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from instructor_embedders_haystack.instructor_document_embedder import InstructorDocumentEmbedder -from instructor_embedders_haystack.instructor_text_embedder import InstructorTextEmbedder - -__all__ = ["InstructorDocumentEmbedder", "InstructorTextEmbedder"] diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 63fb9703b..c8a591b69 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "requests>=2.26.0", "scikit_learn>=1.0.2", "scipy", - "sentence_transformers>=2.2.0", + "sentence_transformers>=2.2.0,<2.3.0", "torch", "tqdm", "rich", @@ -54,6 +54,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/instructor_embedders" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/instructor_embedders-v(?P.*)' @@ -81,7 +84,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:instructor_embedders_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -99,7 +102,6 @@ all = [ [tool.coverage.run] branch = true parallel = true -omit = ["instructor_embedders/__about__.py"] [tool.coverage.report] exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] @@ -152,7 +154,7 @@ unfixable = [ known-first-party = ["instructor_embedders"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports @@ -172,6 +174,7 @@ module = [ "instructor_embedders_haystack.*", "InstructorEmbedding.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/__init__.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/__init__.py new file mode 100644 index 000000000..f68f20a81 --- /dev/null +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .instructor_document_embedder import InstructorDocumentEmbedder +from .instructor_text_embedder import InstructorTextEmbedder + +__all__ = ["InstructorDocumentEmbedder", "InstructorTextEmbedder"] diff --git a/integrations/cohere/src/cohere_haystack/embedders/__init__.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/__init__.py similarity index 100% rename from integrations/cohere/src/cohere_haystack/embedders/__init__.py rename to integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/__init__.py diff --git a/integrations/instructor_embedders/instructor_embedders_haystack/embedding_backend/instructor_backend.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py similarity index 100% rename from integrations/instructor_embedders/instructor_embedders_haystack/embedding_backend/instructor_backend.py rename to integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/embedding_backend/instructor_backend.py diff --git a/integrations/instructor_embedders/instructor_embedders_haystack/instructor_document_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py similarity index 98% rename from integrations/instructor_embedders/instructor_embedders_haystack/instructor_document_embedder.py rename to integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py index 7a40f43cd..34912a2a3 100644 --- a/integrations/instructor_embedders/instructor_embedders_haystack/instructor_document_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py @@ -5,7 +5,7 @@ from haystack import Document, component, default_from_dict, default_to_dict -from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory +from .embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory @component diff --git a/integrations/instructor_embedders/instructor_embedders_haystack/instructor_text_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py similarity index 97% rename from integrations/instructor_embedders/instructor_embedders_haystack/instructor_text_embedder.py rename to integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py index 5a2c66e65..39b8d6a29 100644 --- a/integrations/instructor_embedders/instructor_embedders_haystack/instructor_text_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_text_embedder.py @@ -5,7 +5,7 @@ from haystack import component, default_from_dict, default_to_dict -from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory +from .embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory @component diff --git a/integrations/instructor_embedders/tests/test_instructor_backend.py b/integrations/instructor_embedders/tests/test_instructor_backend.py index 27e31317a..f3fd1653a 100644 --- a/integrations/instructor_embedders/tests/test_instructor_backend.py +++ b/integrations/instructor_embedders/tests/test_instructor_backend.py @@ -1,9 +1,13 @@ from unittest.mock import patch -from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory +from haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend import ( + _InstructorEmbeddingBackendFactory, +) -@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") +@patch( + "haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR" +) def test_factory_behavior(mock_instructor): # noqa: ARG001 embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-large", device="cpu" @@ -20,7 +24,9 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 _InstructorEmbeddingBackendFactory._instances = {} -@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") +@patch( + "haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR" +) def test_model_initialization(mock_instructor): _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="huggingface_auth_token" @@ -32,7 +38,9 @@ def test_model_initialization(mock_instructor): _InstructorEmbeddingBackendFactory._instances = {} -@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") +@patch( + "haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR" +) def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-base" diff --git a/integrations/instructor_embedders/tests/test_instructor_document_embedder.py b/integrations/instructor_embedders/tests/test_instructor_document_embedder.py index b1d0d8fe6..4f01c1742 100644 --- a/integrations/instructor_embedders/tests/test_instructor_document_embedder.py +++ b/integrations/instructor_embedders/tests/test_instructor_document_embedder.py @@ -3,8 +3,7 @@ import numpy as np import pytest from haystack import Document - -from instructor_embedders_haystack.instructor_document_embedder import InstructorDocumentEmbedder +from haystack_integrations.components.embedders.instructor_embedders import InstructorDocumentEmbedder class TestInstructorDocumentEmbedder: @@ -55,7 +54,7 @@ def test_to_dict(self): embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-base") embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cpu", @@ -86,7 +85,7 @@ def test_to_dict_with_custom_init_parameters(self): ) embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cuda", @@ -105,7 +104,7 @@ def test_from_dict(self): Test deserialization of InstructorDocumentEmbedder from a dictionary, using default initialization parameters. """ embedder_dict = { - "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cpu", @@ -134,7 +133,7 @@ def test_from_dict_with_custom_init_parameters(self): Test deserialization of InstructorDocumentEmbedder from a dictionary, using custom initialization parameters. """ embedder_dict = { - "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cuda", @@ -158,7 +157,9 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - @patch("instructor_embedders_haystack.instructor_document_embedder._InstructorEmbeddingBackendFactory") + @patch( + "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder._InstructorEmbeddingBackendFactory" + ) def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. @@ -170,7 +171,9 @@ def test_warmup(self, mocked_factory): model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token=None ) - @patch("instructor_embedders_haystack.instructor_document_embedder._InstructorEmbeddingBackendFactory") + @patch( + "haystack_integrations.components.embedders.instructor_embedders.instructor_document_embedder._InstructorEmbeddingBackendFactory" + ) def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. diff --git a/integrations/instructor_embedders/tests/test_instructor_text_embedder.py b/integrations/instructor_embedders/tests/test_instructor_text_embedder.py index bc00f7348..edede9f39 100644 --- a/integrations/instructor_embedders/tests/test_instructor_text_embedder.py +++ b/integrations/instructor_embedders/tests/test_instructor_text_embedder.py @@ -2,8 +2,7 @@ import numpy as np import pytest - -from instructor_embedders_haystack.instructor_text_embedder import InstructorTextEmbedder +from haystack_integrations.components.embedders.instructor_embedders import InstructorTextEmbedder class TestInstructorTextEmbedder: @@ -48,7 +47,7 @@ def test_to_dict(self): embedder = InstructorTextEmbedder(model="hkunlp/instructor-base") embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cpu", @@ -75,7 +74,7 @@ def test_to_dict_with_custom_init_parameters(self): ) embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cuda", @@ -92,7 +91,7 @@ def test_from_dict(self): Test deserialization of InstructorTextEmbedder from a dictionary, using default initialization parameters. """ embedder_dict = { - "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cpu", @@ -117,7 +116,7 @@ def test_from_dict_with_custom_init_parameters(self): Test deserialization of InstructorTextEmbedder from a dictionary, using custom initialization parameters. """ embedder_dict = { - "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", + "type": "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", # noqa "init_parameters": { "model": "hkunlp/instructor-base", "device": "cuda", @@ -137,7 +136,9 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.progress_bar is False assert embedder.normalize_embeddings is True - @patch("instructor_embedders_haystack.instructor_text_embedder._InstructorEmbeddingBackendFactory") + @patch( + "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder._InstructorEmbeddingBackendFactory" + ) def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. @@ -149,7 +150,9 @@ def test_warmup(self, mocked_factory): model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token=None ) - @patch("instructor_embedders_haystack.instructor_text_embedder._InstructorEmbeddingBackendFactory") + @patch( + "haystack_integrations.components.embedders.instructor_embedders.instructor_text_embedder._InstructorEmbeddingBackendFactory" + ) def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index 0fa01a7ab..1136db797 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -31,6 +31,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/jina-v(?P.*)' @@ -67,7 +70,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/jina_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -133,7 +136,7 @@ unfixable = [ known-first-party = ["jina_haystack"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports @@ -143,12 +146,9 @@ ban-relative-imports = "all" source_pkgs = ["jina_haystack", "tests"] branch = true parallel = true -omit = [ - "src/jina_haystack/__about__.py", -] [tool.coverage.paths] -jina_haystack = ["src/jina_haystack", "*/jina-haystack/src/jina_haystack"] +jina_haystack = ["src"] tests = ["tests", "*/jina-haystack/tests"] [tool.coverage.report] @@ -161,6 +161,7 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ "haystack.*", + "haystack_integrations.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/jina/src/jina_haystack/__init__.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/__init__.py similarity index 57% rename from integrations/jina/src/jina_haystack/__init__.py rename to integrations/jina/src/haystack_integrations/components/embedders/jina/__init__.py index 581b23df5..c98f63398 100644 --- a/integrations/jina/src/jina_haystack/__init__.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/__init__.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 - -from jina_haystack.document_embedder import JinaDocumentEmbedder -from jina_haystack.text_embedder import JinaTextEmbedder +from .document_embedder import JinaDocumentEmbedder +from .text_embedder import JinaTextEmbedder __all__ = ["JinaDocumentEmbedder", "JinaTextEmbedder"] diff --git a/integrations/jina/src/jina_haystack/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py similarity index 100% rename from integrations/jina/src/jina_haystack/document_embedder.py rename to integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py diff --git a/integrations/jina/src/jina_haystack/text_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py similarity index 100% rename from integrations/jina/src/jina_haystack/text_embedder.py rename to integrations/jina/src/haystack_integrations/components/embedders/jina/text_embedder.py diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py index 43b6930c5..4dd91860e 100644 --- a/integrations/jina/tests/test_document_embedder.py +++ b/integrations/jina/tests/test_document_embedder.py @@ -7,8 +7,7 @@ import pytest import requests from haystack import Document - -from jina_haystack import JinaDocumentEmbedder +from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder def mock_session_post_response(*args, **kwargs): # noqa: ARG001 @@ -65,7 +64,7 @@ def test_to_dict(self): component = JinaDocumentEmbedder(api_key="fake-api-key") data = component.to_dict() assert data == { - "type": "jina_haystack.document_embedder.JinaDocumentEmbedder", + "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { "model": "jina-embeddings-v2-base-en", "prefix": "", @@ -90,7 +89,7 @@ def test_to_dict_with_custom_init_parameters(self): ) data = component.to_dict() assert data == { - "type": "jina_haystack.document_embedder.JinaDocumentEmbedder", + "type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder", "init_parameters": { "model": "model", "prefix": "prefix", diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py index e2b68603d..a4f6fd934 100644 --- a/integrations/jina/tests/test_text_embedder.py +++ b/integrations/jina/tests/test_text_embedder.py @@ -6,8 +6,7 @@ import pytest import requests - -from jina_haystack import JinaTextEmbedder +from haystack_integrations.components.embedders.jina import JinaTextEmbedder class TestJinaTextEmbedder: @@ -39,7 +38,7 @@ def test_to_dict(self): component = JinaTextEmbedder(api_key="fake-api-key") data = component.to_dict() assert data == { - "type": "jina_haystack.text_embedder.JinaTextEmbedder", + "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { "model": "jina-embeddings-v2-base-en", "prefix": "", @@ -56,7 +55,7 @@ def test_to_dict_with_custom_init_parameters(self): ) data = component.to_dict() assert data == { - "type": "jina_haystack.text_embedder.JinaTextEmbedder", + "type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder", "init_parameters": { "model": "model", "prefix": "prefix", diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py index 8ae482310..619a61cbe 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py @@ -17,7 +17,7 @@ class LlamaCppGenerator: Usage example: ```python from llama_cpp_haystack import LlamaCppGenerator - generator = LlamaCppGenerator(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512) + generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512) print(generator.run("Who is the best American actor?", generation_kwargs={"max_tokens": 128})) # {'replies': ['John Cusack'], 'meta': [{"object": "text_completion", ...}]} @@ -26,23 +26,23 @@ class LlamaCppGenerator: def __init__( self, - model_path: str, + model: str, n_ctx: Optional[int] = 0, n_batch: Optional[int] = 512, model_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None, ): """ - :param model_path: The path of a quantized model for text generation, + :param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf". - If the model_path is also specified in the `model_kwargs`, this parameter will be ignored. + If the model path is also specified in the `model_kwargs`, this parameter will be ignored. :param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model. If the n_ctx is also specified in the `model_kwargs`, this parameter will be ignored. :param n_batch: Prompt processing maximum batch size. Defaults to 512. If the n_batch is also specified in the `model_kwargs`, this parameter will be ignored. :param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation. These keyword arguments provide fine-grained control over the model loading. - In case of duplication, these kwargs override `model_path`, `n_ctx`, and `n_batch` init parameters. + In case of duplication, these kwargs override `model`, `n_ctx`, and `n_batch` init parameters. See Llama.cpp's [documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__) for more information on the available kwargs. :param generation_kwargs: A dictionary containing keyword arguments to customize text generation. @@ -56,11 +56,11 @@ def __init__( # check if the huggingface_pipeline_kwargs contain the essential parameters # otherwise, populate them with values from init parameters - model_kwargs.setdefault("model_path", model_path) + model_kwargs.setdefault("model_path", model) model_kwargs.setdefault("n_ctx", n_ctx) model_kwargs.setdefault("n_batch", n_batch) - self.model_path = model_path + self.model_path = model self.n_ctx = n_ctx self.n_batch = n_batch self.model_kwargs = model_kwargs diff --git a/integrations/llama_cpp/tests/test_generator.py b/integrations/llama_cpp/tests/test_generator.py index 0b95c03a4..04b8339e5 100644 --- a/integrations/llama_cpp/tests/test_generator.py +++ b/integrations/llama_cpp/tests/test_generator.py @@ -40,14 +40,14 @@ def generator(self, model_path, capsys): download_file(ggml_model_path, str(model_path / filename), capsys) model_path = str(model_path / filename) - generator = LlamaCppGenerator(model_path=model_path, n_ctx=128, n_batch=128) + generator = LlamaCppGenerator(model=model_path, n_ctx=128, n_batch=128) generator.warm_up() return generator @pytest.fixture def generator_mock(self): mock_model = MagicMock() - generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=2048, n_batch=512) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512) generator.model = mock_model return generator, mock_model @@ -55,7 +55,7 @@ def test_default_init(self): """ Test default initialization parameters. """ - generator = LlamaCppGenerator(model_path="test_model.gguf") + generator = LlamaCppGenerator(model="test_model.gguf") assert generator.model_path == "test_model.gguf" assert generator.n_ctx == 0 @@ -68,7 +68,7 @@ def test_custom_init(self): Test custom initialization parameters. """ generator = LlamaCppGenerator( - model_path="test_model.gguf", + model="test_model.gguf", n_ctx=2048, n_batch=512, ) @@ -84,7 +84,7 @@ def test_ignores_model_path_if_specified_in_model_kwargs(self): Test that model_path is ignored if already specified in model_kwargs. """ generator = LlamaCppGenerator( - model_path="test_model.gguf", + model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"model_path": "other_model.gguf"}, @@ -95,25 +95,21 @@ def test_ignores_n_ctx_if_specified_in_model_kwargs(self): """ Test that n_ctx is ignored if already specified in model_kwargs. """ - generator = LlamaCppGenerator( - model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024} - ) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024}) assert generator.model_kwargs["n_ctx"] == 1024 def test_ignores_n_batch_if_specified_in_model_kwargs(self): """ Test that n_batch is ignored if already specified in model_kwargs. """ - generator = LlamaCppGenerator( - model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024} - ) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024}) assert generator.model_kwargs["n_batch"] == 1024 def test_raises_error_without_warm_up(self): """ Test that the generator raises an error if warm_up() is not called before running. """ - generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=512, n_batch=512) + generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512) with pytest.raises(RuntimeError): generator.run("What is the capital of China?") diff --git a/integrations/ollama/pydoc/config.yml b/integrations/ollama/pydoc/config.yml new file mode 100644 index 000000000..768694991 --- /dev/null +++ b/integrations/ollama/pydoc/config.yml @@ -0,0 +1,29 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.generators.ollama.generator", + "haystack_integrations.components.generators.ollama.chat.chat_generator" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Ollama integration for Haystack + category_slug: haystack-integrations + title: Ollama + slug: integrations-ollama + order: 120 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_ollama.md \ No newline at end of file diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 69cc2ed16..e3bb738b6 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -48,6 +48,7 @@ git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -60,6 +61,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] diff --git a/integrations/opensearch/pydoc/config.yml b/integrations/opensearch/pydoc/config.yml new file mode 100644 index 000000000..3e07f6625 --- /dev/null +++ b/integrations/opensearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.opensearch.bm25_retriever", + "haystack_integrations.components.retrievers.opensearch.embedding_retriever", + "haystack_integrations.document_stores.opensearch.document_store", + "haystack_integrations.document_stores.opensearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: OpenSearch integration for Haystack + category_slug: haystack-integrations + title: OpenSearch + slug: integrations-opensearch + order: 130 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_opensearch.md diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 3edd544a2..794fa73fa 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -62,6 +63,10 @@ cov = [ "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] + [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] diff --git a/integrations/pgvector/LICENSE.txt b/integrations/pgvector/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/pgvector/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/pgvector/README.md b/integrations/pgvector/README.md new file mode 100644 index 000000000..277c732f4 --- /dev/null +++ b/integrations/pgvector/README.md @@ -0,0 +1,31 @@ +# pgvector-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/pgvector-haystack.svg)](https://pypi.org/project/pgvector-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pgvector-haystack.svg)](https://pypi.org/project/pgvector-haystack) + +--- + +**Table of Contents** + +- [pgvector-haystack](#pgvector-haystack) + - [Installation](#installation) + - [Testing](#testing) + - [License](#license) + +## Installation + +```console +pip install pgvector-haystack +``` + +## Testing + +TODO + +```console +hatch run test +``` + +## License + +`pgvector-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/pgvector/pydoc/config.yml b/integrations/pgvector/pydoc/config.yml new file mode 100644 index 000000000..79974b4a1 --- /dev/null +++ b/integrations/pgvector/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.pgvector.embedding_retriever", + "haystack_integrations.document_stores.pgvector.document_store", + "haystack_integrations.document_stores.pgvector.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Pgvector integration for Haystack + category_slug: haystack-integrations + title: Pgvector + slug: integrations-pgvector + order: 140 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_pgvector.md diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml new file mode 100644 index 000000000..10ef5d314 --- /dev/null +++ b/integrations/pgvector/pyproject.toml @@ -0,0 +1,182 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "pgvector-haystack" +dynamic = ["version"] +description = "An integration of pgvector (vector search extension for Postgres) with Haystack" +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "pgvector", + "psycopg[binary]" +] + +[project.urls] +Source = "https://github.com/deepset-ai/haystack-core-integrations" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/pgvector/README.md" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/pgvector-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/pgvector-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "ipython", + "haystack-pydoc-tools", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] +docs = [ + "pydoc-markdown pydoc/config.yml" +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["src"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["src", "tests"] +branch = true +parallel = true + + +[tool.coverage.paths] +weaviate_haystack = ["src/haystack_integrations", "*/pgvector-haystack/src"] +tests = ["tests", "*/pgvector-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pgvector.*", + "psycopg.*", + "pytest.*" +] +ignore_missing_imports = true diff --git a/integrations/google_vertex/src/google_vertex_haystack/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py similarity index 51% rename from integrations/google_vertex/src/google_vertex_haystack/__init__.py rename to integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py index e873bc332..ec0cf0dc4 100644 --- a/integrations/google_vertex/src/google_vertex_haystack/__init__.py +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .embedding_retriever import PgvectorEmbeddingRetriever + +__all__ = ["PgvectorEmbeddingRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py new file mode 100644 index 000000000..26807e9bd --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS + + +@component +class PgvectorEmbeddingRetriever: + """ + Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. + + Needs to be connected to the PgvectorDocumentStore. + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Create the PgvectorEmbeddingRetriever component. + + :param document_store: An instance of PgvectorDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param vector_function: The similarity function to use when searching for similar embeddings. + Defaults to the one set in the `document_store` instance. + "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: if the document store is using the "hnsw" search strategy, the vector function + should match the one utilized during index creation to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + + :raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.vector_function = vector_function or document_store.vector_function + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + vector_function=self.vector_function, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PgvectorDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Retrieve documents from the PgvectorDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param vector_function: The similarity function to use when searching for similar embeddings. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + vector_function = vector_function or self.vector_function + + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + vector_function=vector_function, + ) + return {"documents": docs} diff --git a/integrations/google_vertex/src/google_vertex_haystack/generators/__init__.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/__init__.py similarity index 55% rename from integrations/google_vertex/src/google_vertex_haystack/generators/__init__.py rename to integrations/pgvector/src/haystack_integrations/document_stores/pgvector/__init__.py index e873bc332..613962549 100644 --- a/integrations/google_vertex/src/google_vertex_haystack/generators/__init__.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/__init__.py @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .document_store import PgvectorDocumentStore + +__all__ = ["PgvectorDocumentStore"] diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py new file mode 100644 index 000000000..097e86c7e --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -0,0 +1,519 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Any, Dict, List, Literal, Optional + +from haystack import default_to_dict +from haystack.dataclasses.document import ByteStream, Document +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.filters import convert +from psycopg import Error, IntegrityError, connect +from psycopg.abc import Query +from psycopg.cursor import Cursor +from psycopg.rows import dict_row +from psycopg.sql import SQL, Identifier +from psycopg.sql import Literal as SQLLiteral +from psycopg.types.json import Jsonb + +from pgvector.psycopg import register_vector + +from .filters import _convert_filters_to_where_clause_and_params + +logger = logging.getLogger(__name__) + +CREATE_TABLE_STATEMENT = """ +CREATE TABLE IF NOT EXISTS {table_name} ( +id VARCHAR(128) PRIMARY KEY, +embedding VECTOR({embedding_dimension}), +content TEXT, +dataframe JSONB, +blob_data BYTEA, +blob_meta JSONB, +blob_mime_type VARCHAR(255), +meta JSONB) +""" + +INSERT_STATEMENT = """ +INSERT INTO {table_name} +(id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) +VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) +""" + +UPDATE_STATEMENT = """ +ON CONFLICT (id) DO UPDATE SET +embedding = EXCLUDED.embedding, +content = EXCLUDED.content, +dataframe = EXCLUDED.dataframe, +blob_data = EXCLUDED.blob_data, +blob_meta = EXCLUDED.blob_meta, +blob_mime_type = EXCLUDED.blob_mime_type, +meta = EXCLUDED.meta +""" + +VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"] + +VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { + "cosine_similarity": "vector_cosine_ops", + "inner_product": "vector_ip_ops", + "l2_distance": "vector_l2_ops", +} + +HNSW_INDEX_CREATION_VALID_KWARGS = ["m", "ef_construction"] + +HNSW_INDEX_NAME = "haystack_hnsw_index" + + +class PgvectorDocumentStore: + def __init__( + self, + *, + connection_string: str, + table_name: str = "haystack_documents", + embedding_dimension: int = 768, + vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", + recreate_table: bool = False, + search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", + hnsw_recreate_index_if_exists: bool = False, + hnsw_index_creation_kwargs: Optional[Dict[str, int]] = None, + hnsw_ef_search: Optional[int] = None, + ): + """ + Creates a new PgvectorDocumentStore instance. + It is meant to be connected to a PostgreSQL database with the pgvector extension installed. + A specific table to store Haystack documents will be created if it doesn't exist yet. + + :param connection_string: The connection string to use to connect to the PostgreSQL database. + e.g. "postgresql://USER:PASSWORD@HOST:PORT/DB_NAME" + :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". + :param embedding_dimension: The dimension of the embedding. Defaults to 768. + :param vector_function: The similarity function to use when searching for similar embeddings. + Defaults to "cosine_similarity". "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + :param recreate_table: Whether to recreate the table if it already exists. Defaults to False. + :param search_strategy: The search strategy to use when searching for similar embeddings. + Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy, + which trades off some accuracy for speed; it is recommended for large numbers of documents. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. + :type search_strategy: Literal["exact_nearest_neighbor", "hnsw"] + :param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists. + Defaults to False. Only used if search_strategy is set to "hnsw". + :param hnsw_index_creation_kwargs: Additional keyword arguments to pass to the HNSW index creation. + Only used if search_strategy is set to "hnsw". You can find the list of valid arguments in the + pgvector documentation: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + :param hnsw_ef_search: The ef_search parameter to use at query time. Only used if search_strategy is set to + "hnsw". You can find more information about this parameter in the pgvector documentation: + https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw + """ + + self.connection_string = connection_string + self.table_name = table_name + self.embedding_dimension = embedding_dimension + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) + self.vector_function = vector_function + self.recreate_table = recreate_table + self.search_strategy = search_strategy + self.hnsw_recreate_index_if_exists = hnsw_recreate_index_if_exists + self.hnsw_index_creation_kwargs = hnsw_index_creation_kwargs or {} + self.hnsw_ef_search = hnsw_ef_search + + connection = connect(connection_string) + connection.autocommit = True + self._connection = connection + + # we create a generic cursor and another one that returns dictionaries + self._cursor = connection.cursor() + self._dict_cursor = connection.cursor(row_factory=dict_row) + + connection.execute("CREATE EXTENSION IF NOT EXISTS vector") + register_vector(connection) + + if recreate_table: + self.delete_table() + self._create_table_if_not_exists() + + if search_strategy == "hnsw": + self._handle_hnsw() + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + connection_string=self.connection_string, + table_name=self.table_name, + embedding_dimension=self.embedding_dimension, + vector_function=self.vector_function, + recreate_table=self.recreate_table, + search_strategy=self.search_strategy, + hnsw_recreate_index_if_exists=self.hnsw_recreate_index_if_exists, + hnsw_index_creation_kwargs=self.hnsw_index_creation_kwargs, + hnsw_ef_search=self.hnsw_ef_search, + ) + + def _execute_sql( + self, sql_query: Query, params: Optional[tuple] = None, error_msg: str = "", cursor: Optional[Cursor] = None + ): + """ + Internal method to execute SQL statements and handle exceptions. + + :param sql_query: The SQL query to execute. + :param params: The parameters to pass to the SQL query. + :param error_msg: The error message to use if an exception is raised. + :param cursor: The cursor to use to execute the SQL query. Defaults to self._cursor. + """ + + params = params or () + cursor = cursor or self._cursor + + sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) + + try: + result = cursor.execute(sql_query, params) + except Error as e: + self._connection.rollback() + detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs." + raise DocumentStoreError(detailed_error_msg) from e + + return result + + def _create_table_if_not_exists(self): + """ + Creates the table to store Haystack documents if it doesn't exist yet. + """ + + create_sql = SQL(CREATE_TABLE_STATEMENT).format( + table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + ) + + self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") + + def delete_table(self): + """ + Deletes the table used to store Haystack documents. + """ + + delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) + + self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + + def _handle_hnsw(self): + """ + Internal method to handle the HNSW index creation. + It also sets the hnsw.ef_search parameter for queries if it is specified. + """ + + if self.hnsw_ef_search: + sql_set_hnsw_ef_search = SQL("SET hnsw.ef_search = {hnsw_ef_search}").format( + hnsw_ef_search=SQLLiteral(self.hnsw_ef_search) + ) + self._execute_sql(sql_set_hnsw_ef_search, error_msg="Could not set hnsw.ef_search") + + index_esists = bool( + self._execute_sql( + "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", + (self.table_name, HNSW_INDEX_NAME), + "Could not check if HNSW index exists", + ).fetchone() + ) + + if index_esists and not self.hnsw_recreate_index_if_exists: + logger.warning( + "HNSW index already exists and won't be recreated. " + "If you want to recreate it, pass 'hnsw_recreate_index_if_exists=True' to the " + "Document Store constructor" + ) + return + + sql_drop_index = SQL("DROP INDEX IF EXISTS {index_name}").format(index_name=Identifier(HNSW_INDEX_NAME)) + self._execute_sql(sql_drop_index, error_msg="Could not drop HNSW index") + + self._create_hnsw_index() + + def _create_hnsw_index(self): + """ + Internal method to create the HNSW index. + """ + + pg_ops = VECTOR_FUNCTION_TO_POSTGRESQL_OPS[self.vector_function] + actual_hnsw_index_creation_kwargs = { + key: value + for key, value in self.hnsw_index_creation_kwargs.items() + if key in HNSW_INDEX_CREATION_VALID_KWARGS + } + + sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( + index_name=Identifier(HNSW_INDEX_NAME), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + ) + + if actual_hnsw_index_creation_kwargs: + actual_hnsw_index_creation_kwargs_str = ", ".join( + f"{key} = {value}" for key, value in actual_hnsw_index_creation_kwargs.items() + ) + sql_add_creation_kwargs = SQL("WITH ({creation_kwargs_str})").format( + creation_kwargs_str=SQL(actual_hnsw_index_creation_kwargs_str) + ) + sql_create_index += sql_add_creation_kwargs + + self._execute_sql(sql_create_index, error_msg="Could not create HNSW index") + + def count_documents(self) -> int: + """ + Returns how many documents are present in the document store. + """ + + sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + + count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ + 0 + ] + return count + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + For a detailed specification of the filters, + refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) + + :param filters: The filters to apply to the document list. + :return: A list of Documents that match the given filters. + """ + if filters: + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise TypeError(msg) + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + sql_filter += sql_where_clause + + result = self._execute_sql( + sql_filter, + params, + error_msg="Could not filter documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes documents into to PgvectorDocumentStore. + + :param documents: A list of Documents to write to the document store. + :param policy: The duplicate policy to use when writing documents. + :raises DuplicateDocumentError: If a document with the same id already exists in the document store + and the policy is set to DuplicatePolicy.FAIL (or not specified). + :return: The number of documents written to the document store. + """ + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy == DuplicatePolicy.NONE: + policy = DuplicatePolicy.FAIL + + db_documents = self._from_haystack_to_pg_documents(documents) + + sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + + if policy == DuplicatePolicy.OVERWRITE: + sql_insert += SQL(UPDATE_STATEMENT) + elif policy == DuplicatePolicy.SKIP: + sql_insert += SQL("ON CONFLICT DO NOTHING") + + sql_insert += SQL(" RETURNING id") + + sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) + + try: + self._cursor.executemany(sql_insert, db_documents, returning=True) + except IntegrityError as ie: + self._connection.rollback() + raise DuplicateDocumentError from ie + except Error as e: + self._connection.rollback() + error_msg = ( + "Could not write documents to PgvectorDocumentStore. \n" + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(error_msg) from e + + # get the number of the inserted documents, inspired by psycopg3 docs + # https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany + written_docs = 0 + while True: + if self._cursor.fetchone(): + written_docs += 1 + if not self._cursor.nextset(): + break + + return written_docs + + def _from_haystack_to_pg_documents(self, documents: List[Document]) -> List[Dict[str, Any]]: + """ + Internal method to convert a list of Haystack Documents to a list of dictionaries that can be used to insert + documents into the PgvectorDocumentStore. + """ + + db_documents = [] + for document in documents: + db_document = {k: v for k, v in document.to_dict(flatten=False).items() if k not in ["score", "blob"]} + + blob = document.blob + db_document["blob_data"] = blob.data if blob else None + db_document["blob_meta"] = Jsonb(blob.meta) if blob and blob.meta else None + db_document["blob_mime_type"] = blob.mime_type if blob and blob.mime_type else None + + db_document["dataframe"] = Jsonb(db_document["dataframe"]) if db_document["dataframe"] else None + db_document["meta"] = Jsonb(db_document["meta"]) + + db_documents.append(db_document) + + return db_documents + + def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> List[Document]: + """ + Internal method to convert a list of dictionaries from pgvector to a list of Haystack Documents. + """ + + haystack_documents = [] + for document in documents: + haystack_dict = dict(document) + blob_data = haystack_dict.pop("blob_data") + blob_meta = haystack_dict.pop("blob_meta") + blob_mime_type = haystack_dict.pop("blob_mime_type") + + # postgresql returns the embedding as a string + # so we need to convert it to a list of floats + if document.get("embedding"): + haystack_dict["embedding"] = [float(el) for el in document["embedding"].strip("[]").split(",")] + + haystack_document = Document.from_dict(haystack_dict) + + if blob_data: + blob = ByteStream(data=blob_data, meta=blob_meta, mime_type=blob_mime_type) + haystack_document.blob = blob + + haystack_documents.append(haystack_document) + + return haystack_documents + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the document store. + + :param document_ids: the document ids to delete + """ + + if not document_ids: + return + + document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) + + delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( + table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + ) + + self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + + This method is not meant to be part of the public interface of + `PgvectorDocumentStore` and it should not be called directly. + `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. + :raises ValueError + :return: List of Documents that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + if len(query_embedding) != self.embedding_dimension: + msg = ( + f"query_embedding dimension ({len(query_embedding)}) does not match PgvectorDocumentStore " + f"embedding dimension ({self.embedding_dimension})." + ) + raise ValueError(msg) + + vector_function = vector_function or self.vector_function + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) + + # the vector must be a string with this format: "'[3,1,2]'" + query_embedding_for_postgres = f"'[{','.join(str(el) for el in query_embedding)}]'" + + # to compute the scores, we use the approach described in pgvector README: + # https://github.com/pgvector/pgvector?tab=readme-ov-file#distances + # cosine_similarity and inner_product are modified from the result of the operator + if vector_function == "cosine_similarity": + score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score" + elif vector_function == "inner_product": + score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score" + elif vector_function == "l2_distance": + score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" + + sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + table_name=Identifier(self.table_name), + score=SQL(score_definition), + ) + + sql_where_clause = SQL("") + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + + # we always want to return the most similar documents first + # so when using l2_distance, the sort order must be ASC + sort_order = "ASC" if vector_function == "l2_distance" else "DESC" + + sql_sort = SQL(" ORDER BY score {sort_order} LIMIT {top_k}").format( + top_k=SQLLiteral(top_k), + sort_order=SQL(sort_order), + ) + + sql_query = sql_select + sql_where_clause + sql_sort + + result = self._execute_sql( + sql_query, + params, + error_msg="Could not retrieve documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py new file mode 100644 index 000000000..daa90f502 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from itertools import chain +from typing import Any, Dict, List + +from haystack.errors import FilterError +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + +# we need this mapping to cast meta values to the correct type, +# since they are stored in the JSONB field as strings. +# this dict can be extended if needed +PYTHON_TYPES_TO_PG_TYPES = { + int: "integer", + float: "real", + bool: "boolean", +} + +NO_VALUE = "no_value" + + +def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]: + """ + Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. + """ + if "field" in filters: + query, values = _parse_comparison_condition(filters) + else: + query, values = _parse_logical_condition(filters) + + where_clause = SQL(" WHERE ") + SQL(query) + params = tuple(value for value in values if value != NO_VALUE) + + return where_clause, params + + +def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + if operator not in ["AND", "OR"]: + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR'" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + query, vals = _parse_comparison_condition(c) + else: + query, vals = _parse_logical_condition(c) + conditions.append((query, vals)) + + query_parts, values = [], [] + for c in conditions: + query_parts.append(c[0]) + values.append(c[1]) + if isinstance(values[0], list): + values = list(chain.from_iterable(values)) + + if operator == "AND": + sql_query = f"({' AND '.join(query_parts)})" + elif operator == "OR": + sql_query = f"({' OR '.join(query_parts)})" + else: + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + return sql_query, values + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown comparison operator '{operator}'. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise FilterError(msg) + + value: Any = condition["value"] + if isinstance(value, DataFrame): + # DataFrames are stored as JSONB and we query them as such + value = Jsonb(value.to_json()) + field = f"({field})::jsonb" + + if field.startswith("meta."): + field = _treat_meta_field(field, value) + + field, value = COMPARISON_OPERATORS[operator](field, value) + return field, [value] + + +def _treat_meta_field(field: str, value: Any) -> str: + """ + Internal method that modifies the field str + to make the meta JSONB field queryable. + + Examples: + >>> _treat_meta_field(field="meta.number", value=9) + "(meta->>'number')::integer" + + >>> _treat_meta_field(field="meta.name", value="my_name") + "meta->>'name'" + """ + + # use the ->> operator to access keys in the meta JSONB field + field_name = field.split(".", 1)[-1] + field = f"meta->>'{field_name}'" + + # meta fields are stored as strings in the JSONB field, + # so we need to cast them to the correct type + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value)) + if isinstance(value, list) and len(value) > 0: + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0])) + + if type_value: + field = f"({field})::{type_value}" + + return field + + +def _equal(field: str, value: Any) -> tuple[str, Any]: + if value is None: + # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params + return f"{field} IS NULL", NO_VALUE + return f"{field} = %s", value + + +def _not_equal(field: str, value: Any) -> tuple[str, Any]: + # we use IS DISTINCT FROM to correctly handle NULL values + # (not handled by !=) + return f"{field} IS DISTINCT FROM %s", value + + +def _greater_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} > %s", value + + +def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} >= %s", value + + +def _less_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} < %s", value + + +def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} <= %s", value + + +def _not_in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return f"{field} IS NULL OR {field} != ALL(%s)", [value] + + +def _in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation + return f"{field} = ANY(%s)", [value] + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/google_ai/src/google_ai_haystack/__init__.py b/integrations/pgvector/tests/__init__.py similarity index 100% rename from integrations/google_ai/src/google_ai_haystack/__init__.py rename to integrations/pgvector/tests/__init__.py diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py new file mode 100644 index 000000000..743e8de14 --- /dev/null +++ b/integrations/pgvector/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.fixture +def document_store(request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "exact_nearest_neighbor" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py new file mode 100644 index 000000000..e8d9107d7 --- /dev/null +++ b/integrations/pgvector/tests/test_document_store.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import ByteStream, Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from pandas import DataFrame + + +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + def test_write_documents(self, document_store: PgvectorDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + with pytest.raises(DuplicateDocumentError): + document_store.write_documents(docs, DuplicatePolicy.FAIL) + + def test_write_blob(self, document_store: PgvectorDocumentStore): + bytestream = ByteStream(b"test", meta={"meta_key": "meta_value"}, mime_type="mime_type") + docs = [Document(id="1", blob=bytestream)] + document_store.write_documents(docs) + + # TODO: update when filters are implemented + retrieved_docs = document_store.filter_documents() + assert retrieved_docs == docs + + def test_write_dataframe(self, document_store: PgvectorDocumentStore): + dataframe = DataFrame({"col1": [1, 2], "col2": [3, 4]}) + docs = [Document(id="1", dataframe=dataframe)] + + document_store.write_documents(docs) + + # TODO: update when filters are implemented + retrieved_docs = document_store.filter_documents() + assert retrieved_docs == docs + + def test_init(self): + document_store = PgvectorDocumentStore( + connection_string="postgresql://postgres:postgres@localhost:5432/postgres", + table_name="my_table", + embedding_dimension=512, + vector_function="l2_distance", + recreate_table=True, + search_strategy="hnsw", + hnsw_recreate_index_if_exists=True, + hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_ef_search=50, + ) + + assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert document_store.table_name == "my_table" + assert document_store.embedding_dimension == 512 + assert document_store.vector_function == "l2_distance" + assert document_store.recreate_table + assert document_store.search_strategy == "hnsw" + assert document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {"m": 32, "ef_construction": 128} + assert document_store.hnsw_ef_search == 50 + + def test_to_dict(self): + document_store = PgvectorDocumentStore( + connection_string="postgresql://postgres:postgres@localhost:5432/postgres", + table_name="my_table", + embedding_dimension=512, + vector_function="l2_distance", + recreate_table=True, + search_strategy="hnsw", + hnsw_recreate_index_if_exists=True, + hnsw_index_creation_kwargs={"m": 32, "ef_construction": 128}, + hnsw_ef_search=50, + ) + + assert document_store.to_dict() == { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "my_table", + "embedding_dimension": 512, + "vector_function": "l2_distance", + "recreate_table": True, + "search_strategy": "hnsw", + "hnsw_recreate_index_if_exists": True, + "hnsw_index_creation_kwargs": {"m": 32, "ef_construction": 128}, + "hnsw_ef_search": 50, + }, + } + + def test_from_haystack_to_pg_documents(self): + haystack_docs = [ + Document( + id="1", + content="This is a text", + meta={"meta_key": "meta_value"}, + embedding=[0.1, 0.2, 0.3], + score=0.5, + ), + Document( + id="2", + dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4]}), + meta={"meta_key": "meta_value"}, + embedding=[0.4, 0.5, 0.6], + score=0.6, + ), + Document( + id="3", + blob=ByteStream(b"test", meta={"blob_meta_key": "blob_meta_value"}, mime_type="mime_type"), + meta={"meta_key": "meta_value"}, + embedding=[0.7, 0.8, 0.9], + score=0.7, + ), + ] + + with patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" + ) as mock_init: + mock_init.return_value = None + ds = PgvectorDocumentStore(connection_string="test") + + pg_docs = ds._from_haystack_to_pg_documents(haystack_docs) + + assert pg_docs[0]["id"] == "1" + assert pg_docs[0]["content"] == "This is a text" + assert pg_docs[0]["dataframe"] is None + assert pg_docs[0]["blob_data"] is None + assert pg_docs[0]["blob_meta"] is None + assert pg_docs[0]["blob_mime_type"] is None + assert pg_docs[0]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[0]["embedding"] == [0.1, 0.2, 0.3] + assert "score" not in pg_docs[0] + + assert pg_docs[1]["id"] == "2" + assert pg_docs[1]["content"] is None + assert pg_docs[1]["dataframe"].obj == DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json() + assert pg_docs[1]["blob_data"] is None + assert pg_docs[1]["blob_meta"] is None + assert pg_docs[1]["blob_mime_type"] is None + assert pg_docs[1]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[1]["embedding"] == [0.4, 0.5, 0.6] + assert "score" not in pg_docs[1] + + assert pg_docs[2]["id"] == "3" + assert pg_docs[2]["content"] is None + assert pg_docs[2]["dataframe"] is None + assert pg_docs[2]["blob_data"] == b"test" + assert pg_docs[2]["blob_meta"].obj == {"blob_meta_key": "blob_meta_value"} + assert pg_docs[2]["blob_mime_type"] == "mime_type" + assert pg_docs[2]["meta"].obj == {"meta_key": "meta_value"} + assert pg_docs[2]["embedding"] == [0.7, 0.8, 0.9] + assert "score" not in pg_docs[2] + + def test_from_pg_to_haystack_documents(self): + pg_docs = [ + { + "id": "1", + "content": "This is a text", + "dataframe": None, + "blob_data": None, + "blob_meta": None, + "blob_mime_type": None, + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.1, 0.2, 0.3]", + }, + { + "id": "2", + "content": None, + "dataframe": DataFrame({"col1": [1, 2], "col2": [3, 4]}).to_json(), + "blob_data": None, + "blob_meta": None, + "blob_mime_type": None, + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.4, 0.5, 0.6]", + }, + { + "id": "3", + "content": None, + "dataframe": None, + "blob_data": b"test", + "blob_meta": {"blob_meta_key": "blob_meta_value"}, + "blob_mime_type": "mime_type", + "meta": {"meta_key": "meta_value"}, + "embedding": "[0.7, 0.8, 0.9]", + }, + ] + + with patch( + "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore.__init__" + ) as mock_init: + mock_init.return_value = None + ds = PgvectorDocumentStore(connection_string="test") + + haystack_docs = ds._from_pg_to_haystack_documents(pg_docs) + + assert haystack_docs[0].id == "1" + assert haystack_docs[0].content == "This is a text" + assert haystack_docs[0].dataframe is None + assert haystack_docs[0].blob is None + assert haystack_docs[0].meta == {"meta_key": "meta_value"} + assert haystack_docs[0].embedding == [0.1, 0.2, 0.3] + assert haystack_docs[0].score is None + + assert haystack_docs[1].id == "2" + assert haystack_docs[1].content is None + assert haystack_docs[1].dataframe.equals(DataFrame({"col1": [1, 2], "col2": [3, 4]})) + assert haystack_docs[1].blob is None + assert haystack_docs[1].meta == {"meta_key": "meta_value"} + assert haystack_docs[1].embedding == [0.4, 0.5, 0.6] + assert haystack_docs[1].score is None + + assert haystack_docs[2].id == "3" + assert haystack_docs[2].content is None + assert haystack_docs[2].dataframe is None + assert haystack_docs[2].blob.data == b"test" + assert haystack_docs[2].blob.meta == {"blob_meta_key": "blob_meta_value"} + assert haystack_docs[2].blob.mime_type == "mime_type" + assert haystack_docs[2].meta == {"meta_key": "meta_value"} + assert haystack_docs[2].embedding == [0.7, 0.8, 0.9] + assert haystack_docs[2].score is None diff --git a/integrations/pgvector/tests/test_embedding_retrieval.py b/integrations/pgvector/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..1d5e8e297 --- /dev/null +++ b/integrations/pgvector/tests/test_embedding_retrieval.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from numpy.random import rand + + +class TestEmbeddingRetrieval: + @pytest.fixture + def document_store_w_hnsw_index(self, request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_hnsw_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "hnsw" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_cosine_similarity(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="cosine_similarity" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (cosine sim)" + assert results[1].content == "2nd best document (cosine sim)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_inner_product(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (inner product)", embedding=most_similar_embedding), + Document(content="2nd best document (inner product)", embedding=second_best_embedding), + Document(content="Not very similar document (inner product)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="inner_product" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (inner product)" + assert results[1].content == "2nd best document (inner product)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_l2_distance(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.1] * 765 + [0.15] * 3 + second_best_embedding = [0.1] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (l2 dist)", embedding=most_similar_embedding), + Document(content="2nd best document (l2 dist)", embedding=second_best_embedding), + Document(content="Not very similar document (l2 dist)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="l2_distance" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (l2 dist)" + assert results[1].content == "2nd best document (l2 dist)" + assert results[0].score < results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [Document(content=f"Document {i}", embedding=rand(768).tolist()) for i in range(10)] + + for i in range(10): + docs[i].meta["meta_field"] = "custom_value" if i % 2 == 0 else "other_value" + + document_store.write_documents(docs) + + query_embedding = [0.1] * 768 + filters = {"field": "meta.meta_field", "operator": "==", "value": "custom_value"} + + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=3, filters=filters) + assert len(results) == 3 + for result in results: + assert result.meta["meta_field"] == "custom_value" + assert results[0].score > results[1].score > results[2].score + + def test_empty_query_embedding(self, document_store: PgvectorDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py new file mode 100644 index 000000000..8b2dc8ec9 --- /dev/null +++ b/integrations/pgvector/tests/test_filters.py @@ -0,0 +1,179 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack_integrations.document_stores.pgvector.filters import ( + FilterError, + _convert_filters_to_where_clause_and_params, + _parse_comparison_condition, + _parse_logical_condition, + _treat_meta_field, +) +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + This overrides the default assert_documents_are_equal from FilterDocumentsTest. + It is needed because the embeddings are not exactly the same when they are retrieved from Postgres. + """ + + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + # we first compare the embeddings approximately + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + received_doc.embedding, expected_doc.embedding = None, None + assert received_doc == expected_doc + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) + + @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") + def test_not_operator(self, document_store, filterable_docs): ... + + def test_treat_meta_field(self): + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" + assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + + # do not cast the field if its value is not one of the known types, an empty list or None + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" + assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" + assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + + def test_comparison_condition_dataframe_jsonb_conversion(self): + dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + condition = {"field": "meta.df", "operator": "==", "value": dataframe} + field, values = _parse_comparison_condition(condition) + assert field == "(meta.df)::jsonb = %s" + + # we check each slot of the Jsonb object because it does not implement __eq__ + assert values[0].obj == Jsonb(dataframe.to_json()).obj + assert values[0].dumps == Jsonb(dataframe.to_json()).dumps + + def test_comparison_condition_missing_operator(self): + condition = {"field": "meta.type", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_missing_value(self): + condition = {"field": "meta.type", "operator": "=="} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_unknown_operator(self): + condition = {"field": "meta.type", "operator": "unknown", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_logical_condition_missing_operator(self): + condition = {"conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_missing_conditions(self): + condition = {"operator": "AND"} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_unknown_operator(self): + condition = {"operator": "unknown", "conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_nested(self): + condition = { + "operator": "AND", + "conditions": [ + { + "operator": "OR", + "conditions": [ + {"field": "meta.domain", "operator": "!=", "value": "science"}, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, + ], + }, + { + "operator": "OR", + "conditions": [ + {"field": "meta.number", "operator": ">=", "value": 90}, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, + ], + }, + ], + } + query, values = _parse_logical_condition(condition) + assert query == ( + "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " + "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + ) + assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] + + def test_convert_filters_to_where_clause_and_params(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + assert params == (100, "intro") + + def test_convert_filters_to_where_clause_and_params_handle_null(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": None}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + assert params == ("intro",) diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py new file mode 100644 index 000000000..cca6bbc9f --- /dev/null +++ b/integrations/pgvector/tests/test_retriever.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +from haystack.dataclasses import Document +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +class TestRetriever: + def test_init_default(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever(document_store=document_store) + assert retriever.document_store == document_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.vector_function == document_store.vector_function + + def test_init(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + assert retriever.document_store == document_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_to_dict(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + def test_from_dict(self): + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + retriever = PgvectorEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_ef_search is None + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" + ) + + assert res == {"documents": [doc]} diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 5ada5669e..c95ee0aac 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -34,6 +34,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/pinecone" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/pinecone-v(?P.*)' @@ -51,8 +54,8 @@ dependencies = [ [tool.hatch.envs.default.scripts] # Pinecone tests are slow (require HTTP requests), so we run them in parallel # with pytest-xdist (https://pytest-xdist.readthedocs.io/en/stable/distribution.html) -test = "pytest -n auto --maxprocesses=3 {args:tests}" -test-cov = "coverage run -m pytest -n auto --maxprocesses=3 {args:tests}" +test = "pytest -n auto --maxprocesses=2 {args:tests}" +test-cov = "coverage run -m pytest -n auto --maxprocesses=2 {args:tests}" cov-report = [ "- coverage combine", "coverage report", @@ -74,7 +77,7 @@ dependencies = [ "numpy", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/pinecone_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -143,26 +146,26 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["pinecone_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["pinecone_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true omit = [ - "example" + "examples" ] [tool.coverage.paths] -pinecone_haystack = ["src/pinecone_haystack", "*/pinecone_haystack/src/pinecone_haystack"] -tests = ["tests", "*/pinecone_haystack/tests"] +pinecone_haystack = ["src/*"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -182,6 +185,7 @@ markers = [ module = [ "pinecone.*", "haystack.*", + "haystack_integrations.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/__init__.py b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/__init__.py new file mode 100644 index 000000000..d73d799d4 --- /dev/null +++ b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/__init__.py @@ -0,0 +1,3 @@ +from .dense_retriever import PineconeDenseRetriever + +__all__ = ["PineconeDenseRetriever"] diff --git a/integrations/pinecone/src/pinecone_haystack/dense_retriever.py b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/dense_retriever.py similarity index 96% rename from integrations/pinecone/src/pinecone_haystack/dense_retriever.py rename to integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/dense_retriever.py index 3f60f252b..279ef4977 100644 --- a/integrations/pinecone/src/pinecone_haystack/dense_retriever.py +++ b/integrations/pinecone/src/haystack_integrations/components/retrievers/pinecone/dense_retriever.py @@ -6,7 +6,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import Document -from pinecone_haystack.document_store import PineconeDocumentStore +from haystack_integrations.document_stores.pinecone import PineconeDocumentStore @component diff --git a/integrations/instructor_embedders/instructor_embedders_haystack/embedding_backend/__init__.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/__init__.py similarity index 55% rename from integrations/instructor_embedders/instructor_embedders_haystack/embedding_backend/__init__.py rename to integrations/pinecone/src/haystack_integrations/document_stores/pinecone/__init__.py index e873bc332..159a85fae 100644 --- a/integrations/instructor_embedders/instructor_embedders_haystack/embedding_backend/__init__.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/__init__.py @@ -1,3 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .document_store import PineconeDocumentStore + +__all__ = ["PineconeDocumentStore"] diff --git a/integrations/pinecone/src/pinecone_haystack/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py similarity index 99% rename from integrations/pinecone/src/pinecone_haystack/document_store.py rename to integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index 8fe579611..92ea987b4 100644 --- a/integrations/pinecone/src/pinecone_haystack/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -8,13 +8,14 @@ from typing import Any, Dict, List, Optional import pandas as pd -import pinecone from haystack import default_to_dict from haystack.dataclasses import Document from haystack.document_stores.types import DuplicatePolicy from haystack.utils.filters import convert -from pinecone_haystack.filters import _normalize_filters +import pinecone + +from .filters import _normalize_filters logger = logging.getLogger(__name__) @@ -84,7 +85,7 @@ def __init__( ) self.dimension = actual_dimension or dimension - self._dummy_vector = [0.0] * self.dimension + self._dummy_vector = [-10.0] * self.dimension self.environment = environment self.index = index self.namespace = namespace diff --git a/integrations/pinecone/src/pinecone_haystack/errors.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/errors.py similarity index 100% rename from integrations/pinecone/src/pinecone_haystack/errors.py rename to integrations/pinecone/src/haystack_integrations/document_stores/pinecone/errors.py diff --git a/integrations/pinecone/src/pinecone_haystack/filters.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py similarity index 100% rename from integrations/pinecone/src/pinecone_haystack/filters.py rename to integrations/pinecone/src/haystack_integrations/document_stores/pinecone/filters.py diff --git a/integrations/pinecone/src/pinecone_haystack/__init__.py b/integrations/pinecone/src/pinecone_haystack/__init__.py deleted file mode 100644 index e3ec258d2..000000000 --- a/integrations/pinecone/src/pinecone_haystack/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from pinecone_haystack.dense_retriever import PineconeDenseRetriever -from pinecone_haystack.document_store import PineconeDocumentStore - -__all__ = ["PineconeDocumentStore", "PineconeDenseRetriever"] diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index 3ae642ae7..c7a1342d5 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -3,7 +3,7 @@ import pytest from haystack.document_stores.types import DuplicatePolicy -from pinecone_haystack.document_store import PineconeDocumentStore +from haystack_integrations.document_stores.pinecone import PineconeDocumentStore # This is the approximate time it takes for the documents to be available SLEEP_TIME = 20 diff --git a/integrations/pinecone/tests/test_dense_retriever.py b/integrations/pinecone/tests/test_dense_retriever.py index ceb73b687..e0f6dc375 100644 --- a/integrations/pinecone/tests/test_dense_retriever.py +++ b/integrations/pinecone/tests/test_dense_retriever.py @@ -5,8 +5,8 @@ from haystack.dataclasses import Document -from pinecone_haystack.dense_retriever import PineconeDenseRetriever -from pinecone_haystack.document_store import PineconeDocumentStore +from haystack_integrations.components.retrievers.pinecone import PineconeDenseRetriever +from haystack_integrations.document_stores.pinecone import PineconeDocumentStore def test_init_default(): @@ -17,7 +17,7 @@ def test_init_default(): assert retriever.top_k == 10 -@patch("pinecone_haystack.document_store.pinecone") +@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") def test_to_dict(mock_pinecone): mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} document_store = PineconeDocumentStore( @@ -31,7 +31,7 @@ def test_to_dict(mock_pinecone): retriever = PineconeDenseRetriever(document_store=document_store) res = retriever.to_dict() assert res == { - "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "type": "haystack_integrations.components.retrievers.pinecone.dense_retriever.PineconeDenseRetriever", "init_parameters": { "document_store": { "init_parameters": { @@ -41,7 +41,7 @@ def test_to_dict(mock_pinecone): "batch_size": 50, "dimension": 512, }, - "type": "pinecone_haystack.document_store.PineconeDocumentStore", + "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", }, "filters": {}, "top_k": 10, @@ -49,10 +49,10 @@ def test_to_dict(mock_pinecone): } -@patch("pinecone_haystack.document_store.pinecone") +@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") def test_from_dict(mock_pinecone, monkeypatch): data = { - "type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", + "type": "haystack_integrations.components.retrievers.pinecone.dense_retriever.PineconeDenseRetriever", "init_parameters": { "document_store": { "init_parameters": { @@ -62,7 +62,7 @@ def test_from_dict(mock_pinecone, monkeypatch): "batch_size": 50, "dimension": 512, }, - "type": "pinecone_haystack.document_store.PineconeDocumentStore", + "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", }, "filters": {}, "top_k": 10, diff --git a/integrations/pinecone/tests/test_document_store.py b/integrations/pinecone/tests/test_document_store.py index 5c9b32698..cd1bb0db3 100644 --- a/integrations/pinecone/tests/test_document_store.py +++ b/integrations/pinecone/tests/test_document_store.py @@ -5,59 +5,85 @@ from haystack import Document from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest -from pinecone_haystack.document_store import PineconeDocumentStore - - +from haystack_integrations.document_stores.pinecone import PineconeDocumentStore + + +@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +def test_init(mock_pinecone): + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} + + document_store = PineconeDocumentStore( + api_key="fake-api-key", + environment="gcp-starter", + index="my_index", + namespace="test", + batch_size=50, + dimension=30, + metric="euclidean", + ) + + mock_pinecone.init.assert_called_with(api_key="fake-api-key", environment="gcp-starter") + + assert document_store.environment == "gcp-starter" + assert document_store.index == "my_index" + assert document_store.namespace == "test" + assert document_store.batch_size == 50 + assert document_store.dimension == 30 + assert document_store.index_creation_kwargs == {"metric": "euclidean"} + + +@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +def test_init_api_key_in_environment_variable(mock_pinecone, monkeypatch): + monkeypatch.setenv("PINECONE_API_KEY", "fake-api-key") + + PineconeDocumentStore( + environment="gcp-starter", + index="my_index", + namespace="test", + batch_size=50, + dimension=30, + metric="euclidean", + ) + + mock_pinecone.init.assert_called_with(api_key="fake-api-key", environment="gcp-starter") + + +@patch("haystack_integrations.document_stores.pinecone.document_store.pinecone") +def test_to_dict(mock_pinecone): + mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} + document_store = PineconeDocumentStore( + api_key="fake-api-key", + environment="gcp-starter", + index="my_index", + namespace="test", + batch_size=50, + dimension=30, + metric="euclidean", + ) + assert document_store.to_dict() == { + "type": "haystack_integrations.document_stores.pinecone.document_store.PineconeDocumentStore", + "init_parameters": { + "environment": "gcp-starter", + "index": "my_index", + "dimension": 30, + "namespace": "test", + "batch_size": 50, + "metric": "euclidean", + }, + } + + +@pytest.mark.integration class TestDocumentStore(CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest): def test_write_documents(self, document_store: PineconeDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 @pytest.mark.skip(reason="Pinecone only supports UPSERT operations") - def test_write_documents_duplicate_fail(self, document_store: PineconeDocumentStore): - ... + def test_write_documents_duplicate_fail(self, document_store: PineconeDocumentStore): ... @pytest.mark.skip(reason="Pinecone only supports UPSERT operations") - def test_write_documents_duplicate_skip(self, document_store: PineconeDocumentStore): - ... - - @patch("pinecone_haystack.document_store.pinecone") - def test_init(self, mock_pinecone): - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} - - document_store = PineconeDocumentStore( - api_key="fake-api-key", - environment="gcp-starter", - index="my_index", - namespace="test", - batch_size=50, - dimension=30, - metric="euclidean", - ) - - mock_pinecone.init.assert_called_with(api_key="fake-api-key", environment="gcp-starter") - - assert document_store.environment == "gcp-starter" - assert document_store.index == "my_index" - assert document_store.namespace == "test" - assert document_store.batch_size == 50 - assert document_store.dimension == 30 - assert document_store.index_creation_kwargs == {"metric": "euclidean"} - - @patch("pinecone_haystack.document_store.pinecone") - def test_init_api_key_in_environment_variable(self, mock_pinecone, monkeypatch): - monkeypatch.setenv("PINECONE_API_KEY", "fake-api-key") - - PineconeDocumentStore( - environment="gcp-starter", - index="my_index", - namespace="test", - batch_size=50, - dimension=30, - metric="euclidean", - ) - - mock_pinecone.init.assert_called_with(api_key="fake-api-key", environment="gcp-starter") + def test_write_documents_duplicate_skip(self, document_store: PineconeDocumentStore): ... def test_init_fails_wo_api_key(self, monkeypatch): api_key = None @@ -69,30 +95,6 @@ def test_init_fails_wo_api_key(self, monkeypatch): index="my_index", ) - @patch("pinecone_haystack.document_store.pinecone") - def test_to_dict(self, mock_pinecone): - mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 30} - document_store = PineconeDocumentStore( - api_key="fake-api-key", - environment="gcp-starter", - index="my_index", - namespace="test", - batch_size=50, - dimension=30, - metric="euclidean", - ) - assert document_store.to_dict() == { - "type": "pinecone_haystack.document_store.PineconeDocumentStore", - "init_parameters": { - "environment": "gcp-starter", - "index": "my_index", - "dimension": 30, - "namespace": "test", - "batch_size": 50, - "metric": "euclidean", - }, - } - def test_embedding_retrieval(self, document_store: PineconeDocumentStore): query_embedding = [0.1] * 768 most_similar_embedding = [0.8] * 768 diff --git a/integrations/pinecone/tests/test_filters.py b/integrations/pinecone/tests/test_filters.py index 1e6aeb0cd..05796cf20 100644 --- a/integrations/pinecone/tests/test_filters.py +++ b/integrations/pinecone/tests/test_filters.py @@ -7,6 +7,7 @@ ) +@pytest.mark.integration class TestFilters(FilterDocumentsTest): def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): for doc in received: @@ -37,45 +38,34 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert received_doc.embedding == pytest.approx(expected_doc.embedding) @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_not_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_greater_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_less_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with dates") - def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support comparison with null values") - def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Pinecone does not support the 'not' operator") - def test_not_operator(self, document_store, filterable_docs): - ... + def test_not_operator(self, document_store, filterable_docs): ... diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index d1086fcdf..9c19d144e 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -35,6 +35,9 @@ Source = "https://github.com/deepset-ai/haystack-core-integrations" Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/qdrant/README.md" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/qdrant-v(?P.*)' @@ -71,7 +74,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/qdrant_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -136,23 +139,18 @@ unfixable = [ "F401", ] -[tool.ruff.isort] -known-first-party = ["qdrant_haystack"] - [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["qdrant_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true -omit = [ - "src/qdrant_haystack/__about__.py", -] + [tool.coverage.paths] qdrant_haystack = ["src/qdrant_haystack", "*/qdrant-haystack/src/qdrant_haystack"] @@ -168,6 +166,7 @@ exclude_lines = [ [[tool.mypy.overrides]] module = [ "haystack.*", + "haystack_integrations.*", "pytest.*", "qdrant_client.*", "numpy", diff --git a/integrations/google_ai/src/google_ai_haystack/generators/__init__.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py similarity index 55% rename from integrations/google_ai/src/google_ai_haystack/generators/__init__.py rename to integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py index e873bc332..41b59e42d 100644 --- a/integrations/google_ai/src/google_ai_haystack/generators/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py @@ -1,3 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + +from .retriever import QdrantEmbeddingRetriever + +__all__ = ("QdrantEmbeddingRetriever",) diff --git a/integrations/qdrant/src/qdrant_haystack/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py similarity index 97% rename from integrations/qdrant/src/qdrant_haystack/retriever.py rename to integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index bf378688c..e59dca3ad 100644 --- a/integrations/qdrant/src/qdrant_haystack/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -1,8 +1,7 @@ from typing import Any, Dict, List, Optional from haystack import Document, component, default_from_dict, default_to_dict - -from qdrant_haystack import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore @component diff --git a/integrations/qdrant/src/qdrant_haystack/__init__.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/__init__.py similarity index 70% rename from integrations/qdrant/src/qdrant_haystack/__init__.py rename to integrations/qdrant/src/haystack_integrations/document_stores/qdrant/__init__.py index 765ced0ef..dc3def997 100644 --- a/integrations/qdrant/src/qdrant_haystack/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from qdrant_haystack.document_store import QdrantDocumentStore +from .document_store import QdrantDocumentStore __all__ = ("QdrantDocumentStore",) diff --git a/integrations/qdrant/src/qdrant_haystack/converters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py similarity index 100% rename from integrations/qdrant/src/qdrant_haystack/converters.py rename to integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py diff --git a/integrations/qdrant/src/qdrant_haystack/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py similarity index 99% rename from integrations/qdrant/src/qdrant_haystack/document_store.py rename to integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 4fd724f67..50dd0220c 100644 --- a/integrations/qdrant/src/qdrant_haystack/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -16,8 +16,8 @@ from qdrant_client.http.exceptions import UnexpectedResponse from tqdm import tqdm -from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack -from qdrant_haystack.filters import QdrantFilterConverter +from .converters import HaystackToQdrant, QdrantToHaystack +from .filters import QdrantFilterConverter logger = logging.getLogger(__name__) diff --git a/integrations/qdrant/src/qdrant_haystack/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py similarity index 89% rename from integrations/qdrant/src/qdrant_haystack/filters.py rename to integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index cc6b2b6a5..77d800853 100644 --- a/integrations/qdrant/src/qdrant_haystack/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -4,7 +4,7 @@ from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models -from qdrant_haystack.converters import HaystackToQdrant +from .converters import HaystackToQdrant COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() @@ -113,9 +113,11 @@ def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> mo raise FilterError(msg) return models.Filter( should=[ - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " not in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) for item in value ] ) @@ -123,9 +125,11 @@ def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> mo def _build_ne_condition(self, key: str, value: models.ValueVariants) -> models.Condition: return models.Filter( must_not=[ - models.FieldCondition(key=key, match=models.MatchText(text=value)) - if isinstance(value, str) and " " not in value - else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ( + models.FieldCondition(key=key, match=models.MatchText(text=value)) + if isinstance(value, str) and " " not in value + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ) ] ) @@ -135,9 +139,11 @@ def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> m raise FilterError(msg) return models.Filter( must_not=[ - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " not in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) for item in value ] ) diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py index dc4866293..0c6c5676a 100644 --- a/integrations/qdrant/tests/test_converters.py +++ b/integrations/qdrant/tests/test_converters.py @@ -1,9 +1,8 @@ import numpy as np import pytest +from haystack_integrations.document_stores.qdrant.converters import HaystackToQdrant, QdrantToHaystack from qdrant_client.http import models as rest -from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack - CONTENT_FIELD = "content" NAME_FIELD = "name" EMBEDDING_FIELD = "vector" diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 1a211655c..1c9eb36e2 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -1,11 +1,11 @@ -from qdrant_haystack import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore def test_to_dict(): document_store = QdrantDocumentStore(location=":memory:", index="test") expected = { - "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", "init_parameters": { "location": ":memory:", "url": None, @@ -50,7 +50,7 @@ def test_to_dict(): def test_from_dict(): document_store = QdrantDocumentStore.from_dict( { - "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", "init_parameters": { "location": ":memory:", "index": "test", diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py index 0118ae0cf..8316ee565 100644 --- a/integrations/qdrant/tests/test_document_store.py +++ b/integrations/qdrant/tests/test_document_store.py @@ -9,8 +9,7 @@ DeleteDocumentsTest, WriteDocumentsTest, ) - -from qdrant_haystack import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore class TestQdrantStoreBaseTests(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py index a25f4a672..74bac76ad 100644 --- a/integrations/qdrant/tests/test_filters.py +++ b/integrations/qdrant/tests/test_filters.py @@ -4,8 +4,7 @@ from haystack import Document from haystack.testing.document_store import FilterDocumentsTest from haystack.utils.filters import FilterError - -from qdrant_haystack import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore class TestQdrantStoreBaseTests(FilterDocumentsTest): @@ -87,29 +86,22 @@ def test_comparison_less_than_equal_with_none(self, document_store, filterable_d # ======== ========================== ======== @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") - def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") - def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") - def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") - def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") - def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") - def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Cannot distinguish errors yet") - def test_missing_top_level_operator_key(self, document_store, filterable_docs): - ... + def test_missing_top_level_operator_key(self, document_store, filterable_docs): ... diff --git a/integrations/qdrant/tests/test_legacy_filters.py b/integrations/qdrant/tests/test_legacy_filters.py index 957603423..60f1fad2b 100644 --- a/integrations/qdrant/tests/test_legacy_filters.py +++ b/integrations/qdrant/tests/test_legacy_filters.py @@ -5,8 +5,7 @@ from haystack.document_stores.types import DocumentStore from haystack.testing.document_store import LegacyFilterDocumentsTest from haystack.utils.filters import FilterError - -from qdrant_haystack import QdrantDocumentStore +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore # The tests below are from haystack.testing.document_store.LegacyFilterDocumentsTest # Updated to include `meta` prefix for filter keys wherever necessary @@ -45,8 +44,7 @@ def test_filter_simple_metadata_value(self, document_store: DocumentStore, filte self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): ... def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) @@ -59,12 +57,10 @@ def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsNotEqualTest @@ -74,12 +70,10 @@ def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Do self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsInTest @@ -123,22 +117,18 @@ def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs ) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsNotInTest @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): document_store.write_documents(filterable_docs) @@ -164,12 +154,10 @@ def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_d document_store.filter_documents(filters={"meta.page": {"$gt": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsGreaterThanEqualTest @@ -187,12 +175,10 @@ def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_ document_store.filter_documents(filters={"meta.page": {"$gte": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsLessThanTest @@ -210,12 +196,10 @@ def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_d document_store.filter_documents(filters={"meta.page": {"$lt": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsLessThanEqualTest @@ -233,12 +217,10 @@ def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_ document_store.filter_documents(filters={"meta.page": {"$lte": "100"}}) @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") - def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): ... @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") - def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): - ... + def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): ... # LegacyFilterDocumentsSimpleLogicalTest diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index ed220c5bc..7521642ff 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -5,9 +5,8 @@ FilterableDocsFixtureMixin, _random_embeddings, ) - -from qdrant_haystack import QdrantDocumentStore -from qdrant_haystack.retriever import QdrantEmbeddingRetriever +from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever +from haystack_integrations.document_stores.qdrant import QdrantDocumentStore class TestQdrantRetriever(FilterableDocsFixtureMixin): @@ -24,10 +23,10 @@ def test_to_dict(self): retriever = QdrantEmbeddingRetriever(document_store=document_store) res = retriever.to_dict() assert res == { - "type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever", + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantEmbeddingRetriever", "init_parameters": { "document_store": { - "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", "init_parameters": { "location": ":memory:", "url": None, @@ -74,11 +73,11 @@ def test_to_dict(self): def test_from_dict(self): data = { - "type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever", + "type": "haystack_integrations.components.retrievers.qdrant.retriever.QdrantEmbeddingRetriever", "init_parameters": { "document_store": { "init_parameters": {"location": ":memory:", "index": "test"}, - "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "type": "haystack_integrations.document_stores.qdrant.document_store.QdrantDocumentStore", }, "filters": None, "top_k": 5, diff --git a/integrations/unstructured/pyproject.toml b/integrations/unstructured/pyproject.toml index e199b3c3e..9cc2a0c6a 100644 --- a/integrations/unstructured/pyproject.toml +++ b/integrations/unstructured/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "unstructured<0.11.4", # FIXME: investigate why 0.11.4 broke the tests + "unstructured", ] [project.urls] diff --git a/integrations/unstructured/tests/test_converter.py b/integrations/unstructured/tests/test_converter.py index 038807b14..ca590ab2f 100644 --- a/integrations/unstructured/tests/test_converter.py +++ b/integrations/unstructured/tests/test_converter.py @@ -188,6 +188,7 @@ def test_run_one_doc_per_element_with_meta_list_folder_fail(self, samples_path): def test_run_one_doc_per_element_with_meta_list_folder(self, samples_path): pdf_path = [samples_path] meta = {"common_meta": "common"} + local_converter = UnstructuredFileConverter( api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) diff --git a/integrations/uptrain/LICENSE.txt b/integrations/uptrain/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/uptrain/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/uptrain/README.md b/integrations/uptrain/README.md new file mode 100644 index 000000000..6d7605306 --- /dev/null +++ b/integrations/uptrain/README.md @@ -0,0 +1,36 @@ +# uptrain-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) + +--- + +**Table of Contents** + +- [uptrain-haystack](#uptrain-haystack) + - [Installation](#installation) + - [Testing](#testing) + - [Examples](#examples) + - [License](#license) + +## Installation + +```console +pip install uptrain-haystack +``` + +For more information about the UpTrain evaluation framework, please refer to their [documentation](https://docs.uptrain.ai/getting-started/introduction). + +## Testing + +```console +hatch run test +``` + +## Examples + +You can find a code example showing how to use the Evaluator under the `example/` folder of this repo. + +## License + +`uptrain-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/uptrain/example/example.py b/integrations/uptrain/example/example.py new file mode 100644 index 000000000..b029b9a65 --- /dev/null +++ b/integrations/uptrain/example/example.py @@ -0,0 +1,32 @@ +# A valid OpenAI API key is required to run this example. + +from haystack import Pipeline +from haystack_integrations.components.evaluators import UpTrainEvaluator, UpTrainMetric + +QUESTIONS = [ + "Which is the most popular global sport?", + "Who created the Python language?", +] +CONTEXTS = [ + "The popularity of sports can be measured in various ways, including TV viewership, social media presence, number of participants, and economic impact. Football is undoubtedly the world's most popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and Messi, drawing a followership of more than 4 billion people.", + "Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming language. Its design philosophy emphasizes code readability, and its language constructs aim to help programmers write clear, logical code for both small and large-scale software projects.", +] +RESPONSES = [ + "Football is the most popular sport with around 4 billion followers worldwide", + "Python language was created by Guido van Rossum.", +] + +pipeline = Pipeline() +evaluator = UpTrainEvaluator( + metric=UpTrainMetric.FACTUAL_ACCURACY, + api="openai", + api_key_env_var="OPENAI_API_KEY", +) +pipeline.add_component("evaluator", evaluator) + +# Each metric expects a specific set of parameters as input. Refer to the +# UpTrainMetric class' documentation for more details. +output = pipeline.run({"evaluator": {"questions": QUESTIONS, "contexts": CONTEXTS, "responses": RESPONSES}}) + +for output in output["evaluator"]["results"]: + print(output) diff --git a/integrations/uptrain/pyproject.toml b/integrations/uptrain/pyproject.toml new file mode 100644 index 000000000..498772313 --- /dev/null +++ b/integrations/uptrain/pyproject.toml @@ -0,0 +1,157 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "uptrain-haystack" +dynamic = ["version"] +description = 'An integration of UpTrain LLM evaluation framework with Haystack' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "uptrain>=0.5"] + +[project.urls] +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/uptrain" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/uptrain/README.md" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/uptrain-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/uptrain-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = ["coverage[toml]>=6.5", "pytest"] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/}" +style = ["ruff {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Misc + "S101", + "TID252", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +extend-exclude = ["tests", "example"] + +[tool.ruff.isort] +known-first-party = ["src"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["src", "tests"] +branch = true +parallel = true + +[tool.coverage.paths] +uptrain_haystack = [ + "src/haystack_integrations", + "*/uptrain-haystack/src/uptrain_haystack", +] +tests = ["tests"] + +[tool.coverage.report] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*", + "uptrain.*", + "numpy", + "grpc", + "haystack_integrations.*", +] +ignore_missing_imports = true diff --git a/integrations/uptrain/src/haystack_integrations/components/evaluators/__init__.py b/integrations/uptrain/src/haystack_integrations/components/evaluators/__init__.py new file mode 100644 index 000000000..e8366dfc0 --- /dev/null +++ b/integrations/uptrain/src/haystack_integrations/components/evaluators/__init__.py @@ -0,0 +1,7 @@ +from .evaluator import UpTrainEvaluator +from .metrics import UpTrainMetric + +__all__ = ( + "UpTrainEvaluator", + "UpTrainMetric", +) diff --git a/integrations/uptrain/src/haystack_integrations/components/evaluators/evaluator.py b/integrations/uptrain/src/haystack_integrations/components/evaluators/evaluator.py new file mode 100644 index 000000000..f99ec8105 --- /dev/null +++ b/integrations/uptrain/src/haystack_integrations/components/evaluators/evaluator.py @@ -0,0 +1,199 @@ +import json +import os +from typing import Any, Dict, List, Optional, Union + +from haystack import DeserializationError, component, default_from_dict, default_to_dict +from haystack_integrations.components.evaluators.metrics import ( + METRIC_DESCRIPTORS, + InputConverters, + OutputConverters, + UpTrainMetric, +) +from uptrain import APIClient, EvalLLM, Evals +from uptrain.framework.evals import ParametricEval + + +@component +class UpTrainEvaluator: + """ + A component that uses the UpTrain framework to evaluate inputs against a specific metric. + + The supported metrics are defined by :class:`UpTrainMetric`. The inputs of the component + metric-dependent. The output is a list of :class:`UpTrainEvaluatorOutput` objects, each + containing a single input and the result of the evaluation performed on it. + """ + + _backend_metric: Union[Evals, ParametricEval] + _backend_client: Union[APIClient, EvalLLM] + + def __init__( + self, + metric: Union[str, UpTrainMetric], + metric_params: Optional[Dict[str, Any]] = None, + *, + api: str = "openai", + api_key_env_var: Optional[str] = "OPENAI_API_KEY", + api_params: Optional[Dict[str, Any]] = None, + ): + """ + Construct a new UpTrain evaluator. + + :param metric: + The metric to use for evaluation. + :param metric_params: + Parameters to pass to the metric's constructor. + :param api: + The API to use for evaluation. + + Supported APIs: "openai", "uptrain". + :param api_key_env_var: + The name of the environment variable containing the API key. + :param api_params: + Additional parameters to pass to the API client. + """ + self.metric = metric if isinstance(metric, UpTrainMetric) else UpTrainMetric.from_str(metric) + self.metric_params = metric_params + self.descriptor = METRIC_DESCRIPTORS[self.metric] + self.api = api + self.api_key_env_var = api_key_env_var + self.api_params = api_params + + self._init_backend() + expected_inputs = self.descriptor.input_parameters + component.set_input_types(self, **expected_inputs) + + @component.output_types(results=List[List[Dict[str, Any]]]) + def run(self, **inputs) -> Dict[str, Any]: + """ + Run the UpTrain evaluator. + + Example: + ```python + pipeline = Pipeline() + evaluator = UpTrainEvaluator( + metric=UpTrainMetric.FACTUAL_ACCURACY, + api="openai", + api_key_env_var="OPENAI_API_KEY", + ) + pipeline.add_component("evaluator", evaluator) + + # Each metric expects a specific set of parameters as input. Refer to the + # UpTrainMetric class' documentation for more details. + output = pipeline.run({"evaluator": { + "questions": ["question], + "contexts": ["context"], + "responses": ["response"] + }}) + ``` + + :param inputs: + The inputs to evaluate. These are determined by the + metric being calculated. See :class:`UpTrainMetric` for more + information. + :returns: + A nested list of metric results. Each input can have one or more + results, depending on the metric. Each result is a dictionary + containing the following keys and values: + * `name` - The name of the metric. + * `score` - The score of the metric. + * `explanation` - An optional explanation of the score. + """ + # The backend requires random access to the data, so we can't stream it. + InputConverters.validate_input_parameters(self.metric, self.descriptor.input_parameters, inputs) + converted_inputs: List[Dict[str, str]] = list(self.descriptor.input_converter(**inputs)) # type: ignore + + eval_args = {"data": converted_inputs, "checks": [self._backend_metric]} + if self.api_params is not None: + eval_args.update({k: v for k, v in self.api_params.items() if k not in eval_args}) + + results: List[Dict[str, Any]] + if isinstance(self._backend_client, EvalLLM): + results = self._backend_client.evaluate(**eval_args) + else: + results = self._backend_client.log_and_evaluate(**eval_args) + + OutputConverters.validate_outputs(results) + converted_results = [ + [result.to_dict() for result in self.descriptor.output_converter(x, self.metric_params)] for x in results + ] + + return {"results": converted_results} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + + def check_serializable(obj: Any): + try: + json.dumps(obj) + return True + except (TypeError, OverflowError): + return False + + if not check_serializable(self.api_params) or not check_serializable(self.metric_params): + msg = "UpTrain evaluator cannot serialize the API/metric parameters" + raise DeserializationError(msg) + + return default_to_dict( + self, + metric=self.metric, + metric_params=self.metric_params, + api=self.api, + api_key_env_var=self.api_key_env_var, + api_params=self.api_params, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "UpTrainEvaluator": + """ + Deserialize a component from a dictionary. + + :param data: + The dictionary to deserialize from. + """ + return default_from_dict(cls, data) + + def _init_backend(self): + """ + Initialize the UpTrain backend. + """ + if isinstance(self.descriptor.backend, Evals): + if self.metric_params is not None: + msg = ( + f"Uptrain metric '{self.metric}' received the following unexpected init parameters:" + f"{self.metric_params}" + ) + raise ValueError(msg) + backend_metric = self.descriptor.backend + else: + assert issubclass(self.descriptor.backend, ParametricEval) + if self.metric_params is None: + msg = f"Uptrain metric '{self.metric}' expected init parameters but got none" + raise ValueError(msg) + elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()): + msg = ( + f"Invalid init parameters for UpTrain metric '{self.metric}'. " + f"Expected: {list(self.descriptor.init_parameters.keys())}" + ) + + raise ValueError(msg) + backend_metric = self.descriptor.backend(**self.metric_params) + + supported_apis = ("openai", "uptrain") + if self.api not in supported_apis: + msg = f"Unsupported API '{self.api}' for UpTrain evaluator. Supported APIs: {supported_apis}" + raise ValueError(msg) + + api_key = os.environ.get(self.api_key_env_var) + if api_key is None: + msg = f"Missing API key environment variable '{self.api_key_env_var}' for UpTrain evaluator" + raise ValueError(msg) + + if self.api == "openai": + backend_client = EvalLLM(openai_api_key=api_key) + elif self.api == "uptrain": + backend_client = APIClient(uptrain_api_key=api_key) + + self._backend_metric = backend_metric + self._backend_client = backend_client diff --git a/integrations/uptrain/src/haystack_integrations/components/evaluators/metrics.py b/integrations/uptrain/src/haystack_integrations/components/evaluators/metrics.py new file mode 100644 index 000000000..e42b63e21 --- /dev/null +++ b/integrations/uptrain/src/haystack_integrations/components/evaluators/metrics.py @@ -0,0 +1,366 @@ +import dataclasses +import inspect +from dataclasses import dataclass +from enum import Enum +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union + +from uptrain import CritiqueTone, Evals, GuidelineAdherence, ResponseMatching +from uptrain.framework.evals import ParametricEval + + +class UpTrainMetric(Enum): + """ + Metrics supported by UpTrain. + """ + + #: Context relevance. + #: Inputs - `questions: List[str], contexts: List[str]` + CONTEXT_RELEVANCE = "context_relevance" + + #: Factual accuracy. + #: Inputs - `questions: List[str], contexts: List[str], responses: List[str]` + FACTUAL_ACCURACY = "factual_accuracy" + + #: Response relevance. + #: Inputs - `questions: List[str], responses: List[str]` + RESPONSE_RELEVANCE = "response_relevance" + + #: Response completeness. + #: Inputs - `questions: List[str], responses: List[str]` + RESPONSE_COMPLETENESS = "response_completeness" + + #: Response completeness with respect to context. + #: Inputs - `questions: List[str], contexts: List[str], responses: List[str]` + RESPONSE_COMPLETENESS_WRT_CONTEXT = "response_completeness_wrt_context" + + #: Response consistency. + #: Inputs - `questions: List[str], contexts: List[str], responses: List[str]` + RESPONSE_CONSISTENCY = "response_consistency" + + #: Response conciseness. + #: Inputs - `questions: List[str], responses: List[str]` + RESPONSE_CONCISENESS = "response_conciseness" + + #: Language critique. + #: Inputs - `responses: List[str]` + CRITIQUE_LANGUAGE = "critique_language" + + #: Tone critique. + #: Inputs - `responses: List[str]` + CRITIQUE_TONE = "critique_tone" + + #: Guideline adherence. + #: Inputs - `questions: List[str], responses: List[str]` + GUIDELINE_ADHERENCE = "guideline_adherence" + + #: Response matching. + #: Inputs - `responses: List[str], ground_truths: List[str]` + RESPONSE_MATCHING = "response_matching" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "UpTrainMetric": + """ + Create a metric type from a string. + + :param string: + The string to convert. + :returns: + The metric. + """ + enum_map = {e.value: e for e in UpTrainMetric} + metric = enum_map.get(string) + if metric is None: + msg = f"Unknown UpTrain metric '{string}'. Supported metrics: {list(enum_map.keys())}" + raise ValueError(msg) + return metric + + +@dataclass(frozen=True) +class MetricResult: + """ + Result of a metric evaluation. + + :param name: + The name of the metric. + :param score: + The score of the metric. + :param explanation: + An optional explanation of the metric. + """ + + name: str + score: float + explanation: Optional[str] = None + + def to_dict(self): + return dataclasses.asdict(self) + + +@dataclass(frozen=True) +class MetricDescriptor: + """ + Descriptor for a metric. + + :param metric: + The metric. + :param backend: + The associated UpTrain metric class. + :param input_parameters: + Parameters accepted by the metric. This is used + to set the input types of the evaluator component. + :param input_converter: + Callable that converts input parameters to the UpTrain input format. + :param output_converter: + Callable that converts the UpTrain output format to our output format. + :param init_parameters: + Additional parameters that need to be passed to the metric class during initialization. + """ + + metric: UpTrainMetric + backend: Union[Evals, Type[ParametricEval]] + input_parameters: Dict[str, Type] + input_converter: Callable[[Any], Iterable[Dict[str, str]]] + output_converter: Callable[[Dict[str, Any], Optional[Dict[str, Any]]], List[MetricResult]] + init_parameters: Optional[Dict[str, Type[Any]]] = None + + @classmethod + def new( + cls, + metric: UpTrainMetric, + backend: Union[Evals, Type[ParametricEval]], + input_converter: Callable[[Any], Iterable[Dict[str, str]]], + output_converter: Optional[Callable[[Dict[str, Any], Optional[Dict[str, Any]]], List[MetricResult]]] = None, + *, + init_parameters: Optional[Dict[str, Type]] = None, + ) -> "MetricDescriptor": + input_converter_signature = inspect.signature(input_converter) + input_parameters = {} + for name, param in input_converter_signature.parameters.items(): + if name in ("cls", "self"): + continue + elif param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): + continue + input_parameters[name] = param.annotation + + return cls( + metric=metric, + backend=backend, + input_parameters=input_parameters, + input_converter=input_converter, + output_converter=output_converter if output_converter is not None else OutputConverters.default(metric), + init_parameters=init_parameters, + ) + + +class InputConverters: + """ + Converters for input parameters. + + The signature of the converter functions serves as the ground-truth of the + expected input parameters of a given metric. They are also responsible for validating + the input parameters and converting them to the format expected by UpTrain. + """ + + @staticmethod + def _validate_input_elements(**kwargs): + for k, collection in kwargs.items(): + if not isinstance(collection, list): + msg = ( + f"UpTrain evaluator expected input '{k}' to be a collection of type 'list', " + f"got '{type(collection).__name__}' instead" + ) + raise ValueError(msg) + elif not all(isinstance(x, str) for x in collection): + msg = f"UpTrain evaluator expects inputs to be of type 'str' in '{k}'" + raise ValueError(msg) + + same_length = len({len(x) for x in kwargs.values()}) == 1 + if not same_length: + msg = f"Mismatching counts in the following inputs: {({k: len(v) for k, v in kwargs.items()})}" + raise ValueError(msg) + + @staticmethod + def validate_input_parameters(metric: UpTrainMetric, expected: Dict[str, Any], received: Dict[str, Any]): + for param, _ in expected.items(): + if param not in received: + msg = f"UpTrain evaluator expected input parameter '{param}' for metric '{metric}'" + raise ValueError(msg) + + @staticmethod + def question_context_response( + questions: List[str], contexts: List[str], responses: List[str] + ) -> Iterable[Dict[str, str]]: + InputConverters._validate_input_elements(questions=questions, contexts=contexts, responses=responses) + for q, c, r in zip(questions, contexts, responses): # type: ignore + yield {"question": q, "context": c, "response": r} + + @staticmethod + def question_context( + questions: List[str], + contexts: List[str], + ) -> Iterable[Dict[str, str]]: + InputConverters._validate_input_elements(questions=questions, contexts=contexts) + for q, c in zip(questions, contexts): # type: ignore + yield {"question": q, "context": c} + + @staticmethod + def question_response( + questions: List[str], + responses: List[str], + ) -> Iterable[Dict[str, str]]: + InputConverters._validate_input_elements(questions=questions, responses=responses) + for q, r in zip(questions, responses): # type: ignore + yield {"question": q, "response": r} + + @staticmethod + def response( + responses: List[str], + ) -> Iterable[Dict[str, str]]: + InputConverters._validate_input_elements(responses=responses) + for r in responses: + yield {"response": r} + + @staticmethod + def response_ground_truth( + responses: List[str], + ground_truths: List[str], + ) -> Iterable[Dict[str, str]]: + InputConverters._validate_input_elements(ground_truths=ground_truths, responses=responses) + for r, gt in zip(responses, ground_truths): # type: ignore + yield {"response": r, "ground_truth": gt} + + +class OutputConverters: + """ + Converters for results returned by UpTrain. + + They are responsible for converting the results to our output format. + """ + + @staticmethod + def validate_outputs(outputs: List[Dict[str, Any]]): + msg = None + if not isinstance(outputs, list): + msg = f"Expected response from UpTrain evaluator to be a 'list', got '{type(outputs).__name__}'" + elif not all(isinstance(x, dict) for x in outputs): + msg = "UpTrain evaluator expects outputs to be a list of `dict`s" + elif not all(isinstance(y, str) for x in outputs for y in x.keys()): + msg = "UpTrain evaluator expects keys in the output dicts to be `str`" + elif not all(isinstance(y, (float, str)) for x in outputs for y in x.values()): + msg = "UpTrain evaluator expects values in the output dicts to be either `str` or `float`" + + if msg is not None: + raise ValueError(msg) + + @staticmethod + def _extract_default_results(output: Dict[str, Any], metric_name: str) -> MetricResult: + try: + score_key = f"score_{metric_name}" + explanation_key = f"explanation_{metric_name}" + return MetricResult(name=metric_name, score=output[score_key], explanation=output.get(explanation_key)) + except KeyError as e: + msg = f"UpTrain evaluator did not return an expected output for metric '{metric_name}'" + raise ValueError(msg) from e + + @staticmethod + def default( + metric: UpTrainMetric, + ) -> Callable[[Dict[str, Any], Optional[Dict[str, Any]]], List[MetricResult]]: + def inner( + output: Dict[str, Any], metric_params: Optional[Dict[str, Any]], metric: UpTrainMetric # noqa: ARG001 + ) -> List[MetricResult]: + return [OutputConverters._extract_default_results(output, str(metric))] + + return partial(inner, metric=metric) + + @staticmethod + def critique_language( + output: Dict[str, Any], metric_params: Optional[Dict[str, Any]] # noqa: ARG004 + ) -> List[MetricResult]: + out = [] + for expected_key in ("fluency", "coherence", "grammar", "politeness"): + out.append(OutputConverters._extract_default_results(output, expected_key)) + return out + + @staticmethod + def critique_tone( + output: Dict[str, Any], metric_params: Optional[Dict[str, Any]] # noqa: ARG004 + ) -> List[MetricResult]: + return [OutputConverters._extract_default_results(output, "tone")] + + @staticmethod + def guideline_adherence(output: Dict[str, Any], metric_params: Optional[Dict[str, Any]]) -> List[MetricResult]: + assert metric_params is not None + return [OutputConverters._extract_default_results(output, f'{metric_params["guideline_name"]}_adherence')] + + @staticmethod + def response_matching( + output: Dict[str, Any], metric_params: Optional[Dict[str, Any]] # noqa: ARG004 + ) -> List[MetricResult]: + metric_str = "response_match" + out = [OutputConverters._extract_default_results(output, metric_str)] + + # Enumerate other relevant keys. + score_key = f"score_{metric_str}" + for k, v in output.items(): + if k != score_key and metric_str in k and isinstance(v, float): + out.append(MetricResult(name=k, score=v)) + return out + + +METRIC_DESCRIPTORS = { + UpTrainMetric.CONTEXT_RELEVANCE: MetricDescriptor.new( + UpTrainMetric.CONTEXT_RELEVANCE, Evals.CONTEXT_RELEVANCE, InputConverters.question_context # type: ignore + ), + UpTrainMetric.FACTUAL_ACCURACY: MetricDescriptor.new( + UpTrainMetric.FACTUAL_ACCURACY, Evals.FACTUAL_ACCURACY, InputConverters.question_context_response # type: ignore + ), + UpTrainMetric.RESPONSE_RELEVANCE: MetricDescriptor.new( + UpTrainMetric.RESPONSE_RELEVANCE, Evals.RESPONSE_RELEVANCE, InputConverters.question_response # type: ignore + ), + UpTrainMetric.RESPONSE_COMPLETENESS: MetricDescriptor.new( + UpTrainMetric.RESPONSE_COMPLETENESS, Evals.RESPONSE_COMPLETENESS, InputConverters.question_response # type: ignore + ), + UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT: MetricDescriptor.new( + UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT, + Evals.RESPONSE_COMPLETENESS_WRT_CONTEXT, + InputConverters.question_context_response, # type: ignore + ), + UpTrainMetric.RESPONSE_CONSISTENCY: MetricDescriptor.new( + UpTrainMetric.RESPONSE_CONSISTENCY, Evals.RESPONSE_CONSISTENCY, InputConverters.question_context_response # type: ignore + ), + UpTrainMetric.RESPONSE_CONCISENESS: MetricDescriptor.new( + UpTrainMetric.RESPONSE_CONCISENESS, Evals.RESPONSE_CONCISENESS, InputConverters.question_response # type: ignore + ), + UpTrainMetric.CRITIQUE_LANGUAGE: MetricDescriptor.new( + UpTrainMetric.CRITIQUE_LANGUAGE, + Evals.CRITIQUE_LANGUAGE, + InputConverters.response, + OutputConverters.critique_language, + ), + UpTrainMetric.CRITIQUE_TONE: MetricDescriptor.new( + UpTrainMetric.CRITIQUE_TONE, + CritiqueTone, + InputConverters.response, + OutputConverters.critique_tone, + init_parameters={"llm_persona": str}, + ), + UpTrainMetric.GUIDELINE_ADHERENCE: MetricDescriptor.new( + UpTrainMetric.GUIDELINE_ADHERENCE, + GuidelineAdherence, + InputConverters.question_response, # type: ignore + OutputConverters.guideline_adherence, + init_parameters={"guideline": str, "guideline_name": str, "response_schema": Optional[str]}, # type: ignore + ), + UpTrainMetric.RESPONSE_MATCHING: MetricDescriptor.new( + UpTrainMetric.RESPONSE_MATCHING, + ResponseMatching, + InputConverters.response_ground_truth, # type: ignore + OutputConverters.response_matching, + init_parameters={"method": Optional[str]}, # type: ignore + ), +} diff --git a/integrations/uptrain/tests/__init__.py b/integrations/uptrain/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/uptrain/tests/test_evaluator.py b/integrations/uptrain/tests/test_evaluator.py new file mode 100644 index 000000000..2128e0634 --- /dev/null +++ b/integrations/uptrain/tests/test_evaluator.py @@ -0,0 +1,380 @@ +import copy +import os +from dataclasses import dataclass +from typing import List +from unittest.mock import patch + +import pytest +from haystack import DeserializationError + +from haystack_integrations.components.evaluators import UpTrainEvaluator, UpTrainMetric + +DEFAULT_QUESTIONS = [ + "Which is the most popular global sport?", + "Who created the Python language?", +] +DEFAULT_CONTEXTS = [ + "The popularity of sports can be measured in various ways, including TV viewership, social media presence, number of participants, and economic impact. Football is undoubtedly the world's most popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and Messi, drawing a followership of more than 4 billion people.", + "Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming language. Its design philosophy emphasizes code readability, and its language constructs aim to help programmers write clear, logical code for both small and large-scale software projects.", +] +DEFAULT_RESPONSES = [ + "Football is the most popular sport with around 4 billion followers worldwide", + "Python language was created by Guido van Rossum.", +] + + +@dataclass(frozen=True) +class Unserializable: + something: str + + +# Only returns results for the passed metrics. +class MockBackend: + def __init__(self, metric_outputs: List[UpTrainMetric]) -> None: + self.metrics = metric_outputs + if not self.metrics: + self.metrics = [e for e in UpTrainMetric] + + def log_and_evaluate(self, data, checks, **kwargs): + output_map = { + UpTrainMetric.CONTEXT_RELEVANCE: { + "score_context_relevance": 0.5, + "explanation_context_relevance": "1", + }, + UpTrainMetric.FACTUAL_ACCURACY: { + "score_factual_accuracy": 1.0, + "explanation_factual_accuracy": "2", + }, + UpTrainMetric.RESPONSE_RELEVANCE: { + "score_response_relevance": 1.0, + "explanation_response_relevance": "3", + }, + UpTrainMetric.RESPONSE_COMPLETENESS: { + "score_response_completeness": 0.5, + "explanation_response_completeness": "4", + }, + UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT: { + "score_response_completeness_wrt_context": 1.0, + "explanation_response_completeness_wrt_context": "5", + }, + UpTrainMetric.RESPONSE_CONSISTENCY: { + "score_response_consistency": 0.9, + "explanation_response_consistency": "6", + }, + UpTrainMetric.RESPONSE_CONCISENESS: { + "score_response_conciseness": 1.0, + "explanation_response_conciseness": "7", + }, + UpTrainMetric.CRITIQUE_LANGUAGE: { + "score_fluency": 1.0, + "score_coherence": 1.0, + "score_grammar": 1.0, + "score_politeness": 1.0, + "explanation_fluency": "8", + "explanation_coherence": "9", + "explanation_grammar": "10", + "explanation_politeness": "11", + }, + UpTrainMetric.CRITIQUE_TONE: { + "score_tone": 0.4, + "explanation_tone": "12", + }, + UpTrainMetric.GUIDELINE_ADHERENCE: { + "score_guideline_adherence": 1.0, + "explanation_guideline_adherence": "13", + }, + UpTrainMetric.RESPONSE_MATCHING: { + "response_match_precision": 1.0, + "response_match_recall": 0.6666666666666666, + "score_response_match": 0.7272727272727273, + }, + } + + data = copy.deepcopy(data) + for x in data: + for m in self.metrics: + x.update(output_map[m]) + return data + + +@patch("os.environ.get") +def test_evaluator_api(os_environ_get): + api_key_var = "test-api-key" + os_environ_get.return_value = api_key_var + + eval = UpTrainEvaluator(UpTrainMetric.RESPONSE_COMPLETENESS) + assert eval.api == "openai" + assert eval.api_key_env_var == "OPENAI_API_KEY" + + eval = UpTrainEvaluator(UpTrainMetric.RESPONSE_COMPLETENESS, api="uptrain", api_key_env_var="UPTRAIN_API_KEY") + assert eval.api == "uptrain" + assert eval.api_key_env_var == "UPTRAIN_API_KEY" + + with pytest.raises(ValueError, match="Unsupported API"): + UpTrainEvaluator(UpTrainMetric.CONTEXT_RELEVANCE, api="cohere") + + os_environ_get.return_value = None + with pytest.raises(ValueError, match="Missing API key"): + UpTrainEvaluator(UpTrainMetric.CONTEXT_RELEVANCE, api="uptrain") + + +@patch("os.environ.get") +def test_evaluator_metric_init_params(os_environ_get): + api_key = "test-api-key" + os_environ_get.return_value = api_key + + eval = UpTrainEvaluator(UpTrainMetric.CRITIQUE_TONE, metric_params={"llm_persona": "village idiot"}) + assert eval._backend_metric.llm_persona == "village idiot" + + with pytest.raises(ValueError, match="Invalid init parameters"): + UpTrainEvaluator(UpTrainMetric.CRITIQUE_TONE, metric_params={"role": "village idiot"}) + + with pytest.raises(ValueError, match="unexpected init parameters"): + UpTrainEvaluator(UpTrainMetric.FACTUAL_ACCURACY, metric_params={"check_numbers": True}) + + with pytest.raises(ValueError, match="expected init parameters"): + UpTrainEvaluator(UpTrainMetric.RESPONSE_MATCHING) + + +@patch("os.environ.get") +def test_evaluator_serde(os_environ_get): + os_environ_get.return_value = "abacab" + + init_params = { + "metric": UpTrainMetric.RESPONSE_MATCHING, + "metric_params": {"method": "rouge"}, + "api": "uptrain", + "api_key_env_var": "abacab", + "api_params": {"eval_name": "test"}, + } + eval = UpTrainEvaluator(**init_params) + serde_data = eval.to_dict() + new_eval = UpTrainEvaluator.from_dict(serde_data) + + assert eval.metric == new_eval.metric + assert eval.api == new_eval.api + assert eval.api_key_env_var == new_eval.api_key_env_var + assert eval.metric_params == new_eval.metric_params + assert eval.api_params == new_eval.api_params + assert type(new_eval._backend_client) == type(eval._backend_client) + assert type(new_eval._backend_metric) == type(eval._backend_metric) + + with pytest.raises(DeserializationError, match=r"cannot serialize the API/metric parameters"): + init_params3 = copy.deepcopy(init_params) + init_params3["api_params"] = {"arg": Unserializable("")} + eval = UpTrainEvaluator(**init_params3) + eval.to_dict() + + +@pytest.mark.parametrize( + "metric, inputs, params", + [ + (UpTrainMetric.CONTEXT_RELEVANCE, {"questions": [], "contexts": []}, None), + (UpTrainMetric.FACTUAL_ACCURACY, {"questions": [], "contexts": [], "responses": []}, None), + (UpTrainMetric.RESPONSE_RELEVANCE, {"questions": [], "responses": []}, None), + (UpTrainMetric.RESPONSE_COMPLETENESS, {"questions": [], "responses": []}, None), + (UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT, {"questions": [], "contexts": [], "responses": []}, None), + (UpTrainMetric.RESPONSE_CONSISTENCY, {"questions": [], "contexts": [], "responses": []}, None), + (UpTrainMetric.RESPONSE_CONCISENESS, {"questions": [], "responses": []}, None), + (UpTrainMetric.CRITIQUE_LANGUAGE, {"responses": []}, None), + (UpTrainMetric.CRITIQUE_TONE, {"responses": []}, {"llm_persona": "idiot"}), + ( + UpTrainMetric.GUIDELINE_ADHERENCE, + {"questions": [], "responses": []}, + {"guideline": "Do nothing", "guideline_name": "somename", "response_schema": None}, + ), + (UpTrainMetric.RESPONSE_MATCHING, {"ground_truths": [], "responses": []}, {"method": "llm"}), + ], +) +@patch("os.environ.get") +def test_evaluator_valid_inputs(os_environ_get, metric, inputs, params): + os_environ_get.return_value = "abacab" + init_params = { + "metric": metric, + "metric_params": params, + "api": "uptrain", + "api_key_env_var": "abacab", + "api_params": None, + } + eval = UpTrainEvaluator(**init_params) + eval._backend_client = MockBackend([metric]) + output = eval.run(**inputs) + + +@pytest.mark.parametrize( + "metric, inputs, error_string, params", + [ + (UpTrainMetric.CONTEXT_RELEVANCE, {"questions": {}, "contexts": []}, "to be a collection of type 'list'", None), + ( + UpTrainMetric.FACTUAL_ACCURACY, + {"questions": [1], "contexts": [2], "responses": [3]}, + "expects inputs to be of type 'str'", + None, + ), + (UpTrainMetric.RESPONSE_RELEVANCE, {"questions": [""], "responses": []}, "Mismatching counts ", None), + (UpTrainMetric.RESPONSE_RELEVANCE, {"responses": []}, "expected input parameter ", None), + ], +) +@patch("os.environ.get") +def test_evaluator_invalid_inputs(os_environ_get, metric, inputs, error_string, params): + os_environ_get.return_value = "abacab" + with pytest.raises(ValueError, match=error_string): + init_params = { + "metric": metric, + "metric_params": params, + "api": "uptrain", + "api_key_env_var": "abacab", + "api_params": None, + } + eval = UpTrainEvaluator(**init_params) + eval._backend_client = MockBackend([metric]) + output = eval.run(**inputs) + + +# This test validates the expected outputs of the evaluator. +# Each output is parameterized as a list of tuples, where each tuple is +# (name, score, explanation). The name and explanation are optional. If +# the name is None, then the metric name is used. +@pytest.mark.parametrize( + "metric, inputs, expected_outputs, metric_params", + [ + (UpTrainMetric.CONTEXT_RELEVANCE, {"questions": ["q1"], "contexts": ["c1"]}, [[(None, 0.5, "1")]], None), + ( + UpTrainMetric.FACTUAL_ACCURACY, + {"questions": ["q2"], "contexts": ["c2"], "responses": ["r2"]}, + [[(None, 1.0, "2")]], + None, + ), + (UpTrainMetric.RESPONSE_RELEVANCE, {"questions": ["q3"], "responses": ["r3"]}, [[(None, 1.0, "3")]], None), + (UpTrainMetric.RESPONSE_COMPLETENESS, {"questions": ["q4"], "responses": ["r4"]}, [[(None, 0.5, "4")]], None), + ( + UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT, + {"questions": ["q5"], "contexts": ["c5"], "responses": ["r5"]}, + [[(None, 1.0, "5")]], + None, + ), + ( + UpTrainMetric.RESPONSE_CONSISTENCY, + {"questions": ["q6"], "contexts": ["c6"], "responses": ["r6"]}, + [[(None, 0.9, "6")]], + None, + ), + (UpTrainMetric.RESPONSE_CONCISENESS, {"questions": ["q7"], "responses": ["r7"]}, [[(None, 1.0, "7")]], None), + ( + UpTrainMetric.CRITIQUE_LANGUAGE, + {"responses": ["r8"]}, + [ + [ + ("fluency", 1.0, "8"), + ("coherence", 1.0, "9"), + ("grammar", 1.0, "10"), + ("politeness", 1.0, "11"), + ] + ], + None, + ), + (UpTrainMetric.CRITIQUE_TONE, {"responses": ["r9"]}, [[("tone", 0.4, "12")]], {"llm_persona": "idiot"}), + ( + UpTrainMetric.GUIDELINE_ADHERENCE, + {"questions": ["q10"], "responses": ["r10"]}, + [[(None, 1.0, "13")]], + {"guideline": "Do nothing", "guideline_name": "guideline", "response_schema": None}, + ), + ( + UpTrainMetric.RESPONSE_MATCHING, + {"ground_truths": ["g11"], "responses": ["r11"]}, + [ + [ + ("response_match_precision", 1.0, None), + ("response_match_recall", 0.6666666666666666, None), + ("response_match", 0.7272727272727273, None), + ] + ], + {"method": "llm"}, + ), + ], +) +@patch("os.environ.get") +def test_evaluator_outputs(os_environ_get, metric, inputs, expected_outputs, metric_params): + os_environ_get.return_value = "abacab" + init_params = { + "metric": metric, + "metric_params": metric_params, + "api": "uptrain", + "api_key_env_var": "abacab", + "api_params": None, + } + eval = UpTrainEvaluator(**init_params) + eval._backend_client = MockBackend([metric]) + results = eval.run(**inputs)["results"] + + assert type(results) == type(expected_outputs) + assert len(results) == len(expected_outputs) + + for r, o in zip(results, expected_outputs): + assert len(r) == len(o) + + expected = {(name if name is not None else str(metric), score, exp) for name, score, exp in o} + got = {(x["name"], x["score"], x["explanation"]) for x in r} + assert got == expected + + +# This integration test validates the evaluator by running it against the +# OpenAI API. It is parameterized by the metric, the inputs to the evalutor +# and the metric parameters. +@pytest.mark.skipif("OPENAI_API_KEY" not in os.environ, reason="OPENAI_API_KEY not set") +@pytest.mark.parametrize( + "metric, inputs, metric_params", + [ + (UpTrainMetric.CONTEXT_RELEVANCE, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS}, None), + ( + UpTrainMetric.FACTUAL_ACCURACY, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, + None, + ), + (UpTrainMetric.RESPONSE_RELEVANCE, {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES}, None), + (UpTrainMetric.RESPONSE_COMPLETENESS, {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES}, None), + ( + UpTrainMetric.RESPONSE_COMPLETENESS_WRT_CONTEXT, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, + None, + ), + ( + UpTrainMetric.RESPONSE_CONSISTENCY, + {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, + None, + ), + (UpTrainMetric.RESPONSE_CONCISENESS, {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES}, None), + (UpTrainMetric.CRITIQUE_LANGUAGE, {"responses": DEFAULT_RESPONSES}, None), + (UpTrainMetric.CRITIQUE_TONE, {"responses": DEFAULT_RESPONSES}, {"llm_persona": "idiot"}), + ( + UpTrainMetric.GUIDELINE_ADHERENCE, + {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES}, + {"guideline": "Do nothing", "guideline_name": "somename", "response_schema": None}, + ), + ( + UpTrainMetric.RESPONSE_MATCHING, + { + "ground_truths": [ + "Consumerism is the most popular sport in the world", + "Python language was created by some dude.", + ], + "responses": DEFAULT_RESPONSES, + }, + {"method": "llm"}, + ), + ], +) +def test_integration_run(metric, inputs, metric_params): + init_params = { + "metric": metric, + "metric_params": metric_params, + "api": "openai", + } + eval = UpTrainEvaluator(**init_params) + output = eval.run(**inputs) + + assert type(output) == dict + assert len(output) == 1 + assert "results" in output + assert len(output["results"]) == len(next(iter(inputs.values()))) diff --git a/integrations/uptrain/tests/test_metrics.py b/integrations/uptrain/tests/test_metrics.py new file mode 100644 index 000000000..b73b2aa92 --- /dev/null +++ b/integrations/uptrain/tests/test_metrics.py @@ -0,0 +1,11 @@ +import pytest + +from haystack_integrations.components.evaluators import UpTrainMetric + + +def test_uptrain_metric(): + for e in UpTrainMetric: + assert e == UpTrainMetric.from_str(e.value) + + with pytest.raises(ValueError, match="Unknown UpTrain metric"): + UpTrainMetric.from_str("smugness") diff --git a/integrations/weaviate/docker-compose.yml b/integrations/weaviate/docker-compose.yml new file mode 100644 index 000000000..c61b0ed57 --- /dev/null +++ b/integrations/weaviate/docker-compose.yml @@ -0,0 +1,22 @@ +version: '3.4' +services: + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + image: semitechnologies/weaviate:1.23.2 + ports: + - 8080:8080 + - 50051:50051 + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_MODULES: '' + CLUSTER_HOSTNAME: 'node1' diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 4c15d707e..3d658c316 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -1,17 +1,20 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +import base64 from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, Union from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.document import Document -from haystack.document_stores.protocol import DuplicatePolicy +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.document_stores.types.policy import DuplicatePolicy import weaviate from weaviate.auth import AuthCredentials from weaviate.config import Config, ConnectionConfig from weaviate.embedded import EmbeddedOptions +from weaviate.util import generate_uuid5 Number = Union[int, float] TimeoutType = Union[Tuple[Number, Number], Number] @@ -25,6 +28,20 @@ "weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey, } +# This is the default collection properties for Weaviate. +# It's a list of properties that will be created on the collection. +# These are extremely similar to the Document dataclass, but with a few differences: +# - `id` is renamed to `_original_id` as the `id` field is reserved by Weaviate. +# - `blob` is split into `blob_data` and `blob_mime_type` as it's more efficient to store them separately. +DOCUMENT_COLLECTION_PROPERTIES = [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, +] + class WeaviateDocumentStore: """ @@ -35,7 +52,7 @@ def __init__( self, *, url: Optional[str] = None, - collection_name: str = "default", + collection_settings: Optional[Dict[str, Any]] = None, auth_client_secret: Optional[AuthCredentials] = None, timeout_config: TimeoutType = (10, 60), proxies: Optional[Union[Dict, str]] = None, @@ -49,6 +66,16 @@ def __init__( Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance. :param url: The URL to the weaviate instance, defaults to None. + :param collection_settings: The collection settings to use, defaults to None. + If None it will use a collection named `default` with the following properties: + - _original_id: text + - content: text + - dataframe: text + - blob_data: blob + - blob_mime_type: text + - score: number + See the official `Weaviate documentation`_ + for more information on collections. :param auth_client_secret: Authentication credentials, defaults to None. Can be one of the following types depending on the authentication mode: - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens @@ -80,8 +107,6 @@ def __init__( :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. For a full list of options see `weaviate.embedded.EmbeddedOptions`. :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. - :param collection_name: The name of the collection to use, defaults to "default". - If the collection does not exist it will be created. """ self._client = weaviate.Client( url=url, @@ -98,11 +123,22 @@ def __init__( # Test connection, it will raise an exception if it fails. self._client.schema.get() - if not self._client.schema.exists(collection_name): - self._client.schema.create_class({"class": collection_name}) + if collection_settings is None: + collection_settings = { + "class": "Default", + "properties": DOCUMENT_COLLECTION_PROPERTIES, + } + else: + # Set the class if not set + collection_settings["class"] = collection_settings.get("class", "default").capitalize() + # Set the properties if they're not set + collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) + + if not self._client.schema.exists(collection_settings["class"]): + self._client.schema.create_class(collection_settings) self._url = url - self._collection_name = collection_name + self._collection_settings = collection_settings self._auth_client_secret = auth_client_secret self._timeout_config = timeout_config self._proxies = proxies @@ -124,7 +160,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, url=self._url, - collection_name=self._collection_name, + collection_settings=self._collection_settings, auth_client_secret=auth_client_secret, timeout_config=self._timeout_config, proxies=self._proxies, @@ -155,15 +191,195 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": ) def count_documents(self) -> int: - return 0 + collection_name = self._collection_settings["class"] + res = self._client.query.aggregate(collection_name).with_meta_count().do() + return res.get("data", {}).get("Aggregate", {}).get(collection_name, [{}])[0].get("meta", {}).get("count", 0) + + def _to_data_object(self, document: Document) -> Dict[str, Any]: + """ + Convert a Document to a Weviate data object ready to be saved. + """ + data = document.to_dict(flatten=False) + # Weaviate forces a UUID as an id. + # We don't know if the id of our Document is a UUID or not, so we save it on a different field + # and let Weaviate a UUID that we're going to ignore completely. + data["_original_id"] = data.pop("id") + if (blob := data.pop("blob")) is not None: + # Weaviate wants the blob data as a base64 encoded string + # See the official docs for more information: + # https://weaviate.io/developers/weaviate/config-refs/datatypes#datatype-blob + data["blob_data"] = base64.b64encode(bytes(blob.pop("data"))).decode() + data["blob_mime_type"] = blob.pop("mime_type") + # The embedding vector is stored separately from the rest of the data + del data["embedding"] + + # Weaviate doesn't like empty objects, let's delete meta if it's empty + if data["meta"] == {}: + del data["meta"] + + return data + + def _to_document(self, data: Dict[str, Any]) -> Document: + """ + Convert a data object read from Weaviate into a Document. + """ + data["id"] = data.pop("_original_id") + data["embedding"] = data["_additional"].pop("vector") if data["_additional"].get("vector") else None + + if (blob_data := data.get("blob_data")) is not None: + data["blob"] = { + "data": base64.b64decode(blob_data), + "mime_type": data.get("blob_mime_type"), + } + # We always delete these fields as they're not part of the Document dataclass + data.pop("blob_data") + data.pop("blob_mime_type") + + # We don't need these fields anymore, this usually only contains the uuid + # used by Weaviate to identify the object and the embedding vector that we already extracted. + del data["_additional"] + + return Document.from_dict(data) + + def _query(self, properties: List[str], batch_size: int, cursor=None): + collection_name = self._collection_settings["class"] + query = ( + self._client.query.get( + collection_name, + properties, + ) + .with_additional(["id vector"]) + .with_limit(batch_size) + ) + + if cursor: + # Fetch the next set of results + result = query.with_after(cursor).do() + else: + # Fetch the first set of results + result = query.do() + + if "errors" in result: + errors = [e["message"] for e in result.get("errors", {})] + msg = "\n".join(errors) + msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" + raise DocumentStoreError(msg) + + return result["data"]["Get"][collection_name] def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 - return [] + properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) + properties = [prop["name"] for prop in properties] - def write_documents( - self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE # noqa: ARG002 - ) -> int: - return 0 + result = [] - def delete_documents(self, document_ids: List[str]) -> None: # noqa: ARG002 - return + cursor = None + while batch := self._query(properties, 100, cursor): + # Take the cursor before we convert the batch to Documents as we manipulate + # the batch dictionary and might lose that information. + cursor = batch[-1]["_additional"]["id"] + + for doc in batch: + result.append(self._to_document(doc)) + # Move the cursor to the last returned uuid + return result + + def _batch_write(self, documents: List[Document]) -> int: + """ + Writes document to Weaviate in batches. + Documents with the same id will be overwritten. + Raises in case of errors. + """ + statuses = [] + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) + if self._client.batch.num_objects() == self._client.batch.recommended_num_objects: + # Batch is full, let's create the objects + statuses.extend(self._client.batch.create_objects()) + self._client.batch.add_data_object( + uuid=generate_uuid5(doc.id), + data_object=self._to_data_object(doc), + class_name=self._collection_settings["class"], + vector=doc.embedding, + ) + # Write remaining documents + statuses.extend(self._client.batch.create_objects()) + + errors = [] + # Gather errors and number of written documents + for status in statuses: + result_status = status.get("result", {}).get("status") + if result_status == "FAILED": + errors.extend([e["message"] for e in status["result"]["errors"]["error"]]) + + if errors: + msg = "\n".join(errors) + msg = f"Failed to write documents in Weaviate. Errors:\n{msg}" + raise DocumentStoreError(msg) + + # If the document already exists we get no status message back from Weaviate. + # So we assume that all Documents were written. + return len(documents) + + def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: + """ + Writes documents to Weaviate using the specified policy. + This doesn't uses the batch API, so it's slower than _batch_write. + If policy is set to SKIP it will skip any document that already exists. + If policy is set to FAIL it will raise an exception if any of the documents already exists. + """ + written = 0 + duplicate_errors_ids = [] + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) + + if policy == DuplicatePolicy.SKIP and self._client.data_object.exists( + uuid=generate_uuid5(doc.id), + class_name=self._collection_settings["class"], + ): + # This Document already exists, we skip it + continue + + try: + self._client.data_object.create( + uuid=generate_uuid5(doc.id), + data_object=self._to_data_object(doc), + class_name=self._collection_settings["class"], + vector=doc.embedding, + ) + written += 1 + except weaviate.exceptions.ObjectAlreadyExistsException: + if policy == DuplicatePolicy.FAIL: + duplicate_errors_ids.append(doc.id) + if duplicate_errors_ids: + msg = f"IDs '{', '.join(duplicate_errors_ids)}' already exist in the document store." + raise DuplicateDocumentError(msg) + return written + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes documents to Weaviate using the specified policy. + We recommend using a OVERWRITE policy as it's faster than other policies for Weaviate since it uses + the batch API. + We can't use the batch API for other policies as it doesn't return any information whether the document + already exists or not. That prevents us from returning errors when using the FAIL policy or skipping a + Document when using the SKIP policy. + """ + if policy in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + return self._batch_write(documents) + + return self._write(documents, policy) + + def delete_documents(self, document_ids: List[str]) -> None: + self._client.batch.delete_objects( + class_name=self._collection_settings["class"], + where={ + "path": ["id"], + "operator": "ContainsAny", + "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], + }, + ) diff --git a/integrations/weaviate/tests/conftest.py b/integrations/weaviate/tests/conftest.py new file mode 100644 index 000000000..ed1002409 --- /dev/null +++ b/integrations/weaviate/tests/conftest.py @@ -0,0 +1,8 @@ +from pathlib import Path + +import pytest + + +@pytest.fixture() +def test_files_path(): + return Path(__file__).parent / "test_files" diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 0666151ee..0682282f3 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,6 +1,14 @@ +import base64 from unittest.mock import MagicMock, patch -from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore +import pytest +from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses.document import Document +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack_integrations.document_stores.weaviate.document_store import ( + DOCUMENT_COLLECTION_PROPERTIES, + WeaviateDocumentStore, +) from weaviate.auth import AuthApiKey from weaviate.config import Config from weaviate.embedded import ( @@ -12,7 +20,18 @@ ) -class TestWeaviateDocumentStore: +class TestWeaviateDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + @pytest.fixture + def document_store(self, request) -> WeaviateDocumentStore: + # Use a different index for each test so we can run them in parallel + collection_settings = {"class": f"{request.node.name}"} + store = WeaviateDocumentStore( + url="http://localhost:8080", + collection_settings=collection_settings, + ) + yield store + store._client.schema.delete_class(collection_settings["class"]) + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") def test_init(self, mock_weaviate_client_class): mock_client = MagicMock() @@ -21,7 +40,7 @@ def test_init(self, mock_weaviate_client_class): WeaviateDocumentStore( url="http://localhost:8080", - collection_name="my_collection", + collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey("my_api_key"), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -54,14 +73,15 @@ def test_init(self, mock_weaviate_client_class): # Verify collection is created mock_client.schema.get.assert_called_once() - mock_client.schema.exists.assert_called_once_with("my_collection") - mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"}) + mock_client.schema.exists.assert_called_once_with("My_collection") + mock_client.schema.create_class.assert_called_once_with( + {"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES} + ) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_to_dict(self, _mock_weaviate): document_store = WeaviateDocumentStore( url="http://localhost:8080", - collection_name="my_collection", auth_client_secret=AuthApiKey("my_api_key"), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -77,7 +97,17 @@ def test_to_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", - "collection_name": "my_collection", + "collection_settings": { + "class": "Default", + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + }, "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -113,7 +143,7 @@ def test_from_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", - "collection_name": "my_collection", + "collection_settings": None, "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -144,7 +174,17 @@ def test_from_dict(self, _mock_weaviate): ) assert document_store._url == "http://localhost:8080" - assert document_store._collection_name == "my_collection" + assert document_store._collection_settings == { + "class": "Default", + "properties": [ + {"name": "_original_id", "dataType": ["text"]}, + {"name": "content", "dataType": ["text"]}, + {"name": "dataframe", "dataType": ["text"]}, + {"name": "blob_data", "dataType": ["blob"]}, + {"name": "blob_mime_type", "dataType": ["text"]}, + {"name": "score", "dataType": ["number"]}, + ], + } assert document_store._auth_client_secret == AuthApiKey("my_api_key") assert document_store._timeout_config == (10, 60) assert document_store._proxies == {"http": "http://proxy:1234"} @@ -161,3 +201,85 @@ def test_from_dict(self, _mock_weaviate): assert document_store._additional_config.grpc_port_experimental == 12345 assert document_store._additional_config.connection_config.session_pool_connections == 20 assert document_store._additional_config.connection_config.session_pool_maxsize == 20 + + def test_count_not_empty(self, document_store): + # Skipped for the time being as we don't support writing documents + pass + + def test_to_data_object(self, document_store, test_files_path): + doc = Document(content="test doc") + data = document_store._to_data_object(doc) + assert data == { + "_original_id": doc.id, + "content": doc.content, + "dataframe": None, + "score": None, + } + + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + doc = Document( + content="test doc", + blob=image, + embedding=[1, 2, 3], + meta={"key": "value"}, + ) + data = document_store._to_data_object(doc) + assert data == { + "_original_id": doc.id, + "content": doc.content, + "blob_data": base64.b64encode(image.data).decode(), + "blob_mime_type": "image/jpeg", + "dataframe": None, + "score": None, + "meta": {"key": "value"}, + } + + def test_to_document(self, document_store, test_files_path): + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + data = { + "_additional": { + "vector": [1, 2, 3], + }, + "_original_id": "123", + "content": "some content", + "blob_data": base64.b64encode(image.data).decode(), + "blob_mime_type": "image/jpeg", + "dataframe": None, + "score": None, + "meta": {"key": "value"}, + } + + doc = document_store._to_document(data) + assert doc.id == "123" + assert doc.content == "some content" + assert doc.blob == image + assert doc.embedding == [1, 2, 3] + assert doc.score is None + assert doc.meta == {"key": "value"} + + def test_write_documents(self, document_store): + """ + Test write_documents() with default policy overwrites existing documents. + """ + doc = Document(content="test doc") + assert document_store.write_documents([doc]) == 1 + assert document_store.count_documents() == 1 + + doc.content = "test doc 2" + assert document_store.write_documents([doc]) == 1 + assert document_store.count_documents() == 1 + + def test_write_documents_with_blob_data(self, document_store, test_files_path): + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + doc = Document(content="test doc", blob=image) + assert document_store.write_documents([doc]) == 1 + + def test_filter_documents_with_blob_data(self, document_store, test_files_path): + image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") + doc = Document(content="test doc", blob=image) + assert document_store.write_documents([doc]) == 1 + + docs = document_store.filter_documents() + + assert len(docs) == 1 + assert docs[0].blob == image diff --git a/integrations/weaviate/tests/test_files/robot1.jpg b/integrations/weaviate/tests/test_files/robot1.jpg new file mode 100644 index 000000000..a3962db1b Binary files /dev/null and b/integrations/weaviate/tests/test_files/robot1.jpg differ