Skip to content

Commit

Permalink
fix: not append valid data when transfer to insert record (milvus-io#…
Browse files Browse the repository at this point in the history
…36027)

fix not append valid data when transfer to insert record and add a tiny
check when in groupBy field.
milvus-io#35924

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Sep 6, 2024
1 parent 5247631 commit 21b135c
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 10 deletions.
15 changes: 6 additions & 9 deletions internal/core/src/segcore/InsertRecord.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,15 @@ class ThreadSafeValidData {
total += field_data->get_num_rows();
}
if (length_ + total > data_.size()) {
data_.reserve(length_ + total);
data_.resize(length_ + total);
}
length_ += total;

for (auto& field_data : datas) {
auto num_row = field_data->get_num_rows();
for (size_t i = 0; i < num_row; i++) {
data_.push_back(field_data->is_valid(i));
data_[length_ + i] = field_data->is_valid(i);
}
length_ += num_row;
}
}

Expand All @@ -311,14 +312,10 @@ class ThreadSafeValidData {
std::unique_lock<std::shared_mutex> lck(mutex_);
if (field_meta.is_nullable()) {
if (length_ + num_rows > data_.size()) {
data_.reserve(length_ + num_rows);
data_.resize(length_ + num_rows);
}

auto src = data->valid_data().data();
for (size_t i = 0; i < num_rows; ++i) {
data_.push_back(src[i]);
// data_[length_ + i] = src[i];
}
std::copy_n(src, num_rows, data_.data() + length_);
length_ += num_rows;
}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
if groupByFieldName != "" {
fields := schema.GetFields()
for _, field := range fields {
if field.GetNullable() {
return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("groupBy field(%s) not support nullable == true", groupByFieldName))
}
if field.Name == groupByFieldName {
groupByFieldId = field.FieldID
break
Expand Down
19 changes: 19 additions & 0 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2188,6 +2188,25 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
t.Run("check nullable and groupBy", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Key: GroupByFieldKey,
Value: "string_field",
})
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: int64(101),
Name: "string_field",
Nullable: true,
})
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, false)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
t.Run("check iterator and topK", func(t *testing.T) {
normalParam := getValidSearchParams()
normalParam = append(normalParam, &commonpb.KeyValuePair{
Expand Down
10 changes: 10 additions & 0 deletions internal/storage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *Int8FieldData:
int32Data := make([]int32, len(rawData.Data))
Expand All @@ -1058,6 +1059,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *Int16FieldData:
int32Data := make([]int32, len(rawData.Data))
Expand All @@ -1076,6 +1078,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *Int32FieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1090,6 +1093,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *Int64FieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1104,6 +1108,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *FloatFieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1118,6 +1123,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *DoubleFieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1132,6 +1138,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *StringFieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1146,6 +1153,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *ArrayFieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1160,6 +1168,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *JSONFieldData:
fieldData = &schemapb.FieldData{
Expand All @@ -1174,6 +1183,7 @@ func TransferInsertDataToInsertRecord(insertData *InsertData) (*segcorepb.Insert
},
},
},
ValidData: rawData.ValidData,
}
case *FloatVectorFieldData:
fieldData = &schemapb.FieldData{
Expand Down
4 changes: 3 additions & 1 deletion pkg/util/typeutil/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,9 @@ func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error
dst = append(dst, scalarFieldData)
fieldID2Data[srcFieldData.FieldId] = scalarFieldData
}
dstScalar := fieldID2Data[srcFieldData.FieldId].GetScalars()
fieldData := fieldID2Data[srcFieldData.FieldId]
fieldData.ValidData = append(fieldData.ValidData, srcFieldData.GetValidData()...)
dstScalar := fieldData.GetScalars()
switch srcScalar := fieldType.Scalars.Data.(type) {
case *schemapb.ScalarField_BoolData:
if dstScalar.GetBoolData() == nil {
Expand Down
132 changes: 132 additions & 0 deletions tests/integration/null_data/null_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,138 @@ func (s *NullDataSuite) run() {
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
s.checkNullableFieldData(nullableFid.GetName(), queryResult.GetFieldsData(), start)

fieldsData[2] = integration.NewInt64FieldDataNullableWithStart(nullableFid.GetName(), rowNum, start)
fieldsDataForUpsert := make([]*schemapb.FieldData, 0)
fieldsDataForUpsert = append(fieldsDataForUpsert, integration.NewInt64FieldDataWithStart(integration.Int64Field, rowNum, start))
fieldsDataForUpsert = append(fieldsDataForUpsert, fVecColumn)
nullableFidDataForUpsert := &schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldName: nullableFid.GetName(),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{},
},
},
},
},
ValidData: make([]bool, rowNum),
}
fieldsDataForUpsert = append(fieldsDataForUpsert, nullableFidDataForUpsert)
insertResult, err = c.Proxy.Insert(ctx, &milvuspb.InsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: fieldsData,
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(insertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)

upsertResult, err := c.Proxy.Upsert(ctx, &milvuspb.UpsertRequest{
DbName: dbName,
CollectionName: collectionName,
FieldsData: fieldsDataForUpsert,
HashKeys: hashKeys,
NumRows: uint32(rowNum),
})
s.NoError(err)
s.Equal(upsertResult.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)

// create index
createIndexStatus, err = c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: fVecColumn.FieldName,
IndexName: "_default",
ExtraParams: integration.ConstructIndexParam(dim, s.indexType, s.metricType),
})
if createIndexStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("createIndexStatus fail reason", zap.String("reason", createIndexStatus.GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, createIndexStatus.GetErrorCode())

s.WaitForIndexBuilt(ctx, collectionName, fVecColumn.FieldName)

desCollResp, err = c.Proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{
CollectionName: collectionName,
})
s.NoError(err)
s.Equal(desCollResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)

compactResp, err = c.Proxy.ManualCompaction(ctx, &milvuspb.ManualCompactionRequest{
CollectionID: desCollResp.GetCollectionID(),
})

s.NoError(err)
s.Equal(compactResp.GetStatus().GetErrorCode(), commonpb.ErrorCode_Success)

compacted = func() bool {
resp, err := c.Proxy.GetCompactionState(ctx, &milvuspb.GetCompactionStateRequest{
CompactionID: compactResp.GetCompactionID(),
})
if err != nil {
return false
}
return resp.GetState() == commonpb.CompactionState_Completed
}
for !compacted() {
time.Sleep(3 * time.Second)
}

// load
loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
DbName: dbName,
CollectionName: collectionName,
})
s.NoError(err)
if loadStatus.GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("loadStatus fail reason", zap.String("reason", loadStatus.GetReason()))
}
s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode())
s.WaitForLoad(ctx, collectionName)

// flush
flushResp, err = c.Proxy.Flush(ctx, &milvuspb.FlushRequest{
DbName: dbName,
CollectionNames: []string{collectionName},
})
s.NoError(err)
segmentIDs, has = flushResp.GetCollSegIDs()[collectionName]
ids = segmentIDs.GetData()
s.Require().NotEmpty(segmentIDs)
s.Require().True(has)
flushTs, has = flushResp.GetCollFlushTs()[collectionName]
s.True(has)

segments, err = c.MetaWatcher.ShowSegments()
s.NoError(err)
s.NotEmpty(segments)
for _, segment := range segments {
log.Info("ShowSegments result", zap.String("segment", segment.String()))
}
s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName)

// search
searchResult, err = c.Proxy.Search(ctx, searchReq)
err = merr.CheckRPCCall(searchResult, err)
s.NoError(err)
s.checkNullableFieldData(nullableFid.GetName(), searchResult.GetResults().GetFieldsData(), start)

queryResult, err = c.Proxy.Query(ctx, &milvuspb.QueryRequest{
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: []string{"nullableFid"},
})
if queryResult.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("searchResult fail reason", zap.String("reason", queryResult.GetStatus().GetReason()))
}
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, queryResult.GetStatus().GetErrorCode())
s.checkNullableFieldData(nullableFid.GetName(), queryResult.GetFieldsData(), start)

// // expr will not select null data
// exprResult, err := c.Proxy.Query(ctx, &milvuspb.QueryRequest{
// DbName: dbName,
Expand Down

0 comments on commit 21b135c

Please sign in to comment.