Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: improve the deserialization logic for components that use a Document Store #6466

Merged
merged 12 commits into from
Dec 4, 2023
23 changes: 19 additions & 4 deletions haystack/components/caching/url_cache_checker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import List, Dict, Any

import importlib

import logging

from haystack import component, Document, default_from_dict, default_to_dict, DeserializationError
from haystack.document_stores import DocumentStore, document_store
from haystack.document_stores import DocumentStore


logger = logging.getLogger(__name__)


@component
Expand Down Expand Up @@ -34,9 +41,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "UrlCacheChecker":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.")
docstore_class = document_store.registry[init_params["document_store"]["type"]]

try:
module_name, type_ = init_params["document_store"]["type"].rsplit(".", 1)
logger.debug("Trying to import %s", module_name)
module = importlib.import_module(module_name)
except (ImportError, DeserializationError) as e:
raise DeserializationError(
f"DocumentStore of type '{init_params['document_store']['type']}' not correctly imported"
) from e

docstore_class = getattr(module, type_)
docstore = docstore_class.from_dict(init_params["document_store"])

data["init_parameters"]["document_store"] = docstore
Expand Down
11 changes: 4 additions & 7 deletions haystack/components/retrievers/in_memory_bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Any, Optional

from haystack import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores import InMemoryDocumentStore, document_store
from haystack.document_stores import InMemoryDocumentStore


@component
Expand Down Expand Up @@ -67,12 +67,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryBM25Retriever":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")

docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Any, Optional

from haystack import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores import InMemoryDocumentStore, document_store
from haystack.document_stores import InMemoryDocumentStore


@component
Expand Down Expand Up @@ -75,12 +75,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore type '{init_params['document_store']['type']}' not found")

docstore_class = document_store.registry[init_params["document_store"]["type"]]
docstore = docstore_class.from_dict(init_params["document_store"])
data["init_parameters"]["document_store"] = docstore
data["init_parameters"]["document_store"] = InMemoryDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
Expand Down
21 changes: 17 additions & 4 deletions haystack/components/writers/document_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import List, Optional, Dict, Any

import importlib
import logging

from haystack import component, Document, default_from_dict, default_to_dict, DeserializationError
from haystack.document_stores import DocumentStore, DuplicatePolicy, document_store
from haystack.document_stores import DocumentStore, DuplicatePolicy

logger = logging.getLogger(__name__)


@component
Expand Down Expand Up @@ -41,9 +46,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentWriter":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
if init_params["document_store"]["type"] not in document_store.registry:
raise DeserializationError(f"DocumentStore of type '{init_params['document_store']['type']}' not found.")
docstore_class = document_store.registry[init_params["document_store"]["type"]]

try:
module_name, type_ = init_params["document_store"]["type"].rsplit(".", 1)
logger.debug("Trying to import %s", module_name)
module = importlib.import_module(module_name)
except (ImportError, DeserializationError) as e:
raise DeserializationError(
f"DocumentStore of type '{init_params['document_store']['type']}' not correctly imported"
) from e

docstore_class = getattr(module, type_)
docstore = docstore_class.from_dict(init_params["document_store"])

data["init_parameters"]["document_store"] = docstore
Expand Down
2 changes: 0 additions & 2 deletions haystack/document_stores/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from haystack.document_stores.protocol import DocumentStore, DuplicatePolicy
from haystack.document_stores.in_memory.document_store import InMemoryDocumentStore
from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError, MissingDocumentError
from haystack.document_stores.decorator import document_store

__all__ = [
"DocumentStore",
Expand All @@ -10,5 +9,4 @@
"DocumentStoreError",
"DuplicateDocumentError",
"MissingDocumentError",
"document_store",
]
39 changes: 0 additions & 39 deletions haystack/document_stores/decorator.py

This file was deleted.

2 changes: 0 additions & 2 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tqdm.auto import tqdm

from haystack import default_from_dict, default_to_dict
from haystack.document_stores.decorator import document_store
from haystack.dataclasses import Document
from haystack.document_stores.protocol import DuplicatePolicy
from haystack.utils.filters import document_matches_filter, convert
Expand All @@ -27,7 +26,6 @@
DOT_PRODUCT_SCALING_FACTOR = 100


@document_store
class InMemoryDocumentStore:
"""
Stores data in-memory. It's ephemeral and cannot be saved to disk.
Expand Down
4 changes: 2 additions & 2 deletions haystack/testing/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict, Optional, Tuple, Type, List, Union

from haystack.dataclasses import Document
from haystack.document_stores import document_store, DocumentStore, DuplicatePolicy
from haystack.document_stores import DocumentStore, DuplicatePolicy
from haystack.core.component import component, Component
from haystack.core.serialization import default_to_dict, default_from_dict

Expand Down Expand Up @@ -117,7 +117,7 @@ def to_dict(self) -> Dict[str, Any]:
bases = (object,)

cls = type(name, bases, fields)
return document_store(cls)
return cls


def component_class(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Improve the deserialization logic for components that use a Document Store.
Remove the @document_store decorator and the registry of Document Stores.
12 changes: 7 additions & 5 deletions test/components/caching/test_url_cache_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,18 @@ def test_to_dict_with_custom_init_parameters(self):
}

def test_from_dict(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
data = {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MockedDocumentStore", "init_parameters": {}},
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"url_field": "my_url_field",
},
}
component = UrlCacheChecker.from_dict(data)
assert isinstance(component.document_store, mocked_docstore_class)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.url_field == "my_url_field"

def test_from_dict_without_docstore(self):
Expand All @@ -60,9 +62,9 @@ def test_from_dict_without_docstore_type(self):
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "haystack.components.caching.url_cache_checker.UrlCacheChecker",
"init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}},
"init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."):
with pytest.raises(DeserializationError):
UrlCacheChecker.from_dict(data)

def test_run(self):
Expand Down
22 changes: 13 additions & 9 deletions test/components/retrievers/test_in_memory_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,33 @@ def test_to_dict(self):
}

def test_to_dict_with_custom_init_parameters(self):
MyFakeStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = MyFakeStore()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
ds = InMemoryDocumentStore()
serialized_ds = ds.to_dict()

component = InMemoryBM25Retriever(
document_store=document_store, filters={"name": "test.txt"}, top_k=5, scale_score=True
document_store=InMemoryDocumentStore(), filters={"name": "test.txt"}, top_k=5, scale_score=True
)
data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"document_store": serialized_ds,
"filters": {"name": "test.txt"},
"top_k": 5,
"scale_score": True,
},
}

#

def test_from_dict(self):
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
data = {
"type": "haystack.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MyFakeStore", "init_parameters": {}},
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"filters": {"name": "test.txt"},
"top_k": 5,
},
Expand All @@ -103,9 +107,9 @@ def test_from_dict_without_docstore_type(self):
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "haystack.components.retrievers.in_memory_bm25_retriever.InMemoryBM25Retriever",
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
with pytest.raises(DeserializationError):
InMemoryBM25Retriever.from_dict(data)

def test_retriever_valid_run(self, mock_docs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ def test_to_dict_with_custom_init_parameters(self):
}

def test_from_dict(self):
document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
data = {
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MyFakeStore", "init_parameters": {}},
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"filters": {"name": "test.txt"},
"top_k": 5,
},
Expand All @@ -99,15 +101,15 @@ def test_from_dict_without_docstore_type(self):
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {"document_store": {"init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="Missing 'type' in document store's serialization data"):
with pytest.raises(DeserializationError):
InMemoryEmbeddingRetriever.from_dict(data)

def test_from_dict_nonexisting_docstore(self):
data = {
"type": "haystack.components.retrievers.in_memory_embedding_retriever.InMemoryEmbeddingRetriever",
"init_parameters": {"document_store": {"type": "NonexistingDocstore", "init_parameters": {}}},
"init_parameters": {"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore type 'NonexistingDocstore' not found"):
with pytest.raises(DeserializationError):
InMemoryEmbeddingRetriever.from_dict(data)

def test_valid_run(self):
Expand Down
12 changes: 7 additions & 5 deletions test/components/writers/test_document_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ def test_to_dict_with_custom_init_parameters(self):
}

def test_from_dict(self):
mocked_docstore_class = document_store_class("MockedDocumentStore")
data = {
"type": "haystack.components.writers.document_writer.DocumentWriter",
"init_parameters": {
"document_store": {"type": "haystack.testing.factory.MockedDocumentStore", "init_parameters": {}},
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"policy": "SKIP",
},
}
component = DocumentWriter.from_dict(data)
assert isinstance(component.document_store, mocked_docstore_class)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.policy == DuplicatePolicy.SKIP

def test_from_dict_without_docstore(self):
Expand All @@ -58,9 +60,9 @@ def test_from_dict_without_docstore_type(self):
def test_from_dict_nonexisting_docstore(self):
data = {
"type": "DocumentWriter",
"init_parameters": {"document_store": {"type": "NonexistingDocumentStore", "init_parameters": {}}},
"init_parameters": {"document_store": {"type": "Nonexisting.DocumentStore", "init_parameters": {}}},
}
with pytest.raises(DeserializationError, match="DocumentStore of type 'NonexistingDocumentStore' not found."):
with pytest.raises(DeserializationError):
DocumentWriter.from_dict(data)

def test_run(self):
Expand Down
6 changes: 0 additions & 6 deletions test/testing/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from haystack.dataclasses import Document
from haystack.testing.factory import document_store_class, component_class
from haystack.document_stores.decorator import document_store
from haystack.core.component import component


Expand All @@ -23,11 +22,6 @@ def test_document_store_from_dict():
assert isinstance(store, MyStore)


def test_document_store_class_is_registered():
MyStore = document_store_class("MyStore")
assert document_store.registry["haystack.testing.factory.MyStore"] == MyStore


def test_document_store_class_with_documents():
doc = Document(id="fake_id", content="This is a document")
MyStore = document_store_class("MyStore", documents=[doc])
Expand Down