Skip to content

Commit

Permalink
fix: upsert result use the previous pk (milvus-io#34672)
Browse files Browse the repository at this point in the history
milvus-io#34668

Signed-off-by: lixinguo <[email protected]>
Co-authored-by: lixinguo <[email protected]>
  • Loading branch information
smellthemoon and lixinguo authored Jul 31, 2024
1 parent acaa78d commit 6106a48
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 69 deletions.
2 changes: 1 addition & 1 deletion internal/proxy/task_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
// check primaryFieldData whether autoID is true or not
// set rowIDs as primary data if autoID == true
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
it.result.IDs, err = checkPrimaryFieldData(it.schema, it.insertMsg, true)
it.result.IDs, err = checkPrimaryFieldData(it.schema, it.insertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName))
if err != nil {
log.Warn("check primary field data and hash primary key failed",
Expand Down
7 changes: 5 additions & 2 deletions internal/proxy/task_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ type upsertTask struct {
schema *schemaInfo
partitionKeyMode bool
partitionKeys *schemapb.FieldData
// automatic generate pk as new pk wehen autoID == true
// delete task need use the oldIds
oldIds *schemapb.IDs
}

// TraceCtx returns upsertTask context
Expand Down Expand Up @@ -187,7 +190,7 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error {
// use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true
var err error
it.result.IDs, err = checkPrimaryFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg, false)
it.result.IDs, it.oldIds, err = checkUpsertPrimaryFieldData(it.schema.CollectionSchema, it.upsertMsg.InsertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", it.upsertMsg.InsertMsg.CollectionName))
if err != nil {
log.Warn("check primary field data and hash primary key failed when upsert",
Expand Down Expand Up @@ -445,7 +448,7 @@ func (it *upsertTask) deleteExecute(ctx context.Context, msgPack *msgstream.MsgP
it.result.Status = merr.Status(err)
return err
}
it.upsertMsg.DeleteMsg.PrimaryKeys = it.result.IDs
it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIds
it.upsertMsg.DeleteMsg.HashValues = typeutil.HashPK2Channels(it.upsertMsg.DeleteMsg.PrimaryKeys, channelNames)

// repack delete msg by dmChannel
Expand Down
136 changes: 87 additions & 49 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1188,15 +1188,15 @@ func checkFieldsDataBySchema(schema *schemapb.CollectionSchema, insertMsg *msgst
return nil
}

func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg, inInsert bool) (*schemapb.IDs, error) {
func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
}

if err := checkFieldsDataBySchema(schema, insertMsg, inInsert); err != nil {
if err := checkFieldsDataBySchema(schema, insertMsg, true); err != nil {
return nil, err
}

Expand All @@ -1208,60 +1208,33 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
if primaryFieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
if inInsert {
// when checkPrimaryFieldData in insert
// when checkPrimaryFieldData in insert

skipAutoIDCheck := Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() &&
primaryFieldSchema.AutoID &&
typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema)
skipAutoIDCheck := Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() &&
primaryFieldSchema.AutoID &&
typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema)

if !primaryFieldSchema.AutoID || skipAutoIDCheck {
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Info("get primary field data failed", zap.Error(err))
return nil, err
}
} else {
// check primary key data not exist
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name))
}
// if autoID == true, currently support autoID for int64 and varchar PrimaryField
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate primary field data failed when autoID == true", zap.Error(err))
return nil, err
}
// if autoID == true, set the primary field data
// insertMsg.fieldsData need append primaryFieldData
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
if !primaryFieldSchema.AutoID || skipAutoIDCheck {
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Info("get primary field data failed", zap.Error(err))
return nil, err
}
} else {
primaryFieldID := primaryFieldSchema.FieldID
primaryFieldName := primaryFieldSchema.Name
for i, field := range insertMsg.GetFieldsData() {
if field.FieldId == primaryFieldID || field.FieldName == primaryFieldName {
primaryFieldData = field
if primaryFieldSchema.AutoID {
// use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true
newPrimaryFieldData, err := autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate new primary field data failed when upsert", zap.Error(err))
return nil, err
}
insertMsg.FieldsData = append(insertMsg.GetFieldsData()[:i], insertMsg.GetFieldsData()[i+1:]...)
insertMsg.FieldsData = append(insertMsg.FieldsData, newPrimaryFieldData)
}
break
}
// check primary key data not exist
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name))
}
// must assign primary field data when upsert
if primaryFieldData == nil {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldName))
// if autoID == true, currently support autoID for int64 and varchar PrimaryField
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate primary field data failed when autoID == true", zap.Error(err))
return nil, err
}
// if autoID == true, set the primary field data
// insertMsg.fieldsData need append primaryFieldData
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
}

// parse primaryFieldData to result.IDs, and as returned primary keys
Expand All @@ -1274,6 +1247,71 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
return ids, nil
}

func checkUpsertPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, *schemapb.IDs, error) {
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
}

if err := checkFieldsDataBySchema(schema, insertMsg, false); err != nil {
return nil, nil, err
}

primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
log.Error("get primary field schema failed", zap.Any("schema", schema), zap.Error(err))
return nil, nil, err
}
if primaryFieldSchema.GetNullable() {
return nil, nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
var newPrimaryFieldData *schemapb.FieldData

primaryFieldID := primaryFieldSchema.FieldID
primaryFieldName := primaryFieldSchema.Name
for i, field := range insertMsg.GetFieldsData() {
if field.FieldId == primaryFieldID || field.FieldName == primaryFieldName {
primaryFieldData = field
if primaryFieldSchema.AutoID {
// use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true
newPrimaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate new primary field data failed when upsert", zap.Error(err))
return nil, nil, err
}
insertMsg.FieldsData = append(insertMsg.GetFieldsData()[:i], insertMsg.GetFieldsData()[i+1:]...)
insertMsg.FieldsData = append(insertMsg.FieldsData, newPrimaryFieldData)
}
break
}
}
// must assign primary field data when upsert
if primaryFieldData == nil {
return nil, nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldName))
}

// parse primaryFieldData to result.IDs, and as returned primary keys
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return nil, nil, err
}
if !primaryFieldSchema.GetAutoID() {
return ids, ids, nil
}
newIds, err := parsePrimaryFieldData2IDs(newPrimaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return nil, nil, err
}
return newIds, ids, nil
}

func getPartitionKeyFieldData(fieldSchema *schemapb.FieldSchema, insertMsg *msgstream.InsertMsg) (*schemapb.FieldData, error) {
if len(insertMsg.GetPartitionName()) > 0 && !Params.ProxyCfg.SkipPartitionKeyCheck.GetAsBool() {
return nil, errors.New("not support manually specifying the partition names if partition key mode is used")
Expand Down
26 changes: 13 additions & 13 deletions internal/proxy/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
},
}

_, err := checkPrimaryFieldData(case1.schema, case1.insertMsg, true)
_, err := checkPrimaryFieldData(case1.schema, case1.insertMsg)
assert.NotEqual(t, nil, err)

// the num of passed fields is less than needed
Expand Down Expand Up @@ -1694,7 +1694,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err = checkPrimaryFieldData(case2.schema, case2.insertMsg, true)
_, err = checkPrimaryFieldData(case2.schema, case2.insertMsg)
assert.NotEqual(t, nil, err)

// autoID == false, no primary field schema
Expand Down Expand Up @@ -1734,7 +1734,7 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err = checkPrimaryFieldData(case3.schema, case3.insertMsg, true)
_, err = checkPrimaryFieldData(case3.schema, case3.insertMsg)
assert.NotEqual(t, nil, err)

// autoID == true, has primary field schema, but primary field data exist
Expand Down Expand Up @@ -1781,15 +1781,15 @@ func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
case4.schema.Fields[0].IsPrimaryKey = true
case4.schema.Fields[0].AutoID = true
case4.insertMsg.FieldsData[0] = newScalarFieldData(case4.schema.Fields[0], case4.schema.Fields[0].Name, 10)
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg, true)
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
assert.NotEqual(t, nil, err)

// autoID == true, has primary field schema, but DataType don't match
// the data type of the data not matches the schema
case4.schema.Fields[0].IsPrimaryKey = false
case4.schema.Fields[1].IsPrimaryKey = true
case4.schema.Fields[1].AutoID = true
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg, true)
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
assert.NotEqual(t, nil, err)
}

Expand Down Expand Up @@ -1817,7 +1817,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NotEqual(t, nil, err)
})

Expand Down Expand Up @@ -1861,7 +1861,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NotEqual(t, nil, err)
})

Expand Down Expand Up @@ -1902,7 +1902,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NotEqual(t, nil, err)
})

Expand Down Expand Up @@ -1947,7 +1947,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NotEqual(t, nil, err)
})

Expand Down Expand Up @@ -1998,7 +1998,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NotEqual(t, nil, err)
})

Expand Down Expand Up @@ -2039,7 +2039,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NoError(t, nil, err)

// autoid==false
Expand Down Expand Up @@ -2078,7 +2078,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err = checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err = checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
assert.NoError(t, nil, err)
})

Expand Down Expand Up @@ -2129,7 +2129,7 @@ func Test_UpsertTaskCheckPrimaryFieldData(t *testing.T) {
Status: merr.Success(),
},
}
_, err := checkPrimaryFieldData(task.schema, task.insertMsg, false)
_, _, err := checkUpsertPrimaryFieldData(task.schema, task.insertMsg)
newPK := task.insertMsg.FieldsData[0].GetScalars().GetLongData().GetData()
assert.Equal(t, newPK, task.insertMsg.RowIDs)
assert.NoError(t, nil, err)
Expand Down
Loading

0 comments on commit 6106a48

Please sign in to comment.