diff --git a/integrations/astra/README.md b/integrations/astra/README.md index d14544df4..f8b6f7c31 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -3,6 +3,13 @@ # Astra Store ## Installation + +```bash +pip install astra-haystack + +``` + +### Local Development install astra-haystack package locally to run integration tests: Open in gitpod: @@ -46,8 +53,8 @@ This package includes Astra Document Store and Astra Embedding Retriever classes Import the Document Store: ``` -from astra_store.document_store import AstraDocumentStore -from haystack.preview.document_stores import DuplicatePolicy +from haystack_integrations.document_stores.astra import AstraDocumentStore +from haystack.document_stores.types.policy import DuplicatePolicy ``` Load in environment variables: @@ -76,7 +83,7 @@ Then you can use the document store functions like count_document below: Create the Document Store object like above, then import and create the Pipeline: ``` -from haystack.preview import Pipeline +from haystack import Pipeline pipeline = Pipeline() ``` diff --git a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py index 7236b5749..2b9ac7d28 100644 --- a/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py +++ b/integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py @@ -13,15 +13,28 @@ class AstraEmbeddingRetriever: """ A component for retrieving documents from an AstraDocumentStore. + + Usage example: + ```python + from haystack_integrations.document_stores.astra import AstraDocumentStore + from haystack_integrations.components.retrievers.astra import AstraEmbeddingRetriever + + document_store = AstraDocumentStore( + api_endpoint=api_endpoint, + token=token, + collection_name=collection_name, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, + ) + + retriever = AstraEmbeddingRetriever(document_store=document_store) + ``` """ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): """ - Create an AstraEmbeddingRetriever component. Usually you pass some basic configuration - parameters to the constructor. - - :param filters: A dictionary with filters to narrow down the search space (default is None). - :param top_k: The maximum number of documents to retrieve (default is 10). + :param filters: a dictionary with filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. """ self.filters = filters self.top_k = top_k @@ -33,13 +46,13 @@ def __init__(self, document_store: AstraDocumentStore, filters: Optional[Dict[st @component.output_types(documents=List[Document]) def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): - """Run the retriever on the given list of queries. + """Retrieve documents from the AstraDocumentStore. - Args: - query_embedding (List[str]): An input list of queries - filters (Optional[Dict[str, Any]], optional): A dictionary with filters to narrow down the search space. - Defaults to None. - top_k (Optional[int], optional): The maximum number of documents to retrieve. Defaults to None. + :param query_embedding: floats representing the query embedding + :param filters: filters to narrow down the search space. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - documents: A list of documents retrieved from the AstraDocumentStore. """ if not top_k: @@ -51,6 +64,12 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = return {"documents": self.document_store.search(query_embedding, top_k, filters=filters)} def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, filters=self.filters, @@ -60,6 +79,14 @@ def to_dict(self) -> Dict[str, Any]: @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ document_store = AstraDocumentStore.from_dict(data["init_parameters"]["document_store"]) data["init_parameters"]["document_store"] = document_store return default_from_dict(cls, data) diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6e7b1a33f..c1eb1f6a7 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -43,6 +43,19 @@ def __init__( similarity_function: str, namespace: Optional[str] = None, ): + """ + The connection to Astra DB is established and managed through the JSON API. + The required credentials (api endpoint and application token) can be generated + through the UI by clicking and the connect tab, and then selecting JSON API and + Generate Configuration. + + :param api_endpoint: the Astra DB API endpoint. + :param token: the Astra DB application token. + :param collection_name: the current collection in the keyspace in the current Astra DB. + :param embedding_dimension: dimension of embedding vector. + :param similarity_function: the similarity function to use for the index. + :param namespace: the namespace to use for the collection. + """ self.api_endpoint = api_endpoint self.token = token self.collection_name = collection_name @@ -119,23 +132,17 @@ def query( include_values: Optional[bool] = None, ) -> QueryResponse: """ - The Query operation searches a namespace, using a query vector. - It retrieves the ids of the most similar items in a namespace, along with their similarity scores. - - Args: - vector (List[float]): The query vector. This should be the same length as the dimension of the index - being queried. Each `query()` request can contain only one of the parameters - `queries`, `id` or `vector`... [optional] - top_k (int): The number of results to return for each query. Must be an integer greater than 1. - query_filter (Dict[str, Union[str, float, int, bool, List, dict]): - The filter to apply. You can use vector metadata to limit your search. [optional] - include_metadata (bool): Indicates whether metadata is included in the response as well as the ids. - If omitted the server will use the default value of False [optional] - include_values (bool): Indicates whether values/vector is included in the response as well as the ids. - If omitted the server will use the default value of False [optional] - - Returns: object which contains the list of the closest vectors as ScoredVector objects, - and namespace name. + Search the Astra index using a query vector. + + :param vector: the query vector. This should be the same length as the dimension of the index being queried. + Each `query()` request can contain only one of the parameters `queries`, `id` or `vector`. + :param query_filter: the filter to apply. You can use vector metadata to limit your search. + :param top_k: the number of results to return for each query. Must be an integer greater than 1. + :param include_metadata: indicates whether metadata is included in the response as well as the ids. + If omitted the server will use the default value of `False`. + :param include_values: indicates whether values/vector is included in the response as well as the ids. + If omitted the server will use the default value of `False`. + :returns: object which contains the list of the closest vectors as ScoredVector objects, and namespace name. """ # get vector data and scores if vector is None: @@ -183,6 +190,12 @@ def _query(self, vector, top_k, filters=None): return result def find_documents(self, find_query): + """ + Find documents in the Astra index. + + :param find_query: a dictionary with the query options + :returns: the documents found in the index + """ response_dict = self._astra_db_collection.find( filter=find_query.get("filter"), sort=find_query.get("sort"), @@ -195,6 +208,13 @@ def find_documents(self, find_query): logger.warning(f"No documents found: {response_dict}") def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: + """ + Get documents from the Astra index by their ids. + + :param ids: a list of document ids + :param batch_size: the batch size to use when querying the index + :returns: the documents found in the index + """ document_batch = [] def batch_generator(chunks, batch_size): @@ -213,6 +233,12 @@ def batch_generator(chunks, batch_size): return formatted_docs def insert(self, documents: List[Dict]): + """ + Insert documents into the Astra index. + + :param documents: a list of documents to insert + :returns: the IDs of the inserted documents + """ response_dict = self._astra_db_collection.insert_many(documents=documents) inserted_ids = ( @@ -226,6 +252,13 @@ def insert(self, documents: List[Dict]): return inserted_ids def update_document(self, document: Dict, id_key: str): + """ + Update a document in the Astra index. + + :param document: the document to update + :param id_key: the key to use as the document id + :returns: whether the document was updated successfully + """ document_id = document.pop(id_key) response_dict = self._astra_db_collection.find_one_and_update( @@ -251,6 +284,13 @@ def delete( delete_all: Optional[bool] = None, filters: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, ) -> int: + """Delete documents from the Astra index. + + :param ids: the ids of the documents to delete + :param delete_all: if `True`, delete all documents from the index + :param filters: additional filters to apply when deleting documents + :returns: the number of documents deleted + """ if delete_all: query = {"deleteMany": {}} # type: dict if ids is not None: @@ -276,7 +316,8 @@ def delete( def count_documents(self) -> int: """ - Returns how many documents are present in the document store. + Count the number of documents in the Astra index. + :returns: the number of documents in the index """ documents_count = self._astra_db_collection.count_documents() diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py index 1a4ec9d17..2f8f0928d 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py @@ -32,6 +32,19 @@ def _batches(input_list, batch_size): class AstraDocumentStore: """ An AstraDocumentStore document store for Haystack. + + Example Usage: + ```python + from haystack_integrations.document_stores.astra import AstraDocumentStore + + document_store = AstraDocumentStore( + api_endpoint=api_endpoint, + token=token, + collection_name=collection_name, + duplicates_policy=DuplicatePolicy.SKIP, + embedding_dim=384, + ) + ``` """ def __init__( @@ -45,22 +58,24 @@ def __init__( ): """ The connection to Astra DB is established and managed through the JSON API. - The required credentials (api endpoint andapplication token) can be generated + The required credentials (api endpoint and application token) can be generated through the UI by clicking and the connect tab, and then selecting JSON API and Generate Configuration. - :param api_endpoint: The Astra DB API endpoint. - :param token: The Astra DB application token. - :param collection_name: The current collection in the keyspace in the current Astra DB. - :param embedding_dimension: Dimension of embedding vector. - :param duplicates_policy: Handle duplicate documents based on DuplicatePolicy parameter options. - Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, + :param api_endpoint: the Astra DB API endpoint. + :param token: the Astra DB application token. + :param collection_name: the current collection in the keyspace in the current Astra DB. + :param embedding_dimension: dimension of embedding vector. + :param duplicates_policy: handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (`SKIP`, `OVERWRITE`, `FAIL`, `NONE`) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same ID already exists, it is skipped and not written. - - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, it is skipped and not written. - - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. - :param similarity: The similarity function used to compare document vectors. + - `DuplicatePolicy.SKIP`: if a Document with the same ID already exists, it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: if a Document with the same ID already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: if a Document with the same ID already exists, an error is raised. + :param similarity: the similarity function used to compare document vectors. + + :raises ValueError: if the API endpoint or token is not set. """ resolved_api_endpoint = api_endpoint.resolve_value() if resolved_api_endpoint is None: @@ -95,10 +110,24 @@ def __init__( @classmethod def from_dict(cls, data: Dict[str, Any]) -> "AstraDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_endpoint", "token"]) return default_from_dict(cls, data) def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ return default_to_dict( self, api_endpoint=self.api_endpoint.to_dict(), @@ -118,15 +147,18 @@ def write_documents( Indexes documents for later queries. :param documents: a list of Haystack Document objects. - :param policy: Handle duplicate documents based on DuplicatePolicy parameter options. - Parameter options : (SKIP, OVERWRITE, FAIL, NONE) - - `DuplicatePolicy.NONE`: Default policy, If a Document with the same id already exists, - it is skipped and not written. - - `DuplicatePolicy.SKIP`: If a Document with the same id already exists, - it is skipped and not written. - - `DuplicatePolicy.OVERWRITE`: If a Document with the same id already exists, it is overwritten. - - `DuplicatePolicy.FAIL`: If a Document with the same id already exists, an error is raised. - :return: int + :param policy: handle duplicate documents based on DuplicatePolicy parameter options. + Parameter options : (`SKIP`, `OVERWRITE`, `FAIL`, `NONE`) + - `DuplicatePolicy.NONE`: Default policy, If a Document with the same ID already exists, + it is skipped and not written. + - `DuplicatePolicy.SKIP`: If a Document with the same ID already exists, + it is skipped and not written. + - `DuplicatePolicy.OVERWRITE`: If a Document with the same ID already exists, it is overwritten. + - `DuplicatePolicy.FAIL`: If a Document with the same ID already exists, an error is raised. + :returns: number of documents written. + :raises ValueError: if the documents are not of type Document or dict. + :raises DuplicateDocumentError: if a document with the same ID already exists and policy is set to FAIL. + :raises Exception: if the document ID is not a string or if `id` and `_id` are both present in the document. """ if policy is None or policy == DuplicatePolicy.NONE: if self.duplicates_policy is not None and self.duplicates_policy != DuplicatePolicy.NONE: @@ -226,21 +258,19 @@ def _convert_input_document(document: Union[dict, Document]): def count_documents(self) -> int: """ - Returns how many documents are present in the document store. + Counts the number of documents in the document store. + + :returns: the number of documents in the document store. """ return self.index.count_documents() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - """Returns at most 1000 documents that match the filter - - Args: - filters (Optional[Dict[str, Any]], optional): Filters to apply. Defaults to None. - - Raises: - AstraDocumentStoreFilterError: If the filter is invalid or not supported by this class. + """ + Returns at most 1000 documents that match the filter. - Returns: - List[Document]: A list of matching documents. + :param filters: filters to apply. + :returns: matching documents. + :raises AstraDocumentStoreFilterError: if the filter is invalid or not supported by this class. """ if not isinstance(filters, dict) and filters is not None: msg = "Filters must be a dictionary or None" @@ -299,7 +329,10 @@ def _get_result_to_documents(results) -> List[Document]: def get_documents_by_id(self, ids: List[str]) -> List[Document]: """ - Returns documents with given ids. + Gets documents by their IDs. + + :param ids: the IDs of the documents to retrieve. + :returns: the matching documents. """ results = self.index.get_documents(ids=ids) ret = self._get_result_to_documents(results) @@ -307,9 +340,11 @@ def get_documents_by_id(self, ids: List[str]) -> List[Document]: def get_document_by_id(self, document_id: str) -> Document: """ - :param document_id: id of the document to retrieve - Returns documents with given ids. - Raises MissingDocumentError when document_id does not exist in document store + Gets a document by its ID. + + :param document_id: the ID to filter by + :returns: the found document + :raises MissingDocumentError: if the document is not found """ document = self.index.get_documents(ids=[document_id]) ret = self._get_result_to_documents(document) @@ -321,15 +356,13 @@ def get_document_by_id(self, document_id: str) -> Document: def search( self, query_embedding: List[float], top_k: int, filters: Optional[Dict[str, Any]] = None ) -> List[Document]: - """Perform a search for a list of queries. - - Args: - query_embedding (List[float]): A list of query embeddings. - top_k (int): The number of results to return. - filters (Optional[Dict[str, Any]], optional): Filters to apply during search. Defaults to None. + """ + Perform a search for a list of queries. - Returns: - List[Document]: A list of matching documents. + :param query_embedding: a list of query embeddings. + :param top_k: the number of results to return. + :param filters: filters to apply during search. + :returns: matching documents. """ converted_filters = _convert_filters(filters) @@ -348,13 +381,12 @@ def search( def delete_documents(self, document_ids: Optional[List[str]] = None, delete_all: Optional[bool] = None) -> None: """ - Deletes all documents with a matching document_ids from the document store. - Fails with `MissingDocumentError` if no document with this id is present in the store. + Deletes documents from the document store. - :param document_ids: the document_ids to delete. - :param delete_all: delete all documents. + :param document_ids: IDs of the documents to delete. + :param delete_all: if `True`, delete all documents. + :raises MissingDocumentError: if no document was deleted but document IDs were provided. """ - deletion_counter = 0 if self.index.count_documents() > 0: if document_ids is not None: diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py index 186a8fef2..493f62917 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/errors.py @@ -6,12 +6,18 @@ class AstraDocumentStoreError(DocumentStoreError): + """Parent class for all AstraDocumentStore errors.""" + pass class AstraDocumentStoreFilterError(FilterError): + """Raised when an invalid filter is passed to AstraDocumentStore.""" + pass class AstraDocumentStoreConfigError(AstraDocumentStoreError): + """Raised when an invalid configuration is passed to AstraDocumentStore.""" + pass diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py index 6b628486b..44cac25e6 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/filters.py @@ -19,7 +19,7 @@ def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: def _convert_filters(filters: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: """ - Convert haystack filters to astra filterstring capturing all boolean operators + Convert haystack filters to astra filter string capturing all boolean operators """ if not filters: return None