From 58bb67aa0f5b1084a1acd65fbf5d4af013883af6 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 8 Oct 2024 18:18:02 +0200 Subject: [PATCH 01/11] updated graph to match cassandraGraphVectorStore --- .../langchain_astradb/graph_vectorstores.py | 1000 ++++++++++++----- libs/astradb/langchain_astradb/utils/mmr.py | 110 -- .../utils/{mmr_traversal.py => mmr_helper.py} | 40 +- .../astradb/langchain_astradb/vectorstores.py | 125 ++- .../test_graphvectorstore.py | 9 +- .../tests/unit_tests/test_mmr_helper.py | 120 +- 6 files changed, 971 insertions(+), 433 deletions(-) delete mode 100644 libs/astradb/langchain_astradb/utils/mmr.py rename libs/astradb/langchain_astradb/utils/{mmr_traversal.py => mmr_helper.py} (88%) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 4c13fe8..1e481b5 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -2,24 +2,28 @@ from __future__ import annotations +import asyncio +import json +import logging import secrets -from dataclasses import dataclass +from dataclasses import asdict, is_dataclass from typing import ( TYPE_CHECKING, Any, + AsyncIterable, Iterable, Sequence, + cast, ) -from langchain_community.graph_vectorstores.base import ( - GraphVectorStore, - Node, -) +from langchain_community.graph_vectorstores.base import GraphVectorStore, Node +from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link +from langchain_core._api import beta from langchain_core.documents import Document from typing_extensions import override from langchain_astradb.utils.astradb import COMPONENT_NAME_GRAPHVECTORSTORE -from langchain_astradb.utils.mmr_traversal import MmrHelper +from langchain_astradb.utils.mmr_helper import MmrHelper from langchain_astradb.vectorstores import AstraDBVectorStore if TYPE_CHECKING: @@ -33,28 +37,76 @@ DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} -@dataclass -class _Edge: - target_content_id: str - target_text_embedding: list[float] - target_link_to_tags: set[str] - target_doc: Document +logger = logging.getLogger(__name__) + + +class AdjacentNode: + id: str + links: list[Link] + embedding: list[float] + + def __init__(self, node: Node, embedding: list[float]) -> None: + """Create an Adjacent Node.""" + self.id = node.id or "" + self.links = node.links + self.embedding = embedding + + +def _serialize_links(links: list[Link]) -> str: + class SetAndLinkEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: # noqa: ANN401 + if not isinstance(obj, type) and is_dataclass(obj): + return asdict(obj) + + if isinstance(obj, Iterable): + return list(obj) + + # Let the base class default method raise the TypeError + return super().default(obj) + + return json.dumps(links, cls=SetAndLinkEncoder) + + +def _deserialize_links(json_blob: str | None) -> set[Link]: + return { + Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) + } + + +def _metadata_link_key(link: Link) -> str: + return f"link:{link.kind}:{link.tag}" -# NOTE: Conversion to string is necessary -# because AstraDB doesn't support matching on arrays of tuples -def _tag_to_str(kind: str, tag: str) -> str: - return f"{kind}:{tag}" +def _doc_to_node(doc: Document) -> Node: + metadata = doc.metadata.copy() + links = _deserialize_links(metadata.get(METADATA_LINKS_KEY)) + metadata[METADATA_LINKS_KEY] = links + return Node( + id=doc.id, + text=doc.page_content, + metadata=metadata, + links=list(links), + ) + +def _incoming_links(node: Node | AdjacentNode) -> set[Link]: + return {link for link in node.links if link.direction in ["in", "bidir"]} + + +def _outgoing_links(node: Node | AdjacentNode) -> set[Link]: + return {link for link in node.links if link.direction in ["out", "bidir"]} + + +@beta() class AstraDBGraphVectorStore(GraphVectorStore): def __init__( self, *, embedding: Embeddings, collection_name: str, - link_to_metadata_key: str = "links_to", - link_from_metadata_key: str = "links_from", + metadata_incoming_links_key: str = "incoming_links", token: str | TokenProvider | None = None, api_endpoint: str | None = None, namespace: str | None = None, @@ -68,7 +120,7 @@ def __init__( pre_delete_collection: bool = False, metadata_indexing_include: Iterable[str] | None = None, metadata_indexing_exclude: Iterable[str] | None = None, - collection_indexing_policy: dict[str, Any] | None = None, + collection_indexing_policy: dict[str, list[str]] | None = None, content_field: str | None = None, ignore_invalid_documents: bool = False, autodetect_collection: bool = False, @@ -81,10 +133,8 @@ def __init__( Args: embedding: the embeddings function. collection_name: name of the Astra DB collection to create/use. - link_to_metadata_key: document metadata key where the outgoing links are - stored. - link_from_metadata_key: document metadata key where the incoming links are - stored. + metadata_incoming_links_key: document metadata key where the incoming + links are stored (and indexed). token: API token for Astra DB usage, either in the form of a string or a subclass of ``astrapy.authentication.TokenProvider``. If not provided, the environment variable @@ -123,7 +173,8 @@ def __init__( This dict must conform to to the API specifications (see https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option) content_field: name of the field containing the textual content - in the documents when saved on Astra DB. Defaults to "content". + in the documents when saved on Astra DB. Defaults to + "content". The special value "*" can be passed only if autodetect_collection=True. In this case, the actual name of the key for the textual content is guessed by inspection of a few documents from the collection, under the @@ -168,11 +219,24 @@ def __init__( you can pass an already-created 'astrapy.db.AsyncAstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). """ - self.link_to_metadata_key = link_to_metadata_key - self.link_from_metadata_key = link_from_metadata_key + self.metadata_incoming_links_key = metadata_incoming_links_key self.embedding = embedding - self.vectorstore = AstraDBVectorStore( + # update indexing policy to ensure incoming_links are indexed, and the + # full links blob is not. + if collection_indexing_policy is not None: + collection_indexing_policy["allow"].append(self.metadata_incoming_links_key) + collection_indexing_policy["deny"].append(METADATA_LINKS_KEY) + elif metadata_indexing_include is not None: + metadata_indexing_include = set(metadata_indexing_include) + metadata_indexing_include.add(self.metadata_incoming_links_key) + elif metadata_indexing_exclude is not None: + metadata_indexing_exclude = set(metadata_indexing_exclude) + metadata_indexing_exclude.add(METADATA_LINKS_KEY) + elif not autodetect_collection: + metadata_indexing_exclude = [METADATA_LINKS_KEY] + + self.vector_store = AstraDBVectorStore( collection_name=collection_name, embedding=embedding, token=token, @@ -198,48 +262,74 @@ def __init__( async_astra_db_client=async_astra_db_client, ) - self.astra_env = self.vectorstore.astra_env + self.astra_env = self.vector_store.astra_env @property @override - def embeddings(self) -> Embeddings | None: + def embeddings(self) -> Embeddings: return self.embedding + def _get_metadata_filter( + self, + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, + ) -> dict[str, Any]: + if outgoing_link is None: + return metadata or {} + + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[self.metadata_incoming_links_key] = _metadata_link_key( + link=outgoing_link + ) + return metadata_filter + + def _restore_links(self, doc: Document) -> Document: + """Restores the links in the document by deserializing them from metadata. + + Args: + doc: A single Document + + Returns: + The same Document with restored links. + """ + links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) + doc.metadata[METADATA_LINKS_KEY] = links + del doc.metadata[self.metadata_incoming_links_key] + return doc + + # TODO: Async (aadd_nodes) @override def add_nodes( self, nodes: Iterable[Node], **kwargs: Any, ) -> Iterable[str]: + """Add nodes to the graph store. + + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ docs = [] ids = [] for node in nodes: node_id = secrets.token_hex(8) if not node.id else node.id - link_to_tags = set() # link to these tags - link_from_tags = set() # link from these tags - - for tag in node.links: - if tag.direction in {"in", "bidir"}: - # An incoming link should be linked *from* nodes with the given - # tag. - link_from_tags.add(_tag_to_str(tag.kind, tag.tag)) - if tag.direction in {"out", "bidir"}: - link_to_tags.add(_tag_to_str(tag.kind, tag.tag)) - - metadata = node.metadata - metadata[self.link_to_metadata_key] = list(link_to_tags) - metadata[self.link_from_metadata_key] = list(link_from_tags) + combined_metadata = node.metadata.copy() + combined_metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) + combined_metadata[self.metadata_incoming_links_key] = [ + _metadata_link_key(link=link) for link in _incoming_links(node=node) + ] doc = Document( page_content=node.text, - metadata=metadata, + metadata=combined_metadata, id=node_id, ) docs.append(doc) ids.append(node_id) - return self.vectorstore.add_documents(docs, ids=ids) + return self.vector_store.add_documents(docs, ids=ids) @classmethod @override @@ -251,6 +341,7 @@ def from_texts( ids: Iterable[str] | None = None, **kwargs: Any, ) -> AstraDBGraphVectorStore: + """Return AstraDBGraphVectorStore initialized from texts and embeddings.""" store = cls(embedding=embedding, **kwargs) store.add_texts(texts, metadatas, ids=ids) return store @@ -261,10 +352,12 @@ def from_documents( cls: type[AstraDBGraphVectorStore], documents: Iterable[Document], embedding: Embeddings, + ids: Iterable[str] | None = None, **kwargs: Any, ) -> AstraDBGraphVectorStore: + """Return AstraDBGraphVectorStore initialized from docs and embeddings.""" store = cls(embedding=embedding, **kwargs) - store.add_documents(documents) + store.add_documents(documents, ids=ids) return store @override @@ -272,152 +365,171 @@ def similarity_search( self, query: str, k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: - return self.vectorstore.similarity_search(query, k, metadata_filter, **kwargs) + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + return [ + self._restore_links(doc) + for doc in self.vector_store.similarity_search( + query=query, + k=k, + filter=filter, + **kwargs, + ) + ] @override - def similarity_search_by_vector( + async def asimilarity_search( self, - embedding: list[float], + query: str, k: int = 4, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> list[Document]: - return self.vectorstore.similarity_search_by_vector( - embedding, k, metadata_filter, **kwargs - ) + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + return [ + self._restore_links(doc) + for doc in await self.vector_store.asimilarity_search( + query=query, + k=k, + filter=filter, + **kwargs, + ) + ] @override - def traversal_search( # noqa: C901 + def similarity_search_by_vector( self, - query: str, - *, + embedding: list[float], k: int = 4, - depth: int = 1, - adjacent_k: int = 10, - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, - ) -> Iterable[Document]: - # Map from visited ID to depth - visited_ids: dict[str, int] = {} - visited_docs: list[Document] = [] - - # Map from visited tag `(kind, tag)` to depth. Allows skipping queries - # for tags that we've already traversed. - visited_tags: dict[str, int] = {} - - def visit_documents(d: int, docs: Iterable[Any]) -> None: - nonlocal visited_ids, visited_docs, visited_tags - - # Visit documents at the given depth. - # Each document has `id`, `link_from_tags` and `link_to_tags`. - - # Iterate over documents, tracking the *new* outgoing kind tags for this - # depth. This is tags that are either new, or newly discovered at a - # lower depth. - outgoing_tags = set() - for doc in docs: - # Add visited ID. If it is closer it is a new document at this depth: - if d <= visited_ids.get(doc.id, depth): - visited_ids[doc.id] = d - visited_docs.append(doc) - - # If we can continue traversing from this document, - if d < depth and doc.metadata[self.link_to_metadata_key]: - # Record any new (or newly discovered at a lower depth) - # tags to the set to traverse. - for tag in doc.metadata[self.link_to_metadata_key]: - if d <= visited_tags.get(tag, depth): - # Record that we'll query this tag at the - # given depth, so we don't fetch it again - # (unless we find it at an earlier depth) - visited_tags[tag] = d - outgoing_tags.add(tag) - - if outgoing_tags: - # If there are new tags to visit at the next depth, query for the - # doc IDs. - for tag in outgoing_tags: - m_filter = (metadata_filter or {}).copy() - m_filter[self.link_from_metadata_key] = tag - - rows = self.vectorstore.similarity_search( - query=query, k=adjacent_k, filter=m_filter, **kwargs - ) - visit_targets(d, rows) - - def visit_targets(d: int, targets: Sequence[Document]) -> None: - nonlocal visited_ids + ) -> list[Document]: + """Return docs most similar to embedding vector. - new_docs_at_next_depth = {} - for target in targets: - if target.id is None: - continue - if d < visited_ids.get(target.id, depth): - new_docs_at_next_depth[target.id] = target + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. - if new_docs_at_next_depth: - visit_documents(d + 1, new_docs_at_next_depth.values()) + Returns: + The list of Documents most similar to the query vector. + """ + return [ + self._restore_links(doc) + for doc in self.vector_store.similarity_search_by_vector( + embedding, + k=k, + filter=filter, + **kwargs, + ) + ] - docs = self.vectorstore.similarity_search( - query=query, - k=k, - filter=metadata_filter, - **kwargs, - ) - visit_documents(0, docs) + @override + async def asimilarity_search_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Document]: + """Return docs most similar to embedding vector. - return visited_docs + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. - def filter_to_query(self, filter_dict: dict[str, Any] | None) -> dict[str, Any]: - """Prepare a query for use on DB based on metadata filter. + Returns: + The list of Documents most similar to the query vector. + """ + return [ + self._restore_links(doc) + for doc in await self.vector_store.asimilarity_search_by_vector( + embedding, + k=k, + filter=filter, + **kwargs, + ) + ] - Encode an "abstract" filter clause on metadata into a query filter - condition aware of the collection schema choice. + def metadata_search( + self, + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. Args: - filter_dict: a metadata condition in the form {"field": "value"} - or related. - - Returns: - the corresponding mapping ready for use in queries, - aware of the details of the schema used to encode the document on DB. + filter: the metadata to query for. + n: the maximum number of documents to return. """ - return self.vectorstore.filter_to_query(filter_dict) + return [ + self._restore_links(doc) + for doc in self.vector_store.metadata_search( + filter=filter, + n=n, + ) + ] - def _get_outgoing_tags( + async def ametadata_search( self, - source_ids: Iterable[str], - ) -> set[str]: - """Return the set of outgoing tags for the given source ID(s). + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. Args: - source_ids: The IDs of the source nodes to retrieve outgoing tags for. + filter: the metadata to query for. + n: the maximum number of documents to return. """ - tags = set() - - for source_id in source_ids: - hits = list( - self.astra_env.collection.find( - filter=self.vectorstore.document_codec.encode_id(source_id), - # NOTE: Really, only the link-to metadata value is needed here - projection=self.vectorstore.document_codec.base_projection, - ) + return [ + self._restore_links(doc) + for doc in await self.vector_store.ametadata_search( + filter=filter, + n=n, ) + ] - for hit in hits: - doc = self.vectorstore.document_codec.decode(hit) - if doc is None: - continue - metadata = doc.metadata or {} - tags.update(metadata.get(self.link_to_metadata_key, [])) + def get_node(self, node_id: str) -> Node | None: + """Retrieve a single node from the store, given its ID. - return tags + Args: + node_id: The node ID + + Returns: + The the node if it exists. Otherwise None. + """ + doc = self.vector_store.get_by_document_id(document_id=node_id) + if doc is None: + return None + return _doc_to_node(doc=doc) @override - def mmr_traversal_search( # noqa: C901 + async def ammr_traversal_search( # noqa: C901 self, query: str, *, @@ -428,9 +540,41 @@ def mmr_traversal_search( # noqa: C901 adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata_filter: dict[str, Any] | None = None, + filter: dict[str, Any] | None = None, **kwargs: Any, - ) -> Iterable[Document]: + ) -> AsyncIterable[Document]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + """ query_embedding = self.embedding.embed_query(query) helper = MmrHelper( k=k, @@ -439,119 +583,72 @@ def mmr_traversal_search( # noqa: C901 score_threshold=score_threshold, ) - # For each unselected node, stores the outgoing tags. - outgoing_tags: dict[str, set[str]] = {} - - visited_tags: set[str] = set() - - def get_adjacent(tags: set[str]) -> Iterable[_Edge]: - targets: dict[str, _Edge] = {} - - # TODO: Would be better parallelized - for tag in tags: - m_filter = (metadata_filter or {}).copy() - m_filter[self.link_from_metadata_key] = tag - metadata_parameter = self.filter_to_query(m_filter) - - hits = list( - self.astra_env.collection.find( - filter=metadata_parameter, - projection=self.vectorstore.document_codec.full_projection, - limit=adjacent_k, - include_similarity=True, - include_sort_vector=True, - sort=self.vectorstore.document_codec.encode_vector_sort( - query_embedding - ), - ) - ) + # For each unselected node, stores the outgoing links. + outgoing_links_map: dict[str, set[Link]] = {} + visited_links: set[Link] = set() + # Map from id to Document + retrieved_docs: dict[str, Document] = {} - for hit in hits: - doc = self.vectorstore.document_codec.decode(hit) - if doc is None or doc.id is None: - continue - - vector = self.vectorstore.document_codec.decode_vector(hit) - if vector is None: - continue - - if doc.id not in targets: - targets[doc.id] = _Edge( - target_content_id=doc.id, - target_text_embedding=vector, - target_link_to_tags=set( - hit.get(self.link_to_metadata_key, []) - ), - target_doc=doc, - ) - - # TODO: Consider a combined limit based on the similarity and/or - # predicated MMR score? - return targets.values() + async def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + nonlocal outgoing_links_map, visited_links, retrieved_docs - def fetch_neighborhood(neighborhood: Sequence[str]) -> None: - # Put the neighborhood into the outgoing tags, to avoid adding it + # Put the neighborhood into the outgoing links, to avoid adding it # to the candidate set in the future. - outgoing_tags.update({content_id: set() for content_id in neighborhood}) + outgoing_links_map.update( + {content_id: set() for content_id in neighborhood} + ) - # Initialize the visited_tags with the set of outgoing from the + # Initialize the visited_links with the set of outgoing links from the # neighborhood. This prevents re-visiting them. - visited_tags = self._get_outgoing_tags(neighborhood) + visited_links = await self._get_outgoing_links(neighborhood) # Call `self._get_adjacent` to fetch the candidates. - adjacents = get_adjacent(visited_tags) - - new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + adjacent_nodes = await self._get_adjacent( + links=visited_links, + query_embedding=query_embedding, + k_per_link=adjacent_k, + filter=filter, + retrieved_docs=retrieved_docs, + ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_doc, - adjacent.target_text_embedding, + new_candidates: dict[str, list[float]] = {} + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = _outgoing_links( + node=adjacent_node ) + new_candidates[adjacent_node.id] = adjacent_node.embedding helper.add_candidates(new_candidates) - def fetch_initial_candidates() -> None: - metadata_parameter = self.filter_to_query(metadata_filter).copy() - hits = list( - self.astra_env.collection.find( - filter=metadata_parameter, - projection=self.vectorstore.document_codec.full_projection, - limit=fetch_k, - include_similarity=True, - include_sort_vector=True, - sort=self.vectorstore.document_codec.encode_vector_sort( - query_embedding - ), + async def fetch_initial_candidates() -> None: + nonlocal outgoing_links_map, visited_links, retrieved_docs + + results = ( + await self.vector_store.asimilarity_search_with_embedding_id_by_vector( + embedding=query_embedding, + k=fetch_k, + filter=filter, ) ) - candidates = {} - for hit in hits: - doc = self.vectorstore.document_codec.decode(hit) - if doc is None or doc.id is None: - continue - - vector = self.vectorstore.document_codec.decode_vector(hit) - if vector is None: - continue - - candidates[doc.id] = (doc, vector) - tags = set(doc.metadata.get(self.link_to_metadata_key, [])) - outgoing_tags[doc.id] = tags + candidates: dict[str, list[float]] = {} + for doc, embedding, doc_id in results: + if doc_id not in retrieved_docs: + retrieved_docs[doc_id] = doc + if doc_id not in outgoing_links_map: + node = _doc_to_node(doc) + outgoing_links_map[doc_id] = _outgoing_links(node=node) + candidates[doc_id] = embedding helper.add_candidates(candidates) if initial_roots: - fetch_neighborhood(initial_roots) + await fetch_neighborhood(initial_roots) if fetch_k > 0: - fetch_initial_candidates() + await fetch_initial_candidates() # Tracks the depth of each candidate. - depths = dict.fromkeys(helper.candidate_ids(), 0) + depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} # Select the best item, K times. for _ in range(k): @@ -564,36 +661,33 @@ def fetch_initial_candidates() -> None: if next_depth < depth: # If the next nodes would not exceed the depth limit, find the # adjacent nodes. - # - # TODO: For a big performance win, we should track which tags we've - # already incorporated. We don't need to issue adjacent queries for - # those. - # Find the tags linked to from the selected ID. - link_to_tags = outgoing_tags.pop(selected_id) + # Find the links linked to from the selected ID. + selected_outgoing_links = outgoing_links_map.pop(selected_id) - # Don't re-visit already visited tags. - link_to_tags.difference_update(visited_tags) + # Don't re-visit already visited links. + selected_outgoing_links.difference_update(visited_links) - # Find the nodes with incoming links from those tags. - adjacents = get_adjacent(link_to_tags) + # Find the nodes with incoming links from those links. + adjacent_nodes = await self._get_adjacent( + links=selected_outgoing_links, + query_embedding=query_embedding, + k_per_link=adjacent_k, + filter=filter, + retrieved_docs=retrieved_docs, + ) - # Record the link_to_tags as visited. - visited_tags.update(link_to_tags) + # Record the selected_outgoing_links as visited. + visited_links.update(selected_outgoing_links) new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = _outgoing_links( + node=adjacent_node ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_doc, - adjacent.target_text_embedding, - ) - if next_depth < depths.get( - adjacent.target_content_id, depth + 1 - ): + new_candidates[adjacent_node.id] = adjacent_node.embedding + if next_depth < depths.get(adjacent_node.id, depth + 1): # If this is a new shortest depth, or there was no # previous depth, update the depths. This ensures that # when we discover a node we will have the shortest @@ -604,7 +698,355 @@ def fetch_initial_candidates() -> None: # a shorter path via nodes selected later. This is # currently "intended", but may be worth experimenting # with. - depths[adjacent.target_content_id] = next_depth + depths[adjacent_node.id] = next_depth helper.add_candidates(new_candidates) - return [helper.candidate_docs[sid] for sid in helper.selected_ids] + for doc_id, similarity_score, mmr_score in zip( + helper.selected_ids, + helper.selected_similarity_scores, + helper.selected_mmr_scores, + ): + if doc_id in retrieved_docs: + doc = self._restore_links(retrieved_docs[doc_id]) + doc.metadata["similarity_score"] = similarity_score + doc.metadata["mmr_score"] = mmr_score + yield doc + else: + msg = f"retrieved_docs should contain id: {doc_id}" + raise RuntimeError(msg) + + @override + def mmr_traversal_search( + self, + query: str, + *, + initial_roots: Sequence[str] = (), + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Iterable[Document]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + """ + + async def collect_docs() -> Iterable[Document]: + async_iter = self.ammr_traversal_search( + query=query, + initial_roots=initial_roots, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + filter=filter, + **kwargs, + ) + return [doc async for doc in async_iter] + + return asyncio.run(collect_docs()) + + @override + async def atraversal_search( # noqa: C901 + self, + query: str, + *, + k: int = 4, + depth: int = 1, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncIterable[Document]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + # Depth 0: + # Query for `k` nodes similar to the question. + # Retrieve `content_id` and `outgoing_links()`. + # + # Depth 1: + # Query for nodes that have an incoming link in the `outgoing_links()` set. + # Combine node IDs. + # Query for `outgoing_links()` of those "new" node IDs. + # + # ... + + # Map from visited ID to depth + visited_ids: dict[str, int] = {} + + # Map from visited link to depth + visited_links: dict[Link, int] = {} + + # Map from id to Document + retrieved_docs: dict[str, Document] = {} + + async def visit_nodes(d: int, docs: Iterable[Document]) -> None: + """Recursively visit nodes and their outgoing links.""" + nonlocal visited_ids, visited_links, retrieved_docs + + # Iterate over nodes, tracking the *new* outgoing links for this + # depth. These are links that are either new, or newly discovered at a + # lower depth. + outgoing_links: set[Link] = set() + for doc in docs: + if doc.id is not None: + if doc.id not in retrieved_docs: + retrieved_docs[doc.id] = doc + + # If this node is at a closer depth, update visited_ids + if d <= visited_ids.get(doc.id, depth): + visited_ids[doc.id] = d + + # If we can continue traversing from this node, + if d < depth: + node = _doc_to_node(doc=doc) + # Record any new (or newly discovered at a lower depth) + # links to the set to traverse. + for link in _outgoing_links(node=node): + if d <= visited_links.get(link, depth): + # Record that we'll query this link at the + # given depth, so we don't fetch it again + # (unless we find it an earlier depth) + visited_links[link] = d + outgoing_links.add(link) + + if outgoing_links: + metadata_search_tasks = [] + for outgoing_link in outgoing_links: + metadata_filter = self._get_metadata_filter( + metadata=filter, + outgoing_link=outgoing_link, + ) + metadata_search_tasks.append( + asyncio.create_task( + self.vector_store.ametadata_search( + filter=metadata_filter, n=1000 + ) + ) + ) + results = await asyncio.gather(*metadata_search_tasks) + + # Visit targets concurrently + visit_target_tasks = [ + visit_targets(d=d + 1, docs=docs) for docs in results + ] + await asyncio.gather(*visit_target_tasks) + + async def visit_targets(d: int, docs: Iterable[Document]) -> None: + """Visit target nodes retrieved from outgoing links.""" + nonlocal visited_ids, retrieved_docs + + new_ids_at_next_depth = set() + for doc in docs: + if doc.id is not None: + if doc.id not in retrieved_docs: + retrieved_docs[doc.id] = doc + + if d <= visited_ids.get(doc.id, depth): + new_ids_at_next_depth.add(doc.id) + + if new_ids_at_next_depth: + visit_node_tasks = [ + visit_nodes(d=d, docs=[retrieved_docs[doc_id]]) + for doc_id in new_ids_at_next_depth + if doc_id in retrieved_docs + ] + + fetch_tasks = [ + asyncio.create_task( + self.vector_store.aget_by_document_id(document_id=doc_id) + ) + for doc_id in new_ids_at_next_depth + if doc_id not in retrieved_docs + ] + + new_docs: list[Document | None] = await asyncio.gather(*fetch_tasks) + + visit_node_tasks.extend( + visit_nodes(d=d, docs=[new_doc]) + for new_doc in new_docs + if new_doc is not None + ) + + await asyncio.gather(*visit_node_tasks) + + # Start the traversal + initial_docs = self.vector_store.similarity_search( + query=query, + k=k, + filter=filter, + ) + await visit_nodes(d=0, docs=initial_docs) + + for doc_id in visited_ids: + if doc_id in retrieved_docs: + yield self._restore_links(retrieved_docs[doc_id]) + else: + msg = f"retrieved_docs should contain id: {doc_id}" + raise RuntimeError(msg) + + @override + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Iterable[Document]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + + async def collect_docs() -> Iterable[Document]: + async_iter = self.atraversal_search( + query=query, + k=k, + depth=depth, + filter=filter, + **kwargs, + ) + return [doc async for doc in async_iter] + + return asyncio.run(collect_docs()) + + async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: + """Return the set of outgoing links for the given source IDs asynchronously. + + Args: + source_ids: The IDs of the source nodes to retrieve outgoing links for. + + Returns: + A set of `Link` objects representing the outgoing links from the source + nodes. + """ + links = set() + + # Create coroutine objects without scheduling them yet + coroutines = [ + self.vector_store.aget_by_document_id(document_id=source_id) + for source_id in source_ids + ] + + # Schedule and await all coroutines + docs = await asyncio.gather(*coroutines) + + for doc in docs: + if doc is not None: + node = _doc_to_node(doc=doc) + links.update(_outgoing_links(node=node)) + + return links + + async def _get_adjacent( + self, + links: set[Link], + query_embedding: list[float], + retrieved_docs: dict[str, Document], + k_per_link: int | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> Iterable[AdjacentNode]: + """Return the target nodes with incoming links from any of the given links. + + Args: + links: The links to look for. + query_embedding: The query embedding. Used to rank target nodes. + retrieved_docs: A cache of retrieved docs. This will be added to. + k_per_link: The number of target nodes to fetch for each link. + filter: Optional metadata to filter the results. + + Returns: + Iterable of adjacent edges. + """ + targets: dict[str, AdjacentNode] = {} + + tasks = [] + for link in links: + metadata_filter = self._get_metadata_filter( + metadata=filter, + outgoing_link=link, + ) + + tasks.append( + self.vector_store.asimilarity_search_with_embedding_id_by_vector( + embedding=query_embedding, + k=k_per_link or 10, + filter=metadata_filter, + ) + ) + + results = await asyncio.gather(*tasks) + + for result in results: + for doc, embedding, doc_id in result: + if doc_id not in retrieved_docs: + retrieved_docs[doc_id] = doc + if doc_id not in targets: + node = _doc_to_node(doc=doc) + targets[doc_id] = AdjacentNode(node=node, embedding=embedding) + + # TODO: Consider a combined limit based on the similarity and/or + # predicated MMR score? + return targets.values() diff --git a/libs/astradb/langchain_astradb/utils/mmr.py b/libs/astradb/langchain_astradb/utils/mmr.py deleted file mode 100644 index aa4e198..0000000 --- a/libs/astradb/langchain_astradb/utils/mmr.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Tools for the Maximal Marginal Relevance (MMR) reranking. - -Duplicated from langchain_community to avoid cross-dependencies. - -Functions "maximal_marginal_relevance" and "cosine_similarity" -are duplicated in this utility respectively from modules: - - "libs/community/langchain_community/vectorstores/utils.py" - - "libs/community/langchain_community/utils/math.py" -""" - -from __future__ import annotations - -import logging -from typing import List, Union - -import numpy as np - -logger = logging.getLogger(__name__) - -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] - - -def cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: - """Row-wise cosine similarity between two equal-width matrices. - - Args: - x: First matrix. - y: Second matrix. - - Returns: - np.ndarray: Cosine similarity matrix. - """ - if len(x) == 0 or len(y) == 0: - return np.array([]) - - x = np.array(x) - y = np.array(y) - if x.shape[1] != y.shape[1]: - msg = ( - f"Number of columns in X and Y must be the same. X has shape {x.shape} " - f"and Y has shape {y.shape}." - ) - raise ValueError(msg) - try: - import simsimd as simd # type: ignore[import] - except ImportError: - logger.info( - "Unable to import simsimd, defaulting to NumPy implementation. If you want " - "to use simsimd please install with `pip install simsimd`." - ) - x_norm = np.linalg.norm(x, axis=1) - y_norm = np.linalg.norm(y, axis=1) - # Ignore divide by zero errors run time warnings as those are handled below. - with np.errstate(divide="ignore", invalid="ignore"): - similarity = np.dot(x, y.T) / np.outer(x_norm, y_norm) - similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 - return similarity - else: - x = np.array(x, dtype=np.float32) - y = np.array(y, dtype=np.float32) - z = 1 - np.array(simd.cdist(x, y, metric="cosine")) - if isinstance(z, float): - return np.array([z]) - return z - - -def maximal_marginal_relevance( - query_embedding: np.ndarray, - embedding_list: list, - lambda_mult: float = 0.5, - k: int = 4, -) -> list[int]: - """Calculate maximal marginal relevance. - - Args: - query_embedding: Query embedding to compare. - embedding_list: List of embeddings to select from. - lambda_mult: Number between 0 and 1 that determines the degree - of diversity among the results with 0 corresponding - to maximum diversity and 1 to minimum diversity. - k: Number of embeddings to select. - - Returns: - List of indices of selected embeddings. - """ - if min(k, len(embedding_list)) <= 0: - return [] - if query_embedding.ndim == 1: - query_embedding = np.expand_dims(query_embedding, axis=0) - similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0] - most_similar = int(np.argmax(similarity_to_query)) - idxs = [most_similar] - selected = np.array([embedding_list[most_similar]]) - while len(idxs) < min(k, len(embedding_list)): - best_score = -np.inf - idx_to_add = -1 - similarity_to_selected = cosine_similarity(embedding_list, selected) - for i, query_score in enumerate(similarity_to_query): - if i in idxs: - continue - redundant_score = max(similarity_to_selected[i]) - equation_score = ( - lambda_mult * query_score - (1 - lambda_mult) * redundant_score - ) - if equation_score > best_score: - best_score = equation_score - idx_to_add = i - idxs.append(idx_to_add) - selected = np.append(selected, [embedding_list[idx_to_add]], axis=0) - return idxs diff --git a/libs/astradb/langchain_astradb/utils/mmr_traversal.py b/libs/astradb/langchain_astradb/utils/mmr_helper.py similarity index 88% rename from libs/astradb/langchain_astradb/utils/mmr_traversal.py rename to libs/astradb/langchain_astradb/utils/mmr_helper.py index 188ba03..3c2f1e0 100644 --- a/libs/astradb/langchain_astradb/utils/mmr_traversal.py +++ b/libs/astradb/langchain_astradb/utils/mmr_helper.py @@ -6,11 +6,9 @@ from typing import TYPE_CHECKING, Iterable import numpy as np - -from langchain_astradb.utils.mmr import cosine_similarity +from langchain_community.utils.math import cosine_similarity if TYPE_CHECKING: - from langchain_core.documents import Document from numpy.typing import NDArray @@ -27,6 +25,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]: @dataclasses.dataclass class _Candidate: id: str + similarity: float weighted_similarity: float weighted_redundancy: float score: float = dataclasses.field(init=False) @@ -72,6 +71,13 @@ class MmrHelper: selected_ids: list[str] """List of selected IDs (in selection order).""" + + selected_mmr_scores: list[float] + """List of MMR score at the time each document is selected.""" + + selected_similarity_scores: list[float] + """List of similarity score for each selected document.""" + selected_embeddings: NDArray[np.float32] """(N, dim) ndarray with a row for each selected node.""" @@ -82,8 +88,6 @@ class MmrHelper: Same order as rows in `candidate_embeddings`. """ - candidate_docs: dict[str, Document] - """Dict containing the documents associated with each candidate ID.""" candidate_embeddings: NDArray[np.float32] """(N, dim) ndarray with a row for each candidate.""" @@ -106,12 +110,13 @@ def __init__( self.score_threshold = score_threshold self.selected_ids = [] + self.selected_similarity_scores = [] + self.selected_mmr_scores = [] # List of selected embeddings (in selection order). self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32) self.candidate_id_to_index = {} - self.candidate_docs = {} # List of the candidates. self.candidates = [] @@ -130,11 +135,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]: selected = len(self.selected_ids) return np.vsplit(self.selected_embeddings, [selected])[0] - def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]: + def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]: """Pop the candidate with the given ID. Returns: - The embedding of the candidate. + The similarity score and embedding of the candidate. """ # Get the embedding for the id. index = self.candidate_id_to_index.pop(candidate_id) @@ -150,12 +155,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]: # candidate_embeddings. last_index = self.candidate_embeddings.shape[0] - 1 + similarity = 0.0 if index == last_index: # Already the last item. We don't need to swap. - self.candidates.pop() + similarity = self.candidates.pop().similarity else: self.candidate_embeddings[index] = self.candidate_embeddings[last_index] + similarity = self.candidates[index].similarity + old_last = self.candidates.pop() self.candidates[index] = old_last self.candidate_id_to_index[old_last.id] = index @@ -164,7 +172,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]: 0 ] - return embedding + return similarity, embedding def pop_best(self) -> str | None: """Select and pop the best item being considered. @@ -179,11 +187,13 @@ def pop_best(self) -> str | None: # Get the selection and remove from candidates. selected_id = self.best_id - selected_embedding = self._pop_candidate(selected_id) + selected_similarity, selected_embedding = self._pop_candidate(selected_id) # Add the ID and embedding to the selected information. selection_index = len(self.selected_ids) self.selected_ids.append(selected_id) + self.selected_mmr_scores.append(self.best_score) + self.selected_similarity_scores.append(selected_similarity) self.selected_embeddings[selection_index] = selected_embedding # Reset the best score / best ID. @@ -203,9 +213,7 @@ def pop_best(self) -> str | None: return selected_id - def add_candidates( - self, candidates: dict[str, tuple[Document, list[float]]] - ) -> None: + def add_candidates(self, candidates: dict[str, list[float]]) -> None: """Add candidates to the consideration set.""" # Determine the keys to actually include. # These are the candidates that aren't already selected @@ -227,9 +235,8 @@ def add_candidates( for index, candidate_id in enumerate(include_ids): if candidate_id in include_ids: self.candidate_id_to_index[candidate_id] = offset + index - doc, embedding = candidates[candidate_id] + embedding = candidates[candidate_id] new_embeddings[index] = embedding - self.candidate_docs[candidate_id] = doc # Compute the similarity to the query. similarity = cosine_similarity(new_embeddings, self.query_embedding) @@ -245,6 +252,7 @@ def add_candidates( max_redundancy = redundancy[index].max() candidate = _Candidate( id=candidate_id, + similarity=similarity[index][0], weighted_similarity=self.lambda_mult * similarity[index][0], weighted_redundancy=self.lambda_mult_complement * max_redundancy, ) diff --git a/libs/astradb/langchain_astradb/vectorstores.py b/libs/astradb/langchain_astradb/vectorstores.py index c431b07..645aa24 100644 --- a/libs/astradb/langchain_astradb/vectorstores.py +++ b/libs/astradb/langchain_astradb/vectorstores.py @@ -22,6 +22,7 @@ import numpy as np from astrapy.exceptions import InsertManyException +from langchain_community.vectorstores.utils import maximal_marginal_relevance from langchain_core.runnables.utils import gather_with_concurrency from langchain_core.vectorstores import VectorStore from typing_extensions import override @@ -36,7 +37,6 @@ _AstraDBCollectionEnvironment, _survey_collection, ) -from langchain_astradb.utils.mmr import maximal_marginal_relevance from langchain_astradb.utils.vector_store_autodetect import ( _detect_document_codec, ) @@ -1339,6 +1339,91 @@ async def _update_document( return sum(u_res.update_info["n"] for u_res in update_results) + def metadata_search( + self, + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> list[Document]: + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ + self.astra_env.ensure_db_setup() + metadata_parameter = self.filter_to_query(filter) + hits_ite = self.astra_env.collection.find( + filter=metadata_parameter, + projection=self.document_codec.base_projection, + limit=n, + ) + docs = [self.document_codec.decode(hit) for hit in hits_ite] + return [doc for doc in docs if doc is not None] + + async def ametadata_search( + self, + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ + await self.astra_env.aensure_db_setup() + metadata_parameter = self.filter_to_query(filter) + return [ + doc + async for doc in ( + self.document_codec.decode(hit) + async for hit in self.astra_env.async_collection.find( + filter=metadata_parameter, + projection=self.document_codec.base_projection, + limit=n, + ) + ) + if doc is not None + ] + + def get_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + self.astra_env.ensure_db_setup() + # self.collection is not None (by _ensure_astra_db_client) + hit = self.astra_env.collection.find_one( + {"_id": document_id}, + projection=self.document_codec.base_projection, + ) + if hit is None: + return None + return self.document_codec.decode(hit) + + async def aget_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + await self.astra_env.aensure_db_setup() + # self.collection is not None (by _ensure_astra_db_client) + hit = await self.astra_env.async_collection.find_one( + {"_id": document_id}, + projection=self.document_codec.base_projection, + ) + if hit is None: + return None + return self.document_codec.decode(hit) + @override def similarity_search( self, @@ -1702,6 +1787,44 @@ async def asimilarity_search_with_score_id_by_vector( filter=filter, ) + async def asimilarity_search_with_embedding_id_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> list[tuple[Document, list[float], str]]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + + Returns: + List of (Document, embedding, id), the most similar to the query vector. + """ + await self.astra_env.aensure_db_setup() + metadata_parameter = self.filter_to_query(filter).copy() + results: list[tuple[Document, list[float], str]] = [] + async for hit in self.astra_env.async_collection.find( + filter=metadata_parameter, + projection=self.document_codec.full_projection, + limit=k, + include_similarity=True, + include_sort_vector=True, + sort=self.document_codec.encode_vector_sort(embedding), + ): + doc = self.document_codec.decode(hit) + if doc is None or doc.id is None: + continue + + vector = self.document_codec.decode_vector(hit) + if vector is None: + continue + + results.append((doc, vector, doc.id)) + return results + async def _asimilarity_search_with_score_id_by_sort( self, sort: dict[str, Any], diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 3b84248..60bd4e1 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -163,8 +163,7 @@ def autodetect_populated_graph_vector_store_d2( gstore = AstraDBGraphVectorStore( embedding=embedding_d2, collection_name=ephemeral_collection_cleaner_idxall_d2, - link_to_metadata_key="x_link_to_x", - link_from_metadata_key="x_link_from_x", + metadata_incoming_links_key="x_link_to_x", token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], @@ -216,7 +215,7 @@ def test_gvs_similarity_search( ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] assert ss_by_v_labels == ["AR", "A0"] if is_autodetected: - assert_all_flat_docs(store.vectorstore.astra_env.collection) + assert_all_flat_docs(store.vector_store.astra_env.collection) @pytest.mark.parametrize( ("store_name", "is_autodetected"), @@ -241,7 +240,7 @@ def test_gvs_traversal_search( ts_labels = {doc.metadata["label"] for doc in ts_response} assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} if is_autodetected: - assert_all_flat_docs(store.vectorstore.astra_env.collection) + assert_all_flat_docs(store.vector_store.astra_env.collection) @pytest.mark.parametrize( ("store_name", "is_autodetected"), @@ -272,7 +271,7 @@ def test_gvs_mmr_traversal_search( mt_labels = {doc.metadata["label"] for doc in mt_response} assert mt_labels == {"AR", "BR"} if is_autodetected: - assert_all_flat_docs(store.vectorstore.astra_env.collection) + assert_all_flat_docs(store.vector_store.astra_env.collection) def test_gvs_from_texts( self, diff --git a/libs/astradb/tests/unit_tests/test_mmr_helper.py b/libs/astradb/tests/unit_tests/test_mmr_helper.py index 02167c5..e9bec6d 100644 --- a/libs/astradb/tests/unit_tests/test_mmr_helper.py +++ b/libs/astradb/tests/unit_tests/test_mmr_helper.py @@ -1,6 +1,8 @@ -from langchain_core.documents import Document +from __future__ import annotations -from langchain_astradb.utils.mmr_traversal import MmrHelper +import math + +from langchain_astradb.utils.mmr_helper import MmrHelper IDS = { "-1", @@ -22,19 +24,19 @@ def test_mmr_helper_functional(self) -> None: assert len(list(helper.candidate_ids())) == 0 - helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) - helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) - helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) - helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) - helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) assert len(list(helper.candidate_ids())) == 5 - helper.add_candidates({"+1": (Document(page_content="+1"), [5, 3])}) - helper.add_candidates({"+2": (Document(page_content="+2"), [5, 3])}) - helper.add_candidates({"+3": (Document(page_content="+3"), [6, 2])}) - helper.add_candidates({"+4": (Document(page_content="+4"), [6, 1])}) - helper.add_candidates({"+5": (Document(page_content="+5"), [6, 0])}) + helper.add_candidates({"+1": [5, 3]}) + helper.add_candidates({"+2": [5, 3]}) + helper.add_candidates({"+3": [6, 2]}) + helper.add_candidates({"+4": [6, 1]}) + helper.add_candidates({"+5": [6, 0]}) assert len(list(helper.candidate_ids())) == 10 @@ -46,22 +48,96 @@ def test_mmr_helper_functional(self) -> None: def test_mmr_helper_max_diversity(self) -> None: helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=0) - helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) - helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) - helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) - helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) - helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) best = {helper.pop_best(), helper.pop_best()} assert best == {"-1", "-5"} def test_mmr_helper_max_similarity(self) -> None: helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=1) - helper.add_candidates({"-1": (Document(page_content="-1"), [3, 5])}) - helper.add_candidates({"-2": (Document(page_content="-2"), [3, 5])}) - helper.add_candidates({"-3": (Document(page_content="-3"), [2, 6])}) - helper.add_candidates({"-4": (Document(page_content="-4"), [1, 6])}) - helper.add_candidates({"-5": (Document(page_content="-5"), [0, 6])}) + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) best = {helper.pop_best(), helper.pop_best()} assert best == {"-1", "-2"} + + def test_mmr_helper_add_candidate(self) -> None: + helper = MmrHelper(5, [0.0, 1.0]) + helper.add_candidates( + { + "a": [0.0, 1.0], + "b": [1.0, 0.0], + } + ) + assert helper.best_id == "a" + + def test_mmr_helper_pop_best(self) -> None: + helper = MmrHelper(5, [0.0, 1.0]) + helper.add_candidates( + { + "a": [0.0, 1.0], + "b": [1.0, 0.0], + } + ) + assert helper.pop_best() == "a" + assert helper.pop_best() == "b" + assert helper.pop_best() is None + + def angular_embedding(self, angle: float) -> list[float]: + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + + def test_mmr_helper_added_documents(self) -> None: + """Test end to end construction and MMR search. + The embedding function used here ensures `texts` become + the following vectors on a circle (numbered v0 through v3): + + ______ v2 + / \ + / | v1 + v3 | . | query + | / v0 + |______/ (N.B. very crude drawing) + + + With fetch_k==2 and k==2, when query is at 0.0, (1, ), + one expects that v2 and v0 are returned (in some order) + because v1 is "too close" to v0 (and v0 is closer than v1)). + + Both v2 and v3 are discovered after v0. + """ + helper = MmrHelper(5, self.angular_embedding(0.0)) + + # Fetching the 2 nearest neighbors to 0.0 + helper.add_candidates( + { + "v0": self.angular_embedding(-0.124), + "v1": self.angular_embedding(+0.127), + } + ) + assert helper.pop_best() == "v0" + + # After v0 is seletected, new nodes are discovered. + # v2 is closer than v3. v1 is "too similar" to "v0" so it's not included. + helper.add_candidates( + { + "v2": self.angular_embedding(+0.25), + "v3": self.angular_embedding(+1.0), + } + ) + assert helper.pop_best() == "v2" + + assert math.isclose( + helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001 + ) + assert math.isclose( + helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001 + ) + assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001) + assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001) From 1b1bdee7a9842f2df826de3ff1a3506be67a1990 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Thu, 10 Oct 2024 15:25:37 +0200 Subject: [PATCH 02/11] fix tests --- libs/astradb/tests/integration_tests/test_graphvectorstore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 60bd4e1..544ffd0 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -353,9 +353,9 @@ def test_gvs_add_nodes( assert hits[0].page_content == "[0, 2]" md0 = hits[0].metadata assert md0["m"] == 0 - assert any(isinstance(v, list) for k, v in md0.items() if k != "m") + assert any(isinstance(v, set) for k, v in md0.items() if k != "m") assert hits[1].id != "id0" assert hits[1].page_content == "[0, 1]" md1 = hits[1].metadata assert md1["m"] == 1 - assert any(isinstance(v, list) for k, v in md1.items() if k != "m") + assert any(isinstance(v, set) for k, v in md1.items() if k != "m") From 931bfdf690f658cf0e050c73e773963c67b0ed58 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Fri, 11 Oct 2024 12:09:28 +0200 Subject: [PATCH 03/11] some fixes --- .../langchain_astradb/graph_vectorstores.py | 86 ++++++++++++++----- .../tests/unit_tests/test_mmr_helper.py | 12 +-- 2 files changed, 72 insertions(+), 26 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 1e481b5..39daf71 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -27,9 +27,10 @@ from langchain_astradb.vectorstores import AstraDBVectorStore if TYPE_CHECKING: - from astrapy.authentication import TokenProvider + from astrapy.authentication import EmbeddingHeadersProvider, TokenProvider from astrapy.db import AstraDB as AstraDBClient from astrapy.db import AsyncAstraDB as AsyncAstraDBClient + from astrapy.info import CollectionVectorServiceOptions from langchain_core.embeddings import Embeddings from langchain_astradb.utils.astradb import SetupMode @@ -70,7 +71,7 @@ def default(self, obj: Any) -> Any: # noqa: ANN401 def _deserialize_links(json_blob: str | None) -> set[Link]: return { Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) - for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "[]")) } @@ -104,13 +105,13 @@ class AstraDBGraphVectorStore(GraphVectorStore): def __init__( self, *, - embedding: Embeddings, collection_name: str, + embedding: Embeddings, metadata_incoming_links_key: str = "incoming_links", token: str | TokenProvider | None = None, api_endpoint: str | None = None, - namespace: str | None = None, environment: str | None = None, + namespace: str | None = None, metric: str | None = None, batch_size: int | None = None, bulk_insert_batch_concurrency: int | None = None, @@ -120,18 +121,25 @@ def __init__( pre_delete_collection: bool = False, metadata_indexing_include: Iterable[str] | None = None, metadata_indexing_exclude: Iterable[str] | None = None, - collection_indexing_policy: dict[str, list[str]] | None = None, + collection_indexing_policy: dict[str, Any] | None = None, + collection_vector_service_options: CollectionVectorServiceOptions | None = None, + collection_embedding_api_key: str | EmbeddingHeadersProvider | None = None, content_field: str | None = None, ignore_invalid_documents: bool = False, autodetect_collection: bool = False, ext_callers: list[tuple[str | None, str | None] | str | None] | None = None, + component_name: str = COMPONENT_NAME_GRAPHVECTORSTORE, astra_db_client: AstraDBClient | None = None, async_astra_db_client: AsyncAstraDBClient | None = None, ): """Graph Vector Store backed by AstraDB. Args: - embedding: the embeddings function. + embedding: the embeddings function or service to use. + This enables client-side embedding functions or calls to external + embedding providers. If ``embedding`` is provided, arguments + ``collection_vector_service_options`` and + ``collection_embedding_api_key`` cannot be provided. collection_name: name of the Astra DB collection to create/use. metadata_incoming_links_key: document metadata key where the incoming links are stored (and indexed). @@ -142,12 +150,12 @@ def __init__( api_endpoint: full URL to the API endpoint, such as ``https://-us-east1.apps.astra.datastax.com``. If not provided, the environment variable ASTRA_DB_API_ENDPOINT is inspected. - namespace: namespace (aka keyspace) where the collection is created. - If not provided, the environment variable ASTRA_DB_KEYSPACE is - inspected. Defaults to the database's "default namespace". environment: a string specifying the environment of the target Data API. If omitted, defaults to "prod" (Astra DB production). Other values are in ``astrapy.constants.Environment`` enum class. + namespace: namespace (aka keyspace) where the collection is created. + If not provided, the environment variable ASTRA_DB_KEYSPACE is + inspected. Defaults to the database's "default namespace". metric: similarity function to use out of those available in Astra DB. If left out, it will use Astra DB API's defaults (i.e. "cosine" - but, for performance reasons, "dot_product" is suggested if embeddings are @@ -172,9 +180,21 @@ def __init__( what fields should be indexed for later filtering in searches. This dict must conform to to the API specifications (see https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option) + collection_vector_service_options: specifies the use of server-side + embeddings within Astra DB. If passing this parameter, ``embedding`` + cannot be provided. + collection_embedding_api_key: for usage of server-side embeddings + within Astra DB. With this parameter one can supply an API Key + that will be passed to Astra DB with each data request. + This parameter can be either a string or a subclass of + ``astrapy.authentication.EmbeddingHeadersProvider``. + This is useful when the service is configured for the collection, + but no corresponding secret is stored within + Astra's key management system. content_field: name of the field containing the textual content - in the documents when saved on Astra DB. Defaults to - "content". + in the documents when saved on Astra DB. For vectorize collections, + this cannot be specified; for non-vectorize collection, defaults + to "content". The special value "*" can be passed only if autodetect_collection=True. In this case, the actual name of the key for the textual content is guessed by inspection of a few documents from the collection, under the @@ -191,8 +211,10 @@ def __init__( The store will look for an existing collection of the provided name and infer the store settings from it. Default is False. In autodetect mode, ``content_field`` can be given as ``"*"``, meaning - that an attempt will be made to determine it by inspection. - In autodetect mode, the store switches + that an attempt will be made to determine it by inspection (unless + vectorize is enabled, in which case ``content_field`` is ignored). + In autodetect mode, the store not only determines whether embeddings + are client- or server-side, but - most importantly - switches automatically between "nested" and "flat" representations of documents on DB (i.e. having the metadata key-value pairs grouped in a ``metadata`` field or spread at the documents' top-level). The former @@ -202,12 +224,18 @@ def __init__( an AstraDBVectorStore to them. Note that the following parameters cannot be used if this is True: ``metric``, ``setup_mode``, ``metadata_indexing_include``, - ``metadata_indexing_exclude``, ``collection_indexing_policy``. + ``metadata_indexing_exclude``, ``collection_indexing_policy``, + ``collection_vector_service_options``. ext_callers: one or more caller identities to identify Data API calls in the User-Agent header. This is a list of (name, version) pairs, or just strings if no version info is provided, which, if supplied, becomes the leading part of the User-Agent string in all API requests related to this component. + component_name: the string identifying this specific component in the + stack of usage info passed as the User-Agent string to the Data API. + Defaults to "langchain_graphvectorstore", but can be overridden if this + component actually serves as the building block for another component + (such as a Graph Vector Store). astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* @@ -218,6 +246,23 @@ def __init__( *Please use 'token', 'api_endpoint' and optionally 'environment'.* you can pass an already-created 'astrapy.db.AsyncAstraDB' instance (alternatively to 'token', 'api_endpoint' and 'environment'). + + Note: + For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, + on a typical client machine it is suggested to keep the quantity + bulk_insert_batch_concurrency * bulk_insert_overwrite_concurrency + much below 1000 to avoid exhausting the client multithreading/networking + resources. The hardcoded defaults are somewhat conservative to meet + most machines' specs, but a sensible choice to test may be: + + - bulk_insert_batch_concurrency = 80 + - bulk_insert_overwrite_concurrency = 10 + + A bit of experimentation is required to nail the best results here, + depending on both the machine/network specs and the expected workload + (specifically, how often a write is an update of an existing id). + Remember you can pass concurrency settings to individual calls to + :meth:`~add_texts` and :meth:`~add_documents` as well. """ self.metadata_incoming_links_key = metadata_incoming_links_key self.embedding = embedding @@ -226,7 +271,6 @@ def __init__( # full links blob is not. if collection_indexing_policy is not None: collection_indexing_policy["allow"].append(self.metadata_incoming_links_key) - collection_indexing_policy["deny"].append(METADATA_LINKS_KEY) elif metadata_indexing_include is not None: metadata_indexing_include = set(metadata_indexing_include) metadata_indexing_include.add(self.metadata_incoming_links_key) @@ -241,8 +285,8 @@ def __init__( embedding=embedding, token=token, api_endpoint=api_endpoint, - namespace=namespace, environment=environment, + namespace=namespace, metric=metric, batch_size=batch_size, bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, @@ -253,11 +297,13 @@ def __init__( metadata_indexing_include=metadata_indexing_include, metadata_indexing_exclude=metadata_indexing_exclude, collection_indexing_policy=collection_indexing_policy, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, content_field=content_field, ignore_invalid_documents=ignore_invalid_documents, autodetect_collection=autodetect_collection, ext_callers=ext_callers, - component_name=COMPONENT_NAME_GRAPHVECTORSTORE, + component_name=component_name, astra_db_client=astra_db_client, async_astra_db_client=async_astra_db_client, ) @@ -266,7 +312,7 @@ def __init__( @property @override - def embeddings(self) -> Embeddings: + def embeddings(self) -> Embeddings | None: return self.embedding def _get_metadata_filter( @@ -490,7 +536,7 @@ def metadata_search( return [ self._restore_links(doc) for doc in self.vector_store.metadata_search( - filter=filter, + filter=filter or {}, n=n, ) ] @@ -509,7 +555,7 @@ async def ametadata_search( return [ self._restore_links(doc) for doc in await self.vector_store.ametadata_search( - filter=filter, + filter=filter or {}, n=n, ) ] diff --git a/libs/astradb/tests/unit_tests/test_mmr_helper.py b/libs/astradb/tests/unit_tests/test_mmr_helper.py index e9bec6d..bb5c7af 100644 --- a/libs/astradb/tests/unit_tests/test_mmr_helper.py +++ b/libs/astradb/tests/unit_tests/test_mmr_helper.py @@ -98,12 +98,12 @@ def test_mmr_helper_added_documents(self) -> None: The embedding function used here ensures `texts` become the following vectors on a circle (numbered v0 through v3): - ______ v2 - / \ - / | v1 + ______ v2 + / \ + / | v1 v3 | . | query - | / v0 - |______/ (N.B. very crude drawing) + | / v0 + |______/ (N.B. very crude drawing) With fetch_k==2 and k==2, when query is at 0.0, (1, ), @@ -123,7 +123,7 @@ def test_mmr_helper_added_documents(self) -> None: ) assert helper.pop_best() == "v0" - # After v0 is seletected, new nodes are discovered. + # After v0 is selected, new nodes are discovered. # v2 is closer than v3. v1 is "too similar" to "v0" so it's not included. helper.add_candidates( { From 4c6c46e8c101f5958a606ec20ab323a2dae5d798 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 14 Oct 2024 12:59:03 +0200 Subject: [PATCH 04/11] added initial upgrade test --- .../langchain_astradb/graph_vectorstores.py | 119 +++++++++----- .../langchain_astradb/utils/astradb.py | 43 +++--- .../test_upgrade_to_graphvectorstore.py | 146 ++++++++++++++++++ 3 files changed, 250 insertions(+), 58 deletions(-) create mode 100644 libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 39daf71..af8c760 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -267,46 +267,61 @@ def __init__( self.metadata_incoming_links_key = metadata_incoming_links_key self.embedding = embedding - # update indexing policy to ensure incoming_links are indexed, and the - # full links blob is not. - if collection_indexing_policy is not None: - collection_indexing_policy["allow"].append(self.metadata_incoming_links_key) - elif metadata_indexing_include is not None: + # update indexing policy to ensure incoming_links are indexed + if metadata_indexing_include is not None: metadata_indexing_include = set(metadata_indexing_include) metadata_indexing_include.add(self.metadata_incoming_links_key) - elif metadata_indexing_exclude is not None: - metadata_indexing_exclude = set(metadata_indexing_exclude) - metadata_indexing_exclude.add(METADATA_LINKS_KEY) - elif not autodetect_collection: - metadata_indexing_exclude = [METADATA_LINKS_KEY] - - self.vector_store = AstraDBVectorStore( - collection_name=collection_name, - embedding=embedding, - token=token, - api_endpoint=api_endpoint, - environment=environment, - namespace=namespace, - metric=metric, - batch_size=batch_size, - bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, - bulk_insert_overwrite_concurrency=bulk_insert_overwrite_concurrency, - bulk_delete_concurrency=bulk_delete_concurrency, - setup_mode=setup_mode, - pre_delete_collection=pre_delete_collection, - metadata_indexing_include=metadata_indexing_include, - metadata_indexing_exclude=metadata_indexing_exclude, - collection_indexing_policy=collection_indexing_policy, - collection_vector_service_options=collection_vector_service_options, - collection_embedding_api_key=collection_embedding_api_key, - content_field=content_field, - ignore_invalid_documents=ignore_invalid_documents, - autodetect_collection=autodetect_collection, - ext_callers=ext_callers, - component_name=component_name, - astra_db_client=astra_db_client, - async_astra_db_client=async_astra_db_client, - ) + elif collection_indexing_policy is not None: + allow_list = collection_indexing_policy.get("allow") + if allow_list is not None: + allow_list = set(allow_list) + allow_list.add(self.metadata_incoming_links_key) + collection_indexing_policy["allow"] = list(allow_list) + + try: + self.vector_store = AstraDBVectorStore( + collection_name=collection_name, + embedding=embedding, + token=token, + api_endpoint=api_endpoint, + environment=environment, + namespace=namespace, + metric=metric, + batch_size=batch_size, + bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, + bulk_insert_overwrite_concurrency=bulk_insert_overwrite_concurrency, + bulk_delete_concurrency=bulk_delete_concurrency, + setup_mode=setup_mode, + pre_delete_collection=pre_delete_collection, + metadata_indexing_include=metadata_indexing_include, + metadata_indexing_exclude=metadata_indexing_exclude, + collection_indexing_policy=collection_indexing_policy, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, + content_field=content_field, + ignore_invalid_documents=ignore_invalid_documents, + autodetect_collection=autodetect_collection, + ext_callers=ext_callers, + component_name=component_name, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + ) + + # # attempt a query to see if the table is setup correctly + + # self.metadata_search(filter = { + # self.metadata_incoming_links_key : "test" + # }, n=1) + except BaseException as exp: + # determine if error is because of a un-indexed column. Ref: + # https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#considerations-for-selective-indexing + error_message = str(exp).lower() + if ("unindexed filter path" in error_message) or ( + "incompatible with the requested indexing policy" in error_message + ): + msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + raise ValueError(msg) from exp + raise exp # noqa: TRY201 self.astra_env = self.vector_store.astra_env @@ -340,7 +355,8 @@ def _restore_links(self, doc: Document) -> Document: """ links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) doc.metadata[METADATA_LINKS_KEY] = links - del doc.metadata[self.metadata_incoming_links_key] + if self.metadata_incoming_links_key in doc.metadata: + del doc.metadata[self.metadata_incoming_links_key] return doc # TODO: Async (aadd_nodes) @@ -560,6 +576,31 @@ async def ametadata_search( ) ] + def get_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + doc = self.vector_store.get_by_document_id(document_id=document_id) + return self._restore_links(doc) if doc is not None else None + + async def aget_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + await self.astra_env.aensure_db_setup() + doc = await self.vector_store.aget_by_document_id(document_id=document_id) + return self._restore_links(doc) if doc is not None else None + def get_node(self, node_id: str) -> Node | None: """Retrieve a single node from the store, given its ID. diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 9b976c3..442c53b 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -366,7 +366,7 @@ def __init__( self.database.drop_collection(collection_name) if inspect.isawaitable(embedding_dimension): msg = ( - "Cannot use an awaitable embedding_dimension with async_setup " + "Cannot use an awaitable embedding_dimension with sync_setup " "set to False" ) raise ValueError(msg) @@ -380,18 +380,20 @@ def __init__( service=collection_vector_service_options, check_exists=False, ) - except DataAPIException: + except DataAPIException as data_api_exception: # possibly the collection is preexisting and may have legacy, # or custom, indexing settings: verify collection_descriptors = list(self.database.list_collections()) - if not self._validate_indexing_policy( - collection_descriptors=collection_descriptors, - collection_name=self.collection_name, - requested_indexing_policy=requested_indexing_policy, - default_indexing_policy=default_indexing_policy, - ): - # other reasons for the exception - raise + try: + if not self._validate_indexing_policy( + collection_descriptors=collection_descriptors, + collection_name=self.collection_name, + requested_indexing_policy=requested_indexing_policy, + default_indexing_policy=default_indexing_policy, + ): + raise data_api_exception # noqa: TRY201 + except ValueError as validation_error: + raise validation_error from data_api_exception async def _asetup_db( self, @@ -420,20 +422,23 @@ async def _asetup_db( service=collection_vector_service_options, check_exists=False, ) - except DataAPIException: + except DataAPIException as data_api_exception: # possibly the collection is preexisting and may have legacy, # or custom, indexing settings: verify collection_descriptors = [ coll_desc async for coll_desc in self.async_database.list_collections() ] - if not self._validate_indexing_policy( - collection_descriptors=collection_descriptors, - collection_name=self.collection_name, - requested_indexing_policy=requested_indexing_policy, - default_indexing_policy=default_indexing_policy, - ): - # other reasons for the exception - raise + try: + if not self._validate_indexing_policy( + collection_descriptors=collection_descriptors, + collection_name=self.collection_name, + requested_indexing_policy=requested_indexing_policy, + default_indexing_policy=default_indexing_policy, + ): + # other reasons for the exception + raise data_api_exception # noqa: TRY201 + except ValueError as validation_error: + raise validation_error from data_api_exception @staticmethod def _validate_indexing_policy( diff --git a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py new file mode 100644 index 0000000..653c214 --- /dev/null +++ b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py @@ -0,0 +1,146 @@ +"""Test of Upgrading to Astra DB graph vector store class: +`AstraDBGraphVectorStore` from an existing collection used +by the Astra DB vector store class: `AstraDBVectorStore` + +Refer to `test_vectorstores.py` for the requirements to run. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest +from astrapy.authentication import StaticTokenProvider +from langchain_core.documents import Document + +from langchain_astradb.graph_vectorstores import AstraDBGraphVectorStore +from langchain_astradb.vectorstores import AstraDBVectorStore + +from .conftest import ( + astra_db_env_vars_available, +) + +if TYPE_CHECKING: + from langchain_core.embeddings import Embeddings + + from .conftest import AstraDBCredentials + + +@pytest.fixture +def default_vector_store_d2( + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, +) -> AstraDBVectorStore: + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=ephemeral_collection_cleaner_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + ) + + +@pytest.fixture +def vector_store_d2_with_indexing_allow_list( + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, +) -> AstraDBVectorStore: + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=ephemeral_collection_cleaner_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + # this is the only difference from the `default_vector_store_d2` fixture above + collection_indexing_policy={"allow": ["test"]}, + ) + + +@pytest.fixture +def vector_store_d2_with_indexing_deny_list( + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, +) -> AstraDBVectorStore: + return AstraDBVectorStore( + embedding=embedding_d2, + collection_name=ephemeral_collection_cleaner_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + # this is the only difference from the `default_vector_store_d2` fixture above + collection_indexing_policy={"deny": ["test"]}, + ) + + +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestUpgradeToGraphVectorStore: + @pytest.mark.parametrize( + ("store_name", "indexing_policy", "expect_success"), + [ + ("default_vector_store_d2", None, True), + ("vector_store_d2_with_indexing_allow_list", {"allow": ["test"]}, False), + ("vector_store_d2_with_indexing_allow_list", None, False), + ("vector_store_d2_with_indexing_deny_list", {"deny": ["test"]}, True), + ("vector_store_d2_with_indexing_deny_list", None, False), + ], + ids=[ + "default_store_upgrade_should_succeed", + "allow_store_upgrade_with_allow_policy_should_fail", + "allow_store_upgrade_with_no_policy_should_fail", + "deny_store_upgrade_with_deny_policy_should_succeed", + "deny_store_upgrade_with_no_policy_should_fail", + ], + ) + def test_upgrade_to_gvs( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + ephemeral_collection_cleaner_d2: str, + *, + store_name: str, + indexing_policy: dict[str, Any] | None, + expect_success: bool, + request: pytest.FixtureRequest, + ) -> None: + # Create Vector Store, load a document + v_store: AstraDBVectorStore = request.getfixturevalue(store_name) + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + v_store.add_documents([doc_al]) + + # Try to create a GRAPH Vector Store using the existing collection from above + try: + gv_store = AstraDBGraphVectorStore( + embedding=embedding_d2, + collection_name=ephemeral_collection_cleaner_d2, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=indexing_policy, + ) + + if not expect_success: + pytest.fail("Expected ValueError but none was raised") + + except ValueError as value_error: + if expect_success: + pytest.fail(f"Unexpected ValueError raised: {value_error}") + else: + assert ( # noqa: PT017 + str(value_error) + == "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + ) + + if expect_success: + doc = gv_store.get_by_document_id(document_id=doc_id) + assert doc is not None + assert doc.page_content == doc_al.page_content From 8097a09ad395c44feeb751f5fb1daf5434477089 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Mon, 14 Oct 2024 15:06:40 +0200 Subject: [PATCH 05/11] simplified insertion --- .../langchain_astradb/graph_vectorstores.py | 61 +++++++++++++------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index af8c760..93d8586 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -359,40 +359,62 @@ def _restore_links(self, doc: Document) -> Document: del doc.metadata[self.metadata_incoming_links_key] return doc - # TODO: Async (aadd_nodes) - @override - def add_nodes( - self, - nodes: Iterable[Node], - **kwargs: Any, - ) -> Iterable[str]: - """Add nodes to the graph store. + def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]: + metadata = node.metadata.copy() + metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) + metadata[self.metadata_incoming_links_key] = [ + _metadata_link_key(link=link) for link in _incoming_links(node=node) + ] + return metadata - Args: - nodes: the nodes to add. - **kwargs: Additional keyword arguments. - """ + def _get_docs_for_insertion( + self, nodes: Iterable[Node] + ) -> tuple[list[Document], list[str]]: docs = [] ids = [] for node in nodes: node_id = secrets.token_hex(8) if not node.id else node.id - combined_metadata = node.metadata.copy() - combined_metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) - combined_metadata[self.metadata_incoming_links_key] = [ - _metadata_link_key(link=link) for link in _incoming_links(node=node) - ] - doc = Document( page_content=node.text, - metadata=combined_metadata, + metadata=self._get_node_metadata_for_insertion(node=node), id=node_id, ) docs.append(doc) ids.append(node_id) + return (docs, ids) + + @override + def add_nodes( + self, + nodes: Iterable[Node], + **kwargs: Any, + ) -> Iterable[str]: + """Add nodes to the graph store. + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ + (docs, ids) = self._get_docs_for_insertion(nodes=nodes) return self.vector_store.add_documents(docs, ids=ids) + @override + async def aadd_nodes( + self, + nodes: Iterable[Node], + **kwargs: Any, + ) -> AsyncIterable[str]: + """Add nodes to the graph store. + + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ + (docs, ids) = self._get_docs_for_insertion(nodes=nodes) + for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids): + yield inserted_id + @classmethod @override def from_texts( @@ -597,7 +619,6 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: Returns: The the document if it exists. Otherwise None. """ - await self.astra_env.aensure_db_setup() doc = await self.vector_store.aget_by_document_id(document_id=document_id) return self._restore_links(doc) if doc is not None else None From a0c20d4964403f0e5d00b3d5e4355b595463d4aa Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 15 Oct 2024 14:46:16 +0200 Subject: [PATCH 06/11] improved testing --- .../langchain_astradb/graph_vectorstores.py | 44 ++- libs/astradb/tests/conftest.py | 1 - .../tests/integration_tests/conftest.py | 4 + .../test_graphvectorstore.py | 270 +++++++++++++- .../test_upgrade_to_graphvectorstore.py | 340 +++++++++++++----- .../integration_tests/test_vectorstore.py | 211 +++++++++-- 6 files changed, 729 insertions(+), 141 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 93d8586..69ebd44 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -22,7 +22,7 @@ from langchain_core.documents import Document from typing_extensions import override -from langchain_astradb.utils.astradb import COMPONENT_NAME_GRAPHVECTORSTORE +from langchain_astradb.utils.astradb import COMPONENT_NAME_GRAPHVECTORSTORE, SetupMode from langchain_astradb.utils.mmr_helper import MmrHelper from langchain_astradb.vectorstores import AstraDBVectorStore @@ -33,8 +33,6 @@ from astrapy.info import CollectionVectorServiceOptions from langchain_core.embeddings import Embeddings - from langchain_astradb.utils.astradb import SetupMode - DEFAULT_INDEXING_OPTIONS = {"allow": ["metadata"]} @@ -307,11 +305,43 @@ def __init__( async_astra_db_client=async_astra_db_client, ) - # # attempt a query to see if the table is setup correctly + # for the test search, if setup_mode is ASYNC, + # create a temp store with SYNC + if setup_mode == SetupMode.ASYNC: + test_vs = AstraDBVectorStore( + collection_name=collection_name, + embedding=embedding, + token=token, + api_endpoint=api_endpoint, + environment=environment, + namespace=namespace, + metric=metric, + batch_size=batch_size, + bulk_insert_batch_concurrency=bulk_insert_batch_concurrency, + bulk_insert_overwrite_concurrency=bulk_insert_overwrite_concurrency, + bulk_delete_concurrency=bulk_delete_concurrency, + setup_mode=SetupMode.SYNC, + pre_delete_collection=pre_delete_collection, + metadata_indexing_include=metadata_indexing_include, + metadata_indexing_exclude=metadata_indexing_exclude, + collection_indexing_policy=collection_indexing_policy, + collection_vector_service_options=collection_vector_service_options, + collection_embedding_api_key=collection_embedding_api_key, + content_field=content_field, + ignore_invalid_documents=ignore_invalid_documents, + autodetect_collection=autodetect_collection, + ext_callers=ext_callers, + component_name=component_name, + astra_db_client=astra_db_client, + async_astra_db_client=async_astra_db_client, + ) + else: + test_vs = self.vector_store - # self.metadata_search(filter = { - # self.metadata_incoming_links_key : "test" - # }, n=1) + # try a simple search to ensure that the indexes are setup properly + test_vs.metadata_search( + filter={self.metadata_incoming_links_key: "test"}, n=1 + ) except BaseException as exp: # determine if error is because of a un-indexed column. Ref: # https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#considerations-for-selective-indexing diff --git a/libs/astradb/tests/conftest.py b/libs/astradb/tests/conftest.py index 767023b..db44211 100644 --- a/libs/astradb/tests/conftest.py +++ b/libs/astradb/tests/conftest.py @@ -33,7 +33,6 @@ def embed_query(self, text: str) -> list[float]: try: vals = json.loads(text) except json.JSONDecodeError: - print(f'[ParserEmbeddings] Returning a moot vector for "{text}"') return [0.0] * self.dimension else: assert len(vals) == self.dimension diff --git a/libs/astradb/tests/integration_tests/conftest.py b/libs/astradb/tests/integration_tests/conftest.py index 2799088..0ce30d9 100644 --- a/libs/astradb/tests/integration_tests/conftest.py +++ b/libs/astradb/tests/integration_tests/conftest.py @@ -66,8 +66,10 @@ # for KMS (aka shared_secret) vectorize setup (vectorstores) EPHEMERAL_COLLECTION_NAME_VZ_KMS = "lc_test_vz_kms_short" # indexing-related collection names (function-lived) (vectorstores) +EPHEMERAL_ALLOW_IDX_NAME_D2 = "lc_test_allow_idx_d2_short" EPHEMERAL_CUSTOM_IDX_NAME_D2 = "lc_test_custom_idx_d2_short" EPHEMERAL_DEFAULT_IDX_NAME_D2 = "lc_test_default_idx_d2_short" +EPHEMERAL_DENY_IDX_NAME_D2 = "lc_test_deny_idx_d2_short" EPHEMERAL_LEGACY_IDX_NAME_D2 = "lc_test_legacy_idx_d2_short" # indexing-related collection names (function-lived) (storage) EPHEMERAL_CUSTOM_IDX_NAME = "lc_test_custom_idx_short" @@ -515,8 +517,10 @@ def ephemeral_indexing_collections_cleaner( """ collection_names = [ + EPHEMERAL_ALLOW_IDX_NAME_D2, EPHEMERAL_CUSTOM_IDX_NAME_D2, EPHEMERAL_DEFAULT_IDX_NAME_D2, + EPHEMERAL_DENY_IDX_NAME_D2, EPHEMERAL_LEGACY_IDX_NAME_D2, EPHEMERAL_CUSTOM_IDX_NAME, EPHEMERAL_LEGACY_IDX_NAME, diff --git a/libs/astradb/tests/integration_tests/test_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_graphvectorstore.py index 544ffd0..7b90f2b 100644 --- a/libs/astradb/tests/integration_tests/test_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_graphvectorstore.py @@ -61,24 +61,24 @@ def graph_vector_store_docs() -> list[Document]: """ docs_a = [ - Document(page_content="[-1, 9]", metadata={"label": "AL"}), - Document(page_content="[0, 10]", metadata={"label": "A0"}), - Document(page_content="[1, 9]", metadata={"label": "AR"}), + Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}), + Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}), + Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}), ] docs_b = [ - Document(page_content="[9, 1]", metadata={"label": "BL"}), - Document(page_content="[10, 0]", metadata={"label": "B0"}), - Document(page_content="[9, -1]", metadata={"label": "BR"}), + Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}), + Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}), + Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}), ] docs_f = [ - Document(page_content="[1, -9]", metadata={"label": "BL"}), - Document(page_content="[0, -10]", metadata={"label": "B0"}), - Document(page_content="[-1, -9]", metadata={"label": "BR"}), + Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}), + Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}), + Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}), ] docs_t = [ - Document(page_content="[-9, -1]", metadata={"label": "TL"}), - Document(page_content="[-10, 0]", metadata={"label": "T0"}), - Document(page_content="[-9, 1]", metadata={"label": "TR"}), + Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}), + Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}), + Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}), ] for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) @@ -199,7 +199,7 @@ class TestAstraDBGraphVectorStore: ], ids=["native_store", "autodetected_store"], ) - def test_gvs_similarity_search( + def test_gvs_similarity_search_sync( self, *, store_name: str, @@ -225,7 +225,35 @@ def test_gvs_similarity_search( ], ids=["native_store", "autodetected_store"], ) - def test_gvs_traversal_search( + async def test_gvs_similarity_search_async( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """Simple (non-graph) similarity search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + ss_response = await store.asimilarity_search(query="[2, 10]", k=2) + ss_labels = [doc.metadata["label"] for doc in ss_response] + assert ss_labels == ["AR", "A0"] + ss_by_v_response = await store.asimilarity_search_by_vector( + embedding=[2, 10], k=2 + ) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + if is_autodetected: + assert_all_flat_docs(store.vector_store.astra_env.collection) + + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_traversal_search_sync( self, *, store_name: str, @@ -250,7 +278,33 @@ def test_gvs_traversal_search( ], ids=["native_store", "autodetected_store"], ) - def test_gvs_mmr_traversal_search( + async def test_gvs_traversal_search_async( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """Graph traversal search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + ts_labels = set() + async for doc in store.atraversal_search(query="[2, 10]", k=2, depth=2): + ts_labels.add(doc.metadata["label"]) + # this is a set, as some of the internals of trav.search are set-driven + # so ordering is not deterministic: + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + if is_autodetected: + assert_all_flat_docs(store.vector_store.astra_env.collection) + + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_mmr_traversal_search_sync( self, *, store_name: str, @@ -273,6 +327,158 @@ def test_gvs_mmr_traversal_search( if is_autodetected: assert_all_flat_docs(store.vector_store.astra_env.collection) + @pytest.mark.parametrize( + ("store_name", "is_autodetected"), + [ + ("populated_graph_vector_store_d2", False), + ("autodetect_populated_graph_vector_store_d2", True), + ], + ids=["native_store", "autodetected_store"], + ) + async def test_gvs_mmr_traversal_search_async( + self, + *, + store_name: str, + is_autodetected: bool, + request: pytest.FixtureRequest, + ) -> None: + """MMR Graph traversal search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_labels = set() + async for doc in store.ammr_traversal_search( + query="[2, 10]", + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ): + mt_labels.add(doc.metadata["label"]) + # TODO: can this rightfully be a list (or must it be a set)? + assert mt_labels == {"AR", "BR"} + if is_autodetected: + assert_all_flat_docs(store.vector_store.astra_env.collection) + + @pytest.mark.parametrize( + ("store_name"), + [ + ("populated_graph_vector_store_d2"), + ("autodetect_populated_graph_vector_store_d2"), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_metadata_search_sync( + self, + *, + store_name: str, + request: pytest.FixtureRequest, + ) -> None: + """Metadata search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_response = store.metadata_search( + filter={"label": "T0"}, + n=2, + ) + doc: Document = next(iter(mt_response)) + assert doc.page_content == "[-10, 0]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "in" + assert link.kind == "at_example" + assert link.tag == "tag_0" + + @pytest.mark.parametrize( + ("store_name"), + [ + ("populated_graph_vector_store_d2"), + ("autodetect_populated_graph_vector_store_d2"), + ], + ids=["native_store", "autodetected_store"], + ) + async def test_gvs_metadata_search_async( + self, + *, + store_name: str, + request: pytest.FixtureRequest, + ) -> None: + """Metadata search on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + mt_response = await store.ametadata_search( + filter={"label": "T0"}, + n=2, + ) + doc: Document = next(iter(mt_response)) + assert doc.page_content == "[-10, 0]" + links: set[Link] = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "in" + assert link.kind == "at_example" + assert link.tag == "tag_0" + + @pytest.mark.parametrize( + ("store_name"), + [ + ("populated_graph_vector_store_d2"), + ("autodetect_populated_graph_vector_store_d2"), + ], + ids=["native_store", "autodetected_store"], + ) + def test_gvs_get_by_document_id_sync( + self, + *, + store_name: str, + request: pytest.FixtureRequest, + ) -> None: + """Get by document_id on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + doc = store.get_by_document_id(document_id="FL") + assert doc is not None + assert doc.page_content == "[1, -9]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "out" + assert link.kind == "af_example" + assert link.tag == "tag_l" + + invalid_doc = store.get_by_document_id(document_id="invalid") + assert invalid_doc is None + + @pytest.mark.parametrize( + ("store_name"), + [ + ("populated_graph_vector_store_d2"), + ("autodetect_populated_graph_vector_store_d2"), + ], + ids=["native_store", "autodetected_store"], + ) + async def test_gvs_get_by_document_id_async( + self, + *, + store_name: str, + request: pytest.FixtureRequest, + ) -> None: + """Get by document_id on a graph vector store.""" + store: AstraDBGraphVectorStore = request.getfixturevalue(store_name) + doc = await store.aget_by_document_id(document_id="FL") + assert doc is not None + assert doc.page_content == "[1, -9]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "out" + assert link.kind == "af_example" + assert link.tag == "tag_l" + + invalid_doc = await store.aget_by_document_id(document_id="invalid") + assert invalid_doc is None + def test_gvs_from_texts( self, *, @@ -330,7 +536,7 @@ def test_gvs_from_documents_containing_ids( # there may be more re:graph structure. assert hits[0].metadata["md"] == 1 - def test_gvs_add_nodes( + def test_gvs_add_nodes_sync( self, *, graph_vector_store_d2: AstraDBGraphVectorStore, @@ -359,3 +565,35 @@ def test_gvs_add_nodes( md1 = hits[1].metadata assert md1["m"] == 1 assert any(isinstance(v, set) for k, v in md1.items() if k != "m") + + async def test_gvs_add_nodes_async( + self, + *, + graph_vector_store_d2: AstraDBGraphVectorStore, + ) -> None: + links0 = [ + Link(kind="kA", direction="out", tag="tA"), + Link(kind="kB", direction="bidir", tag="tB"), + ] + links1 = [ + Link(kind="kC", direction="in", tag="tC"), + ] + nodes = [ + Node(id="id0", text="[0, 2]", metadata={"m": 0}, links=links0), + Node(text="[0, 1]", metadata={"m": 1}, links=links1), + ] + async for _ in graph_vector_store_d2.aadd_nodes(nodes): + pass + + hits = await graph_vector_store_d2.asimilarity_search_by_vector([0, 3]) + assert len(hits) == 2 + assert hits[0].id == "id0" + assert hits[0].page_content == "[0, 2]" + md0 = hits[0].metadata + assert md0["m"] == 0 + assert any(isinstance(v, set) for k, v in md0.items() if k != "m") + assert hits[1].id != "id0" + assert hits[1].page_content == "[0, 1]" + md1 = hits[1].metadata + assert md1["m"] == 1 + assert any(isinstance(v, set) for k, v in md1.items() if k != "m") diff --git a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py index 653c214..5238f0b 100644 --- a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py @@ -14,9 +14,13 @@ from langchain_core.documents import Document from langchain_astradb.graph_vectorstores import AstraDBGraphVectorStore +from langchain_astradb.utils.astradb import SetupMode from langchain_astradb.vectorstores import AstraDBVectorStore from .conftest import ( + EPHEMERAL_ALLOW_IDX_NAME_D2, + EPHEMERAL_DEFAULT_IDX_NAME_D2, + EPHEMERAL_DENY_IDX_NAME_D2, astra_db_env_vars_available, ) @@ -26,121 +30,281 @@ from .conftest import AstraDBCredentials -@pytest.fixture -def default_vector_store_d2( - astra_db_credentials: AstraDBCredentials, - embedding_d2: Embeddings, - ephemeral_collection_cleaner_d2: str, -) -> AstraDBVectorStore: - return AstraDBVectorStore( - embedding=embedding_d2, - collection_name=ephemeral_collection_cleaner_d2, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) +def _vs_indexing_policy(collection_name: str) -> dict[str, Any] | None: + if collection_name == EPHEMERAL_ALLOW_IDX_NAME_D2: + return {"allow": ["test"]} + if collection_name == EPHEMERAL_DEFAULT_IDX_NAME_D2: + return None + if collection_name == EPHEMERAL_DENY_IDX_NAME_D2: + return {"deny": ["test"]} + msg = f"Unknown collection_name: {collection_name} in _vs_indexing_policy()" + raise ValueError(msg) -@pytest.fixture -def vector_store_d2_with_indexing_allow_list( - astra_db_credentials: AstraDBCredentials, - embedding_d2: Embeddings, - ephemeral_collection_cleaner_d2: str, -) -> AstraDBVectorStore: - return AstraDBVectorStore( - embedding=embedding_d2, - collection_name=ephemeral_collection_cleaner_d2, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - # this is the only difference from the `default_vector_store_d2` fixture above - collection_indexing_policy={"allow": ["test"]}, +@pytest.mark.skipif( + not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" +) +class TestUpgradeToGraphVectorStore: + @pytest.mark.usefixtures("ephemeral_indexing_collections_cleaner") + @pytest.mark.parametrize( + ("collection_name", "gvs_setup_mode", "gvs_indexing_policy"), + [ + (EPHEMERAL_DEFAULT_IDX_NAME_D2, SetupMode.SYNC, None), + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.SYNC, {"deny": ["test"]}), + (EPHEMERAL_DEFAULT_IDX_NAME_D2, SetupMode.OFF, None), + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.OFF, {"deny": ["test"]}), + # for this one, even though the passed policy doesn't + # match the policy used to create the collection, + # there is no error since the SetupMode is OFF and + # and no attempt is made to re-create the collection. + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.OFF, None), + ], + ids=[ + "default_upgrade_no_policy_sync", + "deny_list_upgrade_same_policy_sync", + "default_upgrade_no_policy_off", + "deny_list_upgrade_same_policy_off", + "deny_list_upgrade_change_policy_off", + ], ) + def test_upgrade_to_gvs_success_sync( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + *, + gvs_setup_mode: SetupMode, + collection_name: str, + gvs_indexing_policy: dict[str, Any] | None, + ) -> None: + # Create vector store using SetupMode.SYNC + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=_vs_indexing_policy( + collection_name=collection_name + ), + setup_mode=SetupMode.SYNC, + ) + + # load a document to the vector store + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + v_store.add_documents([doc_al]) + + # get the document from the vector store + v_doc = v_store.get_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + gv_store = AstraDBGraphVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=gvs_indexing_policy, + setup_mode=gvs_setup_mode, + ) -@pytest.fixture -def vector_store_d2_with_indexing_deny_list( - astra_db_credentials: AstraDBCredentials, - embedding_d2: Embeddings, - ephemeral_collection_cleaner_d2: str, -) -> AstraDBVectorStore: - return AstraDBVectorStore( - embedding=embedding_d2, - collection_name=ephemeral_collection_cleaner_d2, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - # this is the only difference from the `default_vector_store_d2` fixture above - collection_indexing_policy={"deny": ["test"]}, + # get the document from the GRAPH vector store + gv_doc = gv_store.get_by_document_id(document_id=doc_id) + assert gv_doc is not None + assert gv_doc.page_content == doc_al.page_content + + @pytest.mark.usefixtures("ephemeral_indexing_collections_cleaner") + @pytest.mark.parametrize( + ("collection_name", "gvs_setup_mode", "gvs_indexing_policy"), + [ + (EPHEMERAL_DEFAULT_IDX_NAME_D2, SetupMode.ASYNC, None), + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.ASYNC, {"deny": ["test"]}), + ], + ids=[ + "default_upgrade_no_policy_async", + "deny_list_upgrade_same_policy_async", + ], ) + async def test_upgrade_to_gvs_success_async( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + *, + gvs_setup_mode: SetupMode, + collection_name: str, + gvs_indexing_policy: dict[str, Any] | None, + ) -> None: + # Create vector store using SetupMode.ASYNC + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=_vs_indexing_policy( + collection_name=collection_name + ), + setup_mode=SetupMode.ASYNC, + ) + + # load a document to the vector store + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + await v_store.aadd_documents([doc_al]) + # get the document from the vector store + v_doc = await v_store.aget_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content -@pytest.mark.skipif( - not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" -) -class TestUpgradeToGraphVectorStore: + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + gv_store = AstraDBGraphVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=gvs_indexing_policy, + setup_mode=gvs_setup_mode, + ) + + # get the document from the GRAPH vector store + gv_doc = await gv_store.aget_by_document_id(document_id=doc_id) + assert gv_doc is not None + assert gv_doc.page_content == doc_al.page_content + + @pytest.mark.usefixtures("ephemeral_indexing_collections_cleaner") @pytest.mark.parametrize( - ("store_name", "indexing_policy", "expect_success"), + ("collection_name", "gvs_setup_mode", "gvs_indexing_policy"), [ - ("default_vector_store_d2", None, True), - ("vector_store_d2_with_indexing_allow_list", {"allow": ["test"]}, False), - ("vector_store_d2_with_indexing_allow_list", None, False), - ("vector_store_d2_with_indexing_deny_list", {"deny": ["test"]}, True), - ("vector_store_d2_with_indexing_deny_list", None, False), + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.SYNC, {"allow": ["test"]}), + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.SYNC, None), + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.SYNC, None), + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.OFF, {"allow": ["test"]}), + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.OFF, None), ], ids=[ - "default_store_upgrade_should_succeed", - "allow_store_upgrade_with_allow_policy_should_fail", - "allow_store_upgrade_with_no_policy_should_fail", - "deny_store_upgrade_with_deny_policy_should_succeed", - "deny_store_upgrade_with_no_policy_should_fail", + "allow_list_upgrade_same_policy_sync", + "allow_list_upgrade_change_policy_sync", + "deny_list_upgrade_change_policy_sync", + "allow_list_upgrade_same_policy_off", + "allow_list_upgrade_change_policy_off", ], ) - def test_upgrade_to_gvs( + def test_upgrade_to_gvs_failure_sync( self, astra_db_credentials: AstraDBCredentials, embedding_d2: Embeddings, - ephemeral_collection_cleaner_d2: str, *, - store_name: str, - indexing_policy: dict[str, Any] | None, - expect_success: bool, - request: pytest.FixtureRequest, + gvs_setup_mode: SetupMode, + collection_name: str, + gvs_indexing_policy: dict[str, Any] | None, ) -> None: - # Create Vector Store, load a document - v_store: AstraDBVectorStore = request.getfixturevalue(store_name) + # Create vector store using SetupMode.SYNC + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=_vs_indexing_policy( + collection_name=collection_name + ), + setup_mode=SetupMode.SYNC, + ) + + # load a document to the vector store doc_id = "AL" doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) v_store.add_documents([doc_al]) - # Try to create a GRAPH Vector Store using the existing collection from above - try: - gv_store = AstraDBGraphVectorStore( + # get the document from the vector store + v_doc = v_store.get_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content + + expected_msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + with pytest.raises(ValueError, match=expected_msg): + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + _ = AstraDBGraphVectorStore( embedding=embedding_d2, - collection_name=ephemeral_collection_cleaner_d2, + collection_name=collection_name, token=StaticTokenProvider(astra_db_credentials["token"]), api_endpoint=astra_db_credentials["api_endpoint"], namespace=astra_db_credentials["namespace"], environment=astra_db_credentials["environment"], - collection_indexing_policy=indexing_policy, + collection_indexing_policy=gvs_indexing_policy, + setup_mode=gvs_setup_mode, ) - if not expect_success: - pytest.fail("Expected ValueError but none was raised") - - except ValueError as value_error: - if expect_success: - pytest.fail(f"Unexpected ValueError raised: {value_error}") - else: - assert ( # noqa: PT017 - str(value_error) - == "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 - ) - - if expect_success: - doc = gv_store.get_by_document_id(document_id=doc_id) - assert doc is not None - assert doc.page_content == doc_al.page_content + @pytest.mark.usefixtures("ephemeral_indexing_collections_cleaner") + @pytest.mark.parametrize( + ("collection_name", "gvs_setup_mode", "gvs_indexing_policy"), + [ + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.ASYNC, {"allow": ["test"]}), + (EPHEMERAL_ALLOW_IDX_NAME_D2, SetupMode.ASYNC, None), + (EPHEMERAL_DENY_IDX_NAME_D2, SetupMode.ASYNC, None), + ], + ids=[ + "allow_list_upgrade_same_policy_async", + "allow_list_upgrade_change_policy_async", + "deny_list_upgrade_change_policy_async", + ], + ) + async def test_upgrade_to_gvs_failure_async( + self, + astra_db_credentials: AstraDBCredentials, + embedding_d2: Embeddings, + *, + gvs_setup_mode: SetupMode, + collection_name: str, + gvs_indexing_policy: dict[str, Any] | None, + ) -> None: + # Create vector store using SetupMode.ASYNC + v_store = AstraDBVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=_vs_indexing_policy( + collection_name=collection_name + ), + setup_mode=SetupMode.ASYNC, + ) + + # load a document to the vector store + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + await v_store.aadd_documents([doc_al]) + + # get the document from the vector store + v_doc = await v_store.aget_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content + + expected_msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + with pytest.raises(ValueError, match=expected_msg): + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + _ = AstraDBGraphVectorStore( + embedding=embedding_d2, + collection_name=collection_name, + token=StaticTokenProvider(astra_db_credentials["token"]), + api_endpoint=astra_db_credentials["api_endpoint"], + namespace=astra_db_credentials["namespace"], + environment=astra_db_credentials["environment"], + collection_indexing_policy=gvs_indexing_policy, + setup_mode=gvs_setup_mode, + ) diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index bd1a3ba..f707ec3 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -29,6 +29,43 @@ from .conftest import AstraDBCredentials +@pytest.fixture +def metadata_documents() -> list[Document]: + """Documents for metadata and id tests""" + return [ + Document( + id="q", + page_content="[1,2]", + metadata={"ord": ord("q"), "group": "consonant", "letter": "q"}, + ), + Document( + id="w", + page_content="[3,4]", + metadata={"ord": ord("w"), "group": "consonant", "letter": "w"}, + ), + Document( + id="r", + page_content="[5,6]", + metadata={"ord": ord("r"), "group": "consonant", "letter": "r"}, + ), + Document( + id="e", + page_content="[-1,2]", + metadata={"ord": ord("e"), "group": "vowel", "letter": "e"}, + ), + Document( + id="i", + page_content="[-3,4]", + metadata={"ord": ord("i"), "group": "vowel", "letter": "i"}, + ), + Document( + id="o", + page_content="[-5,6]", + metadata={"ord": ord("o"), "group": "vowel", "letter": "o"}, + ), + ] + + @pytest.mark.skipif( not astra_db_env_vars_available(), reason="Missing Astra DB env. vars" ) @@ -1100,41 +1137,15 @@ async def test_astradb_vectorstore_mmr_vectorize_async( "vector_store_vz", ], ) - def test_astradb_vectorstore_metadata( + def test_astradb_vectorstore_metadata_filter( self, vector_store: str, request: pytest.FixtureRequest, + metadata_documents: list[Document], ) -> None: """Metadata filtering.""" vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) - vstore.add_documents( - [ - Document( - page_content="[1,2]", - metadata={"ord": ord("q"), "group": "consonant", "letter": "q"}, - ), - Document( - page_content="[3,4]", - metadata={"ord": ord("w"), "group": "consonant", "letter": "w"}, - ), - Document( - page_content="[5,6]", - metadata={"ord": ord("r"), "group": "consonant", "letter": "r"}, - ), - Document( - page_content="[-1,2]", - metadata={"ord": ord("e"), "group": "vowel", "letter": "e"}, - ), - Document( - page_content="[-3,4]", - metadata={"ord": ord("i"), "group": "vowel", "letter": "i"}, - ), - Document( - page_content="[-5,6]", - metadata={"ord": ord("o"), "group": "vowel", "letter": "o"}, - ), - ] - ) + vstore.add_documents(metadata_documents) # no filters res0 = vstore.similarity_search("[-1,-1]", k=10) assert {doc.metadata["letter"] for doc in res0} == set("qwreio") @@ -1167,6 +1178,148 @@ def test_astradb_vectorstore_metadata( ) assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + def test_astradb_vectorstore_metadata_search_sync( + self, + vector_store: str, + request: pytest.FixtureRequest, + metadata_documents: list[Document], + ) -> None: + """Metadata Search""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents(metadata_documents) + # no filters + res0 = vstore.metadata_search(filter={}, n=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.metadata_search( + n=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = vstore.metadata_search( + n=10, + filter={"group": "consonant", "ord": ord("q")}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = vstore.metadata_search( + n=10, + filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, + ) + assert res3 == [] + # filter with logical operator + res4 = vstore.metadata_search( + n=10, + filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, + ) + assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + async def test_astradb_vectorstore_metadata_search_async( + self, + vector_store: str, + request: pytest.FixtureRequest, + metadata_documents: list[Document], + ) -> None: + """Metadata Search""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_documents(metadata_documents) + # no filters + res0 = await vstore.ametadata_search(filter={}, n=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.metadata_search( + n=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = await vstore.ametadata_search( + n=10, + filter={"group": "consonant", "ord": ord("q")}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = await vstore.ametadata_search( + n=10, + filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, + ) + assert res3 == [] + # filter with logical operator + res4 = await vstore.ametadata_search( + n=10, + filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, + ) + assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + def test_astradb_vectorstore_get_by_document_id_sync( + self, + vector_store: str, + request: pytest.FixtureRequest, + metadata_documents: list[Document], + ) -> None: + """Get by document_id""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + vstore.add_documents(metadata_documents) + # invalid id + invalid = vstore.get_by_document_id(document_id="z") + assert invalid is None + # valid id + valid = vstore.get_by_document_id(document_id="q") + assert isinstance(valid, Document) + assert valid.id == "q" + assert valid.page_content == "[1,2]" + assert valid.metadata["group"] == "consonant" + assert valid.metadata["letter"] == "q" + + @pytest.mark.parametrize( + "vector_store", + [ + "vector_store_d2", + "vector_store_vz", + ], + ) + async def test_astradb_vectorstore_get_by_document_id_async( + self, + vector_store: str, + request: pytest.FixtureRequest, + metadata_documents: list[Document], + ) -> None: + """Get by document_id""" + vstore: AstraDBVectorStore = request.getfixturevalue(vector_store) + await vstore.aadd_documents(metadata_documents) + # invalid id + invalid = await vstore.aget_by_document_id(document_id="z") + assert invalid is None + # valid id + valid = await vstore.aget_by_document_id(document_id="q") + assert isinstance(valid, Document) + assert valid.id == "q" + assert valid.page_content == "[1,2]" + assert valid.metadata["group"] == "consonant" + assert valid.metadata["letter"] == "q" + @pytest.mark.parametrize( ("is_vectorize", "vector_store", "texts", "query"), [ From 26e4475b5e3f30ee418b24e03058baf7545c8060 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 15 Oct 2024 16:46:00 +0200 Subject: [PATCH 07/11] improve error msg --- libs/astradb/langchain_astradb/graph_vectorstores.py | 9 +++++++-- .../test_upgrade_to_graphvectorstore.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 69ebd44..322793d 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -342,14 +342,19 @@ def __init__( test_vs.metadata_search( filter={self.metadata_incoming_links_key: "test"}, n=1 ) - except BaseException as exp: + except ValueError as exp: # determine if error is because of a un-indexed column. Ref: # https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#considerations-for-selective-indexing error_message = str(exp).lower() if ("unindexed filter path" in error_message) or ( "incompatible with the requested indexing policy" in error_message ): - msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + msg = ( + "The collection configuration is incompatible with vector graph " + "store. Please create a new collection and make sure the path " + f"`{self.metadata_incoming_links_key}` is not excluded by indexing." + ) + raise ValueError(msg) from exp raise exp # noqa: TRY201 diff --git a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py index 5238f0b..3164513 100644 --- a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py @@ -232,7 +232,11 @@ def test_upgrade_to_gvs_failure_sync( assert v_doc is not None assert v_doc.page_content == doc_al.page_content - expected_msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + expected_msg = ( + "The collection configuration is incompatible with vector graph " + "store. Please create a new collection and make sure the path " + "`incoming_links` is not excluded by indexing." + ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy @@ -294,7 +298,11 @@ async def test_upgrade_to_gvs_failure_async( assert v_doc is not None assert v_doc.page_content == doc_al.page_content - expected_msg = "The collection configuration is incompatible with vector graph store. Please create a new collection." # noqa: E501 + expected_msg = ( + "The collection configuration is incompatible with vector graph " + "store. Please create a new collection and make sure the path " + "`incoming_links` is not excluded by indexing." + ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy From a0a8262550223975d60aa16e19004fc6f769aa84 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 15 Oct 2024 17:15:39 +0200 Subject: [PATCH 08/11] fixed unit test --- libs/astradb/tests/unit_tests/test_callers.py | 63 +++++++++++++++++-- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/libs/astradb/tests/unit_tests/test_callers.py b/libs/astradb/tests/unit_tests/test_callers.py index 63a80d8..f23ec48 100644 --- a/libs/astradb/tests/unit_tests/test_callers.py +++ b/libs/astradb/tests/unit_tests/test_callers.py @@ -294,6 +294,63 @@ def test_callers_component_vectorstore(self, httpserver: HTTPServer) -> None: ext_callers=[("ec0", "ev0")], ) + def test_callers_component_graphvectorstore(self, httpserver: HTTPServer) -> None: + """ + End-to-end testing of callers passed through the components. + The graphvectorstore, which can also do autodetect operations, + requires separate handling. + """ + base_endpoint = httpserver.url_for("/") + base_path = "/v1/ks" + + # through the init flow + httpserver.expect_oneshot_request( + base_path, + method="POST", + headers={ + "User-Agent": "ec0/ev0", + }, + header_value_matcher=hv_prefix_matcher_factory( + COMPONENT_NAME_GRAPHVECTORSTORE + ), + ).respond_with_json({}) + + # the metadata_search test call + httpserver.expect_oneshot_request( + base_path + "/my_graph_coll", + method="POST", + headers={ + "User-Agent": "ec0/ev0", + }, + header_value_matcher=hv_prefix_matcher_factory( + COMPONENT_NAME_GRAPHVECTORSTORE + ), + ).respond_with_json( + { + "status": { + "collections": [ + { + "name": "my_graph_coll", + "options": {"vector": {"dimension": 2}}, + } + ] + }, + "data": { + "nextPageState": None, + "documents": [], + }, + } + ) + + AstraDBGraphVectorStore( + collection_name="my_graph_coll", + api_endpoint=base_endpoint, + environment=Environment.OTHER, + namespace="ks", + embedding=ParserEmbeddings(2), + ext_callers=[("ec0", "ev0")], + ) + @pytest.mark.parametrize( ("component_class", "component_name", "kwargs"), [ @@ -304,11 +361,6 @@ def test_callers_component_vectorstore(self, httpserver: HTTPServer) -> None: COMPONENT_NAME_CHATMESSAGEHISTORY, {"session_id": "x"}, ), - ( - AstraDBGraphVectorStore, - COMPONENT_NAME_GRAPHVECTORSTORE, - {"embedding": ParserEmbeddings(2)}, - ), ( AstraDBSemanticCache, COMPONENT_NAME_SEMANTICCACHE, @@ -320,7 +372,6 @@ def test_callers_component_vectorstore(self, httpserver: HTTPServer) -> None: "Byte store", "Cache", "Chat message history", - "Graph vector store", "Semantic cache", "Store", ], From cb9f56c4265506549a8f3971ee39660b4c6e46ed Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Tue, 15 Oct 2024 17:48:06 +0200 Subject: [PATCH 09/11] added test of asimilarity_search_with_embedding_id_by_vector --- libs/astradb/langchain_astradb/utils/astradb.py | 2 +- .../tests/integration_tests/test_vectorstore.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/libs/astradb/langchain_astradb/utils/astradb.py b/libs/astradb/langchain_astradb/utils/astradb.py index 442c53b..66d88dd 100644 --- a/libs/astradb/langchain_astradb/utils/astradb.py +++ b/libs/astradb/langchain_astradb/utils/astradb.py @@ -366,7 +366,7 @@ def __init__( self.database.drop_collection(collection_name) if inspect.isawaitable(embedding_dimension): msg = ( - "Cannot use an awaitable embedding_dimension with sync_setup " + "Cannot use an awaitable embedding_dimension with async_setup " "set to False" ) raise ValueError(msg) diff --git a/libs/astradb/tests/integration_tests/test_vectorstore.py b/libs/astradb/tests/integration_tests/test_vectorstore.py index f707ec3..6a6a482 100644 --- a/libs/astradb/tests/integration_tests/test_vectorstore.py +++ b/libs/astradb/tests/integration_tests/test_vectorstore.py @@ -824,6 +824,7 @@ async def test_astradb_vectorstore_massive_insert_replace_async( all_ids = [f"doc_{idx}" for idx in range(full_size)] all_texts = [f"[0,{idx + 1}]" for idx in range(full_size)] + all_embeddings = [[0, idx + 1] for idx in range(full_size)] # massive insertion on empty group0_ids = all_ids[0:first_group_size] @@ -855,6 +856,16 @@ async def test_astradb_vectorstore_massive_insert_replace_async( ) for doc, _, doc_id in full_results: assert doc.page_content == expected_text_by_id[doc_id] + expected_embedding_by_id = dict(zip(all_ids, all_embeddings)) + full_results_with_embeddings = ( + await vector_store_d2.asimilarity_search_with_embedding_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + ) + for doc, embedding, doc_id in full_results_with_embeddings: + assert doc.page_content == expected_text_by_id[doc_id] + assert embedding == expected_embedding_by_id[doc_id] def test_astradb_vectorstore_delete_by_metadata_sync( self, From 01691cf8060a8d8860e99a991e582c6931f5afda Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Wed, 16 Oct 2024 16:43:55 +0200 Subject: [PATCH 10/11] made suggested fixes --- libs/astradb/langchain_astradb/graph_vectorstores.py | 7 +++---- .../integration_tests/test_upgrade_to_graphvectorstore.py | 8 ++++---- libs/astradb/tests/unit_tests/test_callers.py | 8 -------- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index 322793d..d387851 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -232,8 +232,7 @@ def __init__( component_name: the string identifying this specific component in the stack of usage info passed as the User-Agent string to the Data API. Defaults to "langchain_graphvectorstore", but can be overridden if this - component actually serves as the building block for another component - (such as a Graph Vector Store). + component actually serves as the building block for another component. astra_db_client: *DEPRECATED starting from version 0.3.5.* *Please use 'token', 'api_endpoint' and optionally 'environment'.* @@ -351,8 +350,8 @@ def __init__( ): msg = ( "The collection configuration is incompatible with vector graph " - "store. Please create a new collection and make sure the path " - f"`{self.metadata_incoming_links_key}` is not excluded by indexing." + "store. Please create a new collection and make sure the metadata " + f"path is not excluded by indexing." ) raise ValueError(msg) from exp diff --git a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py index 3164513..c3a5fe5 100644 --- a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py @@ -234,8 +234,8 @@ def test_upgrade_to_gvs_failure_sync( expected_msg = ( "The collection configuration is incompatible with vector graph " - "store. Please create a new collection and make sure the path " - "`incoming_links` is not excluded by indexing." + "store. Please create a new collection and make sure the metadata " + f"path is not excluded by indexing." ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above @@ -300,8 +300,8 @@ async def test_upgrade_to_gvs_failure_async( expected_msg = ( "The collection configuration is incompatible with vector graph " - "store. Please create a new collection and make sure the path " - "`incoming_links` is not excluded by indexing." + "store. Please create a new collection and make sure the metadata " + f"path is not excluded by indexing." ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above diff --git a/libs/astradb/tests/unit_tests/test_callers.py b/libs/astradb/tests/unit_tests/test_callers.py index f23ec48..c511c28 100644 --- a/libs/astradb/tests/unit_tests/test_callers.py +++ b/libs/astradb/tests/unit_tests/test_callers.py @@ -327,14 +327,6 @@ def test_callers_component_graphvectorstore(self, httpserver: HTTPServer) -> Non ), ).respond_with_json( { - "status": { - "collections": [ - { - "name": "my_graph_coll", - "options": {"vector": {"dimension": 2}}, - } - ] - }, "data": { "nextPageState": None, "documents": [], From cf5fe11172e1d9c9078b88b2d3a9b869fc607d8e Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Wed, 16 Oct 2024 17:06:11 +0200 Subject: [PATCH 11/11] fix lint --- libs/astradb/langchain_astradb/graph_vectorstores.py | 2 +- .../integration_tests/test_upgrade_to_graphvectorstore.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/astradb/langchain_astradb/graph_vectorstores.py b/libs/astradb/langchain_astradb/graph_vectorstores.py index d387851..0556d18 100644 --- a/libs/astradb/langchain_astradb/graph_vectorstores.py +++ b/libs/astradb/langchain_astradb/graph_vectorstores.py @@ -351,7 +351,7 @@ def __init__( msg = ( "The collection configuration is incompatible with vector graph " "store. Please create a new collection and make sure the metadata " - f"path is not excluded by indexing." + "path is not excluded by indexing." ) raise ValueError(msg) from exp diff --git a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py index c3a5fe5..90561ba 100644 --- a/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py +++ b/libs/astradb/tests/integration_tests/test_upgrade_to_graphvectorstore.py @@ -235,7 +235,7 @@ def test_upgrade_to_gvs_failure_sync( expected_msg = ( "The collection configuration is incompatible with vector graph " "store. Please create a new collection and make sure the metadata " - f"path is not excluded by indexing." + "path is not excluded by indexing." ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above @@ -301,7 +301,7 @@ async def test_upgrade_to_gvs_failure_async( expected_msg = ( "The collection configuration is incompatible with vector graph " "store. Please create a new collection and make sure the metadata " - f"path is not excluded by indexing." + "path is not excluded by indexing." ) with pytest.raises(ValueError, match=expected_msg): # Create a GRAPH Vector Store using the existing collection from above