Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update AstraDBGraphVectorStore to match implementation of CassandraGraphVectorStore #95

Merged
merged 11 commits into from
Oct 16, 2024
1,216 changes: 900 additions & 316 deletions libs/astradb/langchain_astradb/graph_vectorstores.py

Large diffs are not rendered by default.

41 changes: 23 additions & 18 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,20 @@ def __init__(
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = list(self.database.list_collections())
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
raise data_api_exception # noqa: TRY201
except ValueError as validation_error:
raise validation_error from data_api_exception
hemidactylus marked this conversation as resolved.
Show resolved Hide resolved

async def _asetup_db(
self,
Expand Down Expand Up @@ -420,20 +422,23 @@ async def _asetup_db(
service=collection_vector_service_options,
check_exists=False,
)
except DataAPIException:
except DataAPIException as data_api_exception:
# possibly the collection is preexisting and may have legacy,
# or custom, indexing settings: verify
collection_descriptors = [
coll_desc async for coll_desc in self.async_database.list_collections()
]
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise
try:
if not self._validate_indexing_policy(
collection_descriptors=collection_descriptors,
collection_name=self.collection_name,
requested_indexing_policy=requested_indexing_policy,
default_indexing_policy=default_indexing_policy,
):
# other reasons for the exception
raise data_api_exception # noqa: TRY201
except ValueError as validation_error:
raise validation_error from data_api_exception

@staticmethod
def _validate_indexing_policy(
Expand Down
110 changes: 0 additions & 110 deletions libs/astradb/langchain_astradb/utils/mmr.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from typing import TYPE_CHECKING, Iterable

import numpy as np

from langchain_astradb.utils.mmr import cosine_similarity
from langchain_community.utils.math import cosine_similarity

if TYPE_CHECKING:
from langchain_core.documents import Document
from numpy.typing import NDArray


Expand All @@ -27,6 +25,7 @@ def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]:
@dataclasses.dataclass
class _Candidate:
id: str
similarity: float
weighted_similarity: float
weighted_redundancy: float
score: float = dataclasses.field(init=False)
Expand Down Expand Up @@ -72,6 +71,13 @@ class MmrHelper:

selected_ids: list[str]
"""List of selected IDs (in selection order)."""

selected_mmr_scores: list[float]
"""List of MMR score at the time each document is selected."""

selected_similarity_scores: list[float]
"""List of similarity score for each selected document."""

selected_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each selected node."""

Expand All @@ -82,8 +88,6 @@ class MmrHelper:

Same order as rows in `candidate_embeddings`.
"""
candidate_docs: dict[str, Document]
"""Dict containing the documents associated with each candidate ID."""
candidate_embeddings: NDArray[np.float32]
"""(N, dim) ndarray with a row for each candidate."""

Expand All @@ -106,12 +110,13 @@ def __init__(
self.score_threshold = score_threshold

self.selected_ids = []
self.selected_similarity_scores = []
self.selected_mmr_scores = []

# List of selected embeddings (in selection order).
self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32)

self.candidate_id_to_index = {}
self.candidate_docs = {}

# List of the candidates.
self.candidates = []
Expand All @@ -130,11 +135,11 @@ def _already_selected_embeddings(self) -> NDArray[np.float32]:
selected = len(self.selected_ids)
return np.vsplit(self.selected_embeddings, [selected])[0]

def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]:
"""Pop the candidate with the given ID.

Returns:
The embedding of the candidate.
The similarity score and embedding of the candidate.
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
Expand All @@ -150,12 +155,15 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
# candidate_embeddings.
last_index = self.candidate_embeddings.shape[0] - 1

similarity = 0.0
if index == last_index:
# Already the last item. We don't need to swap.
self.candidates.pop()
similarity = self.candidates.pop().similarity
else:
self.candidate_embeddings[index] = self.candidate_embeddings[last_index]

similarity = self.candidates[index].similarity

old_last = self.candidates.pop()
self.candidates[index] = old_last
self.candidate_id_to_index[old_last.id] = index
Expand All @@ -164,7 +172,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
0
]

return embedding
return similarity, embedding

def pop_best(self) -> str | None:
"""Select and pop the best item being considered.
Expand All @@ -179,11 +187,13 @@ def pop_best(self) -> str | None:

# Get the selection and remove from candidates.
selected_id = self.best_id
selected_embedding = self._pop_candidate(selected_id)
selected_similarity, selected_embedding = self._pop_candidate(selected_id)

# Add the ID and embedding to the selected information.
selection_index = len(self.selected_ids)
self.selected_ids.append(selected_id)
self.selected_mmr_scores.append(self.best_score)
self.selected_similarity_scores.append(selected_similarity)
self.selected_embeddings[selection_index] = selected_embedding

# Reset the best score / best ID.
Expand All @@ -203,9 +213,7 @@ def pop_best(self) -> str | None:

return selected_id

def add_candidates(
self, candidates: dict[str, tuple[Document, list[float]]]
) -> None:
def add_candidates(self, candidates: dict[str, list[float]]) -> None:
"""Add candidates to the consideration set."""
# Determine the keys to actually include.
# These are the candidates that aren't already selected
Expand All @@ -227,9 +235,8 @@ def add_candidates(
for index, candidate_id in enumerate(include_ids):
if candidate_id in include_ids:
self.candidate_id_to_index[candidate_id] = offset + index
doc, embedding = candidates[candidate_id]
embedding = candidates[candidate_id]
new_embeddings[index] = embedding
self.candidate_docs[candidate_id] = doc

# Compute the similarity to the query.
similarity = cosine_similarity(new_embeddings, self.query_embedding)
Expand All @@ -245,6 +252,7 @@ def add_candidates(
max_redundancy = redundancy[index].max()
candidate = _Candidate(
id=candidate_id,
similarity=similarity[index][0],
weighted_similarity=self.lambda_mult * similarity[index][0],
weighted_redundancy=self.lambda_mult_complement * max_redundancy,
)
Expand Down
Loading
Loading