Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: simplified the logic to check if the insert/request data matches the schema #2303

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
BOUNDED_TS = 2
DEFAULT_CONSISTENCY_LEVEL = ConsistencyLevel.Bounded
DEFAULT_RESOURCE_GROUP = "__default_resource_group"
DYNAMIC_FIELD_NAME = "$meta"
REDUCE_STOP_FOR_BEST = "reduce_stop_for_best"
GROUP_BY_FIELD = "group_by_field"
GROUP_SIZE = "group_size"
Expand Down
104 changes: 47 additions & 57 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .check import check_pass_param, is_legal_collection_properties
from .constants import (
DEFAULT_CONSISTENCY_LEVEL,
DYNAMIC_FIELD_NAME,
GROUP_BY_FIELD,
GROUP_SIZE,
GROUP_STRICT_SIZE,
Expand All @@ -29,7 +30,7 @@
ResourceGroupConfig,
get_consistency_level,
)
from .utils import traverse_info, traverse_rows_info, traverse_upsert_info
from .utils import traverse_info, traverse_upsert_info


class Prepare:
Expand Down Expand Up @@ -386,29 +387,35 @@ def partition_name(cls, collection_name: str, partition_name: str):
return milvus_types.PartitionName(collection_name=collection_name, tag=partition_name)

@staticmethod
def _num_input_fields(fields_info: List[Dict]):
return len(fields_info) - len(
[field for field in fields_info if field.get("is_function_output", False)]
def _is_input_field(field: Dict, is_upsert: bool):
return (not field.get("auto_id", False) or is_upsert) and not field.get(
"is_function_output", False
)

@staticmethod
def _num_input_fields(fields_info: List[Dict], is_upsert: bool):
return len([field for field in fields_info if Prepare._is_input_field(field, is_upsert)])

@staticmethod
def _parse_row_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
fields_info: dict,
fields_info: List[Dict],
enable_dynamic: bool,
entities: List,
):
input_fields_info = [
field for field in fields_info if Prepare._is_input_field(field, is_upsert=False)
]
fields_data = {
field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"])
for field in fields_info
if not field.get("auto_id", False)
}
field_info_map = {
field["name"]: field for field in fields_info if not field.get("auto_id", False)
for field in input_fields_info
}
field_info_map = {field["name"]: field for field in input_fields_info}

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
d_field = schema_types.FieldData(
field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON
)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

Expand All @@ -430,13 +437,9 @@ def _parse_row_request(
):
field_data.valid_data.append(v is not None)
entity_helper.pack_field_value_to_field_data(v, field_data, field_info)
for field in fields_info:
for field in input_fields_info:
key = field["name"]
if (
key in entity
or field.get("auto_id", False)
or field.get("is_function_output", False)
):
if key in entity:
continue

field_info, field_data = field_info_map[key], fields_data[key]
Expand All @@ -458,44 +461,35 @@ def _parse_row_request(
except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

request.fields_data.extend(
[
fields_data[field["name"]]
for field in fields_info
if not field.get("is_function_output", False) and not field.get("auto_id", False)
]
)

if enable_dynamic:
request.fields_data.append(d_field)
request.fields_data.extend(fields_data.values())

num_input_fields = Prepare._num_input_fields(fields_info)
expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0)

_, _, auto_id_loc = traverse_rows_info(fields_info, entities)
if auto_id_loc is not None:
if (enable_dynamic and len(fields_data) != num_input_fields) or (
not enable_dynamic and len(fields_data) + 1 != num_input_fields
):
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)
elif enable_dynamic and len(fields_data) != num_input_fields + 1:
if len(fields_data) != expected_num_input_fields:
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)

return request

@staticmethod
def _parse_upsert_row_request(
request: Union[milvus_types.InsertRequest, milvus_types.UpsertRequest],
fields_info: dict,
fields_info: List[Dict],
enable_dynamic: bool,
entities: List,
):
input_fields_info = [
field for field in fields_info if Prepare._is_input_field(field, is_upsert=True)
]
fields_data = {
field["name"]: schema_types.FieldData(field_name=field["name"], type=field["type"])
for field in fields_info
for field in input_fields_info
}
field_info_map = {field["name"]: field for field in fields_info}
field_info_map = {field["name"]: field for field in input_fields_info}

if enable_dynamic:
d_field = schema_types.FieldData(is_dynamic=True, type=DataType.JSON)
d_field = schema_types.FieldData(
field_name=DYNAMIC_FIELD_NAME, is_dynamic=True, type=DataType.JSON
)
fields_data[d_field.field_name] = d_field
field_info_map[d_field.field_name] = d_field

Expand All @@ -519,7 +513,7 @@ def _parse_upsert_row_request(
entity_helper.pack_field_value_to_field_data(v, field_data, field_info)
for field in fields_info:
key = field["name"]
if key in entity or field.get("is_function_output", False):
if key in entity:
continue

field_info, field_data = field_info_map[key], fields_data[key]
Expand All @@ -541,10 +535,7 @@ def _parse_upsert_row_request(
except (TypeError, ValueError) as e:
raise DataNotMatchException(message=ExceptionsMessage.DataTypeInconsistent) from e

request.fields_data.extend([fields_data[field["name"]] for field in fields_info])

if enable_dynamic:
request.fields_data.append(d_field)
request.fields_data.extend(fields_data.values())

for _, field in enumerate(fields_info):
is_dynamic = False
Expand All @@ -558,10 +549,12 @@ def _parse_upsert_row_request(
raise ParamError(
message=f"dynamic field enabled, {field_name} shouldn't in entities[{j}]"
)
if (enable_dynamic and len(fields_data) != len(fields_info) + 1) or (
not enable_dynamic and len(fields_data) != len(fields_info)
):

expected_num_input_fields = len(input_fields_info) + (1 if enable_dynamic else 0)

if len(fields_data) != expected_num_input_fields:
raise ParamError(ExceptionsMessage.FieldsNumInconsistent)

return request

@classmethod
Expand Down Expand Up @@ -625,21 +618,18 @@ def _pre_insert_batch_check(
if not fields_info:
raise ParamError(message="Missing collection meta to validate entities")

location, primary_key_loc, auto_id_loc = traverse_info(fields_info)
location, primary_key_loc, _ = traverse_info(fields_info)

# though impossible from sdk
if primary_key_loc is None:
raise ParamError(message="primary key not found")

num_input_fields = Prepare._num_input_fields(fields_info)
expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=False)

if auto_id_loc is None and len(entities) != num_input_fields:
msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}"
if len(entities) != expected_num_input_fields:
msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}"
raise ParamError(msg)

if auto_id_loc is not None and len(entities) + 1 != num_input_fields:
msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}"
raise ParamError(msg)
return location

@staticmethod
Expand All @@ -665,10 +655,10 @@ def _pre_upsert_batch_check(
if primary_key_loc is None:
raise ParamError(message="primary key not found")

num_input_fields = Prepare._num_input_fields(fields_info)
expected_num_input_fields = Prepare._num_input_fields(fields_info, is_upsert=True)

if len(entities) != num_input_fields:
msg = f"number of fields: {num_input_fields}, number of entities: {len(entities)}"
if len(entities) != expected_num_input_fields:
msg = f"expected number of fields: {expected_num_input_fields}, actual number of fields in entities: {len(entities)}"
raise ParamError(msg)
return location

Expand Down