Skip to content

Commit

Permalink
Merge branch 'main' into pgvector-draft
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Jan 21, 2024
2 parents 5617a95 + cd78080 commit b0a1d8f
Show file tree
Hide file tree
Showing 29 changed files with 119 additions and 127 deletions.
4 changes: 2 additions & 2 deletions integrations/chroma/example/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

querying = Pipeline()
querying.add_component("retriever", ChromaQueryRetriever(document_store))
results = querying.run({"retriever": {"queries": ["Variable declarations"], "top_k": 3}})
results = querying.run({"retriever": {"query": "Variable declarations", "top_k": 3}})

for d in results["retriever"]["documents"][0]:
for d in results["retriever"]["documents"]:
print(d.meta, d.score)
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .retriever import ChromaEmbeddingRetriever, ChromaQueryRetriever, ChromaSingleQueryRetriever
from .retriever import ChromaEmbeddingRetriever, ChromaQueryRetriever

__all__ = ["ChromaQueryRetriever", "ChromaEmbeddingRetriever", "ChromaSingleQueryRetriever"]
__all__ = ["ChromaQueryRetriever", "ChromaEmbeddingRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@ def __init__(self, document_store: ChromaDocumentStore, filters: Optional[Dict[s
self.top_k = top_k
self.document_store = document_store

@component.output_types(documents=List[List[Document]])
@component.output_types(documents=List[Document])
def run(
self,
queries: List[str],
query: str,
_: Optional[Dict[str, Any]] = None, # filters not yet supported
top_k: Optional[int] = None,
):
"""
Run the retriever on the given input data.
:param queries: The input data for the retriever. In this case, a list of queries.
:param query: The input data for the retriever. In this case, a plain-text query.
:return: The retrieved documents.
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if not top_k:
top_k = 3
return {"documents": self.document_store.search(queries, top_k)}
top_k = top_k or self.top_k

return {"documents": self.document_store.search([query], top_k)[0]}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -60,24 +60,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChromaQueryRetriever":
return default_from_dict(cls, data)


@component
class ChromaSingleQueryRetriever(ChromaQueryRetriever):
"""
A convenient wrapper to the standard query retriever that accepts a single query
and returns a list of documents
"""

@component.output_types(documents=List[Document])
def run(
self,
query: str,
filters: Optional[Dict[str, Any]] = None, # filters not yet supported
top_k: Optional[int] = None,
):
queries = [query]
return super().run(queries, filters, top_k)[0]


@component
class ChromaEmbeddingRetriever(ChromaQueryRetriever):
@component.output_types(documents=List[Document])
Expand All @@ -95,8 +77,7 @@ def run(
:raises ValueError: If the specified document store is not found or is not a MemoryDocumentStore instance.
"""
if not top_k:
top_k = 3
top_k = top_k or self.top_k

query_embeddings = [query_embedding]
return {"documents": self.document_store.search_embeddings(query_embeddings, top_k)[0]}
19 changes: 10 additions & 9 deletions integrations/elasticsearch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m
Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues"
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/elasticsearch"

[tool.hatch.build.targets.wheel]
packages = ["src/haystack_integrations"]

[tool.hatch.version]
source = "vcs"
tag-pattern = 'integrations\/elasticsearch-v(?P<version>.*)'
Expand Down Expand Up @@ -70,7 +73,7 @@ dependencies = [
"ruff>=0.0.243",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/elasticsearch_haystack tests}"
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
Expand Down Expand Up @@ -139,26 +142,23 @@ unfixable = [
]

[tool.ruff.isort]
known-first-party = ["elasticsearch_haystack"]
known-first-party = ["src"]

[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "all"
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source_pkgs = ["elasticsearch_haystack", "tests"]
source_pkgs = ["src", "tests"]
branch = true
parallel = true
omit = [
"src/elasticsearch_haystack/__about__.py",
]

[tool.coverage.paths]
elasticsearch_haystack = ["src/elasticsearch_haystack", "*/elasticsearch-haystack/src/elasticsearch_haystack"]
tests = ["tests", "*/elasticsearch-haystack/tests"]
elasticsearch_haystack = ["src/haystack_integrations", "*/elasticsearch/src/haystack_integrations"]
tests = ["tests", "*/elasticsearch/src/tests"]

[tool.coverage.report]
exclude_lines = [
Expand All @@ -177,6 +177,7 @@ markers = [
[[tool.mypy.overrides]]
module = [
"haystack.*",
"haystack_integrations.*",
"pytest.*"
]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .bm25_retriever import ElasticsearchBM25Retriever
from .embedding_retriever import ElasticsearchEmbeddingRetriever

__all__ = ["ElasticsearchBM25Retriever", "ElasticsearchEmbeddingRetriever"]
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore


@component
Expand All @@ -19,8 +18,8 @@ class ElasticsearchBM25Retriever:
Usage example:
```python
from haystack import Document
from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from elasticsearch_haystack.bm25_retriever import ElasticsearchBM25Retriever
from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchBM25Retriever
document_store = ElasticsearchDocumentStore(hosts="http://localhost:9200")
retriever = ElasticsearchBM25Retriever(document_store=document_store)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import Document

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from haystack_integrations.document_stores.elasticsearch.document_store import ElasticsearchDocumentStore


@component
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from .document_store import ElasticsearchDocumentStore

__all__ = ["ElasticsearchDocumentStore"]
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

# There are no import stubs for elastic_transport and elasticsearch so mypy fails
from elastic_transport import NodeConfig # type: ignore[import-not-found]
from elasticsearch import Elasticsearch, helpers # type: ignore[import-not-found]
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import Document
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.utils.filters import convert

from elasticsearch_haystack.filters import _normalize_filters
from elasticsearch import Elasticsearch, helpers # type: ignore[import-not-found]

from .filters import _normalize_filters

logger = logging.getLogger(__name__)

Expand Down
17 changes: 8 additions & 9 deletions integrations/elasticsearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from unittest.mock import Mock, patch

from haystack.dataclasses import Document

from elasticsearch_haystack.bm25_retriever import ElasticsearchBM25Retriever
from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchBM25Retriever
from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore


def test_init_default():
Expand All @@ -18,21 +17,21 @@ def test_init_default():
assert not retriever._scale_score


@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_to_dict(_mock_elasticsearch_client):
document_store = ElasticsearchDocumentStore(hosts="some fake host")
retriever = ElasticsearchBM25Retriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "elasticsearch_haystack.bm25_retriever.ElasticsearchBM25Retriever",
"type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"index": "default",
"embedding_similarity_function": "cosine",
},
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"fuzziness": "AUTO",
Expand All @@ -42,14 +41,14 @@ def test_to_dict(_mock_elasticsearch_client):
}


@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_from_dict(_mock_elasticsearch_client):
data = {
"type": "elasticsearch_haystack.bm25_retriever.ElasticsearchBM25Retriever",
"type": "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever.ElasticsearchBM25Retriever",
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"fuzziness": "AUTO",
Expand Down
12 changes: 6 additions & 6 deletions integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError
from haystack.document_stores.types import DuplicatePolicy
from haystack.testing.document_store import DocumentStoreBaseTests

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore


@pytest.mark.integration
class TestDocumentStore(DocumentStoreBaseTests):
"""
Common test cases will be provided by `DocumentStoreBaseTests` but
Expand Down Expand Up @@ -67,23 +67,23 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do

super().assert_documents_are_equal(received, expected)

@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_to_dict(self, _mock_elasticsearch_client):
document_store = ElasticsearchDocumentStore(hosts="some hosts")
res = document_store.to_dict()
assert res == {
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"index": "default",
"embedding_similarity_function": "cosine",
},
}

@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_from_dict(self, _mock_elasticsearch_client):
data = {
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"index": "default",
Expand Down
19 changes: 10 additions & 9 deletions integrations/elasticsearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from unittest.mock import Mock, patch

from haystack.dataclasses import Document

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from elasticsearch_haystack.embedding_retriever import ElasticsearchEmbeddingRetriever
from haystack_integrations.components.retrievers.elasticsearch import ElasticsearchEmbeddingRetriever
from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore


def test_init_default():
Expand All @@ -18,21 +17,22 @@ def test_init_default():
assert retriever._num_candidates is None


@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_to_dict(_mock_elasticsearch_client):
document_store = ElasticsearchDocumentStore(hosts="some fake host")
retriever = ElasticsearchEmbeddingRetriever(document_store=document_store)
res = retriever.to_dict()
t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever"
assert res == {
"type": "elasticsearch_haystack.embedding_retriever.ElasticsearchEmbeddingRetriever",
"type": t,
"init_parameters": {
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"index": "default",
"embedding_similarity_function": "cosine",
},
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"top_k": 10,
Expand All @@ -41,14 +41,15 @@ def test_to_dict(_mock_elasticsearch_client):
}


@patch("elasticsearch_haystack.document_store.Elasticsearch")
@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_from_dict(_mock_elasticsearch_client):
t = "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever.ElasticsearchEmbeddingRetriever"
data = {
"type": "elasticsearch_haystack.embedding_retriever.ElasticsearchEmbeddingRetriever",
"type": t,
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "elasticsearch_haystack.document_store.ElasticsearchDocumentStore",
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
},
"filters": {},
"top_k": 10,
Expand Down
3 changes: 1 addition & 2 deletions integrations/elasticsearch/tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from haystack.errors import FilterError

from elasticsearch_haystack.filters import _normalize_filters, _normalize_ranges
from haystack_integrations.document_stores.elasticsearch.filters import _normalize_filters, _normalize_ranges

filters_data = [
(
Expand Down
Loading

0 comments on commit b0a1d8f

Please sign in to comment.