Skip to content

Commit

Permalink
Update AstraDBGraphVectorStore to match implementation of CassandraGr…
Browse files Browse the repository at this point in the history
…aphVectorStore (#95)

* updated graph to match cassandraGraphVectorStore

* fix tests

* some fixes

* added initial upgrade test

* simplified insertion

* improved testing

* improve error msg

* fixed unit test

* added test of asimilarity_search_with_embedding_id_by_vector

* made suggested fixes

* fix lint
  • Loading branch information
epinzur authored Oct 16, 2024
1 parent 31c25d4 commit 89250df
Show file tree
Hide file tree
Showing 12 changed files with 1,993 additions and 542 deletions.
1,216 changes: 900 additions & 316 deletions libs/astradb/langchain_astradb/graph_vectorstores.py

Large diffs are not rendered by default.

41 changes: 23 additions & 18 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,20 @@ def __init__(
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = list(self.database.list_collections())
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
raise data_api_exception # noqa: TRY201
except ValueError as validation_error:
raise validation_error from data_api_exception

async def _asetup_db(
self,
Expand Down Expand Up @@ -420,20 +422,23 @@ async def _asetup_db(
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise data_api_exception # noqa: TRY201
except ValueError as validation_error:
raise validation_error from data_api_exception

@staticmethod
def _validate_indexing_policy(
Expand Down
110 changes: 0 additions & 110 deletions libs/astradb/langchain_astradb/utils/mmr.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from typing import TYPE_CHECKING, Iterable

import numpy as np

from langchain_astradb.utils.mmr import cosine_similarity
from langchain_community.utils.math import cosine_similarity

if TYPE_CHECKING:
from langchain_core.documents import Document
from numpy.typing import NDArray


Expand All @@ -27,6 +25,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
@dataclasses.dataclass
class _Candidate:
id: str
similarity: float
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
Expand Down Expand Up @@ -72,6 +71,13 @@ class MmrHelper:

selected_ids: list[str]
"""List of selected IDs (in selection order)."""

selected_mmr_scores: list[float]
"""List of MMR score at the time each document is selected."""

selected_similarity_scores: list[float]
"""List of similarity score for each selected document."""

selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""

Expand All @@ -82,8 +88,6 @@ class MmrHelper:
Same order as rows in `candidate_embeddings`.
"""
candidate_docs: dict[str, Document]
"""Dict containing the documents associated with each candidate ID."""
candidate_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each candidate."""

Expand All @@ -106,12 +110,13 @@ def __init__(
self.score_threshold = score_threshold

self.selected_ids = []
self.selected_similarity_scores = []
self.selected_mmr_scores = []

# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)

self.candidate_id_to_index = {}
self.candidate_docs = {}

# List of the candidates.
self.candidates = []
Expand All @@ -130,11 +135,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]

def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
"""Pop the candidate with the given ID.
Returns:
The embedding of the candidate.
The similarity score and embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
Expand All @@ -150,12 +155,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1

similarity = 0.0
if index == last_index:
# Already the last item. We don't need to swap.
self.candidates.pop()
similarity = self.candidates.pop().similarity
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]

similarity = self.candidates[index].similarity

old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
Expand All @@ -164,7 +172,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
0
]

return embedding
return similarity, embedding

def pop_best(self) -> str | None:
"""Select and pop the best item being considered.
Expand All @@ -179,11 +187,13 @@ def pop_best(self) -> str | None:

# Get the selection and remove from candidates.
selected_id = self.best_id
selected_embedding = self._pop_candidate(selected_id)
selected_similarity, selected_embedding = self._pop_candidate(selected_id)

# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_mmr_scores.append(self.best_score)
self.selected_similarity_scores.append(selected_similarity)
self.selected_embeddings[selection_index] = selected_embedding

# Reset the best score / best ID.
Expand All @@ -203,9 +213,7 @@ def pop_best(self) -> str | None:

return selected_id

def add_candidates(
self, candidates: dict[str, tuple[Document, list[float]]]
) -> None:
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
Expand All @@ -227,9 +235,8 @@ def add_candidates(
for index, candidate_id in enumerate(include_ids):
if candidate_id in include_ids:
self.candidate_id_to_index[candidate_id] = offset + index
doc, embedding = candidates[candidate_id]
embedding = candidates[candidate_id]
new_embeddings[index] = embedding
self.candidate_docs[candidate_id] = doc

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self.query_embedding)
Expand All @@ -245,6 +252,7 @@ def add_candidates(
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
Expand Down
Loading

0 comments on commit 89250df

Please sign in to comment.