diff --git a/.github/labeler.yml b/.github/labeler.yml index 355e37231..319f7c726 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,6 +4,11 @@ integration:amazon-bedrock: - any-glob-to-any-file: "integrations/amazon_bedrock/**/*" - any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml" +integration:astra: + - changed-files: + - any-glob-to-any-file: "integrations/astra/**/*" + - any-glob-to-any-file: ".github/workflows/astra.yml" + integration:chroma: - changed-files: - any-glob-to-any-file: "integrations/chroma/**/*" diff --git a/integrations/astra/examples/example.py b/integrations/astra/examples/example.py index af2c12b4c..ac93f43ed 100644 --- a/integrations/astra/examples/example.py +++ b/integrations/astra/examples/example.py @@ -1,3 +1,4 @@ +import logging import os from pathlib import Path @@ -12,9 +13,13 @@ from astra_haystack.document_store import AstraDocumentStore from astra_haystack.retriever import AstraRetriever +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + HERE = Path(__file__).resolve().parent file_paths = [HERE / "data" / Path(name) for name in os.listdir("integrations/astra/examples/data")] -print(file_paths) +logger.info(file_paths) astra_id = os.getenv("ASTRA_DB_ID", "") astra_region = os.getenv("ASTRA_DB_REGION", "us-east1") @@ -67,15 +72,18 @@ question = "This chapter introduces the manuals available with Vim" result = q.run({"embedder": {"text": question}, "retriever": {"top_k": 1}}) -print(result) - -print("count:") -print(document_store.count_documents()) -assert document_store.count_documents() == 9 - -print("filter:") -print( - document_store.filter_documents( +logger.info(result) + +ALL_DOCUMENTS_COUNT = 9 +documents_count = document_store.count_documents() +logger.info("count:") +logger.info(documents_count) +if documents_count != ALL_DOCUMENTS_COUNT: + msg = f"count mismatch, expected 9 documents, got {documents_count}" + raise ValueError(msg) + +logger.info( + f"""filter results: {document_store.filter_documents( { "field": "meta", "operator": "==", @@ -85,22 +93,27 @@ }, } ) +}""" ) -print("get_document_by_id") -print(document_store.get_document_by_id("92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10")) -print("get_documents_by_ids") -print( - document_store.get_documents_by_id( +logger.info( + f"""get_document_by_id {document_store.get_document_by_id( + "92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10")}""" +) + +logger.info( + f"""get_documents_by_ids {document_store.get_documents_by_id( [ "92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10", "6f2450a51eaa3eeb9239d875402bcfe24b2d3534ff27f26c1f3fc8133b04e756", ] - ) + )}""" ) document_store.delete_documents(["92ef055fbae55b2b0fc79d34cbf8a80b0ad7700ca526053223b0cc6d1351df10"]) -print("count:") -print(document_store.count_documents()) -assert document_store.count_documents() == 8 +documents_count = document_store.count_documents() +logger.info(f"count: {document_store.count_documents()}") +if documents_count != ALL_DOCUMENTS_COUNT - 1: + msg = f"count mismatch, expected 9 documents, got {documents_count}" + raise ValueError(msg) diff --git a/integrations/astra/examples/pipeline_example.py b/integrations/astra/examples/pipeline_example.py index 1fd49fd44..fb13c3d93 100644 --- a/integrations/astra/examples/pipeline_example.py +++ b/integrations/astra/examples/pipeline_example.py @@ -1,3 +1,4 @@ +import logging import os from haystack import Document, Pipeline @@ -11,6 +12,9 @@ from astra_haystack.document_store import AstraDocumentStore from astra_haystack.retriever import AstraRetriever +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + # Create a RAG query pipeline prompt_template = """ Given these documents, answer the question. @@ -48,10 +52,12 @@ documents = [ Document(content="There are over 7,000 languages spoken around the world today."), Document( - content="Elephants have been observed to behave in a way that indicates a high level of self-awareness, such as recognizing themselves in mirrors." + content="Elephants have been observed to behave in a way that indicates" + " a high level of self-awareness, such as recognizing themselves in mirrors." ), Document( - content="In certain parts of the world, like the Maldives, Puerto Rico, and San Diego, you can witness the phenomenon of bioluminescent waves." + content="In certain parts of the world, like the Maldives, Puerto Rico, " + "and San Diego, you can witness the phenomenon of bioluminescent waves." ), ] p = Pipeline() @@ -97,4 +103,5 @@ "answer_builder": {"query": question}, } ) -print(result) + +logger.info(result) diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index aa891b257..0c53b7c79 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -93,11 +93,6 @@ target-version = ["py37"] line-length = 120 skip-string-normalization = true -[tool.isort] -profile = 'black' -line_length = 79 -skip_gitignore = true - [tool.ruff] target-version = "py37" line-length = 120 @@ -142,6 +137,7 @@ unfixable = [ # Don't touch unused imports "F401", ] +exclude = ["example"] [tool.ruff.isort] known-first-party = ["astra_haystack"] @@ -157,6 +153,9 @@ ban-relative-imports = "all" source_pkgs = ["astra_haystack", "tests"] branch = true parallel = true +omit = [ + "example" +] [tool.coverage.paths] astra_haystack = ["src/astra_haystack", "*/astra-store/src/astra_haystack"] diff --git a/integrations/astra/src/astra_haystack/astra_client.py b/integrations/astra/src/astra_haystack/astra_client.py index 3f21ef628..ec0263a5a 100644 --- a/integrations/astra/src/astra_haystack/astra_client.py +++ b/integrations/astra/src/astra_haystack/astra_client.py @@ -10,7 +10,7 @@ @dataclass class Response: - id: str + document_id: str text: Optional[str] values: Optional[list] metadata: Optional[dict] @@ -80,13 +80,15 @@ def find_index(self): collection_embedding_dim = collection_name_matches[0]["options"]["vector"]["dimension"] if collection_embedding_dim != self.embedding_dim: - raise Exception( + msg = ( f"Collection vector dimension is not valid, expected {self.embedding_dim}, " f"found {collection_embedding_dim}" ) + raise Exception(msg) else: - raise Exception(f"status not in response: {response.text}") + msg = f"status not in response: {response.text}" + raise Exception(msg) return True @@ -107,9 +109,8 @@ def create_index(self): def query( self, vector: Optional[List[float]] = None, - filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + query_filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, top_k: Optional[int] = None, - namespace: Optional[str] = None, include_metadata: Optional[bool] = None, include_values: Optional[bool] = None, ) -> QueryResponse: @@ -122,7 +123,7 @@ def query( being queried. Each `query()` request can contain only one of the parameters `queries`, `id` or `vector`... [optional] top_k (int): The number of results to return for each query. Must be an integer greater than 1. - filter (Dict[str, Union[str, float, int, bool, List, dict]): + query_filter (Dict[str, Union[str, float, int, bool, List, dict]): The filter to apply. You can use vector metadata to limit your search. [optional] include_metadata (bool): Indicates whether metadata is included in the response as well as the ids. If omitted the server will use the default value of False [optional] @@ -134,9 +135,9 @@ def query( """ # get vector data and scores if vector is None: - responses = self._query_without_vector(top_k, filter) + responses = self._query_without_vector(top_k, query_filter) else: - responses = self._query(vector, top_k, filter) + responses = self._query(vector, top_k, query_filter) # include_metadata means return all columns in the table (including text that got embedded) # include_values means return the vector of the embedding for the searched items @@ -158,7 +159,7 @@ def _format_query_response(responses, include_metadata, include_values): score = response.pop("$similarity", None) text = response.pop("content", None) values = response.pop("$vector", None) if include_values else [] - metadata = response if include_metadata else dict() # Add all remaining fields to the metadata + metadata = response if include_metadata else {} # Add all remaining fields to the metadata rsp = Response(_id, text, values, metadata, score) final_res.append(rsp) return QueryResponse(final_res) @@ -185,7 +186,7 @@ def find_documents(self, find_query): if "data" in response_dict and "documents" in response_dict["data"]: return response_dict["data"]["documents"] else: - logger.warning("No documents found", response_dict) + logger.warning(f"No documents found: {response_dict}") def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: document_batch = [] @@ -253,14 +254,14 @@ def delete( self, ids: Optional[List[str]] = None, delete_all: Optional[bool] = None, - filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, ) -> int: if delete_all: - query = {"deleteMany": {}} + query = {"deleteMany": {}} # type: dict if ids is not None: query = {"deleteMany": {"filter": {"_id": {"$in": ids}}}} - if filter is not None: - query = {"deleteMany": {"filter": filter}} + if filters is not None: + query = {"deleteMany": {"filter": filters}} deletion_counter = 0 moredata = True diff --git a/integrations/astra/src/astra_haystack/document_store.py b/integrations/astra/src/astra_haystack/document_store.py index 6d1a887dd..3aa90a276 100644 --- a/integrations/astra/src/astra_haystack/document_store.py +++ b/integrations/astra/src/astra_haystack/document_store.py @@ -21,7 +21,9 @@ from astra_haystack.filters import _convert_filters logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) + + +MAX_BATCH_SIZE = 20 def _batches(input_list, batch_size): @@ -44,7 +46,7 @@ def __init__( astra_keyspace: str, astra_collection: str, embedding_dim: Optional[int] = 768, - duplicates_policy: Optional[DuplicatePolicy] = DuplicatePolicy.NONE, + duplicates_policy: DuplicatePolicy = DuplicatePolicy.NONE, similarity: str = "cosine", ): """ @@ -62,7 +64,8 @@ def __init__( :param similarity: The similarity function used to compare document vectors. :param duplicates_policy: Handle duplicate documents based on DuplicatePolicy parameter options. Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, it is skipped and not written. + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + it is skipped and not written. - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. @@ -107,7 +110,7 @@ def write_documents( self, documents: List[Document], index: Optional[str] = None, - batch_size: Optional[int] = 20, + batch_size: int = 20, policy: DuplicatePolicy = DuplicatePolicy.NONE, ): """ @@ -119,13 +122,19 @@ def write_documents( :param batch_size: Number of documents that are passed to bulk function at a time. :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, it is skipped and not written. - - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + it is skipped and not written. + - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, + it is skipped and not written. - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. :return: int """ + if index is None and self.index is None: + msg = "No Astra client provided" + raise ValueError(msg) + if index is None: index = self.index @@ -135,12 +144,12 @@ def write_documents( else: policy = DuplicatePolicy.SKIP - if batch_size > 20: + if batch_size > MAX_BATCH_SIZE: logger.warning( f"batch_size set to {batch_size}, " f"but maximum batch_size for Astra when using the JSON API is 20. batch_size set to 20." ) - batch_size = 20 + batch_size = MAX_BATCH_SIZE def _convert_input_document(document: Union[dict, Document]): if isinstance(document, Document): @@ -148,21 +157,22 @@ def _convert_input_document(document: Union[dict, Document]): elif isinstance(document, dict): document_dict = document else: - raise ValueError(f"Unsupported type for documents, documents is of type {type(document)}.") + msg = f"Unsupported type for documents, documents is of type {type(document)}." + raise ValueError(msg) if "id" in document_dict: if "_id" not in document_dict: document_dict["_id"] = document_dict.pop("id") elif "_id" in document_dict: - raise Exception( - f"Duplicate id definitions, both 'id' and '_id' present in document {document_dict}" - ) + msg = f"Duplicate id definitions, both 'id' and '_id' present in document {document_dict}" + raise Exception(msg) if "_id" in document_dict: if not isinstance(document_dict["_id"], str): - raise Exception( + msg = ( f"Document id {document_dict['_id']} is not a string, " f"but is of type {type(document_dict['_id'])}" ) + raise Exception(msg) if "dataframe" in document_dict and document_dict["dataframe"] is not None: document_dict["dataframe"] = document_dict.pop("dataframe").to_json() @@ -180,7 +190,8 @@ def _convert_input_document(document: Union[dict, Document]): response = self.index.find_documents({"filter": {"_id": doc["_id"]}}) if response: if policy == DuplicatePolicy.FAIL: - raise DuplicateDocumentError(f"ID '{doc['_id']}' already exists.") + msg = f"ID '{doc['_id']}' already exists." + raise DuplicateDocumentError(msg) duplicate_documents.append(doc) else: new_documents.append(doc) @@ -264,13 +275,17 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc for vector in vectors: converted_filters = _convert_filters(filters) results = self.index.query( - vector=vector, filter=converted_filters, top_k=1000, include_values=True, include_metadata=True + vector=vector, + query_filter=converted_filters, + top_k=1000, + include_values=True, + include_metadata=True, ) documents.extend(self._get_result_to_documents(results)) else: converted_filters = _convert_filters(filters) results = self.index.query( - vector=vector, filter=converted_filters, top_k=1000, include_values=True, include_metadata=True + vector=vector, query_filter=converted_filters, top_k=1000, include_values=True, include_metadata=True ) documents = self._get_result_to_documents(results) return documents @@ -286,7 +301,7 @@ def _get_result_to_documents(results) -> List[Document]: df = None document = Document( content=match.text, - id=match.id, + id=match.document_id, embedding=match.values, dataframe=df, blob=match.metadata.pop("blob", None), @@ -313,7 +328,8 @@ def get_document_by_id(self, document_id: str) -> Document: document = self.index.get_documents(ids=[document_id]) ret = self._get_result_to_documents(document) if not ret: - raise MissingDocumentError(f"Document {document_id} does not exist") + msg = f"Document {document_id} does not exist" + raise MissingDocumentError(msg) return ret[0] def search( @@ -335,7 +351,7 @@ def search( self.index.query( vector=query_embedding, top_k=top_k, - filter=converted_filters, + query_filter=converted_filters, include_metadata=True, include_values=True, ) @@ -344,7 +360,7 @@ def search( return result - def delete_documents(self, document_ids: List[str] = None, delete_all: Optional[bool] = None) -> None: + def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: Optional[bool] = None) -> None: """ Deletes all documents with a matching document_ids from the document store. Fails with `MissingDocumentError` if no document with this id is present in the store. @@ -356,13 +372,14 @@ def delete_documents(self, document_ids: List[str] = None, delete_all: Optional[ deletion_counter = 0 if self.index.count_documents() > 0: if document_ids is not None: - for batch in _batches(document_ids, 20): + for batch in _batches(document_ids, MAX_BATCH_SIZE): deletion_counter += self.index.delete(ids=batch) else: deletion_counter = self.index.delete(delete_all=delete_all) logger.info(f"{deletion_counter} documents deleted") if document_ids is not None and deletion_counter == 0: - raise MissingDocumentError(f"Document {document_ids} does not exist") + msg = f"Document {document_ids} does not exist" + raise MissingDocumentError(msg) else: logger.info("No documents in document store") diff --git a/integrations/astra/src/astra_haystack/filters.py b/integrations/astra/src/astra_haystack/filters.py index 605d5ecaa..6b628486b 100644 --- a/integrations/astra/src/astra_haystack/filters.py +++ b/integrations/astra/src/astra_haystack/filters.py @@ -32,19 +32,19 @@ def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[ else: if key == "id": filter_statements[key] = {"_id": value} - if key != "$in" and type(value) is list: + if key != "$in" and isinstance(value, list): filter_statements[key] = {"$in": value} + elif isinstance(value, pd.DataFrame): + filter_statements[key] = value.to_json() + elif isinstance(value, dict): + for dkey, dvalue in value.items(): + if dkey == "$in" and not isinstance(dvalue, list): + exception_message = f"$in operator must have `ARRAY`, got {dvalue} of type {type(dvalue)}" + raise FilterError(exception_message) + converted = {dkey: dvalue} + filter_statements[key] = converted else: - if type(value) is pd.DataFrame: - filter_statements[key] = value.to_json() - elif type(value) is dict: - for dkey, dvalue in value.items(): - if dkey == "$in" and type(dvalue) is not list: - raise FilterError(f"$in operator must have `ARRAY`, got {dvalue} of type {type(dvalue)}") - converted = {dkey: dvalue} - filter_statements[key] = converted - else: - filter_statements[key] = value + filter_statements[key] = value return filter_statements @@ -77,7 +77,8 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: if len(conditions) > 1: conditions = _normalize_ranges(conditions) if operator not in OPERATORS: - raise FilterError(f"Unknown operator {operator}") + msg = f"Unknown operator {operator}" + raise FilterError(msg) return {OPERATORS[operator]: conditions} diff --git a/integrations/astra/src/astra_haystack/retriever.py b/integrations/astra/src/astra_haystack/retriever.py index 22c8f2664..47304df2c 100644 --- a/integrations/astra/src/astra_haystack/retriever.py +++ b/integrations/astra/src/astra_haystack/retriever.py @@ -28,7 +28,8 @@ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[st self.document_store = document_store if not isinstance(document_store, AstraDocumentStore): - raise Exception("document_store must be an instance of AstraDocumentStore") + message = "document_store must be an instance of AstraDocumentStore" + raise Exception(message) @component.output_types(documents=List[Document]) def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): @@ -36,7 +37,8 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = Args: query_embedding (List[str]): An input list of queries - filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space. Defaults to None. + filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space. + Defaults to None. top_k (Optional[int], optional): The maximum number of documents to retrieve. Defaults to None. """ diff --git a/integrations/astra/tests/__init__.py b/integrations/astra/tests/__init__.py index ad09dadb6..f5e799e88 100644 --- a/integrations/astra/tests/__init__.py +++ b/integrations/astra/tests/__init__.py @@ -1,6 +1,3 @@ # SPDX-FileCopyrightText: 2023-present Anant Corporation # # SPDX-License-Identifier: Apache-2.0 -import sys - -sys.path.append("../src/astra_haystack/")