diff --git a/pyproject.toml b/pyproject.toml index e7c064e..6f11de2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "typing_extensions", "pymilvus", "milvus", - "farm-haystack" ] [project.urls] @@ -68,7 +67,7 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/milvus_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -175,6 +174,10 @@ markers = [ [[tool.mypy.overrides]] module = [ "haystack.*", + "milvus_haystack.*", + "pymilvus.*", + "numpy", + "milvus", "pytest.*" ] ignore_missing_imports = true diff --git a/src/milvus_haystack/__init__.py b/src/milvus_haystack/__init__.py index 099e217..2e67592 100644 --- a/src/milvus_haystack/__init__.py +++ b/src/milvus_haystack/__init__.py @@ -1,10 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from .document_store import MilvusDocumentStore -from .milvus_embedding_retriever import MilvusEmbeddingRetriever +from .document_store import MilvusDocumentStore # noqa: TID252 +from .milvus_embedding_retriever import MilvusEmbeddingRetriever # noqa: TID252 -__all__ = [ - "MilvusDocumentStore", - "MilvusEmbeddingRetriever" -] +__all__ = ["MilvusDocumentStore", "MilvusEmbeddingRetriever"] diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index f898ac1..c4c8813 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -1,10 +1,12 @@ import logging -from typing import List, Dict, Optional, Union, Any +from typing import Any, Dict, List, Optional, Union from uuid import uuid4 + from haystack import Document, default_from_dict, default_to_dict from haystack.document_stores.types import DuplicatePolicy from haystack.errors import FilterError from pymilvus import MilvusException + from milvus_haystack.filters import parse_filters logger = logging.getLogger(__name__) @@ -24,23 +26,23 @@ class MilvusDocumentStore: """ def __init__( - self, - collection_name: str = "HaystackCollection", - collection_description: str = "", - collection_properties: Optional[Dict[str, Any]] = None, - connection_args: Optional[Dict[str, Any]] = None, - consistency_level: str = "Session", - index_params: Optional[dict] = None, - search_params: Optional[dict] = None, - drop_old: Optional[bool] = False, - *, - primary_field: str = "id", - text_field: str = "text", - vector_field: str = "vector", - partition_key_field: Optional[str] = None, - partition_names: Optional[list] = None, - replica_number: int = 1, - timeout: Optional[float] = None, + self, + collection_name: str = "HaystackCollection", + collection_description: str = "", + collection_properties: Optional[Dict[str, Any]] = None, + connection_args: Optional[Dict[str, Any]] = None, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: Optional[bool] = False, # noqa: FBT002 + *, + primary_field: str = "id", + text_field: str = "text", + vector_field: str = "vector", + partition_key_field: Optional[str] = None, + partition_names: Optional[list] = None, + replica_number: int = 1, + timeout: Optional[float] = None, ): """ Initialize the Milvus vector store. @@ -232,9 +234,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc output_fields=output_fields, ) except MilvusException as err: - logger.error( - "Failed to query documents with filters expr: %s", expr - ) + logger.error("Failed to query documents with filters expr: %s", expr) raise FilterError(err) from err docs = [self._parse_document(d) for d in res] return docs @@ -293,7 +293,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D return 0 # If the collection hasn't been initialized yet, perform all steps to do so - kwargs = {} + kwargs: Dict[str, Any] = {} if not isinstance(self.col, Collection): kwargs = {"embeddings": embeddings, "metas": metas} if self.partition_names: @@ -322,10 +322,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D vectors: list = insert_dict[self._vector_field] total_count = len(vectors) - ids: list[str] = [] - batch_size = 1000 - assert isinstance(self.col, Collection) + if not isinstance(self.col, Collection): + raise MilvusException(message="Collection is not initialized") for i in range(0, total_count, batch_size): # Grab end index end = min(i + batch_size, total_count) @@ -337,9 +336,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D res = self.col.insert(insert_list, timeout=None, **kwargs) ids.extend(res.primary_keys) except MilvusException as err: - logger.error( - "Failed to insert batch starting at entity: %s/%s", i, total_count - ) + logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count) raise err self.col.flush() return len(ids) @@ -432,11 +429,11 @@ def _create_connection_alias(self, connection_args: dict) -> str: for con in connections.list_connections(): addr = connections.get_connection_addr(con[0]) if ( - con[1] - and ("address" in addr) - and (addr["address"] == given_address) - and ("user" in addr) - and (addr["user"] == tmp_user) + con[1] + and ("address" in addr) + and (addr["address"] == given_address) + and ("user" in addr) + and (addr["user"] == tmp_user) ): logger.debug("Using previous connection: %s", con[0]) return con[0] @@ -452,12 +449,12 @@ def _create_connection_alias(self, connection_args: dict) -> str: raise err def _init( - self, - embeddings: Optional[List] = None, - metas: Optional[List[Dict]] = None, - partition_names: Optional[List] = None, - replica_number: int = 1, - timeout: Optional[float] = None, + self, + embeddings: Optional[List] = None, + metas: Optional[List[Dict]] = None, + partition_names: Optional[List] = None, + replica_number: int = 1, + timeout: Optional[float] = None, ) -> None: if embeddings is not None: self._create_collection(embeddings, metas) @@ -470,9 +467,7 @@ def _init( timeout=timeout, ) - def _create_collection( - self, embeddings: list, metas: Optional[List[Dict]] = None - ) -> None: + def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None: from pymilvus import ( Collection, CollectionSchema, @@ -498,26 +493,16 @@ def _create_collection( raise ValueError(err_msg) # Datatype is a string/varchar equivalent elif dtype == DataType.VARCHAR: - fields.append( - FieldSchema(key, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) else: fields.append(FieldSchema(key, dtype)) # Create the text field - fields.append( - FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) - ) + fields.append(FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)) # Create the primary key field - fields.append( - FieldSchema( - self._primary_field, DataType.VARCHAR, is_primary=True, max_length=65_535 - ) - ) + fields.append(FieldSchema(self._primary_field, DataType.VARCHAR, is_primary=True, max_length=65_535)) # Create the vector field, supports binary or float vectors - fields.append( - FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) - ) + fields.append(FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)) # Create the schema for the collection schema = CollectionSchema( @@ -538,9 +523,7 @@ def _create_collection( if self.collection_properties is not None: self.col.set_properties(self.collection_properties) except MilvusException as err: - logger.error( - "Failed to create collection: %s error: %s", self.collection_name, err - ) + logger.error("Failed to create collection: %s error: %s", self.collection_name, err) raise err def _extract_fields(self) -> None: @@ -592,9 +575,7 @@ def _create_index(self) -> None: ) except MilvusException as err: - logger.error( - "Failed to create an index on collection: %s", self.collection_name - ) + logger.error("Failed to create an index on collection: %s", self.collection_name) raise err def _create_search_params(self) -> None: @@ -620,20 +601,19 @@ def _get_index(self) -> Optional[Dict[str, Any]]: return None def _load( - self, - partition_names: Optional[list] = None, - replica_number: int = 1, - timeout: Optional[float] = None, + self, + partition_names: Optional[list] = None, + replica_number: int = 1, + timeout: Optional[float] = None, ) -> None: """Load the collection if available.""" from pymilvus import Collection, utility from pymilvus.client.types import LoadState if ( - isinstance(self.col, Collection) - and self._get_index() is not None - and utility.load_state(self.collection_name, using=self.alias) - == LoadState.NotLoad + isinstance(self.col, Collection) + and self._get_index() is not None + and utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad ): self.col.load( partition_names=partition_names, @@ -642,11 +622,8 @@ def _load( ) def _embedding_retrieval( - self, - query_embedding: List[float], - filters: Optional[Dict[str, Any]] = None, - top_k: int = 10 - ): + self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10 + ) -> List[Document]: if self.col is None: logger.debug("No existing collection to search.") return [] diff --git a/src/milvus_haystack/filters.py b/src/milvus_haystack/filters.py index 6aca2e5..5f527d9 100644 --- a/src/milvus_haystack/filters.py +++ b/src/milvus_haystack/filters.py @@ -1,4 +1,5 @@ -from typing import Dict, Union, Any +from typing import Any, Dict, Union + from haystack.errors import FilterError LOGIC_OPERATORS = [ @@ -55,10 +56,12 @@ def _parse_comparison(filters: Dict[str, Any]) -> str: def _assert_comparison_filter(filters: Dict[str, Any]): - assert "operator" in filters, "operator must be specified in filters" - assert "field" in filters, "field must be specified in filters" - assert "value" in filters, "value must be specified in filters" - assert filters["operator"] in COMPARISON_OPERATORS, FilterError("operator must be one of: %s" % LOGIC_OPERATORS) + assert "operator" in filters, "operator must be specified in filters" # noqa: S101 + assert "field" in filters, "field must be specified in filters" # noqa: S101 + assert "value" in filters, "value must be specified in filters" # noqa: S101 + assert filters["operator"] in COMPARISON_OPERATORS, FilterError( # noqa: S101 + "operator must be one of: %s" % LOGIC_OPERATORS + ) def _parse_logic(filters: Dict[str, Any]) -> str: @@ -80,7 +83,7 @@ def _parse_logic(filters: Dict[str, Any]) -> str: def _assert_logic_filter(filters: Dict[str, Any]): - assert "operator" in filters, "operator must be specified in filters" - assert "conditions" in filters, "conditions must be specified in filters" - assert filters["operator"] in LOGIC_OPERATORS, "operator must be one of: %s" % LOGIC_OPERATORS - assert isinstance(filters["conditions"], list), "conditions must be a list" + assert "operator" in filters, "operator must be specified in filters" # noqa: S101 + assert "conditions" in filters, "conditions must be specified in filters" # noqa: S101 + assert filters["operator"] in LOGIC_OPERATORS, "operator must be one of: %s" % LOGIC_OPERATORS # noqa: S101 + assert isinstance(filters["conditions"], list), "conditions must be a list" # noqa: S101 diff --git a/src/milvus_haystack/milvus_embedding_retriever.py b/src/milvus_haystack/milvus_embedding_retriever.py index 0257167..5817c33 100644 --- a/src/milvus_haystack/milvus_embedding_retriever.py +++ b/src/milvus_haystack/milvus_embedding_retriever.py @@ -1,5 +1,7 @@ -from typing import Any, Dict, Optional, List -from haystack import component, Document +from typing import Any, Dict, List, Optional + +from haystack import Document, component + from milvus_haystack import MilvusDocumentStore @@ -22,7 +24,7 @@ def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[s self.document_store = document_store @component.output_types(documents=List[Document]) - def run(self, query_embedding: List[float]): + def run(self, query_embedding: List[float]) -> Dict[str, List[Document]]: """ Retrieve documents from the `MilvusDocumentStore`, based on their dense embeddings. diff --git a/tests/test_document_store.py b/tests/test_document_store.py index b1bebff..c5222f3 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -1,27 +1,34 @@ +import logging +import time + import pytest from haystack import Document from haystack.document_stores.types import DocumentStore -from haystack.testing.document_store import CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest +from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest + from src.milvus_haystack import MilvusDocumentStore +logger = logging.getLogger(__name__) + class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): from milvus import MilvusServer + milvus_server = MilvusServer() milvus_server.set_base_dir("test_milvus_base") milvus_server.listen_port = 19530 try: milvus_server.stop() - except: - pass + except Exception as err: + logger.debug("Can not stop Milvus server. %s", err) try: milvus_server.cleanup() - except: - pass + except Exception as err: + logger.debug("Can not cleanup Milvus. %s", err) try: milvus_server.start() - except: - pass + except Exception as err: + logger.debug("Can not start Milvus server. %s", err) @pytest.fixture def document_store(self) -> MilvusDocumentStore: @@ -42,17 +49,26 @@ def test_write_documents(self, document_store: DocumentStore): ) assert document_store.count_documents() == 3 + def test_delete_documents(self, document_store: DocumentStore): + """ + Test delete_documents() normal behaviour. + """ + doc = Document(content="test doc") + document_store.write_documents([doc]) + assert document_store.count_documents() == 1 + + document_store.delete_documents([doc.id]) + time.sleep(1) + assert document_store.count_documents() == 0 + @pytest.mark.skip(reason="Milvus does not currently check if entity primary keys are duplicates") - def test_write_documents_duplicate_fail(self, document_store: DocumentStore): - ... + def test_write_documents_duplicate_fail(self, document_store: DocumentStore): ... @pytest.mark.skip(reason="Milvus does not currently check if entity primary keys are duplicates") - def test_write_documents_duplicate_skip(self, document_store: DocumentStore): - ... + def test_write_documents_duplicate_skip(self, document_store: DocumentStore): ... @pytest.mark.skip(reason="Milvus does not currently check if entity primary keys are duplicates") - def test_write_documents_duplicate_overwrite(self, document_store: DocumentStore): - ... + def test_write_documents_duplicate_overwrite(self, document_store: DocumentStore): ... def test_to_and_from_dict(self, document_store: MilvusDocumentStore): document_store_dict = document_store.to_dict() @@ -62,13 +78,7 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore): "collection_name": "HaystackCollection", "collection_description": "", "collection_properties": None, - "connection_args": { - "host": "localhost", - "port": "19530", - "user": "", - "password": "", - "secure": False - }, + "connection_args": {"host": "localhost", "port": "19530", "user": "", "password": "", "secure": False}, "consistency_level": "Session", "index_params": None, "search_params": None, @@ -79,8 +89,8 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore): "partition_key_field": None, "partition_names": None, "replica_number": 1, - "timeout": None - } + "timeout": None, + }, } assert document_store_dict == expected_dict reconstructed_document_store = MilvusDocumentStore.from_dict(document_store_dict) diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py index ba6daf0..4d898f0 100644 --- a/tests/test_embedding_retriever.py +++ b/tests/test_embedding_retriever.py @@ -1,26 +1,32 @@ +import logging + import pytest from haystack import Document + from src.milvus_haystack import MilvusDocumentStore from src.milvus_haystack.milvus_embedding_retriever import MilvusEmbeddingRetriever +logger = logging.getLogger(__name__) + class TestMilvusEmbeddingTests: from milvus import MilvusServer + milvus_server = MilvusServer() milvus_server.set_base_dir("test_milvus_base") milvus_server.listen_port = 19530 try: milvus_server.stop() - except: - pass + except Exception as err: + logger.debug("Can not stop Milvus server. %s", err) try: milvus_server.cleanup() - except: - pass + except Exception as err: + logger.debug("Can not cleanup Milvus. %s", err) try: milvus_server.start() - except: - pass + except Exception as err: + logger.debug("Can not start Milvus server. %s", err) @pytest.fixture def document_store(self) -> MilvusDocumentStore: @@ -50,7 +56,9 @@ def test_run(self, document_store: MilvusDocumentStore): ) documents.append(doc) document_store.write_documents(documents) - retriever = MilvusEmbeddingRetriever(document_store, ) + retriever = MilvusEmbeddingRetriever( + document_store, + ) query_embedding = [-10.0] * 128 res = retriever.run(query_embedding) - assert res["documents"] == doc + assert res["documents"] == documents diff --git a/tests/test_filters.py b/tests/test_filters.py index 5484b2e..b7a82da 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,28 +1,34 @@ +import logging +from typing import List + import numpy as np import pytest -from typing import List from haystack import Document from haystack.testing.document_store import FilterDocumentsTest, _random_embeddings + from src.milvus_haystack import MilvusDocumentStore +logger = logging.getLogger(__name__) + class TestMilvusFilters(FilterDocumentsTest): from milvus import MilvusServer + milvus_server = MilvusServer() - milvus_server.set_base_dir('test_milvus_base') + milvus_server.set_base_dir("test_milvus_base") milvus_server.listen_port = 19530 try: milvus_server.stop() - except: - pass + except Exception as err: + logger.debug("Can not stop Milvus server. %s", err) try: milvus_server.cleanup() - except: - pass + except Exception as err: + logger.debug("Can not cleanup Milvus. %s", err) try: milvus_server.start() - except: - pass + except Exception as err: + logger.debug("Can not start Milvus server. %s", err) @pytest.fixture def filterable_docs(self) -> List[Document]: @@ -103,65 +109,49 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert np.allclose(np.array(r.embedding), np.array(e.embedding), atol=1e-4) @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") - def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with Dates") - def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with Dates") - def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with Dates") - def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with Dates") - def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_not_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_greater_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_less_than_with_none(self, document_store, filterable_docs): - ... + def test_comparison_less_than_with_none(self, document_store, filterable_docs): ... @pytest.mark.skip(reason="Milvus doesn't support comparison with None") - def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): - ... + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): ...