diff --git a/src/milvus_haystack/milvus_embedding_retriever.py b/src/milvus_haystack/milvus_embedding_retriever.py index 5817c33..d8611ac 100644 --- a/src/milvus_haystack/milvus_embedding_retriever.py +++ b/src/milvus_haystack/milvus_embedding_retriever.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from haystack import Document, component +from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict from milvus_haystack import MilvusDocumentStore @@ -23,6 +23,35 @@ def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[s self.top_k = top_k self.document_store = document_store + def to_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the retriever component. + + :returns: + A dictionary representation of the retriever component. + """ + return default_to_dict( + self, document_store=self.document_store.to_dict(), filters=self.filters, top_k=self.top_k + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever": + """ + Creates a new retriever from a dictionary. + + :param data: The dictionary to use to create the retriever. + :return: A new retriever. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + err_msg = "Missing 'document_store' in serialization data" + raise DeserializationError(err_msg) + + docstore = MilvusDocumentStore.from_dict(init_params["document_store"]) + data["init_parameters"]["document_store"] = docstore + + return default_from_dict(cls, data) + @component.output_types(documents=List[Document]) def run(self, query_embedding: List[float]) -> Dict[str, List[Document]]: """ diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py index 0c88538..f41b482 100644 --- a/tests/test_embedding_retriever.py +++ b/tests/test_embedding_retriever.py @@ -44,3 +44,81 @@ def test_run(self, document_store: MilvusDocumentStore): query_embedding = [-10.0] * 128 res = retriever.run(query_embedding) assert res["documents"] == documents + + def test_to_dict(self, document_store: MilvusDocumentStore): + expected_dict = { + "type": "src.milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": {"host": "localhost", "port": "19530", "user": "", "password": "", "secure": False}, + "consistency_level": "Session", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + } + retriever = MilvusEmbeddingRetriever(document_store) + result = retriever.to_dict() + + assert result["type"] == "src.milvus_haystack.milvus_embedding_retriever.MilvusEmbeddingRetriever" + assert result["init_parameters"]["document_store"] == expected_dict + + def test_from_dict(self, document_store: MilvusDocumentStore): + retriever_dict = { + "type": "src.milvus_haystack.milvus_embedding_retriever.MilvusEmbeddingRetriever", + "init_parameters": { + "document_store": { + "type": "milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": { + "host": "localhost", + "port": "19530", + "user": "", + "password": "", + "secure": False, + }, + "consistency_level": "Session", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + }, + "filters": None, + "top_k": 10, + }, + } + + retriever = MilvusEmbeddingRetriever(document_store) + + reconstructed_retriever = MilvusEmbeddingRetriever.from_dict(retriever_dict) + for field in vars(reconstructed_retriever): + if field.startswith("__"): + continue + elif field == "document_store": + for doc_store_field in vars(document_store): + if doc_store_field.startswith("__"): + continue + assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( + document_store, doc_store_field + ) + else: + assert getattr(reconstructed_retriever, field) == getattr(retriever, field)