From 9fc90b6cc1e87c9914809b7f3c5288cbd589f0f5 Mon Sep 17 00:00:00 2001 From: ThreadDao Date: Fri, 20 Dec 2024 18:50:41 +0800 Subject: [PATCH] test: add cases for async milvus client Signed-off-by: ThreadDao --- .../base/async_milvus_client_wrapper.py | 174 ++++++ tests/python_client/base/client_base.py | 13 + .../base/high_level_api_wrapper.py | 30 ++ tests/python_client/check/func_check.py | 13 +- tests/python_client/conftest.py | 4 +- tests/python_client/pytest.ini | 3 + tests/python_client/requirements.txt | 7 +- .../async_milvus_client/test_e2e_async.py | 509 ++++++++++++++++++ tests/python_client/utils/api_request.py | 63 ++- 9 files changed, 797 insertions(+), 19 deletions(-) create mode 100644 tests/python_client/base/async_milvus_client_wrapper.py create mode 100644 tests/python_client/testcases/async_milvus_client/test_e2e_async.py diff --git a/tests/python_client/base/async_milvus_client_wrapper.py b/tests/python_client/base/async_milvus_client_wrapper.py new file mode 100644 index 0000000000000..82f3f2f59f216 --- /dev/null +++ b/tests/python_client/base/async_milvus_client_wrapper.py @@ -0,0 +1,174 @@ +import asyncio +import sys +from typing import Optional, List, Union, Dict + +from pymilvus import ( + AsyncMilvusClient, + AnnSearchRequest, + RRFRanker, +) +from pymilvus.orm.types import CONSISTENCY_STRONG +from pymilvus.orm.collection import CollectionSchema + +from check.func_check import ResponseChecker +from utils.api_request import api_request, logger_interceptor + + +class AsyncMilvusClientWrapper: + async_milvus_client = None + + def __init__(self, active_trace=False): + self.active_trace = active_trace + + def init_async_client(self, uri: str = "http://localhost:19530", + user: str = "", + password: str = "", + db_name: str = "", + token: str = "", + timeout: Optional[float] = None, + active_trace=False, + check_task=None, check_items=None, + **kwargs): + self.active_trace = active_trace + + """ In order to distinguish the same name of collection """ + func_name = sys._getframe().f_code.co_name + res, is_succ = api_request([AsyncMilvusClient, uri, user, password, db_name, token, + timeout], **kwargs) + self.async_milvus_client = res if is_succ else None + check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ, **kwargs).run() + return res, check_result + + @logger_interceptor() + async def create_collection(self, + collection_name: str, + dimension: Optional[int] = None, + primary_field_name: str = "id", # default is "id" + id_type: str = "int", # or "string", + vector_field_name: str = "vector", # default is "vector" + metric_type: str = "COSINE", + auto_id: bool = False, + timeout: Optional[float] = None, + schema: Optional[CollectionSchema] = None, + index_params=None, + **kwargs): + kwargs["consistency_level"] = kwargs.get("consistency_level", CONSISTENCY_STRONG) + + return await self.async_milvus_client.create_collection(collection_name, dimension, + primary_field_name, + id_type, vector_field_name, metric_type, + auto_id, + timeout, schema, index_params, **kwargs) + + @logger_interceptor() + async def drop_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + return await self.async_milvus_client.drop_collection(collection_name, timeout, **kwargs) + + @logger_interceptor() + async def load_collection(self, collection_name: str, timeout: Optional[float] = None, **kwargs): + return await self.async_milvus_client.load_collection(collection_name, timeout, **kwargs) + + @logger_interceptor() + async def create_index(self, collection_name: str, index_params, timeout: Optional[float] = None, + **kwargs): + return await self.async_milvus_client.create_index(collection_name, index_params, timeout, **kwargs) + + @logger_interceptor() + async def insert(self, + collection_name: str, + data: Union[Dict, List[Dict]], + timeout: Optional[float] = None, + partition_name: Optional[str] = "", + **kwargs): + return await self.async_milvus_client.insert(collection_name, data, timeout, partition_name, **kwargs) + + @logger_interceptor() + async def upsert(self, + collection_name: str, + data: Union[Dict, List[Dict]], + timeout: Optional[float] = None, + partition_name: Optional[str] = "", + **kwargs): + return await self.async_milvus_client.upsert(collection_name, data, timeout, partition_name, **kwargs) + + @logger_interceptor() + async def search(self, + collection_name: str, + data: Union[List[list], list], + filter: str = "", + limit: int = 10, + output_fields: Optional[List[str]] = None, + search_params: Optional[dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + **kwargs): + return await self.async_milvus_client.search(collection_name, data, + filter, + limit, output_fields, search_params, + timeout, + partition_names, anns_field, **kwargs) + + @logger_interceptor() + async def hybrid_search(self, + collection_name: str, + reqs: List[AnnSearchRequest], + ranker: RRFRanker, + limit: int = 10, + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + **kwargs): + return await self.async_milvus_client.hybrid_search(collection_name, reqs, + ranker, + limit, output_fields, + timeout, partition_names, **kwargs) + + @logger_interceptor() + async def query(self, + collection_name: str, + filter: str = "", + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + ids: Optional[Union[List, str, int]] = None, + partition_names: Optional[List[str]] = None, + **kwargs): + return await self.async_milvus_client.query(collection_name, filter, + output_fields, timeout, + ids, partition_names, + **kwargs) + + @logger_interceptor() + async def get(self, + collection_name: str, + ids: Union[list, str, int], + output_fields: Optional[List[str]] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + **kwargs): + return await self.async_milvus_client.get(collection_name, ids, + output_fields, timeout, + partition_names, + **kwargs) + + @logger_interceptor() + async def delete(self, + collection_name: str, + ids: Optional[Union[list, str, int]] = None, + timeout: Optional[float] = None, + filter: Optional[str] = None, + partition_name: Optional[str] = None, + **kwargs): + return await self.async_milvus_client.delete(collection_name, ids, + timeout, filter, + partition_name, + **kwargs) + + @classmethod + def create_schema(cls, **kwargs): + kwargs["check_fields"] = False # do not check fields for now + return CollectionSchema([], **kwargs) + + @logger_interceptor() + async def close(self, **kwargs): + return await self.async_milvus_client.close(**kwargs) \ No newline at end of file diff --git a/tests/python_client/base/client_base.py b/tests/python_client/base/client_base.py index 56c4e56ce88ff..2a708f100d68c 100644 --- a/tests/python_client/base/client_base.py +++ b/tests/python_client/base/client_base.py @@ -13,6 +13,7 @@ from base.utility_wrapper import ApiUtilityWrapper from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper from base.high_level_api_wrapper import HighLevelApiWrapper +from base.async_milvus_client_wrapper import AsyncMilvusClientWrapper from utils.util_log import test_log as log from common import common_func as cf from common import common_type as ct @@ -35,6 +36,7 @@ class Base: collection_object_list = [] resource_group_list = [] high_level_api_wrap = None + async_milvus_client_wrap = None skip_connection = False def setup_class(self): @@ -59,6 +61,7 @@ def _setup_objects(self): self.field_schema_wrap = ApiFieldSchemaWrapper() self.database_wrap = ApiDatabaseWrapper() self.high_level_api_wrap = HighLevelApiWrapper() + self.async_milvus_client_wrap = AsyncMilvusClientWrapper() def teardown_method(self, method): log.info(("*" * 35) + " teardown " + ("*" * 35)) @@ -166,6 +169,16 @@ def _connect(self, enable_milvus_client_api=False): log.info(f"server version: {server_version}") return res + def init_async_milvus_client(self): + uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" + kwargs = { + "uri": uri, + "user": cf.param_info.param_user, + "password": cf.param_info.param_password, + "token": cf.param_info.param_token, + } + self.async_milvus_client_wrap.init_async_client(**kwargs) + def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None, enable_dynamic_field=False, with_json=True, **kwargs): name = cf.gen_unique_str('coll_') if name is None else name diff --git a/tests/python_client/base/high_level_api_wrapper.py b/tests/python_client/base/high_level_api_wrapper.py index 4dc62e17f2792..814f2001774b9 100644 --- a/tests/python_client/base/high_level_api_wrapper.py +++ b/tests/python_client/base/high_level_api_wrapper.py @@ -1,5 +1,7 @@ import sys import time +from typing import Optional + import timeout_decorator from numpy import NaN @@ -40,6 +42,13 @@ def init_milvus_client(self, uri, user="", password="", db_name="", token="", ti timeout=timeout, **kwargs).run() return res, check_result + @trace() + def close(self, client, check_task=None, check_items=None): + func_name = sys._getframe().f_code.co_name + res, is_succ = api_request([client.close]) + check_result = ResponseChecker(res, func_name, check_task, check_items, is_succ).run() + return res, check_result + @trace() def create_schema(self, client, timeout=None, check_task=None, check_items=None, **kwargs): @@ -103,6 +112,17 @@ def upsert(self, client, collection_name, data, timeout=None, check_task=None, c **kwargs).run() return res, check_result + @trace() + def get_collection_stats(self, client, collection_name, timeout=None, check_task=None, check_items=None, **kwargs): + timeout = TIMEOUT if timeout is None else timeout + kwargs.update({"timeout": timeout}) + + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.get_collection_stats, collection_name], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, check_items, check, + collection_name=collection_name, **kwargs).run() + return res, check_result + @trace() def search(self, client, collection_name, data, limit=10, filter=None, output_fields=None, search_params=None, timeout=None, check_task=None, check_items=None, **kwargs): @@ -315,6 +335,16 @@ def rename_collection(self, client, old_name, new_name, timeout=None, check_task **kwargs).run() return res, check_result + @trace() + def create_database(self, client, db_name, properties: Optional[dict] = None, check_task=None, check_items=None, **kwargs): + func_name = sys._getframe().f_code.co_name + res, check = api_request([client.create_database, db_name, properties], **kwargs) + check_result = ResponseChecker(res, func_name, check_task, + check_items, check, + db_name=db_name, properties=properties, + **kwargs).run() + return res, check_result + @trace() def create_partition(self, client, collection_name, partition_name, timeout=None, check_task=None, check_items=None, **kwargs): timeout = TIMEOUT if timeout is None else timeout diff --git a/tests/python_client/check/func_check.py b/tests/python_client/check/func_check.py index 7f757eb5325a1..d1b70206c2384 100644 --- a/tests/python_client/check/func_check.py +++ b/tests/python_client/check/func_check.py @@ -7,10 +7,21 @@ from common.common_type import CheckTasks, Connect_Object_Name # from common.code_mapping import ErrorCode, ErrorMessage from pymilvus import Collection, Partition, ResourceGroupInfo -from utils.api_request import Error import check.param_check as pc +class Error: + def __init__(self, error): + self.code = getattr(error, 'code', -1) + self.message = getattr(error, 'message', str(error)) + + def __str__(self): + return f"Error(code={self.code}, message={self.message})" + + def __repr__(self): + return f"Error(code={self.code}, message={self.message})" + + class ResponseChecker: def __init__(self, response, func_name, check_task, check_items, is_succ=True, **kwargs): self.response = response # response of api request diff --git a/tests/python_client/conftest.py b/tests/python_client/conftest.py index 04a03dc71072f..03f6b2007d57f 100644 --- a/tests/python_client/conftest.py +++ b/tests/python_client/conftest.py @@ -25,7 +25,7 @@ def pytest_addoption(parser): parser.addoption("--user", action="store", default="", help="user name for connection") parser.addoption("--password", action="store", default="", help="password for connection") parser.addoption("--db_name", action="store", default="default", help="database name for connection") - parser.addoption("--secure", type=bool, action="store", default=False, help="secure for connection") + parser.addoption("--secure", action="store", default=False, help="secure for connection") parser.addoption("--milvus_ns", action="store", default="chaos-testing", help="milvus_ns") parser.addoption("--http_port", action="store", default=19121, help="http's port") parser.addoption("--handler", action="store", default="GRPC", help="handler of request") @@ -45,7 +45,7 @@ def pytest_addoption(parser): parser.addoption('--term_expr', action='store', default="term_expr", help="expr of query quest") parser.addoption('--check_content', action='store', default="check_content", help="content of check") parser.addoption('--field_name', action='store', default="field_name", help="field_name of index") - parser.addoption('--replica_num', type='int', action='store', default=ct.default_replica_num, help="memory replica number") + parser.addoption('--replica_num', action='store', default=ct.default_replica_num, help="memory replica number") parser.addoption('--minio_host', action='store', default="localhost", help="minio service's ip") parser.addoption('--uri', action='store', default="", help="uri for high level api") parser.addoption('--token', action='store', default="", help="token for high level api") diff --git a/tests/python_client/pytest.ini b/tests/python_client/pytest.ini index c89c29238acf6..8869b6d5de330 100644 --- a/tests/python_client/pytest.ini +++ b/tests/python_client/pytest.ini @@ -10,3 +10,6 @@ log_date_format = %Y-%m-%d %H:%M:%S filterwarnings = ignore::DeprecationWarning + +asyncio_default_fixture_loop_scope = function + diff --git a/tests/python_client/requirements.txt b/tests/python_client/requirements.txt index 793e07251e452..e29aea109cf63 100644 --- a/tests/python_client/requirements.txt +++ b/tests/python_client/requirements.txt @@ -4,7 +4,8 @@ requests==2.26.0 scikit-learn==1.1.3 timeout_decorator==0.5.0 ujson==5.5.0 -pytest==7.2.0 +pytest==8.3.4 +pytest-asyncio==0.24.0 pytest-assume==2.4.3 pytest-timeout==1.3.3 pytest-repeat==0.8.0 @@ -27,8 +28,8 @@ pytest-parallel pytest-random-order # pymilvus -pymilvus==2.5.1rc25 -pymilvus[bulk_writer]==2.5.1rc25 +pymilvus==2.5.2rc3 +pymilvus[bulk_writer]==2.5.2rc3 # for customize config test diff --git a/tests/python_client/testcases/async_milvus_client/test_e2e_async.py b/tests/python_client/testcases/async_milvus_client/test_e2e_async.py new file mode 100644 index 0000000000000..827282e5436aa --- /dev/null +++ b/tests/python_client/testcases/async_milvus_client/test_e2e_async.py @@ -0,0 +1,509 @@ +import random +import time +import pytest +import asyncio +from pymilvus.client.types import LoadState, DataType +from pymilvus import AnnSearchRequest, RRFRanker + +from base.client_base import TestcaseBase +from common import common_func as cf +from common import common_type as ct +from common.common_type import CaseLabel, CheckTasks +from utils.util_log import test_log as log + +pytestmark = pytest.mark.asyncio + +prefix = "async" +async_default_nb = 5000 +default_pk_name = "id" +default_vector_name = "vector" + + +class TestAsyncMilvusClient(TestcaseBase): + + def teardown_method(self, method): + loop = asyncio.get_event_loop() + loop.run_until_complete(self.async_milvus_client_wrap.close()) + super().teardown_method(method) + + @pytest.mark.tags(CaseLabel.L0) + async def test_async_client_default(self): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + self.init_async_milvus_client() + + # create collection + c_name = cf.gen_unique_str(prefix) + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + + # insert entities + rows = [ + {default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]} + for i in range(async_default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, async_default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step]) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # dql tasks + tasks = [] + # search default + vector = cf.gen_vectors(ct.default_nq, ct.default_dim) + default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(default_search_task) + + # search with filter & search_params + sp = {"metric_type": "COSINE", "params": {"ef": "96"}} + filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + filter=f"{default_pk_name} > 10", + search_params=sp, + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(filter_params_search_task) + + # search output fields + output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + output_fields=["*"], + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(output_search_task) + + # query with filter and default output "*" + exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)] + filter_query_task = self.async_milvus_client_wrap.query(c_name, + filter=f"{default_pk_name} < {ct.default_limit}", + output_fields=[default_pk_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": exp_query_res, + "primary_field": default_pk_name}) + tasks.append(filter_query_task) + # query with ids and output all fields + ids_query_task = self.async_milvus_client_wrap.query(c_name, + ids=[i for i in range(ct.default_limit)], + output_fields=["*"], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": rows[:ct.default_limit], + "with_vec": True, + "primary_field": default_pk_name}) + tasks.append(ids_query_task) + # get with ids + get_task = self.async_milvus_client_wrap.get(c_name, + ids=[0, 1], + output_fields=[default_pk_name, default_vector_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": rows[:2], "with_vec": True, + "primary_field": default_pk_name}) + tasks.append(get_task) + await asyncio.gather(*tasks) + + @pytest.mark.tags(CaseLabel.L0) + async def test_async_client_partition(self): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + self.init_async_milvus_client() + + # create collection & partition + c_name = cf.gen_unique_str(prefix) + p_name = cf.gen_unique_str("par") + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + self.high_level_api_wrap.create_partition(milvus_client, c_name, p_name) + partitions, _ = self.high_level_api_wrap.list_partitions(milvus_client, c_name) + assert p_name in partitions + + # insert entities + rows = [ + {default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]} + for i in range(async_default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, async_default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step], partition_name=p_name) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # count from default partition + count_res, _ = await self.async_milvus_client_wrap.query(c_name, output_fields=["count(*)"], partition_names=[ct.default_partition_name]) + assert count_res[0]["count(*)"] == 0 + + # dql tasks + tasks = [] + # search default + vector = cf.gen_vectors(ct.default_nq, ct.default_dim) + default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + partition_names=[p_name], + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(default_search_task) + + # search with filter & search_params + sp = {"metric_type": "COSINE", "params": {"ef": "96"}} + filter_params_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + filter=f"{default_pk_name} > 10", + search_params=sp, + partition_names=[p_name], + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(filter_params_search_task) + + # search output fields + output_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + output_fields=["*"], + partition_names=[p_name], + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(output_search_task) + + # query with filter and default output "*" + exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)] + filter_query_task = self.async_milvus_client_wrap.query(c_name, + filter=f"{default_pk_name} < {ct.default_limit}", + output_fields=[default_pk_name], + partition_names=[p_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": exp_query_res, + "primary_field": default_pk_name}) + tasks.append(filter_query_task) + # query with ids and output all fields + ids_query_task = self.async_milvus_client_wrap.query(c_name, + ids=[i for i in range(ct.default_limit)], + output_fields=["*"], + partition_names=[p_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": rows[:ct.default_limit], + "with_vec": True, + "primary_field": default_pk_name}) + tasks.append(ids_query_task) + # get with ids + get_task = self.async_milvus_client_wrap.get(c_name, + ids=[0, 1], partition_names=[p_name], + output_fields=[default_pk_name, default_vector_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": rows[:2], "with_vec": True, + "primary_field": default_pk_name}) + tasks.append(get_task) + await asyncio.gather(*tasks) + + @pytest.mark.tags(CaseLabel.L0) + async def test_async_client_with_schema(self, schema): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + self.init_async_milvus_client() + + # create collection + c_name = cf.gen_unique_str(prefix) + schema = self.async_milvus_client_wrap.create_schema(auto_id=False, + partition_key_field=ct.default_int64_field_name) + schema.add_field(ct.default_string_field_name, DataType.VARCHAR, max_length=100, is_primary=True) + schema.add_field(ct.default_int64_field_name, DataType.INT64, is_partition_key=True) + schema.add_field(ct.default_float_vec_field_name, DataType.FLOAT_VECTOR, dim=ct.default_dim) + schema.add_field(default_vector_name, DataType.FLOAT_VECTOR, dim=ct.default_dim) + await self.async_milvus_client_wrap.create_collection(c_name, schema=schema) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + + # insert entities + rows = [ + {ct.default_string_field_name: str(i), + ct.default_int64_field_name: i, + ct.default_float_vec_field_name: [random.random() for _ in range(ct.default_dim)], + default_vector_name: [random.random() for _ in range(ct.default_dim)], + } for i in range(async_default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, async_default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step]) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # flush + self.high_level_api_wrap.flush(milvus_client, c_name) + stats, _ = self.high_level_api_wrap.get_collection_stats(milvus_client, c_name) + assert stats["row_count"] == async_default_nb + + # create index -> load + index_params, _ = self.high_level_api_wrap.prepare_index_params(milvus_client, + field_name=ct.default_float_vec_field_name, + index_type="HNSW", metric_type="COSINE", M=30, + efConstruction=200) + index_params.add_index(field_name=default_vector_name, index_type="IVF_SQ8", + metric_type="L2", nlist=32) + await self.async_milvus_client_wrap.create_index(c_name, index_params) + await self.async_milvus_client_wrap.load_collection(c_name) + + _index, _ = self.high_level_api_wrap.describe_index(milvus_client, c_name, default_vector_name) + assert _index["indexed_rows"] == async_default_nb + assert _index["state"] == "Finished" + _load, _ = self.high_level_api_wrap.get_load_state(milvus_client, c_name) + assert _load["state"] == LoadState.Loaded + + # dql tasks + tasks = [] + # search default + vector = cf.gen_vectors(ct.default_nq, ct.default_dim) + default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + anns_field=ct.default_float_vec_field_name, + search_params={"metric_type": "COSINE", + "params": {"ef": "96"}}, + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(default_search_task) + + # hybrid_search + search_param = { + "data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"), + "anns_field": ct.default_float_vec_field_name, + "param": {"metric_type": "COSINE", "params": {"ef": "96"}}, + "limit": ct.default_limit, + "expr": f"{ct.default_int64_field_name} > 10"} + req = AnnSearchRequest(**search_param) + + search_param2 = { + "data": cf.gen_vectors(ct.default_nq, ct.default_dim, vector_data_type="FLOAT_VECTOR"), + "anns_field": default_vector_name, + "param": {"metric_type": "L2", "params": {"nprobe": "32"}}, + "limit": ct.default_limit + } + req2 = AnnSearchRequest(**search_param2) + _output_fields = [ct.default_int64_field_name, ct.default_string_field_name] + filter_params_search_task = self.async_milvus_client_wrap.hybrid_search(c_name, [req, req2], RRFRanker(), + limit=5, + check_task=CheckTasks.check_search_results, + check_items={ + "enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": 5}) + tasks.append(filter_params_search_task) + + # get with ids + get_task = self.async_milvus_client_wrap.get(c_name, ids=['0', '1'], output_fields=[ct.default_int64_field_name, + ct.default_string_field_name]) + tasks.append(get_task) + await asyncio.gather(*tasks) + + @pytest.mark.tags(CaseLabel.L0) + async def test_async_client_dml(self): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + self.init_async_milvus_client() + + # create collection + c_name = cf.gen_unique_str(prefix) + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + + # insert entities + rows = [ + {default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]} + for i in range(ct.default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, ct.default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step]) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # dml tasks + # query id -> upsert id -> query id -> delete id -> query id + _id = 10 + get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id], + output_fields=[default_pk_name, default_vector_name]) + assert len(get_res) == 1 + + # upsert + upsert_row = [{ + default_pk_name: _id, default_vector_name: [random.random() for _ in range(ct.default_dim)] + }] + upsert_res, _ = await self.async_milvus_client_wrap.upsert(c_name, upsert_row) + assert upsert_res["upsert_count"] == 1 + + # get _id after upsert + get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id], + output_fields=[default_pk_name, default_vector_name]) + for j in range(5): + assert abs(get_res[0][default_vector_name][j] - upsert_row[0][default_vector_name][j]) < ct.epsilon + + # delete + del_res, _ = await self.async_milvus_client_wrap.delete(c_name, ids=[_id]) + assert del_res["delete_count"] == 1 + + # query after delete + get_res, _ = await self.async_milvus_client_wrap.get(c_name, ids=[_id], + output_fields=[default_pk_name, default_vector_name]) + assert len(get_res) == 0 + + @pytest.mark.tags(CaseLabel.L2) + async def test_async_client_with_db(self): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + db_name = cf.gen_unique_str("db") + self.high_level_api_wrap.create_database(milvus_client, db_name) + self.high_level_api_wrap.close(milvus_client) + uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" + milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, db_name=db_name) + self.async_milvus_client_wrap.init_async_client(uri, db_name=db_name) + + # create collection + c_name = cf.gen_unique_str(prefix) + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + + # insert entities + rows = [ + {default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]} + for i in range(async_default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, async_default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step]) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # dql tasks + tasks = [] + # search default + vector = cf.gen_vectors(ct.default_nq, ct.default_dim) + default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(default_search_task) + + # query with filter and default output "*" + exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)] + filter_query_task = self.async_milvus_client_wrap.query(c_name, + filter=f"{default_pk_name} < {ct.default_limit}", + output_fields=[default_pk_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": exp_query_res, + "primary_field": default_pk_name}) + tasks.append(filter_query_task) + + # get with ids + get_task = self.async_milvus_client_wrap.get(c_name, + ids=[0, 1], + output_fields=[default_pk_name, default_vector_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": rows[:2], "with_vec": True, + "primary_field": default_pk_name}) + tasks.append(get_task) + await asyncio.gather(*tasks) + + @pytest.mark.tags(CaseLabel.L0) + async def test_async_client_close(self): + # init async client + uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" + self.async_milvus_client_wrap.init_async_client(uri) + + # create collection + c_name = cf.gen_unique_str(prefix) + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + + # close -> search raise error + await self.async_milvus_client_wrap.close() + vector = cf.gen_vectors(1, ct.default_dim) + error = {ct.err_code: 1, ct.err_msg: "should create connection first"} + await self.async_milvus_client_wrap.search(c_name, vector, check_task=CheckTasks.err_res, check_items=error) + + @pytest.mark.tags(CaseLabel.L3) + @pytest.mark.skip("connect with zilliz cloud") + async def test_async_client_with_token(self): + # init client + milvus_client = self._connect(enable_milvus_client_api=True) + uri = cf.param_info.param_uri or f"http://{cf.param_info.param_host}:{cf.param_info.param_port}" + token = cf.param_info.param_token + milvus_client, _ = self.connection_wrap.MilvusClient(uri=uri, token=token) + self.async_milvus_client_wrap.init_async_client(uri, token=token) + + # create collection + c_name = cf.gen_unique_str(prefix) + await self.async_milvus_client_wrap.create_collection(c_name, dimension=ct.default_dim) + collections, _ = self.high_level_api_wrap.list_collections(milvus_client) + assert c_name in collections + + # insert entities + rows = [ + {default_pk_name: i, default_vector_name: [random.random() for _ in range(ct.default_dim)]} + for i in range(ct.default_nb)] + start_time = time.time() + tasks = [] + step = 1000 + for i in range(0, ct.default_nb, step): + task = self.async_milvus_client_wrap.insert(c_name, rows[i:i + step]) + tasks.append(task) + insert_res = await asyncio.gather(*tasks) + end_time = time.time() + log.info("Total time: {:.2f} seconds".format(end_time - start_time)) + for r in insert_res: + assert r[0]['insert_count'] == step + + # dql tasks + tasks = [] + # search default + vector = cf.gen_vectors(ct.default_nq, ct.default_dim) + default_search_task = self.async_milvus_client_wrap.search(c_name, vector, limit=ct.default_limit, + check_task=CheckTasks.check_search_results, + check_items={"enable_milvus_client_api": True, + "nq": ct.default_nq, + "limit": ct.default_limit}) + tasks.append(default_search_task) + + # query with filter and default output "*" + exp_query_res = [{default_pk_name: i} for i in range(ct.default_limit)] + filter_query_task = self.async_milvus_client_wrap.query(c_name, + filter=f"{default_pk_name} < {ct.default_limit}", + output_fields=[default_pk_name], + check_task=CheckTasks.check_query_results, + check_items={"exp_res": exp_query_res, + "primary_field": default_pk_name}) + tasks.append(filter_query_task) + await asyncio.gather(*tasks) diff --git a/tests/python_client/utils/api_request.py b/tests/python_client/utils/api_request.py index c313c50166cbb..244bfa86cc521 100644 --- a/tests/python_client/utils/api_request.py +++ b/tests/python_client/utils/api_request.py @@ -1,24 +1,14 @@ +import sys import traceback import copy -import os + +from check.func_check import ResponseChecker, Error from utils.util_log import test_log as log # enable_traceback = os.getenv('ENABLE_TRACEBACK', "True") # log.info(f"enable_traceback:{enable_traceback}") -class Error: - def __init__(self, error): - self.code = getattr(error, 'code', -1) - self.message = getattr(error, 'message', str(error)) - - def __str__(self): - return f"Error(code={self.code}, message={self.message})" - - def __repr__(self): - return f"Error(code={self.code}, message={self.message})" - - log_row_length = 300 @@ -62,3 +52,50 @@ def api_request(_list, **kwargs): log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__qualname__, log_arg, str(kwargs))) return func(*arg, **kwargs) return False, False + + +def logger_interceptor(): + def wrapper(func): + def log_request(*arg, **kwargs): + arg = arg[1:] + arg_str = str(arg) + log_arg = arg_str[0:log_row_length] + '......' if len(arg_str) > log_row_length else arg_str + if kwargs.get("enable_traceback", True): + log.debug("(api_request) : [%s] args: %s, kwargs: %s" % (func.__name__, log_arg, str(kwargs))) + + def log_response(res, **kwargs): + if kwargs.get("enable_traceback", True): + res_str = str(res) + log_res = res_str[0:log_row_length] + '......' if len(res_str) > log_row_length else res_str + log.debug("(api_response) : [%s] %s " % (func.__name__, log_res)) + return res, True + + async def handler(*args, **kwargs): + _kwargs = copy.deepcopy(kwargs) + _kwargs.pop("enable_traceback", None) + check_task = kwargs.get("check_task", None) + check_items = kwargs.get("check_items", None) + try: + # log request + log_request(*args, **_kwargs) + # exec func + res = await func(*args, **_kwargs) + # log response + log_response(res, **_kwargs) + # check_response + check_res = ResponseChecker(res, sys._getframe().f_code.co_name, check_task, check_items, True).run() + return res, check_res + except Exception as e: + log.error(str(e)) + e_str = str(e) + log_e = e_str[0:log_row_length] + '......' if len(e_str) > log_row_length else e_str + if kwargs.get("enable_traceback", True): + log.error(traceback.format_exc()) + log.error("(api_response) : %s" % log_e) + check_res = ResponseChecker(Error(e), sys._getframe().f_code.co_name, check_task, + check_items, False).run() + return Error(e), check_res + + return handler + + return wrapper