Skip to content

Commit

Permalink
Merge pull request #13 from zc277584121/main
Browse files Browse the repository at this point in the history
fix lint
  • Loading branch information
jaelgu authored Mar 13, 2024
2 parents 496c412 + b6e3958 commit a828d4d
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 165 deletions.
10 changes: 10 additions & 0 deletions OWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
filters:
".*":
reviewers:
- jaelgu
- zc277584121
- codingjaguar
approvers:
- jaelgu
- zc277584121
- codingjaguar
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ dependencies = [
"typing_extensions",
"pymilvus",
"milvus",
"farm-haystack"
]

[project.urls]
Expand Down Expand Up @@ -68,7 +67,7 @@ dependencies = [
"ruff>=0.0.243",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/milvus_haystack tests}"
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
Expand Down Expand Up @@ -175,6 +174,10 @@ markers = [
[[tool.mypy.overrides]]
module = [
"haystack.*",
"milvus_haystack.*",
"pymilvus.*",
"numpy",
"milvus",
"pytest.*"
]
ignore_missing_imports = true
9 changes: 3 additions & 6 deletions src/milvus_haystack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .document_store import MilvusDocumentStore
from .milvus_embedding_retriever import MilvusEmbeddingRetriever
from .document_store import MilvusDocumentStore # noqa: TID252
from .milvus_embedding_retriever import MilvusEmbeddingRetriever # noqa: TID252

__all__ = [
"MilvusDocumentStore",
"MilvusEmbeddingRetriever"
]
__all__ = ["MilvusDocumentStore", "MilvusEmbeddingRetriever"]
127 changes: 52 additions & 75 deletions src/milvus_haystack/document_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
from typing import List, Dict, Optional, Union, Any
from typing import Any, Dict, List, Optional, Union
from uuid import uuid4

from haystack import Document, default_from_dict, default_to_dict
from haystack.document_stores.types import DuplicatePolicy
from haystack.errors import FilterError
from pymilvus import MilvusException

from milvus_haystack.filters import parse_filters

logger = logging.getLogger(__name__)
Expand All @@ -24,23 +26,23 @@ class MilvusDocumentStore:
"""

def __init__(
self,
collection_name: str = "HaystackCollection",
collection_description: str = "",
collection_properties: Optional[Dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
*,
primary_field: str = "id",
text_field: str = "text",
vector_field: str = "vector",
partition_key_field: Optional[str] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
self,
collection_name: str = "HaystackCollection",
collection_description: str = "",
collection_properties: Optional[Dict[str, Any]] = None,
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False, # noqa: FBT002
*,
primary_field: str = "id",
text_field: str = "text",
vector_field: str = "vector",
partition_key_field: Optional[str] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
):
"""
Initialize the Milvus vector store.
Expand Down Expand Up @@ -232,9 +234,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
output_fields=output_fields,
)
except MilvusException as err:
logger.error(
"Failed to query documents with filters expr: %s", expr
)
logger.error("Failed to query documents with filters expr: %s", expr)
raise FilterError(err) from err
docs = [self._parse_document(d) for d in res]
return docs
Expand Down Expand Up @@ -293,7 +293,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
return 0

# If the collection hasn't been initialized yet, perform all steps to do so
kwargs = {}
kwargs: Dict[str, Any] = {}
if not isinstance(self.col, Collection):
kwargs = {"embeddings": embeddings, "metas": metas}
if self.partition_names:
Expand Down Expand Up @@ -322,10 +322,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)

ids: list[str] = []

batch_size = 1000
assert isinstance(self.col, Collection)
if not isinstance(self.col, Collection):
raise MilvusException(message="Collection is not initialized")
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
Expand All @@ -337,9 +336,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
res = self.col.insert(insert_list, timeout=None, **kwargs)
ids.extend(res.primary_keys)
except MilvusException as err:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
logger.error("Failed to insert batch starting at entity: %s/%s", i, total_count)
raise err
self.col.flush()
return len(ids)
Expand Down Expand Up @@ -432,11 +429,11 @@ def _create_connection_alias(self, connection_args: dict) -> str:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
Expand All @@ -452,12 +449,12 @@ def _create_connection_alias(self, connection_args: dict) -> str:
raise err

def _init(
self,
embeddings: Optional[List] = None,
metas: Optional[List[Dict]] = None,
partition_names: Optional[List] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
self,
embeddings: Optional[List] = None,
metas: Optional[List[Dict]] = None,
partition_names: Optional[List] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metas)
Expand All @@ -470,9 +467,7 @@ def _init(
timeout=timeout,
)

def _create_collection(
self, embeddings: list, metas: Optional[List[Dict]] = None
) -> None:
def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None:
from pymilvus import (
Collection,
CollectionSchema,
Expand All @@ -498,26 +493,16 @@ def _create_collection(
raise ValueError(err_msg)
# Datatype is a string/varchar equivalent
elif dtype == DataType.VARCHAR:
fields.append(
FieldSchema(key, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
else:
fields.append(FieldSchema(key, dtype))

# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
fields.append(FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535))
# Create the primary key field
fields.append(
FieldSchema(
self._primary_field, DataType.VARCHAR, is_primary=True, max_length=65_535
)
)
fields.append(FieldSchema(self._primary_field, DataType.VARCHAR, is_primary=True, max_length=65_535))
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
fields.append(FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim))

# Create the schema for the collection
schema = CollectionSchema(
Expand All @@ -538,9 +523,7 @@ def _create_collection(
if self.collection_properties is not None:
self.col.set_properties(self.collection_properties)
except MilvusException as err:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, err
)
logger.error("Failed to create collection: %s error: %s", self.collection_name, err)
raise err

def _extract_fields(self) -> None:
Expand Down Expand Up @@ -592,9 +575,7 @@ def _create_index(self) -> None:
)

except MilvusException as err:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
logger.error("Failed to create an index on collection: %s", self.collection_name)
raise err

def _create_search_params(self) -> None:
Expand All @@ -620,20 +601,19 @@ def _get_index(self) -> Optional[Dict[str, Any]]:
return None

def _load(
self,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
self,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
"""Load the collection if available."""
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState

if (
isinstance(self.col, Collection)
and self._get_index() is not None
and utility.load_state(self.collection_name, using=self.alias)
== LoadState.NotLoad
isinstance(self.col, Collection)
and self._get_index() is not None
and utility.load_state(self.collection_name, using=self.alias) == LoadState.NotLoad
):
self.col.load(
partition_names=partition_names,
Expand All @@ -642,11 +622,8 @@ def _load(
)

def _embedding_retrieval(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10
):
self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10
) -> List[Document]:
if self.col is None:
logger.debug("No existing collection to search.")
return []
Expand Down
21 changes: 12 additions & 9 deletions src/milvus_haystack/filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict, Union, Any
from typing import Any, Dict, Union

from haystack.errors import FilterError

LOGIC_OPERATORS = [
Expand Down Expand Up @@ -55,10 +56,12 @@ def _parse_comparison(filters: Dict[str, Any]) -> str:


def _assert_comparison_filter(filters: Dict[str, Any]):
assert "operator" in filters, "operator must be specified in filters"
assert "field" in filters, "field must be specified in filters"
assert "value" in filters, "value must be specified in filters"
assert filters["operator"] in COMPARISON_OPERATORS, FilterError("operator must be one of: %s" % LOGIC_OPERATORS)
assert "operator" in filters, "operator must be specified in filters" # noqa: S101
assert "field" in filters, "field must be specified in filters" # noqa: S101
assert "value" in filters, "value must be specified in filters" # noqa: S101
assert filters["operator"] in COMPARISON_OPERATORS, FilterError( # noqa: S101
"operator must be one of: %s" % LOGIC_OPERATORS
)


def _parse_logic(filters: Dict[str, Any]) -> str:
Expand All @@ -80,7 +83,7 @@ def _parse_logic(filters: Dict[str, Any]) -> str:


def _assert_logic_filter(filters: Dict[str, Any]):
assert "operator" in filters, "operator must be specified in filters"
assert "conditions" in filters, "conditions must be specified in filters"
assert filters["operator"] in LOGIC_OPERATORS, "operator must be one of: %s" % LOGIC_OPERATORS
assert isinstance(filters["conditions"], list), "conditions must be a list"
assert "operator" in filters, "operator must be specified in filters" # noqa: S101
assert "conditions" in filters, "conditions must be specified in filters" # noqa: S101
assert filters["operator"] in LOGIC_OPERATORS, "operator must be one of: %s" % LOGIC_OPERATORS # noqa: S101
assert isinstance(filters["conditions"], list), "conditions must be a list" # noqa: S101
8 changes: 5 additions & 3 deletions src/milvus_haystack/milvus_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Dict, Optional, List
from haystack import component, Document
from typing import Any, Dict, List, Optional

from haystack import Document, component

from milvus_haystack import MilvusDocumentStore


Expand All @@ -22,7 +24,7 @@ def __init__(self, document_store: MilvusDocumentStore, filters: Optional[Dict[s
self.document_store = document_store

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
def run(self, query_embedding: List[float]) -> Dict[str, List[Document]]:
"""
Retrieve documents from the `MilvusDocumentStore`, based on their dense embeddings.
Expand Down
Loading

0 comments on commit a828d4d

Please sign in to comment.