Skip to content

Commit

Permalink
community[patch]: Use newer MetadataVectorCassandraTable in Cassandra…
Browse files Browse the repository at this point in the history
… vector store (#15987)

as VectorTable is deprecated

Tested manually with `test_cassandra.py` vector store integration test.
  • Loading branch information
cbornet authored Jan 17, 2024
1 parent 1fa056c commit fb940d1
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions libs/community/langchain_community/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
ttl_seconds: Optional[int] = None,
) -> None:
try:
from cassio.vector import VectorTable
from cassio.table import MetadataVectorCassandraTable
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import cassio python package. "
Expand All @@ -90,11 +90,12 @@ def __init__(
#
self._embedding_dimension = None
#
self.table = VectorTable(
self.table = MetadataVectorCassandraTable(
session=session,
keyspace=keyspace,
table=table_name,
embedding_dimension=self._get_embedding_dimension(),
vector_dimension=self._get_embedding_dimension(),
metadata_indexing="all",
primary_key_type="TEXT",
)

Expand Down Expand Up @@ -127,7 +128,7 @@ def clear(self) -> None:
self.table.clear()

def delete_by_document_id(self, document_id: str) -> None:
return self.table.delete(document_id)
return self.table.delete(row_id=document_id)

def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete by vector IDs.
Expand Down Expand Up @@ -188,7 +189,11 @@ def add_texts(

futures = [
self.table.put_async(
text, embedding_vector, text_id, metadata, ttl_seconds
row_id=text_id,
body_blob=text,
vector=embedding_vector,
metadata=metadata or {},
ttl_seconds=ttl_seconds,
)
for text, embedding_vector, text_id, metadata in zip(
batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas
Expand All @@ -215,23 +220,22 @@ def similarity_search_with_score_id_by_vector(
"""
search_metadata = self._filter_to_metadata(filter)
#
hits = self.table.search(
embedding_vector=embedding,
top_k=k,
hits = self.table.metric_ann_search(
vector=embedding,
n=k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
)
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
# (1=most relevant), as required by this class' contract.
return [
(
Document(
page_content=hit["document"],
page_content=hit["body_blob"],
metadata=hit["metadata"],
),
0.5 + 0.5 * hit["distance"],
hit["document_id"],
hit["row_id"],
)
for hit in hits
]
Expand Down Expand Up @@ -340,31 +344,32 @@ def max_marginal_relevance_search_by_vector(
"""
search_metadata = self._filter_to_metadata(filter)

prefetchHits = self.table.search(
embedding_vector=embedding,
top_k=fetch_k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
prefetch_hits = list(
self.table.metric_ann_search(
vector=embedding,
n=fetch_k,
metric="cos",
metadata=search_metadata,
)
)
# let the mmr utility pick the *indices* in the above array
mmrChosenIndices = maximal_marginal_relevance(
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[pfHit["embedding_vector"] for pfHit in prefetchHits],
[pf_hit["vector"] for pf_hit in prefetch_hits],
k=k,
lambda_mult=lambda_mult,
)
mmrHits = [
pfHit
for pfIndex, pfHit in enumerate(prefetchHits)
if pfIndex in mmrChosenIndices
mmr_hits = [
pf_hit
for pf_index, pf_hit in enumerate(prefetch_hits)
if pf_index in mmr_chosen_indices
]
return [
Document(
page_content=hit["document"],
page_content=hit["body_blob"],
metadata=hit["metadata"],
)
for hit in mmrHits
for hit in mmr_hits
]

def max_marginal_relevance_search(
Expand Down

0 comments on commit fb940d1

Please sign in to comment.