diff --git a/libs/community/langchain_community/vectorstores/neo4j_vector.py b/libs/community/langchain_community/vectorstores/neo4j_vector.py index a6aa4390fd687..94b803437af80 100644 --- a/libs/community/langchain_community/vectorstores/neo4j_vector.py +++ b/libs/community/langchain_community/vectorstores/neo4j_vector.py @@ -15,13 +15,17 @@ Type, ) +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore from langchain_community.graphs import Neo4jGraph -from langchain_community.vectorstores.utils import DistanceStrategy +from langchain_community.vectorstores.utils import ( + DistanceStrategy, + maximal_marginal_relevance, +) DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE DISTANCE_MAPPING = { @@ -1042,17 +1046,35 @@ def similarity_search_with_score_by_vector( filter_params = {} if self._index_type == IndexType.RELATIONSHIP: - default_retrieval = ( - f"RETURN relationship.`{self.text_node_property}` AS text, score, " - f"relationship {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" - ) + if kwargs.get("return_embeddings"): + default_retrieval = ( + f"RETURN relationship.`{self.text_node_property}` AS text, score, " + f"relationship {{.*, `{self.text_node_property}`: Null, " + f"`{self.embedding_node_property}`: Null, id: Null, " + f"_embedding_: relationship.`{self.embedding_node_property}`}} " + "AS metadata" + ) + else: + default_retrieval = ( + f"RETURN relationship.`{self.text_node_property}` AS text, score, " + f"relationship {{.*, `{self.text_node_property}`: Null, " + f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" + ) + else: - default_retrieval = ( - f"RETURN node.`{self.text_node_property}` AS text, score, " - f"node {{.*, `{self.text_node_property}`: Null, " - f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" - ) + if kwargs.get("return_embeddings"): + default_retrieval = ( + f"RETURN node.`{self.text_node_property}` AS text, score, " + f"node {{.*, `{self.text_node_property}`: Null, " + f"`{self.embedding_node_property}`: Null, id: Null, " + f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata" + ) + else: + default_retrieval = ( + f"RETURN node.`{self.text_node_property}` AS text, score, " + f"node {{.*, `{self.text_node_property}`: Null, " + f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata" + ) retrieval_query = ( self.retrieval_query if self.retrieval_query else default_retrieval @@ -1083,6 +1105,20 @@ def similarity_search_with_score_by_vector( "Inspect the `retrieval_query` and ensure it doesn't " "return None for the `text` column" ) + if kwargs.get("return_embeddings") and any( + result["metadata"]["_embedding_"] is None for result in results + ): + if not self.retrieval_query: + raise ValueError( + f"Make sure that none of the `{self.embedding_node_property}` " + f"properties on nodes with label `{self.node_label}` " + "are missing or empty" + ) + else: + raise ValueError( + "Inspect the `retrieval_query` and ensure it doesn't " + "return None for the `_embedding_` metadata column" + ) docs = [ ( @@ -1487,6 +1523,64 @@ def from_existing_graph( break return store + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: search query text. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter: Filter on metadata properties, e.g. + { + "str_property": "foo", + "int_property": 123 + } + Returns: + List of Documents selected by maximal marginal relevance. + """ + # Embed the query + query_embedding = self.embedding.embed_query(query) + + # Fetch the initial documents + got_docs = self.similarity_search_with_score_by_vector( + embedding=query_embedding, + query=query, + k=fetch_k, + return_embeddings=True, + filter=filter, + **kwargs, + ) + + # Get the embeddings for the fetched documents + got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs] + + # Select documents using maximal marginal relevance + selected_indices = maximal_marginal_relevance( + np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k + ) + selected_docs = [got_docs[i][0] for i in selected_indices] + + # Remove embedding values from metadata + for doc in selected_docs: + del doc.metadata["_embedding_"] + + return selected_docs + def _select_relevance_score_fn(self) -> Callable[[float], float]: """ The 'correct' relevance function diff --git a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py index 761450f4a91e8..d8586e089ccf9 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py +++ b/libs/community/tests/integration_tests/vectorstores/test_neo4jvector.py @@ -14,7 +14,10 @@ _get_search_index_query, ) from langchain_community.vectorstores.utils import DistanceStrategy -from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings +from tests.integration_tests.vectorstores.fake_embeddings import ( + AngularTwoDimensionalEmbeddings, + FakeEmbeddings, +) from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import ( DOCUMENTS, TYPE_1_FILTERING_TEST_CASES, @@ -928,6 +931,45 @@ def test_neo4jvector_relationship_index_retrieval() -> None: drop_vector_indexes(docsearch) +def test_neo4j_max_marginal_relevance_search() -> None: + """ + Test end to end construction and MMR search. + The embedding function used here ensures `texts` become + the following vectors on a circle (numbered v0 through v3): + + ______ v2 + / \ + / | v1 + v3 | . | query + | / v0 + |______/ (N.B. very crude drawing) + + With fetch_k==3 and k==2, when query is at (1, ), + one expects that v2 and v0 are returned (in some order). + """ + texts = ["-0.124", "+0.127", "+0.25", "+1.0"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = Neo4jVector.from_texts( + texts, + metadatas=metadatas, + embedding=AngularTwoDimensionalEmbeddings(), + pre_delete_collection=True, + ) + + expected_set = { + ("+0.25", 2), + ("-0.124", 0), + } + + output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3) + output_set = { + (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output + } + assert output_set == expected_set + + drop_vector_indexes(docsearch) + + def test_neo4jvector_passing_graph_object() -> None: """Test end to end construction and search with passing graph object.""" graph = Neo4jGraph()