Skip to content

Commit

Permalink
test: add cases for async milvus client
Browse files Browse the repository at this point in the history
Signed-off-by: ThreadDao <[email protected]>
  • Loading branch information
ThreadDao committed Dec 24, 2024
1 parent 636e107 commit 9fc90b6
Show file tree
Hide file tree
Showing 9 changed files with 797 additions and 19 deletions.
174 changes: 174 additions & 0 deletions tests/python_client/base/async_milvus_client_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio
import sys
from typing import Optional, List, Union, Dict

from pymilvus import (
AsyncMilvusClient,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.orm.types import CONSISTENCY_STRONG
from pymilvus.orm.collection import CollectionSchema

from check.func_check import ResponseChecker
from utils.api_request import api_request, logger_interceptor


class AsyncMilvusClientWrapper:
async_milvus_client = None

def __init__(self, active_trace=False):
self.active_trace = active_trace

def init_async_client(self, uri: str = "http://localhost:19530",
user: str = "",
password: str = "",
db_name: str = "",
token: str = "",
timeout: Optional[float] = None,
active_trace=False,
check_task=None, check_items=None,
**kwargs):
self.active_trace = active_trace

""" In order to distinguish the same name of collection """
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([AsyncMilvusClient, uri, user, password, db_name, token,
timeout], **kwargs)
self.async_milvus_client = res if is_succ else None
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ, **kwargs).run()
return res, check_result

@logger_interceptor()
async def create_collection(self,
collection_name: str,
dimension: Optional[int] = None,
primary_field_name: str = "id", # default is "id"
id_type: str = "int", # or "string",
vector_field_name: str = "vector", # default is "vector"
metric_type: str = "COSINE",
auto_id: bool = False,
timeout: Optional[float] = None,
schema: Optional[CollectionSchema] = None,
index_params=None,
**kwargs):
kwargs["consistency_level"] = kwargs.get("consistency_level", CONSISTENCY_STRONG)

return await self.async_milvus_client.create_collection(collection_name, dimension,
primary_field_name,
id_type, vector_field_name, metric_type,
auto_id,
timeout, schema, index_params, **kwargs)

@logger_interceptor()
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.drop_collection(collection_name, timeout, **kwargs)

@logger_interceptor()
async def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.load_collection(collection_name, timeout, **kwargs)

@logger_interceptor()
async def create_index(self, collection_name: str, index_params, timeout: Optional[float] = None,
**kwargs):
return await self.async_milvus_client.create_index(collection_name, index_params, timeout, **kwargs)

@logger_interceptor()
async def insert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.insert(collection_name, data, timeout, partition_name, **kwargs)

@logger_interceptor()
async def upsert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.upsert(collection_name, data, timeout, partition_name, **kwargs)

@logger_interceptor()
async def search(self,
collection_name: str,
data: Union[List[list], list],
filter: str = "",
limit: int = 10,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.search(collection_name, data,
filter,
limit, output_fields, search_params,
timeout,
partition_names, anns_field, **kwargs)

@logger_interceptor()
async def hybrid_search(self,
collection_name: str,
reqs: List[AnnSearchRequest],
ranker: RRFRanker,
limit: int = 10,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.hybrid_search(collection_name, reqs,
ranker,
limit, output_fields,
timeout, partition_names, **kwargs)

@logger_interceptor()
async def query(self,
collection_name: str,
filter: str = "",
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
ids: Optional[Union[List, str, int]] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.query(collection_name, filter,
output_fields, timeout,
ids, partition_names,
**kwargs)

@logger_interceptor()
async def get(self,
collection_name: str,
ids: Union[list, str, int],
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.get(collection_name, ids,
output_fields, timeout,
partition_names,
**kwargs)

@logger_interceptor()
async def delete(self,
collection_name: str,
ids: Optional[Union[list, str, int]] = None,
timeout: Optional[float] = None,
filter: Optional[str] = None,
partition_name: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.delete(collection_name, ids,
timeout, filter,
partition_name,
**kwargs)

@classmethod
def create_schema(cls, **kwargs):
kwargs["check_fields"] = False # do not check fields for now
return CollectionSchema([], **kwargs)

@logger_interceptor()
async def close(self, **kwargs):
return await self.async_milvus_client.close(**kwargs)
13 changes: 13 additions & 0 deletions tests/python_client/base/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from base.utility_wrapper import ApiUtilityWrapper
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from base.high_level_api_wrapper import HighLevelApiWrapper
from base.async_milvus_client_wrapper import AsyncMilvusClientWrapper
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
Expand All @@ -35,6 +36,7 @@ class Base:
collection_object_list = []
resource_group_list = []
high_level_api_wrap = None
async_milvus_client_wrap = None
skip_connection = False

def setup_class(self):
Expand All @@ -59,6 +61,7 @@ def _setup_objects(self):
self.field_schema_wrap = ApiFieldSchemaWrapper()
self.database_wrap = ApiDatabaseWrapper()
self.high_level_api_wrap = HighLevelApiWrapper()
self.async_milvus_client_wrap = AsyncMilvusClientWrapper()

def teardown_method(self, method):
log.info(("*" * 35) + " teardown " + ("*" * 35))
Expand Down Expand Up @@ -166,6 +169,16 @@ def _connect(self, enable_milvus_client_api=False):
log.info(f"server version: {server_version}")
return res

def init_async_milvus_client(self):
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
kwargs = {
"uri": uri,
"user": cf.param_info.param_user,
"password": cf.param_info.param_password,
"token": cf.param_info.param_token,
}
self.async_milvus_client_wrap.init_async_client(**kwargs)

def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
enable_dynamic_field=False, with_json=True, **kwargs):
name = cf.gen_unique_str('coll_') if name is None else name
Expand Down
30 changes: 30 additions & 0 deletions tests/python_client/base/high_level_api_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
import time
from typing import Optional

import timeout_decorator
from numpy import NaN

Expand Down Expand Up @@ -40,6 +42,13 @@ def init_milvus_client(self, uri, user="", password="", db_name="", token="", ti
timeout=timeout, **kwargs).run()
return res, check_result

@trace()
def close(self, client, check_task=None, check_items=None):
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([client.close])
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ).run()
return res, check_result

@trace()
def create_schema(self, client, timeout=None, check_task=None,
check_items=None, **kwargs):
Expand Down Expand Up @@ -103,6 +112,17 @@ def upsert(self, client, collection_name, data, timeout=None, check_task=None, c
**kwargs).run()
return res, check_result

@trace()
def get_collection_stats(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})

func_name = sys._getframe().f_code.co_name
res, check = api_request([client.get_collection_stats, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, **kwargs).run()
return res, check_result

@trace()
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
timeout=None, check_task=None, check_items=None, **kwargs):
Expand Down Expand Up @@ -315,6 +335,16 @@ def rename_collection(self, client, old_name, new_name, timeout=None, check_task
**kwargs).run()
return res, check_result

@trace()
def create_database(self, client, db_name, properties: Optional[dict] = None, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_database, db_name, properties], **kwargs)
check_result = ResponseChecker(res, func_name, check_task,
check_items, check,
db_name=db_name, properties=properties,
**kwargs).run()
return res, check_result

@trace()
def create_partition(self, client, collection_name, partition_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
Expand Down
13 changes: 12 additions & 1 deletion tests/python_client/check/func_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
from common.common_type import CheckTasks, Connect_Object_Name
# from common.code_mapping import ErrorCode, ErrorMessage
from pymilvus import Collection, Partition, ResourceGroupInfo
from utils.api_request import Error
import check.param_check as pc


class Error:
def __init__(self, error):
self.code = getattr(error, 'code', -1)
self.message = getattr(error, 'message', str(error))

def __str__(self):
return f"Error(code={self.code}, message={self.message})"

def __repr__(self):
return f"Error(code={self.code}, message={self.message})"


class ResponseChecker:
def __init__(self, response, func_name, check_task, check_items, is_succ=True, **kwargs):
self.response = response # response of api request
Expand Down
4 changes: 2 additions & 2 deletions tests/python_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pytest_addoption(parser):
parser.addoption("--user", action="store", default="", help="user name for connection")
parser.addoption("--password", action="store", default="", help="password for connection")
parser.addoption("--db_name", action="store", default="default", help="database name for connection")
parser.addoption("--secure", type=bool, action="store", default=False, help="secure for connection")
parser.addoption("--secure", action="store", default=False, help="secure for connection")
parser.addoption("--milvus_ns", action="store", default="chaos-testing", help="milvus_ns")
parser.addoption("--http_port", action="store", default=19121, help="http's port")
parser.addoption("--handler", action="store", default="GRPC", help="handler of request")
Expand All @@ -45,7 +45,7 @@ def pytest_addoption(parser):
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--replica_num', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
parser.addoption('--uri', action='store', default="", help="uri for high level api")
parser.addoption('--token', action='store', default="", help="token for high level api")
Expand Down
3 changes: 3 additions & 0 deletions tests/python_client/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ log_date_format = %Y-%m-%d %H:%M:%S

filterwarnings =
ignore::DeprecationWarning

asyncio_default_fixture_loop_scope = function

7 changes: 4 additions & 3 deletions tests/python_client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ requests==2.26.0
scikit-learn==1.1.3
timeout_decorator==0.5.0
ujson==5.5.0
pytest==7.2.0
pytest==8.3.4
pytest-asyncio==0.24.0
pytest-assume==2.4.3
pytest-timeout==1.3.3
pytest-repeat==0.8.0
Expand All @@ -27,8 +28,8 @@ pytest-parallel
pytest-random-order

# pymilvus
pymilvus==2.5.1rc25
pymilvus[bulk_writer]==2.5.1rc25
pymilvus==2.5.2rc3
pymilvus[bulk_writer]==2.5.2rc3


# for customize config test
Expand Down
Loading

0 comments on commit 9fc90b6

Please sign in to comment.