Skip to content

Commit

Permalink
Pinecone - dense retriever (#145)
Browse files Browse the repository at this point in the history
* filters

* Update integrations/pinecone/src/pinecone_haystack/filters.py

Co-authored-by: Massimiliano Pippi <[email protected]>

* improv from PR review

* fmt

* dense retriever!

---------

Co-authored-by: Massimiliano Pippi <[email protected]>
  • Loading branch information
anakin87 and masci authored Dec 22, 2023
1 parent 30e0b7c commit 5668b83
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 0 deletions.
72 changes: 72 additions & 0 deletions integrations/pinecone/src/pinecone_haystack/dense_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document

from pinecone_haystack.document_store import PineconeDocumentStore


@component
class PineconeDenseRetriever:
"""
Retrieves documents from the PineconeDocumentStore, based on their dense embeddings.
Needs to be connected to the PineconeDocumentStore.
"""

def __init__(
self,
*,
document_store: PineconeDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create the PineconeDenseRetriever component.
:param document_store: An instance of PineconeDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param top_k: Maximum number of Documents to return, defaults to 10.
:raises ValueError: If `document_store` is not an instance of PineconeDocumentStore.
"""
if not isinstance(document_store, PineconeDocumentStore):
msg = "document_store must be an instance of PineconeDocumentStore"
raise ValueError(msg)

self.document_store = document_store
self.filters = filters or {}
self.top_k = top_k

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self.filters,
top_k=self.top_k,
document_store=self.document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PineconeDenseRetriever":
data["init_parameters"]["document_store"] = default_from_dict(
PineconeDocumentStore, data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
"""
Retrieve documents from the PineconeDocumentStore, based on their dense embeddings.
:param query_embedding: Embedding of the query.
:return: List of Document similar to `query_embedding`.
"""
docs = self.document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self.filters,
top_k=self.top_k,
)
return {"documents": docs}
100 changes: 100 additions & 0 deletions integrations/pinecone/tests/test_dense_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

from haystack.dataclasses import Document

from pinecone_haystack.dense_retriever import PineconeDenseRetriever
from pinecone_haystack.document_store import PineconeDocumentStore


def test_init_default():
mock_store = Mock(spec=PineconeDocumentStore)
retriever = PineconeDenseRetriever(document_store=mock_store)
assert retriever.document_store == mock_store
assert retriever.filters == {}
assert retriever.top_k == 10


@patch("pinecone_haystack.document_store.pinecone")
def test_to_dict(mock_pinecone):
mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512}
document_store = PineconeDocumentStore(
api_key="test-key",
environment="gcp-starter",
index="default",
namespace="test-namespace",
batch_size=50,
dimension=512,
)
retriever = PineconeDenseRetriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"environment": "gcp-starter",
"index": "default",
"namespace": "test-namespace",
"batch_size": 50,
"dimension": 512,
},
"type": "pinecone_haystack.document_store.PineconeDocumentStore",
},
"filters": {},
"top_k": 10,
},
}


@patch("pinecone_haystack.document_store.pinecone")
def test_from_dict(mock_pinecone, monkeypatch):
data = {
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"environment": "gcp-starter",
"index": "default",
"namespace": "test-namespace",
"batch_size": 50,
"dimension": 512,
},
"type": "pinecone_haystack.document_store.PineconeDocumentStore",
},
"filters": {},
"top_k": 10,
},
}

mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512}
monkeypatch.setenv("PINECONE_API_KEY", "test-key")
retriever = PineconeDenseRetriever.from_dict(data)

document_store = retriever.document_store
assert document_store.environment == "gcp-starter"
assert document_store.index == "default"
assert document_store.namespace == "test-namespace"
assert document_store.batch_size == 50
assert document_store.dimension == 512

assert retriever.filters == {}
assert retriever.top_k == 10


def test_run():
mock_store = Mock(spec=PineconeDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = PineconeDenseRetriever(document_store=mock_store)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={},
top_k=10,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]

0 comments on commit 5668b83

Please sign in to comment.