From 70b3f9f457a74661eb846b5f0d33912a879defce Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Fri, 18 Oct 2024 17:34:23 +0800 Subject: [PATCH] simplified the logic to check if the insert/request data matches the schema Signed-off-by: Buqian Zheng --- pymilvus/client/constants.py | 1 + pymilvus/client/prepare.py | 104 ++++++++++++++++------------------- 2 files changed, 48 insertions(+), 57 deletions(-) diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index 1aed48898..d9e5aa715 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -8,6 +8,7 @@ BOUNDED_TS = 2 DEFAULT_CONSISTENCY_LEVEL = ConsistencyLevel.Bounded DEFAULT_RESOURCE_GROUP = "__default_resource_group" +DYNAMIC_FIELD_NAME = "$meta" REDUCE_STOP_FOR_BEST = "reduce_stop_for_best" GROUP_BY_FIELD = "group_by_field" GROUP_SIZE = "group_size" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 09a02a316..b368a233f 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -15,6 +15,7 @@ from .check import check_pass_param, is_legal_collection_properties from .constants import ( DEFAULT_CONSISTENCY_LEVEL, + DYNAMIC_FIELD_NAME, GROUP_BY_FIELD, GROUP_SIZE, GROUP_STRICT_SIZE, @@ -29,7 +30,7 @@ ResourceGroupConfig, get_consistency_level, ) -from .utils import traverse_info, traverse_rows_info, traverse_upsert_info +from .utils import traverse_info, traverse_upsert_info class Prepare: @@ -386,29 +387,35 @@ def partition_name(cls, collection_name: str, partition_name: str): return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name) @staticmethod - def _num_input_fields(fields_info: List[Dict]): - return len(fields_info) - len( - [field for field in fields_info if field.get("is_function_output", False)] + def _is_input_field(field: Dict, is_upsert: bool): + return (not field.get("auto_id", False) or is_upsert) and not field.get( + "is_function_output", False ) + @staticmethod + def _num_input_fields(fields_info: List[Dict], is_upsert: bool): + return len([field for field in fields_info if Prepare._is_input_field(field, is_upsert)]) + @staticmethod def _parse_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], - fields_info: dict, + fields_info: List[Dict], enable_dynamic: bool, entities: List, ): + input_fields_info = [ + field for field in fields_info if Prepare._is_input_field(field, is_upsert=False) + ] fields_data = { field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"]) - for field in fields_info - if not field.get("auto_id", False) - } - field_info_map = { - field["name"]: field for field in fields_info if not field.get("auto_id", False) + for field in input_fields_info } + field_info_map = {field["name"]: field for field in input_fields_info} if enable_dynamic: - d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON) + d_field = schema_types.FieldData( + field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON + ) fields_data[d_field.field_name] = d_field field_info_map[d_field.field_name] = d_field @@ -430,13 +437,9 @@ def _parse_row_request( ): field_data.valid_data.append(v is not None) entity_helper.pack_field_value_to_field_data(v, field_data, field_info) - for field in fields_info: + for field in input_fields_info: key = field["name"] - if ( - key in entity - or field.get("auto_id", False) - or field.get("is_function_output", False) - ): + if key in entity: continue field_info, field_data = field_info_map[key], fields_data[key] @@ -458,44 +461,35 @@ def _parse_row_request( except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e - request.fields_data.extend( - [ - fields_data[field["name"]] - for field in fields_info - if not field.get("is_function_output", False) and not field.get("auto_id", False) - ] - ) - - if enable_dynamic: - request.fields_data.append(d_field) + request.fields_data.extend(fields_data.values()) - num_input_fields = Prepare._num_input_fields(fields_info) + expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0) - _, _, auto_id_loc = traverse_rows_info(fields_info, entities) - if auto_id_loc is not None: - if (enable_dynamic and len(fields_data) != num_input_fields) or ( - not enable_dynamic and len(fields_data) + 1 != num_input_fields - ): - raise ParamError(ExceptionsMessage.FieldsNumInconsistent) - elif enable_dynamic and len(fields_data) != num_input_fields + 1: + if len(fields_data) != expected_num_input_fields: raise ParamError(ExceptionsMessage.FieldsNumInconsistent) + return request @staticmethod def _parse_upsert_row_request( request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest], - fields_info: dict, + fields_info: List[Dict], enable_dynamic: bool, entities: List, ): + input_fields_info = [ + field for field in fields_info if Prepare._is_input_field(field, is_upsert=True) + ] fields_data = { field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"]) - for field in fields_info + for field in input_fields_info } - field_info_map = {field["name"]: field for field in fields_info} + field_info_map = {field["name"]: field for field in input_fields_info} if enable_dynamic: - d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON) + d_field = schema_types.FieldData( + field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON + ) fields_data[d_field.field_name] = d_field field_info_map[d_field.field_name] = d_field @@ -519,7 +513,7 @@ def _parse_upsert_row_request( entity_helper.pack_field_value_to_field_data(v, field_data, field_info) for field in fields_info: key = field["name"] - if key in entity or field.get("is_function_output", False): + if key in entity: continue field_info, field_data = field_info_map[key], fields_data[key] @@ -541,10 +535,7 @@ def _parse_upsert_row_request( except (TypeError, ValueError) as e: raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e - request.fields_data.extend([fields_data[field["name"]] for field in fields_info]) - - if enable_dynamic: - request.fields_data.append(d_field) + request.fields_data.extend(fields_data.values()) for _, field in enumerate(fields_info): is_dynamic = False @@ -558,10 +549,12 @@ def _parse_upsert_row_request( raise ParamError( message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]" ) - if (enable_dynamic and len(fields_data) != len(fields_info) + 1) or ( - not enable_dynamic and len(fields_data) != len(fields_info) - ): + + expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0) + + if len(fields_data) != expected_num_input_fields: raise ParamError(ExceptionsMessage.FieldsNumInconsistent) + return request @classmethod @@ -625,21 +618,18 @@ def _pre_insert_batch_check( if not fields_info: raise ParamError(message="Missing collection meta to validate entities") - location, primary_key_loc, auto_id_loc = traverse_info(fields_info) + location, primary_key_loc, _ = traverse_info(fields_info) # though impossible from sdk if primary_key_loc is None: raise ParamError(message="primary key not found") - num_input_fields = Prepare._num_input_fields(fields_info) + expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=False) - if auto_id_loc is None and len(entities) != num_input_fields: - msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" + if len(entities) != expected_num_input_fields: + msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}" raise ParamError(msg) - if auto_id_loc is not None and len(entities) + 1 != num_input_fields: - msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" - raise ParamError(msg) return location @staticmethod @@ -665,10 +655,10 @@ def _pre_upsert_batch_check( if primary_key_loc is None: raise ParamError(message="primary key not found") - num_input_fields = Prepare._num_input_fields(fields_info) + expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=True) - if len(entities) != num_input_fields: - msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}" + if len(entities) != expected_num_input_fields: + msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}" raise ParamError(msg) return location