Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add cases for async milvus client #38699

Merged
merged 1 commit into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions tests/python_client/base/async_milvus_client_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio
import sys
from typing import Optional, List, Union, Dict

from pymilvus import (
AsyncMilvusClient,
AnnSearchRequest,
RRFRanker,
)
from pymilvus.orm.types import CONSISTENCY_STRONG
from pymilvus.orm.collection import CollectionSchema

from check.func_check import ResponseChecker
from utils.api_request import api_request, logger_interceptor


class AsyncMilvusClientWrapper:
async_milvus_client = None

def __init__(self, active_trace=False):
self.active_trace = active_trace

def init_async_client(self, uri: str = "http://localhost:19530",
user: str = "",
password: str = "",
db_name: str = "",
token: str = "",
timeout: Optional[float] = None,
active_trace=False,
check_task=None, check_items=None,
**kwargs):
self.active_trace = active_trace

""" In order to distinguish the same name of collection """
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([AsyncMilvusClient, uri, user, password, db_name, token,
timeout], **kwargs)
self.async_milvus_client = res if is_succ else None
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ, **kwargs).run()
return res, check_result

@logger_interceptor()
async def create_collection(self,
collection_name: str,
dimension: Optional[int] = None,
primary_field_name: str = "id", # default is "id"
id_type: str = "int", # or "string",
vector_field_name: str = "vector", # default is "vector"
metric_type: str = "COSINE",
auto_id: bool = False,
timeout: Optional[float] = None,
schema: Optional[CollectionSchema] = None,
index_params=None,
**kwargs):
kwargs["consistency_level"] = kwargs.get("consistency_level", CONSISTENCY_STRONG)

return await self.async_milvus_client.create_collection(collection_name, dimension,
primary_field_name,
id_type, vector_field_name, metric_type,
auto_id,
timeout, schema, index_params, **kwargs)

@logger_interceptor()
async def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.drop_collection(collection_name, timeout, **kwargs)

@logger_interceptor()
async def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
return await self.async_milvus_client.load_collection(collection_name, timeout, **kwargs)

@logger_interceptor()
async def create_index(self, collection_name: str, index_params, timeout: Optional[float] = None,
**kwargs):
return await self.async_milvus_client.create_index(collection_name, index_params, timeout, **kwargs)

@logger_interceptor()
async def insert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.insert(collection_name, data, timeout, partition_name, **kwargs)

@logger_interceptor()
async def upsert(self,
collection_name: str,
data: Union[Dict, List[Dict]],
timeout: Optional[float] = None,
partition_name: Optional[str] = "",
**kwargs):
return await self.async_milvus_client.upsert(collection_name, data, timeout, partition_name, **kwargs)

@logger_interceptor()
async def search(self,
collection_name: str,
data: Union[List[list], list],
filter: str = "",
limit: int = 10,
output_fields: Optional[List[str]] = None,
search_params: Optional[dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.search(collection_name, data,
filter,
limit, output_fields, search_params,
timeout,
partition_names, anns_field, **kwargs)

@logger_interceptor()
async def hybrid_search(self,
collection_name: str,
reqs: List[AnnSearchRequest],
ranker: RRFRanker,
limit: int = 10,
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.hybrid_search(collection_name, reqs,
ranker,
limit, output_fields,
timeout, partition_names, **kwargs)

@logger_interceptor()
async def query(self,
collection_name: str,
filter: str = "",
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
ids: Optional[Union[List, str, int]] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.query(collection_name, filter,
output_fields, timeout,
ids, partition_names,
**kwargs)

@logger_interceptor()
async def get(self,
collection_name: str,
ids: Union[list, str, int],
output_fields: Optional[List[str]] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
**kwargs):
return await self.async_milvus_client.get(collection_name, ids,
output_fields, timeout,
partition_names,
**kwargs)

@logger_interceptor()
async def delete(self,
collection_name: str,
ids: Optional[Union[list, str, int]] = None,
timeout: Optional[float] = None,
filter: Optional[str] = None,
partition_name: Optional[str] = None,
**kwargs):
return await self.async_milvus_client.delete(collection_name, ids,
timeout, filter,
partition_name,
**kwargs)

@classmethod
def create_schema(cls, **kwargs):
kwargs["check_fields"] = False # do not check fields for now
return CollectionSchema([], **kwargs)

@logger_interceptor()
async def close(self, **kwargs):
return await self.async_milvus_client.close(**kwargs)
13 changes: 13 additions & 0 deletions tests/python_client/base/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from base.utility_wrapper import ApiUtilityWrapper
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from base.high_level_api_wrapper import HighLevelApiWrapper
from base.async_milvus_client_wrapper import AsyncMilvusClientWrapper
from utils.util_log import test_log as log
from common import common_func as cf
from common import common_type as ct
Expand All @@ -35,6 +36,7 @@ class Base:
collection_object_list = []
resource_group_list = []
high_level_api_wrap = None
async_milvus_client_wrap = None
skip_connection = False

def setup_class(self):
Expand All @@ -59,6 +61,7 @@ def _setup_objects(self):
self.field_schema_wrap = ApiFieldSchemaWrapper()
self.database_wrap = ApiDatabaseWrapper()
self.high_level_api_wrap = HighLevelApiWrapper()
self.async_milvus_client_wrap = AsyncMilvusClientWrapper()

def teardown_method(self, method):
log.info(("*" * 35) + " teardown " + ("*" * 35))
Expand Down Expand Up @@ -166,6 +169,16 @@ def _connect(self, enable_milvus_client_api=False):
log.info(f"server version: {server_version}")
return res

def init_async_milvus_client(self):
uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}"
kwargs = {
"uri": uri,
"user": cf.param_info.param_user,
"password": cf.param_info.param_password,
"token": cf.param_info.param_token,
}
self.async_milvus_client_wrap.init_async_client(**kwargs)

def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
enable_dynamic_field=False, with_json=True, **kwargs):
name = cf.gen_unique_str('coll_') if name is None else name
Expand Down
30 changes: 30 additions & 0 deletions tests/python_client/base/high_level_api_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import sys
import time
from typing import Optional

import timeout_decorator
from numpy import NaN

Expand Down Expand Up @@ -40,6 +42,13 @@ def init_milvus_client(self, uri, user="", password="", db_name="", token="", ti
timeout=timeout, **kwargs).run()
return res, check_result

@trace()
def close(self, client, check_task=None, check_items=None):
func_name = sys._getframe().f_code.co_name
res, is_succ = api_request([client.close])
check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ).run()
return res, check_result

@trace()
def create_schema(self, client, timeout=None, check_task=None,
check_items=None, **kwargs):
Expand Down Expand Up @@ -103,6 +112,17 @@ def upsert(self, client, collection_name, data, timeout=None, check_task=None, c
**kwargs).run()
return res, check_result

@trace()
def get_collection_stats(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})

func_name = sys._getframe().f_code.co_name
res, check = api_request([client.get_collection_stats, collection_name], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
collection_name=collection_name, **kwargs).run()
return res, check_result

@trace()
def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None,
timeout=None, check_task=None, check_items=None, **kwargs):
Expand Down Expand Up @@ -315,6 +335,16 @@ def rename_collection(self, client, old_name, new_name, timeout=None, check_task
**kwargs).run()
return res, check_result

@trace()
def create_database(self, client, db_name, properties: Optional[dict] = None, check_task=None, check_items=None, **kwargs):
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_database, db_name, properties], **kwargs)
check_result = ResponseChecker(res, func_name, check_task,
check_items, check,
db_name=db_name, properties=properties,
**kwargs).run()
return res, check_result

@trace()
def create_partition(self, client, collection_name, partition_name, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
Expand Down
13 changes: 12 additions & 1 deletion tests/python_client/check/func_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
from common.common_type import CheckTasks, Connect_Object_Name
# from common.code_mapping import ErrorCode, ErrorMessage
from pymilvus import Collection, Partition, ResourceGroupInfo
from utils.api_request import Error
import check.param_check as pc


class Error:
def __init__(self, error):
self.code = getattr(error, 'code', -1)
self.message = getattr(error, 'message', str(error))

def __str__(self):
return f"Error(code={self.code}, message={self.message})"

def __repr__(self):
return f"Error(code={self.code}, message={self.message})"


class ResponseChecker:
def __init__(self, response, func_name, check_task, check_items, is_succ=True, **kwargs):
self.response = response # response of api request
Expand Down
4 changes: 2 additions & 2 deletions tests/python_client/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pytest_addoption(parser):
parser.addoption("--user", action="store", default="", help="user name for connection")
parser.addoption("--password", action="store", default="", help="password for connection")
parser.addoption("--db_name", action="store", default="default", help="database name for connection")
parser.addoption("--secure", type=bool, action="store", default=False, help="secure for connection")
parser.addoption("--secure", action="store", default=False, help="secure for connection")
parser.addoption("--milvus_ns", action="store", default="chaos-testing", help="milvus_ns")
parser.addoption("--http_port", action="store", default=19121, help="http's port")
parser.addoption("--handler", action="store", default="GRPC", help="handler of request")
Expand All @@ -45,7 +45,7 @@ def pytest_addoption(parser):
parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest")
parser.addoption('--check_content', action='store', default="check_content", help="content of check")
parser.addoption('--field_name', action='store', default="field_name", help="field_name of index")
parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--replica_num', action='store', default=ct.default_replica_num, help="memory replica number")
parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip")
parser.addoption('--uri', action='store', default="", help="uri for high level api")
parser.addoption('--token', action='store', default="", help="token for high level api")
Expand Down
3 changes: 3 additions & 0 deletions tests/python_client/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ log_date_format = %Y-%m-%d %H:%M:%S

filterwarnings =
ignore::DeprecationWarning

asyncio_default_fixture_loop_scope = function

7 changes: 4 additions & 3 deletions tests/python_client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ requests==2.26.0
scikit-learn==1.1.3
timeout_decorator==0.5.0
ujson==5.5.0
pytest==7.2.0
pytest==8.3.4
pytest-asyncio==0.24.0
pytest-assume==2.4.3
pytest-timeout==1.3.3
pytest-repeat==0.8.0
Expand All @@ -27,8 +28,8 @@ pytest-parallel
pytest-random-order

# pymilvus
pymilvus==2.5.1rc25
pymilvus[bulk_writer]==2.5.1rc25
pymilvus==2.5.2rc3
pymilvus[bulk_writer]==2.5.2rc3


# for customize config test
Expand Down
Loading