diff --git a/internal/rootcoord/alter_collection_task.go b/internal/rootcoord/alter_collection_task.go index dd875a8e5bb55..97dddc29312cf 100644 --- a/internal/rootcoord/alter_collection_task.go +++ b/internal/rootcoord/alter_collection_task.go @@ -66,39 +66,63 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error { return err } - newColl := oldColl.Clone() + var newProperties []*commonpb.KeyValuePair if len(a.Req.Properties) > 0 { if ContainsKeyPairArray(a.Req.GetProperties(), oldColl.Properties) { log.Info("skip to alter collection due to no changes were detected in the properties", zap.Int64("collectionID", oldColl.CollectionID)) return nil } - newColl.Properties = MergeProperties(oldColl.Properties, a.Req.GetProperties()) + newProperties = MergeProperties(oldColl.Properties, a.Req.GetProperties()) } else if len(a.Req.DeleteKeys) > 0 { - newColl.Properties = DeleteProperties(oldColl.Properties, a.Req.GetDeleteKeys()) + newProperties = DeleteProperties(oldColl.Properties, a.Req.GetDeleteKeys()) } ts := a.GetTs() - redoTask := newBaseRedoTask(a.core.stepExecutor) + return executeAlterCollectionTaskSteps(ctx, a.core, oldColl, oldColl.Properties, newProperties, a.Req, ts) +} + +func (a *alterCollectionTask) GetLockerKey() LockerKey { + collection := a.core.getCollectionIDStr(a.ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.Req.GetCollectionID()) + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(a.Req.GetDbName(), false), + NewCollectionLockerKey(collection, true), + ) +} + +func executeAlterCollectionTaskSteps(ctx context.Context, + core *Core, + col *model.Collection, + oldProperties []*commonpb.KeyValuePair, + newProperties []*commonpb.KeyValuePair, + request *milvuspb.AlterCollectionRequest, + ts Timestamp, +) error { + oldColl := col.Clone() + oldColl.Properties = oldProperties + newColl := col.Clone() + newColl.Properties = newProperties + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&AlterCollectionStep{ - baseStep: baseStep{core: a.core}, + baseStep: baseStep{core: core}, oldColl: oldColl, newColl: newColl, ts: ts, }) - a.Req.CollectionID = oldColl.CollectionID + request.CollectionID = oldColl.CollectionID redoTask.AddSyncStep(&BroadcastAlteredCollectionStep{ - baseStep: baseStep{core: a.core}, - req: a.Req, - core: a.core, + baseStep: baseStep{core: core}, + req: request, + core: core, }) // properties needs to be refreshed in the cache - aliases := a.core.meta.ListAliasesByID(ctx, oldColl.CollectionID) + aliases := core.meta.ListAliasesByID(ctx, oldColl.CollectionID) redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: a.core}, - dbName: a.Req.GetDbName(), - collectionNames: append(aliases, a.Req.GetCollectionName()), + baseStep: baseStep{core: core}, + dbName: request.GetDbName(), + collectionNames: append(aliases, request.GetCollectionName()), collectionID: oldColl.CollectionID, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_AlterCollection)}, }) @@ -119,7 +143,7 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error { zap.Strings("newResourceGroups", newResourceGroups), ) redoTask.AddAsyncStep(NewSimpleStep("", func(ctx context.Context) ([]nestedStep, error) { - resp, err := a.core.queryCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{ + resp, err := core.queryCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{ CollectionIDs: []int64{oldColl.CollectionID}, ReplicaNumber: int32(newReplicaNumber), ResourceGroups: newResourceGroups, @@ -165,22 +189,13 @@ func (a *alterCollectionTask) Execute(ctx context.Context) error { zap.String("database", newColl.DBName), zap.String("replicateID", replicateID), ) - return nil, a.core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack) + return nil, core.chanTimeTick.broadcastDmlChannels(newColl.PhysicalChannelNames, msgPack) })) } return redoTask.Execute(ctx) } -func (a *alterCollectionTask) GetLockerKey() LockerKey { - collection := a.core.getCollectionIDStr(a.ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), a.Req.GetCollectionID()) - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(a.Req.GetDbName(), false), - NewCollectionLockerKey(collection, true), - ) -} - func DeleteProperties(oldProps []*commonpb.KeyValuePair, deleteKeys []string) []*commonpb.KeyValuePair { propsMap := make(map[string]string) for _, prop := range oldProps { @@ -227,35 +242,66 @@ func (a *alterCollectionFieldTask) Execute(ctx context.Context) error { return err } - newColl := oldColl.Clone() - err = UpdateFieldProperties(newColl, a.Req.GetFieldName(), a.Req.GetProperties()) + oldFieldProperties, err := GetFieldProperties(oldColl, a.Req.GetFieldName()) if err != nil { + log.Warn("get field properties failed during changing collection state", zap.Error(err)) return err } ts := a.GetTs() - redoTask := newBaseRedoTask(a.core.stepExecutor) + return executeAlterCollectionFieldTaskSteps(ctx, a.core, oldColl, oldFieldProperties, a.Req, ts) +} + +func (a *alterCollectionFieldTask) GetLockerKey() LockerKey { + collection := a.core.getCollectionIDStr(a.ctx, a.Req.GetDbName(), a.Req.GetCollectionName(), 0) + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(a.Req.GetDbName(), false), + NewCollectionLockerKey(collection, true), + ) +} + +func executeAlterCollectionFieldTaskSteps(ctx context.Context, + core *Core, + col *model.Collection, + oldFieldProperties []*commonpb.KeyValuePair, + request *milvuspb.AlterCollectionFieldRequest, + ts Timestamp, +) error { + var err error + filedName := request.GetFieldName() + newFieldProperties := UpdateFieldPropertyParams(oldFieldProperties, request.GetProperties()) + oldColl := col.Clone() + err = ResetFieldProperties(oldColl, filedName, oldFieldProperties) + if err != nil { + return err + } + newColl := col.Clone() + err = ResetFieldProperties(newColl, filedName, newFieldProperties) + if err != nil { + return err + } + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&AlterCollectionStep{ - baseStep: baseStep{core: a.core}, + baseStep: baseStep{core: core}, oldColl: oldColl, newColl: newColl, ts: ts, }) redoTask.AddSyncStep(&BroadcastAlteredCollectionStep{ - baseStep: baseStep{core: a.core}, + baseStep: baseStep{core: core}, req: &milvuspb.AlterCollectionRequest{ - Base: a.Req.Base, - DbName: a.Req.DbName, - CollectionName: a.Req.CollectionName, + Base: request.Base, + DbName: request.DbName, + CollectionName: request.CollectionName, CollectionID: oldColl.CollectionID, }, - core: a.core, + core: core, }) - collectionNames := []string{} redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: a.core}, - dbName: a.Req.GetDbName(), - collectionNames: append(collectionNames, a.Req.GetCollectionName()), + baseStep: baseStep{core: core}, + dbName: request.GetDbName(), + collectionNames: []string{request.GetCollectionName()}, collectionID: oldColl.CollectionID, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_AlterCollectionField)}, }) @@ -263,16 +309,25 @@ func (a *alterCollectionFieldTask) Execute(ctx context.Context) error { return redoTask.Execute(ctx) } -func UpdateFieldProperties(coll *model.Collection, fieldName string, updatedProps []*commonpb.KeyValuePair) error { +func ResetFieldProperties(coll *model.Collection, fieldName string, newProps []*commonpb.KeyValuePair) error { for i, field := range coll.Fields { if field.Name == fieldName { - coll.Fields[i].TypeParams = UpdateFieldPropertyParams(field.TypeParams, updatedProps) + coll.Fields[i].TypeParams = newProps return nil } } return merr.WrapErrParameterInvalidMsg("field %s does not exist in collection", fieldName) } +func GetFieldProperties(coll *model.Collection, fieldName string) ([]*commonpb.KeyValuePair, error) { + for _, field := range coll.Fields { + if field.Name == fieldName { + return field.TypeParams, nil + } + } + return nil, merr.WrapErrParameterInvalidMsg("field %s does not exist in collection", fieldName) +} + func UpdateFieldPropertyParams(oldProps, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair { props := make(map[string]string) for _, prop := range oldProps { diff --git a/internal/rootcoord/alter_database_task.go b/internal/rootcoord/alter_database_task.go index e11e4fa058f5f..6d21e841b3355 100644 --- a/internal/rootcoord/alter_database_task.go +++ b/internal/rootcoord/alter_database_task.go @@ -79,30 +79,78 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { return err } - newDB := oldDB.Clone() + var newProperties []*commonpb.KeyValuePair if (len(a.Req.GetProperties())) > 0 { if ContainsKeyPairArray(a.Req.GetProperties(), oldDB.Properties) { log.Info("skip to alter database due to no changes were detected in the properties", zap.String("databaseName", a.Req.GetDbName())) return nil } - ret := MergeProperties(oldDB.Properties, a.Req.GetProperties()) - newDB.Properties = ret + newProperties = MergeProperties(oldDB.Properties, a.Req.GetProperties()) } else if (len(a.Req.GetDeleteKeys())) > 0 { - ret := DeleteProperties(oldDB.Properties, a.Req.GetDeleteKeys()) - newDB.Properties = ret + newProperties = DeleteProperties(oldDB.Properties, a.Req.GetDeleteKeys()) } - ts := a.GetTs() - redoTask := newBaseRedoTask(a.core.stepExecutor) + return executeAlterDatabaseTaskSteps(ctx, a.core, oldDB, oldDB.Properties, newProperties, a.ts) +} + +func (a *alterDatabaseTask) GetLockerKey() LockerKey { + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(a.Req.GetDbName(), true), + ) +} + +func MergeProperties(oldProps []*commonpb.KeyValuePair, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair { + _, existEndTS := common.GetReplicateEndTS(updatedProps) + if existEndTS { + updatedProps = append(updatedProps, &commonpb.KeyValuePair{ + Key: common.ReplicateIDKey, + Value: "", + }) + } + + props := make(map[string]string) + for _, prop := range oldProps { + props[prop.Key] = prop.Value + } + + for _, prop := range updatedProps { + props[prop.Key] = prop.Value + } + + propKV := make([]*commonpb.KeyValuePair, 0) + + for key, value := range props { + propKV = append(propKV, &commonpb.KeyValuePair{ + Key: key, + Value: value, + }) + } + + return propKV +} + +func executeAlterDatabaseTaskSteps(ctx context.Context, + core *Core, + dbInfo *model.Database, + oldProperties []*commonpb.KeyValuePair, + newProperties []*commonpb.KeyValuePair, + ts Timestamp, +) error { + oldDB := dbInfo.Clone() + oldDB.Properties = oldProperties + newDB := dbInfo.Clone() + newDB.Properties = newProperties + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&AlterDatabaseStep{ - baseStep: baseStep{core: a.core}, + baseStep: baseStep{core: core}, oldDB: oldDB, newDB: newDB, ts: ts, }) redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: a.core}, + baseStep: baseStep{core: core}, dbName: newDB.Name, ts: ts, // make sure to send the "expire cache" request @@ -129,7 +177,7 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { zap.Strings("newResourceGroups", newResourceGroups), ) redoTask.AddAsyncStep(NewSimpleStep("", func(ctx context.Context) ([]nestedStep, error) { - colls, err := a.core.meta.ListCollections(ctx, oldDB.Name, a.ts, true) + colls, err := core.meta.ListCollections(ctx, oldDB.Name, ts, true) if err != nil { log.Ctx(ctx).Warn("failed to trigger update load config for database", zap.Int64("dbID", oldDB.ID), zap.Error(err)) return nil, err @@ -138,7 +186,7 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { return nil, nil } - resp, err := a.core.queryCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{ + resp, err := core.queryCoord.UpdateLoadConfig(ctx, &querypb.UpdateLoadConfigRequest{ CollectionIDs: lo.Map(colls, func(coll *model.Collection, _ int) int64 { return coll.CollectionID }), ReplicaNumber: int32(newReplicaNumber), ResourceGroups: newResourceGroups, @@ -180,46 +228,9 @@ func (a *alterDatabaseTask) Execute(ctx context.Context) error { } msgPack.Msgs = append(msgPack.Msgs, msg) log.Info("send replicate end msg for db", zap.String("db", newDB.Name), zap.String("replicateID", replicateID)) - return nil, a.core.chanTimeTick.broadcastDmlChannels(a.core.chanTimeTick.listDmlChannels(), msgPack) + return nil, core.chanTimeTick.broadcastDmlChannels(core.chanTimeTick.listDmlChannels(), msgPack) })) } return redoTask.Execute(ctx) } - -func (a *alterDatabaseTask) GetLockerKey() LockerKey { - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(a.Req.GetDbName(), true), - ) -} - -func MergeProperties(oldProps []*commonpb.KeyValuePair, updatedProps []*commonpb.KeyValuePair) []*commonpb.KeyValuePair { - _, existEndTS := common.GetReplicateEndTS(updatedProps) - if existEndTS { - updatedProps = append(updatedProps, &commonpb.KeyValuePair{ - Key: common.ReplicateIDKey, - Value: "", - }) - } - - props := make(map[string]string) - for _, prop := range oldProps { - props[prop.Key] = prop.Value - } - - for _, prop := range updatedProps { - props[prop.Key] = prop.Value - } - - propKV := make([]*commonpb.KeyValuePair, 0) - - for key, value := range props { - propKV = append(propKV, &commonpb.KeyValuePair{ - Key: key, - Value: value, - }) - } - - return propKV -} diff --git a/internal/rootcoord/create_collection_task.go b/internal/rootcoord/create_collection_task.go index 25b437ed91f5d..00f1804127aa4 100644 --- a/internal/rootcoord/create_collection_task.go +++ b/internal/rootcoord/create_collection_task.go @@ -620,68 +620,81 @@ func (t *createCollectionTask) Execute(ctx context.Context) error { } collInfo.StartPositions = toKeyDataPairs(startPositions) - undoTask := newBaseUndoTask(t.core.stepExecutor) + return executeCreateCollectionTaskSteps(ctx, t.core, &collInfo, t.Req.GetDbName(), t.dbProperties, ts) +} + +func (t *createCollectionTask) GetLockerKey() LockerKey { + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(t.Req.GetDbName(), false), + NewCollectionLockerKey(strconv.FormatInt(t.collID, 10), true), + ) +} + +func executeCreateCollectionTaskSteps(ctx context.Context, + core *Core, + col *model.Collection, + dbName string, + dbProperties []*commonpb.KeyValuePair, + ts Timestamp, +) error { + undoTask := newBaseUndoTask(core.stepExecutor) + collID := col.CollectionID undoTask.AddStep(&expireCacheStep{ - baseStep: baseStep{core: t.core}, - dbName: t.Req.GetDbName(), - collectionNames: []string{t.Req.GetCollectionName()}, + baseStep: baseStep{core: core}, + dbName: dbName, + collectionNames: []string{col.Name}, collectionID: collID, ts: ts, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropCollection)}, }, &nullStep{}) undoTask.AddStep(&nullStep{}, &removeDmlChannelsStep{ - baseStep: baseStep{core: t.core}, - pChannels: chanNames, + baseStep: baseStep{core: core}, + pChannels: col.PhysicalChannelNames, }) // remove dml channels if any error occurs. undoTask.AddStep(&addCollectionMetaStep{ - baseStep: baseStep{core: t.core}, - coll: &collInfo, + baseStep: baseStep{core: core}, + coll: col, }, &deleteCollectionMetaStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, collectionID: collID, // When we undo createCollectionTask, this ts may be less than the ts when unwatch channels. ts: ts, }) // serve for this case: watching channels succeed in datacoord but failed due to network failure. undoTask.AddStep(&nullStep{}, &unwatchChannelsStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, collectionID: collID, - channels: t.channels, - isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), + channels: collectionChannels{ + virtualChannels: col.VirtualChannelNames, + physicalChannels: col.PhysicalChannelNames, + }, + isSkip: !Params.CommonCfg.TTMsgEnabled.GetAsBool(), }) undoTask.AddStep(&watchChannelsStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, info: &watchInfo{ ts: ts, collectionID: collID, - vChannels: t.channels.virtualChannels, - startPositions: toKeyDataPairs(startPositions), + vChannels: col.VirtualChannelNames, + startPositions: col.StartPositions, schema: &schemapb.CollectionSchema{ - Name: collInfo.Name, - DbName: collInfo.DBName, - Description: collInfo.Description, - AutoID: collInfo.AutoID, - Fields: model.MarshalFieldModels(collInfo.Fields), - Properties: collInfo.Properties, - Functions: model.MarshalFunctionModels(collInfo.Functions), + Name: col.Name, + DbName: col.DBName, + Description: col.Description, + AutoID: col.AutoID, + Fields: model.MarshalFieldModels(col.Fields), + Properties: col.Properties, + Functions: model.MarshalFunctionModels(col.Functions), }, - dbProperties: t.dbProperties, + dbProperties: dbProperties, }, }, &nullStep{}) undoTask.AddStep(&changeCollectionStateStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, collectionID: collID, state: pb.CollectionState_CollectionCreated, ts: ts, }, &nullStep{}) // We'll remove the whole collection anyway. - return undoTask.Execute(ctx) } - -func (t *createCollectionTask) GetLockerKey() LockerKey { - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(t.Req.GetDbName(), false), - NewCollectionLockerKey(t.Req.GetCollectionName(), true), - ) -} diff --git a/internal/rootcoord/create_partition_task.go b/internal/rootcoord/create_partition_task.go index 609efb2e7730a..7d76720c8006e 100644 --- a/internal/rootcoord/create_partition_task.go +++ b/internal/rootcoord/create_partition_task.go @@ -76,60 +76,71 @@ func (t *createPartitionTask) Execute(ctx context.Context) error { State: pb.PartitionState_PartitionCreating, } - undoTask := newBaseUndoTask(t.core.stepExecutor) + return executeCreatePartitionTaskSteps(ctx, t.core, partition, t.collMeta, t.Req.GetDbName(), t.GetTs()) +} + +func (t *createPartitionTask) GetLockerKey() LockerKey { + collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0) + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(t.Req.GetDbName(), false), + NewCollectionLockerKey(collection, true), + ) +} +func executeCreatePartitionTaskSteps(ctx context.Context, + core *Core, + partition *model.Partition, + col *model.Collection, + dbName string, + ts Timestamp, +) error { + undoTask := newBaseUndoTask(core.stepExecutor) + partID := partition.PartitionID + collectionID := partition.CollectionID undoTask.AddStep(&expireCacheStep{ - baseStep: baseStep{core: t.core}, - dbName: t.Req.GetDbName(), - collectionNames: []string{t.collMeta.Name}, - collectionID: t.collMeta.CollectionID, - partitionName: t.Req.GetPartitionName(), - ts: t.GetTs(), + baseStep: baseStep{core: core}, + dbName: dbName, + collectionNames: []string{col.Name}, + collectionID: collectionID, + partitionName: partition.PartitionName, + ts: ts, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_CreatePartition)}, }, &nullStep{}) undoTask.AddStep(&addPartitionMetaStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, partition: partition, }, &removePartitionMetaStep{ - baseStep: baseStep{core: t.core}, - dbID: t.collMeta.DBID, + baseStep: baseStep{core: core}, + dbID: col.DBID, collectionID: partition.CollectionID, partitionID: partition.PartitionID, - ts: t.GetTs(), + ts: ts, }) if streamingutil.IsStreamingServiceEnabled() { undoTask.AddStep(&broadcastCreatePartitionMsgStep{ - baseStep: baseStep{core: t.core}, - vchannels: t.collMeta.VirtualChannelNames, + baseStep: baseStep{core: core}, + vchannels: col.VirtualChannelNames, partition: partition, - ts: t.GetTs(), + ts: ts, }, &nullStep{}) } undoTask.AddStep(&nullStep{}, &releasePartitionsStep{ - baseStep: baseStep{core: t.core}, - collectionID: t.collMeta.CollectionID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, partitionIDs: []int64{partID}, }) undoTask.AddStep(&changePartitionStateStep{ - baseStep: baseStep{core: t.core}, - collectionID: t.collMeta.CollectionID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, partitionID: partID, state: pb.PartitionState_PartitionCreated, - ts: t.GetTs(), + ts: ts, }, &nullStep{}) return undoTask.Execute(ctx) } - -func (t *createPartitionTask) GetLockerKey() LockerKey { - collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0) - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(t.Req.GetDbName(), false), - NewCollectionLockerKey(collection, true), - ) -} diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index 3ae2eca075888..795842a51b609 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -25,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/log" @@ -71,46 +72,68 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { aliases := t.core.meta.ListAliasesByID(ctx, collMeta.CollectionID) ts := t.GetTs() + return executeDropCollectionTaskSteps(ctx, + t.core, collMeta, t.Req.GetDbName(), aliases, + t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), + ts) +} + +func (t *dropCollectionTask) GetLockerKey() LockerKey { + collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0) + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(t.Req.GetDbName(), false), + NewCollectionLockerKey(collection, true), + ) +} - redoTask := newBaseRedoTask(t.core.stepExecutor) +func executeDropCollectionTaskSteps(ctx context.Context, + core *Core, + col *model.Collection, + dbName string, + alias []string, + isReplicate bool, + ts Timestamp, +) error { + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: t.core}, - dbName: t.Req.GetDbName(), - collectionNames: append(aliases, collMeta.Name), - collectionID: collMeta.CollectionID, + baseStep: baseStep{core: core}, + dbName: dbName, + collectionNames: append(alias, col.Name), + collectionID: col.CollectionID, ts: ts, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropCollection)}, }) redoTask.AddSyncStep(&changeCollectionStateStep{ - baseStep: baseStep{core: t.core}, - collectionID: collMeta.CollectionID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, state: pb.CollectionState_CollectionDropping, ts: ts, }) redoTask.AddAsyncStep(&releaseCollectionStep{ - baseStep: baseStep{core: t.core}, - collectionID: collMeta.CollectionID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, }) redoTask.AddAsyncStep(&dropIndexStep{ - baseStep: baseStep{core: t.core}, - collID: collMeta.CollectionID, + baseStep: baseStep{core: core}, + collID: col.CollectionID, partIDs: nil, }) redoTask.AddAsyncStep(&deleteCollectionDataStep{ - baseStep: baseStep{core: t.core}, - coll: collMeta, - isSkip: t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), + baseStep: baseStep{core: core}, + coll: col, + isSkip: isReplicate, }) redoTask.AddAsyncStep(&removeDmlChannelsStep{ - baseStep: baseStep{core: t.core}, - pChannels: collMeta.PhysicalChannelNames, + baseStep: baseStep{core: core}, + pChannels: col.PhysicalChannelNames, }) - redoTask.AddAsyncStep(newConfirmGCStep(t.core, collMeta.CollectionID, allPartition)) + redoTask.AddAsyncStep(newConfirmGCStep(core, col.CollectionID, allPartition)) redoTask.AddAsyncStep(&deleteCollectionMetaStep{ - baseStep: baseStep{core: t.core}, - collectionID: collMeta.CollectionID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, // This ts is less than the ts when we notify data nodes to drop collection, but it's OK since we have already // marked this collection as deleted. If we want to make this ts greater than the notification's ts, we should // wrap a step who will have these three children and connect them with ts. @@ -119,11 +142,3 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { return redoTask.Execute(ctx) } - -func (t *dropCollectionTask) GetLockerKey() LockerKey { - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(t.Req.GetDbName(), false), - NewCollectionLockerKey(t.Req.GetCollectionName(), true), - ) -} diff --git a/internal/rootcoord/drop_db_task.go b/internal/rootcoord/drop_db_task.go index bdc1cc035db32..e071f3b3dfab4 100644 --- a/internal/rootcoord/drop_db_task.go +++ b/internal/rootcoord/drop_db_task.go @@ -39,16 +39,28 @@ func (t *dropDatabaseTask) Prepare(ctx context.Context) error { } func (t *dropDatabaseTask) Execute(ctx context.Context) error { - redoTask := newBaseRedoTask(t.core.stepExecutor) dbName := t.Req.GetDbName() ts := t.GetTs() + return executeDropDatabaseTaskSteps(ctx, t.core, dbName, ts) +} + +func (t *dropDatabaseTask) GetLockerKey() LockerKey { + return NewLockerKeyChain(NewClusterLockerKey(true)) +} + +func executeDropDatabaseTaskSteps(ctx context.Context, + core *Core, + dbName string, + ts Timestamp, +) error { + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&deleteDatabaseMetaStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, databaseName: dbName, ts: ts, }) redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: t.core}, + baseStep: baseStep{core: core}, dbName: dbName, ts: ts, // make sure to send the "expire cache" request @@ -60,7 +72,3 @@ func (t *dropDatabaseTask) Execute(ctx context.Context) error { }) return redoTask.Execute(ctx) } - -func (t *dropDatabaseTask) GetLockerKey() LockerKey { - return NewLockerKeyChain(NewClusterLockerKey(true)) -} diff --git a/internal/rootcoord/drop_partition_task.go b/internal/rootcoord/drop_partition_task.go index d25265f05da55..648b3c74ddd6f 100644 --- a/internal/rootcoord/drop_partition_task.go +++ b/internal/rootcoord/drop_partition_task.go @@ -67,56 +67,71 @@ func (t *dropPartitionTask) Execute(ctx context.Context) error { return nil } - redoTask := newBaseRedoTask(t.core.stepExecutor) + return executeDropPartitionTaskSteps(ctx, t.core, + t.Req.GetPartitionName(), partID, + t.collMeta, t.Req.GetDbName(), + t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), t.GetTs()) +} + +func (t *dropPartitionTask) GetLockerKey() LockerKey { + collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0) + return NewLockerKeyChain( + NewClusterLockerKey(false), + NewDatabaseLockerKey(t.Req.GetDbName(), false), + NewCollectionLockerKey(collection, true), + ) +} + +func executeDropPartitionTaskSteps(ctx context.Context, + core *Core, + partitionName string, + partitionID UniqueID, + col *model.Collection, + dbName string, + isReplicate bool, + ts Timestamp, +) error { + redoTask := newBaseRedoTask(core.stepExecutor) redoTask.AddSyncStep(&expireCacheStep{ - baseStep: baseStep{core: t.core}, - dbName: t.Req.GetDbName(), - collectionNames: []string{t.collMeta.Name}, - collectionID: t.collMeta.CollectionID, - partitionName: t.Req.GetPartitionName(), - ts: t.GetTs(), + baseStep: baseStep{core: core}, + dbName: dbName, + collectionNames: []string{col.Name}, + collectionID: col.CollectionID, + partitionName: partitionName, + ts: ts, opts: []proxyutil.ExpireCacheOpt{proxyutil.SetMsgType(commonpb.MsgType_DropPartition)}, }) redoTask.AddSyncStep(&changePartitionStateStep{ - baseStep: baseStep{core: t.core}, - collectionID: t.collMeta.CollectionID, - partitionID: partID, + baseStep: baseStep{core: core}, + collectionID: col.CollectionID, + partitionID: partitionID, state: pb.PartitionState_PartitionDropping, - ts: t.GetTs(), + ts: ts, }) redoTask.AddAsyncStep(&deletePartitionDataStep{ - baseStep: baseStep{core: t.core}, - pchans: t.collMeta.PhysicalChannelNames, - vchans: t.collMeta.VirtualChannelNames, + baseStep: baseStep{core: core}, + pchans: col.PhysicalChannelNames, + vchans: col.VirtualChannelNames, partition: &model.Partition{ - PartitionID: partID, - PartitionName: t.Req.GetPartitionName(), - CollectionID: t.collMeta.CollectionID, + PartitionID: partitionID, + PartitionName: partitionName, + CollectionID: col.CollectionID, }, - isSkip: t.Req.GetBase().GetReplicateInfo().GetIsReplicate(), + isSkip: isReplicate, }) - redoTask.AddAsyncStep(newConfirmGCStep(t.core, t.collMeta.CollectionID, partID)) + redoTask.AddAsyncStep(newConfirmGCStep(core, col.CollectionID, partitionID)) redoTask.AddAsyncStep(&removePartitionMetaStep{ - baseStep: baseStep{core: t.core}, - dbID: t.collMeta.DBID, - collectionID: t.collMeta.CollectionID, - partitionID: partID, + baseStep: baseStep{core: core}, + dbID: col.DBID, + collectionID: col.CollectionID, + partitionID: partitionID, // This ts is less than the ts when we notify data nodes to drop partition, but it's OK since we have already // marked this partition as deleted. If we want to make this ts greater than the notification's ts, we should // wrap a step who will have these children and connect them with ts. - ts: t.GetTs(), + ts: ts, }) return redoTask.Execute(ctx) } - -func (t *dropPartitionTask) GetLockerKey() LockerKey { - collection := t.core.getCollectionIDStr(t.ctx, t.Req.GetDbName(), t.Req.GetCollectionName(), 0) - return NewLockerKeyChain( - NewClusterLockerKey(false), - NewDatabaseLockerKey(t.Req.GetDbName(), false), - NewCollectionLockerKey(collection, true), - ) -} diff --git a/internal/rootcoord/rbac_task.go b/internal/rootcoord/rbac_task.go new file mode 100644 index 0000000000000..20bacc559be50 --- /dev/null +++ b/internal/rootcoord/rbac_task.go @@ -0,0 +1,370 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rootcoord + +import ( + "context" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func executeDeleteCredentialTaskSteps(ctx context.Context, core *Core, username string) error { + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("delete credential meta data", func(ctx context.Context) ([]nestedStep, error) { + err := core.meta.DeleteCredential(ctx, username) + if err != nil { + log.Ctx(ctx).Warn("delete credential meta data failed", zap.String("username", username), zap.Error(err)) + } + return nil, err + })) + redoTask.AddAsyncStep(NewSimpleStep("delete credential cache", func(ctx context.Context) ([]nestedStep, error) { + err := core.ExpireCredCache(ctx, username) + if err != nil { + log.Ctx(ctx).Warn("delete credential cache failed", zap.String("username", username), zap.Error(err)) + } + return nil, err + })) + redoTask.AddAsyncStep(NewSimpleStep("delete user role cache for the user", func(ctx context.Context) ([]nestedStep, error) { + err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: int32(typeutil.CacheDeleteUser), + OpKey: username, + }) + if err != nil { + log.Ctx(ctx).Warn("delete user role cache failed for the user", zap.String("username", username), zap.Error(err)) + } + return nil, err + })) + + return redoTask.Execute(ctx) +} + +func executeDropRoleTaskSteps(ctx context.Context, core *Core, roleName string, foreDrop bool) error { + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("drop role meta data", func(ctx context.Context) ([]nestedStep, error) { + err := core.meta.DropRole(ctx, util.DefaultTenant, roleName) + if err != nil { + log.Ctx(ctx).Warn("drop role mata data failed", zap.String("role_name", roleName), zap.Error(err)) + } + return nil, err + })) + redoTask.AddAsyncStep(NewSimpleStep("drop the privilege list of this role", func(ctx context.Context) ([]nestedStep, error) { + if !foreDrop { + return nil, nil + } + err := core.meta.DropGrant(ctx, util.DefaultTenant, &milvuspb.RoleEntity{Name: roleName}) + if err != nil { + log.Ctx(ctx).Warn("drop the privilege list failed for the role", zap.String("role_name", roleName), zap.Error(err)) + } + return nil, err + })) + redoTask.AddAsyncStep(NewSimpleStep("drop role cache", func(ctx context.Context) ([]nestedStep, error) { + err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: int32(typeutil.CacheDropRole), + OpKey: roleName, + }) + if err != nil { + log.Ctx(ctx).Warn("delete user role cache failed for the role", zap.String("role_name", roleName), zap.Error(err)) + } + return nil, err + })) + return redoTask.Execute(ctx) +} + +func executeOperateUserRoleTaskSteps(ctx context.Context, core *Core, in *milvuspb.OperateUserRoleRequest) error { + username := in.Username + roleName := in.RoleName + operateType := in.Type + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("operate user role meta data", func(ctx context.Context) ([]nestedStep, error) { + err := core.meta.OperateUserRole(ctx, util.DefaultTenant, &milvuspb.UserEntity{Name: username}, &milvuspb.RoleEntity{Name: roleName}, operateType) + if err != nil && !common.IsIgnorableError(err) { + log.Ctx(ctx).Warn("operate user role mata data failed", + zap.String("username", username), zap.String("role_name", roleName), + zap.Any("operate_type", operateType), + zap.Error(err)) + return nil, err + } + return nil, nil + })) + redoTask.AddAsyncStep(NewSimpleStep("operate user role cache", func(ctx context.Context) ([]nestedStep, error) { + var opType int32 + switch operateType { + case milvuspb.OperateUserRoleType_AddUserToRole: + opType = int32(typeutil.CacheAddUserToRole) + case milvuspb.OperateUserRoleType_RemoveUserFromRole: + opType = int32(typeutil.CacheRemoveUserFromRole) + default: + errMsg := "invalid operate type for the OperateUserRole api" + log.Ctx(ctx).Warn(errMsg, + zap.String("username", username), zap.String("role_name", roleName), + zap.Any("operate_type", operateType), + ) + return nil, nil + } + if err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: opType, + OpKey: funcutil.EncodeUserRoleCache(username, roleName), + }); err != nil { + log.Ctx(ctx).Warn("fail to refresh policy info cache", + zap.String("username", username), zap.String("role_name", roleName), + zap.Any("operate_type", operateType), + zap.Error(err), + ) + return nil, err + } + return nil, nil + })) + return redoTask.Execute(ctx) +} + +func executeOperatePrivilegeTaskSteps(ctx context.Context, core *Core, in *milvuspb.OperatePrivilegeRequest) error { + privName := in.Entity.Grantor.Privilege.Name + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("operate privilege meta data", func(ctx context.Context) ([]nestedStep, error) { + if !util.IsAnyWord(privName) { + // set up privilege name for metastore + dbPrivName, err := core.getMetastorePrivilegeName(ctx, privName) + if err != nil { + return nil, err + } + in.Entity.Grantor.Privilege.Name = dbPrivName + } + + err := core.meta.OperatePrivilege(ctx, util.DefaultTenant, in.Entity, in.Type) + if err != nil && !common.IsIgnorableError(err) { + log.Ctx(ctx).Warn("fail to operate the privilege", zap.Any("in", in), zap.Error(err)) + return nil, err + } + return nil, nil + })) + redoTask.AddAsyncStep(NewSimpleStep("operate privilege cache", func(ctx context.Context) ([]nestedStep, error) { + // set back to expand privilege group + in.Entity.Grantor.Privilege.Name = privName + var opType int32 + switch in.Type { + case milvuspb.OperatePrivilegeType_Grant: + opType = int32(typeutil.CacheGrantPrivilege) + case milvuspb.OperatePrivilegeType_Revoke: + opType = int32(typeutil.CacheRevokePrivilege) + default: + log.Ctx(ctx).Warn("invalid operate type for the OperatePrivilege api", zap.Any("in", in)) + return nil, nil + } + grants := []*milvuspb.GrantEntity{in.Entity} + + allGroups, err := core.meta.ListPrivilegeGroups(ctx) + allGroups = append(allGroups, core.initBuiltinPrivilegeGroups()...) + if err != nil { + return nil, err + } + groups := lo.SliceToMap(allGroups, func(group *milvuspb.PrivilegeGroupInfo) (string, []*milvuspb.PrivilegeEntity) { + return group.GroupName, group.Privileges + }) + expandGrants, err := core.expandPrivilegeGroups(ctx, grants, groups) + if err != nil { + return nil, err + } + // if there is same grant in the other privilege groups, the grant should not be removed from the cache + if in.Type == milvuspb.OperatePrivilegeType_Revoke { + metaGrants, err := core.meta.SelectGrant(ctx, util.DefaultTenant, &milvuspb.GrantEntity{ + Role: in.Entity.Role, + DbName: in.Entity.DbName, + }) + if err != nil { + return nil, err + } + metaExpandGrants, err := core.expandPrivilegeGroups(ctx, metaGrants, groups) + if err != nil { + return nil, err + } + expandGrants = lo.Filter(expandGrants, func(g1 *milvuspb.GrantEntity, _ int) bool { + return !lo.ContainsBy(metaExpandGrants, func(g2 *milvuspb.GrantEntity) bool { + return proto.Equal(g1, g2) + }) + }) + } + if err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: opType, + OpKey: funcutil.PolicyForPrivileges(expandGrants), + }); err != nil { + log.Ctx(ctx).Warn("fail to refresh policy info cache", zap.Any("in", in), zap.Error(err)) + return nil, err + } + return nil, nil + })) + + return redoTask.Execute(ctx) +} + +func executeRestoreRBACTaskSteps(ctx context.Context, core *Core, in *milvuspb.RestoreRBACMetaRequest) error { + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("restore rbac meta data", func(ctx context.Context) ([]nestedStep, error) { + if err := core.meta.RestoreRBAC(ctx, util.DefaultTenant, in.RBACMeta); err != nil { + log.Ctx(ctx).Warn("fail to restore rbac meta data", zap.Any("in", in), zap.Error(err)) + return nil, err + } + return nil, nil + })) + redoTask.AddAsyncStep(NewSimpleStep("operate privilege cache", func(ctx context.Context) ([]nestedStep, error) { + if err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: int32(typeutil.CacheRefresh), + }); err != nil { + log.Ctx(ctx).Warn("fail to refresh policy info cache", zap.Any("in", in), zap.Error(err)) + return nil, err + } + return nil, nil + })) + + return redoTask.Execute(ctx) +} + +func executeOperatePrivilegeGroupTaskSteps(ctx context.Context, core *Core, in *milvuspb.OperatePrivilegeGroupRequest) error { + redoTask := newBaseRedoTask(core.stepExecutor) + redoTask.AddSyncStep(NewSimpleStep("operate privilege group", func(ctx context.Context) ([]nestedStep, error) { + groups, err := core.meta.ListPrivilegeGroups(ctx) + if err != nil && !common.IsIgnorableError(err) { + log.Ctx(ctx).Warn("fail to list privilege groups", zap.Error(err)) + return nil, err + } + currGroups := lo.SliceToMap(groups, func(group *milvuspb.PrivilegeGroupInfo) (string, []*milvuspb.PrivilegeEntity) { + return group.GroupName, group.Privileges + }) + + // get roles granted to the group + roles, err := core.meta.GetPrivilegeGroupRoles(ctx, in.GroupName) + if err != nil { + return nil, err + } + + newGroups := make(map[string][]*milvuspb.PrivilegeEntity) + for k, v := range currGroups { + if k != in.GroupName { + newGroups[k] = v + continue + } + switch in.Type { + case milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup: + newPrivs := lo.Union(v, in.Privileges) + newGroups[k] = lo.UniqBy(newPrivs, func(p *milvuspb.PrivilegeEntity) string { + return p.Name + }) + + // check if privileges are the same object type + objectTypes := lo.SliceToMap(newPrivs, func(p *milvuspb.PrivilegeEntity) (string, struct{}) { + return util.GetObjectType(p.Name), struct{}{} + }) + if len(objectTypes) > 1 { + return nil, errors.New("privileges are not the same object type") + } + case milvuspb.OperatePrivilegeGroupType_RemovePrivilegesFromGroup: + newPrivs, _ := lo.Difference(v, in.Privileges) + newGroups[k] = newPrivs + default: + return nil, errors.New("invalid operate type") + } + } + + var rolesToRevoke []*milvuspb.GrantEntity + var rolesToGrant []*milvuspb.GrantEntity + compareGrants := func(a, b *milvuspb.GrantEntity) bool { + return a.Role.Name == b.Role.Name && + a.Object.Name == b.Object.Name && + a.ObjectName == b.ObjectName && + a.Grantor.User.Name == b.Grantor.User.Name && + a.Grantor.Privilege.Name == b.Grantor.Privilege.Name && + a.DbName == b.DbName + } + for _, role := range roles { + grants, err := core.meta.SelectGrant(ctx, util.DefaultTenant, &milvuspb.GrantEntity{ + Role: role, + DbName: util.AnyWord, + }) + if err != nil { + return nil, err + } + currGrants, err := core.expandPrivilegeGroups(ctx, grants, currGroups) + if err != nil { + return nil, err + } + newGrants, err := core.expandPrivilegeGroups(ctx, grants, newGroups) + if err != nil { + return nil, err + } + + toRevoke := lo.Filter(currGrants, func(item *milvuspb.GrantEntity, _ int) bool { + return !lo.ContainsBy(newGrants, func(newItem *milvuspb.GrantEntity) bool { + return compareGrants(item, newItem) + }) + }) + + toGrant := lo.Filter(newGrants, func(item *milvuspb.GrantEntity, _ int) bool { + return !lo.ContainsBy(currGrants, func(currItem *milvuspb.GrantEntity) bool { + return compareGrants(item, currItem) + }) + }) + + rolesToRevoke = append(rolesToRevoke, toRevoke...) + rolesToGrant = append(rolesToGrant, toGrant...) + } + + if len(rolesToRevoke) > 0 { + opType := int32(typeutil.CacheRevokePrivilege) + if err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: opType, + OpKey: funcutil.PolicyForPrivileges(rolesToRevoke), + }); err != nil { + log.Ctx(ctx).Warn("fail to refresh policy info cache for revoke privileges in operate privilege group", zap.Any("in", in), zap.Error(err)) + return nil, err + } + } + + if len(rolesToGrant) > 0 { + opType := int32(typeutil.CacheGrantPrivilege) + if err := core.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ + OpType: opType, + OpKey: funcutil.PolicyForPrivileges(rolesToGrant), + }); err != nil { + log.Ctx(ctx).Warn("fail to refresh policy info cache for grants privilege in operate privilege group", zap.Any("in", in), zap.Error(err)) + return nil, err + } + } + return nil, nil + })) + + redoTask.AddSyncStep(NewSimpleStep("operate privilege group meta data", func(ctx context.Context) ([]nestedStep, error) { + err := core.meta.OperatePrivilegeGroup(ctx, in.GroupName, in.Privileges, in.Type) + if err != nil && !common.IsIgnorableError(err) { + log.Ctx(ctx).Warn("fail to operate privilege group", zap.Error(err)) + } + return nil, err + })) + + return redoTask.Execute(ctx) +} diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index d4ac677b6e038..76c4e0e5b8d36 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -32,7 +32,6 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/atomic" "go.uber.org/zap" - "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" @@ -2271,33 +2270,7 @@ func (c *Core) DeleteCredential(ctx context.Context, in *milvuspb.DeleteCredenti } }() - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("delete credential meta data", func(ctx context.Context) ([]nestedStep, error) { - err := c.meta.DeleteCredential(ctx, in.Username) - if err != nil { - ctxLog.Warn("delete credential meta data failed", zap.Error(err)) - } - return nil, err - })) - redoTask.AddAsyncStep(NewSimpleStep("delete credential cache", func(ctx context.Context) ([]nestedStep, error) { - err := c.ExpireCredCache(ctx, in.Username) - if err != nil { - ctxLog.Warn("delete credential cache failed", zap.Error(err)) - } - return nil, err - })) - redoTask.AddAsyncStep(NewSimpleStep("delete user role cache for the user", func(ctx context.Context) ([]nestedStep, error) { - err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: int32(typeutil.CacheDeleteUser), - OpKey: in.Username, - }) - if err != nil { - ctxLog.Warn("delete user role cache failed for the user", zap.Error(err)) - } - return nil, err - })) - - err := redoTask.Execute(ctx) + err := executeDeleteCredentialTaskSteps(ctx, c, in.Username) if err != nil { errMsg := "fail to execute task when deleting the user" ctxLog.Warn(errMsg, zap.Error(err)) @@ -2412,35 +2385,7 @@ func (c *Core) DropRole(ctx context.Context, in *milvuspb.DropRoleRequest) (*com return merr.StatusWithErrorCode(errors.New(errMsg), commonpb.ErrorCode_DropRoleFailure), nil } } - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("drop role meta data", func(ctx context.Context) ([]nestedStep, error) { - err := c.meta.DropRole(ctx, util.DefaultTenant, in.RoleName) - if err != nil { - ctxLog.Warn("drop role mata data failed", zap.Error(err)) - } - return nil, err - })) - redoTask.AddAsyncStep(NewSimpleStep("drop the privilege list of this role", func(ctx context.Context) ([]nestedStep, error) { - if !in.ForceDrop { - return nil, nil - } - err := c.meta.DropGrant(ctx, util.DefaultTenant, &milvuspb.RoleEntity{Name: in.RoleName}) - if err != nil { - ctxLog.Warn("drop the privilege list failed for the role", zap.Error(err)) - } - return nil, err - })) - redoTask.AddAsyncStep(NewSimpleStep("drop role cache", func(ctx context.Context) ([]nestedStep, error) { - err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: int32(typeutil.CacheDropRole), - OpKey: in.RoleName, - }) - if err != nil { - ctxLog.Warn("delete user role cache failed for the role", zap.Error(err)) - } - return nil, err - })) - err := redoTask.Execute(ctx) + err := executeDropRoleTaskSteps(ctx, c, in.RoleName, in.ForceDrop) if err != nil { errMsg := "fail to execute task when dropping the role" ctxLog.Warn(errMsg, zap.Error(err)) @@ -2484,37 +2429,7 @@ func (c *Core) OperateUserRole(ctx context.Context, in *milvuspb.OperateUserRole } } - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("operate user role meta data", func(ctx context.Context) ([]nestedStep, error) { - err := c.meta.OperateUserRole(ctx, util.DefaultTenant, &milvuspb.UserEntity{Name: in.Username}, &milvuspb.RoleEntity{Name: in.RoleName}, in.Type) - if err != nil && !common.IsIgnorableError(err) { - ctxLog.Warn("operate user role mata data failed", zap.Error(err)) - return nil, err - } - return nil, nil - })) - redoTask.AddAsyncStep(NewSimpleStep("operate user role cache", func(ctx context.Context) ([]nestedStep, error) { - var opType int32 - switch in.Type { - case milvuspb.OperateUserRoleType_AddUserToRole: - opType = int32(typeutil.CacheAddUserToRole) - case milvuspb.OperateUserRoleType_RemoveUserFromRole: - opType = int32(typeutil.CacheRemoveUserFromRole) - default: - errMsg := "invalid operate type for the OperateUserRole api" - ctxLog.Warn(errMsg, zap.Any("in", in)) - return nil, nil - } - if err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: opType, - OpKey: funcutil.EncodeUserRoleCache(in.Username, in.RoleName), - }); err != nil { - ctxLog.Warn("fail to refresh policy info cache", zap.Any("in", in), zap.Error(err)) - return nil, err - } - return nil, nil - })) - err := redoTask.Execute(ctx) + err := executeOperateUserRoleTaskSteps(ctx, c, in) if err != nil { errMsg := "fail to execute task when operate the user and role" ctxLog.Warn(errMsg, zap.Error(err)) @@ -2729,83 +2644,7 @@ func (c *Core) OperatePrivilege(ctx context.Context, in *milvuspb.OperatePrivile } } - privName = in.Entity.Grantor.Privilege.Name - - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("operate privilege meta data", func(ctx context.Context) ([]nestedStep, error) { - if !util.IsAnyWord(privName) { - // set up privilege name for metastore - dbPrivName, err := c.getMetastorePrivilegeName(ctx, privName) - if err != nil { - return nil, err - } - in.Entity.Grantor.Privilege.Name = dbPrivName - } - - err := c.meta.OperatePrivilege(ctx, util.DefaultTenant, in.Entity, in.Type) - if err != nil && !common.IsIgnorableError(err) { - ctxLog.Warn("fail to operate the privilege", zap.Any("in", in), zap.Error(err)) - return nil, err - } - return nil, nil - })) - redoTask.AddAsyncStep(NewSimpleStep("operate privilege cache", func(ctx context.Context) ([]nestedStep, error) { - // set back to expand privilege group - in.Entity.Grantor.Privilege.Name = privName - var opType int32 - switch in.Type { - case milvuspb.OperatePrivilegeType_Grant: - opType = int32(typeutil.CacheGrantPrivilege) - case milvuspb.OperatePrivilegeType_Revoke: - opType = int32(typeutil.CacheRevokePrivilege) - default: - log.Warn("invalid operate type for the OperatePrivilege api", zap.Any("in", in)) - return nil, nil - } - grants := []*milvuspb.GrantEntity{in.Entity} - - allGroups, err := c.meta.ListPrivilegeGroups(ctx) - allGroups = append(allGroups, c.initBuiltinPrivilegeGroups()...) - if err != nil { - return nil, err - } - groups := lo.SliceToMap(allGroups, func(group *milvuspb.PrivilegeGroupInfo) (string, []*milvuspb.PrivilegeEntity) { - return group.GroupName, group.Privileges - }) - expandGrants, err := c.expandPrivilegeGroups(ctx, grants, groups) - if err != nil { - return nil, err - } - // if there is same grant in the other privilege groups, the grant should not be removed from the cache - if in.Type == milvuspb.OperatePrivilegeType_Revoke { - metaGrants, err := c.meta.SelectGrant(ctx, util.DefaultTenant, &milvuspb.GrantEntity{ - Role: in.Entity.Role, - DbName: in.Entity.DbName, - }) - if err != nil { - return nil, err - } - metaExpandGrants, err := c.expandPrivilegeGroups(ctx, metaGrants, groups) - if err != nil { - return nil, err - } - expandGrants = lo.Filter(expandGrants, func(g1 *milvuspb.GrantEntity, _ int) bool { - return !lo.ContainsBy(metaExpandGrants, func(g2 *milvuspb.GrantEntity) bool { - return proto.Equal(g1, g2) - }) - }) - } - if err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: opType, - OpKey: funcutil.PolicyForPrivileges(expandGrants), - }); err != nil { - log.Warn("fail to refresh policy info cache", zap.Any("in", in), zap.Error(err)) - return nil, err - } - return nil, nil - })) - - err := redoTask.Execute(ctx) + err := executeOperatePrivilegeTaskSteps(ctx, c, in) if err != nil { errMsg := "fail to execute task when operating the privilege" ctxLog.Warn(errMsg, zap.Error(err)) @@ -3022,25 +2861,7 @@ func (c *Core) RestoreRBAC(ctx context.Context, in *milvuspb.RestoreRBACMetaRequ return merr.Status(err), nil } - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("restore rbac meta data", func(ctx context.Context) ([]nestedStep, error) { - if err := c.meta.RestoreRBAC(ctx, util.DefaultTenant, in.RBACMeta); err != nil { - ctxLog.Warn("fail to restore rbac meta data", zap.Any("in", in), zap.Error(err)) - return nil, err - } - return nil, nil - })) - redoTask.AddAsyncStep(NewSimpleStep("operate privilege cache", func(ctx context.Context) ([]nestedStep, error) { - if err := c.proxyClientManager.RefreshPolicyInfoCache(c.ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: int32(typeutil.CacheRefresh), - }); err != nil { - ctxLog.Warn("fail to refresh policy info cache", zap.Any("in", in), zap.Error(err)) - return nil, err - } - return nil, nil - })) - - err := redoTask.Execute(ctx) + err := executeRestoreRBACTaskSteps(ctx, c, in) if err != nil { errMsg := "fail to execute task when restore rbac meta data" ctxLog.Warn(errMsg, zap.Error(err)) @@ -3265,127 +3086,7 @@ func (c *Core) OperatePrivilegeGroup(ctx context.Context, in *milvuspb.OperatePr return merr.Status(err), nil } - redoTask := newBaseRedoTask(c.stepExecutor) - redoTask.AddSyncStep(NewSimpleStep("operate privilege group", func(ctx context.Context) ([]nestedStep, error) { - groups, err := c.meta.ListPrivilegeGroups(ctx) - if err != nil && !common.IsIgnorableError(err) { - ctxLog.Warn("fail to list privilege groups", zap.Error(err)) - return nil, err - } - currGroups := lo.SliceToMap(groups, func(group *milvuspb.PrivilegeGroupInfo) (string, []*milvuspb.PrivilegeEntity) { - return group.GroupName, group.Privileges - }) - - // get roles granted to the group - roles, err := c.meta.GetPrivilegeGroupRoles(ctx, in.GroupName) - if err != nil { - return nil, err - } - - newGroups := make(map[string][]*milvuspb.PrivilegeEntity) - for k, v := range currGroups { - if k != in.GroupName { - newGroups[k] = v - continue - } - switch in.Type { - case milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup: - newPrivs := lo.Union(v, in.Privileges) - newGroups[k] = lo.UniqBy(newPrivs, func(p *milvuspb.PrivilegeEntity) string { - return p.Name - }) - - // check if privileges are the same object type - objectTypes := lo.SliceToMap(newPrivs, func(p *milvuspb.PrivilegeEntity) (string, struct{}) { - return util.GetObjectType(p.Name), struct{}{} - }) - if len(objectTypes) > 1 { - return nil, errors.New("privileges are not the same object type") - } - case milvuspb.OperatePrivilegeGroupType_RemovePrivilegesFromGroup: - newPrivs, _ := lo.Difference(v, in.Privileges) - newGroups[k] = newPrivs - default: - return nil, errors.New("invalid operate type") - } - } - - rolesToRevoke := []*milvuspb.GrantEntity{} - rolesToGrant := []*milvuspb.GrantEntity{} - compareGrants := func(a, b *milvuspb.GrantEntity) bool { - return a.Role.Name == b.Role.Name && - a.Object.Name == b.Object.Name && - a.ObjectName == b.ObjectName && - a.Grantor.User.Name == b.Grantor.User.Name && - a.Grantor.Privilege.Name == b.Grantor.Privilege.Name && - a.DbName == b.DbName - } - for _, role := range roles { - grants, err := c.meta.SelectGrant(ctx, util.DefaultTenant, &milvuspb.GrantEntity{ - Role: role, - DbName: util.AnyWord, - }) - if err != nil { - return nil, err - } - currGrants, err := c.expandPrivilegeGroups(ctx, grants, currGroups) - if err != nil { - return nil, err - } - newGrants, err := c.expandPrivilegeGroups(ctx, grants, newGroups) - if err != nil { - return nil, err - } - - toRevoke := lo.Filter(currGrants, func(item *milvuspb.GrantEntity, _ int) bool { - return !lo.ContainsBy(newGrants, func(newItem *milvuspb.GrantEntity) bool { - return compareGrants(item, newItem) - }) - }) - - toGrant := lo.Filter(newGrants, func(item *milvuspb.GrantEntity, _ int) bool { - return !lo.ContainsBy(currGrants, func(currItem *milvuspb.GrantEntity) bool { - return compareGrants(item, currItem) - }) - }) - - rolesToRevoke = append(rolesToRevoke, toRevoke...) - rolesToGrant = append(rolesToGrant, toGrant...) - } - - if len(rolesToRevoke) > 0 { - opType := int32(typeutil.CacheRevokePrivilege) - if err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: opType, - OpKey: funcutil.PolicyForPrivileges(rolesToRevoke), - }); err != nil { - ctxLog.Warn("fail to refresh policy info cache for revoke privileges in operate privilege group", zap.Any("in", in), zap.Error(err)) - return nil, err - } - } - - if len(rolesToGrant) > 0 { - opType := int32(typeutil.CacheGrantPrivilege) - if err := c.proxyClientManager.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{ - OpType: opType, - OpKey: funcutil.PolicyForPrivileges(rolesToGrant), - }); err != nil { - ctxLog.Warn("fail to refresh policy info cache for grants privilege in operate privilege group", zap.Any("in", in), zap.Error(err)) - return nil, err - } - } - return nil, nil - })) - - redoTask.AddSyncStep(NewSimpleStep("operate privilege group meta data", func(ctx context.Context) ([]nestedStep, error) { - err := c.meta.OperatePrivilegeGroup(ctx, in.GroupName, in.Privileges, in.Type) - if err != nil && !common.IsIgnorableError(err) { - ctxLog.Warn("fail to operate privilege group", zap.Error(err)) - } - return nil, err - })) - - err := redoTask.Execute(ctx) + err := executeOperatePrivilegeGroupTaskSteps(ctx, c, in) if err != nil { errMsg := "fail to execute task when operate privilege group" ctxLog.Warn(errMsg, zap.Error(err)) diff --git a/internal/rootcoord/task_test.go b/internal/rootcoord/task_test.go index 09f0eb5a9771d..043885eab9e58 100644 --- a/internal/rootcoord/task_test.go +++ b/internal/rootcoord/task_test.go @@ -156,9 +156,10 @@ func TestGetLockerKey(t *testing.T) { DbName: "foo", CollectionName: "bar", }, + collID: 10, } key := tt.GetLockerKey() - assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|bar-2-true") + assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|10-2-true") }) t.Run("create database task locker key", func(t *testing.T) { tt := &createDatabaseTask{ @@ -259,14 +260,26 @@ func TestGetLockerKey(t *testing.T) { assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true") }) t.Run("drop collection task locker key", func(t *testing.T) { + metaMock := mockrootcoord.NewIMetaTable(t) + metaMock.EXPECT().GetCollectionByName(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(ctx context.Context, s string, s2 string, u uint64) (*model.Collection, error) { + return &model.Collection{ + Name: "bar", + CollectionID: 111, + }, nil + }) + c := &Core{ + meta: metaMock, + } tt := &dropCollectionTask{ + baseTask: baseTask{core: c}, Req: &milvuspb.DropCollectionRequest{ DbName: "foo", CollectionName: "bar", }, } key := tt.GetLockerKey() - assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|bar-2-true") + assert.Equal(t, GetLockerKeyString(key), "$-0-false|foo-1-false|111-2-true") }) t.Run("drop database task locker key", func(t *testing.T) { tt := &dropDatabaseTask{