Skip to content

Commit

Permalink
Merge branch 'main' into adopt-secret-amazon_chat
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista authored Feb 14, 2024
2 parents d6668d5 + 9383e28 commit 8daad4f
Show file tree
Hide file tree
Showing 8 changed files with 562 additions and 3 deletions.
9 changes: 6 additions & 3 deletions integrations/weaviate/pydoc/config.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../src]
modules: [
"haystack_integrations.document_stores.weaviate.document_store",
]
modules:
[
"haystack_integrations.document_stores.weaviate.document_store",
"haystack_integrations.components.retrievers.weaviate.bm25_retriever",
"haystack_integrations.components.retrievers.weaviate.embedding_retriever",
]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .bm25_retriever import WeaviateBM25Retriever
from .embedding_retriever import WeaviateEmbeddingRetriever

__all__ = ["WeaviateBM25Retriever", "WeaviateEmbeddingRetriever"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


@component
class WeaviateBM25Retriever:
"""
Retriever that uses BM25 to find the most promising documents for a given query.
"""

def __init__(
self,
*,
document_store: WeaviateDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
):
"""
Create a new instance of WeaviateBM25Retriever.
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
"""
self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
document_store=self._document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateBM25Retriever":
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
filters = filters or self._filters
top_k = top_k or self._top_k
return self._document_store._bm25_retrieval(query=query, filters=filters, top_k=top_k)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_from_dict, default_to_dict
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


@component
class WeaviateEmbeddingRetriever:
"""
A retriever that uses Weaviate's vector search to find similar documents based on the embeddings of the query.
"""

def __init__(
self,
*,
document_store: WeaviateDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
distance: Optional[float] = None,
certainty: Optional[float] = None,
):
"""
Create a new instance of WeaviateEmbeddingRetriever.
Raises ValueError if both `distance` and `certainty` are provided.
See the official Weaviate documentation to learn more about the `distance` and `certainty` parameters:
https://weaviate.io/developers/weaviate/api/graphql/search-operators#variables
:param document_store: Instance of WeaviateDocumentStore that will be associated with this retriever.
:param filters: Custom filters applied when running the retriever, defaults to None
:param top_k: Maximum number of documents to return, defaults to 10
:param distance: The maximum allowed distance between Documents' embeddings, defaults to None
:param certainty: Normalized distance between the result item and the search vector, defaults to None
"""
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k
self._distance = distance
self._certainty = certainty

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
distance=self._distance,
certainty=self._certainty,
document_store=self._document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateEmbeddingRetriever":
data["init_parameters"]["document_store"] = WeaviateDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
distance: Optional[float] = None,
certainty: Optional[float] = None,
):
filters = filters or self._filters
top_k = top_k or self._top_k
distance = distance or self._distance
certainty = certainty or self._certainty
return self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=filters,
top_k=top_k,
distance=distance,
certainty=certainty,
)
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,66 @@ def delete_documents(self, document_ids: List[str]) -> None:
"valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids],
},
)

def _bm25_retrieval(
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
) -> List[Document]:
collection_name = self._collection_settings["class"]
properties = self._client.schema.get(self._collection_settings["class"]).get("properties", [])
properties = [prop["name"] for prop in properties]

query_builder = (
self._client.query.get(collection_name, properties=properties)
.with_bm25(query=query, properties=["content"])
.with_additional(["vector"])
)

if filters:
query_builder = query_builder.with_where(convert_filters(filters))

if top_k:
query_builder = query_builder.with_limit(top_k)

result = query_builder.do()

return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]]

def _embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
distance: Optional[float] = None,
certainty: Optional[float] = None,
) -> List[Document]:
if distance is not None and certainty is not None:
msg = "Can't use 'distance' and 'certainty' parameters together"
raise ValueError(msg)

collection_name = self._collection_settings["class"]
properties = self._client.schema.get(self._collection_settings["class"]).get("properties", [])
properties = [prop["name"] for prop in properties]

near_vector: Dict[str, Union[float, List[float]]] = {
"vector": query_embedding,
}
if distance is not None:
near_vector["distance"] = distance

if certainty is not None:
near_vector["certainty"] = certainty

query_builder = (
self._client.query.get(collection_name, properties=properties)
.with_near_vector(near_vector)
.with_additional(["vector"])
)

if filters:
query_builder = query_builder.with_where(convert_filters(filters))

if top_k:
query_builder = query_builder.with_limit(top_k)

result = query_builder.do()
return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]]
102 changes: 102 additions & 0 deletions integrations/weaviate/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from unittest.mock import Mock, patch

from haystack_integrations.components.retrievers.weaviate import WeaviateBM25Retriever
from haystack_integrations.document_stores.weaviate import WeaviateDocumentStore


def test_init_default():
mock_document_store = Mock(spec=WeaviateDocumentStore)
retriever = WeaviateBM25Retriever(document_store=mock_document_store)
assert retriever._document_store == mock_document_store
assert retriever._filters == {}
assert retriever._top_k == 10


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(_mock_weaviate):
document_store = WeaviateDocumentStore()
retriever = WeaviateBM25Retriever(document_store=document_store)
assert retriever.to_dict() == {
"type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}


@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_from_dict(_mock_weaviate):
retriever = WeaviateBM25Retriever.from_dict(
{
"type": "haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateBM25Retriever",
"init_parameters": {
"filters": {},
"top_k": 10,
"document_store": {
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": None,
"collection_settings": {
"class": "Default",
"invertedIndexConfig": {"indexNullState": True},
"properties": [
{"name": "_original_id", "dataType": ["text"]},
{"name": "content", "dataType": ["text"]},
{"name": "dataframe", "dataType": ["text"]},
{"name": "blob_data", "dataType": ["blob"]},
{"name": "blob_mime_type", "dataType": ["text"]},
{"name": "score", "dataType": ["number"]},
],
},
"auth_client_secret": None,
"timeout_config": (10, 60),
"proxies": None,
"trust_env": False,
"additional_headers": None,
"startup_period": 5,
"embedded_options": None,
"additional_config": None,
},
},
},
}
)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10


@patch("haystack_integrations.components.retrievers.weaviate.bm25_retriever.WeaviateDocumentStore")
def test_run(mock_document_store):
retriever = WeaviateBM25Retriever(document_store=mock_document_store)
query = "some query"
filters = {"field": "content", "operator": "==", "value": "Some text"}
retriever.run(query=query, filters=filters, top_k=5)
mock_document_store._bm25_retrieval.assert_called_once_with(query=query, filters=filters, top_k=5)
Loading

0 comments on commit 8daad4f

Please sign in to comment.