Skip to content

Commit

Permalink
Add mmr to neo4j vector (#25765)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Aug 27, 2024
1 parent 995305f commit f359e6b
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 12 deletions.
116 changes: 105 additions & 11 deletions libs/community/langchain_community/vectorstores/neo4j_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
Type,
)

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore

from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.vectorstores.utils import (
DistanceStrategy,
maximal_marginal_relevance,
)

DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
DISTANCE_MAPPING = {
Expand Down Expand Up @@ -1042,17 +1046,35 @@ def similarity_search_with_score_by_vector(
filter_params = {}

if self._index_type == IndexType.RELATIONSHIP:
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
if kwargs.get("return_embeddings"):
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null, "
f"_embedding_: relationship.`{self.embedding_node_property}`}} "
"AS metadata"
)
else:
default_retrieval = (
f"RETURN relationship.`{self.text_node_property}` AS text, score, "
f"relationship {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)

else:
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)
if kwargs.get("return_embeddings"):
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null, "
f"_embedding_: node.`{self.embedding_node_property}`}} AS metadata"
)
else:
default_retrieval = (
f"RETURN node.`{self.text_node_property}` AS text, score, "
f"node {{.*, `{self.text_node_property}`: Null, "
f"`{self.embedding_node_property}`: Null, id: Null }} AS metadata"
)

retrieval_query = (
self.retrieval_query if self.retrieval_query else default_retrieval
Expand Down Expand Up @@ -1083,6 +1105,20 @@ def similarity_search_with_score_by_vector(
"Inspect the `retrieval_query` and ensure it doesn't "
"return None for the `text` column"
)
if kwargs.get("return_embeddings") and any(
result["metadata"]["_embedding_"] is None for result in results
):
if not self.retrieval_query:
raise ValueError(
f"Make sure that none of the `{self.embedding_node_property}` "
f"properties on nodes with label `{self.node_label}` "
"are missing or empty"
)
else:
raise ValueError(
"Inspect the `retrieval_query` and ensure it doesn't "
"return None for the `_embedding_` metadata column"
)

docs = [
(
Expand Down Expand Up @@ -1487,6 +1523,64 @@ def from_existing_graph(
break
return store

def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[dict] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: search query text.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filter on metadata properties, e.g.
{
"str_property": "foo",
"int_property": 123
}
Returns:
List of Documents selected by maximal marginal relevance.
"""
# Embed the query
query_embedding = self.embedding.embed_query(query)

# Fetch the initial documents
got_docs = self.similarity_search_with_score_by_vector(
embedding=query_embedding,
query=query,
k=fetch_k,
return_embeddings=True,
filter=filter,
**kwargs,
)

# Get the embeddings for the fetched documents
got_embeddings = [doc.metadata["_embedding_"] for doc, _ in got_docs]

# Select documents using maximal marginal relevance
selected_indices = maximal_marginal_relevance(
np.array(query_embedding), got_embeddings, lambda_mult=lambda_mult, k=k
)
selected_docs = [got_docs[i][0] for i in selected_indices]

# Remove embedding values from metadata
for doc in selected_docs:
del doc.metadata["_embedding_"]

return selected_docs

def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
_get_search_index_query,
)
from langchain_community.vectorstores.utils import DistanceStrategy
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
from tests.integration_tests.vectorstores.fake_embeddings import (
AngularTwoDimensionalEmbeddings,
FakeEmbeddings,
)
from tests.integration_tests.vectorstores.fixtures.filtering_test_cases import (
DOCUMENTS,
TYPE_1_FILTERING_TEST_CASES,
Expand Down Expand Up @@ -928,6 +931,45 @@ def test_neo4jvector_relationship_index_retrieval() -> None:
drop_vector_indexes(docsearch)


def test_neo4j_max_marginal_relevance_search() -> None:
"""
Test end to end construction and MMR search.
The embedding function used here ensures `texts` become
the following vectors on a circle (numbered v0 through v3):
______ v2
/ \
/ | v1
v3 | . | query
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
"""
texts = ["-0.124", "+0.127", "+0.25", "+1.0"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = Neo4jVector.from_texts(
texts,
metadatas=metadatas,
embedding=AngularTwoDimensionalEmbeddings(),
pre_delete_collection=True,
)

expected_set = {
("+0.25", 2),
("-0.124", 0),
}

output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
output_set = {
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
}
assert output_set == expected_set

drop_vector_indexes(docsearch)


def test_neo4jvector_passing_graph_object() -> None:
"""Test end to end construction and search with passing graph object."""
graph = Neo4jGraph()
Expand Down

0 comments on commit f359e6b

Please sign in to comment.