diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..c5b101f --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,17 @@ +# Contributing to Haystack + +## Contribute code + +### Run code quality checks locally``` + +Install and update your [ruff](https://github.com/astral-sh/ruff) and [hatch](https://github.com/pypa/hatch) to the latest versions. + +To check your code style according to linting rules run: +```sh +hatch run lint:all +```` + +If the linters spot any error, you can fix it before checking in your code: +```sh +hatch run lint:fmt +``` diff --git a/README.md b/README.md index 0ac5397..cb2d069 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,6 @@ pip install --upgrade pymilvus milvus-haystack ## Usage -By default, if you install the latest version of pymilvus, you don't need to start the milvus service manually. -Optionally, you -can [start the Milvus service by docker](https://milvus.io/docs/install_standalone-docker.md#Start-Milvus). - Use the `MilvusDocumentStore` in a Haystack pipeline as a quick start. ```python @@ -22,10 +18,7 @@ from haystack import Document from milvus_haystack import MilvusDocumentStore document_store = MilvusDocumentStore( - # If you have installed the latest version of pymilvus with milvus lite, you can use a local path as the uri without starting the milvus service. connection_args={"uri": "./milvus.db"}, - # Or, if you have started the milvus standalone service by docker, you can use the specified uri to connect to the service. - # connection_args={"uri": "http://localhost:19530"}, drop_old=True, ) documents = [Document( @@ -36,6 +29,38 @@ documents = [Document( document_store.write_documents(documents) print(document_store.count_documents()) # 1 ``` +### Different ways to connect to Milvus + +- For the case of [Milvus Lite](https://milvus.io/docs/milvus_lite.md), the most convenient method, just set the uri as a local file. +```python +document_store = MilvusDocumentStore( + connection_args={"uri": "./milvus.db"}, + drop_old=True, +) +``` + +- For the case of Milvus server on [docker or kubernetes](https://milvus.io/docs/quickstart.md), it is recommended to use when you are dealing with large scale of data. After starting the Milvus service, you can use the specified uri to connect to the service. +```python +document_store = MilvusDocumentStore( + connection_args={"uri": "http://localhost:19530"}, + drop_old=True, +) +``` + +- For the case of [Zilliz Cloud](https://zilliz.com/cloud), the fully managed cloud service for Milvus, adjust the uri and token, which correspond to the [Public Endpoint and Api key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#free-cluster-details) in Zilliz Cloud. +```python +from haystack.utils import Secret +document_store = MilvusDocumentStore( + connection_args={ + "uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com", # Your Public Endpoint + "token": Secret.from_env_var("ZILLIZ_CLOUD_API_KEY"), # API key, we recommend using the Secret class to load the token from env variable for security. + "secure": True + }, + drop_old=True, +) +``` + + ## Dive deep usage @@ -45,12 +70,6 @@ Prepare an OpenAI API key and set it as an environment variable: export OPENAI_API_KEY= ``` -Here are the ways to - -- Create the indexing Pipeline -- Create the retrieval pipeline -- Create the RAG pipeline - ### Create the indexing Pipeline and index some documents ```python @@ -70,10 +89,7 @@ current_file_path = os.path.abspath(__file__) file_paths = [current_file_path] # You can replace it with your own file paths. document_store = MilvusDocumentStore( - # If you have installed the latest version of pymilvus with milvus lite, you can use a local path as the uri without starting the milvus service. connection_args={"uri": "./milvus.db"}, - # Or, if you have started the milvus standalone service by docker, you can use the specified uri to connect to the service. - # connection_args={"uri": "http://localhost:19530"}, drop_old=True, ) indexing_pipeline = Pipeline() @@ -143,6 +159,121 @@ print('RAG answer:', results["generator"]["replies"][0]) ``` +## Sparse Retrieval +```python +from haystack import Document, Pipeline +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy +from haystack_integrations.components.embedders.fastembed import ( + FastembedSparseDocumentEmbedder, + FastembedSparseTextEmbedder, +) + +from milvus_haystack import MilvusDocumentStore, MilvusSparseEmbeddingRetriever + +document_store = MilvusDocumentStore( + connection_args={"uri": "./milvus.db"}, + sparse_vector_field="sparse_vector", # Specify a name of the sparse vector field to enable sparse retrieval. + drop_old=True, +) + +documents = [ + Document(content="My name is Wolfgang and I live in Berlin"), + Document(content="I saw a black horse running"), + Document(content="Germany has many big cities"), + Document(content="fastembed is supported by and maintained by Milvus."), +] + +sparse_document_embedder = FastembedSparseDocumentEmbedder() +writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE) + +indexing_pipeline = Pipeline() +indexing_pipeline.add_component("sparse_document_embedder", sparse_document_embedder) +indexing_pipeline.add_component("writer", writer) +indexing_pipeline.connect("sparse_document_embedder", "writer") + +indexing_pipeline.run({"sparse_document_embedder": {"documents": documents}}) + +query_pipeline = Pipeline() +query_pipeline.add_component("sparse_text_embedder", FastembedSparseTextEmbedder()) +query_pipeline.add_component("sparse_retriever", MilvusSparseEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("sparse_text_embedder.sparse_embedding", "sparse_retriever.query_sparse_embedding") + +query = "Who supports fastembed?" + +result = query_pipeline.run({"sparse_text_embedder": {"text": query}}) + +print(result["sparse_retriever"]["documents"][0]) # noqa: T201 + +# Document(id=..., content: 'fastembed is supported by and maintained by Milvus.', sparse_embedding: vector with 48 non-zero elements) +``` + +## Hybrid Retrieval + +```python +from haystack import Document, Pipeline +from haystack.components.embedders import OpenAIDocumentEmbedder, OpenAITextEmbedder +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy +from haystack_integrations.components.embedders.fastembed import ( + FastembedSparseDocumentEmbedder, + FastembedSparseTextEmbedder, +) + +from milvus_haystack import MilvusDocumentStore, MilvusHybridRetriever + +document_store = MilvusDocumentStore( + connection_args={"uri": "./milvus.db"}, + drop_old=True, + sparse_vector_field="sparse_vector", # Specify a name of the sparse vector field to enable hybrid retrieval. +) + +documents = [ + Document(content="My name is Wolfgang and I live in Berlin"), + Document(content="I saw a black horse running"), + Document(content="Germany has many big cities"), + Document(content="fastembed is supported by and maintained by Milvus."), +] + +writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE) + +indexing_pipeline = Pipeline() +indexing_pipeline.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder()) +indexing_pipeline.add_component("dense_doc_embedder", OpenAIDocumentEmbedder()) +indexing_pipeline.add_component("writer", writer) +indexing_pipeline.connect("sparse_doc_embedder", "dense_doc_embedder") +indexing_pipeline.connect("dense_doc_embedder", "writer") + +indexing_pipeline.run({"sparse_doc_embedder": {"documents": documents}}) + +querying_pipeline = Pipeline() +querying_pipeline.add_component("sparse_text_embedder", + FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1")) + +querying_pipeline.add_component("dense_text_embedder", OpenAITextEmbedder()) +querying_pipeline.add_component( + "retriever", + MilvusHybridRetriever( + document_store=document_store, + # reranker=WeightedRanker(0.5, 0.5), # Default is RRFRanker() + ) +) + +querying_pipeline.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding") +querying_pipeline.connect("dense_text_embedder.embedding", "retriever.query_embedding") + +question = "Who supports fastembed?" + +results = querying_pipeline.run( + {"dense_text_embedder": {"text": question}, + "sparse_text_embedder": {"text": question}} +) + +print(results["retriever"]["documents"][0]) + +# Document(id=..., content: 'fastembed is supported by and maintained by Milvus.', embedding: vector of size 1536, sparse_embedding: vector with 48 non-zero elements) + +``` ## License `milvus-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ee62cd7..b2c88ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,12 +68,12 @@ dependencies = [ [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ - "ruff {args:.}", + "ruff check {args:.}", "black --check --diff {args:.}", ] fmt = [ "black {args:.}", - "ruff --fix {args:.}", + "ruff check --fix {args:.}", "style", ] all = [ diff --git a/src/milvus_haystack/__about__.py b/src/milvus_haystack/__about__.py index 4dcc8e0..133d429 100644 --- a/src/milvus_haystack/__about__.py +++ b/src/milvus_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present Tuana Celik # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.8" +__version__ = "0.0.9" diff --git a/src/milvus_haystack/__init__.py b/src/milvus_haystack/__init__.py index 2e67592..f2b42e4 100644 --- a/src/milvus_haystack/__init__.py +++ b/src/milvus_haystack/__init__.py @@ -2,6 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import MilvusDocumentStore # noqa: TID252 -from .milvus_embedding_retriever import MilvusEmbeddingRetriever # noqa: TID252 +from .milvus_embedding_retriever import ( # noqa: TID252 + MilvusEmbeddingRetriever, + MilvusHybridRetriever, + MilvusSparseEmbeddingRetriever, +) -__all__ = ["MilvusDocumentStore", "MilvusEmbeddingRetriever"] +__all__ = ["MilvusDocumentStore", "MilvusEmbeddingRetriever", "MilvusSparseEmbeddingRetriever", "MilvusHybridRetriever"] diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 9787747..c930552 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -1,16 +1,27 @@ +import copy import logging +from copy import deepcopy from typing import Any, Dict, List, Optional, Union from uuid import uuid4 from haystack import Document, default_from_dict, default_to_dict +from haystack.dataclasses.sparse_embedding import SparseEmbedding +from haystack.document_stores.errors import DocumentStoreError from haystack.document_stores.types import DuplicatePolicy from haystack.errors import FilterError -from pymilvus import MilvusException +from haystack.utils import Secret, deserialize_secrets_inplace +from pymilvus import AnnSearchRequest, MilvusException, RRFRanker +from pymilvus.client.abstract import BaseRanker from milvus_haystack.filters import parse_filters logger = logging.getLogger(__name__) + +class MilvusStoreError(DocumentStoreError): + pass + + DEFAULT_MILVUS_CONNECTION = { "host": "localhost", "port": "19530", @@ -19,6 +30,8 @@ "secure": False, } +MAX_LIMIT_SIZE = 10_000 + class MilvusDocumentStore: """ @@ -39,6 +52,9 @@ def __init__( primary_field: str = "id", text_field: str = "text", vector_field: str = "vector", + sparse_vector_field: Optional[str] = None, + sparse_index_params: Optional[dict] = None, + sparse_search_params: Optional[dict] = None, partition_key_field: Optional[str] = None, partition_names: Optional[list] = None, replica_number: int = 1, @@ -46,6 +62,8 @@ def __init__( ): """ Initialize the Milvus vector store. + For more information about Milvus, please refer to + https://milvus.io/docs :param collection_name: The name of the collection to be created. "HaystackCollection" as default. @@ -55,14 +73,50 @@ def __init__( If set, will override collection existing properties. For example: {"collection.ttl.seconds": 60}. :param connection_args: The connection args used for this class comes in the form of a dict. + - For the case of [Milvus Lite](https://milvus.io/docs/milvus_lite.md), + the most convenient method, just set the uri as a local file. + Examples: + connection_args = { + "uri": "./milvus.db" + } + - For the case of Milvus server on [docker or kubernetes](https://milvus.io/docs/quickstart.md), + it is recommended to use when you are dealing with large scale of data. + Examples: + connection_args = { + "uri": "http://localhost:19530" + } + - For the case of [Zilliz Cloud](https://zilliz.com/cloud), the fully managed + cloud service for Milvus, adjust the uri and token, which correspond to the + [Public Endpoint and Api key](https://docs.zilliz.com/docs/on-zilliz-cloud-console#free-cluster-details) + in Zilliz Cloud. + Examples: + connection_args = { + "uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com", # Public Endpoint + "token": Secret.from_env_var("ZILLIZ_CLOUD_API_KEY"), # API key. + "secure": True + } + If you use `token` or `password`, we recommend using the `Secret` class to load + the token from environment variable for security. :param consistency_level: The consistency level to use for a collection. Defaults to "Session". + :param index_params: Which index params to use. :param search_params: Which search params to use. Defaults to default of index. :param drop_old: Whether to drop the current collection. Defaults to False. :param primary_field: Name of the primary key field. Defaults to "id". :param text_field: Name of the text field. Defaults to "text". :param vector_field: Name of the vector field. Defaults to "vector". + :param sparse_vector_field: Name of the sparse vector field. Defaults to None, + which means do not use sparse retrival, + else enable sparse retrieval with this specified field. + For more information about Milvus sparse retrieval, + please refer to https://milvus.io/docs/sparse_vector.md#Sparse-Vector + :param sparse_index_params: Which index params to use for sparse field. + Only useful when `sparse_vector_field` is set. + If not specified, will use a default value. + :param sparse_search_params: Which search params to use for sparse field. + Only useful when `sparse_vector_field` is set. + If not specified, will use a default value. :param partition_key_field: Name of the partition key field. Defaults to None. :param partition_names: List of partition names. Defaults to None. :param replica_number: Number of replicas. Defaults to 1. @@ -103,6 +157,9 @@ def __init__( self._primary_field = primary_field self._text_field = text_field self._vector_field = vector_field + self._sparse_vector_field = sparse_vector_field + self.sparse_index_params = sparse_index_params + self.sparse_search_params = sparse_search_params self._partition_key_field = partition_key_field self.fields: List[str] = [] self.partition_names = partition_names @@ -111,8 +168,8 @@ def __init__( # Create the connection to the server if connection_args is None: - connection_args = DEFAULT_MILVUS_CONNECTION - self.alias = self._create_connection_alias(connection_args) + self.connection_args = DEFAULT_MILVUS_CONNECTION + self.alias = self._create_connection_alias(self.connection_args) # type: ignore[arg-type] self.col: Optional[Collection] = None # Grab the existing collection if it exists @@ -134,6 +191,7 @@ def __init__( replica_number=replica_number, timeout=timeout, ) + self._dummy_value = 999.0 def count_documents(self) -> int: """ @@ -226,7 +284,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc # Build expr. if not filters: - expr = 'id like "%"' + expr = "" else: expr = parse_filters(filters) @@ -235,6 +293,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc res = self.col.query( expr=expr, output_fields=output_fields, + limit=MAX_LIMIT_SIZE, ) except MilvusException as err: logger.error("Failed to query documents with filters expr: %s", expr) @@ -254,7 +313,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D from pymilvus import Collection, MilvusException - if len(documents) > 0 and not isinstance(documents[0], Document): + documents_cp = deepcopy(documents) + if len(documents_cp) > 0 and not isinstance(documents_cp[0], Document): err_msg = "param 'documents' must contain a list of objects of type Document" raise ValueError(err_msg) @@ -266,30 +326,46 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D "and if they aren't Milvus may contain multiple entities with duplicate primary keys." ) + # Check embeddings embedding_dim = 128 - for doc in documents: + for doc in documents_cp: if doc.embedding is not None: embedding_dim = len(doc.embedding) break empty_embedding = False - for doc in documents: + empty_sparse_embedding = False + for doc in documents_cp: if doc.embedding is None: empty_embedding = True - dummy_vector = [-10.0] * embedding_dim + dummy_vector = [self._dummy_value] * embedding_dim doc.embedding = dummy_vector + if doc.sparse_embedding is None: + empty_sparse_embedding = True + dummy_sparse_vector = SparseEmbedding( + indices=[0], + values=[self._dummy_value], + ) + doc.sparse_embedding = dummy_sparse_vector if doc.content is None: doc.content = "" - if empty_embedding: + if empty_embedding and self._sparse_vector_field is None: logger.warning( "Milvus is a purely vector database, but document has no embedding. " "A dummy embedding will be used, but this can AFFECT THE SEARCH RESULTS!!! " "Please calculate the embedding in each document first, and then write them to Milvus Store." ) + if empty_sparse_embedding and self._sparse_vector_field is not None: + logger.warning( + "You specified `sparse_vector_field`, but document has no sparse embedding. " + "A dummy sparse embedding will be used, but this can AFFECT THE SEARCH RESULTS!!! " + "Please calculate the sparse embedding in each document first, and then write them to Milvus Store." + ) - embeddings = [doc.embedding for doc in documents] - metas = [doc.meta for doc in documents] - texts = [doc.content for doc in documents] - ids = [doc.id for doc in documents] + embeddings = [doc.embedding for doc in documents_cp] + sparse_embeddings = [self._convert_sparse_to_dict(doc.sparse_embedding) for doc in documents_cp] + metas = [doc.meta for doc in documents_cp] + texts = [doc.content for doc in documents_cp] + ids = [doc.id for doc in documents_cp] if len(embeddings) == 0: logger.debug("Nothing to insert, skipping.") @@ -313,6 +389,8 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D self._vector_field: embeddings, self._primary_field: ids, } + if self._sparse_vector_field: + insert_dict[self._sparse_vector_field] = sparse_embeddings # Collect the meta into the insert dict. if metas is not None: @@ -364,11 +442,17 @@ def to_dict(self) -> Dict[str, Any]: :return: A dictionary representation of the document store. """ + new_connection_args = {} + for conn_arg_key, conn_arg_value in self.connection_args.items(): # type: ignore[union-attr] + if isinstance(conn_arg_value, Secret): + new_connection_args[conn_arg_key] = conn_arg_value.to_dict() + else: + new_connection_args[conn_arg_key] = conn_arg_value init_parameters = { "collection_name": self.collection_name, "collection_description": self.collection_description, "collection_properties": self.collection_properties, - "connection_args": self.connection_args, + "connection_args": new_connection_args, "consistency_level": self.consistency_level, "index_params": self.index_params, "search_params": self.search_params, @@ -376,6 +460,9 @@ def to_dict(self) -> Dict[str, Any]: "primary_field": self._primary_field, "text_field": self._text_field, "vector_field": self._vector_field, + "sparse_vector_field": self._sparse_vector_field, + "sparse_index_params": self.sparse_index_params, + "sparse_search_params": self.sparse_search_params, "partition_key_field": self._partition_key_field, "partition_names": self.partition_names, "replica_number": self.replica_number, @@ -391,18 +478,24 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusDocumentStore": :param data: The dictionary to use to create the document store. :return: A new document store. """ + for conn_arg_key, conn_arg_value in data["init_parameters"]["connection_args"].items(): + if isinstance(conn_arg_value, dict) and "type" in conn_arg_value and conn_arg_value["type"] == "env_var": + deserialize_secrets_inplace(data["init_parameters"]["connection_args"], keys=[conn_arg_key]) return default_from_dict(cls, data) def _create_connection_alias(self, connection_args: dict) -> str: """Create the connection to the Milvus server.""" from pymilvus import MilvusException, connections + connection_args_cp = copy.deepcopy(connection_args) # Grab the connection arguments that are used for checking existing connection - host: str = connection_args.get("host", None) - port: Union[str, int] = connection_args.get("port", None) - address: str = connection_args.get("address", None) - uri: str = connection_args.get("uri", None) - user = connection_args.get("user", None) + host: str = connection_args_cp.get("host", None) + port: Union[str, int] = connection_args_cp.get("port", None) + address: str = connection_args_cp.get("address", None) + uri: str = connection_args_cp.get("uri", None) + user = connection_args_cp.get("user", None) + token: Union[str, Secret] = connection_args_cp.get("token", None) + password: Union[str, Secret] = connection_args_cp.get("password", None) # Order of use is host/port, uri, address if host is not None and port is not None: @@ -442,8 +535,14 @@ def _create_connection_alias(self, connection_args: dict) -> str: # Generate a new connection if one doesn't exist alias = uuid4().hex + token = self._resolve_value(token) + password = self._resolve_value(password) + if token is not None: + connection_args_cp["token"] = token + if password is not None: + connection_args_cp["password"] = password try: - connections.connect(alias=alias, **connection_args) + connections.connect(alias=alias, **connection_args_cp) logger.debug("Created new connection using: %s", alias) return alias except MilvusException as err: @@ -505,6 +604,8 @@ def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = Non fields.append(FieldSchema(self._primary_field, DataType.VARCHAR, is_primary=True, max_length=65_535)) # Create the vector field, supports binary or float vectors fields.append(FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)) + if self._sparse_vector_field: + fields.append(FieldSchema(self._sparse_vector_field, DataType.SPARSE_FLOAT_VECTOR)) # Create the schema for the collection schema = CollectionSchema( @@ -538,7 +639,7 @@ def _extract_fields(self) -> None: self.fields.append(x.name) def _create_index(self) -> None: - """Create a index on the collection""" + """Create an index on the collection""" from pymilvus import Collection, MilvusException if isinstance(self.col, Collection) and self._get_index() is None: @@ -571,6 +672,18 @@ def _create_index(self) -> None: index_params=self.index_params, using=self.alias, ) + if self._sparse_vector_field: + if self.sparse_index_params is None: + self.sparse_index_params = { + "index_type": "SPARSE_INVERTED_INDEX", + "metric_type": "IP", + } + self.col.create_index( + self._sparse_vector_field, + index_params=self.sparse_index_params, + using=self.alias, + ) + logger.debug( "Successfully created an index on collection: %s", self.collection_name, @@ -623,16 +736,27 @@ def _load( timeout=timeout, ) + def _resolve_value(self, secret: Union[str, Secret]): + if isinstance(secret, Secret): + return secret.resolve_value() + if secret: + logger.warning( + "Some secret values are not encrypted. Please use `Secret` class to encrypt them. " + "The best way to implement it is to use `Secret.from_env` to load from environment variables. " + "For example:\n" + "from haystack.utils import Secret\n" + "token = Secret.from_env('YOUR_TOKEN_ENV_VAR_NAME')" + ) + return secret + def _embedding_retrieval( self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10 ) -> List[Document]: + """Dense embedding retrieval""" if self.col is None: logger.debug("No existing collection to search.") return [] - # if param is None: - param = self.search_params - # Determine result metadata fields. output_fields = self.fields[:] @@ -646,7 +770,48 @@ def _embedding_retrieval( res = self.col.search( data=[query_embedding], anns_field=self._vector_field, - param=param, + param=self.search_params, + limit=top_k, + expr=expr, + output_fields=output_fields, + timeout=None, + ) + docs = self._parse_search_result(res, output_fields=output_fields) + return docs + + def _sparse_embedding_retrieval( + self, query_sparse_embedding: SparseEmbedding, filters: Optional[Dict[str, Any]] = None, top_k: int = 10 + ) -> List[Document]: + """Sparse embedding retrieval""" + if self.col is None: + logger.debug("No existing collection to search.") + return [] + if self._sparse_vector_field is None: + message = ( + "You need to specify `sparse_vector_field` in the document store " + "to use sparse embedding retrieval. Such as: " + "MilvusDocumentStore(..., sparse_vector_field='sparse_vector',...)" + ) + raise MilvusStoreError(message) + + if self.sparse_search_params is None: + self.sparse_search_params = {"metric_type": "IP"} + + # Determine result metadata fields. + output_fields = self.fields[:] + + # Build expr. + if not filters: + expr = None + else: + expr = parse_filters(filters) + + # Perform the search. + search_data = self._convert_sparse_to_dict(query_sparse_embedding) + res = self.col.search( + data=[search_data], + anns_field=self._sparse_vector_field, + param=self.sparse_index_params, limit=top_k, expr=expr, output_fields=output_fields, @@ -655,6 +820,55 @@ def _embedding_retrieval( docs = self._parse_search_result(res, output_fields=output_fields) return docs + def _hybrid_retrieval( + self, + query_embedding: List[float], + query_sparse_embedding: SparseEmbedding, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + reranker: Optional[BaseRanker] = None, + ) -> List[Document]: + """Hybrid retrieval using both dense and sparse embeddings""" + if self.col is None: + logger.debug("No existing collection to search.") + return [] + if self._sparse_vector_field is None: + message = ( + "You need to specify `sparse_vector_field` in the document store " + "to use hybrid retrieval. Such as: " + "MilvusDocumentStore(..., sparse_vector_field='sparse_vector',...)" + ) + raise MilvusStoreError(message) + + if reranker is None: + reranker = RRFRanker() + + if self.sparse_search_params is None: + self.sparse_search_params = {"metric_type": "IP"} + + # Determine result metadata fields. + output_fields = self.fields[:] + + # Build expr. + if not filters: + expr = None + else: + expr = parse_filters(filters) + + dense_req = AnnSearchRequest([query_embedding], self._vector_field, self.search_params, limit=top_k, expr=expr) + sparse_req = AnnSearchRequest( + [self._convert_sparse_to_dict(query_sparse_embedding)], + self._sparse_vector_field, + self.sparse_search_params, + limit=top_k, + expr=expr, + ) + + # Search topK docs based on dense and sparse vectors and rerank. + res = self.col.hybrid_search([dense_req, sparse_req], rerank=reranker, limit=top_k, output_fields=output_fields) + docs = self._parse_search_result(res, output_fields=output_fields) + return docs + def _parse_search_result(self, result, output_fields=None) -> List[Document]: if output_fields is None: output_fields = self.fields[:] @@ -666,9 +880,29 @@ def _parse_search_result(self, result, output_fields=None) -> List[Document]: return docs def _parse_document(self, data: dict) -> Document: + # we store dummy vectors during writing documents if they are not provided, + # so we don't return them if they are dummy vectors + embedding = data.pop(self._vector_field) + if all(x == self._dummy_value for x in embedding): + embedding = None + + sparse_embedding = None + sparse_dict = data.pop(self._sparse_vector_field, None) + if sparse_dict: + sparse_embedding = self._convert_dict_to_sparse(sparse_dict) + if sparse_embedding.values == [self._dummy_value] and sparse_embedding.indices == [0]: + sparse_embedding = None + return Document( id=data.pop(self._primary_field), content=data.pop(self._text_field), - embedding=data.pop(self._vector_field), + embedding=embedding, + sparse_embedding=sparse_embedding, meta=data, ) + + def _convert_sparse_to_dict(self, sparse_embedding: SparseEmbedding) -> Dict: + return dict(zip(sparse_embedding.indices, sparse_embedding.values)) + + def _convert_dict_to_sparse(self, sparse_dict: Dict) -> SparseEmbedding: + return SparseEmbedding(indices=list(sparse_dict.keys()), values=list(sparse_dict.values())) diff --git a/src/milvus_haystack/milvus_embedding_retriever.py b/src/milvus_haystack/milvus_embedding_retriever.py index d8611ac..b3465b5 100644 --- a/src/milvus_haystack/milvus_embedding_retriever.py +++ b/src/milvus_haystack/milvus_embedding_retriever.py @@ -1,6 +1,10 @@ +import importlib from typing import Any, Dict, List, Optional from haystack import DeserializationError, Document, component, default_from_dict, default_to_dict +from haystack.dataclasses.sparse_embedding import SparseEmbedding +from pymilvus import RRFRanker +from pymilvus.client.abstract import BaseRanker from milvus_haystack import MilvusDocumentStore @@ -8,7 +12,7 @@ @component class MilvusEmbeddingRetriever: """ - A component for retrieving documents from an Milvus Document Store. + A component for retrieving documents from a Milvus Document Store. """ def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): @@ -66,3 +70,158 @@ def run(self, query_embedding: List[float]) -> Dict[str, List[Document]]: top_k=self.top_k, ) return {"documents": docs} + + +@component +class MilvusSparseEmbeddingRetriever: + """ + A component for retrieving documents using sparse embeddings from a Milvus Document Store. + """ + + def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10): + """ + Initializes a new instance of the MilvusSparseEmbeddingRetriever. + + :param document_store: A Milvus Document Store object used to retrieve documents. + :param filters: A dictionary with filters to narrow down the search space (default is None). + :param top_k: The maximum number of documents to retrieve (default is 10). + """ + self.filters = filters + self.top_k = top_k + self.document_store = document_store + + def to_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the retriever component. + + :returns: + A dictionary representation of the retriever component. + """ + return default_to_dict( + self, document_store=self.document_store.to_dict(), filters=self.filters, top_k=self.top_k + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever": + """ + Creates a new retriever from a dictionary. + + :param data: The dictionary to use to create the retriever. + :return: A new retriever. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + err_msg = "Missing 'document_store' in serialization data" + raise DeserializationError(err_msg) + + docstore = MilvusDocumentStore.from_dict(init_params["document_store"]) + data["init_parameters"]["document_store"] = docstore + + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_sparse_embedding: SparseEmbedding) -> Dict[str, List[Document]]: + """ + Retrieve documents from the `MilvusDocumentStore`, based on their sparse embeddings. + + :param query_sparse_embedding: Sparse Embedding of the query. + :return: List of Document similar to `query_embedding`. + """ + docs = self.document_store._sparse_embedding_retrieval( + query_sparse_embedding=query_sparse_embedding, + filters=self.filters, + top_k=self.top_k, + ) + return {"documents": docs} + + +@component +class MilvusHybridRetriever: + """ + A component for retrieving documents using hybrid search from a Milvus Document Store. + """ + + def __init__( + self, + document_store: MilvusDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + reranker: Optional[BaseRanker] = None, + ): + """ + Initializes a new instance of the MilvusHybridRetriever. + + :param document_store: A Milvus Document Store object used to retrieve documents. + :param filters: A dictionary with filters to narrow down the search space (default is None). + :param top_k: The maximum number of documents to retrieve (default is 10). + :param reranker: A PyMilvus ranker used to re-rank the results (default is RRFRanker). + """ + self.filters = filters + self.top_k = top_k + self.document_store = document_store + if reranker is None: + reranker = RRFRanker() + self.reranker = reranker + + def to_dict(self) -> Dict[str, Any]: + """ + Returns a dictionary representation of the retriever component. + + :returns: + A dictionary representation of the retriever component. + """ + return default_to_dict( + self, + document_store=self.document_store.to_dict(), + filters=self.filters, + top_k=self.top_k, + reranker=default_to_dict(self.reranker), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "MilvusEmbeddingRetriever": + """ + Creates a new retriever from a dictionary. + + :param data: The dictionary to use to create the retriever. + :return: A new retriever. + """ + init_params = data.get("init_parameters", {}) + if "document_store" not in init_params: + err_msg = "Missing 'document_store' in serialization data" + raise DeserializationError(err_msg) + + docstore = MilvusDocumentStore.from_dict(init_params["document_store"]) + data["init_parameters"]["document_store"] = docstore + if "reranker" in init_params: + reranker_type_str = init_params["reranker"]["type"] + reranker_module_name, reranker_class_name = reranker_type_str.rsplit(".", 1) + reranker_module = importlib.import_module(reranker_module_name) + reranker_cls = getattr(reranker_module, reranker_class_name) + reranker_data = { + "type": reranker_type_str, + "init_parameters": data["init_parameters"]["reranker"]["init_parameters"], + } + data["init_parameters"]["reranker"] = default_from_dict( + reranker_cls, + reranker_data, + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], query_sparse_embedding: SparseEmbedding): + """ + Retrieve documents from the `MilvusDocumentStore`, based on their dense and sparse embeddings. + + :param query_embedding: Dense Embedding of the query. + :param query_sparse_embedding: Sparse Embedding of the query. + :return: List of Document similar to `query_embedding`. + """ + docs = self.document_store._hybrid_retrieval( + query_embedding=query_embedding, + query_sparse_embedding=query_sparse_embedding, + filters=self.filters, + top_k=self.top_k, + reranker=self.reranker, + ) + return {"documents": docs} diff --git a/tests/test_document_store.py b/tests/test_document_store.py index af48adc..282edad 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -1,5 +1,4 @@ import logging -import time import pytest from haystack import Document @@ -21,6 +20,7 @@ class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsT def document_store(self) -> MilvusDocumentStore: return MilvusDocumentStore( connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", drop_old=True, ) @@ -40,7 +40,6 @@ def test_delete_documents(self, document_store: DocumentStore): assert document_store.count_documents() == 1 document_store.delete_documents([doc.id]) - time.sleep(1) assert document_store.count_documents() == 0 @pytest.mark.skip(reason="Milvus does not currently check if entity primary keys are duplicates") @@ -61,13 +60,16 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore): "collection_description": "", "collection_properties": None, "connection_args": DEFAULT_CONNECTION_ARGS, - "consistency_level": "Session", + "consistency_level": "Strong", "index_params": None, "search_params": None, "drop_old": True, "primary_field": "id", "text_field": "text", "vector_field": "vector", + "sparse_vector_field": None, + "sparse_index_params": None, + "sparse_search_params": None, "partition_key_field": None, "partition_names": None, "replica_number": 1, diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py index c4fa0d7..32071b2 100644 --- a/tests/test_embedding_retriever.py +++ b/tests/test_embedding_retriever.py @@ -1,10 +1,18 @@ import logging +from typing import List import pytest -from haystack import Document +from haystack import Document, default_to_dict +from haystack.dataclasses.sparse_embedding import SparseEmbedding +from pymilvus import RRFRanker from src.milvus_haystack import MilvusDocumentStore -from src.milvus_haystack.milvus_embedding_retriever import MilvusEmbeddingRetriever +from src.milvus_haystack.document_store import MilvusStoreError +from src.milvus_haystack.milvus_embedding_retriever import ( + MilvusEmbeddingRetriever, + MilvusHybridRetriever, + MilvusSparseEmbeddingRetriever, +) logger = logging.getLogger(__name__) @@ -19,6 +27,7 @@ class TestMilvusEmbeddingTests: def document_store(self) -> MilvusDocumentStore: return MilvusDocumentStore( connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", drop_old=True, ) @@ -52,13 +61,16 @@ def test_to_dict(self, document_store: MilvusDocumentStore): "collection_description": "", "collection_properties": None, "connection_args": DEFAULT_CONNECTION_ARGS, - "consistency_level": "Session", + "consistency_level": "Strong", "index_params": None, "search_params": None, "drop_old": True, "primary_field": "id", "text_field": "text", "vector_field": "vector", + "sparse_vector_field": None, + "sparse_index_params": None, + "sparse_search_params": None, "partition_key_field": None, "partition_names": None, "replica_number": 1, @@ -70,6 +82,8 @@ def test_to_dict(self, document_store: MilvusDocumentStore): assert result["type"] == "src.milvus_haystack.milvus_embedding_retriever.MilvusEmbeddingRetriever" assert result["init_parameters"]["document_store"] == expected_dict + assert result["init_parameters"]["filters"] is None + assert result["init_parameters"]["top_k"] == 10 def test_from_dict(self, document_store: MilvusDocumentStore): retriever_dict = { @@ -82,13 +96,16 @@ def test_from_dict(self, document_store: MilvusDocumentStore): "collection_description": "", "collection_properties": None, "connection_args": DEFAULT_CONNECTION_ARGS, - "consistency_level": "Session", + "consistency_level": "Strong", "index_params": None, "search_params": None, "drop_old": True, "primary_field": "id", "text_field": "text", "vector_field": "vector", + "sparse_vector_field": None, + "sparse_index_params": None, + "sparse_search_params": None, "partition_key_field": None, "partition_names": None, "replica_number": 1, @@ -115,3 +132,286 @@ def test_from_dict(self, document_store: MilvusDocumentStore): ) else: assert getattr(reconstructed_retriever, field) == getattr(retriever, field) + + +class TestMilvusSparseEmbeddingTests: + @pytest.fixture + def document_store(self) -> MilvusDocumentStore: + return MilvusDocumentStore( + connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", + drop_old=True, + sparse_vector_field="sparse_vector", + ) + + @pytest.fixture + def documents(self) -> List[Document]: + documents = [] + doc = Document( + content="A Foo Document", + meta={ + "name": "name_0", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40", + }, + embedding=[-10.0] * 128, + sparse_embedding=SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]), + ) + documents.append(doc) + return documents + + def test_run(self, document_store: MilvusDocumentStore, documents: List[Document]): + document_store.write_documents(documents) + retriever = MilvusSparseEmbeddingRetriever( + document_store, + ) + sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]) + res = retriever.run(sparse_query_embedding) + assert res["documents"] == documents + + def test_fail_without_sparse_field(self, documents: List[Document]): + document_store = MilvusDocumentStore( + connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", + drop_old=True, + vector_field="vector", + # Missing sparse_vector_field + ) + document_store.write_documents(documents) + retriever = MilvusSparseEmbeddingRetriever( + document_store, + ) + sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]) + with pytest.raises(MilvusStoreError): + retriever.run( + query_sparse_embedding=sparse_query_embedding, + ) + + def test_to_dict(self, document_store: MilvusDocumentStore): + expected_dict = { + "type": "src.milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": DEFAULT_CONNECTION_ARGS, + "consistency_level": "Strong", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "sparse_vector_field": "sparse_vector", + "sparse_index_params": None, + "sparse_search_params": None, + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + } + retriever = MilvusSparseEmbeddingRetriever(document_store) + result = retriever.to_dict() + + assert result["type"] == "src.milvus_haystack.milvus_embedding_retriever.MilvusSparseEmbeddingRetriever" + assert result["init_parameters"]["document_store"] == expected_dict + assert result["init_parameters"]["filters"] is None + assert result["init_parameters"]["top_k"] == 10 + + def test_from_dict(self, document_store: MilvusDocumentStore): + retriever_dict = { + "type": "src.milvus_haystack.milvus_embedding_retriever.MilvusSparseEmbeddingRetriever", + "init_parameters": { + "document_store": { + "type": "milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": DEFAULT_CONNECTION_ARGS, + "consistency_level": "Strong", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "sparse_vector_field": "sparse_vector", + "sparse_index_params": None, + "sparse_search_params": None, + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + }, + "filters": None, + "top_k": 10, + }, + } + + retriever = MilvusSparseEmbeddingRetriever(document_store) + + reconstructed_retriever = MilvusSparseEmbeddingRetriever.from_dict(retriever_dict) + for field in vars(reconstructed_retriever): + if field.startswith("__"): + continue + elif field == "document_store": + for doc_store_field in vars(document_store): + if doc_store_field.startswith("__") or doc_store_field == "alias": + continue + assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( + document_store, doc_store_field + ) + else: + assert getattr(reconstructed_retriever, field) == getattr(retriever, field) + + +class TestMilvusHybridTests: + @pytest.fixture + def document_store(self) -> MilvusDocumentStore: + return MilvusDocumentStore( + connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", + drop_old=True, + vector_field="vector", + sparse_vector_field="sparse_vector", + ) + + @pytest.fixture + def documents(self) -> List[Document]: + documents = [] + doc = Document( + content="A Foo Document", + meta={ + "name": "name_0", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40", + }, + embedding=[-10.0] * 128, + sparse_embedding=SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]), + ) + documents.append(doc) + return documents + + def test_run(self, document_store: MilvusDocumentStore, documents: List[Document]): + document_store.write_documents(documents) + retriever = MilvusHybridRetriever( + document_store, + ) + query_embedding = [-10.0] * 128 + sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]) + res = retriever.run( + query_embedding=query_embedding, + query_sparse_embedding=sparse_query_embedding, + ) + assert res["documents"] == documents + + def test_fail_without_sparse_field(self, documents: List[Document]): + document_store = MilvusDocumentStore( + connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", + drop_old=True, + vector_field="vector", + # Missing sparse_vector_field + ) + document_store.write_documents(documents) + retriever = MilvusHybridRetriever( + document_store, + ) + query_embedding = [-10.0] * 128 + sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]) + with pytest.raises(MilvusStoreError): + retriever.run( + query_embedding=query_embedding, + query_sparse_embedding=sparse_query_embedding, + ) + + def test_to_dict(self, document_store: MilvusDocumentStore): + expected_dict = { + "type": "src.milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": DEFAULT_CONNECTION_ARGS, + "consistency_level": "Strong", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "sparse_vector_field": "sparse_vector", + "sparse_index_params": None, + "sparse_search_params": None, + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + } + retriever = MilvusHybridRetriever(document_store) + result = retriever.to_dict() + + assert result["type"] == "src.milvus_haystack.milvus_embedding_retriever.MilvusHybridRetriever" + assert result["init_parameters"]["document_store"] == expected_dict + assert result["init_parameters"]["filters"] is None + assert result["init_parameters"]["top_k"] == 10 + assert result["init_parameters"]["reranker"] == default_to_dict(RRFRanker()) + + def test_from_dict(self, document_store: MilvusDocumentStore): + retriever_dict = { + "type": "src.milvus_haystack.milvus_embedding_retriever.MilvusHybridRetriever", + "init_parameters": { + "document_store": { + "type": "milvus_haystack.document_store.MilvusDocumentStore", + "init_parameters": { + "collection_name": "HaystackCollection", + "collection_description": "", + "collection_properties": None, + "connection_args": DEFAULT_CONNECTION_ARGS, + "consistency_level": "Strong", + "index_params": None, + "search_params": None, + "drop_old": True, + "primary_field": "id", + "text_field": "text", + "vector_field": "vector", + "sparse_vector_field": "sparse_vector", + "sparse_index_params": None, + "sparse_search_params": None, + "partition_key_field": None, + "partition_names": None, + "replica_number": 1, + "timeout": None, + }, + }, + "filters": None, + "top_k": 10, + "reranker": {"type": "pymilvus.client.abstract.RRFRanker", "init_parameters": {}}, + }, + } + + retriever = MilvusHybridRetriever(document_store) + + reconstructed_retriever = MilvusHybridRetriever.from_dict(retriever_dict) + for field in vars(reconstructed_retriever): + if field.startswith("__"): + continue + elif field == "document_store": + for doc_store_field in vars(document_store): + if doc_store_field.startswith("__") or doc_store_field == "alias": + continue + assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( + document_store, doc_store_field + ) + elif field == "reranker": + assert default_to_dict(getattr(reconstructed_retriever, field)) == default_to_dict(RRFRanker()) + else: + assert getattr(reconstructed_retriever, field) == getattr(retriever, field) diff --git a/tests/test_filters.py b/tests/test_filters.py index 3598152..c33ae32 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -10,6 +10,12 @@ logger = logging.getLogger(__name__) +DEFAULT_CONNECTION_ARGS = { + "uri": "http://localhost:19530", + # "uri": "./milvus_test.db", # This uri works for Milvus Lite + # Note: milvus lite may fail in some tests due to currently not supporting some expressions +} + class TestMilvusFilters(FilterDocumentsTest): @pytest.fixture @@ -60,13 +66,8 @@ def filterable_docs(self) -> List[Document]: @pytest.fixture def document_store(self) -> MilvusDocumentStore: return MilvusDocumentStore( - connection_args={ - "host": "localhost", - "port": "19530", - "user": "", - "password": "", - "secure": False, - }, + connection_args=DEFAULT_CONNECTION_ARGS, + consistency_level="Strong", drop_old=True, ) @@ -88,7 +89,9 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert r.content_type == e.content_type assert r.blob == e.blob assert r.score == r.score - assert np.allclose(np.array(r.embedding), np.array(e.embedding), atol=1e-4) + if r.embedding is not None or e.embedding is not None: + assert np.allclose(np.array(r.embedding), np.array(e.embedding), atol=1e-4) + assert r.sparse_embedding == e.sparse_embedding @pytest.mark.skip(reason="Milvus doesn't support comparison with dataframe") def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ...