-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* filters * Update integrations/pinecone/src/pinecone_haystack/filters.py Co-authored-by: Massimiliano Pippi <[email protected]> * improv from PR review * fmt * dense retriever! --------- Co-authored-by: Massimiliano Pippi <[email protected]>
- Loading branch information
Showing
2 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
72 changes: 72 additions & 0 deletions
72
integrations/pinecone/src/pinecone_haystack/dense_retriever.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Any, Dict, List, Optional | ||
|
||
from haystack import component, default_from_dict, default_to_dict | ||
from haystack.dataclasses import Document | ||
|
||
from pinecone_haystack.document_store import PineconeDocumentStore | ||
|
||
|
||
@component | ||
class PineconeDenseRetriever: | ||
""" | ||
Retrieves documents from the PineconeDocumentStore, based on their dense embeddings. | ||
Needs to be connected to the PineconeDocumentStore. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
document_store: PineconeDocumentStore, | ||
filters: Optional[Dict[str, Any]] = None, | ||
top_k: int = 10, | ||
): | ||
""" | ||
Create the PineconeDenseRetriever component. | ||
:param document_store: An instance of PineconeDocumentStore. | ||
:param filters: Filters applied to the retrieved Documents. Defaults to None. | ||
:param top_k: Maximum number of Documents to return, defaults to 10. | ||
:raises ValueError: If `document_store` is not an instance of PineconeDocumentStore. | ||
""" | ||
if not isinstance(document_store, PineconeDocumentStore): | ||
msg = "document_store must be an instance of PineconeDocumentStore" | ||
raise ValueError(msg) | ||
|
||
self.document_store = document_store | ||
self.filters = filters or {} | ||
self.top_k = top_k | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
return default_to_dict( | ||
self, | ||
filters=self.filters, | ||
top_k=self.top_k, | ||
document_store=self.document_store.to_dict(), | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "PineconeDenseRetriever": | ||
data["init_parameters"]["document_store"] = default_from_dict( | ||
PineconeDocumentStore, data["init_parameters"]["document_store"] | ||
) | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run(self, query_embedding: List[float]): | ||
""" | ||
Retrieve documents from the PineconeDocumentStore, based on their dense embeddings. | ||
:param query_embedding: Embedding of the query. | ||
:return: List of Document similar to `query_embedding`. | ||
""" | ||
docs = self.document_store._embedding_retrieval( | ||
query_embedding=query_embedding, | ||
filters=self.filters, | ||
top_k=self.top_k, | ||
) | ||
return {"documents": docs} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from unittest.mock import Mock, patch | ||
|
||
from haystack.dataclasses import Document | ||
|
||
from pinecone_haystack.dense_retriever import PineconeDenseRetriever | ||
from pinecone_haystack.document_store import PineconeDocumentStore | ||
|
||
|
||
def test_init_default(): | ||
mock_store = Mock(spec=PineconeDocumentStore) | ||
retriever = PineconeDenseRetriever(document_store=mock_store) | ||
assert retriever.document_store == mock_store | ||
assert retriever.filters == {} | ||
assert retriever.top_k == 10 | ||
|
||
|
||
@patch("pinecone_haystack.document_store.pinecone") | ||
def test_to_dict(mock_pinecone): | ||
mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} | ||
document_store = PineconeDocumentStore( | ||
api_key="test-key", | ||
environment="gcp-starter", | ||
index="default", | ||
namespace="test-namespace", | ||
batch_size=50, | ||
dimension=512, | ||
) | ||
retriever = PineconeDenseRetriever(document_store=document_store) | ||
res = retriever.to_dict() | ||
assert res == { | ||
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", | ||
"init_parameters": { | ||
"document_store": { | ||
"init_parameters": { | ||
"environment": "gcp-starter", | ||
"index": "default", | ||
"namespace": "test-namespace", | ||
"batch_size": 50, | ||
"dimension": 512, | ||
}, | ||
"type": "pinecone_haystack.document_store.PineconeDocumentStore", | ||
}, | ||
"filters": {}, | ||
"top_k": 10, | ||
}, | ||
} | ||
|
||
|
||
@patch("pinecone_haystack.document_store.pinecone") | ||
def test_from_dict(mock_pinecone, monkeypatch): | ||
data = { | ||
"type": "pinecone_haystack.dense_retriever.PineconeDenseRetriever", | ||
"init_parameters": { | ||
"document_store": { | ||
"init_parameters": { | ||
"environment": "gcp-starter", | ||
"index": "default", | ||
"namespace": "test-namespace", | ||
"batch_size": 50, | ||
"dimension": 512, | ||
}, | ||
"type": "pinecone_haystack.document_store.PineconeDocumentStore", | ||
}, | ||
"filters": {}, | ||
"top_k": 10, | ||
}, | ||
} | ||
|
||
mock_pinecone.Index.return_value.describe_index_stats.return_value = {"dimension": 512} | ||
monkeypatch.setenv("PINECONE_API_KEY", "test-key") | ||
retriever = PineconeDenseRetriever.from_dict(data) | ||
|
||
document_store = retriever.document_store | ||
assert document_store.environment == "gcp-starter" | ||
assert document_store.index == "default" | ||
assert document_store.namespace == "test-namespace" | ||
assert document_store.batch_size == 50 | ||
assert document_store.dimension == 512 | ||
|
||
assert retriever.filters == {} | ||
assert retriever.top_k == 10 | ||
|
||
|
||
def test_run(): | ||
mock_store = Mock(spec=PineconeDocumentStore) | ||
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] | ||
retriever = PineconeDenseRetriever(document_store=mock_store) | ||
res = retriever.run(query_embedding=[0.5, 0.7]) | ||
mock_store._embedding_retrieval.assert_called_once_with( | ||
query_embedding=[0.5, 0.7], | ||
filters={}, | ||
top_k=10, | ||
) | ||
assert len(res) == 1 | ||
assert len(res["documents"]) == 1 | ||
assert res["documents"][0].content == "Test doc" | ||
assert res["documents"][0].embedding == [0.1, 0.2] |