From fb940d11df5f275bb0a82f725f76643fd9594307 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 17 Jan 2024 19:37:07 +0100 Subject: [PATCH] community[patch]: Use newer MetadataVectorCassandraTable in Cassandra vector store (#15987) as VectorTable is deprecated Tested manually with `test_cassandra.py` vector store integration test. --- .../vectorstores/cassandra.py | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 2a062014339a6..041f699520083 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -75,7 +75,7 @@ def __init__( ttl_seconds: Optional[int] = None, ) -> None: try: - from cassio.vector import VectorTable + from cassio.table import MetadataVectorCassandraTable except (ImportError, ModuleNotFoundError): raise ImportError( "Could not import cassio python package. " @@ -90,11 +90,12 @@ def __init__( # self._embedding_dimension = None # - self.table = VectorTable( + self.table = MetadataVectorCassandraTable( session=session, keyspace=keyspace, table=table_name, - embedding_dimension=self._get_embedding_dimension(), + vector_dimension=self._get_embedding_dimension(), + metadata_indexing="all", primary_key_type="TEXT", ) @@ -127,7 +128,7 @@ def clear(self) -> None: self.table.clear() def delete_by_document_id(self, document_id: str) -> None: - return self.table.delete(document_id) + return self.table.delete(row_id=document_id) def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: """Delete by vector IDs. @@ -188,7 +189,11 @@ def add_texts( futures = [ self.table.put_async( - text, embedding_vector, text_id, metadata, ttl_seconds + row_id=text_id, + body_blob=text, + vector=embedding_vector, + metadata=metadata or {}, + ttl_seconds=ttl_seconds, ) for text, embedding_vector, text_id, metadata in zip( batch_texts, batch_embedding_vectors, batch_ids, batch_metadatas @@ -215,11 +220,10 @@ def similarity_search_with_score_id_by_vector( """ search_metadata = self._filter_to_metadata(filter) # - hits = self.table.search( - embedding_vector=embedding, - top_k=k, + hits = self.table.metric_ann_search( + vector=embedding, + n=k, metric="cos", - metric_threshold=None, metadata=search_metadata, ) # We stick to 'cos' distance as it can be normalized on a 0-1 axis @@ -227,11 +231,11 @@ def similarity_search_with_score_id_by_vector( return [ ( Document( - page_content=hit["document"], + page_content=hit["body_blob"], metadata=hit["metadata"], ), 0.5 + 0.5 * hit["distance"], - hit["document_id"], + hit["row_id"], ) for hit in hits ] @@ -340,31 +344,32 @@ def max_marginal_relevance_search_by_vector( """ search_metadata = self._filter_to_metadata(filter) - prefetchHits = self.table.search( - embedding_vector=embedding, - top_k=fetch_k, - metric="cos", - metric_threshold=None, - metadata=search_metadata, + prefetch_hits = list( + self.table.metric_ann_search( + vector=embedding, + n=fetch_k, + metric="cos", + metadata=search_metadata, + ) ) # let the mmr utility pick the *indices* in the above array - mmrChosenIndices = maximal_marginal_relevance( + mmr_chosen_indices = maximal_marginal_relevance( np.array(embedding, dtype=np.float32), - [pfHit["embedding_vector"] for pfHit in prefetchHits], + [pf_hit["vector"] for pf_hit in prefetch_hits], k=k, lambda_mult=lambda_mult, ) - mmrHits = [ - pfHit - for pfIndex, pfHit in enumerate(prefetchHits) - if pfIndex in mmrChosenIndices + mmr_hits = [ + pf_hit + for pf_index, pf_hit in enumerate(prefetch_hits) + if pf_index in mmr_chosen_indices ] return [ Document( - page_content=hit["document"], + page_content=hit["body_blob"], metadata=hit["metadata"], ) - for hit in mmrHits + for hit in mmr_hits ] def max_marginal_relevance_search(