From 067adba2673304bf7f02a642393f1338c7298524 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Tue, 22 Oct 2024 06:33:36 -0700 Subject: [PATCH] feat: Update astradb integration for latest client library (#1145) * Update astradb integration for latest client library * Update CHANGELOG.md * Ruff check update * Black linting updates * Tweak to versioning for astrapy * removing CHANGELOG.MD changes since those are automatically added --------- Co-authored-by: David S. Batista --- integrations/astra/README.md | 24 ++- integrations/astra/examples/requirements.txt | 2 +- integrations/astra/pyproject.toml | 7 +- .../document_stores/astra/astra_client.py | 173 ++++++++---------- .../astra/tests/test_document_store.py | 26 +-- 5 files changed, 109 insertions(+), 123 deletions(-) diff --git a/integrations/astra/README.md b/integrations/astra/README.md index f679b7207..9ee47b8c9 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -6,17 +6,18 @@ ```bash pip install astra-haystack - ``` ### Local Development + install astra-haystack package locally to run integration tests: Open in gitpod: [![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/Anant/astra-haystack/tree/main) -Switch Python version to 3.9 (Requires 3.8+ but not 3.12) -``` +Switch Python version to 3.9 (Requires 3.9+ but not 3.12) + +```bash pyenv install 3.9 pyenv local 3.9 ``` @@ -33,7 +34,8 @@ Install requirements `pip install -r requirements.txt` Export environment variables -``` + +```bash export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." export COLLECTION_NAME="my_collection" @@ -49,22 +51,25 @@ or This package includes Astra Document Store and Astra Embedding Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. -### In order to use the Document Store directly: +### Use the Document Store Directly Import the Document Store: -``` + +```python from haystack_integrations.document_stores.astra import AstraDocumentStore from haystack.document_stores.types.policy import DuplicatePolicy ``` Load in environment variables: -``` + +```python namespace = os.environ.get("ASTRA_DB_KEYSPACE") collection_name = os.environ.get("COLLECTION_NAME", "haystack_vector_search") ``` Create the Document Store object (API Endpoint and Token are read off the environment): -``` + +```python document_store = AstraDocumentStore( collection_name=collection_name, namespace=namespace, @@ -80,7 +85,7 @@ Then you can use the document store functions like count_document below: Create the Document Store object like above, then import and create the Pipeline: -``` +```python from haystack import Pipeline pipeline = Pipeline() ``` @@ -101,7 +106,6 @@ or, > Astra DB collection '...' is detected as having the following indexing policy: {...}. This does not match the requested indexing policy for this object: {...}. In particular, there may be stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. - The reason for the warning is that the requested collection already exists on the database, and it is configured to [index all of its fields for search](https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option), possibly implicitly, by default. When the Haystack object tries to create it, it attempts to enforce, instead, an indexing policy tailored to the prospected usage: this is both to enable storing very long texts and to avoid indexing fields that will never be used in filtering a search (indexing those would also have a slight performance cost for writes). Typically there are two reasons why you may encounter the warning: diff --git a/integrations/astra/examples/requirements.txt b/integrations/astra/examples/requirements.txt index 710749bbe..221138666 100644 --- a/integrations/astra/examples/requirements.txt +++ b/integrations/astra/examples/requirements.txt @@ -1,4 +1,4 @@ haystack-ai sentence_transformers==2.2.2 openai==1.6.1 -astrapy>=0.7.7 \ No newline at end of file +astrapy>=1.5.0,<2.0 diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index f9e8fe982..5645cd5d3 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -7,7 +7,7 @@ name = "astra-haystack" dynamic = ["version"] description = '' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "Anant Corporation", email = "support@anant.us" }] @@ -15,14 +15,13 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy"] +dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy>=1.5.0,<2.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra#readme" @@ -57,7 +56,7 @@ cov = ["test-cov", "cov-report"] cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] installer = "uv" diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index b594f87d3..6f2289786 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Union from warnings import warn -from astrapy.api import APIRequestError -from astrapy.db import AstraDB +from astrapy import DataAPIClient as AstraDBClient +from astrapy.constants import ReturnDocument +from astrapy.exceptions import CollectionAlreadyExistsException from haystack.version import __version__ as integration_version from pydantic.dataclasses import dataclass @@ -65,83 +66,78 @@ def __init__( self.similarity_function = similarity_function self.namespace = namespace - # Build the Astra DB object - self._astra_db = AstraDB( + # Get the keyspace from the collection name + my_client = AstraDBClient( + callers=[(CALLER_NAME, integration_version)], + ) + + # Get the database object + self._astra_db = my_client.get_database( api_endpoint=api_endpoint, token=token, - namespace=namespace, - caller_name=CALLER_NAME, - caller_version=integration_version, + keyspace=namespace, ) - indexing_options = {"indexing": {"deny": NON_INDEXED_FIELDS}} + indexing_options = {"deny": NON_INDEXED_FIELDS} try: # Create and connect to the newly created collection self._astra_db_collection = self._astra_db.create_collection( - collection_name=collection_name, + name=collection_name, dimension=embedding_dimension, - options=indexing_options, + indexing=indexing_options, ) - except APIRequestError: + except CollectionAlreadyExistsException as _: # possibly the collection is preexisting and has legacy # indexing settings: verify - get_coll_response = self._astra_db.get_collections(options={"explain": True}) - - collections = (get_coll_response["status"] or {}).get("collections") or [] - - preexisting = [collection for collection in collections if collection["name"] == collection_name] + preexisting = [ + coll_descriptor + for coll_descriptor in self._astra_db.list_collections() + if coll_descriptor.name == collection_name + ] if preexisting: - pre_collection = preexisting[0] # if it has no "indexing", it is a legacy collection; - # otherwise it's unexpected warn and proceed at user's risk - pre_col_options = pre_collection.get("options") or {} - if "indexing" not in pre_col_options: + # otherwise it's unexpected: warn and proceed at user's risk + pre_col_idx_opts = preexisting[0].options.indexing or {} + if not pre_col_idx_opts: warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having indexing turned on for all " - "fields (either created manually or by older " - "versions of this plugin). This implies stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' is detected as " + "having indexing turned on for all fields " + "(either created manually or by older versions " + "of this plugin). This implies stricter " + "limitations on the amount of text" + " each entry can store. Consider indexing anew on a" + " fresh collection to be able to store longer texts." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._astra_db_collection = self._astra_db.get_collection( + collection_name, + ) + # check if the indexing options match entirely + elif pre_col_idx_opts == indexing_options: + self._astra_db_collection = self._astra_db.get_collection( + collection_name, ) - elif pre_col_options["indexing"] != indexing_options["indexing"]: - detected_options_json = json.dumps(pre_col_options["indexing"]) - indexing_options_json = json.dumps(indexing_options["indexing"]) + else: + options_json = json.dumps(pre_col_idx_opts) warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having the following indexing policy: " - f"{detected_options_json}. This does not match the requested " - f"indexing policy for this object: {indexing_options_json}. " - "In particular, there may be stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider indexing anew on a" + " fresh collection." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._collection = self._astra_db.get_collection( + collection_name, ) - else: - # the collection mismatch lies elsewhere than the indexing - raise else: # other exception raise @@ -180,7 +176,7 @@ def query( return formatted_response def _query_without_vector(self, top_k, filters=None): - query = {"filter": filters, "options": {"limit": top_k}} + query = {"filter": filters, "limit": top_k} return self.find_documents(query) @@ -196,8 +192,11 @@ def _format_query_response(responses, include_metadata, include_values): score = response.pop("$similarity", None) text = response.pop("content", None) values = response.pop("$vector", None) if include_values else [] + metadata = response if include_metadata else {} # Add all remaining fields to the metadata + rsp = Response(_id, text, values, metadata, score) + final_res.append(rsp) return QueryResponse(final_res) @@ -219,17 +218,21 @@ def find_documents(self, find_query): :param find_query: a dictionary with the query options :returns: the documents found in the index """ - response_dict = self._astra_db_collection.find( + find_cursor = self._astra_db_collection.find( filter=find_query.get("filter"), sort=find_query.get("sort"), - options=find_query.get("options"), + limit=find_query.get("limit"), projection={"*": 1}, ) - if "data" in response_dict and "documents" in response_dict["data"]: - return response_dict["data"]["documents"] - else: - logger.warning(f"No documents found: {response_dict}") + find_results = [] + for result in find_cursor: + find_results.append(result) + + if not find_results: + logger.warning("No documents found.") + + return find_results def find_one_document(self, find_query): """ @@ -238,16 +241,15 @@ def find_one_document(self, find_query): :param find_query: a dictionary with the query options :returns: the document found in the index """ - response_dict = self._astra_db_collection.find_one( + find_result = self._astra_db_collection.find_one( filter=find_query.get("filter"), - options=find_query.get("options"), projection={"*": 1}, ) - if "data" in response_dict and "document" in response_dict["data"]: - return response_dict["data"]["document"] - else: - logger.warning(f"No document found: {response_dict}") + if not find_result: + logger.warning("No document found.") + + return find_result def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: """ @@ -281,15 +283,8 @@ def insert(self, documents: List[Dict]): :param documents: a list of documents to insert :returns: the IDs of the inserted documents """ - response_dict = self._astra_db_collection.insert_many(documents=documents) - - inserted_ids = ( - response_dict["status"]["insertedIds"] - if "status" in response_dict and "insertedIds" in response_dict["status"] - else [] - ) - if "errors" in response_dict: - logger.error(response_dict["errors"]) + insert_result = self._astra_db_collection.insert_many(documents=documents) + inserted_ids = [str(_id) for _id in insert_result.inserted_ids] return inserted_ids @@ -303,23 +298,21 @@ def update_document(self, document: Dict, id_key: str): """ document_id = document.pop(id_key) - response_dict = self._astra_db_collection.find_one_and_update( + update_result = self._astra_db_collection.find_one_and_update( filter={id_key: document_id}, update={"$set": document}, - options={"returnDocument": "after"}, + return_document=ReturnDocument.AFTER, projection={"*": 1}, ) document[id_key] = document_id - if "status" in response_dict and "errors" not in response_dict: - if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: - if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: - return True + if update_result is None: + logger.warning(f"Documents {document_id} not updated in Astra DB.") - logger.warning(f"Documents {document_id} not updated in Astra DB.") + return False - return False + return True def delete( self, @@ -345,23 +338,13 @@ def delete( if "filter" in query["deleteMany"]: filter_dict = query["deleteMany"]["filter"] - deletion_counter = 0 - moredata = True - while moredata: - response_dict = self._astra_db_collection.delete_many(filter=filter_dict) - - if "moreData" not in response_dict.get("status", {}): - moredata = False + delete_result = self._astra_db_collection.delete_many(filter=filter_dict) - deletion_counter += int(response_dict["status"].get("deletedCount", 0)) + return delete_result.deleted_count - return deletion_counter - - def count_documents(self) -> int: + def count_documents(self, upper_bound: int = 10000) -> int: """ Count the number of documents in the Astra index. :returns: the number of documents in the index """ - documents_count = self._astra_db_collection.count_documents() - - return documents_count["status"]["count"] + return self._astra_db_collection.count_documents({}, upper_bound=upper_bound) diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index c4d1b6347..ef00b6b25 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -20,25 +20,14 @@ def mock_auth(monkeypatch): monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token") -@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") +@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient") def test_init_is_lazy(_mock_client, mock_auth): # noqa _ = AstraDocumentStore() _mock_client.assert_not_called() -def test_namespace_init(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: - _ = AstraDocumentStore().index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] is None - - _ = AstraDocumentStore(namespace="foo").index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] == "foo" - - def test_to_dict(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient"): ds = AstraDocumentStore() result = ds.to_dict() assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" @@ -206,6 +195,17 @@ def test_filter_documents_by_id(self, document_store): result = document_store.filter_documents(filters={"field": "id", "operator": "==", "value": "1"}) self.assert_documents_are_equal(result, [docs[0]]) + def test_filter_documents_by_in_operator(self, document_store): + docs = [Document(id="3", content="test doc 3"), Document(id="4", content="test doc 4")] + document_store.write_documents(docs) + result = document_store.filter_documents(filters={"field": "id", "operator": "in", "value": ["3", "4"]}) + + # Sort the result in place by the id field + result.sort(key=lambda x: x.id) + + self.assert_documents_are_equal([result[0]], [docs[0]]) + self.assert_documents_are_equal([result[1]], [docs[1]]) + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass